aboutsummaryrefslogtreecommitdiff
path: root/syncapi
diff options
context:
space:
mode:
authorSam Wedgwood <28223854+swedgwood@users.noreply.github.com>2023-08-15 12:37:04 +0100
committerGitHub <noreply@github.com>2023-08-15 12:37:04 +0100
commit9a12420428f1832c76fc0c84ad85db200e261ecb (patch)
tree38ce262c515d74865920f6ebaf336f1887dee11b /syncapi
parentfa6c7ba45671c8fbf13cb7ba456355a04941b535 (diff)
[pseudoID] More pseudo ID fixes (#3167)
Signed-off-by: `Sam Wedgwood <sam@wedgwood.dev>`
Diffstat (limited to 'syncapi')
-rw-r--r--syncapi/internal/history_visibility.go59
-rw-r--r--syncapi/internal/keychange_test.go4
-rw-r--r--syncapi/routing/context.go6
-rw-r--r--syncapi/routing/getevent.go43
-rw-r--r--syncapi/routing/messages.go13
-rw-r--r--syncapi/routing/relations.go33
-rw-r--r--syncapi/streams/stream_pdu.go7
-rw-r--r--syncapi/syncapi_test.go14
8 files changed, 118 insertions, 61 deletions
diff --git a/syncapi/internal/history_visibility.go b/syncapi/internal/history_visibility.go
index 3c230895..91a2d63c 100644
--- a/syncapi/internal/history_visibility.go
+++ b/syncapi/internal/history_visibility.go
@@ -16,6 +16,7 @@ package internal
import (
"context"
+ "fmt"
"math"
"time"
@@ -101,13 +102,15 @@ func (ev eventVisibility) allowed() (allowed bool) {
// ApplyHistoryVisibilityFilter applies the room history visibility filter on types.HeaderedEvents.
// Returns the filtered events and an error, if any.
+//
+// This function assumes that all provided events are from the same room.
func ApplyHistoryVisibilityFilter(
ctx context.Context,
syncDB storage.DatabaseTransaction,
rsAPI api.SyncRoomserverAPI,
events []*types.HeaderedEvent,
alwaysIncludeEventIDs map[string]struct{},
- userID, endpoint string,
+ userID spec.UserID, endpoint string,
) ([]*types.HeaderedEvent, error) {
if len(events) == 0 {
return events, nil
@@ -115,15 +118,29 @@ func ApplyHistoryVisibilityFilter(
start := time.Now()
// try to get the current membership of the user
- membershipCurrent, _, err := syncDB.SelectMembershipForUser(ctx, events[0].RoomID(), userID, math.MaxInt64)
+ membershipCurrent, _, err := syncDB.SelectMembershipForUser(ctx, events[0].RoomID(), userID.String(), math.MaxInt64)
if err != nil {
return nil, err
}
// Get the mapping from eventID -> eventVisibility
eventsFiltered := make([]*types.HeaderedEvent, 0, len(events))
- visibilities := visibilityForEvents(ctx, rsAPI, events, userID, events[0].RoomID())
+ firstEvRoomID, err := spec.NewRoomID(events[0].RoomID())
+ if err != nil {
+ return nil, err
+ }
+ senderID, err := rsAPI.QuerySenderIDForUser(ctx, *firstEvRoomID, userID)
+ if err != nil {
+ return nil, err
+ }
+ visibilities := visibilityForEvents(ctx, rsAPI, events, senderID, *firstEvRoomID)
+
for _, ev := range events {
+ // Validate same room assumption
+ if ev.RoomID() != firstEvRoomID.String() {
+ return nil, fmt.Errorf("events from different rooms supplied to ApplyHistoryVisibilityFilter")
+ }
+
evVis := visibilities[ev.EventID()]
evVis.membershipCurrent = membershipCurrent
// Always include specific state events for /sync responses
@@ -133,23 +150,15 @@ func ApplyHistoryVisibilityFilter(
continue
}
}
- // NOTSPEC: Always allow user to see their own membership events (spec contains more "rules")
- user, err := spec.NewUserID(userID, true)
- if err != nil {
- return nil, err
- }
- roomID, err := spec.NewRoomID(ev.RoomID())
- if err != nil {
- return nil, err
- }
- senderID, err := rsAPI.QuerySenderIDForUser(ctx, *roomID, *user)
- if err == nil && senderID != nil {
+ // NOTSPEC: Always allow user to see their own membership events (spec contains more "rules")
+ if senderID != nil {
if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(string(*senderID)) {
eventsFiltered = append(eventsFiltered, ev)
continue
}
}
+
// Always allow history evVis events on boundaries. This is done
// by setting the effective evVis to the least restrictive
// of the old vs new.
@@ -178,13 +187,13 @@ func ApplyHistoryVisibilityFilter(
}
// visibilityForEvents returns a map from eventID to eventVisibility containing the visibility and the membership
-// of `userID` at the given event.
+// of `senderID` at the given event. If provided sender ID is nil, assume that membership is Leave
// Returns an error if the roomserver can't calculate the memberships.
func visibilityForEvents(
ctx context.Context,
rsAPI api.SyncRoomserverAPI,
events []*types.HeaderedEvent,
- userID, roomID string,
+ senderID *spec.SenderID, roomID spec.RoomID,
) map[string]eventVisibility {
eventIDs := make([]string, len(events))
for i := range events {
@@ -194,15 +203,13 @@ func visibilityForEvents(
result := make(map[string]eventVisibility, len(eventIDs))
// get the membership events for all eventIDs
- membershipResp := &api.QueryMembershipAtEventResponse{}
-
- err := rsAPI.QueryMembershipAtEvent(ctx, &api.QueryMembershipAtEventRequest{
- RoomID: roomID,
- EventIDs: eventIDs,
- UserID: userID,
- }, membershipResp)
- if err != nil {
- logrus.WithError(err).Error("visibilityForEvents: failed to fetch membership at event, defaulting to 'leave'")
+ var err error
+ membershipEvents := make(map[string]*types.HeaderedEvent)
+ if senderID != nil {
+ membershipEvents, err = rsAPI.QueryMembershipAtEvent(ctx, roomID, eventIDs, *senderID)
+ if err != nil {
+ logrus.WithError(err).Error("visibilityForEvents: failed to fetch membership at event, defaulting to 'leave'")
+ }
}
// Create a map from eventID -> eventVisibility
@@ -212,7 +219,7 @@ func visibilityForEvents(
membershipAtEvent: spec.Leave, // default to leave, to not expose events by accident
visibility: event.Visibility,
}
- ev, ok := membershipResp.Membership[eventID]
+ ev, ok := membershipEvents[eventID]
if !ok || ev == nil {
result[eventID] = vis
continue
diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go
index 3f5e990c..81b82bf6 100644
--- a/syncapi/internal/keychange_test.go
+++ b/syncapi/internal/keychange_test.go
@@ -69,8 +69,8 @@ func (s *mockRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spe
}
// QueryRoomsForUser retrieves a list of room IDs matching the given query.
-func (s *mockRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error {
- return nil
+func (s *mockRoomserverAPI) QueryRoomsForUser(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error) {
+ return nil, nil
}
// QueryBulkStateContent does a bulk query for state event content in the given rooms.
diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go
index 649d77b4..b0c91c40 100644
--- a/syncapi/routing/context.go
+++ b/syncapi/routing/context.go
@@ -138,7 +138,7 @@ func Context(
// verify the user is allowed to see the context for this room/event
startTime := time.Now()
- filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, []*rstypes.HeaderedEvent{&requestedEvent}, nil, device.UserID, "context")
+ filteredEvents, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, []*rstypes.HeaderedEvent{&requestedEvent}, nil, *userID, "context")
if err != nil {
logrus.WithError(err).Error("unable to apply history visibility filter")
return util.JSONResponse{
@@ -176,7 +176,7 @@ func Context(
}
startTime = time.Now()
- eventsBeforeFiltered, eventsAfterFiltered, err := applyHistoryVisibilityOnContextEvents(ctx, snapshot, rsAPI, eventsBefore, eventsAfter, device.UserID)
+ eventsBeforeFiltered, eventsAfterFiltered, err := applyHistoryVisibilityOnContextEvents(ctx, snapshot, rsAPI, eventsBefore, eventsAfter, *userID)
if err != nil {
logrus.WithError(err).Error("unable to apply history visibility filter")
return util.JSONResponse{
@@ -257,7 +257,7 @@ func Context(
func applyHistoryVisibilityOnContextEvents(
ctx context.Context, snapshot storage.DatabaseTransaction, rsAPI roomserver.SyncRoomserverAPI,
eventsBefore, eventsAfter []*rstypes.HeaderedEvent,
- userID string,
+ userID spec.UserID,
) (filteredBefore, filteredAfter []*rstypes.HeaderedEvent, err error) {
eventIDsBefore := make(map[string]struct{}, len(eventsBefore))
eventIDsAfter := make(map[string]struct{}, len(eventsAfter))
diff --git a/syncapi/routing/getevent.go b/syncapi/routing/getevent.go
index 09c2aef0..4fa282f3 100644
--- a/syncapi/routing/getevent.go
+++ b/syncapi/routing/getevent.go
@@ -37,7 +37,7 @@ import (
func GetEvent(
req *http.Request,
device *userapi.Device,
- roomID string,
+ rawRoomID string,
eventID string,
cfg *config.SyncAPI,
syncDB storage.Database,
@@ -47,7 +47,7 @@ func GetEvent(
db, err := syncDB.NewDatabaseTransaction(ctx)
logger := util.GetLogger(ctx).WithFields(logrus.Fields{
"event_id": eventID,
- "room_id": roomID,
+ "room_id": rawRoomID,
})
if err != nil {
logger.WithError(err).Error("GetEvent: syncDB.NewDatabaseTransaction failed")
@@ -57,6 +57,14 @@ func GetEvent(
}
}
+ roomID, err := spec.NewRoomID(rawRoomID)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: spec.InvalidParam("invalid room ID"),
+ }
+ }
+
events, err := db.Events(ctx, []string{eventID})
if err != nil {
logger.WithError(err).Error("GetEvent: syncDB.Events failed")
@@ -76,13 +84,22 @@ func GetEvent(
}
// If the request is coming from an appservice, get the user from the request
- userID := device.UserID
+ rawUserID := device.UserID
if asUserID := req.FormValue("user_id"); device.AppserviceID != "" && asUserID != "" {
- userID = asUserID
+ rawUserID = asUserID
+ }
+
+ userID, err := spec.NewUserID(rawUserID, true)
+ if err != nil {
+ util.GetLogger(req.Context()).WithError(err).Error("invalid device.UserID")
+ return util.JSONResponse{
+ Code: http.StatusInternalServerError,
+ JSON: spec.Unknown("internal server error"),
+ }
}
// Apply history visibility to determine if the user is allowed to view the event
- events, err = internal.ApplyHistoryVisibilityFilter(ctx, db, rsAPI, events, nil, userID, "event")
+ events, err = internal.ApplyHistoryVisibilityFilter(ctx, db, rsAPI, events, nil, *userID, "event")
if err != nil {
logger.WithError(err).Error("GetEvent: internal.ApplyHistoryVisibilityFilter failed")
return util.JSONResponse{
@@ -101,18 +118,14 @@ func GetEvent(
}
}
- sender := spec.UserID{}
- validRoomID, err := spec.NewRoomID(roomID)
- if err != nil {
+ senderUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *roomID, events[0].SenderID())
+ if err != nil || senderUserID == nil {
+ util.GetLogger(req.Context()).WithError(err).WithField("senderID", events[0].SenderID()).WithField("roomID", *roomID).Error("QueryUserIDForSender errored or returned nil-user ID when user should be part of a room")
return util.JSONResponse{
- Code: http.StatusBadRequest,
- JSON: spec.BadJSON("roomID is invalid"),
+ Code: http.StatusInternalServerError,
+ JSON: spec.Unknown("internal server error"),
}
}
- senderUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, events[0].SenderID())
- if err == nil && senderUserID != nil {
- sender = *senderUserID
- }
sk := events[0].StateKey()
if sk != nil && *sk != "" {
@@ -131,6 +144,6 @@ func GetEvent(
}
return util.JSONResponse{
Code: http.StatusOK,
- JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender, sk),
+ JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, *senderUserID, sk),
}
}
diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go
index 23a09544..3333cb54 100644
--- a/syncapi/routing/messages.go
+++ b/syncapi/routing/messages.go
@@ -50,6 +50,7 @@ type messagesReq struct {
from *types.TopologyToken
to *types.TopologyToken
device *userapi.Device
+ deviceUserID spec.UserID
wasToProvided bool
backwardOrdering bool
filter *synctypes.RoomEventFilter
@@ -77,6 +78,15 @@ func OnIncomingMessagesRequest(
) util.JSONResponse {
var err error
+ deviceUserID, err := spec.NewUserID(device.UserID, true)
+ if err != nil {
+ util.GetLogger(req.Context()).WithError(err).Error("device.UserID invalid")
+ return util.JSONResponse{
+ Code: http.StatusInternalServerError,
+ JSON: spec.Unknown("internal server error"),
+ }
+ }
+
// NewDatabaseTransaction is used here instead of NewDatabaseSnapshot as we
// expect to be able to write to the database in response to a /messages
// request that requires backfilling from the roomserver or federation.
@@ -240,6 +250,7 @@ func OnIncomingMessagesRequest(
filter: filter,
backwardOrdering: backwardOrdering,
device: device,
+ deviceUserID: *deviceUserID,
}
clientEvents, start, end, err := mReq.retrieveEvents(req.Context(), rsAPI)
@@ -359,7 +370,7 @@ func (r *messagesReq) retrieveEvents(ctx context.Context, rsAPI api.SyncRoomserv
// Apply room history visibility filter
startTime := time.Now()
- filteredEvents, err := internal.ApplyHistoryVisibilityFilter(r.ctx, r.snapshot, r.rsAPI, events, nil, r.device.UserID, "messages")
+ filteredEvents, err := internal.ApplyHistoryVisibilityFilter(r.ctx, r.snapshot, r.rsAPI, events, nil, r.deviceUserID, "messages")
if err != nil {
return []synctypes.ClientEvent{}, *r.from, *r.to, nil
}
diff --git a/syncapi/routing/relations.go b/syncapi/routing/relations.go
index 17933b2f..e3d1069a 100644
--- a/syncapi/routing/relations.go
+++ b/syncapi/routing/relations.go
@@ -43,9 +43,25 @@ func Relations(
req *http.Request, device *userapi.Device,
syncDB storage.Database,
rsAPI api.SyncRoomserverAPI,
- roomID, eventID, relType, eventType string,
+ rawRoomID, eventID, relType, eventType string,
) util.JSONResponse {
- var err error
+ roomID, err := spec.NewRoomID(rawRoomID)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: spec.InvalidParam("invalid room ID"),
+ }
+ }
+
+ userID, err := spec.NewUserID(device.UserID, true)
+ if err != nil {
+ util.GetLogger(req.Context()).WithError(err).Error("device.UserID invalid")
+ return util.JSONResponse{
+ Code: http.StatusInternalServerError,
+ JSON: spec.Unknown("internal server error"),
+ }
+ }
+
var from, to types.StreamPosition
var limit int
dir := req.URL.Query().Get("dir")
@@ -93,7 +109,7 @@ func Relations(
}
var events []types.StreamEvent
events, res.PrevBatch, res.NextBatch, err = snapshot.RelationsFor(
- req.Context(), roomID, eventID, relType, eventType, from, to, dir == "b", limit,
+ req.Context(), roomID.String(), eventID, relType, eventType, from, to, dir == "b", limit,
)
if err != nil {
return util.ErrorResponse(err)
@@ -105,12 +121,7 @@ func Relations(
}
// Apply history visibility to the result events.
- filteredEvents, err := internal.ApplyHistoryVisibilityFilter(req.Context(), snapshot, rsAPI, headeredEvents, nil, device.UserID, "relations")
- if err != nil {
- return util.ErrorResponse(err)
- }
-
- validRoomID, err := spec.NewRoomID(roomID)
+ filteredEvents, err := internal.ApplyHistoryVisibilityFilter(req.Context(), snapshot, rsAPI, headeredEvents, nil, *userID, "relations")
if err != nil {
return util.ErrorResponse(err)
}
@@ -120,14 +131,14 @@ func Relations(
res.Chunk = make([]synctypes.ClientEvent, 0, len(filteredEvents))
for _, event := range filteredEvents {
sender := spec.UserID{}
- userID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, event.SenderID())
+ userID, err := rsAPI.QueryUserIDForSender(req.Context(), *roomID, event.SenderID())
if err == nil && userID != nil {
sender = *userID
}
sk := event.StateKey()
if sk != nil && *sk != "" {
- skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *validRoomID, spec.SenderID(*event.StateKey()))
+ skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), *roomID, spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go
index 48daf857..4622c21a 100644
--- a/syncapi/streams/stream_pdu.go
+++ b/syncapi/streams/stream_pdu.go
@@ -562,8 +562,13 @@ func applyHistoryVisibilityFilter(
}
}
+ parsedUserID, err := spec.NewUserID(userID, true)
+ if err != nil {
+ return nil, err
+ }
+
startTime := time.Now()
- events, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, recentEvents, alwaysIncludeIDs, userID, "sync")
+ events, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, recentEvents, alwaysIncludeIDs, *parsedUserID, "sync")
if err != nil {
return nil, err
}
diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go
index 996b21e9..ea1183cd 100644
--- a/syncapi/syncapi_test.go
+++ b/syncapi/syncapi_test.go
@@ -44,6 +44,11 @@ func (s *syncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID spe
return spec.NewUserID(string(senderID), true)
}
+func (s *syncRoomserverAPI) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (*spec.SenderID, error) {
+ senderID := spec.SenderID(userID.String())
+ return &senderID, nil
+}
+
func (s *syncRoomserverAPI) QueryLatestEventsAndState(ctx context.Context, req *rsapi.QueryLatestEventsAndStateRequest, res *rsapi.QueryLatestEventsAndStateResponse) error {
var room *test.Room
for _, r := range s.rooms {
@@ -74,8 +79,13 @@ func (s *syncRoomserverAPI) QueryMembershipForUser(ctx context.Context, req *rsa
return nil
}
-func (s *syncRoomserverAPI) QueryMembershipAtEvent(ctx context.Context, req *rsapi.QueryMembershipAtEventRequest, res *rsapi.QueryMembershipAtEventResponse) error {
- return nil
+func (s *syncRoomserverAPI) QueryMembershipAtEvent(
+ ctx context.Context,
+ roomID spec.RoomID,
+ eventIDs []string,
+ senderID spec.SenderID,
+) (map[string]*rstypes.HeaderedEvent, error) {
+ return map[string]*rstypes.HeaderedEvent{}, nil
}
type syncUserAPI struct {