aboutsummaryrefslogtreecommitdiff
path: root/roomserver/internal/query/query.go
diff options
context:
space:
mode:
Diffstat (limited to 'roomserver/internal/query/query.go')
-rw-r--r--roomserver/internal/query/query.go30
1 files changed, 26 insertions, 4 deletions
diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go
index 6d898e8a..707e95b2 100644
--- a/roomserver/internal/query/query.go
+++ b/roomserver/internal/query/query.go
@@ -159,7 +159,9 @@ func (r *Queryer) QueryStateAfterEvents(
}
stateEvents, err = gomatrixserverlib.ResolveConflicts(
- info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents),
+ info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID, senderID string) (*spec.UserID, error) {
+ return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ },
)
if err != nil {
return fmt.Errorf("state.ResolveConflictsAdhoc: %w", err)
@@ -386,7 +388,12 @@ func (r *Queryer) QueryMembershipsForRoom(
return fmt.Errorf("r.DB.Events: %w", err)
}
for _, event := range events {
- clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll)
+ sender := spec.UserID{}
+ userID, queryErr := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
+ if queryErr == nil && userID != nil {
+ sender = *userID
+ }
+ clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender)
response.JoinEvents = append(response.JoinEvents, clientEvent)
}
return nil
@@ -435,7 +442,12 @@ func (r *Queryer) QueryMembershipsForRoom(
}
for _, event := range events {
- clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll)
+ sender := spec.UserID{}
+ userID, err := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
+ if err == nil && userID != nil {
+ sender = *userID
+ }
+ clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender)
response.JoinEvents = append(response.JoinEvents, clientEvent)
}
@@ -625,7 +637,9 @@ func (r *Queryer) QueryStateAndAuthChain(
if request.ResolveState {
stateEvents, err = gomatrixserverlib.ResolveConflicts(
- info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents),
+ info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID, senderID string) (*spec.UserID, error) {
+ return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ },
)
if err != nil {
return err
@@ -960,3 +974,11 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.Ro
return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, userID)
}
+
+func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error) {
+ return r.DB.GetSenderIDForUser(ctx, roomID, userID)
+}
+
+func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) {
+ return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+}