diff options
author | Neil Alexander <neilalexander@users.noreply.github.com> | 2020-07-21 15:48:21 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-07-21 15:48:21 +0100 |
commit | b6bc132485ec4d6b37815929f6a4e73e5a062d3b (patch) | |
tree | dddea5f8e9fdbd25905822971a908218eb6697a6 /syncapi | |
parent | 1d72ce8b7ab759555503df37af666529749b489c (diff) |
Use TransactionWriter in other component SQLite (#1209)
* Use TransactionWriter on other component SQLites
* Fix sync API tests
* Fix panic in media API
* Fix a couple of transactions
* Fix wrong query, add some logging output
* Add debug logging into StoreEvent
* Adjust InsertRoomNID
* Update logging
Diffstat (limited to 'syncapi')
-rw-r--r-- | syncapi/storage/shared/syncserver.go | 12 | ||||
-rw-r--r-- | syncapi/storage/sqlite3/account_data_table.go | 20 | ||||
-rw-r--r-- | syncapi/storage/sqlite3/backwards_extremities_table.go | 20 | ||||
-rw-r--r-- | syncapi/storage/sqlite3/current_room_state_table.go | 42 | ||||
-rw-r--r-- | syncapi/storage/sqlite3/filter_table.go | 61 | ||||
-rw-r--r-- | syncapi/storage/sqlite3/invites_table.go | 57 | ||||
-rw-r--r-- | syncapi/storage/sqlite3/output_room_events_table.go | 70 | ||||
-rw-r--r-- | syncapi/storage/sqlite3/output_room_events_topology_table.go | 19 | ||||
-rw-r--r-- | syncapi/storage/sqlite3/send_to_device_table.go | 25 | ||||
-rw-r--r-- | syncapi/storage/sqlite3/stream_id_table.go | 19 | ||||
-rw-r--r-- | syncapi/storage/storage_test.go | 9 |
11 files changed, 222 insertions, 132 deletions
diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 32079291..e1312671 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -281,16 +281,16 @@ func (d *Database) WriteEvent( ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, ) if err != nil { - return err + return fmt.Errorf("d.OutputEvents.InsertEvent: %w", err) } pduPosition = pos if err = d.Topology.InsertEventInTopology(ctx, txn, ev, pos); err != nil { - return err + return fmt.Errorf("d.Topology.InsertEventInTopology: %w", err) } if err = d.handleBackwardExtremities(ctx, txn, ev); err != nil { - return err + return fmt.Errorf("d.handleBackwardExtremities: %w", err) } if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 { @@ -313,7 +313,7 @@ func (d *Database) updateRoomState( // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add. for _, eventID := range removedEventIDs { if err := d.CurrentRoomState.DeleteRoomStateByEventID(ctx, txn, eventID); err != nil { - return err + return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateByEventID: %w", err) } } @@ -326,13 +326,13 @@ func (d *Database) updateRoomState( if event.Type() == "m.room.member" { value, err := event.Membership() if err != nil { - return err + return fmt.Errorf("event.Membership: %w", err) } membership = &value } if err := d.CurrentRoomState.UpsertRoomState(ctx, txn, event, membership, pduPosition); err != nil { - return err + return fmt.Errorf("d.CurrentRoomState.UpsertRoomState: %w", err) } } diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index ae5caa4e..609cef14 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -49,6 +50,8 @@ const selectMaxAccountDataIDSQL = "" + "SELECT MAX(id) FROM syncapi_account_data_type" type accountDataStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter streamIDStatements *streamIDStatements insertAccountDataStmt *sql.Stmt selectMaxAccountDataIDStmt *sql.Stmt @@ -57,6 +60,8 @@ type accountDataStatements struct { func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) { s := &accountDataStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), streamIDStatements: streamID, } _, err := db.Exec(accountDataSchema) @@ -79,12 +84,15 @@ func (s *accountDataStatements) InsertAccountData( ctx context.Context, txn *sql.Tx, userID, roomID, dataType string, ) (pos types.StreamPosition, err error) { - pos, err = s.streamIDStatements.nextStreamID(ctx, txn) - if err != nil { - return - } - _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType) - return + return pos, s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + var err error + pos, err = s.streamIDStatements.nextStreamID(ctx, txn) + if err != nil { + return err + } + _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType) + return err + }) } func (s *accountDataStatements) SelectAccountDataInRange( diff --git a/syncapi/storage/sqlite3/backwards_extremities_table.go b/syncapi/storage/sqlite3/backwards_extremities_table.go index e16e54a6..1aeb041f 100644 --- a/syncapi/storage/sqlite3/backwards_extremities_table.go +++ b/syncapi/storage/sqlite3/backwards_extremities_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" ) @@ -47,13 +48,18 @@ const deleteBackwardExtremitySQL = "" + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2" type backwardExtremitiesStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter insertBackwardExtremityStmt *sql.Stmt selectBackwardExtremitiesForRoomStmt *sql.Stmt deleteBackwardExtremityStmt *sql.Stmt } func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) { - s := &backwardExtremitiesStatements{} + s := &backwardExtremitiesStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), + } _, err := db.Exec(backwardExtremitiesSchema) if err != nil { return nil, err @@ -73,8 +79,10 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string, ) (err error) { - _, err = txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID) - return + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + _, err := txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID) + return err + }) } func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom( @@ -102,6 +110,8 @@ func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom( func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( ctx context.Context, txn *sql.Tx, roomID, knownEventID string, ) (err error) { - _, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) - return + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + _, err := txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) + return err + }) } diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 85f212ad..08b42f5b 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -84,6 +84,8 @@ const selectEventsWithEventIDsSQL = "" + " FROM syncapi_current_room_state WHERE event_id IN ($1)" type currentRoomStateStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter streamIDStatements *streamIDStatements upsertRoomStateStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt @@ -95,6 +97,8 @@ type currentRoomStateStatements struct { func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) { s := ¤tRoomStateStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), streamIDStatements: streamID, } _, err := db.Exec(currentRoomStateSchema) @@ -196,9 +200,11 @@ func (s *currentRoomStateStatements) SelectCurrentState( func (s *currentRoomStateStatements) DeleteRoomStateByEventID( ctx context.Context, txn *sql.Tx, eventID string, ) error { - stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt) - _, err := stmt.ExecContext(ctx, eventID) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt) + _, err := stmt.ExecContext(ctx, eventID) + return err + }) } func (s *currentRoomStateStatements) UpsertRoomState( @@ -219,20 +225,22 @@ func (s *currentRoomStateStatements) UpsertRoomState( } // upsert state event - stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt) - _, err = stmt.ExecContext( - ctx, - event.RoomID(), - event.EventID(), - event.Type(), - event.Sender(), - containsURL, - *event.StateKey(), - headeredJSON, - membership, - addedAt, - ) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt) + _, err := stmt.ExecContext( + ctx, + event.RoomID(), + event.EventID(), + event.Type(), + event.Sender(), + containsURL, + *event.StateKey(), + headeredJSON, + membership, + addedAt, + ) + return err + }) } func (s *currentRoomStateStatements) SelectEventsWithEventIDs( diff --git a/syncapi/storage/sqlite3/filter_table.go b/syncapi/storage/sqlite3/filter_table.go index 8b26759d..3e8a4655 100644 --- a/syncapi/storage/sqlite3/filter_table.go +++ b/syncapi/storage/sqlite3/filter_table.go @@ -20,6 +20,7 @@ import ( "encoding/json" "fmt" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/gomatrixserverlib" ) @@ -50,6 +51,8 @@ const insertFilterSQL = "" + "INSERT INTO syncapi_filter (filter, localpart) VALUES ($1, $2)" type filterStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter selectFilterStmt *sql.Stmt selectFilterIDByContentStmt *sql.Stmt insertFilterStmt *sql.Stmt @@ -60,7 +63,10 @@ func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) { if err != nil { return nil, err } - s := &filterStatements{} + s := &filterStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), + } if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil { return nil, err } @@ -108,30 +114,33 @@ func (s *filterStatements) InsertFilter( return "", err } - // Check if filter already exists in the database using its localpart and content - // - // This can result in a race condition when two clients try to insert the - // same filter and localpart at the same time, however this is not a - // problem as both calls will result in the same filterID - err = s.selectFilterIDByContentStmt.QueryRowContext(ctx, - localpart, filterJSON).Scan(&existingFilterID) - if err != nil && err != sql.ErrNoRows { - return "", err - } - // If it does, return the existing ID - if existingFilterID != "" { - return existingFilterID, err - } - - // Otherwise insert the filter and return the new ID - res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart) - if err != nil { - return "", err - } - rowid, err := res.LastInsertId() - if err != nil { - return "", err - } - filterID = fmt.Sprintf("%d", rowid) + err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + // Check if filter already exists in the database using its localpart and content + // + // This can result in a race condition when two clients try to insert the + // same filter and localpart at the same time, however this is not a + // problem as both calls will result in the same filterID + err = s.selectFilterIDByContentStmt.QueryRowContext(ctx, + localpart, filterJSON).Scan(&existingFilterID) + if err != nil && err != sql.ErrNoRows { + return err + } + // If it does, return the existing ID + if existingFilterID != "" { + return nil + } + + // Otherwise insert the filter and return the new ID + res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart) + if err != nil { + return err + } + rowid, err := res.LastInsertId() + if err != nil { + return err + } + filterID = fmt.Sprintf("%d", rowid) + return nil + }) return } diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go index aa051388..19e7a7c6 100644 --- a/syncapi/storage/sqlite3/invites_table.go +++ b/syncapi/storage/sqlite3/invites_table.go @@ -58,6 +58,8 @@ const selectMaxInviteIDSQL = "" + "SELECT MAX(id) FROM syncapi_invite_events" type inviteEventsStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter streamIDStatements *streamIDStatements insertInviteEventStmt *sql.Stmt selectInviteEventsInRangeStmt *sql.Stmt @@ -67,6 +69,8 @@ type inviteEventsStatements struct { func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) { s := &inviteEventsStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), streamIDStatements: streamID, } _, err := db.Exec(inviteEventsSchema) @@ -91,36 +95,45 @@ func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Inv func (s *inviteEventsStatements) InsertInviteEvent( ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent, ) (streamPos types.StreamPosition, err error) { - streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn) - if err != nil { - return - } + err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + var err error + streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn) + if err != nil { + return err + } - var headeredJSON []byte - headeredJSON, err = json.Marshal(inviteEvent) - if err != nil { - return - } + var headeredJSON []byte + headeredJSON, err = json.Marshal(inviteEvent) + if err != nil { + return err + } - _, err = txn.Stmt(s.insertInviteEventStmt).ExecContext( - ctx, - streamPos, - inviteEvent.RoomID(), - inviteEvent.EventID(), - *inviteEvent.StateKey(), - headeredJSON, - ) + _, err = txn.Stmt(s.insertInviteEventStmt).ExecContext( + ctx, + streamPos, + inviteEvent.RoomID(), + inviteEvent.EventID(), + *inviteEvent.StateKey(), + headeredJSON, + ) + return err + }) return } func (s *inviteEventsStatements) DeleteInviteEvent( ctx context.Context, inviteEventID string, ) (types.StreamPosition, error) { - streamPos, err := s.streamIDStatements.nextStreamID(ctx, nil) - if err != nil { - return streamPos, err - } - _, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID) + var streamPos types.StreamPosition + err := s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + var err error + streamPos, err = s.streamIDStatements.nextStreamID(ctx, nil) + if err != nil { + return err + } + _, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID) + return err + }) return streamPos, err } diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index da2ea3f6..12b4dbab 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -104,6 +104,8 @@ const selectStateInRangeSQL = "" + " LIMIT $8" // limit type outputRoomEventsStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter streamIDStatements *streamIDStatements insertEventStmt *sql.Stmt selectEventsStmt *sql.Stmt @@ -117,6 +119,8 @@ type outputRoomEventsStatements struct { func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) { s := &outputRoomEventsStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), streamIDStatements: streamID, } _, err := db.Exec(outputRoomEventsSchema) @@ -155,8 +159,10 @@ func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event if err != nil { return err } - _, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID()) - return err + return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + _, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID()) + return err + }) } // selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos. @@ -267,7 +273,7 @@ func (s *outputRoomEventsStatements) InsertEvent( ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, addState, removeState []string, transactionID *api.TransactionID, excludeFromSync bool, -) (streamPos types.StreamPosition, err error) { +) (types.StreamPosition, error) { var txnID *string var sessionID *int64 if transactionID != nil { @@ -284,43 +290,47 @@ func (s *outputRoomEventsStatements) InsertEvent( } var headeredJSON []byte - headeredJSON, err = json.Marshal(event) - if err != nil { - return - } - - streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn) + headeredJSON, err := json.Marshal(event) if err != nil { - return + return 0, err } addStateJSON, err := json.Marshal(addState) if err != nil { - return + return 0, err } removeStateJSON, err := json.Marshal(removeState) if err != nil { - return + return 0, err } - insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) - _, err = insertStmt.ExecContext( - ctx, - streamPos, - event.RoomID(), - event.EventID(), - headeredJSON, - event.Type(), - event.Sender(), - containsURL, - string(addStateJSON), - string(removeStateJSON), - sessionID, - txnID, - excludeFromSync, - excludeFromSync, - ) - return + var streamPos types.StreamPosition + err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn) + if err != nil { + return err + } + + insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) + _, ierr := insertStmt.ExecContext( + ctx, + streamPos, + event.RoomID(), + event.EventID(), + headeredJSON, + event.Type(), + event.Sender(), + containsURL, + string(addStateJSON), + string(removeStateJSON), + sessionID, + txnID, + excludeFromSync, + excludeFromSync, + ) + return ierr + }) + return streamPos, err } func (s *outputRoomEventsStatements) SelectRecentEvents( diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go index 811dfa4f..2e71e8f3 100644 --- a/syncapi/storage/sqlite3/output_room_events_topology_table.go +++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go @@ -66,6 +66,8 @@ const selectMaxPositionInTopologySQL = "" + " WHERE room_id = $1 ORDER BY stream_position DESC" type outputRoomEventsTopologyStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter insertEventInTopologyStmt *sql.Stmt selectEventIDsInRangeASCStmt *sql.Stmt selectEventIDsInRangeDESCStmt *sql.Stmt @@ -74,7 +76,10 @@ type outputRoomEventsTopologyStatements struct { } func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { - s := &outputRoomEventsTopologyStatements{} + s := &outputRoomEventsTopologyStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), + } _, err := db.Exec(outputRoomEventsTopologySchema) if err != nil { return nil, err @@ -102,11 +107,13 @@ func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { func (s *outputRoomEventsTopologyStatements) InsertEventInTopology( ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition, ) (err error) { - stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt) - _, err = stmt.ExecContext( - ctx, event.EventID(), event.Depth(), event.RoomID(), pos, - ) - return + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt) + _, err := stmt.ExecContext( + ctx, event.EventID(), event.Depth(), event.RoomID(), pos, + ) + return err + }) } func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( diff --git a/syncapi/storage/sqlite3/send_to_device_table.go b/syncapi/storage/sqlite3/send_to_device_table.go index 42bd3c19..88b319fb 100644 --- a/syncapi/storage/sqlite3/send_to_device_table.go +++ b/syncapi/storage/sqlite3/send_to_device_table.go @@ -72,13 +72,18 @@ const deleteSendToDeviceMessagesSQL = ` ` type sendToDeviceStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter insertSendToDeviceMessageStmt *sql.Stmt selectSendToDeviceMessagesStmt *sql.Stmt countSendToDeviceMessagesStmt *sql.Stmt } func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { - s := &sendToDeviceStatements{} + s := &sendToDeviceStatements{ + db: db, + writer: sqlutil.NewTransactionWriter(), + } _, err := db.Exec(sendToDeviceSchema) if err != nil { return nil, err @@ -98,8 +103,10 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { func (s *sendToDeviceStatements) InsertSendToDeviceMessage( ctx context.Context, txn *sql.Tx, userID, deviceID, content string, ) (err error) { - _, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content) - return + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + _, err := sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content) + return err + }) } func (s *sendToDeviceStatements) CountSendToDeviceMessages( @@ -156,8 +163,10 @@ func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages( for k, v := range nids { params[k+1] = v } - _, err = txn.ExecContext(ctx, query, params...) - return + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + _, err := txn.ExecContext(ctx, query, params...) + return err + }) } func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( @@ -168,6 +177,8 @@ func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( for k, v := range nids { params[k] = v } - _, err = txn.ExecContext(ctx, query, params...) - return + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + _, err := txn.ExecContext(ctx, query, params...) + return err + }) } diff --git a/syncapi/storage/sqlite3/stream_id_table.go b/syncapi/storage/sqlite3/stream_id_table.go index 57abd9c4..cf3eed5b 100644 --- a/syncapi/storage/sqlite3/stream_id_table.go +++ b/syncapi/storage/sqlite3/stream_id_table.go @@ -27,11 +27,15 @@ const selectStreamIDStmt = "" + "SELECT stream_id FROM syncapi_stream_id WHERE stream_name = $1" type streamIDStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter increaseStreamIDStmt *sql.Stmt selectStreamIDStmt *sql.Stmt } func (s *streamIDStatements) prepare(db *sql.DB) (err error) { + s.db = db + s.writer = sqlutil.NewTransactionWriter() _, err = db.Exec(streamIDTableSchema) if err != nil { return @@ -48,11 +52,14 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) { func (s *streamIDStatements) nextStreamID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt) - if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil { - return - } - if err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos); err != nil { - return - } + err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + if _, ierr := increaseStmt.ExecContext(ctx, "global"); err != nil { + return ierr + } + if serr := selectStmt.QueryRowContext(ctx, "global").Scan(&pos); err != nil { + return serr + } + return nil + }) return } diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index feacbc18..474d3222 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -5,6 +5,7 @@ import ( "crypto/ed25519" "encoding/json" "fmt" + "os" "testing" "time" @@ -52,7 +53,13 @@ func MustCreateEvent(t *testing.T, roomID string, prevs []gomatrixserverlib.Head } func MustCreateDatabase(t *testing.T) storage.Database { - db, err := sqlite3.NewDatabase("file::memory:") + dbname := fmt.Sprintf("test_%s.db", t.Name()) + if _, err := os.Stat(dbname); err == nil { + if err = os.Remove(dbname); err != nil { + t.Fatalf("tried to delete stale test database but failed: %s", err) + } + } + db, err := sqlite3.NewDatabase(fmt.Sprintf("file:%s", dbname)) if err != nil { t.Fatalf("NewSyncServerDatasource returned %s", err) } |