aboutsummaryrefslogtreecommitdiff
path: root/roomserver/internal/query/query.go
diff options
context:
space:
mode:
Diffstat (limited to 'roomserver/internal/query/query.go')
-rw-r--r--roomserver/internal/query/query.go77
1 files changed, 57 insertions, 20 deletions
diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go
index caea6b52..19fd456b 100644
--- a/roomserver/internal/query/query.go
+++ b/roomserver/internal/query/query.go
@@ -16,6 +16,7 @@ package query
import (
"context"
+ "crypto/ed25519"
"database/sql"
"errors"
"fmt"
@@ -89,7 +90,7 @@ func (r *Queryer) QueryLatestEventsAndState(
request *api.QueryLatestEventsAndStateRequest,
response *api.QueryLatestEventsAndStateResponse,
) error {
- return helpers.QueryLatestEventsAndState(ctx, r.DB, request, response)
+ return helpers.QueryLatestEventsAndState(ctx, r.DB, r, request, response)
}
// QueryStateAfterEvents implements api.RoomserverInternalAPI
@@ -106,7 +107,7 @@ func (r *Queryer) QueryStateAfterEvents(
return nil
}
- roomState := state.NewStateResolution(r.DB, info)
+ roomState := state.NewStateResolution(r.DB, info, r)
response.RoomExists = true
response.RoomVersion = info.RoomVersion
@@ -159,8 +160,8 @@ func (r *Queryer) QueryStateAfterEvents(
}
stateEvents, err = gomatrixserverlib.ResolveConflicts(
- info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
- return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ return r.QueryUserIDForSender(ctx, roomID, senderID)
},
)
if err != nil {
@@ -271,15 +272,15 @@ func (r *Queryer) QueryMembershipForUser(
request *api.QueryMembershipForUserRequest,
response *api.QueryMembershipForUserResponse,
) error {
- senderID, err := r.DB.GetSenderIDForUser(ctx, request.RoomID, request.UserID)
+ roomID, err := spec.NewRoomID(request.RoomID)
if err != nil {
return err
}
-
- roomID, err := spec.NewRoomID(request.RoomID)
+ senderID, err := r.QuerySenderIDForUser(ctx, *roomID, request.UserID)
if err != nil {
return err
}
+
return r.QueryMembershipForSenderID(ctx, *roomID, senderID, response)
}
@@ -320,7 +321,7 @@ func (r *Queryer) QueryMembershipAtEvent(
}
response.Membership = make(map[string]*types.HeaderedEvent)
- stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID])
+ stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID], r)
if err != nil {
return fmt.Errorf("unable to get state before event: %w", err)
}
@@ -407,7 +408,7 @@ func (r *Queryer) QueryMembershipsForRoom(
return fmt.Errorf("r.DB.Events: %w", err)
}
for _, event := range events {
- clientEvent := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
+ clientEvent := synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.QueryUserIDForSender(ctx, roomID, senderID)
}, event)
response.JoinEvents = append(response.JoinEvents, clientEvent)
@@ -445,7 +446,7 @@ func (r *Queryer) QueryMembershipsForRoom(
events, err = r.DB.Events(ctx, info.RoomVersion, eventNIDs)
} else {
- stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID)
+ stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID, r)
if err != nil {
logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event")
return err
@@ -458,7 +459,7 @@ func (r *Queryer) QueryMembershipsForRoom(
}
for _, event := range events {
- clientEvent := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
+ clientEvent := synctypes.ToClientEventDefault(func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
return r.QueryUserIDForSender(ctx, roomID, senderID)
}, event)
response.JoinEvents = append(response.JoinEvents, clientEvent)
@@ -532,7 +533,7 @@ func (r *Queryer) QueryServerAllowedToSeeEvent(
}
return helpers.CheckServerAllowedToSeeEvent(
- ctx, r.DB, info, roomID, eventID, serverName, isInRoom,
+ ctx, r.DB, info, roomID, eventID, serverName, isInRoom, r,
)
}
@@ -573,7 +574,7 @@ func (r *Queryer) QueryMissingEvents(
return fmt.Errorf("missing RoomInfo for room %d", events[front[0]].RoomNID)
}
- resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName)
+ resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName, r)
if err != nil {
return err
}
@@ -651,8 +652,8 @@ func (r *Queryer) QueryStateAndAuthChain(
if request.ResolveState {
stateEvents, err = gomatrixserverlib.ResolveConflicts(
- info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
- return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+ info.RoomVersion, gomatrixserverlib.ToPDUs(stateEvents), gomatrixserverlib.ToPDUs(authEvents), func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ return r.QueryUserIDForSender(ctx, roomID, senderID)
},
)
if err != nil {
@@ -673,7 +674,7 @@ func (r *Queryer) QueryStateAndAuthChain(
// first bool: is rejected, second bool: state missing
func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) ([]gomatrixserverlib.PDU, bool, bool, error) {
- roomState := state.NewStateResolution(r.DB, roomInfo)
+ roomState := state.NewStateResolution(r.DB, roomInfo, r)
prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs)
if err != nil {
switch err.(type) {
@@ -989,10 +990,46 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.Ro
return verImpl.CheckRestrictedJoin(ctx, r.Cfg.Global.ServerName, &api.JoinRoomQuerier{Roomserver: r}, roomID, senderID)
}
-func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID string, userID spec.UserID) (spec.SenderID, error) {
- return r.DB.GetSenderIDForUser(ctx, roomID, userID)
+func (r *Queryer) QuerySenderIDForUser(ctx context.Context, roomID spec.RoomID, userID spec.UserID) (spec.SenderID, error) {
+ version, err := r.DB.GetRoomVersion(ctx, roomID.String())
+ if err != nil {
+ return "", err
+ }
+
+ switch version {
+ case gomatrixserverlib.RoomVersionPseudoIDs:
+ key, err := r.DB.SelectUserRoomPublicKey(ctx, userID, roomID)
+ if err != nil {
+ return "", err
+ }
+ return spec.SenderID(spec.Base64Bytes(key).Encode()), nil
+ default:
+ return spec.SenderID(userID.String()), nil
+ }
}
-func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
- return r.DB.GetUserIDForSender(ctx, roomID, senderID)
+func (r *Queryer) QueryUserIDForSender(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) {
+ userID, err := spec.NewUserID(string(senderID), true)
+ if err == nil {
+ return userID, nil
+ }
+
+ bytes := spec.Base64Bytes{}
+ err = bytes.Decode(string(senderID))
+ if err != nil {
+ return nil, err
+ }
+ queryMap := map[spec.RoomID][]ed25519.PublicKey{roomID: {ed25519.PublicKey(bytes)}}
+ result, err := r.DB.SelectUserIDsForPublicKeys(ctx, queryMap)
+ if err != nil {
+ return nil, err
+ }
+
+ if userKeys, ok := result[roomID]; ok {
+ if userID, ok := userKeys[string(senderID)]; ok {
+ return spec.NewUserID(userID, true)
+ }
+ }
+
+ return nil, nil
}