diff options
Diffstat (limited to 'roomserver/internal/query/query.go')
-rw-r--r-- | roomserver/internal/query/query.go | 77 |
1 files changed, 57 insertions, 20 deletions
diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index caea6b52..19fd456b 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -16,6 +16,7 @@ package query import ( "context" + "crypto/ed25519" "database/sql" "errors" "fmt" @@ -89,7 +90,7 @@ func (r *Queryer) QueryLatestEventsAndState( request *api.QueryLatestEventsAndStateRequest, response *api.QueryLatestEventsAndStateResponse, ) error { - return helpers.QueryLatestEventsAndState(ctx, r.DB, request, response) + return helpers.QueryLatestEventsAndState(ctx, r.DB, r, request, response) } // QueryStateAfterEvents implements api.RoomserverInternalAPI @@ -106,7 +107,7 @@ func (r *Queryer) QueryStateAfterEvents( return nil } - roomState := state.NewStateResolution(r.DB, info) + roomState := state.NewStateResolution(r.DB, info, r) response.RoomExists = true response.RoomVersion = info.RoomVersion @@ -159,8 +160,8 @@ func (r *Queryer) QueryStateAfterEvents( } stateEvents, err = gomatrixserverlib.ResolveConflicts( - info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomID, senderID) + info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.QueryUserIDForSender(ctx, roomID, senderID) }, ) if err != nil { @@ -271,15 +272,15 @@ func (r *Queryer) QueryMembershipForUser( request *api.QueryMembershipForUserRequest, response *api.QueryMembershipForUserResponse, ) error { - senderID, err := r.DB.GetSenderIDForUser(ctx, request.RoomID, request.UserID) + roomID, err := spec.NewRoomID(request.RoomID) if err != nil { return err } - - roomID, err := spec.NewRoomID(request.RoomID) + senderID, err := r.QuerySenderIDForUser(ctx, *roomID, request.UserID) if err != nil { return err } + return r.QueryMembershipForSenderID(ctx, *roomID, senderID, response) } @@ -320,7 +321,7 @@ func (r *Queryer) QueryMembershipAtEvent( } response.Membership = make(map[string]*types.HeaderedEvent) - stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID]) + stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID], r) if err != nil { return fmt.Errorf("unable to get state before event: %w", err) } @@ -407,7 +408,7 @@ func (r *Queryer) QueryMembershipsForRoom( return fmt.Errorf("r.DB.Events: %w", err) } for _, event := range events { - clientEvent := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + clientEvent := synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return r.QueryUserIDForSender(ctx, roomID, senderID) }, event) response.JoinEvents = append(response.JoinEvents, clientEvent) @@ -445,7 +446,7 @@ func (r *Queryer) QueryMembershipsForRoom( events, err = r.DB.Events(ctx, info.RoomVersion, eventNIDs) } else { - stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID) + stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID, r) if err != nil { logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event") return err @@ -458,7 +459,7 @@ func (r *Queryer) QueryMembershipsForRoom( } for _, event := range events { - clientEvent := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { + clientEvent := synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return r.QueryUserIDForSender(ctx, roomID, senderID) }, event) response.JoinEvents = append(response.JoinEvents, clientEvent) @@ -532,7 +533,7 @@ func (r *Queryer) QueryServerAllowedToSeeEvent( } return helpers.CheckServerAllowedToSeeEvent( - ctx, r.DB, info, roomID, eventID, serverName, isInRoom, + ctx, r.DB, info, roomID, eventID, serverName, isInRoom, r, ) } @@ -573,7 +574,7 @@ func (r *Queryer) QueryMissingEvents( return fmt.Errorf("missing RoomInfo for room %d", events[front[0]].RoomNID) } - resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName) + resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName, r) if err != nil { return err } @@ -651,8 +652,8 @@ func (r *Queryer) QueryStateAndAuthChain( if request.ResolveState { stateEvents, err = gomatrixserverlib.ResolveConflicts( - info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomID, senderID) + info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + return r.QueryUserIDForSender(ctx, roomID, senderID) }, ) if err != nil { @@ -673,7 +674,7 @@ func (r *Queryer) QueryStateAndAuthChain( // first bool: is rejected, second bool: state missing func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]gomatrixserverlib.PDU, bool, bool, error) { - roomState := state.NewStateResolution(r.DB, roomInfo) + roomState := state.NewStateResolution(r.DB, roomInfo, r) prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs) if err != nil { switch err.(type) { @@ -989,10 +990,46 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.Ro return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, senderID) } -func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) { - return r.DB.GetSenderIDForUser(ctx, roomID, userID) +func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) { + version, err := r.DB.GetRoomVersion(ctx, roomID.String()) + if err != nil { + return "", err + } + + switch version { + case gomatrixserverlib.RoomVersionPseudoIDs: + key, err := r.DB.SelectUserRoomPublicKey(ctx, userID, roomID) + if err != nil { + return "", err + } + return spec.SenderID(spec.Base64Bytes(key).Encode()), nil + default: + return spec.SenderID(userID.String()), nil + } } -func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) { - return r.DB.GetUserIDForSender(ctx, roomID, senderID) +func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { + userID, err := spec.NewUserID(string(senderID), true) + if err == nil { + return userID, nil + } + + bytes := spec.Base64Bytes{} + err = bytes.Decode(string(senderID)) + if err != nil { + return nil, err + } + queryMap := map[spec.RoomID][]ed25519.PublicKey{roomID: {ed25519.PublicKey(bytes)}} + result, err := r.DB.SelectUserIDsForPublicKeys(ctx, queryMap) + if err != nil { + return nil, err + } + + if userKeys, ok := result[roomID]; ok { + if userID, ok := userKeys[string(senderID)]; ok { + return spec.NewUserID(userID, true) + } + } + + return nil, nil } |