aboutsummaryrefslogtreecommitdiff
path: root/roomserver/storage/shared/storage.go
diff options
context:
space:
mode:
Diffstat (limited to 'roomserver/storage/shared/storage.go')
-rw-r--r--roomserver/storage/shared/storage.go20
1 files changed, 13 insertions, 7 deletions
diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go
index 6e84b283..6dc40816 100644
--- a/roomserver/storage/shared/storage.go
+++ b/roomserver/storage/shared/storage.go
@@ -1104,13 +1104,23 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
return result, nil
}
-// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
-func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
+// JoinedUsersSetInRooms returns a map of how many times the given users appear in the specified rooms.
+func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error) {
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs)
if err != nil {
return nil, err
}
- userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs)
+ userNIDsMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, userIDs)
+ if err != nil {
+ return nil, err
+ }
+ userNIDs := make([]types.EventStateKeyNID, 0, len(userNIDsMap))
+ nidToUserID := make(map[types.EventStateKeyNID]string, len(userNIDsMap))
+ for id, nid := range userNIDsMap {
+ userNIDs = append(userNIDs, nid)
+ nidToUserID[nid] = id
+ }
+ userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs, userNIDs)
if err != nil {
return nil, err
}
@@ -1120,10 +1130,6 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string)
stateKeyNIDs[i] = nid
i++
}
- nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, stateKeyNIDs)
- if err != nil {
- return nil, err
- }
if len(nidToUserID) != len(userNIDToCount) {
logrus.Warnf("SelectJoinedUsersSetForRooms found %d users but BulkSelectEventStateKey only returned state key NIDs for %d of them", len(userNIDToCount), len(nidToUserID))
}