aboutsummaryrefslogtreecommitdiff
path: root/roomserver/storage/postgres/membership_table.go
diff options
context:
space:
mode:
authorKegsay <kegan@matrix.org>2020-09-03 18:27:02 +0100
committerGitHub <noreply@github.com>2020-09-03 18:27:02 +0100
commit33b8143a9597ff8c6b75ea47a588d50dc6e72259 (patch)
tree8c8b2862291297f857844f656bb64cb1da9b8bd3 /roomserver/storage/postgres/membership_table.go
parentb20386123e0cbdc53016231f0087d0047b5667e9 (diff)
Implement more CSS storage functions in roomserver (#1388)
Diffstat (limited to 'roomserver/storage/postgres/membership_table.go')
-rw-r--r--roomserver/storage/postgres/membership_table.go59
1 files changed, 59 insertions, 0 deletions
diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go
index 0799647e..5164f654 100644
--- a/roomserver/storage/postgres/membership_table.go
+++ b/roomserver/storage/postgres/membership_table.go
@@ -18,7 +18,9 @@ package postgres
import (
"context"
"database/sql"
+ "fmt"
+ "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
@@ -62,6 +64,10 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
);
`
+var selectJoinedUsersSetForRoomsSQL = "" +
+ "SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid = ANY($1) AND" +
+ " membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " GROUP BY target_nid"
+
// Insert a row in to membership table so that it can be locked by the
// SELECT FOR UPDATE
const insertMembershipSQL = "" +
@@ -102,6 +108,16 @@ const updateMembershipSQL = "" +
const selectRoomsWithMembershipSQL = "" +
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2"
+// selectKnownUsersSQL uses a sub-select statement here to find rooms that the user is
+// joined to. Since this information is used to populate the user directory, we will
+// only return users that the user would ordinarily be able to see anyway.
+var selectKnownUsersSQL = "" +
+ "SELECT DISTINCT event_state_key FROM roomserver_membership INNER JOIN roomserver_event_state_keys ON " +
+ "roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" +
+ " WHERE room_nid = ANY(" +
+ " SELECT DISTINCT room_nid FROM roomserver_membership WHERE target_nid=$1 AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) +
+ ") AND membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " AND event_state_key LIKE $2 LIMIT $3"
+
type membershipStatements struct {
insertMembershipStmt *sql.Stmt
selectMembershipForUpdateStmt *sql.Stmt
@@ -112,6 +128,8 @@ type membershipStatements struct {
selectLocalMembershipsFromRoomStmt *sql.Stmt
updateMembershipStmt *sql.Stmt
selectRoomsWithMembershipStmt *sql.Stmt
+ selectJoinedUsersSetForRoomsStmt *sql.Stmt
+ selectKnownUsersStmt *sql.Stmt
}
func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) {
@@ -131,6 +149,8 @@ func NewPostgresMembershipTable(db *sql.DB) (tables.Membership, error) {
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
{&s.updateMembershipStmt, updateMembershipSQL},
{&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL},
+ {&s.selectJoinedUsersSetForRoomsStmt, selectJoinedUsersSetForRoomsSQL},
+ {&s.selectKnownUsersStmt, selectKnownUsersSQL},
}.Prepare(db)
}
@@ -246,3 +266,42 @@ func (s *membershipStatements) SelectRoomsWithMembership(
}
return roomNIDs, nil
}
+
+func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
+ roomIDarray := make([]int64, len(roomNIDs))
+ for i := range roomNIDs {
+ roomIDarray[i] = int64(roomNIDs[i])
+ }
+ rows, err := s.selectJoinedUsersSetForRoomsStmt.QueryContext(ctx, pq.Int64Array(roomIDarray))
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed")
+ result := make(map[types.EventStateKeyNID]int)
+ for rows.Next() {
+ var userID types.EventStateKeyNID
+ var count int
+ if err := rows.Scan(&userID, &count); err != nil {
+ return nil, err
+ }
+ result[userID] = count
+ }
+ return result, rows.Err()
+}
+
+func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) {
+ rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
+ if err != nil {
+ return nil, err
+ }
+ result := []string{}
+ defer internal.CloseAndLogIfError(ctx, rows, "SelectKnownUsers: rows.close() failed")
+ for rows.Next() {
+ var userID string
+ if err := rows.Scan(&userID); err != nil {
+ return nil, err
+ }
+ result = append(result, userID)
+ }
+ return result, rows.Err()
+}