diff options
Diffstat (limited to 'syncapi/storage/postgres/current_room_state_table.go')
-rw-r--r-- | syncapi/storage/postgres/current_room_state_table.go | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index 5e6daaaf..4ffd2961 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -185,9 +185,9 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro // SelectJoinedUsers returns a map of room ID to a list of joined user IDs. func (s *currentRoomStateStatements) SelectJoinedUsers( - ctx context.Context, + ctx context.Context, txn *sql.Tx, ) (map[string][]string, error) { - rows, err := s.selectJoinedUsersStmt.QueryContext(ctx) + rows, err := sqlutil.TxStmt(txn, s.selectJoinedUsersStmt).QueryContext(ctx) if err != nil { return nil, err } @@ -209,9 +209,9 @@ func (s *currentRoomStateStatements) SelectJoinedUsers( // SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room. func (s *currentRoomStateStatements) SelectJoinedUsersInRoom( - ctx context.Context, roomIDs []string, + ctx context.Context, txn *sql.Tx, roomIDs []string, ) (map[string][]string, error) { - rows, err := s.selectJoinedUsersInRoomStmt.QueryContext(ctx, pq.StringArray(roomIDs)) + rows, err := sqlutil.TxStmt(txn, s.selectJoinedUsersInRoomStmt).QueryContext(ctx, pq.StringArray(roomIDs)) if err != nil { return nil, err } @@ -387,9 +387,9 @@ func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) { } func (s *currentRoomStateStatements) SelectStateEvent( - ctx context.Context, roomID, evType, stateKey string, + ctx context.Context, txn *sql.Tx, roomID, evType, stateKey string, ) (*gomatrixserverlib.HeaderedEvent, error) { - stmt := s.selectStateEventStmt + stmt := sqlutil.TxStmt(txn, s.selectStateEventStmt) var res []byte err := stmt.QueryRowContext(ctx, roomID, evType, stateKey).Scan(&res) if err == sql.ErrNoRows { |