aboutsummaryrefslogtreecommitdiff
path: root/roomserver/internal/query
diff options
context:
space:
mode:
authorSam Wedgwood <28223854+swedgwood@users.noreply.github.com>2023-08-15 12:37:04 +0100
committerGitHub <noreply@github.com>2023-08-15 12:37:04 +0100
commit9a12420428f1832c76fc0c84ad85db200e261ecb (patch)
tree38ce262c515d74865920f6ebaf336f1887dee11b /roomserver/internal/query
parentfa6c7ba45671c8fbf13cb7ba456355a04941b535 (diff)
[pseudoID] More pseudo ID fixes (#3167)
Signed-off-by: `Sam Wedgwood <sam@wedgwood.dev>`
Diffstat (limited to 'roomserver/internal/query')
-rw-r--r--roomserver/internal/query/query.go128
1 files changed, 78 insertions, 50 deletions
diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go
index 0fe0f4f2..f87a3f7e 100644
--- a/roomserver/internal/query/query.go
+++ b/roomserver/internal/query/query.go
@@ -230,6 +230,33 @@ func (r *Queryer) QueryMembershipForSenderID(
senderID spec.SenderID,
response *api.QueryMembershipForUserResponse,
) error {
+ return r.queryMembershipForOptionalSenderID(ctx, roomID, &senderID, response)
+}
+
+// QueryMembershipForUser implements api.RoomserverInternalAPI
+func (r *Queryer) QueryMembershipForUser(
+ ctx context.Context,
+ request *api.QueryMembershipForUserRequest,
+ response *api.QueryMembershipForUserResponse,
+) error {
+ roomID, err := spec.NewRoomID(request.RoomID)
+ if err != nil {
+ return err
+ }
+ senderID, err := r.QuerySenderIDForUser(ctx, *roomID, request.UserID)
+ if err != nil {
+ return err
+ }
+
+ return r.queryMembershipForOptionalSenderID(ctx, *roomID, senderID, response)
+}
+
+// Query membership information for provided sender ID and room ID
+//
+// If sender ID is nil, then act as if the provided sender is not a member of the room.
+func (r *Queryer) queryMembershipForOptionalSenderID(ctx context.Context, roomID spec.RoomID, senderID *spec.SenderID, response *api.QueryMembershipForUserResponse) error {
+ response.SenderID = senderID
+
info, err := r.DB.RoomInfo(ctx, roomID.String())
if err != nil {
return err
@@ -240,7 +267,11 @@ func (r *Queryer) QueryMembershipForSenderID(
}
response.RoomExists = true
- membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, senderID)
+ if senderID == nil {
+ return nil
+ }
+
+ membershipEventNID, stillInRoom, isRoomforgotten, err := r.DB.GetMembership(ctx, info.RoomNID, *senderID)
if err != nil {
return err
}
@@ -268,70 +299,55 @@ func (r *Queryer) QueryMembershipForSenderID(
return err
}
-// QueryMembershipForUser implements api.RoomserverInternalAPI
-func (r *Queryer) QueryMembershipForUser(
- ctx context.Context,
- request *api.QueryMembershipForUserRequest,
- response *api.QueryMembershipForUserResponse,
-) error {
- roomID, err := spec.NewRoomID(request.RoomID)
- if err != nil {
- return err
- }
- senderID, err := r.QuerySenderIDForUser(ctx, *roomID, request.UserID)
- if err != nil {
- return err
- }
-
- return r.QueryMembershipForSenderID(ctx, *roomID, *senderID, response)
-}
-
// QueryMembershipAtEvent returns the known memberships at a given event.
// If the state before an event is not known, an empty list will be returned
// for that event instead.
+//
+// Returned map from eventID to membership event. Events that
+// do not have known state will return a nil event, resulting in a "leave" membership
+// when calculating history visibility.
func (r *Queryer) QueryMembershipAtEvent(
ctx context.Context,
- request *api.QueryMembershipAtEventRequest,
- response *api.QueryMembershipAtEventResponse,
-) error {
- response.Membership = make(map[string]*types.HeaderedEvent)
-
- info, err := r.DB.RoomInfo(ctx, request.RoomID)
+ roomID spec.RoomID,
+ eventIDs []string,
+ senderID spec.SenderID,
+) (map[string]*types.HeaderedEvent, error) {
+ info, err := r.DB.RoomInfo(ctx, roomID.String())
if err != nil {
- return fmt.Errorf("unable to get roomInfo: %w", err)
+ return nil, fmt.Errorf("unable to get roomInfo: %w", err)
}
if info == nil {
- return fmt.Errorf("no roomInfo found")
+ return nil, fmt.Errorf("no roomInfo found")
}
// get the users stateKeyNID
- stateKeyNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{request.UserID})
+ stateKeyNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{string(senderID)})
if err != nil {
- return fmt.Errorf("unable to get stateKeyNIDs for %s: %w", request.UserID, err)
+ return nil, fmt.Errorf("unable to get stateKeyNIDs for %s: %w", senderID, err)
}
- if _, ok := stateKeyNIDs[request.UserID]; !ok {
- return fmt.Errorf("requested stateKeyNID for %s was not found", request.UserID)
+ if _, ok := stateKeyNIDs[string(senderID)]; !ok {
+ return nil, fmt.Errorf("requested stateKeyNID for %s was not found", senderID)
}
- response.Membership, err = r.DB.GetMembershipForHistoryVisibility(ctx, stateKeyNIDs[request.UserID], info, request.EventIDs...)
+ eventIDMembershipMap, err := r.DB.GetMembershipForHistoryVisibility(ctx, stateKeyNIDs[string(senderID)], info, eventIDs...)
switch err {
case nil:
- return nil
+ return eventIDMembershipMap, nil
case tables.OptimisationNotSupportedError: // fallthrough, slow way of getting the membership events for each event
default:
- return err
+ return eventIDMembershipMap, err
}
- response.Membership = make(map[string]*types.HeaderedEvent)
- stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID], r)
+ eventIDMembershipMap = make(map[string]*types.HeaderedEvent)
+ stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, eventIDs, stateKeyNIDs[string(senderID)], r)
if err != nil {
- return fmt.Errorf("unable to get state before event: %w", err)
+ return eventIDMembershipMap, fmt.Errorf("unable to get state before event: %w", err)
}
// If we only have one or less state entries, we can short circuit the below
// loop and avoid hitting the database
allStateEventNIDs := make(map[types.EventNID]types.StateEntry)
- for _, eventID := range request.EventIDs {
+ for _, eventID := range eventIDs {
stateEntry := stateEntries[eventID]
for _, s := range stateEntry {
allStateEventNIDs[s.EventNID] = s
@@ -344,10 +360,10 @@ func (r *Queryer) QueryMembershipAtEvent(
}
var memberships []types.Event
- for _, eventID := range request.EventIDs {
+ for _, eventID := range eventIDs {
stateEntry, ok := stateEntries[eventID]
if !ok || len(stateEntry) == 0 {
- response.Membership[eventID] = nil
+ eventIDMembershipMap[eventID] = nil
continue
}
@@ -361,7 +377,7 @@ func (r *Queryer) QueryMembershipAtEvent(
memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntry, false)
}
if err != nil {
- return fmt.Errorf("unable to get memberships at state: %w", err)
+ return eventIDMembershipMap, fmt.Errorf("unable to get memberships at state: %w", err)
}
// Iterate over all membership events we got. Given we only query the membership for
@@ -369,13 +385,13 @@ func (r *Queryer) QueryMembershipAtEvent(
// a given event, overwrite any other existing membership events.
for i := range memberships {
ev := memberships[i]
- if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(request.UserID) {
- response.Membership[eventID] = &types.HeaderedEvent{PDU: ev.PDU}
+ if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(string(senderID)) {
+ eventIDMembershipMap[eventID] = &types.HeaderedEvent{PDU: ev.PDU}
}
}
}
- return nil
+ return eventIDMembershipMap, nil
}
// QueryMembershipsForRoom implements api.RoomserverInternalAPI
@@ -830,13 +846,20 @@ func (r *Queryer) QueryCurrentState(ctx context.Context, req *api.QueryCurrentSt
return nil
}
-func (r *Queryer) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error {
- roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, req.WantMembership)
+func (r *Queryer) QueryRoomsForUser(ctx context.Context, userID spec.UserID, desiredMembership string) ([]spec.RoomID, error) {
+ roomIDStrs, err := r.DB.GetRoomsByMembership(ctx, userID, desiredMembership)
if err != nil {
- return err
+ return nil, err
}
- res.RoomIDs = roomIDs
- return nil
+ roomIDs := make([]spec.RoomID, len(roomIDStrs))
+ for i, roomIDStr := range roomIDStrs {
+ roomID, err := spec.NewRoomID(roomIDStr)
+ if err != nil {
+ return nil, err
+ }
+ roomIDs[i] = *roomID
+ }
+ return roomIDs, nil
}
func (r *Queryer) QueryKnownUsers(ctx context.Context, req *api.QueryKnownUsersRequest, res *api.QueryKnownUsersResponse) error {
@@ -879,7 +902,12 @@ func (r *Queryer) QueryLeftUsers(ctx context.Context, req *api.QueryLeftUsersReq
}
func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error {
- roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, "join")
+ parsedUserID, err := spec.NewUserID(req.UserID, true)
+ if err != nil {
+ return err
+ }
+
+ roomIDs, err := r.DB.GetRoomsByMembership(ctx, *parsedUserID, "join")
if err != nil {
return err
}