aboutsummaryrefslogtreecommitdiff
path: root/syncapi/storage
diff options
context:
space:
mode:
Diffstat (limited to 'syncapi/storage')
-rw-r--r--syncapi/storage/postgres/current_room_state_table.go4
-rw-r--r--syncapi/storage/sqlite3/current_room_state_table.go43
2 files changed, 25 insertions, 22 deletions
diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go
index d13b7be4..58f40451 100644
--- a/syncapi/storage/postgres/current_room_state_table.go
+++ b/syncapi/storage/postgres/current_room_state_table.go
@@ -112,7 +112,7 @@ const selectEventsWithEventIDsSQL = "" +
const selectSharedUsersSQL = "" +
"SELECT state_key FROM syncapi_current_room_state WHERE room_id = ANY(" +
" SELECT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" +
- ") AND state_key = ANY($2) AND membership='join';"
+ ") AND state_key = ANY($2) AND membership IN ('join', 'invite');"
type currentRoomStateStatements struct {
upsertRoomStateStmt *sql.Stmt
@@ -407,7 +407,7 @@ func (s *currentRoomStateStatements) SelectSharedUsers(
ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string,
) ([]string, error) {
stmt := sqlutil.TxStmt(txn, s.selectSharedUsersStmt)
- rows, err := stmt.QueryContext(ctx, userID, otherUserIDs)
+ rows, err := stmt.QueryContext(ctx, userID, pq.Array(otherUserIDs))
if err != nil {
return nil, err
}
diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go
index e19298ae..3a10b232 100644
--- a/syncapi/storage/sqlite3/current_room_state_table.go
+++ b/syncapi/storage/sqlite3/current_room_state_table.go
@@ -94,9 +94,9 @@ const selectEventsWithEventIDsSQL = "" +
" FROM syncapi_current_room_state WHERE event_id IN ($1)"
const selectSharedUsersSQL = "" +
- "SELECT state_key FROM syncapi_current_room_state WHERE room_id = ANY(" +
+ "SELECT state_key FROM syncapi_current_room_state WHERE room_id IN(" +
" SELECT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" +
- ") AND state_key IN ($2) AND membership='join';"
+ ") AND state_key IN ($2) AND membership IN ('join', 'invite');"
type currentRoomStateStatements struct {
db *sql.DB
@@ -420,25 +420,28 @@ func (s *currentRoomStateStatements) SelectStateEvent(
func (s *currentRoomStateStatements) SelectSharedUsers(
ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string,
) ([]string, error) {
- query := strings.Replace(selectSharedUsersSQL, "($2)", sqlutil.QueryVariadicOffset(len(otherUserIDs), 1), 1)
- stmt, err := s.db.Prepare(query)
- if err != nil {
- return nil, fmt.Errorf("SelectSharedUsers s.db.Prepare: %w", err)
- }
- defer internal.CloseAndLogIfError(ctx, stmt, "SelectSharedUsers: stmt.close() failed")
- rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, userID, otherUserIDs)
- if err != nil {
- return nil, err
+
+ params := make([]interface{}, len(otherUserIDs)+1)
+ params[0] = userID
+ for k, v := range otherUserIDs {
+ params[k+1] = v
}
- defer internal.CloseAndLogIfError(ctx, rows, "selectSharedUsersStmt: rows.close() failed")
- var stateKey string
result := make([]string, 0, len(otherUserIDs))
- for rows.Next() {
- if err := rows.Scan(&stateKey); err != nil {
- return nil, err
- }
- result = append(result, stateKey)
- }
- return result, rows.Err()
+ query := strings.Replace(selectSharedUsersSQL, "($2)", sqlutil.QueryVariadicOffset(len(otherUserIDs), 1), 1)
+ err := sqlutil.RunLimitedVariablesQuery(
+ ctx, query, s.db, params, sqlutil.SQLite3MaxVariables,
+ func(rows *sql.Rows) error {
+ var stateKey string
+ for rows.Next() {
+ if err := rows.Scan(&stateKey); err != nil {
+ return err
+ }
+ result = append(result, stateKey)
+ }
+ return nil
+ },
+ )
+
+ return result, err
}