diff options
author | devonh <devon.dmytro@gmail.com> | 2023-06-06 20:55:18 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-06-06 20:55:18 +0000 |
commit | 7a1fd7f512ce06a472a2051ee63eae4a270eb71a (patch) | |
tree | 20128b0d3f7c69dd776aa7b2b9bc3194dda7dd75 /roomserver | |
parent | 725ff5567d2a3bc9992b065e72ccabefb595ec1c (diff) |
PDU Sender split (#3100)
Initial cut of splitting PDU Sender into SenderID & looking up UserID where required.
Diffstat (limited to 'roomserver')
-rw-r--r-- | roomserver/api/alias.go | 2 | ||||
-rw-r--r-- | roomserver/api/api.go | 12 | ||||
-rw-r--r-- | roomserver/api/query.go | 4 | ||||
-rw-r--r-- | roomserver/internal/alias.go | 21 | ||||
-rw-r--r-- | roomserver/internal/helpers/auth.go | 4 | ||||
-rw-r--r-- | roomserver/internal/input/input_events.go | 32 | ||||
-rw-r--r-- | roomserver/internal/input/input_events_test.go | 2 | ||||
-rw-r--r-- | roomserver/internal/input/input_missing.go | 24 | ||||
-rw-r--r-- | roomserver/internal/perform/perform_admin.go | 8 | ||||
-rw-r--r-- | roomserver/internal/perform/perform_backfill.go | 12 | ||||
-rw-r--r-- | roomserver/internal/perform/perform_create_room.go | 4 | ||||
-rw-r--r-- | roomserver/internal/perform/perform_invite.go | 12 | ||||
-rw-r--r-- | roomserver/internal/perform/perform_upgrade.go | 10 | ||||
-rw-r--r-- | roomserver/internal/query/query.go | 30 | ||||
-rw-r--r-- | roomserver/producers/roomevent.go | 2 | ||||
-rw-r--r-- | roomserver/state/state.go | 9 | ||||
-rw-r--r-- | roomserver/storage/interface.go | 5 | ||||
-rw-r--r-- | roomserver/storage/shared/membership_updater.go | 2 | ||||
-rw-r--r-- | roomserver/storage/shared/room_updater.go | 5 | ||||
-rw-r--r-- | roomserver/storage/shared/storage.go | 28 |
20 files changed, 173 insertions, 55 deletions
diff --git a/roomserver/api/alias.go b/roomserver/api/alias.go index 37892a44..1b947540 100644 --- a/roomserver/api/alias.go +++ b/roomserver/api/alias.go @@ -62,7 +62,7 @@ type GetAliasesForRoomIDResponse struct { // RemoveRoomAliasRequest is a request to RemoveRoomAlias type RemoveRoomAliasRequest struct { // ID of the user removing the alias - UserID string `json:"user_id"` + SenderID string `json:"user_id"` // The room alias to remove Alias string `json:"alias"` } diff --git a/roomserver/api/api.go b/roomserver/api/api.go index a37ade3a..d61a0553 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -49,6 +49,7 @@ type RoomserverInternalAPI interface { ClientRoomserverAPI UserRoomserverAPI FederationRoomserverAPI + QuerySenderIDAPI // needed to avoid chicken and egg scenario when setting up the // interdependencies between the roomserver and other input APIs @@ -75,6 +76,11 @@ type InputRoomEventsAPI interface { ) } +type QuerySenderIDAPI interface { + QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error) + QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) +} + // Query the latest events and state for a room from the room server. type QueryLatestEventsAndStateAPI interface { QueryLatestEventsAndState(ctx context.Context, req *QueryLatestEventsAndStateRequest, res *QueryLatestEventsAndStateResponse) error @@ -102,6 +108,7 @@ type QueryEventsAPI interface { type SyncRoomserverAPI interface { QueryLatestEventsAndStateAPI QueryBulkStateContentAPI + QuerySenderIDAPI // QuerySharedUsers returns a list of users who share at least 1 room in common with the given user. QuerySharedUsers(ctx context.Context, req *QuerySharedUsersRequest, res *QuerySharedUsersResponse) error // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine @@ -142,6 +149,7 @@ type SyncRoomserverAPI interface { } type AppserviceRoomserverAPI interface { + QuerySenderIDAPI // QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine // which room to use by querying the first events roomID. QueryEventsByID( @@ -168,6 +176,7 @@ type ClientRoomserverAPI interface { QueryLatestEventsAndStateAPI QueryBulkStateContentAPI QueryEventsAPI + QuerySenderIDAPI QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error QueryRoomsForUser(ctx context.Context, req *QueryRoomsForUserRequest, res *QueryRoomsForUserResponse) error @@ -200,6 +209,7 @@ type ClientRoomserverAPI interface { } type UserRoomserverAPI interface { + QuerySenderIDAPI QueryLatestEventsAndStateAPI KeyserverRoomserverAPI QueryCurrentState(ctx context.Context, req *QueryCurrentStateRequest, res *QueryCurrentStateResponse) error @@ -213,6 +223,8 @@ type FederationRoomserverAPI interface { InputRoomEventsAPI QueryLatestEventsAndStateAPI QueryBulkStateContentAPI + QuerySenderIDAPI + // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error QueryMembershipForUser(ctx context.Context, req *QueryMembershipForUserRequest, res *QueryMembershipForUserResponse) error diff --git a/roomserver/api/query.go b/roomserver/api/query.go index e741c140..d79dcebb 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -491,10 +491,10 @@ type MembershipQuerier struct { Roomserver FederationRoomserverAPI } -func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (string, error) { +func (mq *MembershipQuerier) CurrentMembership(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) { req := QueryMembershipForUserRequest{ RoomID: roomID.String(), - UserID: userID.String(), + UserID: string(senderID), } res := QueryMembershipForUserResponse{} err := mq.Roomserver.QueryMembershipForUser(ctx, &req, &res) diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go index 52b90cf4..dcfb26b8 100644 --- a/roomserver/internal/alias.go +++ b/roomserver/internal/alias.go @@ -119,11 +119,6 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( request *api.RemoveRoomAliasRequest, response *api.RemoveRoomAliasResponse, ) error { - _, virtualHost, err := r.Cfg.Global.SplitLocalID('@', request.UserID) - if err != nil { - return err - } - roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias) if err != nil { return fmt.Errorf("r.DB.GetRoomIDForAlias: %w", err) @@ -134,13 +129,19 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( return nil } + sender, err := r.QueryUserIDForSender(ctx, roomID, request.SenderID) + if err != nil { + return fmt.Errorf("r.QueryUserIDForSender: %w", err) + } + virtualHost := sender.Domain() + response.Found = true creatorID, err := r.DB.GetCreatorIDForAlias(ctx, request.Alias) if err != nil { return fmt.Errorf("r.DB.GetCreatorIDForAlias: %w", err) } - if creatorID != request.UserID { + if creatorID != request.SenderID { var plEvent *types.HeaderedEvent var pls *gomatrixserverlib.PowerLevelContent @@ -154,7 +155,7 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( return fmt.Errorf("plEvent.PowerLevels: %w", err) } - if pls.UserLevel(request.UserID) < pls.EventLevel(spec.MRoomCanonicalAlias, true) { + if pls.UserLevel(request.SenderID) < pls.EventLevel(spec.MRoomCanonicalAlias, true) { response.Removed = false return nil } @@ -172,9 +173,9 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( return err } - sender := request.UserID - if request.UserID != ev.Sender() { - sender = ev.Sender() + sender := request.SenderID + if request.SenderID != ev.SenderID() { + sender = ev.SenderID() } _, senderDomain, err := r.Cfg.Global.SplitLocalID('@', sender) diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 7ec0892e..932ce615 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -76,7 +76,9 @@ func CheckForSoftFail( } // Check if the event is allowed. - if err = gomatrixserverlib.Allowed(event.PDU, &authEvents); err != nil { + if err = gomatrixserverlib.Allowed(event.PDU, &authEvents, func(roomID, senderID string) (*spec.UserID, error) { + return db.GetUserIDForSender(ctx, roomID, senderID) + }); err != nil { // return true, nil return true, err } diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 386083f6..764bdfe2 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -128,9 +128,13 @@ func (r *Inputer) processRoomEvent( if roomInfo == nil && !isCreateEvent { return fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID()) } - _, senderDomain, err := gomatrixserverlib.SplitID('@', event.Sender()) + sender, err := r.DB.GetUserIDForSender(ctx, event.RoomID(), event.SenderID()) if err != nil { - return fmt.Errorf("event has invalid sender %q", input.Event.Sender()) + return fmt.Errorf("failed getting userID for sender %q. %w", event.SenderID(), err) + } + senderDomain := spec.ServerName("") + if sender != nil { + senderDomain = sender.Domain() } // If we already know about this outlier and it hasn't been rejected @@ -193,7 +197,9 @@ func (r *Inputer) processRoomEvent( serverRes.ServerNames = append(serverRes.ServerNames, input.Origin) delete(servers, input.Origin) } - if senderDomain != input.Origin && senderDomain != r.Cfg.Matrix.ServerName { + // Only perform this check if the sender mxid_mapping can be resolved. + // Don't fail processing the event if we have no mxid_maping. + if sender != nil && senderDomain != input.Origin && senderDomain != r.Cfg.Matrix.ServerName { serverRes.ServerNames = append(serverRes.ServerNames, senderDomain) delete(servers, senderDomain) } @@ -276,7 +282,9 @@ func (r *Inputer) processRoomEvent( // Check if the event is allowed by its auth events. If it isn't then // we consider the event to be "rejected" — it will still be persisted. - if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil { + if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) + }); err != nil { isRejected = true rejectionErr = err logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID()) @@ -493,7 +501,7 @@ func (r *Inputer) processRoomEvent( func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event gomatrixserverlib.PDU) error { oldRoomID := event.RoomID() newRoomID := gjson.GetBytes(event.Content(), "replacement_room").Str - return r.DB.UpgradeRoom(ctx, oldRoomID, newRoomID, event.Sender()) + return r.DB.UpgradeRoom(ctx, oldRoomID, newRoomID, event.SenderID()) } // processStateBefore works out what the state is before the event and @@ -579,7 +587,9 @@ func (r *Inputer) processStateBefore( stateBeforeAuth := gomatrixserverlib.NewAuthEvents( gomatrixserverlib.ToPDUs(stateBeforeEvent), ) - if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth); rejectionErr != nil { + if rejectionErr = gomatrixserverlib.Allowed(event, &stateBeforeAuth, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) + }); rejectionErr != nil { rejectionErr = fmt.Errorf("Allowed() failed for stateBeforeEvent: %w", rejectionErr) return } @@ -690,7 +700,9 @@ nextAuthEvent: // Check the signatures of the event. If this fails then we'll simply // skip it, because gomatrixserverlib.Allowed() will notice a problem // if a critical event is missing anyway. - if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing()); err != nil { + if err := gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.FSAPI.KeyRing(), func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) + }); err != nil { continue nextAuthEvent } @@ -706,7 +718,9 @@ nextAuthEvent: } // Check if the auth event should be rejected. - err := gomatrixserverlib.Allowed(authEvent, auth) + err := gomatrixserverlib.Allowed(authEvent, auth, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) + }) if isRejected = err != nil; isRejected { logger.WithError(err).Warnf("Auth event %s rejected", authEvent.EventID()) } @@ -828,11 +842,13 @@ func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, r continue } + // TODO: pseudoIDs: get userID for room using state key (which is now senderID) localpart, senderDomain, err := gomatrixserverlib.SplitID('@', *memberEvent.StateKey()) if err != nil { continue } + // TODO: pseudoIDs: query account by state key (which is now senderID) accountRes := &userAPI.QueryAccountByLocalpartResponse{} if err = r.UserAPI.QueryAccountByLocalpart(ctx, &userAPI.QueryAccountByLocalpartRequest{ Localpart: localpart, diff --git a/roomserver/internal/input/input_events_test.go b/roomserver/internal/input/input_events_test.go index 56803813..0ba7d19f 100644 --- a/roomserver/internal/input/input_events_test.go +++ b/roomserver/internal/input/input_events_test.go @@ -58,7 +58,7 @@ func Test_EventAuth(t *testing.T) { } // Finally check that the event is NOT allowed - if err := gomatrixserverlib.Allowed(ev.PDU, &allower); err == nil { + if err := gomatrixserverlib.Allowed(ev.PDU, &allower, func(roomID, senderID string) (*spec.UserID, error) { return spec.NewUserID(senderID, true) }); err == nil { t.Fatalf("event should not be allowed, but it was") } } diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index 10486138..ac0670fc 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -473,14 +473,18 @@ func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion stateEventList = append(stateEventList, state.StateEvents...) } resolvedStateEvents, err := gomatrixserverlib.ResolveConflicts( - roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), + roomVersion, gomatrixserverlib.ToPDUs(stateEventList), gomatrixserverlib.ToPDUs(authEventList), func(roomID, senderID string) (*spec.UserID, error) { + return t.db.GetUserIDForSender(ctx, roomID, senderID) + }, ) if err != nil { return nil, err } // apply the current event retryAllowedState: - if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents); err != nil { + if err = checkAllowedByState(backwardsExtremity, resolvedStateEvents, func(roomID, senderID string) (*spec.UserID, error) { + return t.db.GetUserIDForSender(ctx, roomID, senderID) + }); err != nil { switch missing := err.(type) { case gomatrixserverlib.MissingAuthEventError: h, err2 := t.lookupEvent(ctx, roomVersion, backwardsExtremity.RoomID(), missing.AuthEventID, true) @@ -565,7 +569,9 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e gomatrixserver // will be added and duplicates will be removed. missingEvents := make([]gomatrixserverlib.PDU, 0, len(missingResp.Events)) for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) { - if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys); err != nil { + if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys, func(roomID, senderID string) (*spec.UserID, error) { + return t.db.GetUserIDForSender(ctx, roomID, senderID) + }); err != nil { continue } missingEvents = append(missingEvents, t.cacheAndReturn(ev)) @@ -654,7 +660,9 @@ func (t *missingStateReq) lookupMissingStateViaState( authEvents, stateEvents, err := gomatrixserverlib.CheckStateResponse(ctx, &fclient.RespState{ StateEvents: state.GetStateEvents(), AuthEvents: state.GetAuthEvents(), - }, roomVersion, t.keys, nil) + }, roomVersion, t.keys, nil, func(roomID, senderID string) (*spec.UserID, error) { + return t.db.GetUserIDForSender(ctx, roomID, senderID) + }) if err != nil { return nil, err } @@ -889,14 +897,16 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs t.log.WithField("missing_event_id", missingEventID).Warnf("Failed to get missing /event for event ID from %d server(s)", len(t.servers)) return nil, fmt.Errorf("wasn't able to find event via %d server(s)", len(t.servers)) } - if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys); err != nil { + if err := gomatrixserverlib.VerifyEventSignatures(ctx, event, t.keys, func(roomID, senderID string) (*spec.UserID, error) { + return t.db.GetUserIDForSender(ctx, roomID, senderID) + }); err != nil { t.log.WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID()) return nil, verifySigError{event.EventID(), err} } return t.cacheAndReturn(event), nil } -func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverlib.PDU) error { +func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverlib.PDU, userIDForSender spec.UserIDForSender) error { authUsingState := gomatrixserverlib.NewAuthEvents(nil) for i := range stateEvents { err := authUsingState.AddEvent(stateEvents[i]) @@ -904,7 +914,7 @@ func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverli return err } } - return gomatrixserverlib.Allowed(e, &authUsingState) + return gomatrixserverlib.Allowed(e, &authUsingState, userIDForSender) } func (t *missingStateReq) hadEvent(eventID string) { diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index 575525e2..ca736cb6 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -262,13 +262,17 @@ func (r *Admin) PerformAdminDownloadState( return fmt.Errorf("r.Inputer.FSAPI.LookupState (%q): %s", fwdExtremity, err) } for _, authEvent := range state.GetAuthEvents().UntrustedEvents(roomInfo.RoomVersion) { - if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing); err != nil { + if err = gomatrixserverlib.VerifyEventSignatures(ctx, authEvent, r.Inputer.KeyRing, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) + }); err != nil { continue } authEventMap[authEvent.EventID()] = authEvent } for _, stateEvent := range state.GetStateEvents().UntrustedEvents(roomInfo.RoomVersion) { - if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing); err != nil { + if err = gomatrixserverlib.VerifyEventSignatures(ctx, stateEvent, r.Inputer.KeyRing, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) + }); err != nil { continue } stateEventMap[stateEvent.EventID()] = stateEvent diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index fb579f03..0f743f4e 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -121,7 +121,9 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform // Specifically the test "Outbound federation can backfill events" events, err := gomatrixserverlib.RequestBackfill( ctx, req.VirtualHost, requester, - r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, + r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) + }, ) // Only return an error if we really couldn't get any events. if err != nil && len(events) == 0 { @@ -210,7 +212,9 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom continue } loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false) - result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents) + result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents, func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) + }) if err != nil { logger.WithError(err).Warn("failed to load and verify event") continue @@ -484,8 +488,8 @@ FindSuccessor: // Store the server names in a temporary map to avoid duplicates. serverSet := make(map[spec.ServerName]bool) for _, event := range memberEvents { - if _, senderDomain, err := gomatrixserverlib.SplitID('@', event.Sender()); err == nil { - serverSet[senderDomain] = true + if sender, err := b.db.GetUserIDForSender(ctx, event.RoomID(), event.SenderID()); err == nil { + serverSet[sender.Domain()] = true } } var servers []spec.ServerName diff --git a/roomserver/internal/perform/perform_create_room.go b/roomserver/internal/perform/perform_create_room.go index 41194832..897bd3a0 100644 --- a/roomserver/internal/perform/perform_create_room.go +++ b/roomserver/internal/perform/perform_create_room.go @@ -308,7 +308,9 @@ func (c *Creator) PerformCreateRoom(ctx context.Context, userID spec.UserID, roo } } - if err = gomatrixserverlib.Allowed(ev, &authEvents); err != nil { + if err = gomatrixserverlib.Allowed(ev, &authEvents, func(roomID, senderID string) (*spec.UserID, error) { + return c.DB.GetUserIDForSender(ctx, roomID, senderID) + }); err != nil { util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.Allowed failed") return "", &util.JSONResponse{ Code: http.StatusInternalServerError, diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index 1930b5ac..e8e20ede 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -97,11 +97,12 @@ func (r *Inviter) ProcessInviteMembership( ) ([]api.OutputEvent, error) { var outputUpdates []api.OutputEvent var updater *shared.MembershipUpdater - _, domain, err := gomatrixserverlib.SplitID('@', *inviteEvent.StateKey()) + + userID, err := r.RSAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), *inviteEvent.StateKey()) if err != nil { return nil, api.ErrInvalidID{Err: fmt.Errorf("the user ID %s is invalid", *inviteEvent.StateKey())} } - isTargetLocal := r.Cfg.Matrix.IsLocalServerName(domain) + isTargetLocal := r.Cfg.Matrix.IsLocalServerName(userID.Domain()) if updater, err = r.DB.MembershipUpdater(ctx, inviteEvent.RoomID(), *inviteEvent.StateKey(), isTargetLocal, inviteEvent.Version()); err != nil { return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err) } @@ -125,9 +126,9 @@ func (r *Inviter) PerformInvite( ) error { event := req.Event - sender, err := spec.NewUserID(event.Sender(), true) + sender, err := r.DB.GetUserIDForSender(ctx, event.RoomID(), event.SenderID()) if err != nil { - return spec.InvalidParam("The user ID is invalid") + return spec.InvalidParam("The sender user ID is invalid") } if !r.Cfg.Matrix.IsLocalServerName(sender.Domain()) { return api.ErrInvalidID{Err: fmt.Errorf("the invite must be from a local user")} @@ -155,6 +156,9 @@ func (r *Inviter) PerformInvite( StrippedState: req.InviteRoomState, MembershipQuerier: &api.MembershipQuerier{Roomserver: r.RSAPI}, StateQuerier: &QueryState{r.DB}, + UserIDQuerier: func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) + }, } inviteEvent, err := gomatrixserverlib.PerformInvite(ctx, input, r.FSAPI) if err != nil { diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go index ff4a6a1d..8c0df1c4 100644 --- a/roomserver/internal/perform/perform_upgrade.go +++ b/roomserver/internal/perform/perform_upgrade.go @@ -176,7 +176,7 @@ func moveLocalAliases(ctx context.Context, } for _, alias := range aliasRes.Aliases { - removeAliasReq := api.RemoveRoomAliasRequest{UserID: userID, Alias: alias} + removeAliasReq := api.RemoveRoomAliasRequest{SenderID: userID, Alias: alias} removeAliasRes := api.RemoveRoomAliasResponse{} if err = URSAPI.RemoveRoomAlias(ctx, &removeAliasReq, &removeAliasRes); err != nil { return fmt.Errorf("Failed to remove old room alias: %w", err) @@ -484,7 +484,9 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user } - if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil { + if err = gomatrixserverlib.Allowed(event, &authEvents, func(roomID, senderID string) (*spec.UserID, error) { + return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID) + }); err != nil { return fmt.Errorf("Failed to auth new %q event: %w", builder.Type, err) } @@ -567,7 +569,9 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, user stateEvents[i] = queryRes.StateEvents[i].PDU } provider := gomatrixserverlib.NewAuthEvents(stateEvents) - if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider); err != nil { + if err = gomatrixserverlib.Allowed(headeredEvent.PDU, &provider, func(roomID, senderID string) (*spec.UserID, error) { + return r.URSAPI.QueryUserIDForSender(ctx, roomID, senderID) + }); err != nil { return nil, api.ErrNotAllowed{Err: fmt.Errorf("failed to auth new %q event: %w", proto.Type, err)} // TODO: Is this error string comprehensible to the client? } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 6d898e8a..707e95b2 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -159,7 +159,9 @@ func (r *Queryer) QueryStateAfterEvents( } stateEvents, err = gomatrixserverlib.ResolveConflicts( - info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), + info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) + }, ) if err != nil { return fmt.Errorf("state.ResolveConflictsAdhoc: %w", err) @@ -386,7 +388,12 @@ func (r *Queryer) QueryMembershipsForRoom( return fmt.Errorf("r.DB.Events: %w", err) } for _, event := range events { - clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll) + sender := spec.UserID{} + userID, queryErr := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + if queryErr == nil && userID != nil { + sender = *userID + } + clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender) response.JoinEvents = append(response.JoinEvents, clientEvent) } return nil @@ -435,7 +442,12 @@ func (r *Queryer) QueryMembershipsForRoom( } for _, event := range events { - clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll) + sender := spec.UserID{} + userID, err := r.QueryUserIDForSender(ctx, event.RoomID(), event.SenderID()) + if err == nil && userID != nil { + sender = *userID + } + clientEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender) response.JoinEvents = append(response.JoinEvents, clientEvent) } @@ -625,7 +637,9 @@ func (r *Queryer) QueryStateAndAuthChain( if request.ResolveState { stateEvents, err = gomatrixserverlib.ResolveConflicts( - info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), + info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) + }, ) if err != nil { return err @@ -960,3 +974,11 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.Ro return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, userID) } + +func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error) { + return r.DB.GetSenderIDForUser(ctx, roomID, userID) +} + +func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { + return r.DB.GetUserIDForSender(ctx, roomID, senderID) +} diff --git a/roomserver/producers/roomevent.go b/roomserver/producers/roomevent.go index febe8ddf..165304d4 100644 --- a/roomserver/producers/roomevent.go +++ b/roomserver/producers/roomevent.go @@ -60,7 +60,7 @@ func (r *RoomEventProducer) ProduceRoomEvents(roomID string, updates []api.Outpu "adds_state": len(update.NewRoomEvent.AddsStateEventIDs), "removes_state": len(update.NewRoomEvent.RemovesStateEventIDs), "send_as_server": update.NewRoomEvent.SendAsServer, - "sender": update.NewRoomEvent.Event.Sender(), + "sender": update.NewRoomEvent.Event.SenderID(), }) if update.NewRoomEvent.Event.StateKey() != nil { logger = logger.WithField("state_key", *update.NewRoomEvent.Event.StateKey()) diff --git a/roomserver/state/state.go b/roomserver/state/state.go index f38d8f96..3131cbff 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -24,6 +24,7 @@ import ( "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" @@ -43,6 +44,7 @@ type StateResolutionStorage interface { AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error) Events(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, eventNIDs []types.EventNID) ([]types.Event, error) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]types.Event, error) + GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) } type StateResolution struct { @@ -945,7 +947,9 @@ func (v *StateResolution) resolveConflictsV1( } // Resolve the conflicts. - resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents) + resolvedEvents := gomatrixserverlib.ResolveStateConflicts(conflictedEvents, authEvents, func(roomID, senderID string) (*spec.UserID, error) { + return v.db.GetUserIDForSender(ctx, roomID, senderID) + }) // Map from the full events back to numeric state entries. for _, resolvedEvent := range resolvedEvents { @@ -1057,6 +1061,9 @@ func (v *StateResolution) resolveConflictsV2( conflictedEvents, nonConflictedEvents, authEvents, + func(roomID, senderID string) (*spec.UserID, error) { + return v.db.GetUserIDForSender(ctx, roomID, senderID) + }, ) }() diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 7d22df00..2d007bed 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -166,6 +166,10 @@ type Database interface { GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName spec.ServerName) (bool, error) // GetKnownUsers searches all users that userID knows about. GetKnownUsers(ctx context.Context, userID, searchString string, limit int) ([]string, error) + // GetKnownUsers tries to obtain the current mxid for a given user. + GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) + // GetKnownUsers tries to obtain the current senderID for a given user. + GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error) // GetKnownRooms returns a list of all rooms we know about. GetKnownRooms(ctx context.Context) ([]string, error) // ForgetRoom sets a flag in the membership table, that the user wishes to forget a specific room @@ -211,6 +215,7 @@ type RoomDatabase interface { GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error) GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error) + GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) } type EventDatabase interface { diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go index f9c889cb..105e61df 100644 --- a/roomserver/storage/shared/membership_updater.go +++ b/roomserver/storage/shared/membership_updater.go @@ -101,7 +101,7 @@ func (u *MembershipUpdater) Update(newMembership tables.MembershipState, event * var inserted bool // Did the query result in a membership change? var retired []string // Did we retire any updates in the process? return inserted, retired, u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { - senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender()) + senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.SenderID()) if err != nil { return fmt.Errorf("u.d.AssignStateKeyNID: %w", err) } diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index 70672a33..73500138 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -250,3 +251,7 @@ func (u *RoomUpdater) MarkEventAsSent(eventNID types.EventNID) error { func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) { return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal) } + +func (u *RoomUpdater) GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { + return u.d.GetUserIDForSender(ctx, roomID, senderID) +} diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index cefa58a3..406d7cf1 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -988,8 +988,18 @@ func (d *EventDatabase) MaybeRedactEvent( return nil } - _, sender1, _ := gomatrixserverlib.SplitID('@', redactedEvent.Sender()) - _, sender2, _ := gomatrixserverlib.SplitID('@', redactionEvent.Sender()) + // TODO: Don't hack senderID into userID here (pseudoIDs) + sender1Domain := "" + sender1, err1 := spec.NewUserID(redactedEvent.SenderID(), true) + if err1 == nil { + sender1Domain = string(sender1.Domain()) + } + // TODO: Don't hack senderID into userID here (pseudoIDs) + sender2Domain := "" + sender2, err2 := spec.NewUserID(redactionEvent.SenderID(), true) + if err2 == nil { + sender2Domain = string(sender2.Domain()) + } var powerlevels *gomatrixserverlib.PowerLevelContent powerlevels, err = plResolver.Resolve(ctx, redactionEvent.EventID()) if err != nil { @@ -997,9 +1007,9 @@ func (d *EventDatabase) MaybeRedactEvent( } switch { - case powerlevels.UserLevel(redactionEvent.Sender()) >= powerlevels.Redact: + case powerlevels.UserLevel(redactionEvent.SenderID()) >= powerlevels.Redact: // 1. The power level of the redaction event’s sender is greater than or equal to the redact level. - case sender1 == sender2: + case sender1Domain == sender2Domain: // 2. The domain of the redaction event’s sender matches that of the original event’s sender. default: ignoreRedaction = true @@ -1514,6 +1524,16 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit) } +func (d *Database) GetUserIDForSender(ctx context.Context, roomID string, senderID string) (*spec.UserID, error) { + // TODO: Use real logic once DB for pseudoIDs is in place + return spec.NewUserID(senderID, true) +} + +func (d *Database) GetSenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (string, error) { + // TODO: Use real logic once DB for pseudoIDs is in place + return userID.String(), nil +} + // GetKnownRooms returns a list of all rooms we know about. func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) { return d.RoomsTable.SelectRoomIDsWithEvents(ctx, nil) |