aboutsummaryrefslogtreecommitdiff
path: root/syncapi
diff options
context:
space:
mode:
authordevonh <devon.dmytro@gmail.com>2023-06-12 11:19:25 +0000
committerGitHub <noreply@github.com>2023-06-12 11:19:25 +0000
commit77d9e4e93dd01f6baa82bd6236850c1007346cac (patch)
tree20be66224646cc82199028cf89f4cd7fab80b97f /syncapi
parent832ccc32f6a023665e250eee44b5f678e985d50e (diff)
Cleanup remaining statekey usage for senderIDs (#3106)
Diffstat (limited to 'syncapi')
-rw-r--r--syncapi/consumers/roomserver.go29
-rw-r--r--syncapi/internal/history_visibility.go14
-rw-r--r--syncapi/internal/keychange.go16
-rw-r--r--syncapi/internal/keychange_test.go4
-rw-r--r--syncapi/notifier/notifier.go45
-rw-r--r--syncapi/notifier/notifier_test.go22
-rw-r--r--syncapi/routing/context.go18
-rw-r--r--syncapi/routing/getevent.go11
-rw-r--r--syncapi/routing/memberships.go13
-rw-r--r--syncapi/routing/messages.go6
-rw-r--r--syncapi/routing/relations.go11
-rw-r--r--syncapi/routing/search.go11
-rw-r--r--syncapi/storage/shared/storage_consumer.go16
-rw-r--r--syncapi/storage/shared/storage_sync.go4
-rw-r--r--syncapi/streams/stream_invite.go11
-rw-r--r--syncapi/streams/stream_pdu.go12
-rw-r--r--syncapi/syncapi.go2
-rw-r--r--syncapi/synctypes/clientevent.go35
-rw-r--r--syncapi/synctypes/clientevent_test.go6
-rw-r--r--syncapi/types/types.go4
-rw-r--r--syncapi/types/types_test.go8
21 files changed, 231 insertions, 67 deletions
diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go
index 8a2a0b1f..c5f2db9c 100644
--- a/syncapi/consumers/roomserver.go
+++ b/syncapi/consumers/roomserver.go
@@ -373,7 +373,15 @@ func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *rst
// TODO: check that it's a join and not a profile change (means unmarshalling prev_content)
if membership == spec.Join {
// check it's a local join
- if _, _, err := s.cfg.Matrix.SplitLocalID('@', *ev.StateKey()); err != nil {
+ if ev.StateKey() == nil {
+ return sp, fmt.Errorf("unexpected nil state_key")
+ }
+
+ userID, err := s.rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey()))
+ if err != nil || userID == nil {
+ return sp, fmt.Errorf("failed getting userID for sender: %w", err)
+ }
+ if !s.cfg.Matrix.IsLocalServerName(userID.Domain()) {
return sp, nil
}
@@ -395,9 +403,15 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent(
if msg.Event.StateKey() == nil {
return
}
- if _, _, err := s.cfg.Matrix.SplitLocalID('@', *msg.Event.StateKey()); err != nil {
+
+ userID, err := s.rsAPI.QueryUserIDForSender(ctx, msg.Event.RoomID(), spec.SenderID(*msg.Event.StateKey()))
+ if err != nil || userID == nil {
+ return
+ }
+ if !s.cfg.Matrix.IsLocalServerName(userID.Domain()) {
return
}
+
pduPos, err := s.db.AddInviteEvent(ctx, msg.Event)
if err != nil {
sentry.CaptureException(err)
@@ -440,7 +454,16 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent(
// Notify any active sync requests that the invite has been retired.
s.inviteStream.Advance(pduPos)
- s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, msg.TargetUserID)
+ userID, err := s.rsAPI.QueryUserIDForSender(ctx, msg.RoomID, msg.TargetSenderID)
+ if err != nil || userID == nil {
+ log.WithFields(log.Fields{
+ "event_id": msg.EventID,
+ "sender_id": msg.TargetSenderID,
+ log.ErrorKey: err,
+ }).Errorf("failed to find userID for sender")
+ return
+ }
+ s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, userID.String())
}
func (s *OutputRoomEventConsumer) onNewPeek(
diff --git a/syncapi/internal/history_visibility.go b/syncapi/internal/history_visibility.go
index 7449b464..ab1a7f83 100644
--- a/syncapi/internal/history_visibility.go
+++ b/syncapi/internal/history_visibility.go
@@ -134,9 +134,17 @@ func ApplyHistoryVisibilityFilter(
}
}
// NOTSPEC: Always allow user to see their own membership events (spec contains more "rules")
- if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(userID) {
- eventsFiltered = append(eventsFiltered, ev)
- continue
+
+ user, err := spec.NewUserID(userID, true)
+ if err != nil {
+ return nil, err
+ }
+ senderID, err := rsAPI.QuerySenderIDForUser(ctx, ev.RoomID(), *user)
+ if err == nil {
+ if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(string(senderID)) {
+ eventsFiltered = append(eventsFiltered, ev)
+ continue
+ }
}
// Always allow history evVis events on boundaries. This is done
// by setting the effective evVis to the least restrictive
diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go
index ad5935cd..f4b6ace5 100644
--- a/syncapi/internal/keychange.go
+++ b/syncapi/internal/keychange.go
@@ -169,12 +169,16 @@ func TrackChangedUsers(
if err != nil {
return nil, nil, err
}
- for _, state := range stateRes.Rooms {
+ for roomID, state := range stateRes.Rooms {
for tuple, membership := range state {
if membership != spec.Join {
continue
}
- queryRes.UserIDsToCount[tuple.StateKey]--
+ user, queryErr := rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(tuple.StateKey))
+ if queryErr != nil || user == nil {
+ continue
+ }
+ queryRes.UserIDsToCount[user.String()]--
}
}
@@ -211,14 +215,18 @@ func TrackChangedUsers(
if err != nil {
return nil, left, err
}
- for _, state := range stateRes.Rooms {
+ for roomID, state := range stateRes.Rooms {
for tuple, membership := range state {
if membership != spec.Join {
continue
}
// new user who we weren't previously sharing rooms with
if _, ok := queryRes.UserIDsToCount[tuple.StateKey]; !ok {
- changed = append(changed, tuple.StateKey) // changed is returned
+ user, err := rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(tuple.StateKey))
+ if err != nil || user == nil {
+ continue
+ }
+ changed = append(changed, user.String()) // changed is returned
}
}
}
diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go
index 23c2ecba..efa64147 100644
--- a/syncapi/internal/keychange_test.go
+++ b/syncapi/internal/keychange_test.go
@@ -64,6 +64,10 @@ type mockRoomserverAPI struct {
roomIDToJoinedMembers map[string][]string
}
+func (s *mockRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
+ return spec.NewUserID(string(senderID), true)
+}
+
// QueryRoomsForUser retrieves a list of room IDs matching the given query.
func (s *mockRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error {
return nil
diff --git a/syncapi/notifier/notifier.go b/syncapi/notifier/notifier.go
index f7645685..4ee7c860 100644
--- a/syncapi/notifier/notifier.go
+++ b/syncapi/notifier/notifier.go
@@ -20,6 +20,7 @@ import (
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/roomserver/api"
rstypes "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
@@ -36,7 +37,8 @@ import (
// the event, but the token has already advanced by the time they fetch it, resulting
// in missed events.
type Notifier struct {
- lock *sync.RWMutex
+ lock *sync.RWMutex
+ rsAPI api.SyncRoomserverAPI
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
roomIDToJoinedUsers map[string]*userIDSet
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
@@ -55,8 +57,9 @@ type Notifier struct {
// NewNotifier creates a new notifier set to the given sync position.
// In order for this to be of any use, the Notifier needs to be told all rooms and
// the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase).
-func NewNotifier() *Notifier {
+func NewNotifier(rsAPI api.SyncRoomserverAPI) *Notifier {
return &Notifier{
+ rsAPI: rsAPI,
roomIDToJoinedUsers: make(map[string]*userIDSet),
roomIDToPeekingDevices: make(map[string]peekingDeviceSet),
userDeviceStreams: make(map[string]map[string]*UserDeviceStream),
@@ -104,26 +107,32 @@ func (n *Notifier) OnNewEvent(
peekingDevicesToNotify := n._peekingDevices(ev.RoomID())
// If this is an invite, also add in the invitee to this list.
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
- targetUserID := *ev.StateKey()
- membership, err := ev.Membership()
+ targetUserID, err := n.rsAPI.QueryUserIDForSender(context.Background(), ev.RoomID(), spec.SenderID(*ev.StateKey()))
if err != nil {
log.WithError(err).WithField("event_id", ev.EventID()).Errorf(
- "Notifier.OnNewEvent: Failed to unmarshal member event",
+ "Notifier.OnNewEvent: Failed to find the userID for this event",
)
} else {
- // Keep the joined user map up-to-date
- switch membership {
- case spec.Invite:
- usersToNotify = append(usersToNotify, targetUserID)
- case spec.Join:
- // Manually append the new user's ID so they get notified
- // along all members in the room
- usersToNotify = append(usersToNotify, targetUserID)
- n._addJoinedUser(ev.RoomID(), targetUserID)
- case spec.Leave:
- fallthrough
- case spec.Ban:
- n._removeJoinedUser(ev.RoomID(), targetUserID)
+ membership, err := ev.Membership()
+ if err != nil {
+ log.WithError(err).WithField("event_id", ev.EventID()).Errorf(
+ "Notifier.OnNewEvent: Failed to unmarshal member event",
+ )
+ } else {
+ // Keep the joined user map up-to-date
+ switch membership {
+ case spec.Invite:
+ usersToNotify = append(usersToNotify, targetUserID.String())
+ case spec.Join:
+ // Manually append the new user's ID so they get notified
+ // along all members in the room
+ usersToNotify = append(usersToNotify, targetUserID.String())
+ n._addJoinedUser(ev.RoomID(), targetUserID.String())
+ case spec.Leave:
+ fallthrough
+ case spec.Ban:
+ n._removeJoinedUser(ev.RoomID(), targetUserID.String())
+ }
}
}
}
diff --git a/syncapi/notifier/notifier_test.go b/syncapi/notifier/notifier_test.go
index 36577a0e..7076f713 100644
--- a/syncapi/notifier/notifier_test.go
+++ b/syncapi/notifier/notifier_test.go
@@ -22,9 +22,11 @@ import (
"testing"
"time"
+ "github.com/matrix-org/dendrite/roomserver/api"
rstypes "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util"
)
@@ -105,9 +107,15 @@ func mustEqualPositions(t *testing.T, got, want types.StreamingToken) {
}
}
+type TestRoomServer struct{ api.SyncRoomserverAPI }
+
+func (t *TestRoomServer) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
+ return spec.NewUserID(string(senderID), true)
+}
+
// Test that the current position is returned if a request is already behind.
func TestImmediateNotification(t *testing.T) {
- n := NewNotifier()
+ n := NewNotifier(&TestRoomServer{})
n.SetCurrentPosition(syncPositionBefore)
pos, err := waitForEvents(n, newTestSyncRequest(alice, aliceDev, syncPositionVeryOld))
if err != nil {
@@ -118,7 +126,7 @@ func TestImmediateNotification(t *testing.T) {
// Test that new events to a joined room unblocks the request.
func TestNewEventAndJoinedToRoom(t *testing.T) {
- n := NewNotifier()
+ n := NewNotifier(&TestRoomServer{})
n.SetCurrentPosition(syncPositionBefore)
n.setUsersJoinedToRooms(map[string][]string{
roomID: {alice, bob},
@@ -144,7 +152,7 @@ func TestNewEventAndJoinedToRoom(t *testing.T) {
}
func TestCorrectStream(t *testing.T) {
- n := NewNotifier()
+ n := NewNotifier(&TestRoomServer{})
n.SetCurrentPosition(syncPositionBefore)
stream := lockedFetchUserStream(n, bob, bobDev)
if stream.UserID != bob {
@@ -156,7 +164,7 @@ func TestCorrectStream(t *testing.T) {
}
func TestCorrectStreamWakeup(t *testing.T) {
- n := NewNotifier()
+ n := NewNotifier(&TestRoomServer{})
n.SetCurrentPosition(syncPositionBefore)
awoken := make(chan string)
@@ -184,7 +192,7 @@ func TestCorrectStreamWakeup(t *testing.T) {
// Test that an invite unblocks the request
func TestNewInviteEventForUser(t *testing.T) {
- n := NewNotifier()
+ n := NewNotifier(&TestRoomServer{})
n.SetCurrentPosition(syncPositionBefore)
n.setUsersJoinedToRooms(map[string][]string{
roomID: {alice, bob},
@@ -241,7 +249,7 @@ func TestEDUWakeup(t *testing.T) {
// Test that all blocked requests get woken up on a new event.
func TestMultipleRequestWakeup(t *testing.T) {
- n := NewNotifier()
+ n := NewNotifier(&TestRoomServer{})
n.SetCurrentPosition(syncPositionBefore)
n.setUsersJoinedToRooms(map[string][]string{
roomID: {alice, bob},
@@ -278,7 +286,7 @@ func TestMultipleRequestWakeup(t *testing.T) {
func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
// listen as bob. Make bob leave room. Make alice send event to room.
// Make sure alice gets woken up only and not bob as well.
- n := NewNotifier()
+ n := NewNotifier(&TestRoomServer{})
n.SetCurrentPosition(syncPositionBefore)
n.setUsersJoinedToRooms(map[string][]string{
roomID: {alice, bob},
diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go
index 7fb88faa..55fd3c5a 100644
--- a/syncapi/routing/context.go
+++ b/syncapi/routing/context.go
@@ -85,9 +85,16 @@ func Context(
*filter.Rooms = append(*filter.Rooms, roomID)
}
+ userID, err := spec.NewUserID(device.UserID, true)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: spec.InvalidParam("Device UserID is invalid"),
+ }
+ }
ctx := req.Context()
membershipRes := roomserver.QueryMembershipForUserResponse{}
- membershipReq := roomserver.QueryMembershipForUserRequest{UserID: device.UserID, RoomID: roomID}
+ membershipReq := roomserver.QueryMembershipForUserRequest{UserID: *userID, RoomID: roomID}
if err = rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes); err != nil {
logrus.WithError(err).Error("unable to query membership")
return util.JSONResponse{
@@ -217,12 +224,9 @@ func Context(
}
}
- sender := spec.UserID{}
- userID, err := rsAPI.QueryUserIDForSender(ctx, requestedEvent.RoomID(), requestedEvent.SenderID())
- if err == nil && userID != nil {
- sender = *userID
- }
- ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, sender)
+ ev := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
+ return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ }, requestedEvent)
response := ContextRespsonse{
Event: &ev,
EventsAfter: eventsAfterClient,
diff --git a/syncapi/routing/getevent.go b/syncapi/routing/getevent.go
index 63df7e83..de790e5c 100644
--- a/syncapi/routing/getevent.go
+++ b/syncapi/routing/getevent.go
@@ -106,8 +106,17 @@ func GetEvent(
if err == nil && senderUserID != nil {
sender = *senderUserID
}
+
+ sk := events[0].StateKey()
+ if sk != nil && *sk != "" {
+ skUserID, err := rsAPI.QueryUserIDForSender(ctx, events[0].RoomID(), spec.SenderID(*events[0].StateKey()))
+ if err == nil && skUserID != nil {
+ skString := skUserID.String()
+ sk = &skString
+ }
+ }
return util.JSONResponse{
Code: http.StatusOK,
- JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender),
+ JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender, sk),
}
}
diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go
index 813167a5..cf6769ba 100644
--- a/syncapi/routing/memberships.go
+++ b/syncapi/routing/memberships.go
@@ -59,14 +59,21 @@ func GetMemberships(
syncDB storage.Database, rsAPI api.SyncRoomserverAPI,
joinedOnly bool, membership, notMembership *string, at string,
) util.JSONResponse {
+ userID, err := spec.NewUserID(device.UserID, true)
+ if err != nil {
+ return util.JSONResponse{
+ Code: http.StatusBadRequest,
+ JSON: spec.InvalidParam("Device UserID is invalid"),
+ }
+ }
queryReq := api.QueryMembershipForUserRequest{
RoomID: roomID,
- UserID: device.UserID,
+ UserID: *userID,
}
var queryRes api.QueryMembershipForUserResponse
- if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil {
- util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryMembershipsForRoom failed")
+ if queryErr := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); queryErr != nil {
+ util.GetLogger(req.Context()).WithError(queryErr).Error("rsAPI.QueryMembershipsForRoom failed")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},
diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go
index 781fd53e..6784a27b 100644
--- a/syncapi/routing/messages.go
+++ b/syncapi/routing/messages.go
@@ -296,9 +296,13 @@ func OnIncomingMessagesRequest(
}
func getMembershipForUser(ctx context.Context, roomID, userID string, rsAPI api.SyncRoomserverAPI) (resp api.QueryMembershipForUserResponse, err error) {
+ fullUserID, err := spec.NewUserID(userID, true)
+ if err != nil {
+ return resp, err
+ }
req := api.QueryMembershipForUserRequest{
RoomID: roomID,
- UserID: userID,
+ UserID: *fullUserID,
}
if err := rsAPI.QueryMembershipForUser(ctx, &req, &resp); err != nil {
return api.QueryMembershipForUserResponse{}, err
diff --git a/syncapi/routing/relations.go b/syncapi/routing/relations.go
index f21c684c..6efa065a 100644
--- a/syncapi/routing/relations.go
+++ b/syncapi/routing/relations.go
@@ -119,9 +119,18 @@ func Relations(
if err == nil && userID != nil {
sender = *userID
}
+
+ sk := event.StateKey()
+ if sk != nil && *sk != "" {
+ skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey()))
+ if err == nil && skUserID != nil {
+ skString := skUserID.String()
+ sk = &skString
+ }
+ }
res.Chunk = append(
res.Chunk,
- synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender),
+ synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender, sk),
)
}
diff --git a/syncapi/routing/search.go b/syncapi/routing/search.go
index add50b18..7d9182f4 100644
--- a/syncapi/routing/search.go
+++ b/syncapi/routing/search.go
@@ -235,6 +235,15 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
if err == nil && userID != nil {
sender = *userID
}
+
+ sk := event.StateKey()
+ if sk != nil && *sk != "" {
+ skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey()))
+ if err == nil && skUserID != nil {
+ skString := skUserID.String()
+ sk = &skString
+ }
+ }
results = append(results, Result{
Context: SearchContextResponse{
Start: startToken.String(),
@@ -248,7 +257,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
ProfileInfo: profileInfos,
},
Rank: eventScore[event.EventID()].Score,
- Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender),
+ Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender, sk),
})
roomGroup := groups[event.RoomID()]
roomGroup.Results = append(roomGroup.Results, event.EventID())
diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go
index 5bd3b1f0..799e3d16 100644
--- a/syncapi/storage/shared/storage_consumer.go
+++ b/syncapi/storage/shared/storage_consumer.go
@@ -507,8 +507,20 @@ func (d *Database) CleanSendToDeviceUpdates(
// getMembershipFromEvent returns the value of content.membership iff the event is a state event
// with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned.
-func getMembershipFromEvent(ev gomatrixserverlib.PDU, userID string) (string, string) {
- if ev.Type() != "m.room.member" || !ev.StateKeyEquals(userID) {
+func getMembershipFromEvent(ctx context.Context, ev gomatrixserverlib.PDU, userID string, rsAPI api.SyncRoomserverAPI) (string, string) {
+ if ev.StateKey() == nil || *ev.StateKey() == "" {
+ return "", ""
+ }
+ fullUser, err := spec.NewUserID(userID, true)
+ if err != nil {
+ return "", ""
+ }
+ senderID, err := rsAPI.QuerySenderIDForUser(ctx, ev.RoomID(), *fullUser)
+ if err != nil {
+ return "", ""
+ }
+
+ if ev.Type() != "m.room.member" || !ev.StateKeyEquals(string(senderID)) {
return "", ""
}
membership, err := ev.Membership()
diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go
index df961385..8e79b71d 100644
--- a/syncapi/storage/shared/storage_sync.go
+++ b/syncapi/storage/shared/storage_sync.go
@@ -430,7 +430,7 @@ func (d *DatabaseTransaction) GetStateDeltas(
for _, ev := range stateStreamEvents {
// Look for our membership in the state events and skip over any
// membership events that are not related to us.
- membership, prevMembership := getMembershipFromEvent(ev.PDU, userID)
+ membership, prevMembership := getMembershipFromEvent(ctx, ev.PDU, userID, rsAPI)
if membership == "" {
continue
}
@@ -556,7 +556,7 @@ func (d *DatabaseTransaction) GetStateDeltasForFullStateSync(
for roomID, stateStreamEvents := range state {
for _, ev := range stateStreamEvents {
- if membership, _ := getMembershipFromEvent(ev.PDU, userID); membership != "" {
+ if membership, _ := getMembershipFromEvent(ctx, ev.PDU, userID, rsAPI); membership != "" {
if membership != spec.Join { // We've already added full state for all joined rooms above.
deltas[roomID] = types.StateDelta{
Membership: membership,
diff --git a/syncapi/streams/stream_invite.go b/syncapi/streams/stream_invite.go
index a8b0a7b6..3a5badd9 100644
--- a/syncapi/streams/stream_invite.go
+++ b/syncapi/streams/stream_invite.go
@@ -70,11 +70,20 @@ func (p *InviteStreamProvider) IncrementalSync(
user = *sender
}
+ sk := inviteEvent.StateKey()
+ if sk != nil && *sk != "" {
+ skUserID, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey()))
+ if err == nil && skUserID != nil {
+ skString := skUserID.String()
+ sk = &skString
+ }
+ }
+
// skip ignored user events
if _, ok := req.IgnoredUsers.List[user.String()]; ok {
continue
}
- ir := types.NewInviteResponse(inviteEvent, user)
+ ir := types.NewInviteResponse(inviteEvent, user, sk)
req.Response.Rooms.Invite[roomID] = ir
}
diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go
index d214980b..f728d4ae 100644
--- a/syncapi/streams/stream_pdu.go
+++ b/syncapi/streams/stream_pdu.go
@@ -605,13 +605,17 @@ func (p *PDUStreamProvider) lazyLoadMembers(
// If this is a gapped incremental sync, we still want this membership
isGappedIncremental := limited && incremental
// We want this users membership event, keep it in the list
- stateKey := *event.StateKey()
- if _, ok := timelineUsers[stateKey]; ok || isGappedIncremental || stateKey == device.UserID {
+ userID := ""
+ stateKeyUserID, err := p.rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(*event.StateKey()))
+ if err == nil && stateKeyUserID != nil {
+ userID = stateKeyUserID.String()
+ }
+ if _, ok := timelineUsers[userID]; ok || isGappedIncremental || userID == device.UserID {
newStateEvents = append(newStateEvents, event)
if !stateFilter.IncludeRedundantMembers {
- p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, stateKey, event.EventID())
+ p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, userID, event.EventID())
}
- delete(timelineUsers, stateKey)
+ delete(timelineUsers, userID)
}
} else {
newStateEvents = append(newStateEvents, event)
diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go
index ecbe05dd..64a4af75 100644
--- a/syncapi/syncapi.go
+++ b/syncapi/syncapi.go
@@ -60,7 +60,7 @@ func AddPublicRoutes(
}
eduCache := caching.NewTypingCache()
- notifier := notifier.NewNotifier()
+ notifier := notifier.NewNotifier(rsAPI)
streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, eduCache, caches, notifier)
notifier.SetCurrentPosition(streams.Latest(context.Background()))
if err = notifier.Load(context.Background(), syncDB); err != nil {
diff --git a/syncapi/synctypes/clientevent.go b/syncapi/synctypes/clientevent.go
index 66fb1d01..358a0c97 100644
--- a/syncapi/synctypes/clientevent.go
+++ b/syncapi/synctypes/clientevent.go
@@ -55,18 +55,27 @@ func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat,
if err == nil && userID != nil {
sender = *userID
}
- evs = append(evs, ToClientEvent(se, format, sender))
+
+ sk := se.StateKey()
+ if sk != nil && *sk != "" {
+ skUserID, err := userIDForSender(se.RoomID(), spec.SenderID(*sk))
+ if err == nil && skUserID != nil {
+ skString := skUserID.String()
+ sk = &skString
+ }
+ }
+ evs = append(evs, ToClientEvent(se, format, sender, sk))
}
return evs
}
// ToClientEvent converts a single server event to a client event.
-func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender spec.UserID) ClientEvent {
+func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender spec.UserID, stateKey *string) ClientEvent {
ce := ClientEvent{
Content: spec.RawJSON(se.Content()),
Sender: sender.String(),
Type: se.Type(),
- StateKey: se.StateKey(),
+ StateKey: stateKey,
Unsigned: spec.RawJSON(se.Unsigned()),
OriginServerTS: se.OriginServerTS(),
EventID: se.EventID(),
@@ -77,3 +86,23 @@ func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender sp
}
return ce
}
+
+// ToClientEvent converts a single server event to a client event.
+// It provides default logic for event.SenderID & event.StateKey -> userID conversions.
+func ToClientEventDefault(userIDQuery spec.UserIDForSender, event gomatrixserverlib.PDU) ClientEvent {
+ sender := spec.UserID{}
+ userID, err := userIDQuery(event.RoomID(), event.SenderID())
+ if err == nil && userID != nil {
+ sender = *userID
+ }
+
+ sk := event.StateKey()
+ if sk != nil && *sk != "" {
+ skUserID, err := userIDQuery(event.RoomID(), spec.SenderID(*event.StateKey()))
+ if err == nil && skUserID != nil {
+ skString := skUserID.String()
+ sk = &skString
+ }
+ }
+ return ToClientEvent(event, FormatAll, sender, sk)
+}
diff --git a/syncapi/synctypes/clientevent_test.go b/syncapi/synctypes/clientevent_test.go
index 34179508..63c65b2a 100644
--- a/syncapi/synctypes/clientevent_test.go
+++ b/syncapi/synctypes/clientevent_test.go
@@ -48,7 +48,8 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo
if err != nil {
t.Fatalf("failed to create userID: %s", err)
}
- ce := ToClientEvent(ev, FormatAll, *userID)
+ sk := ""
+ ce := ToClientEvent(ev, FormatAll, *userID, &sk)
if ce.EventID != ev.EventID() {
t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID)
}
@@ -107,7 +108,8 @@ func TestToClientFormatSync(t *testing.T) {
if err != nil {
t.Fatalf("failed to create userID: %s", err)
}
- ce := ToClientEvent(ev, FormatSync, *userID)
+ sk := ""
+ ce := ToClientEvent(ev, FormatSync, *userID, &sk)
if ce.RoomID != "" {
t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID)
}
diff --git a/syncapi/types/types.go b/syncapi/types/types.go
index a3dc7f54..cb3c362d 100644
--- a/syncapi/types/types.go
+++ b/syncapi/types/types.go
@@ -539,7 +539,7 @@ type InviteResponse struct {
}
// NewInviteResponse creates an empty response with initialised arrays.
-func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID) *InviteResponse {
+func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID, stateKey *string) *InviteResponse {
res := InviteResponse{}
res.InviteState.Events = []json.RawMessage{}
@@ -552,7 +552,7 @@ func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID) *InviteRe
// Then we'll see if we can create a partial of the invite event itself.
// This is needed for clients to work out *who* sent the invite.
- inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userID)
+ inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userID, stateKey)
inviteEvent.Unsigned = nil
if ev, err := json.Marshal(inviteEvent); err == nil {
res.InviteState.Events = append(res.InviteState.Events, ev)
diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go
index a79ce541..c1b7f70b 100644
--- a/syncapi/types/types_test.go
+++ b/syncapi/types/types_test.go
@@ -65,8 +65,14 @@ func TestNewInviteResponse(t *testing.T) {
if err != nil {
t.Fatal(err)
}
+ skUserID, err := spec.NewUserID("@neilalexander:dendrite.neilalexander.dev", true)
+ if err != nil {
+ t.Fatal(err)
+ }
+ skString := skUserID.String()
+ sk := &skString
- res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender)
+ res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender, sk)
j, err := json.Marshal(res)
if err != nil {
t.Fatal(err)