diff options
author | Kegsay <kegan@matrix.org> | 2020-09-03 18:27:02 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-09-03 18:27:02 +0100 |
commit | 33b8143a9597ff8c6b75ea47a588d50dc6e72259 (patch) | |
tree | 8c8b2862291297f857844f656bb64cb1da9b8bd3 /roomserver/storage/postgres/membership_table.go | |
parent | b20386123e0cbdc53016231f0087d0047b5667e9 (diff) |
Implement more CSS storage functions in roomserver (#1388)
Diffstat (limited to 'roomserver/storage/postgres/membership_table.go')
-rw-r--r-- | roomserver/storage/postgres/membership_table.go | 59 |
1 files changed, 59 insertions, 0 deletions
diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index 0799647e..5164f654 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -18,7 +18,9 @@ package postgres import ( "context" "database/sql" + "fmt" + "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" @@ -62,6 +64,10 @@ CREATE TABLE IF NOT EXISTS roomserver_membership ( ); ` +var selectJoinedUsersSetForRoomsSQL = "" + + "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid = ANY($1) AND" + + " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " GROUP BY target_nid" + // Insert a row in to membership table so that it can be locked by the // SELECT FOR UPDATE const insertMembershipSQL = "" + @@ -102,6 +108,16 @@ const updateMembershipSQL = "" + const selectRoomsWithMembershipSQL = "" + "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2" +// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is +// joined to. Since this information is used to populate the user directory, we will +// only return users that the user would ordinarily be able to see anyway. +var selectKnownUsersSQL = "" + + "SELECT DISTINCT event_state_key FROM roomserver_membership INNER JOIN roomserver_event_state_keys ON " + + "roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" + + " WHERE room_nid = ANY(" + + " SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + + ") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3" + type membershipStatements struct { insertMembershipStmt *sql.Stmt selectMembershipForUpdateStmt *sql.Stmt @@ -112,6 +128,8 @@ type membershipStatements struct { selectLocalMembershipsFromRoomStmt *sql.Stmt updateMembershipStmt *sql.Stmt selectRoomsWithMembershipStmt *sql.Stmt + selectJoinedUsersSetForRoomsStmt *sql.Stmt + selectKnownUsersStmt *sql.Stmt } func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) { @@ -131,6 +149,8 @@ func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) { {&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL}, {&s.updateMembershipStmt, updateMembershipSQL}, {&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL}, + {&s.selectJoinedUsersSetForRoomsStmt, selectJoinedUsersSetForRoomsSQL}, + {&s.selectKnownUsersStmt, selectKnownUsersSQL}, }.Prepare(db) } @@ -246,3 +266,42 @@ func (s *membershipStatements) SelectRoomsWithMembership( } return roomNIDs, nil } + +func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) { + roomIDarray := make([]int64, len(roomNIDs)) + for i := range roomNIDs { + roomIDarray[i] = int64(roomNIDs[i]) + } + rows, err := s.selectJoinedUsersSetForRoomsStmt.QueryContext(ctx, pq.Int64Array(roomIDarray)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed") + result := make(map[types.EventStateKeyNID]int) + for rows.Next() { + var userID types.EventStateKeyNID + var count int + if err := rows.Scan(&userID, &count); err != nil { + return nil, err + } + result[userID] = count + } + return result, rows.Err() +} + +func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) { + rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit) + if err != nil { + return nil, err + } + result := []string{} + defer internal.CloseAndLogIfError(ctx, rows, "SelectKnownUsers: rows.close() failed") + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return nil, err + } + result = append(result, userID) + } + return result, rows.Err() +} |