aboutsummaryrefslogtreecommitdiff
path: root/roomserver/storage
diff options
context:
space:
mode:
authorNeil Alexander <neilalexander@users.noreply.github.com>2022-03-01 13:01:38 +0000
committerGitHub <noreply@github.com>2022-03-01 13:01:38 +0000
commit530f05885dccba91559aff09eaaa20540a08a419 (patch)
tree1f59853bfddd88e3b9f39fedb71696cb139ea62d /roomserver/storage
parent58bf91a585ec78f6ca6ff0c9ad0c10c5db9715a7 (diff)
Limit `JoinedUsersSetInRooms` to interested users (#2234)
* Limit database work in `JoinedUsersSetInRooms` to changed user IDs only * Comments * Fix variadic params for SQLite, update comments
Diffstat (limited to 'roomserver/storage')
-rw-r--r--roomserver/storage/interface.go4
-rw-r--r--roomserver/storage/postgres/membership_table.go10
-rw-r--r--roomserver/storage/shared/storage.go20
-rw-r--r--roomserver/storage/sqlite3/membership_table.go21
-rw-r--r--roomserver/storage/tables/interface.go5
5 files changed, 34 insertions, 26 deletions
diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go
index 685505d5..cd232e3e 100644
--- a/roomserver/storage/interface.go
+++ b/roomserver/storage/interface.go
@@ -151,8 +151,8 @@ type Database interface {
// GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match.
// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error)
- // JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
- JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error)
+ // JoinedUsersSetInRooms returns how many times each of the given users appears across the given rooms.
+ JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error)
// GetLocalServerInRoom returns true if we think we're in a given room or false otherwise.
GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error)
// GetServerInRoom returns true if we think a server is in a given room or false otherwise.
diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go
index 48c2c35c..12717874 100644
--- a/roomserver/storage/postgres/membership_table.go
+++ b/roomserver/storage/postgres/membership_table.go
@@ -66,7 +66,8 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
`
var selectJoinedUsersSetForRoomsSQL = "" +
- "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid = ANY($1) AND" +
+ "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
+ " WHERE room_nid = ANY($1) AND target_nid = ANY($2) AND" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
" GROUP BY target_nid"
@@ -306,13 +307,10 @@ func (s *membershipStatements) SelectRoomsWithMembership(
func (s *membershipStatements) SelectJoinedUsersSetForRooms(
ctx context.Context, txn *sql.Tx,
roomNIDs []types.RoomNID,
+ userNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]int, error) {
- roomIDarray := make([]int64, len(roomNIDs))
- for i := range roomNIDs {
- roomIDarray[i] = int64(roomNIDs[i])
- }
stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt)
- rows, err := stmt.QueryContext(ctx, pq.Int64Array(roomIDarray))
+ rows, err := stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs))
if err != nil {
return nil, err
}
diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go
index 6e84b283..6dc40816 100644
--- a/roomserver/storage/shared/storage.go
+++ b/roomserver/storage/shared/storage.go
@@ -1104,13 +1104,23 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
return result, nil
}
-// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
-func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
+// JoinedUsersSetInRooms returns a map of how many times the given users appear in the specified rooms.
+func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error) {
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs)
if err != nil {
return nil, err
}
- userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs)
+ userNIDsMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, userIDs)
+ if err != nil {
+ return nil, err
+ }
+ userNIDs := make([]types.EventStateKeyNID, 0, len(userNIDsMap))
+ nidToUserID := make(map[types.EventStateKeyNID]string, len(userNIDsMap))
+ for id, nid := range userNIDsMap {
+ userNIDs = append(userNIDs, nid)
+ nidToUserID[nid] = id
+ }
+ userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs, userNIDs)
if err != nil {
return nil, err
}
@@ -1120,10 +1130,6 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string)
stateKeyNIDs[i] = nid
i++
}
- nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, stateKeyNIDs)
- if err != nil {
- return nil, err
- }
if len(nidToUserID) != len(userNIDToCount) {
logrus.Warnf("SelectJoinedUsersSetForRooms found %d users but BulkSelectEventStateKey only returned state key NIDs for %d of them", len(userNIDToCount), len(nidToUserID))
}
diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go
index 181b4b4c..43567a94 100644
--- a/roomserver/storage/sqlite3/membership_table.go
+++ b/roomserver/storage/sqlite3/membership_table.go
@@ -42,7 +42,8 @@ const membershipSchema = `
`
var selectJoinedUsersSetForRoomsSQL = "" +
- "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid IN ($1) AND" +
+ "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
+ " WHERE room_nid IN ($1) AND target_nid IN ($2) AND" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
" GROUP BY target_nid"
@@ -280,18 +281,22 @@ func (s *membershipStatements) SelectRoomsWithMembership(
return roomNIDs, nil
}
-func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
- iRoomNIDs := make([]interface{}, len(roomNIDs))
- for i, v := range roomNIDs {
- iRoomNIDs[i] = v
+func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error) {
+ params := make([]interface{}, 0, len(roomNIDs)+len(userNIDs))
+ for _, v := range roomNIDs {
+ params = append(params, v)
}
- query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1)
+ for _, v := range userNIDs {
+ params = append(params, v)
+ }
+ query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
+ query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1)
var rows *sql.Rows
var err error
if txn != nil {
- rows, err = txn.QueryContext(ctx, query, iRoomNIDs...)
+ rows, err = txn.QueryContext(ctx, query, params...)
} else {
- rows, err = s.db.QueryContext(ctx, query, iRoomNIDs...)
+ rows, err = s.db.QueryContext(ctx, query, params...)
}
if err != nil {
return nil, err
diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go
index e3fed700..04e3c96c 100644
--- a/roomserver/storage/tables/interface.go
+++ b/roomserver/storage/tables/interface.go
@@ -127,9 +127,8 @@ type Membership interface {
SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error
SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
- // SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the
- // counts of how many rooms they are joined.
- SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error)
+ // SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms.
+ SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error)
SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error
SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error)