diff options
author | Neil Alexander <neilalexander@users.noreply.github.com> | 2022-09-28 10:18:03 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-09-28 10:18:03 +0100 |
commit | 3f9e38e80a7be356aaf1294038888df27e0697a8 (patch) | |
tree | 529efc1f8841e409c590d698495c4a9e96383308 /syncapi/storage/postgres | |
parent | a574ed53696c06e6be6dbe313af0caaa56a659ec (diff) |
Consistent `*sql.Tx` usage across sync API (#2744)
This tidies up the `storage` package so that everything takes a
transaction parameter instead of something things that do and some that
don't.
Diffstat (limited to 'syncapi/storage/postgres')
-rw-r--r-- | syncapi/storage/postgres/account_data_table.go | 5 | ||||
-rw-r--r-- | syncapi/storage/postgres/backwards_extremities_table.go | 4 | ||||
-rw-r--r-- | syncapi/storage/postgres/current_room_state_table.go | 12 | ||||
-rw-r--r-- | syncapi/storage/postgres/filter_table.go | 14 | ||||
-rw-r--r-- | syncapi/storage/postgres/invites_table.go | 2 | ||||
-rw-r--r-- | syncapi/storage/postgres/output_room_events_table.go | 4 | ||||
-rw-r--r-- | syncapi/storage/postgres/output_room_events_topology_table.go | 8 | ||||
-rw-r--r-- | syncapi/storage/postgres/peeks_table.go | 4 | ||||
-rw-r--r-- | syncapi/storage/postgres/receipt_table.go | 4 |
9 files changed, 30 insertions, 27 deletions
diff --git a/syncapi/storage/postgres/account_data_table.go b/syncapi/storage/postgres/account_data_table.go index e9c72058..aa54cb08 100644 --- a/syncapi/storage/postgres/account_data_table.go +++ b/syncapi/storage/postgres/account_data_table.go @@ -99,14 +99,15 @@ func (s *accountDataStatements) InsertAccountData( } func (s *accountDataStatements) SelectAccountDataInRange( - ctx context.Context, + ctx context.Context, txn *sql.Tx, userID string, r types.Range, accountDataEventFilter *gomatrixserverlib.EventFilter, ) (data map[string][]string, pos types.StreamPosition, err error) { data = make(map[string][]string) - rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High(), + rows, err := sqlutil.TxStmt(txn, s.selectAccountDataInRangeStmt).QueryContext( + ctx, userID, r.Low(), r.High(), pq.StringArray(filterConvertTypeWildcardToSQL(accountDataEventFilter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(accountDataEventFilter.NotTypes)), accountDataEventFilter.Limit, diff --git a/syncapi/storage/postgres/backwards_extremities_table.go b/syncapi/storage/postgres/backwards_extremities_table.go index d4515735..8fc92091 100644 --- a/syncapi/storage/postgres/backwards_extremities_table.go +++ b/syncapi/storage/postgres/backwards_extremities_table.go @@ -79,9 +79,9 @@ func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( } func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom( - ctx context.Context, roomID string, + ctx context.Context, txn *sql.Tx, roomID string, ) (bwExtrems map[string][]string, err error) { - rows, err := s.selectBackwardExtremitiesForRoomStmt.QueryContext(ctx, roomID) + rows, err := sqlutil.TxStmt(txn, s.selectBackwardExtremitiesForRoomStmt).QueryContext(ctx, roomID) if err != nil { return } diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index 5e6daaaf..4ffd2961 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -185,9 +185,9 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro // SelectJoinedUsers returns a map of room ID to a list of joined user IDs. func (s *currentRoomStateStatements) SelectJoinedUsers( - ctx context.Context, + ctx context.Context, txn *sql.Tx, ) (map[string][]string, error) { - rows, err := s.selectJoinedUsersStmt.QueryContext(ctx) + rows, err := sqlutil.TxStmt(txn, s.selectJoinedUsersStmt).QueryContext(ctx) if err != nil { return nil, err } @@ -209,9 +209,9 @@ func (s *currentRoomStateStatements) SelectJoinedUsers( // SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room. func (s *currentRoomStateStatements) SelectJoinedUsersInRoom( - ctx context.Context, roomIDs []string, + ctx context.Context, txn *sql.Tx, roomIDs []string, ) (map[string][]string, error) { - rows, err := s.selectJoinedUsersInRoomStmt.QueryContext(ctx, pq.StringArray(roomIDs)) + rows, err := sqlutil.TxStmt(txn, s.selectJoinedUsersInRoomStmt).QueryContext(ctx, pq.StringArray(roomIDs)) if err != nil { return nil, err } @@ -387,9 +387,9 @@ func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) { } func (s *currentRoomStateStatements) SelectStateEvent( - ctx context.Context, roomID, evType, stateKey string, + ctx context.Context, txn *sql.Tx, roomID, evType, stateKey string, ) (*gomatrixserverlib.HeaderedEvent, error) { - stmt := s.selectStateEventStmt + stmt := sqlutil.TxStmt(txn, s.selectStateEventStmt) var res []byte err := stmt.QueryRowContext(ctx, roomID, evType, stateKey).Scan(&res) if err == sql.ErrNoRows { diff --git a/syncapi/storage/postgres/filter_table.go b/syncapi/storage/postgres/filter_table.go index c82ef092..86cec362 100644 --- a/syncapi/storage/postgres/filter_table.go +++ b/syncapi/storage/postgres/filter_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "encoding/json" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/gomatrixserverlib" ) @@ -73,11 +74,11 @@ func NewPostgresFilterTable(db *sql.DB) (tables.Filter, error) { } func (s *filterStatements) SelectFilter( - ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string, + ctx context.Context, txn *sql.Tx, target *gomatrixserverlib.Filter, localpart string, filterID string, ) error { // Retrieve filter from database (stored as canonical JSON) var filterData []byte - err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData) + err := sqlutil.TxStmt(txn, s.selectFilterStmt).QueryRowContext(ctx, localpart, filterID).Scan(&filterData) if err != nil { return err } @@ -90,7 +91,7 @@ func (s *filterStatements) SelectFilter( } func (s *filterStatements) InsertFilter( - ctx context.Context, filter *gomatrixserverlib.Filter, localpart string, + ctx context.Context, txn *sql.Tx, filter *gomatrixserverlib.Filter, localpart string, ) (filterID string, err error) { var existingFilterID string @@ -111,8 +112,9 @@ func (s *filterStatements) InsertFilter( // This can result in a race condition when two clients try to insert the // same filter and localpart at the same time, however this is not a // problem as both calls will result in the same filterID - err = s.selectFilterIDByContentStmt.QueryRowContext(ctx, - localpart, filterJSON).Scan(&existingFilterID) + err = sqlutil.TxStmt(txn, s.selectFilterIDByContentStmt).QueryRowContext( + ctx, localpart, filterJSON, + ).Scan(&existingFilterID) if err != nil && err != sql.ErrNoRows { return "", err } @@ -122,7 +124,7 @@ func (s *filterStatements) InsertFilter( } // Otherwise insert the filter and return the new ID - err = s.insertFilterStmt.QueryRowContext(ctx, filterJSON, localpart). + err = sqlutil.TxStmt(txn, s.insertFilterStmt).QueryRowContext(ctx, filterJSON, localpart). Scan(&filterID) return } diff --git a/syncapi/storage/postgres/invites_table.go b/syncapi/storage/postgres/invites_table.go index 97001ae2..f87ccf96 100644 --- a/syncapi/storage/postgres/invites_table.go +++ b/syncapi/storage/postgres/invites_table.go @@ -99,7 +99,7 @@ func (s *inviteEventsStatements) InsertInviteEvent( return } - err = s.insertInviteEventStmt.QueryRowContext( + err = sqlutil.TxStmt(txn, s.insertInviteEventStmt).QueryRowContext( ctx, inviteEvent.RoomID(), inviteEvent.EventID(), diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index 20a9ea42..cb092150 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -222,12 +222,12 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { }.Prepare(db) } -func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error { +func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent) error { headeredJSON, err := json.Marshal(event) if err != nil { return err } - _, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID()) + _, err = sqlutil.TxStmt(txn, s.updateEventJSONStmt).ExecContext(ctx, headeredJSON, event.EventID()) return err } diff --git a/syncapi/storage/postgres/output_room_events_topology_table.go b/syncapi/storage/postgres/output_room_events_topology_table.go index a1fc9b2a..6fab900e 100644 --- a/syncapi/storage/postgres/output_room_events_topology_table.go +++ b/syncapi/storage/postgres/output_room_events_topology_table.go @@ -173,7 +173,7 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( func (s *outputRoomEventsTopologyStatements) SelectPositionInTopology( ctx context.Context, txn *sql.Tx, eventID string, ) (pos, spos types.StreamPosition, err error) { - err = s.selectPositionInTopologyStmt.QueryRowContext(ctx, eventID).Scan(&pos, &spos) + err = sqlutil.TxStmt(txn, s.selectPositionInTopologyStmt).QueryRowContext(ctx, eventID).Scan(&pos, &spos) return } @@ -183,9 +183,9 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition( ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, backwardOrdering bool, ) (topoPos types.StreamPosition, err error) { if backwardOrdering { - err = s.selectStreamToTopologicalPositionDescStmt.QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos) + err = sqlutil.TxStmt(txn, s.selectStreamToTopologicalPositionDescStmt).QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos) } else { - err = s.selectStreamToTopologicalPositionAscStmt.QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos) + err = sqlutil.TxStmt(txn, s.selectStreamToTopologicalPositionAscStmt).QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos) } return } @@ -193,6 +193,6 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition( func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology( ctx context.Context, txn *sql.Tx, roomID string, ) (pos types.StreamPosition, spos types.StreamPosition, err error) { - err = s.selectMaxPositionInTopologyStmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos) + err = sqlutil.TxStmt(txn, s.selectMaxPositionInTopologyStmt).QueryRowContext(ctx, roomID).Scan(&pos, &spos) return } diff --git a/syncapi/storage/postgres/peeks_table.go b/syncapi/storage/postgres/peeks_table.go index 75eeac98..e20a4882 100644 --- a/syncapi/storage/postgres/peeks_table.go +++ b/syncapi/storage/postgres/peeks_table.go @@ -152,9 +152,9 @@ func (s *peekStatements) SelectPeeksInRange( } func (s *peekStatements) SelectPeekingDevices( - ctx context.Context, + ctx context.Context, txn *sql.Tx, ) (peekingDevices map[string][]types.PeekingDevice, err error) { - rows, err := s.selectPeekingDevicesStmt.QueryContext(ctx) + rows, err := sqlutil.TxStmt(txn, s.selectPeekingDevicesStmt).QueryContext(ctx) if err != nil { return nil, err } diff --git a/syncapi/storage/postgres/receipt_table.go b/syncapi/storage/postgres/receipt_table.go index bbddaa93..327a7a37 100644 --- a/syncapi/storage/postgres/receipt_table.go +++ b/syncapi/storage/postgres/receipt_table.go @@ -104,9 +104,9 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room return } -func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) { +func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, txn *sql.Tx, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) { var lastPos types.StreamPosition - rows, err := r.selectRoomReceipts.QueryContext(ctx, pq.Array(roomIDs), streamPos) + rows, err := sqlutil.TxStmt(txn, r.selectRoomReceipts).QueryContext(ctx, pq.Array(roomIDs), streamPos) if err != nil { return 0, nil, fmt.Errorf("unable to query room receipts: %w", err) } |