aboutsummaryrefslogtreecommitdiff
path: root/roomserver/storage/postgres/membership_table.go
diff options
context:
space:
mode:
Diffstat (limited to 'roomserver/storage/postgres/membership_table.go')
-rw-r--r--roomserver/storage/postgres/membership_table.go22
1 files changed, 20 insertions, 2 deletions
diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go
index c01753c3..ce626ad1 100644
--- a/roomserver/storage/postgres/membership_table.go
+++ b/roomserver/storage/postgres/membership_table.go
@@ -65,12 +65,18 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
);
`
-var selectJoinedUsersSetForRoomsSQL = "" +
+var selectJoinedUsersSetForRoomsAndUserSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
" WHERE room_nid = ANY($1) AND target_nid = ANY($2) AND" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
" GROUP BY target_nid"
+var selectJoinedUsersSetForRoomsSQL = "" +
+ "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
+ " WHERE room_nid = ANY($1) AND" +
+ " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
+ " GROUP BY target_nid"
+
// Insert a row in to membership table so that it can be locked by the
// SELECT FOR UPDATE
const insertMembershipSQL = "" +
@@ -153,6 +159,7 @@ type membershipStatements struct {
selectLocalMembershipsFromRoomStmt *sql.Stmt
updateMembershipStmt *sql.Stmt
selectRoomsWithMembershipStmt *sql.Stmt
+ selectJoinedUsersSetForRoomsAndUserStmt *sql.Stmt
selectJoinedUsersSetForRoomsStmt *sql.Stmt
selectKnownUsersStmt *sql.Stmt
updateMembershipForgetRoomStmt *sql.Stmt
@@ -178,6 +185,7 @@ func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) {
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
{&s.updateMembershipStmt, updateMembershipSQL},
{&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL},
+ {&s.selectJoinedUsersSetForRoomsAndUserStmt, selectJoinedUsersSetForRoomsAndUserSQL},
{&s.selectJoinedUsersSetForRoomsStmt, selectJoinedUsersSetForRoomsSQL},
{&s.selectKnownUsersStmt, selectKnownUsersSQL},
{&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom},
@@ -313,8 +321,18 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(
roomNIDs []types.RoomNID,
userNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]int, error) {
+ var (
+ rows *sql.Rows
+ err error
+ )
stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt)
- rows, err := stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs))
+ if len(userNIDs) > 0 {
+ stmt = sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsAndUserStmt)
+ rows, err = stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs))
+ } else {
+ rows, err = stmt.QueryContext(ctx, pq.Array(roomNIDs))
+ }
+
if err != nil {
return nil, err
}