diff options
author | Neil Alexander <neilalexander@users.noreply.github.com> | 2020-08-21 10:42:08 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-08-21 10:42:08 +0100 |
commit | 9d53351dc20283103bf2eec6b92831033d06c5a8 (patch) | |
tree | 653cf0ddca3f777bcdba188187fb78fe39ae2b02 /syncapi | |
parent | 5aaf32bbed4d704d5a22ad7dff79f7a68002a213 (diff) |
Component-wide TransactionWriters (#1290)
* Offset updates take place using TransactionWriter
* Refactor TransactionWriter in current state server
* Refactor TransactionWriter in federation sender
* Refactor TransactionWriter in key server
* Refactor TransactionWriter in media API
* Refactor TransactionWriter in server key API
* Refactor TransactionWriter in sync API
* Refactor TransactionWriter in user API
* Fix deadlocking Sync API tests
* Un-deadlock device database
* Fix appservice API
* Rename TransactionWriters to Writers
* Move writers up a layer in sync API
* Document sqlutil.Writer interface
* Add note to Writer documentation
Diffstat (limited to 'syncapi')
-rw-r--r-- | syncapi/storage/postgres/syncserver.go | 8 | ||||
-rw-r--r-- | syncapi/storage/shared/syncserver.go | 31 | ||||
-rw-r--r-- | syncapi/storage/sqlite3/account_data_table.go | 18 | ||||
-rw-r--r-- | syncapi/storage/sqlite3/backwards_extremities_table.go | 17 | ||||
-rw-r--r-- | syncapi/storage/sqlite3/current_room_state_table.go | 40 | ||||
-rw-r--r-- | syncapi/storage/sqlite3/filter_table.go | 58 | ||||
-rw-r--r-- | syncapi/storage/sqlite3/invites_table.go | 56 | ||||
-rw-r--r-- | syncapi/storage/sqlite3/output_room_events_table.go | 55 | ||||
-rw-r--r-- | syncapi/storage/sqlite3/output_room_events_topology_table.go | 16 | ||||
-rw-r--r-- | syncapi/storage/sqlite3/send_to_device_table.go | 22 | ||||
-rw-r--r-- | syncapi/storage/sqlite3/stream_id_table.go | 15 | ||||
-rw-r--r-- | syncapi/storage/sqlite3/syncserver.go | 8 |
12 files changed, 144 insertions, 200 deletions
diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go index 26ef082f..36e8de67 100644 --- a/syncapi/storage/postgres/syncserver.go +++ b/syncapi/storage/postgres/syncserver.go @@ -30,7 +30,8 @@ import ( // both the database for PDUs and caches for EDUs. type SyncServerDatasource struct { shared.Database - db *sql.DB + db *sql.DB + writer sqlutil.Writer sqlutil.PartitionOffsetStatements } @@ -41,7 +42,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e if d.db, err = sqlutil.Open(dbProperties); err != nil { return nil, err } - if err = d.PartitionOffsetStatements.Prepare(d.db, "syncapi"); err != nil { + d.writer = sqlutil.NewDummyWriter() + if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "syncapi"); err != nil { return nil, err } accountData, err := NewPostgresAccountDataTable(d.db) @@ -78,6 +80,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e } d.Database = shared.Database{ DB: d.db, + Writer: sqlutil.NewDummyWriter(), Invites: invites, AccountData: accountData, OutputEvents: events, @@ -86,7 +89,6 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e BackwardExtremities: backwardExtremities, Filter: filter, SendToDevice: sendToDevice, - SendToDeviceWriter: sqlutil.NewTransactionWriter(), EDUCache: cache.New(), } return &d, nil diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index fdbf6758..699a6647 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -37,6 +37,7 @@ import ( // For now this contains the shared functions type Database struct { DB *sql.DB + Writer sqlutil.Writer Invites tables.Invites AccountData tables.AccountData OutputEvents tables.Events @@ -45,7 +46,6 @@ type Database struct { BackwardExtremities tables.BackwardsExtremities SendToDevice tables.SendToDevice Filter tables.Filter - SendToDeviceWriter sqlutil.TransactionWriter EDUCache *cache.EDUCache } @@ -129,10 +129,7 @@ func (d *Database) GetStateEvent( func (d *Database) GetStateEventsForRoom( ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter, ) (stateEvents []gomatrixserverlib.HeaderedEvent, err error) { - err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { - stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, txn, roomID, stateFilter) - return err - }) + stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, nil, roomID, stateFilter) return } @@ -171,9 +168,9 @@ func (d *Database) SyncStreamPosition(ctx context.Context) (types.StreamPosition func (d *Database) AddInviteEvent( ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent, ) (sp types.StreamPosition, err error) { - err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { - sp, err = d.Invites.InsertInviteEvent(ctx, txn, inviteEvent) - return err + _ = d.Writer.Do(nil, nil, func(_ *sql.Tx) error { + sp, err = d.Invites.InsertInviteEvent(ctx, nil, inviteEvent) + return nil }) return } @@ -182,8 +179,12 @@ func (d *Database) AddInviteEvent( // Returns an error if there was a problem communicating with the database. func (d *Database) RetireInviteEvent( ctx context.Context, inviteEventID string, -) (types.StreamPosition, error) { - return d.Invites.DeleteInviteEvent(ctx, inviteEventID) +) (sp types.StreamPosition, err error) { + _ = d.Writer.Do(nil, nil, func(_ *sql.Tx) error { + sp, err = d.Invites.DeleteInviteEvent(ctx, inviteEventID) + return nil + }) + return } // GetAccountDataInRange returns all account data for a given user inserted or @@ -207,7 +208,7 @@ func (d *Database) GetAccountDataInRange( func (d *Database) UpsertAccountData( ctx context.Context, userID, roomID, dataType string, ) (sp types.StreamPosition, err error) { - err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { sp, err = d.AccountData.InsertAccountData(ctx, txn, userID, roomID, dataType) return err }) @@ -237,6 +238,7 @@ func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.Strea // handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of // the events listed in the event's 'prev_events'. This function also updates the backwards extremities table // to account for the fact that the given event is no longer a backwards extremity, but may be marked as such. +// This function should always be called within a sqlutil.Writer for safety in SQLite. func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, ev *gomatrixserverlib.HeaderedEvent) error { if err := d.BackwardExtremities.DeleteBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID()); err != nil { return err @@ -275,7 +277,7 @@ func (d *Database) WriteEvent( addStateEventIDs, removeStateEventIDs []string, transactionID *api.TransactionID, excludeFromSync bool, ) (pduPosition types.StreamPosition, returnErr error) { - returnErr = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { var err error pos, err := d.OutputEvents.InsertEvent( ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, @@ -304,6 +306,7 @@ func (d *Database) WriteEvent( return pduPosition, returnErr } +// This function should always be called within a sqlutil.Writer for safety in SQLite. func (d *Database) updateRoomState( ctx context.Context, txn *sql.Tx, removedEventIDs []string, @@ -1114,7 +1117,7 @@ func (d *Database) StoreNewSendForDeviceMessage( } // Delegate the database write task to the SendToDeviceWriter. It'll guarantee // that we don't lock the table for writes in more than one place. - err = d.SendToDeviceWriter.Do(d.DB, nil, func(txn *sql.Tx) error { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.AddSendToDeviceEvent( ctx, txn, userID, deviceID, string(j), ) @@ -1179,7 +1182,7 @@ func (d *Database) CleanSendToDeviceUpdates( // If we need to write to the database then we'll ask the SendToDeviceWriter to // do that for us. It'll guarantee that we don't lock the table for writes in // more than one place. - err = d.SendToDeviceWriter.Do(d.DB, nil, func(txn *sql.Tx) error { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { // Delete any send-to-device messages marked for deletion. if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil { return fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e) diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index 248ec926..72c46e48 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -20,7 +20,6 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -51,7 +50,6 @@ const selectMaxAccountDataIDSQL = "" + type accountDataStatements struct { db *sql.DB - writer sqlutil.TransactionWriter streamIDStatements *streamIDStatements insertAccountDataStmt *sql.Stmt selectMaxAccountDataIDStmt *sql.Stmt @@ -61,7 +59,6 @@ type accountDataStatements struct { func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) { s := &accountDataStatements{ db: db, - writer: sqlutil.NewTransactionWriter(), streamIDStatements: streamID, } _, err := db.Exec(accountDataSchema) @@ -84,15 +81,12 @@ func (s *accountDataStatements) InsertAccountData( ctx context.Context, txn *sql.Tx, userID, roomID, dataType string, ) (pos types.StreamPosition, err error) { - return pos, s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - var err error - pos, err = s.streamIDStatements.nextStreamID(ctx, txn) - if err != nil { - return err - } - _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType) - return err - }) + pos, err = s.streamIDStatements.nextStreamID(ctx, txn) + if err != nil { + return + } + _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType) + return } func (s *accountDataStatements) SelectAccountDataInRange( diff --git a/syncapi/storage/sqlite3/backwards_extremities_table.go b/syncapi/storage/sqlite3/backwards_extremities_table.go index d96f2fe5..116c33dc 100644 --- a/syncapi/storage/sqlite3/backwards_extremities_table.go +++ b/syncapi/storage/sqlite3/backwards_extremities_table.go @@ -19,7 +19,6 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" ) @@ -49,7 +48,6 @@ const deleteBackwardExtremitySQL = "" + type backwardExtremitiesStatements struct { db *sql.DB - writer sqlutil.TransactionWriter insertBackwardExtremityStmt *sql.Stmt selectBackwardExtremitiesForRoomStmt *sql.Stmt deleteBackwardExtremityStmt *sql.Stmt @@ -57,8 +55,7 @@ type backwardExtremitiesStatements struct { func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) { s := &backwardExtremitiesStatements{ - db: db, - writer: sqlutil.NewTransactionWriter(), + db: db, } _, err := db.Exec(backwardExtremitiesSchema) if err != nil { @@ -79,10 +76,8 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string, ) (err error) { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - _, err := txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID) - return err - }) + _, err = txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID) + return err } func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom( @@ -110,8 +105,6 @@ func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom( func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( ctx context.Context, txn *sql.Tx, roomID, knownEventID string, ) (err error) { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - _, err := txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) - return err - }) + _, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) + return err } diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 2f0068ed..6f822c90 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -85,7 +85,6 @@ const selectEventsWithEventIDsSQL = "" + type currentRoomStateStatements struct { db *sql.DB - writer sqlutil.TransactionWriter streamIDStatements *streamIDStatements upsertRoomStateStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt @@ -98,7 +97,6 @@ type currentRoomStateStatements struct { func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) { s := ¤tRoomStateStatements{ db: db, - writer: sqlutil.NewTransactionWriter(), streamIDStatements: streamID, } _, err := db.Exec(currentRoomStateSchema) @@ -200,11 +198,9 @@ func (s *currentRoomStateStatements) SelectCurrentState( func (s *currentRoomStateStatements) DeleteRoomStateByEventID( ctx context.Context, txn *sql.Tx, eventID string, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt) - _, err := stmt.ExecContext(ctx, eventID) - return err - }) + stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt) + _, err := stmt.ExecContext(ctx, eventID) + return err } func (s *currentRoomStateStatements) UpsertRoomState( @@ -225,22 +221,20 @@ func (s *currentRoomStateStatements) UpsertRoomState( } // upsert state event - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt) - _, err := stmt.ExecContext( - ctx, - event.RoomID(), - event.EventID(), - event.Type(), - event.Sender(), - containsURL, - *event.StateKey(), - headeredJSON, - membership, - addedAt, - ) - return err - }) + stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt) + _, err = stmt.ExecContext( + ctx, + event.RoomID(), + event.EventID(), + event.Type(), + event.Sender(), + containsURL, + *event.StateKey(), + headeredJSON, + membership, + addedAt, + ) + return err } func minOfInts(a, b int) int { diff --git a/syncapi/storage/sqlite3/filter_table.go b/syncapi/storage/sqlite3/filter_table.go index 338b0b50..3092bcd7 100644 --- a/syncapi/storage/sqlite3/filter_table.go +++ b/syncapi/storage/sqlite3/filter_table.go @@ -20,7 +20,6 @@ import ( "encoding/json" "fmt" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/gomatrixserverlib" ) @@ -52,7 +51,6 @@ const insertFilterSQL = "" + type filterStatements struct { db *sql.DB - writer sqlutil.TransactionWriter selectFilterStmt *sql.Stmt selectFilterIDByContentStmt *sql.Stmt insertFilterStmt *sql.Stmt @@ -64,8 +62,7 @@ func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) { return nil, err } s := &filterStatements{ - db: db, - writer: sqlutil.NewTransactionWriter(), + db: db, } if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil { return nil, err @@ -114,33 +111,30 @@ func (s *filterStatements) InsertFilter( return "", err } - err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - // Check if filter already exists in the database using its localpart and content - // - // This can result in a race condition when two clients try to insert the - // same filter and localpart at the same time, however this is not a - // problem as both calls will result in the same filterID - err = s.selectFilterIDByContentStmt.QueryRowContext(ctx, - localpart, filterJSON).Scan(&existingFilterID) - if err != nil && err != sql.ErrNoRows { - return err - } - // If it does, return the existing ID - if existingFilterID != "" { - return nil - } - - // Otherwise insert the filter and return the new ID - res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart) - if err != nil { - return err - } - rowid, err := res.LastInsertId() - if err != nil { - return err - } - filterID = fmt.Sprintf("%d", rowid) - return nil - }) + // Check if filter already exists in the database using its localpart and content + // + // This can result in a race condition when two clients try to insert the + // same filter and localpart at the same time, however this is not a + // problem as both calls will result in the same filterID + err = s.selectFilterIDByContentStmt.QueryRowContext(ctx, + localpart, filterJSON).Scan(&existingFilterID) + if err != nil && err != sql.ErrNoRows { + return "", err + } + // If it does, return the existing ID + if existingFilterID != "" { + return existingFilterID, nil + } + + // Otherwise insert the filter and return the new ID + res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart) + if err != nil { + return "", err + } + rowid, err := res.LastInsertId() + if err != nil { + return "", err + } + filterID = fmt.Sprintf("%d", rowid) return } diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go index 0bbd79f7..45862efb 100644 --- a/syncapi/storage/sqlite3/invites_table.go +++ b/syncapi/storage/sqlite3/invites_table.go @@ -59,7 +59,6 @@ const selectMaxInviteIDSQL = "" + type inviteEventsStatements struct { db *sql.DB - writer sqlutil.TransactionWriter streamIDStatements *streamIDStatements insertInviteEventStmt *sql.Stmt selectInviteEventsInRangeStmt *sql.Stmt @@ -70,7 +69,6 @@ type inviteEventsStatements struct { func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) { s := &inviteEventsStatements{ db: db, - writer: sqlutil.NewTransactionWriter(), streamIDStatements: streamID, } _, err := db.Exec(inviteEventsSchema) @@ -95,45 +93,37 @@ func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Inv func (s *inviteEventsStatements) InsertInviteEvent( ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent, ) (streamPos types.StreamPosition, err error) { - err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - var err error - streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn) - if err != nil { - return err - } + streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn) + if err != nil { + return + } - var headeredJSON []byte - headeredJSON, err = json.Marshal(inviteEvent) - if err != nil { - return err - } + var headeredJSON []byte + headeredJSON, err = json.Marshal(inviteEvent) + if err != nil { + return + } - _, err = txn.Stmt(s.insertInviteEventStmt).ExecContext( - ctx, - streamPos, - inviteEvent.RoomID(), - inviteEvent.EventID(), - *inviteEvent.StateKey(), - headeredJSON, - ) - return err - }) + stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt) + _, err = stmt.ExecContext( + ctx, + streamPos, + inviteEvent.RoomID(), + inviteEvent.EventID(), + *inviteEvent.StateKey(), + headeredJSON, + ) return } func (s *inviteEventsStatements) DeleteInviteEvent( ctx context.Context, inviteEventID string, ) (types.StreamPosition, error) { - var streamPos types.StreamPosition - err := s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - var err error - streamPos, err = s.streamIDStatements.nextStreamID(ctx, nil) - if err != nil { - return err - } - _, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID) - return err - }) + streamPos, err := s.streamIDStatements.nextStreamID(ctx, nil) + if err != nil { + return streamPos, err + } + _, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID) return streamPos, err } diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 0d154650..f10d0106 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -105,7 +105,6 @@ const selectStateInRangeSQL = "" + type outputRoomEventsStatements struct { db *sql.DB - writer sqlutil.TransactionWriter streamIDStatements *streamIDStatements insertEventStmt *sql.Stmt selectEventsStmt *sql.Stmt @@ -120,7 +119,6 @@ type outputRoomEventsStatements struct { func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) { s := &outputRoomEventsStatements{ db: db, - writer: sqlutil.NewTransactionWriter(), streamIDStatements: streamID, } _, err := db.Exec(outputRoomEventsSchema) @@ -159,10 +157,8 @@ func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event if err != nil { return err } - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - _, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID()) - return err - }) + _, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID()) + return err } // selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos. @@ -304,32 +300,27 @@ func (s *outputRoomEventsStatements) InsertEvent( return 0, err } - var streamPos types.StreamPosition - err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn) - if err != nil { - return err - } - - insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) - _, ierr := insertStmt.ExecContext( - ctx, - streamPos, - event.RoomID(), - event.EventID(), - headeredJSON, - event.Type(), - event.Sender(), - containsURL, - string(addStateJSON), - string(removeStateJSON), - sessionID, - txnID, - excludeFromSync, - excludeFromSync, - ) - return ierr - }) + streamPos, err := s.streamIDStatements.nextStreamID(ctx, txn) + if err != nil { + return 0, err + } + insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) + _, err = insertStmt.ExecContext( + ctx, + streamPos, + event.RoomID(), + event.EventID(), + headeredJSON, + event.Type(), + event.Sender(), + containsURL, + string(addStateJSON), + string(removeStateJSON), + sessionID, + txnID, + excludeFromSync, + excludeFromSync, + ) return streamPos, err } diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go index 5c4ab005..d8c97b7e 100644 --- a/syncapi/storage/sqlite3/output_room_events_topology_table.go +++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go @@ -67,7 +67,6 @@ const selectMaxPositionInTopologySQL = "" + type outputRoomEventsTopologyStatements struct { db *sql.DB - writer sqlutil.TransactionWriter insertEventInTopologyStmt *sql.Stmt selectEventIDsInRangeASCStmt *sql.Stmt selectEventIDsInRangeDESCStmt *sql.Stmt @@ -77,8 +76,7 @@ type outputRoomEventsTopologyStatements struct { func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { s := &outputRoomEventsTopologyStatements{ - db: db, - writer: sqlutil.NewTransactionWriter(), + db: db, } _, err := db.Exec(outputRoomEventsTopologySchema) if err != nil { @@ -107,13 +105,11 @@ func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { func (s *outputRoomEventsTopologyStatements) InsertEventInTopology( ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition, ) (err error) { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt) - _, err := stmt.ExecContext( - ctx, event.EventID(), event.Depth(), event.RoomID(), pos, - ) - return err - }) + stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt) + _, err = stmt.ExecContext( + ctx, event.EventID(), event.Depth(), event.RoomID(), pos, + ) + return } func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( diff --git a/syncapi/storage/sqlite3/send_to_device_table.go b/syncapi/storage/sqlite3/send_to_device_table.go index 53786589..fbc759b1 100644 --- a/syncapi/storage/sqlite3/send_to_device_table.go +++ b/syncapi/storage/sqlite3/send_to_device_table.go @@ -73,7 +73,6 @@ const deleteSendToDeviceMessagesSQL = ` type sendToDeviceStatements struct { db *sql.DB - writer sqlutil.TransactionWriter insertSendToDeviceMessageStmt *sql.Stmt selectSendToDeviceMessagesStmt *sql.Stmt countSendToDeviceMessagesStmt *sql.Stmt @@ -81,8 +80,7 @@ type sendToDeviceStatements struct { func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { s := &sendToDeviceStatements{ - db: db, - writer: sqlutil.NewTransactionWriter(), + db: db, } _, err := db.Exec(sendToDeviceSchema) if err != nil { @@ -103,10 +101,8 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { func (s *sendToDeviceStatements) InsertSendToDeviceMessage( ctx context.Context, txn *sql.Tx, userID, deviceID, content string, ) (err error) { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - _, err := sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content) - return err - }) + _, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content) + return } func (s *sendToDeviceStatements) CountSendToDeviceMessages( @@ -163,10 +159,8 @@ func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages( for k, v := range nids { params[k+1] = v } - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - _, err := txn.ExecContext(ctx, query, params...) - return err - }) + _, err = txn.ExecContext(ctx, query, params...) + return } func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( @@ -177,8 +171,6 @@ func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( for k, v := range nids { params[k] = v } - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - _, err := txn.ExecContext(ctx, query, params...) - return err - }) + _, err = txn.ExecContext(ctx, query, params...) + return } diff --git a/syncapi/storage/sqlite3/stream_id_table.go b/syncapi/storage/sqlite3/stream_id_table.go index 1971e7f3..e6bdc4fc 100644 --- a/syncapi/storage/sqlite3/stream_id_table.go +++ b/syncapi/storage/sqlite3/stream_id_table.go @@ -28,14 +28,12 @@ const selectStreamIDStmt = "" + type streamIDStatements struct { db *sql.DB - writer sqlutil.TransactionWriter increaseStreamIDStmt *sql.Stmt selectStreamIDStmt *sql.Stmt } func (s *streamIDStatements) prepare(db *sql.DB) (err error) { s.db = db - s.writer = sqlutil.NewTransactionWriter() _, err = db.Exec(streamIDTableSchema) if err != nil { return @@ -52,14 +50,9 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) { func (s *streamIDStatements) nextStreamID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt) - err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - if _, ierr := increaseStmt.ExecContext(ctx, "global"); err != nil { - return ierr - } - if serr := selectStmt.QueryRowContext(ctx, "global").Scan(&pos); err != nil { - return serr - } - return nil - }) + if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil { + return + } + err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos) return } diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index 9564a23a..81197bb7 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -31,7 +31,8 @@ import ( // both the database for PDUs and caches for EDUs. type SyncServerDatasource struct { shared.Database - db *sql.DB + db *sql.DB + writer sqlutil.Writer sqlutil.PartitionOffsetStatements streamID streamIDStatements } @@ -44,6 +45,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e if d.db, err = sqlutil.Open(dbProperties); err != nil { return nil, err } + d.writer = sqlutil.NewExclusiveWriter() if err = d.prepare(); err != nil { return nil, err } @@ -51,7 +53,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e } func (d *SyncServerDatasource) prepare() (err error) { - if err = d.PartitionOffsetStatements.Prepare(d.db, "syncapi"); err != nil { + if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "syncapi"); err != nil { return err } if err = d.streamID.prepare(d.db); err != nil { @@ -91,6 +93,7 @@ func (d *SyncServerDatasource) prepare() (err error) { } d.Database = shared.Database{ DB: d.db, + Writer: sqlutil.NewExclusiveWriter(), Invites: invites, AccountData: accountData, OutputEvents: events, @@ -99,7 +102,6 @@ func (d *SyncServerDatasource) prepare() (err error) { Topology: topology, Filter: filter, SendToDevice: sendToDevice, - SendToDeviceWriter: sqlutil.NewTransactionWriter(), EDUCache: cache.New(), } return nil |