diff options
author | devonh <devon.dmytro@gmail.com> | 2023-06-07 17:14:35 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-06-07 17:14:35 +0000 |
commit | 8ea1a11105ea7e66aa459537bcbef0de606147cd (patch) | |
tree | 62a539b761d078978226752c8d3ef23f1a80e677 /syncapi | |
parent | 7a1fd7f512ce06a472a2051ee63eae4a270eb71a (diff) |
Use SenderID Type (#3105)
Diffstat (limited to 'syncapi')
-rw-r--r-- | syncapi/consumers/roomserver.go | 2 | ||||
-rw-r--r-- | syncapi/routing/context.go | 8 | ||||
-rw-r--r-- | syncapi/routing/memberships.go | 19 | ||||
-rw-r--r-- | syncapi/routing/messages.go | 6 | ||||
-rw-r--r-- | syncapi/routing/search.go | 8 | ||||
-rw-r--r-- | syncapi/routing/search_test.go | 4 | ||||
-rw-r--r-- | syncapi/storage/interface.go | 6 | ||||
-rw-r--r-- | syncapi/storage/shared/storage_consumer.go | 75 | ||||
-rw-r--r-- | syncapi/storage/shared/storage_sync.go | 19 | ||||
-rw-r--r-- | syncapi/storage/storage_test.go | 2 | ||||
-rw-r--r-- | syncapi/streams/stream_pdu.go | 30 | ||||
-rw-r--r-- | syncapi/syncapi_test.go | 4 | ||||
-rw-r--r-- | syncapi/types/types.go | 2 |
13 files changed, 98 insertions, 87 deletions
diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index c0836465..8a2a0b1f 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -523,7 +523,7 @@ func (s *OutputRoomEventConsumer) updateStateEvent(event *rstypes.HeaderedEvent) prev := types.PrevEventRef{ PrevContent: prevEvent.Content(), ReplacesState: prevEvent.EventID(), - PrevSender: prevEvent.SenderID(), + PrevSenderID: string(prevEvent.SenderID()), } event.PDU, err = event.SetUnsigned(prev) diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index 27e99a35..7fb88faa 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -193,10 +193,10 @@ func Context( } } - eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + eventsBeforeClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBeforeFiltered), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) - eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + eventsAfterClient := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfterFiltered), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) @@ -204,7 +204,7 @@ func Context( if filter.LazyLoadMembers { allEvents := append(eventsBeforeFiltered, eventsAfterFiltered...) allEvents = append(allEvents, &requestedEvent) - evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + evs := synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(allEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache) @@ -227,7 +227,7 @@ func Context( Event: &ev, EventsAfter: eventsAfterClient, EventsBefore: eventsBeforeClient, - State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + State: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(newState), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), } diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go index 9c2319dd..813167a5 100644 --- a/syncapi/routing/memberships.go +++ b/syncapi/routing/memberships.go @@ -144,7 +144,22 @@ func GetMemberships( JSON: spec.InternalServerError{}, } } - res.Joined[ev.SenderID()] = joinedMember(content) + + userID, err := rsAPI.QueryUserIDForSender(req.Context(), ev.RoomID(), ev.SenderID()) + if err != nil || userID == nil { + util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryUserIDForSender failed") + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: spec.InternalServerError{}, + } + } + if err != nil { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: spec.Forbidden("You don't have permission to kick this user, unknown senderID"), + } + } + res.Joined[userID.String()] = joinedMember(content) } return util.JSONResponse{ Code: http.StatusOK, @@ -153,7 +168,7 @@ func GetMemberships( } return util.JSONResponse{ Code: http.StatusOK, - JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + JSON: getMembershipResponse{synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(result), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) })}, } diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 879739d0..781fd53e 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -273,7 +273,7 @@ func OnIncomingMessagesRequest( JSON: spec.InternalServerError{}, } } - res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + res.State = append(res.State, synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(membershipEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) })...) } @@ -385,7 +385,7 @@ func (r *messagesReq) retrieveEvents(ctx context.Context, rsAPI api.SyncRoomserv "events_before": len(events), "events_after": len(filteredEvents), }).Debug("applied history visibility (messages)") - return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomID, senderID string) (*spec.UserID, error) { + return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }), start, end, err } @@ -495,7 +495,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent } // Append the events ve previously retrieved locally. - events = append(events, r.snapshot.StreamEventsToEvents(nil, streamEvents)...) + events = append(events, r.snapshot.StreamEventsToEvents(r.ctx, nil, streamEvents, r.rsAPI)...) sort.Sort(eventsByDepth(events)) return diff --git a/syncapi/routing/search.go b/syncapi/routing/search.go index 9cf3eabe..add50b18 100644 --- a/syncapi/routing/search.go +++ b/syncapi/routing/search.go @@ -213,7 +213,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts profile, ok := knownUsersProfiles[userID.String()] if !ok { - stateEvent, stateErr := snapshot.GetStateEvent(ctx, ev.RoomID(), spec.MRoomMember, ev.SenderID()) + stateEvent, stateErr := snapshot.GetStateEvent(ctx, ev.RoomID(), spec.MRoomMember, string(ev.SenderID())) if stateErr != nil { logrus.WithError(stateErr).WithField("sender_id", event.SenderID()).Warn("failed to query userprofile") continue @@ -239,10 +239,10 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts Context: SearchContextResponse{ Start: startToken.String(), End: endToken.String(), - EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + EventsAfter: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsAfter), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) }), - EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + EventsBefore: synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(eventsBefore), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) }), ProfileInfo: profileInfos, @@ -263,7 +263,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts JSON: spec.InternalServerError{}, } } - stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + stateForRooms[event.RoomID()] = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(state), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(req.Context(), roomID, senderID) }) } diff --git a/syncapi/routing/search_test.go b/syncapi/routing/search_test.go index b36be823..5eb094ca 100644 --- a/syncapi/routing/search_test.go +++ b/syncapi/routing/search_test.go @@ -25,8 +25,8 @@ import ( type FakeSyncRoomserverAPI struct{ rsapi.SyncRoomserverAPI } -func (f *FakeSyncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { - return spec.NewUserID(senderID, true) +func (f *FakeSyncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) } func TestSearch(t *testing.T) { diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 302b9bad..8798b62e 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -44,8 +44,8 @@ type DatabaseTransaction interface { MaxStreamPositionForRelations(ctx context.Context) (types.StreamPosition, error) CurrentState(ctx context.Context, roomID string, stateFilterPart *synctypes.StateFilter, excludeEventIDs []string) ([]*rstypes.HeaderedEvent, error) - GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *synctypes.StateFilter) ([]types.StateDelta, []string, error) - GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *synctypes.StateFilter) ([]types.StateDelta, []string, error) + GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *synctypes.StateFilter, rsAPI api.SyncRoomserverAPI) ([]types.StateDelta, []string, error) + GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *synctypes.StateFilter, rsAPI api.SyncRoomserverAPI) ([]types.StateDelta, []string, error) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) GetRoomSummary(ctx context.Context, roomID, userID string) (summary *types.Summary, err error) @@ -90,7 +90,7 @@ type DatabaseTransaction interface { // StreamEventsToEvents converts streamEvent to Event. If device is non-nil and // matches the streamevent.transactionID device then the transaction ID gets // added to the unsigned section of the output event. - StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*rstypes.HeaderedEvent + StreamEventsToEvents(ctx context.Context, device *userapi.Device, in []types.StreamEvent, rsAPI api.SyncRoomserverAPI) []*rstypes.HeaderedEvent // SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns the // relevant events within the given ranges for the supplied user ID and device ID. SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, from, to types.StreamPosition) (pos types.StreamPosition, events []types.SendToDeviceEvent, err error) diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go index 17a6a69c..5bd3b1f0 100644 --- a/syncapi/storage/shared/storage_consumer.go +++ b/syncapi/storage/shared/storage_consumer.go @@ -99,7 +99,41 @@ func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*rstypes.He // We don't include a device here as we only include transaction IDs in // incremental syncs. - return d.StreamEventsToEvents(nil, streamEvents), nil + return d.StreamEventsToEvents(ctx, nil, streamEvents, nil), nil +} + +func (d *Database) StreamEventsToEvents(ctx context.Context, device *userapi.Device, in []types.StreamEvent, rsAPI api.SyncRoomserverAPI) []*rstypes.HeaderedEvent { + out := make([]*rstypes.HeaderedEvent, len(in)) + for i := 0; i < len(in); i++ { + out[i] = in[i].HeaderedEvent + if device != nil && in[i].TransactionID != nil { + userID, err := spec.NewUserID(device.UserID, true) + if err != nil { + logrus.WithFields(logrus.Fields{ + "event_id": out[i].EventID(), + }).WithError(err).Warnf("Failed to add transaction ID to event") + continue + } + deviceSenderID, err := rsAPI.QuerySenderIDForUser(ctx, in[i].RoomID(), *userID) + if err != nil { + logrus.WithFields(logrus.Fields{ + "event_id": out[i].EventID(), + }).WithError(err).Warnf("Failed to add transaction ID to event") + continue + } + if deviceSenderID == in[i].SenderID() && device.SessionID == in[i].TransactionID.SessionID { + err := out[i].SetUnsignedField( + "transaction_id", in[i].TransactionID.TransactionID, + ) + if err != nil { + logrus.WithFields(logrus.Fields{ + "event_id": out[i].EventID(), + }).WithError(err).Warnf("Failed to add transaction ID to event") + } + } + } + } + return out } // AddInviteEvent stores a new invite event for a user. @@ -190,45 +224,6 @@ func (d *Database) UpsertAccountData( return } -func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*rstypes.HeaderedEvent { - out := make([]*rstypes.HeaderedEvent, len(in)) - for i := 0; i < len(in); i++ { - out[i] = in[i].HeaderedEvent - if device != nil && in[i].TransactionID != nil { - userID, err := spec.NewUserID(device.UserID, true) - if err != nil { - logrus.WithFields(logrus.Fields{ - "event_id": out[i].EventID(), - }).WithError(err).Warnf("Failed to add transaction ID to event") - continue - } - deviceSenderID, err := d.getSenderIDForUser(in[i].RoomID(), *userID) - if err != nil { - logrus.WithFields(logrus.Fields{ - "event_id": out[i].EventID(), - }).WithError(err).Warnf("Failed to add transaction ID to event") - continue - } - if deviceSenderID == in[i].SenderID() && device.SessionID == in[i].TransactionID.SessionID { - err := out[i].SetUnsignedField( - "transaction_id", in[i].TransactionID.TransactionID, - ) - if err != nil { - logrus.WithFields(logrus.Fields{ - "event_id": out[i].EventID(), - }).WithError(err).Warnf("Failed to add transaction ID to event") - } - } - } - } - return out -} - -func (d *Database) getSenderIDForUser(roomID string, userID spec.UserID) (string, error) { // nolint - // TODO: Repalce with actual logic for pseudoIDs - return userID.String(), nil -} - // handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of // the events listed in the event's 'prev_events'. This function also updates the backwards extremities table // to account for the fact that the given event is no longer a backwards extremity, but may be marked as such. diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index f2b1c58d..df961385 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -10,6 +10,7 @@ import ( "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/internal/eventutil" + "github.com/matrix-org/dendrite/roomserver/api" rstypes "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/syncapi/synctypes" "github.com/matrix-org/dendrite/syncapi/types" @@ -186,7 +187,7 @@ func (d *DatabaseTransaction) Events(ctx context.Context, eventIDs []string) ([] // We don't include a device here as we only include transaction IDs in // incremental syncs. - return d.StreamEventsToEvents(nil, streamEvents), nil + return d.StreamEventsToEvents(ctx, nil, streamEvents, nil), nil } func (d *DatabaseTransaction) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) { @@ -325,7 +326,7 @@ func (d *DatabaseTransaction) GetBackwardTopologyPos( func (d *DatabaseTransaction) GetStateDeltas( ctx context.Context, device *userapi.Device, r types.Range, userID string, - stateFilter *synctypes.StateFilter, + stateFilter *synctypes.StateFilter, rsAPI api.SyncRoomserverAPI, ) (deltas []types.StateDelta, joinedRoomsIDs []string, err error) { // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 // - Get membership list changes for this user in this sync response @@ -417,7 +418,7 @@ func (d *DatabaseTransaction) GetStateDeltas( if !peek.Deleted { deltas = append(deltas, types.StateDelta{ Membership: spec.Peek, - StateEvents: d.StreamEventsToEvents(device, state[peek.RoomID]), + StateEvents: d.StreamEventsToEvents(ctx, device, state[peek.RoomID], rsAPI), RoomID: peek.RoomID, }) } @@ -462,7 +463,7 @@ func (d *DatabaseTransaction) GetStateDeltas( deltas = append(deltas, types.StateDelta{ Membership: membership, MembershipPos: ev.StreamPosition, - StateEvents: d.StreamEventsToEvents(device, stateFiltered[roomID]), + StateEvents: d.StreamEventsToEvents(ctx, device, stateFiltered[roomID], rsAPI), RoomID: roomID, }) break @@ -474,7 +475,7 @@ func (d *DatabaseTransaction) GetStateDeltas( for _, joinedRoomID := range joinedRoomIDs { deltas = append(deltas, types.StateDelta{ Membership: spec.Join, - StateEvents: d.StreamEventsToEvents(device, stateFiltered[joinedRoomID]), + StateEvents: d.StreamEventsToEvents(ctx, device, stateFiltered[joinedRoomID], rsAPI), RoomID: joinedRoomID, NewlyJoined: newlyJoinedRooms[joinedRoomID], }) @@ -490,7 +491,7 @@ func (d *DatabaseTransaction) GetStateDeltas( func (d *DatabaseTransaction) GetStateDeltasForFullStateSync( ctx context.Context, device *userapi.Device, r types.Range, userID string, - stateFilter *synctypes.StateFilter, + stateFilter *synctypes.StateFilter, rsAPI api.SyncRoomserverAPI, ) ([]types.StateDelta, []string, error) { // Look up all memberships for the user. We only care about rooms that a // user has ever interacted with — joined to, kicked/banned from, left. @@ -531,7 +532,7 @@ func (d *DatabaseTransaction) GetStateDeltasForFullStateSync( } deltas[peek.RoomID] = types.StateDelta{ Membership: spec.Peek, - StateEvents: d.StreamEventsToEvents(device, s), + StateEvents: d.StreamEventsToEvents(ctx, device, s, rsAPI), RoomID: peek.RoomID, } } @@ -560,7 +561,7 @@ func (d *DatabaseTransaction) GetStateDeltasForFullStateSync( deltas[roomID] = types.StateDelta{ Membership: membership, MembershipPos: ev.StreamPosition, - StateEvents: d.StreamEventsToEvents(device, stateStreamEvents), + StateEvents: d.StreamEventsToEvents(ctx, device, stateStreamEvents, rsAPI), RoomID: roomID, } } @@ -581,7 +582,7 @@ func (d *DatabaseTransaction) GetStateDeltasForFullStateSync( } deltas[joinedRoomID] = types.StateDelta{ Membership: spec.Join, - StateEvents: d.StreamEventsToEvents(device, s), + StateEvents: d.StreamEventsToEvents(ctx, device, s, rsAPI), RoomID: joinedRoomID, } } diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index 08ca99a7..bc64aa50 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -214,7 +214,7 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) { if err != nil { t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err) } - gots := snapshot.StreamEventsToEvents(nil, paginatedEvents) + gots := snapshot.StreamEventsToEvents(context.Background(), nil, paginatedEvents, nil) test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:])) }) }) diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 8f83a089..d214980b 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -175,12 +175,12 @@ func (p *PDUStreamProvider) IncrementalSync( eventFilter := req.Filter.Room.Timeline if req.WantFullState { - if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { + if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltasForFullStateSync(ctx, req.Device, r, req.Device.UserID, &stateFilter, p.rsAPI); err != nil { req.Log.WithError(err).Error("p.DB.GetStateDeltasForFullStateSync failed") return from } } else { - if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter); err != nil { + if stateDeltas, syncJoinedRooms, err = snapshot.GetStateDeltas(ctx, req.Device, r, req.Device.UserID, &stateFilter, p.rsAPI); err != nil { req.Log.WithError(err).Error("p.DB.GetStateDeltas failed") return from } @@ -275,7 +275,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( limited := dbEvents[delta.RoomID].Limited recEvents := gomatrixserverlib.ReverseTopologicalOrdering( - gomatrixserverlib.ToPDUs(snapshot.StreamEventsToEvents(device, recentStreamEvents)), + gomatrixserverlib.ToPDUs(snapshot.StreamEventsToEvents(ctx, device, recentStreamEvents, p.rsAPI)), gomatrixserverlib.TopologicalOrderByPrevEvents, ) recentEvents := make([]*rstypes.HeaderedEvent, len(recEvents)) @@ -376,13 +376,13 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( } } jr.Timeline.PrevBatch = &prevBatch - jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. jr.Timeline.Limited = (limited && len(events) == len(recentEvents)) || delta.NewlyJoined - jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) req.Response.Rooms.Join[delta.RoomID] = jr @@ -391,11 +391,11 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( jr := types.NewJoinResponse() jr.Timeline.PrevBatch = &prevBatch // TODO: Apply history visibility on peeked rooms - jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(recentEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) jr.Timeline.Limited = limited - jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) req.Response.Rooms.Peek[delta.RoomID] = jr @@ -406,13 +406,13 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( case spec.Ban: lr := types.NewLeaveResponse() lr.Timeline.PrevBatch = &prevBatch - lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + lr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. lr.Timeline.Limited = limited && len(events) == len(recentEvents) - lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + lr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(delta.StateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) req.Response.Rooms.Leave[delta.RoomID] = lr @@ -437,7 +437,7 @@ func applyHistoryVisibilityFilter( for _, ev := range recentEvents { if ev.StateKey() != nil { stateTypes = append(stateTypes, ev.Type()) - senders = append(senders, ev.SenderID()) + senders = append(senders, string(ev.SenderID())) } } @@ -512,7 +512,7 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( // We don't include a device here as we don't need to send down // transaction IDs for complete syncs, but we do it anyway because Sytest demands it for: // "Can sync a room with a message with a transaction id" - which does a complete sync to check. - recentEvents := snapshot.StreamEventsToEvents(device, recentStreamEvents) + recentEvents := snapshot.StreamEventsToEvents(ctx, device, recentStreamEvents, p.rsAPI) events := recentEvents // Only apply history visibility checks if the response is for joined rooms @@ -564,13 +564,13 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( } jr.Timeline.PrevBatch = prevBatch - jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + jr.Timeline.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(events), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) // If we are limited by the filter AND the history visibility filter // didn't "remove" events, return that the response is limited. jr.Timeline.Limited = limited && len(events) == len(recentEvents) - jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync, func(roomID, senderID string) (*spec.UserID, error) { + jr.State.Events = synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(stateEvents), synctypes.FormatSync, func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { return p.rsAPI.QueryUserIDForSender(ctx, roomID, senderID) }) return jr, nil @@ -593,8 +593,8 @@ func (p *PDUStreamProvider) lazyLoadMembers( // Add all users the client doesn't know about yet to a list for _, event := range timelineEvents { // Membership is not yet cached, add it to the list - if _, ok := p.lazyLoadCache.IsLazyLoadedUserCached(device, roomID, event.SenderID()); !ok { - timelineUsers[event.SenderID()] = struct{}{} + if _, ok := p.lazyLoadCache.IsLazyLoadedUserCached(device, roomID, string(event.SenderID())); !ok { + timelineUsers[string(event.SenderID())] = struct{}{} } } // Preallocate with the same amount, even if it will end up with fewer values diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index 78c857ab..b9f13c51 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -40,8 +40,8 @@ type syncRoomserverAPI struct { rooms []*test.Room } -func (s *syncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { - return spec.NewUserID(senderID, true) +func (s *syncRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { + return spec.NewUserID(string(senderID), true) } func (s *syncRoomserverAPI) QueryLatestEventsAndState(ctx context.Context, req *rsapi.QueryLatestEventsAndStateRequest, res *rsapi.QueryLatestEventsAndStateResponse) error { diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 526a120d..a3dc7f54 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -343,7 +343,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) { type PrevEventRef struct { PrevContent json.RawMessage `json:"prev_content"` ReplacesState string `json:"replaces_state"` - PrevSender string `json:"prev_sender"` + PrevSenderID string `json:"prev_sender"` } type DeviceLists struct { |