aboutsummaryrefslogtreecommitdiff
path: root/roomserver/internal
diff options
context:
space:
mode:
Diffstat (limited to 'roomserver/internal')
-rw-r--r--roomserver/internal/alias.go21
-rw-r--r--roomserver/internal/helpers/auth.go4
-rw-r--r--roomserver/internal/input/input_events.go32
-rw-r--r--roomserver/internal/input/input_events_test.go2
-rw-r--r--roomserver/internal/input/input_missing.go24
-rw-r--r--roomserver/internal/perform/perform_admin.go8
-rw-r--r--roomserver/internal/perform/perform_backfill.go12
-rw-r--r--roomserver/internal/perform/perform_create_room.go4
-rw-r--r--roomserver/internal/perform/perform_invite.go12
-rw-r--r--roomserver/internal/perform/perform_upgrade.go10
-rw-r--r--roomserver/internal/query/query.go30
11 files changed, 114 insertions, 45 deletions
diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go
index 52b90cf4..dcfb26b8 100644
--- a/roomserver/internal/alias.go
+++ b/roomserver/internal/alias.go
@@ -119,11 +119,6 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
request *api.RemoveRoomAliasRequest,
response *api.RemoveRoomAliasResponse,
) error {
- _, virtualHost, err := r.Cfg.Global.SplitLocalID('@', request.UserID)
- if err != nil {
- return err
- }
-
roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias)
if err != nil {
return fmt.Errorf("r.DB.GetRoomIDForAlias: %w", err)
@@ -134,13 +129,19 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
return nil
}
+ sender, err := r.QueryUserIDForSender(ctx, roomID, request.SenderID)
+ if err != nil {
+ return fmt.Errorf("r.QueryUserIDForSender: %w", err)
+ }
+ virtualHost := sender.Domain()
+
response.Found = true
creatorID, err := r.DB.GetCreatorIDForAlias(ctx, request.Alias)
if err != nil {
return fmt.Errorf("r.DB.GetCreatorIDForAlias: %w", err)
}
- if creatorID != request.UserID {
+ if creatorID != request.SenderID {
var plEvent *types.HeaderedEvent
var pls *gomatrixserverlib.PowerLevelContent
@@ -154,7 +155,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
return fmt.Errorf("plEvent.PowerLevels: %w", err)
}
- if pls.UserLevel(request.UserID) < pls.EventLevel(spec.MRoomCanonicalAlias, true) {
+ if pls.UserLevel(request.SenderID) < pls.EventLevel(spec.MRoomCanonicalAlias, true) {
response.Removed = false
return nil
}
@@ -172,9 +173,9 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias(
return err
}
- sender := request.UserID
- if request.UserID != ev.Sender() {
- sender = ev.Sender()
+ sender := request.SenderID
+ if request.SenderID != ev.SenderID() {
+ sender = ev.SenderID()
}
_, senderDomain, err := r.Cfg.Global.SplitLocalID('@', sender)
diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go
index 7ec0892e..932ce615 100644
--- a/roomserver/internal/helpers/auth.go
+++ b/roomserver/internal/helpers/auth.go
@@ -76,7 +76,9 @@ func CheckForSoftFail(
}
// Check if the event is allowed.
- if err = gomatrixserverlib.Allowed(event.PDU, &authEvents); err != nil {
+ if err = gomatrixserverlib.Allowed(event.PDU, &authEvents, func(roomID, senderID string) (*spec.UserID, error) {
+ return db.GetUserIDForSender(ctx, roomID, senderID)
+ }); err != nil {
// return true, nil
return true, err
}
diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go
index 386083f6..764bdfe2 100644
--- a/roomserver/internal/input/input_events.go
+++ b/roomserver/internal/input/input_events.go
@@ -128,9 +128,13 @@ func (r *Inputer) processRoomEvent(
if roomInfo == nil && !isCreateEvent {
return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID())
}
- _, senderDomain, err := gomatrixserverlib.SplitID('@', event.Sender())
+ sender, err := r.DB.GetUserIDForSender(ctx, event.RoomID(), event.SenderID())
if err != nil {
- return fmt.Errorf("event has invalid sender %q", input.Event.Sender())
+ return fmt.Errorf("failed getting userID for sender %q. %w", event.SenderID(), err)
+ }
+ senderDomain := spec.ServerName("")
+ if sender != nil {
+ senderDomain = sender.Domain()
}
// If we already know about this outlier and it hasn't been rejected
@@ -193,7 +197,9 @@ func (r *Inputer) processRoomEvent(
serverRes.ServerNames = append(serverRes.ServerNames, input.Origin)
delete(servers, input.Origin)
}
- if senderDomain != input.Origin && senderDomain != r.Cfg.Matrix.ServerName {
+ // Only perform this check if the sender mxid_mapping can be resolved.
+ // Don't fail processing the event if we have no mxid_maping.
+ if sender != nil && senderDomain != input.Origin && senderDomain != r.Cfg.Matrix.ServerName {
serverRes.ServerNames = append(serverRes.ServerNames, senderDomain)
delete(servers, senderDomain)
}
@@ -276,7 +282,9 @@ func (r *Inputer) processRoomEvent(
// Check if the event is allowed by its auth events. If it isn't then
// we consider the event to be "rejected" — it will still be persisted.
- if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil {
+ if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID, senderID string) (*spec.UserID, error) {
+ return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ }); err != nil {
isRejected = true
rejectionErr = err
logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID())
@@ -493,7 +501,7 @@ func (r *Inputer) processRoomEvent(
func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event gomatrixserverlib.PDU) error {
oldRoomID := event.RoomID()
newRoomID := gjson.GetBytes(event.Content(), "replacement_room").Str
- return r.DB.UpgradeRoom(ctx, oldRoomID, newRoomID, event.Sender())
+ return r.DB.UpgradeRoom(ctx, oldRoomID, newRoomID, event.SenderID())
}
// processStateBefore works out what the state is before the event and
@@ -579,7 +587,9 @@ func (r *Inputer) processStateBefore(
stateBeforeAuth := gomatrixserverlib.NewAuthEvents(
gomatrixserverlib.ToPDUs(stateBeforeEvent),
)
- if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth); rejectionErr != nil {
+ if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomID, senderID string) (*spec.UserID, error) {
+ return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ }); rejectionErr != nil {
rejectionErr = fmt.Errorf("Allowed() failed for stateBeforeEvent: %w", rejectionErr)
return
}
@@ -690,7 +700,9 @@ nextAuthEvent:
// Check the signatures of the event. If this fails then we'll simply
// skip it, because gomatrixserverlib.Allowed() will notice a problem
// if a critical event is missing anyway.
- if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing()); err != nil {
+ if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID, senderID string) (*spec.UserID, error) {
+ return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ }); err != nil {
continue nextAuthEvent
}
@@ -706,7 +718,9 @@ nextAuthEvent:
}
// Check if the auth event should be rejected.
- err := gomatrixserverlib.Allowed(authEvent, auth)
+ err := gomatrixserverlib.Allowed(authEvent, auth, func(roomID, senderID string) (*spec.UserID, error) {
+ return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ })
if isRejected = err != nil; isRejected {
logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID())
}
@@ -828,11 +842,13 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r
continue
}
+ // TODO: pseudoIDs: get userID for room using state key (which is now senderID)
localpart, senderDomain, err := gomatrixserverlib.SplitID('@', *memberEvent.StateKey())
if err != nil {
continue
}
+ // TODO: pseudoIDs: query account by state key (which is now senderID)
accountRes := &userAPI.QueryAccountByLocalpartResponse{}
if err = r.UserAPI.QueryAccountByLocalpart(ctx, &userAPI.QueryAccountByLocalpartRequest{
Localpart: localpart,
diff --git a/roomserver/internal/input/input_events_test.go b/roomserver/internal/input/input_events_test.go
index 56803813..0ba7d19f 100644
--- a/roomserver/internal/input/input_events_test.go
+++ b/roomserver/internal/input/input_events_test.go
@@ -58,7 +58,7 @@ func Test_EventAuth(t *testing.T) {
}
// Finally check that the event is NOT allowed
- if err := gomatrixserverlib.Allowed(ev.PDU, &allower); err == nil {
+ if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomID, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) }); err == nil {
t.Fatalf("event should not be allowed, but it was")
}
}
diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go
index 10486138..ac0670fc 100644
--- a/roomserver/internal/input/input_missing.go
+++ b/roomserver/internal/input/input_missing.go
@@ -473,14 +473,18 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion
stateEventList = append(stateEventList, state.StateEvents...)
}
resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts(
- roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList),
+ roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomID, senderID string) (*spec.UserID, error) {
+ return t.db.GetUserIDForSender(ctx, roomID, senderID)
+ },
)
if err != nil {
return nil, err
}
// apply the current event
retryAllowedState:
- if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents); err != nil {
+ if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomID, senderID string) (*spec.UserID, error) {
+ return t.db.GetUserIDForSender(ctx, roomID, senderID)
+ }); err != nil {
switch missing := err.(type) {
case gomatrixserverlib.MissingAuthEventError:
h, err2 := t.lookupEvent(ctx, roomVersion, backwardsExtremity.RoomID(), missing.AuthEventID, true)
@@ -565,7 +569,9 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e gomatrixserver
// will be added and duplicates will be removed.
missingEvents := make([]gomatrixserverlib.PDU, 0, len(missingResp.Events))
for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) {
- if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys); err != nil {
+ if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID, senderID string) (*spec.UserID, error) {
+ return t.db.GetUserIDForSender(ctx, roomID, senderID)
+ }); err != nil {
continue
}
missingEvents = append(missingEvents, t.cacheAndReturn(ev))
@@ -654,7 +660,9 @@ func (t *missingStateReq) lookupMissingStateViaState(
authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(ctx, &fclient.RespState{
StateEvents: state.GetStateEvents(),
AuthEvents: state.GetAuthEvents(),
- }, roomVersion, t.keys, nil)
+ }, roomVersion, t.keys, nil, func(roomID, senderID string) (*spec.UserID, error) {
+ return t.db.GetUserIDForSender(ctx, roomID, senderID)
+ })
if err != nil {
return nil, err
}
@@ -889,14 +897,16 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers))
return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers))
}
- if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys); err != nil {
+ if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID, senderID string) (*spec.UserID, error) {
+ return t.db.GetUserIDForSender(ctx, roomID, senderID)
+ }); err != nil {
t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID())
return nil, verifySigError{event.EventID(), err}
}
return t.cacheAndReturn(event), nil
}
-func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverlib.PDU) error {
+func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverlib.PDU, userIDForSender spec.UserIDForSender) error {
authUsingState := gomatrixserverlib.NewAuthEvents(nil)
for i := range stateEvents {
err := authUsingState.AddEvent(stateEvents[i])
@@ -904,7 +914,7 @@ func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverli
return err
}
}
- return gomatrixserverlib.Allowed(e, &authUsingState)
+ return gomatrixserverlib.Allowed(e, &authUsingState, userIDForSender)
}
func (t *missingStateReq) hadEvent(eventID string) {
diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go
index 575525e2..ca736cb6 100644
--- a/roomserver/internal/perform/perform_admin.go
+++ b/roomserver/internal/perform/perform_admin.go
@@ -262,13 +262,17 @@ func (r *Admin) PerformAdminDownloadState(
return fmt.Errorf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity, err)
}
for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) {
- if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing); err != nil {
+ if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID, senderID string) (*spec.UserID, error) {
+ return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ }); err != nil {
continue
}
authEventMap[authEvent.EventID()] = authEvent
}
for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) {
- if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing); err != nil {
+ if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID, senderID string) (*spec.UserID, error) {
+ return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ }); err != nil {
continue
}
stateEventMap[stateEvent.EventID()] = stateEvent
diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go
index fb579f03..0f743f4e 100644
--- a/roomserver/internal/perform/perform_backfill.go
+++ b/roomserver/internal/perform/perform_backfill.go
@@ -121,7 +121,9 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
// Specifically the test "Outbound federation can backfill events"
events, err := gomatrixserverlib.RequestBackfill(
ctx, req.VirtualHost, requester,
- r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100,
+ r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomID, senderID string) (*spec.UserID, error) {
+ return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ },
)
// Only return an error if we really couldn't get any events.
if err != nil && len(events) == 0 {
@@ -210,7 +212,9 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom
continue
}
loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false)
- result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents)
+ result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomID, senderID string) (*spec.UserID, error) {
+ return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ })
if err != nil {
logger.WithError(err).Warn("failed to load and verify event")
continue
@@ -484,8 +488,8 @@ FindSuccessor:
// Store the server names in a temporary map to avoid duplicates.
serverSet := make(map[spec.ServerName]bool)
for _, event := range memberEvents {
- if _, senderDomain, err := gomatrixserverlib.SplitID('@', event.Sender()); err == nil {
- serverSet[senderDomain] = true
+ if sender, err := b.db.GetUserIDForSender(ctx, event.RoomID(), event.SenderID()); err == nil {
+ serverSet[sender.Domain()] = true
}
}
var servers []spec.ServerName
diff --git a/roomserver/internal/perform/perform_create_room.go b/roomserver/internal/perform/perform_create_room.go
index 41194832..897bd3a0 100644
--- a/roomserver/internal/perform/perform_create_room.go
+++ b/roomserver/internal/perform/perform_create_room.go
@@ -308,7 +308,9 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo
}
}
- if err = gomatrixserverlib.Allowed(ev, &authEvents); err != nil {
+ if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID, senderID string) (*spec.UserID, error) {
+ return c.DB.GetUserIDForSender(ctx, roomID, senderID)
+ }); err != nil {
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed")
return "", &util.JSONResponse{
Code: http.StatusInternalServerError,
diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go
index 1930b5ac..e8e20ede 100644
--- a/roomserver/internal/perform/perform_invite.go
+++ b/roomserver/internal/perform/perform_invite.go
@@ -97,11 +97,12 @@ func (r *Inviter) ProcessInviteMembership(
) ([]api.OutputEvent, error) {
var outputUpdates []api.OutputEvent
var updater *shared.MembershipUpdater
- _, domain, err := gomatrixserverlib.SplitID('@', *inviteEvent.StateKey())
+
+ userID, err := r.RSAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), *inviteEvent.StateKey())
if err != nil {
return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", *inviteEvent.StateKey())}
}
- isTargetLocal := r.Cfg.Matrix.IsLocalServerName(domain)
+ isTargetLocal := r.Cfg.Matrix.IsLocalServerName(userID.Domain())
if updater, err = r.DB.MembershipUpdater(ctx, inviteEvent.RoomID(), *inviteEvent.StateKey(), isTargetLocal, inviteEvent.Version()); err != nil {
return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err)
}
@@ -125,9 +126,9 @@ func (r *Inviter) PerformInvite(
) error {
event := req.Event
- sender, err := spec.NewUserID(event.Sender(), true)
+ sender, err := r.DB.GetUserIDForSender(ctx, event.RoomID(), event.SenderID())
if err != nil {
- return spec.InvalidParam("The user ID is invalid")
+ return spec.InvalidParam("The sender user ID is invalid")
}
if !r.Cfg.Matrix.IsLocalServerName(sender.Domain()) {
return api.ErrInvalidID{Err: fmt.Errorf("the invite must be from a local user")}
@@ -155,6 +156,9 @@ func (r *Inviter) PerformInvite(
StrippedState: req.InviteRoomState,
MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI},
StateQuerier: &QueryState{r.DB},
+ UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) {
+ return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ },
}
inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI)
if err != nil {
diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go
index ff4a6a1d..8c0df1c4 100644
--- a/roomserver/internal/perform/perform_upgrade.go
+++ b/roomserver/internal/perform/perform_upgrade.go
@@ -176,7 +176,7 @@ func moveLocalAliases(ctx context.Context,
}
for _, alias := range aliasRes.Aliases {
- removeAliasReq := api.RemoveRoomAliasRequest{UserID: userID, Alias: alias}
+ removeAliasReq := api.RemoveRoomAliasRequest{SenderID: userID, Alias: alias}
removeAliasRes := api.RemoveRoomAliasResponse{}
if err = URSAPI.RemoveRoomAlias(ctx, &removeAliasReq, &removeAliasRes); err != nil {
return fmt.Errorf("Failed to remove old room alias: %w", err)
@@ -484,7 +484,9 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user
}
- if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil {
+ if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID, senderID string) (*spec.UserID, error) {
+ return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ }); err != nil {
return fmt.Errorf("Failed to auth new %q event: %w", builder.Type, err)
}
@@ -567,7 +569,9 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, user
stateEvents[i] = queryRes.StateEvents[i].PDU
}
provider := gomatrixserverlib.NewAuthEvents(stateEvents)
- if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider); err != nil {
+ if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomID, senderID string) (*spec.UserID, error) {
+ return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID)
+ }); err != nil {
return nil, api.ErrNotAllowed{Err: fmt.Errorf("failed to auth new %q event: %w", proto.Type, err)} // TODO: Is this error string comprehensible to the client?
}
diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go
index 6d898e8a..707e95b2 100644
--- a/roomserver/internal/query/query.go
+++ b/roomserver/internal/query/query.go
@@ -159,7 +159,9 @@ func (r *Queryer) QueryStateAfterEvents(
}
stateEvents, err = gomatrixserverlib.ResolveConflicts(
- info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents),
+ info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID, senderID string) (*spec.UserID, error) {
+ return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ },
)
if err != nil {
return fmt.Errorf("state.ResolveConflictsAdhoc: %w", err)
@@ -386,7 +388,12 @@ func (r *Queryer) QueryMembershipsForRoom(
return fmt.Errorf("r.DB.Events: %w", err)
}
for _, event := range events {
- clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll)
+ sender := spec.UserID{}
+ userID, queryErr := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
+ if queryErr == nil && userID != nil {
+ sender = *userID
+ }
+ clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender)
response.JoinEvents = append(response.JoinEvents, clientEvent)
}
return nil
@@ -435,7 +442,12 @@ func (r *Queryer) QueryMembershipsForRoom(
}
for _, event := range events {
- clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll)
+ sender := spec.UserID{}
+ userID, err := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID())
+ if err == nil && userID != nil {
+ sender = *userID
+ }
+ clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender)
response.JoinEvents = append(response.JoinEvents, clientEvent)
}
@@ -625,7 +637,9 @@ func (r *Queryer) QueryStateAndAuthChain(
if request.ResolveState {
stateEvents, err = gomatrixserverlib.ResolveConflicts(
- info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents),
+ info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID, senderID string) (*spec.UserID, error) {
+ return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ },
)
if err != nil {
return err
@@ -960,3 +974,11 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.Ro
return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, userID)
}
+
+func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error) {
+ return r.DB.GetSenderIDForUser(ctx, roomID, userID)
+}
+
+func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) {
+ return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+}