diff options
Diffstat (limited to 'syncapi/storage/postgres')
-rw-r--r-- | syncapi/storage/postgres/current_room_state_table.go | 30 |
1 files changed, 30 insertions, 0 deletions
diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index 8ee387b3..c4667baf 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -107,6 +107,11 @@ const selectEventsWithEventIDsSQL = "" + "SELECT event_id, added_at, headered_event_json, 0 AS session_id, false AS exclude_from_sync, '' AS transaction_id" + " FROM syncapi_current_room_state WHERE event_id = ANY($1)" +const selectSharedUsersSQL = "" + + "SELECT state_key FROM syncapi_current_room_state WHERE room_id = ANY(" + + " SELECT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" + + ") AND state_key = ANY($2) AND membership='join';" + type currentRoomStateStatements struct { upsertRoomStateStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt @@ -118,6 +123,7 @@ type currentRoomStateStatements struct { selectJoinedUsersInRoomStmt *sql.Stmt selectEventsWithEventIDsStmt *sql.Stmt selectStateEventStmt *sql.Stmt + selectSharedUsersStmt *sql.Stmt } func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) { @@ -156,6 +162,9 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil { return nil, err } + if s.selectSharedUsersStmt, err = db.Prepare(selectSharedUsersSQL); err != nil { + return nil, err + } return s, nil } @@ -379,3 +388,24 @@ func (s *currentRoomStateStatements) SelectStateEvent( } return &ev, err } + +func (s *currentRoomStateStatements) SelectSharedUsers( + ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string, +) ([]string, error) { + stmt := sqlutil.TxStmt(txn, s.selectSharedUsersStmt) + rows, err := stmt.QueryContext(ctx, userID, otherUserIDs) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectSharedUsersStmt: rows.close() failed") + + var stateKey string + result := make([]string, 0, len(otherUserIDs)) + for rows.Next() { + if err := rows.Scan(&stateKey); err != nil { + return nil, err + } + result = append(result, stateKey) + } + return result, rows.Err() +} |