aboutsummaryrefslogtreecommitdiff
path: root/syncapi/storage/postgres/current_room_state_table.go
diff options
context:
space:
mode:
Diffstat (limited to 'syncapi/storage/postgres/current_room_state_table.go')
-rw-r--r--syncapi/storage/postgres/current_room_state_table.go12
1 files changed, 6 insertions, 6 deletions
diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go
index 5e6daaaf..4ffd2961 100644
--- a/syncapi/storage/postgres/current_room_state_table.go
+++ b/syncapi/storage/postgres/current_room_state_table.go
@@ -185,9 +185,9 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
// SelectJoinedUsers returns a map of room ID to a list of joined user IDs.
func (s *currentRoomStateStatements) SelectJoinedUsers(
- ctx context.Context,
+ ctx context.Context, txn *sql.Tx,
) (map[string][]string, error) {
- rows, err := s.selectJoinedUsersStmt.QueryContext(ctx)
+ rows, err := sqlutil.TxStmt(txn, s.selectJoinedUsersStmt).QueryContext(ctx)
if err != nil {
return nil, err
}
@@ -209,9 +209,9 @@ func (s *currentRoomStateStatements) SelectJoinedUsers(
// SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room.
func (s *currentRoomStateStatements) SelectJoinedUsersInRoom(
- ctx context.Context, roomIDs []string,
+ ctx context.Context, txn *sql.Tx, roomIDs []string,
) (map[string][]string, error) {
- rows, err := s.selectJoinedUsersInRoomStmt.QueryContext(ctx, pq.StringArray(roomIDs))
+ rows, err := sqlutil.TxStmt(txn, s.selectJoinedUsersInRoomStmt).QueryContext(ctx, pq.StringArray(roomIDs))
if err != nil {
return nil, err
}
@@ -387,9 +387,9 @@ func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) {
}
func (s *currentRoomStateStatements) SelectStateEvent(
- ctx context.Context, roomID, evType, stateKey string,
+ ctx context.Context, txn *sql.Tx, roomID, evType, stateKey string,
) (*gomatrixserverlib.HeaderedEvent, error) {
- stmt := s.selectStateEventStmt
+ stmt := sqlutil.TxStmt(txn, s.selectStateEventStmt)
var res []byte
err := stmt.QueryRowContext(ctx, roomID, evType, stateKey).Scan(&res)
if err == sql.ErrNoRows {