From 9d53351dc20283103bf2eec6b92831033d06c5a8 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 21 Aug 2020 10:42:08 +0100 Subject: Component-wide TransactionWriters (#1290) * Offset updates take place using TransactionWriter * Refactor TransactionWriter in current state server * Refactor TransactionWriter in federation sender * Refactor TransactionWriter in key server * Refactor TransactionWriter in media API * Refactor TransactionWriter in server key API * Refactor TransactionWriter in sync API * Refactor TransactionWriter in user API * Fix deadlocking Sync API tests * Un-deadlock device database * Fix appservice API * Rename TransactionWriters to Writers * Move writers up a layer in sync API * Document sqlutil.Writer interface * Add note to Writer documentation --- appservice/storage/postgres/storage.go | 4 +- .../storage/sqlite3/appservice_events_table.go | 6 +- appservice/storage/sqlite3/storage.go | 8 ++- appservice/storage/sqlite3/txn_id_counter_table.go | 6 +- currentstateserver/storage/postgres/storage.go | 7 ++- currentstateserver/storage/shared/storage.go | 3 +- .../storage/sqlite3/current_room_state_table.go | 4 +- currentstateserver/storage/sqlite3/storage.go | 7 ++- federationsender/storage/postgres/storage.go | 7 ++- federationsender/storage/shared/storage.go | 18 ++++-- .../storage/sqlite3/blacklist_table.go | 20 +++---- .../storage/sqlite3/joined_hosts_table.go | 26 ++++----- .../storage/sqlite3/queue_edus_table.go | 30 ++++------ .../storage/sqlite3/queue_json_table.go | 33 +++++------ .../storage/sqlite3/queue_pdus_table.go | 30 ++++------ federationsender/storage/sqlite3/room_table.go | 18 ++---- federationsender/storage/sqlite3/storage.go | 7 ++- internal/sqlutil/partition_offset_table.go | 13 ++++- internal/sqlutil/sql.go | 4 -- internal/sqlutil/writer.go | 46 +++++++++++++++ internal/sqlutil/writer_dummy.go | 16 +++-- internal/sqlutil/writer_exclusive.go | 21 +++---- keyserver/storage/sqlite3/device_keys_table.go | 6 +- keyserver/storage/sqlite3/key_changes_table.go | 6 +- keyserver/storage/sqlite3/one_time_keys_table.go | 6 +- keyserver/storage/sqlite3/stale_device_lists.go | 17 ++++-- keyserver/storage/sqlite3/storage.go | 9 +-- mediaapi/storage/sqlite3/media_repository_table.go | 6 +- mediaapi/storage/sqlite3/sql.go | 8 ++- mediaapi/storage/sqlite3/storage.go | 7 ++- mediaapi/storage/sqlite3/thumbnail_table.go | 34 ++++++----- roomserver/storage/postgres/storage.go | 2 +- roomserver/storage/shared/storage.go | 2 +- roomserver/storage/sqlite3/storage.go | 6 +- serverkeyapi/storage/sqlite3/keydb.go | 7 ++- serverkeyapi/storage/sqlite3/server_key_table.go | 6 +- syncapi/storage/postgres/syncserver.go | 8 ++- syncapi/storage/shared/syncserver.go | 31 +++++----- syncapi/storage/sqlite3/account_data_table.go | 18 ++---- .../storage/sqlite3/backwards_extremities_table.go | 17 ++---- .../storage/sqlite3/current_room_state_table.go | 40 ++++++------- syncapi/storage/sqlite3/filter_table.go | 58 +++++++++--------- syncapi/storage/sqlite3/invites_table.go | 56 ++++++++---------- .../storage/sqlite3/output_room_events_table.go | 55 ++++++++--------- .../sqlite3/output_room_events_topology_table.go | 16 ++--- syncapi/storage/sqlite3/send_to_device_table.go | 22 +++---- syncapi/storage/sqlite3/stream_id_table.go | 15 ++--- syncapi/storage/sqlite3/syncserver.go | 8 ++- userapi/storage/accounts/postgres/storage.go | 25 ++++---- .../storage/accounts/sqlite3/account_data_table.go | 6 +- userapi/storage/accounts/sqlite3/accounts_table.go | 6 +- userapi/storage/accounts/sqlite3/profile_table.go | 6 +- userapi/storage/accounts/sqlite3/storage.go | 33 +++++------ userapi/storage/accounts/sqlite3/threepid_table.go | 6 +- userapi/storage/devices/sqlite3/devices_table.go | 68 +++++++++------------- userapi/storage/devices/sqlite3/storage.go | 18 +++--- 56 files changed, 484 insertions(+), 484 deletions(-) create mode 100644 internal/sqlutil/writer.go diff --git a/appservice/storage/postgres/storage.go b/appservice/storage/postgres/storage.go index 9fda87ae..95215816 100644 --- a/appservice/storage/postgres/storage.go +++ b/appservice/storage/postgres/storage.go @@ -32,6 +32,7 @@ type Database struct { events eventsStatements txnID txnStatements db *sql.DB + writer sqlutil.Writer } // NewDatabase opens a new database @@ -41,10 +42,11 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) { if result.db, err = sqlutil.Open(dbProperties); err != nil { return nil, err } + result.writer = sqlutil.NewDummyWriter() if err = result.prepare(); err != nil { return nil, err } - if err = result.PartitionOffsetStatements.Prepare(result.db, "appservice"); err != nil { + if err = result.PartitionOffsetStatements.Prepare(result.db, result.writer, "appservice"); err != nil { return nil, err } return &result, nil diff --git a/appservice/storage/sqlite3/appservice_events_table.go b/appservice/storage/sqlite3/appservice_events_table.go index 5cc07ed3..5dfb72f6 100644 --- a/appservice/storage/sqlite3/appservice_events_table.go +++ b/appservice/storage/sqlite3/appservice_events_table.go @@ -67,7 +67,7 @@ const ( type eventsStatements struct { db *sql.DB - writer sqlutil.TransactionWriter + writer sqlutil.Writer selectEventsByApplicationServiceIDStmt *sql.Stmt countEventsByApplicationServiceIDStmt *sql.Stmt insertEventStmt *sql.Stmt @@ -75,9 +75,9 @@ type eventsStatements struct { deleteEventsBeforeAndIncludingIDStmt *sql.Stmt } -func (s *eventsStatements) prepare(db *sql.DB) (err error) { +func (s *eventsStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { s.db = db - s.writer = sqlutil.NewTransactionWriter() + s.writer = writer _, err = db.Exec(appserviceEventsSchema) if err != nil { return diff --git a/appservice/storage/sqlite3/storage.go b/appservice/storage/sqlite3/storage.go index 59af9016..916845ab 100644 --- a/appservice/storage/sqlite3/storage.go +++ b/appservice/storage/sqlite3/storage.go @@ -32,6 +32,7 @@ type Database struct { events eventsStatements txnID txnStatements db *sql.DB + writer sqlutil.Writer } // NewDatabase opens a new database @@ -41,21 +42,22 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) { if result.db, err = sqlutil.Open(dbProperties); err != nil { return nil, err } + result.writer = sqlutil.NewExclusiveWriter() if err = result.prepare(); err != nil { return nil, err } - if err = result.PartitionOffsetStatements.Prepare(result.db, "appservice"); err != nil { + if err = result.PartitionOffsetStatements.Prepare(result.db, result.writer, "appservice"); err != nil { return nil, err } return &result, nil } func (d *Database) prepare() error { - if err := d.events.prepare(d.db); err != nil { + if err := d.events.prepare(d.db, d.writer); err != nil { return err } - return d.txnID.prepare(d.db) + return d.txnID.prepare(d.db, d.writer) } // StoreEvent takes in a gomatrixserverlib.HeaderedEvent and stores it in the database diff --git a/appservice/storage/sqlite3/txn_id_counter_table.go b/appservice/storage/sqlite3/txn_id_counter_table.go index 0ae0feee..b2940e35 100644 --- a/appservice/storage/sqlite3/txn_id_counter_table.go +++ b/appservice/storage/sqlite3/txn_id_counter_table.go @@ -38,13 +38,13 @@ const selectTxnIDSQL = ` type txnStatements struct { db *sql.DB - writer sqlutil.TransactionWriter + writer sqlutil.Writer selectTxnIDStmt *sql.Stmt } -func (s *txnStatements) prepare(db *sql.DB) (err error) { +func (s *txnStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { s.db = db - s.writer = sqlutil.NewTransactionWriter() + s.writer = writer _, err = db.Exec(txnIDSchema) if err != nil { return diff --git a/currentstateserver/storage/postgres/storage.go b/currentstateserver/storage/postgres/storage.go index 0cd7e555..cb5ebff0 100644 --- a/currentstateserver/storage/postgres/storage.go +++ b/currentstateserver/storage/postgres/storage.go @@ -10,7 +10,8 @@ import ( type Database struct { shared.Database - db *sql.DB + db *sql.DB + writer sqlutil.Writer sqlutil.PartitionOffsetStatements } @@ -21,7 +22,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) { if d.db, err = sqlutil.Open(dbProperties); err != nil { return nil, err } - if err = d.PartitionOffsetStatements.Prepare(d.db, "currentstate"); err != nil { + d.writer = sqlutil.NewDummyWriter() + if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "currentstate"); err != nil { return nil, err } currRoomState, err := NewPostgresCurrentRoomStateTable(d.db) @@ -30,6 +32,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) { } d.Database = shared.Database{ DB: d.db, + Writer: d.writer, CurrentRoomState: currRoomState, } return &d, nil diff --git a/currentstateserver/storage/shared/storage.go b/currentstateserver/storage/shared/storage.go index 46ef9e6c..2cf40ccc 100644 --- a/currentstateserver/storage/shared/storage.go +++ b/currentstateserver/storage/shared/storage.go @@ -27,6 +27,7 @@ import ( type Database struct { DB *sql.DB + Writer sqlutil.Writer CurrentRoomState tables.CurrentRoomState } @@ -59,7 +60,7 @@ func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, reda func (d *Database) StoreStateEvents(ctx context.Context, addStateEvents []gomatrixserverlib.HeaderedEvent, removeStateEventIDs []string) error { - return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { // remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add. for _, eventID := range removeStateEventIDs { if err := d.CurrentRoomState.DeleteRoomStateByEventID(ctx, txn, eventID); err != nil { diff --git a/currentstateserver/storage/sqlite3/current_room_state_table.go b/currentstateserver/storage/sqlite3/current_room_state_table.go index 9d2fe6e0..c6cf40ed 100644 --- a/currentstateserver/storage/sqlite3/current_room_state_table.go +++ b/currentstateserver/storage/sqlite3/current_room_state_table.go @@ -83,7 +83,7 @@ const selectKnownUsersSQL = "" + type currentRoomStateStatements struct { db *sql.DB - writer sqlutil.TransactionWriter + writer sqlutil.Writer upsertRoomStateStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt selectRoomIDsWithMembershipStmt *sql.Stmt @@ -96,7 +96,7 @@ type currentRoomStateStatements struct { func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) { s := ¤tRoomStateStatements{ db: db, - writer: sqlutil.NewTransactionWriter(), + writer: sqlutil.NewExclusiveWriter(), } _, err := db.Exec(currentRoomStateSchema) if err != nil { diff --git a/currentstateserver/storage/sqlite3/storage.go b/currentstateserver/storage/sqlite3/storage.go index 4454c9ed..e79afd70 100644 --- a/currentstateserver/storage/sqlite3/storage.go +++ b/currentstateserver/storage/sqlite3/storage.go @@ -10,7 +10,8 @@ import ( type Database struct { shared.Database - db *sql.DB + db *sql.DB + writer sqlutil.Writer sqlutil.PartitionOffsetStatements } @@ -22,7 +23,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) { if d.db, err = sqlutil.Open(dbProperties); err != nil { return nil, err } - if err = d.PartitionOffsetStatements.Prepare(d.db, "currentstate"); err != nil { + d.writer = sqlutil.NewExclusiveWriter() + if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "currentstate"); err != nil { return nil, err } currRoomState, err := NewSqliteCurrentRoomStateTable(d.db) @@ -31,6 +33,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) { } d.Database = shared.Database{ DB: d.db, + Writer: d.writer, CurrentRoomState: currRoomState, } return &d, nil diff --git a/federationsender/storage/postgres/storage.go b/federationsender/storage/postgres/storage.go index b65ff0b6..b3b4da39 100644 --- a/federationsender/storage/postgres/storage.go +++ b/federationsender/storage/postgres/storage.go @@ -27,7 +27,8 @@ import ( type Database struct { shared.Database sqlutil.PartitionOffsetStatements - db *sql.DB + db *sql.DB + writer sqlutil.Writer } // NewDatabase opens a new database @@ -37,6 +38,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) { if d.db, err = sqlutil.Open(dbProperties); err != nil { return nil, err } + d.writer = sqlutil.NewDummyWriter() joinedHosts, err := NewPostgresJoinedHostsTable(d.db) if err != nil { return nil, err @@ -63,6 +65,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) { } d.Database = shared.Database{ DB: d.db, + Writer: d.writer, FederationSenderJoinedHosts: joinedHosts, FederationSenderQueuePDUs: queuePDUs, FederationSenderQueueEDUs: queueEDUs, @@ -70,7 +73,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) { FederationSenderRooms: rooms, FederationSenderBlacklist: blacklist, } - if err = d.PartitionOffsetStatements.Prepare(d.db, "federationsender"); err != nil { + if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "federationsender"); err != nil { return nil, err } return &d, nil diff --git a/federationsender/storage/shared/storage.go b/federationsender/storage/shared/storage.go index 4a681de6..4e347259 100644 --- a/federationsender/storage/shared/storage.go +++ b/federationsender/storage/shared/storage.go @@ -28,6 +28,7 @@ import ( type Database struct { DB *sql.DB + Writer sqlutil.Writer FederationSenderQueuePDUs tables.FederationSenderQueuePDUs FederationSenderQueueEDUs tables.FederationSenderQueueEDUs FederationSenderQueueJSON tables.FederationSenderQueueJSON @@ -64,7 +65,7 @@ func (d *Database) UpdateRoom( addHosts []types.JoinedHost, removeHosts []string, ) (joinedHosts []types.JoinedHost, err error) { - err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.FederationSenderRooms.InsertRoom(ctx, txn, roomID) if err != nil { return err @@ -133,7 +134,12 @@ func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string) func (d *Database) StoreJSON( ctx context.Context, js string, ) (*Receipt, error) { - nid, err := d.FederationSenderQueueJSON.InsertQueueJSON(ctx, nil, js) + var nid int64 + var err error + _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + nid, err = d.FederationSenderQueueJSON.InsertQueueJSON(ctx, txn, js) + return nil + }) if err != nil { return nil, fmt.Errorf("d.insertQueueJSON: %w", err) } @@ -143,11 +149,15 @@ func (d *Database) StoreJSON( } func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error { - return d.FederationSenderBlacklist.InsertBlacklist(context.TODO(), nil, serverName) + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationSenderBlacklist.InsertBlacklist(context.TODO(), txn, serverName) + }) } func (d *Database) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error { - return d.FederationSenderBlacklist.DeleteBlacklist(context.TODO(), nil, serverName) + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationSenderBlacklist.DeleteBlacklist(context.TODO(), txn, serverName) + }) } func (d *Database) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) { diff --git a/federationsender/storage/sqlite3/blacklist_table.go b/federationsender/storage/sqlite3/blacklist_table.go index b23bfcba..90b44ac9 100644 --- a/federationsender/storage/sqlite3/blacklist_table.go +++ b/federationsender/storage/sqlite3/blacklist_table.go @@ -42,7 +42,6 @@ const deleteBlacklistSQL = "" + type blacklistStatements struct { db *sql.DB - writer sqlutil.TransactionWriter insertBlacklistStmt *sql.Stmt selectBlacklistStmt *sql.Stmt deleteBlacklistStmt *sql.Stmt @@ -50,8 +49,7 @@ type blacklistStatements struct { func NewSQLiteBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) { s = &blacklistStatements{ - db: db, - writer: sqlutil.NewTransactionWriter(), + db: db, } _, err = db.Exec(blacklistSchema) if err != nil { @@ -75,11 +73,9 @@ func NewSQLiteBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) { func (s *blacklistStatements) InsertBlacklist( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt) - _, err := stmt.ExecContext(ctx, serverName) - return err - }) + stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err } // selectRoomForUpdate locks the row for the room and returns the last_event_id. @@ -105,9 +101,7 @@ func (s *blacklistStatements) SelectBlacklist( func (s *blacklistStatements) DeleteBlacklist( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt) - _, err := stmt.ExecContext(ctx, serverName) - return err - }) + stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err } diff --git a/federationsender/storage/sqlite3/joined_hosts_table.go b/federationsender/storage/sqlite3/joined_hosts_table.go index 5dc18f4e..3bc45e7d 100644 --- a/federationsender/storage/sqlite3/joined_hosts_table.go +++ b/federationsender/storage/sqlite3/joined_hosts_table.go @@ -65,7 +65,6 @@ const selectJoinedHostsForRoomsSQL = "" + type joinedHostsStatements struct { db *sql.DB - writer sqlutil.TransactionWriter insertJoinedHostsStmt *sql.Stmt deleteJoinedHostsStmt *sql.Stmt selectJoinedHostsStmt *sql.Stmt @@ -75,8 +74,7 @@ type joinedHostsStatements struct { func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { s = &joinedHostsStatements{ - db: db, - writer: sqlutil.NewTransactionWriter(), + db: db, } _, err = db.Exec(joinedHostsSchema) if err != nil { @@ -103,25 +101,21 @@ func (s *joinedHostsStatements) InsertJoinedHosts( roomID, eventID string, serverName gomatrixserverlib.ServerName, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt) - _, err := stmt.ExecContext(ctx, roomID, eventID, serverName) - return err - }) + stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt) + _, err := stmt.ExecContext(ctx, roomID, eventID, serverName) + return err } func (s *joinedHostsStatements) DeleteJoinedHosts( ctx context.Context, txn *sql.Tx, eventIDs []string, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - for _, eventID := range eventIDs { - stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt) - if _, err := stmt.ExecContext(ctx, eventID); err != nil { - return err - } + for _, eventID := range eventIDs { + stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt) + if _, err := stmt.ExecContext(ctx, eventID); err != nil { + return err } - return nil - }) + } + return nil } func (s *joinedHostsStatements) SelectJoinedHostsWithTx( diff --git a/federationsender/storage/sqlite3/queue_edus_table.go b/federationsender/storage/sqlite3/queue_edus_table.go index 2abcc105..a6d60950 100644 --- a/federationsender/storage/sqlite3/queue_edus_table.go +++ b/federationsender/storage/sqlite3/queue_edus_table.go @@ -64,7 +64,6 @@ const selectQueueServerNamesSQL = "" + type queueEDUsStatements struct { db *sql.DB - writer sqlutil.TransactionWriter insertQueueEDUStmt *sql.Stmt selectQueueEDUStmt *sql.Stmt selectQueueEDUReferenceJSONCountStmt *sql.Stmt @@ -74,8 +73,7 @@ type queueEDUsStatements struct { func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) { s = &queueEDUsStatements{ - db: db, - writer: sqlutil.NewTransactionWriter(), + db: db, } _, err = db.Exec(queueEDUsSchema) if err != nil { @@ -106,16 +104,14 @@ func (s *queueEDUsStatements) InsertQueueEDU( serverName gomatrixserverlib.ServerName, nid int64, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt) - _, err := stmt.ExecContext( - ctx, - eduType, // the EDU type - serverName, // destination server name - nid, // JSON blob NID - ) - return err - }) + stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt) + _, err := stmt.ExecContext( + ctx, + eduType, // the EDU type + serverName, // destination server name + nid, // JSON blob NID + ) + return err } func (s *queueEDUsStatements) DeleteQueueEDUs( @@ -135,11 +131,9 @@ func (s *queueEDUsStatements) DeleteQueueEDUs( params[k+1] = v } - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, deleteStmt) - _, err := stmt.ExecContext(ctx, params...) - return err - }) + stmt := sqlutil.TxStmt(txn, deleteStmt) + _, err = stmt.ExecContext(ctx, params...) + return err } func (s *queueEDUsStatements) SelectQueueEDUs( diff --git a/federationsender/storage/sqlite3/queue_json_table.go b/federationsender/storage/sqlite3/queue_json_table.go index 867ffd44..3e3f60f6 100644 --- a/federationsender/storage/sqlite3/queue_json_table.go +++ b/federationsender/storage/sqlite3/queue_json_table.go @@ -50,7 +50,6 @@ const selectJSONSQL = "" + type queueJSONStatements struct { db *sql.DB - writer sqlutil.TransactionWriter insertJSONStmt *sql.Stmt //deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic //selectJSONStmt *sql.Stmt - prepared at runtime due to variadic @@ -58,8 +57,7 @@ type queueJSONStatements struct { func NewSQLiteQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) { s = &queueJSONStatements{ - db: db, - writer: sqlutil.NewTransactionWriter(), + db: db, } _, err = db.Exec(queueJSONSchema) if err != nil { @@ -74,18 +72,15 @@ func NewSQLiteQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) { func (s *queueJSONStatements) InsertQueueJSON( ctx context.Context, txn *sql.Tx, json string, ) (lastid int64, err error) { - err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertJSONStmt) - res, err := stmt.ExecContext(ctx, json) - if err != nil { - return fmt.Errorf("stmt.QueryContext: %w", err) - } - lastid, err = res.LastInsertId() - if err != nil { - return fmt.Errorf("res.LastInsertId: %w", err) - } - return nil - }) + stmt := sqlutil.TxStmt(txn, s.insertJSONStmt) + res, err := stmt.ExecContext(ctx, json) + if err != nil { + return 0, fmt.Errorf("stmt.QueryContext: %w", err) + } + lastid, err = res.LastInsertId() + if err != nil { + return 0, fmt.Errorf("res.LastInsertId: %w", err) + } return } @@ -103,11 +98,9 @@ func (s *queueJSONStatements) DeleteQueueJSON( iNIDs[k] = v } - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, deleteStmt) - _, err = stmt.ExecContext(ctx, iNIDs...) - return err - }) + stmt := sqlutil.TxStmt(txn, deleteStmt) + _, err = stmt.ExecContext(ctx, iNIDs...) + return err } func (s *queueJSONStatements) SelectQueueJSON( diff --git a/federationsender/storage/sqlite3/queue_pdus_table.go b/federationsender/storage/sqlite3/queue_pdus_table.go index 538ba3db..70519c9e 100644 --- a/federationsender/storage/sqlite3/queue_pdus_table.go +++ b/federationsender/storage/sqlite3/queue_pdus_table.go @@ -71,7 +71,6 @@ const selectQueuePDUsServerNamesSQL = "" + type queuePDUsStatements struct { db *sql.DB - writer sqlutil.TransactionWriter insertQueuePDUStmt *sql.Stmt selectQueueNextTransactionIDStmt *sql.Stmt selectQueuePDUsByTransactionStmt *sql.Stmt @@ -83,8 +82,7 @@ type queuePDUsStatements struct { func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) { s = &queuePDUsStatements{ - db: db, - writer: sqlutil.NewTransactionWriter(), + db: db, } _, err = db.Exec(queuePDUsSchema) if err != nil { @@ -121,16 +119,14 @@ func (s *queuePDUsStatements) InsertQueuePDU( serverName gomatrixserverlib.ServerName, nid int64, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertQueuePDUStmt) - _, err := stmt.ExecContext( - ctx, - transactionID, // the transaction ID that we initially attempted - serverName, // destination server name - nid, // JSON blob NID - ) - return err - }) + stmt := sqlutil.TxStmt(txn, s.insertQueuePDUStmt) + _, err := stmt.ExecContext( + ctx, + transactionID, // the transaction ID that we initially attempted + serverName, // destination server name + nid, // JSON blob NID + ) + return err } func (s *queuePDUsStatements) DeleteQueuePDUs( @@ -150,11 +146,9 @@ func (s *queuePDUsStatements) DeleteQueuePDUs( params[k+1] = v } - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, deleteStmt) - _, err := stmt.ExecContext(ctx, params...) - return err - }) + stmt := sqlutil.TxStmt(txn, deleteStmt) + _, err = stmt.ExecContext(ctx, params...) + return err } func (s *queuePDUsStatements) SelectQueuePDUNextTransactionID( diff --git a/federationsender/storage/sqlite3/room_table.go b/federationsender/storage/sqlite3/room_table.go index 9a439fad..0710ccca 100644 --- a/federationsender/storage/sqlite3/room_table.go +++ b/federationsender/storage/sqlite3/room_table.go @@ -44,7 +44,6 @@ const updateRoomSQL = "" + type roomStatements struct { db *sql.DB - writer sqlutil.TransactionWriter insertRoomStmt *sql.Stmt selectRoomForUpdateStmt *sql.Stmt updateRoomStmt *sql.Stmt @@ -52,8 +51,7 @@ type roomStatements struct { func NewSQLiteRoomsTable(db *sql.DB) (s *roomStatements, err error) { s = &roomStatements{ - db: db, - writer: sqlutil.NewTransactionWriter(), + db: db, } _, err = db.Exec(roomSchema) if err != nil { @@ -77,10 +75,8 @@ func NewSQLiteRoomsTable(db *sql.DB) (s *roomStatements, err error) { func (s *roomStatements) InsertRoom( ctx context.Context, txn *sql.Tx, roomID string, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - _, err := sqlutil.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID) - return err - }) + _, err := sqlutil.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID) + return err } // selectRoomForUpdate locks the row for the room and returns the last_event_id. @@ -103,9 +99,7 @@ func (s *roomStatements) SelectRoomForUpdate( func (s *roomStatements) UpdateRoom( ctx context.Context, txn *sql.Tx, roomID, lastEventID string, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.updateRoomStmt) - _, err := stmt.ExecContext(ctx, roomID, lastEventID) - return err - }) + stmt := sqlutil.TxStmt(txn, s.updateRoomStmt) + _, err := stmt.ExecContext(ctx, roomID, lastEventID) + return err } diff --git a/federationsender/storage/sqlite3/storage.go b/federationsender/storage/sqlite3/storage.go index 41b91871..ba467f02 100644 --- a/federationsender/storage/sqlite3/storage.go +++ b/federationsender/storage/sqlite3/storage.go @@ -29,7 +29,8 @@ import ( type Database struct { shared.Database sqlutil.PartitionOffsetStatements - db *sql.DB + db *sql.DB + writer sqlutil.Writer } // NewDatabase opens a new database @@ -39,6 +40,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) { if d.db, err = sqlutil.Open(dbProperties); err != nil { return nil, err } + d.writer = sqlutil.NewExclusiveWriter() joinedHosts, err := NewSQLiteJoinedHostsTable(d.db) if err != nil { return nil, err @@ -65,6 +67,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) { } d.Database = shared.Database{ DB: d.db, + Writer: d.writer, FederationSenderJoinedHosts: joinedHosts, FederationSenderQueuePDUs: queuePDUs, FederationSenderQueueEDUs: queueEDUs, @@ -72,7 +75,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) { FederationSenderRooms: rooms, FederationSenderBlacklist: blacklist, } - if err = d.PartitionOffsetStatements.Prepare(d.db, "federationsender"); err != nil { + if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "federationsender"); err != nil { return nil, err } return &d, nil diff --git a/internal/sqlutil/partition_offset_table.go b/internal/sqlutil/partition_offset_table.go index 34882902..be079442 100644 --- a/internal/sqlutil/partition_offset_table.go +++ b/internal/sqlutil/partition_offset_table.go @@ -53,6 +53,8 @@ const upsertPartitionOffsetsSQL = "" + // PartitionOffsetStatements represents a set of statements that can be run on a partition_offsets table. type PartitionOffsetStatements struct { + db *sql.DB + writer Writer selectPartitionOffsetsStmt *sql.Stmt upsertPartitionOffsetStmt *sql.Stmt } @@ -60,7 +62,9 @@ type PartitionOffsetStatements struct { // Prepare converts the raw SQL statements into prepared statements. // Takes a prefix to prepend to the table name used to store the partition offsets. // This allows multiple components to share the same database schema. -func (s *PartitionOffsetStatements) Prepare(db *sql.DB, prefix string) (err error) { +func (s *PartitionOffsetStatements) Prepare(db *sql.DB, writer Writer, prefix string) (err error) { + s.db = db + s.writer = writer _, err = db.Exec(strings.Replace(partitionOffsetsSchema, "${prefix}", prefix, -1)) if err != nil { return @@ -121,6 +125,9 @@ func (s *PartitionOffsetStatements) selectPartitionOffsets( func (s *PartitionOffsetStatements) upsertPartitionOffset( ctx context.Context, topic string, partition int32, offset int64, ) error { - _, err := s.upsertPartitionOffsetStmt.ExecContext(ctx, topic, partition, offset) - return err + return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + stmt := TxStmt(txn, s.upsertPartitionOffsetStmt) + _, err := stmt.ExecContext(ctx, topic, partition, offset) + return err + }) } diff --git a/internal/sqlutil/sql.go b/internal/sqlutil/sql.go index 002d7718..d296c418 100644 --- a/internal/sqlutil/sql.go +++ b/internal/sqlutil/sql.go @@ -103,7 +103,3 @@ func SQLiteDriverName() string { } return "sqlite3" } - -type TransactionWriter interface { - Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error -} diff --git a/internal/sqlutil/writer.go b/internal/sqlutil/writer.go new file mode 100644 index 00000000..5d93fef4 --- /dev/null +++ b/internal/sqlutil/writer.go @@ -0,0 +1,46 @@ +package sqlutil + +import "database/sql" + +// The Writer interface is designed to solve the problem of how +// to handle database writes for database engines that don't allow +// concurrent writes, e.g. SQLite. +// +// The interface has a single Do function which takes an optional +// database parameter, an optional transaction parameter and a +// required function parameter. The Writer will call the function +// provided when it is safe to do so, optionally providing a +// transaction to use. +// +// Depending on the combination of parameters provided, the Writer +// will behave in one of three ways: +// +// 1. `db` provided, `txn` provided: +// +// The Writer will call f() when it is safe to do so. The supplied +// "txn" will ALWAYS be passed through to f(). Use this when you +// already have a transaction open. +// +// 2. `db` provided, `txn` not provided (nil): +// +// The Writer will open a new transaction on the provided database +// and then will call f() when it is safe to do so. The new +// transaction will ALWAYS be passed through to f(). Use this if +// you plan to perform more than one SQL query within f(). +// +// 3. `db` not provided (nil), `txn` not provided (nil): +// +// The Writer will call f() when it is safe to do so, but will +// not make any attempt to open a new database transaction or to +// pass through an existing one. The "txn" parameter within f() +// will ALWAYS be nil in this mode. This is useful if you just +// want to perform a single query on an already-prepared statement +// without the overhead of opening a new transaction to do it in. +// +// You MUST take particular care not to call Do() from within f() +// on the same Writer, or it will likely result in a deadlock. +type Writer interface { + // Queue up one or more database write operations within the + // provided function to be executed when it is safe to do so. + Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error +} diff --git a/internal/sqlutil/writer_dummy.go b/internal/sqlutil/writer_dummy.go index e6ab81f6..f426c2bc 100644 --- a/internal/sqlutil/writer_dummy.go +++ b/internal/sqlutil/writer_dummy.go @@ -4,15 +4,21 @@ import ( "database/sql" ) -type DummyTransactionWriter struct { +// DummyWriter implements sqlutil.Writer. +// The DummyWriter is designed to allow reuse of the sqlutil.Writer +// interface but, unlike ExclusiveWriter, it will not guarantee +// writer exclusivity. This is fine in PostgreSQL where overlapping +// transactions and writes are acceptable. +type DummyWriter struct { } -func NewDummyTransactionWriter() TransactionWriter { - return &DummyTransactionWriter{} +// NewDummyWriter returns a new dummy writer. +func NewDummyWriter() Writer { + return &DummyWriter{} } -func (w *DummyTransactionWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error { - if txn == nil { +func (w *DummyWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error { + if db != nil && txn == nil { return WithTransaction(db, func(txn *sql.Tx) error { return f(txn) }) diff --git a/internal/sqlutil/writer_exclusive.go b/internal/sqlutil/writer_exclusive.go index 2e3666ae..002bc32c 100644 --- a/internal/sqlutil/writer_exclusive.go +++ b/internal/sqlutil/writer_exclusive.go @@ -7,16 +7,17 @@ import ( "go.uber.org/atomic" ) -// ExclusiveTransactionWriter allows queuing database writes so that you don't +// ExclusiveWriter implements sqlutil.Writer. +// ExclusiveWriter allows queuing database writes so that you don't // contend on database locks in, e.g. SQLite. Only one task will run -// at a time on a given ExclusiveTransactionWriter. -type ExclusiveTransactionWriter struct { +// at a time on a given ExclusiveWriter. +type ExclusiveWriter struct { running atomic.Bool todo chan transactionWriterTask } -func NewTransactionWriter() TransactionWriter { - return &ExclusiveTransactionWriter{ +func NewExclusiveWriter() Writer { + return &ExclusiveWriter{ todo: make(chan transactionWriterTask), } } @@ -34,7 +35,7 @@ type transactionWriterTask struct { // txn parameter if one is supplied, and if not, will take out a // new transaction from the database supplied in the database // parameter. Either way, this will block until the task is done. -func (w *ExclusiveTransactionWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error { +func (w *ExclusiveWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error { if w.todo == nil { return errors.New("not initialised") } @@ -55,20 +56,20 @@ func (w *ExclusiveTransactionWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql // of these goroutines will run at a time. A transaction will be // opened using the database object from the task and then this will // be passed as a parameter to the task function. -func (w *ExclusiveTransactionWriter) run() { +func (w *ExclusiveWriter) run() { if !w.running.CAS(false, true) { return } defer w.running.Store(false) for task := range w.todo { - if task.txn != nil { + if task.db != nil && task.txn != nil { task.wait <- task.f(task.txn) - } else if task.db != nil { + } else if task.db != nil && task.txn == nil { task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error { return task.f(txn) }) } else { - panic("expected database or transaction but got neither") + task.wait <- task.f(nil) } close(task.wait) } diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index c95790be..2af33761 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -63,7 +63,7 @@ const deleteAllDeviceKeysSQL = "" + type deviceKeysStatements struct { db *sql.DB - writer sqlutil.TransactionWriter + writer sqlutil.Writer upsertDeviceKeysStmt *sql.Stmt selectDeviceKeysStmt *sql.Stmt selectBatchDeviceKeysStmt *sql.Stmt @@ -71,10 +71,10 @@ type deviceKeysStatements struct { deleteAllDeviceKeysStmt *sql.Stmt } -func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { +func NewSqliteDeviceKeysTable(db *sql.DB, writer sqlutil.Writer) (tables.DeviceKeys, error) { s := &deviceKeysStatements{ db: db, - writer: sqlutil.NewTransactionWriter(), + writer: writer, } _, err := db.Exec(deviceKeysSchema) if err != nil { diff --git a/keyserver/storage/sqlite3/key_changes_table.go b/keyserver/storage/sqlite3/key_changes_table.go index f451d657..cd178413 100644 --- a/keyserver/storage/sqlite3/key_changes_table.go +++ b/keyserver/storage/sqlite3/key_changes_table.go @@ -52,15 +52,15 @@ const selectKeyChangesSQL = "" + type keyChangesStatements struct { db *sql.DB - writer sqlutil.TransactionWriter + writer sqlutil.Writer upsertKeyChangeStmt *sql.Stmt selectKeyChangesStmt *sql.Stmt } -func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { +func NewSqliteKeyChangesTable(db *sql.DB, writer sqlutil.Writer) (tables.KeyChanges, error) { s := &keyChangesStatements{ db: db, - writer: sqlutil.NewTransactionWriter(), + writer: writer, } _, err := db.Exec(keyChangesSchema) if err != nil { diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/keyserver/storage/sqlite3/one_time_keys_table.go index c71cc47d..d788f676 100644 --- a/keyserver/storage/sqlite3/one_time_keys_table.go +++ b/keyserver/storage/sqlite3/one_time_keys_table.go @@ -60,7 +60,7 @@ const selectKeyByAlgorithmSQL = "" + type oneTimeKeysStatements struct { db *sql.DB - writer sqlutil.TransactionWriter + writer sqlutil.Writer upsertKeysStmt *sql.Stmt selectKeysStmt *sql.Stmt selectKeysCountStmt *sql.Stmt @@ -68,10 +68,10 @@ type oneTimeKeysStatements struct { deleteOneTimeKeyStmt *sql.Stmt } -func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { +func NewSqliteOneTimeKeysTable(db *sql.DB, writer sqlutil.Writer) (tables.OneTimeKeys, error) { s := &oneTimeKeysStatements{ db: db, - writer: sqlutil.NewTransactionWriter(), + writer: writer, } _, err := db.Exec(oneTimeKeysSchema) if err != nil { diff --git a/keyserver/storage/sqlite3/stale_device_lists.go b/keyserver/storage/sqlite3/stale_device_lists.go index a989476d..8b6f8813 100644 --- a/keyserver/storage/sqlite3/stale_device_lists.go +++ b/keyserver/storage/sqlite3/stale_device_lists.go @@ -20,6 +20,7 @@ import ( "time" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/gomatrixserverlib" ) @@ -49,13 +50,18 @@ const selectStaleDeviceListsSQL = "" + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" type staleDeviceListsStatements struct { + db *sql.DB + writer sqlutil.Writer upsertStaleDeviceListStmt *sql.Stmt selectStaleDeviceListsWithDomainsStmt *sql.Stmt selectStaleDeviceListsStmt *sql.Stmt } -func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { - s := &staleDeviceListsStatements{} +func NewSqliteStaleDeviceListsTable(db *sql.DB, writer sqlutil.Writer) (tables.StaleDeviceLists, error) { + s := &staleDeviceListsStatements{ + db: db, + writer: writer, + } _, err := db.Exec(staleDeviceListsSchema) if err != nil { return nil, err @@ -77,8 +83,11 @@ func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, if err != nil { return err } - _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix()) - return err + return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.upsertStaleDeviceListStmt) + _, err = stmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix()) + return err + }) } func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { diff --git a/keyserver/storage/sqlite3/storage.go b/keyserver/storage/sqlite3/storage.go index bb293558..1a2a237f 100644 --- a/keyserver/storage/sqlite3/storage.go +++ b/keyserver/storage/sqlite3/storage.go @@ -25,19 +25,20 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) if err != nil { return nil, err } - otk, err := NewSqliteOneTimeKeysTable(db) + writer := sqlutil.NewExclusiveWriter() + otk, err := NewSqliteOneTimeKeysTable(db, writer) if err != nil { return nil, err } - dk, err := NewSqliteDeviceKeysTable(db) + dk, err := NewSqliteDeviceKeysTable(db, writer) if err != nil { return nil, err } - kc, err := NewSqliteKeyChangesTable(db) + kc, err := NewSqliteKeyChangesTable(db, writer) if err != nil { return nil, err } - sdl, err := NewSqliteStaleDeviceListsTable(db) + sdl, err := NewSqliteStaleDeviceListsTable(db, writer) if err != nil { return nil, err } diff --git a/mediaapi/storage/sqlite3/media_repository_table.go b/mediaapi/storage/sqlite3/media_repository_table.go index ff6ddf3d..dcc1b41e 100644 --- a/mediaapi/storage/sqlite3/media_repository_table.go +++ b/mediaapi/storage/sqlite3/media_repository_table.go @@ -62,14 +62,14 @@ SELECT content_type, file_size_bytes, creation_ts, upload_name, base64hash, user type mediaStatements struct { db *sql.DB - writer sqlutil.TransactionWriter + writer sqlutil.Writer insertMediaStmt *sql.Stmt selectMediaStmt *sql.Stmt } -func (s *mediaStatements) prepare(db *sql.DB) (err error) { +func (s *mediaStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { s.db = db - s.writer = sqlutil.NewTransactionWriter() + s.writer = writer _, err = db.Exec(mediaSchema) if err != nil { diff --git a/mediaapi/storage/sqlite3/sql.go b/mediaapi/storage/sqlite3/sql.go index 9cd78b8e..245bd40c 100644 --- a/mediaapi/storage/sqlite3/sql.go +++ b/mediaapi/storage/sqlite3/sql.go @@ -17,6 +17,8 @@ package sqlite3 import ( "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" ) type statements struct { @@ -24,11 +26,11 @@ type statements struct { thumbnail thumbnailStatements } -func (s *statements) prepare(db *sql.DB) (err error) { - if err = s.media.prepare(db); err != nil { +func (s *statements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { + if err = s.media.prepare(db, writer); err != nil { return } - if err = s.thumbnail.prepare(db); err != nil { + if err = s.thumbnail.prepare(db, writer); err != nil { return } diff --git a/mediaapi/storage/sqlite3/storage.go b/mediaapi/storage/sqlite3/storage.go index a1e7fec7..d5c3031e 100644 --- a/mediaapi/storage/sqlite3/storage.go +++ b/mediaapi/storage/sqlite3/storage.go @@ -31,16 +31,19 @@ import ( type Database struct { statements statements db *sql.DB + writer sqlutil.Writer } // Open opens a postgres database. func Open(dbProperties *config.DatabaseOptions) (*Database, error) { - var d Database + d := Database{ + writer: sqlutil.NewExclusiveWriter(), + } var err error if d.db, err = sqlutil.Open(dbProperties); err != nil { return nil, err } - if err = d.statements.prepare(d.db); err != nil { + if err = d.statements.prepare(d.db, d.writer); err != nil { return nil, err } return &d, nil diff --git a/mediaapi/storage/sqlite3/thumbnail_table.go b/mediaapi/storage/sqlite3/thumbnail_table.go index 432a1590..06b056b6 100644 --- a/mediaapi/storage/sqlite3/thumbnail_table.go +++ b/mediaapi/storage/sqlite3/thumbnail_table.go @@ -21,6 +21,7 @@ import ( "time" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/mediaapi/types" "github.com/matrix-org/gomatrixserverlib" ) @@ -57,16 +58,20 @@ SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method ` type thumbnailStatements struct { + db *sql.DB + writer sqlutil.Writer insertThumbnailStmt *sql.Stmt selectThumbnailStmt *sql.Stmt selectThumbnailsStmt *sql.Stmt } -func (s *thumbnailStatements) prepare(db *sql.DB) (err error) { +func (s *thumbnailStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { _, err = db.Exec(thumbnailSchema) if err != nil { return } + s.db = db + s.writer = writer return statementList{ {&s.insertThumbnailStmt, insertThumbnailSQL}, @@ -79,18 +84,21 @@ func (s *thumbnailStatements) insertThumbnail( ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata, ) error { thumbnailMetadata.MediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000) - _, err := s.insertThumbnailStmt.ExecContext( - ctx, - thumbnailMetadata.MediaMetadata.MediaID, - thumbnailMetadata.MediaMetadata.Origin, - thumbnailMetadata.MediaMetadata.ContentType, - thumbnailMetadata.MediaMetadata.FileSizeBytes, - thumbnailMetadata.MediaMetadata.CreationTimestamp, - thumbnailMetadata.ThumbnailSize.Width, - thumbnailMetadata.ThumbnailSize.Height, - thumbnailMetadata.ThumbnailSize.ResizeMethod, - ) - return err + return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.insertThumbnailStmt) + _, err := stmt.ExecContext( + ctx, + thumbnailMetadata.MediaMetadata.MediaID, + thumbnailMetadata.MediaMetadata.Origin, + thumbnailMetadata.MediaMetadata.ContentType, + thumbnailMetadata.MediaMetadata.FileSizeBytes, + thumbnailMetadata.MediaMetadata.CreationTimestamp, + thumbnailMetadata.ThumbnailSize.Width, + thumbnailMetadata.ThumbnailSize.Height, + thumbnailMetadata.ThumbnailSize.ResizeMethod, + ) + return err + }) } func (s *thumbnailStatements) selectThumbnail( diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 0b7ed225..d217b5d2 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -98,7 +98,7 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) { } d.Database = shared.Database{ DB: db, - Writer: sqlutil.NewDummyTransactionWriter(), + Writer: sqlutil.NewDummyWriter(), EventTypesTable: eventTypes, EventStateKeysTable: eventStateKeys, EventJSONTable: eventJSON, diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 56c2b029..7101376a 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -27,7 +27,7 @@ const redactionsArePermanent = false type Database struct { DB *sql.DB - Writer sqlutil.TransactionWriter + Writer sqlutil.Writer EventsTable tables.Events EventJSONTable tables.EventJSON EventTypesTable tables.EventTypes diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 874bbbc7..d1738966 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -41,7 +41,7 @@ type Database struct { invites tables.Invites membership tables.Membership db *sql.DB - writer sqlutil.TransactionWriter + writer sqlutil.Writer } // Open a sqlite database. @@ -52,7 +52,7 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) { if d.db, err = sqlutil.Open(dbProperties); err != nil { return nil, err } - d.writer = sqlutil.NewTransactionWriter() + d.writer = sqlutil.NewExclusiveWriter() //d.db.Exec("PRAGMA journal_mode=WAL;") //d.db.Exec("PRAGMA read_uncommitted = true;") @@ -120,7 +120,7 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) { } d.Database = shared.Database{ DB: d.db, - Writer: sqlutil.NewTransactionWriter(), + Writer: sqlutil.NewExclusiveWriter(), EventsTable: d.events, EventTypesTable: d.eventTypes, EventStateKeysTable: d.eventStateKeys, diff --git a/serverkeyapi/storage/sqlite3/keydb.go b/serverkeyapi/storage/sqlite3/keydb.go index 5174ece1..0ee74bc1 100644 --- a/serverkeyapi/storage/sqlite3/keydb.go +++ b/serverkeyapi/storage/sqlite3/keydb.go @@ -30,6 +30,7 @@ import ( // A Database implements gomatrixserverlib.KeyDatabase and is used to store // the public keys for other matrix servers. type Database struct { + writer sqlutil.Writer statements serverKeyStatements } @@ -47,8 +48,10 @@ func NewDatabase( if err != nil { return nil, err } - d := &Database{} - err = d.statements.prepare(db) + d := &Database{ + writer: sqlutil.NewExclusiveWriter(), + } + err = d.statements.prepare(db, d.writer) if err != nil { return nil, err } diff --git a/serverkeyapi/storage/sqlite3/server_key_table.go b/serverkeyapi/storage/sqlite3/server_key_table.go index b829eae7..f756ef5e 100644 --- a/serverkeyapi/storage/sqlite3/server_key_table.go +++ b/serverkeyapi/storage/sqlite3/server_key_table.go @@ -63,14 +63,14 @@ const upsertServerKeysSQL = "" + type serverKeyStatements struct { db *sql.DB - writer sqlutil.TransactionWriter + writer sqlutil.Writer bulkSelectServerKeysStmt *sql.Stmt upsertServerKeysStmt *sql.Stmt } -func (s *serverKeyStatements) prepare(db *sql.DB) (err error) { +func (s *serverKeyStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { s.db = db - s.writer = sqlutil.NewTransactionWriter() + s.writer = writer _, err = db.Exec(serverKeysSchema) if err != nil { return diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go index 26ef082f..36e8de67 100644 --- a/syncapi/storage/postgres/syncserver.go +++ b/syncapi/storage/postgres/syncserver.go @@ -30,7 +30,8 @@ import ( // both the database for PDUs and caches for EDUs. type SyncServerDatasource struct { shared.Database - db *sql.DB + db *sql.DB + writer sqlutil.Writer sqlutil.PartitionOffsetStatements } @@ -41,7 +42,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e if d.db, err = sqlutil.Open(dbProperties); err != nil { return nil, err } - if err = d.PartitionOffsetStatements.Prepare(d.db, "syncapi"); err != nil { + d.writer = sqlutil.NewDummyWriter() + if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "syncapi"); err != nil { return nil, err } accountData, err := NewPostgresAccountDataTable(d.db) @@ -78,6 +80,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e } d.Database = shared.Database{ DB: d.db, + Writer: sqlutil.NewDummyWriter(), Invites: invites, AccountData: accountData, OutputEvents: events, @@ -86,7 +89,6 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e BackwardExtremities: backwardExtremities, Filter: filter, SendToDevice: sendToDevice, - SendToDeviceWriter: sqlutil.NewTransactionWriter(), EDUCache: cache.New(), } return &d, nil diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index fdbf6758..699a6647 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -37,6 +37,7 @@ import ( // For now this contains the shared functions type Database struct { DB *sql.DB + Writer sqlutil.Writer Invites tables.Invites AccountData tables.AccountData OutputEvents tables.Events @@ -45,7 +46,6 @@ type Database struct { BackwardExtremities tables.BackwardsExtremities SendToDevice tables.SendToDevice Filter tables.Filter - SendToDeviceWriter sqlutil.TransactionWriter EDUCache *cache.EDUCache } @@ -129,10 +129,7 @@ func (d *Database) GetStateEvent( func (d *Database) GetStateEventsForRoom( ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter, ) (stateEvents []gomatrixserverlib.HeaderedEvent, err error) { - err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { - stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, txn, roomID, stateFilter) - return err - }) + stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, nil, roomID, stateFilter) return } @@ -171,9 +168,9 @@ func (d *Database) SyncStreamPosition(ctx context.Context) (types.StreamPosition func (d *Database) AddInviteEvent( ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent, ) (sp types.StreamPosition, err error) { - err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { - sp, err = d.Invites.InsertInviteEvent(ctx, txn, inviteEvent) - return err + _ = d.Writer.Do(nil, nil, func(_ *sql.Tx) error { + sp, err = d.Invites.InsertInviteEvent(ctx, nil, inviteEvent) + return nil }) return } @@ -182,8 +179,12 @@ func (d *Database) AddInviteEvent( // Returns an error if there was a problem communicating with the database. func (d *Database) RetireInviteEvent( ctx context.Context, inviteEventID string, -) (types.StreamPosition, error) { - return d.Invites.DeleteInviteEvent(ctx, inviteEventID) +) (sp types.StreamPosition, err error) { + _ = d.Writer.Do(nil, nil, func(_ *sql.Tx) error { + sp, err = d.Invites.DeleteInviteEvent(ctx, inviteEventID) + return nil + }) + return } // GetAccountDataInRange returns all account data for a given user inserted or @@ -207,7 +208,7 @@ func (d *Database) GetAccountDataInRange( func (d *Database) UpsertAccountData( ctx context.Context, userID, roomID, dataType string, ) (sp types.StreamPosition, err error) { - err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { sp, err = d.AccountData.InsertAccountData(ctx, txn, userID, roomID, dataType) return err }) @@ -237,6 +238,7 @@ func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.Strea // handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of // the events listed in the event's 'prev_events'. This function also updates the backwards extremities table // to account for the fact that the given event is no longer a backwards extremity, but may be marked as such. +// This function should always be called within a sqlutil.Writer for safety in SQLite. func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, ev *gomatrixserverlib.HeaderedEvent) error { if err := d.BackwardExtremities.DeleteBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID()); err != nil { return err @@ -275,7 +277,7 @@ func (d *Database) WriteEvent( addStateEventIDs, removeStateEventIDs []string, transactionID *api.TransactionID, excludeFromSync bool, ) (pduPosition types.StreamPosition, returnErr error) { - returnErr = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { + returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { var err error pos, err := d.OutputEvents.InsertEvent( ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync, @@ -304,6 +306,7 @@ func (d *Database) WriteEvent( return pduPosition, returnErr } +// This function should always be called within a sqlutil.Writer for safety in SQLite. func (d *Database) updateRoomState( ctx context.Context, txn *sql.Tx, removedEventIDs []string, @@ -1114,7 +1117,7 @@ func (d *Database) StoreNewSendForDeviceMessage( } // Delegate the database write task to the SendToDeviceWriter. It'll guarantee // that we don't lock the table for writes in more than one place. - err = d.SendToDeviceWriter.Do(d.DB, nil, func(txn *sql.Tx) error { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.AddSendToDeviceEvent( ctx, txn, userID, deviceID, string(j), ) @@ -1179,7 +1182,7 @@ func (d *Database) CleanSendToDeviceUpdates( // If we need to write to the database then we'll ask the SendToDeviceWriter to // do that for us. It'll guarantee that we don't lock the table for writes in // more than one place. - err = d.SendToDeviceWriter.Do(d.DB, nil, func(txn *sql.Tx) error { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { // Delete any send-to-device messages marked for deletion. if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil { return fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e) diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index 248ec926..72c46e48 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -20,7 +20,6 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" @@ -51,7 +50,6 @@ const selectMaxAccountDataIDSQL = "" + type accountDataStatements struct { db *sql.DB - writer sqlutil.TransactionWriter streamIDStatements *streamIDStatements insertAccountDataStmt *sql.Stmt selectMaxAccountDataIDStmt *sql.Stmt @@ -61,7 +59,6 @@ type accountDataStatements struct { func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) { s := &accountDataStatements{ db: db, - writer: sqlutil.NewTransactionWriter(), streamIDStatements: streamID, } _, err := db.Exec(accountDataSchema) @@ -84,15 +81,12 @@ func (s *accountDataStatements) InsertAccountData( ctx context.Context, txn *sql.Tx, userID, roomID, dataType string, ) (pos types.StreamPosition, err error) { - return pos, s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - var err error - pos, err = s.streamIDStatements.nextStreamID(ctx, txn) - if err != nil { - return err - } - _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType) - return err - }) + pos, err = s.streamIDStatements.nextStreamID(ctx, txn) + if err != nil { + return + } + _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType) + return } func (s *accountDataStatements) SelectAccountDataInRange( diff --git a/syncapi/storage/sqlite3/backwards_extremities_table.go b/syncapi/storage/sqlite3/backwards_extremities_table.go index d96f2fe5..116c33dc 100644 --- a/syncapi/storage/sqlite3/backwards_extremities_table.go +++ b/syncapi/storage/sqlite3/backwards_extremities_table.go @@ -19,7 +19,6 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" ) @@ -49,7 +48,6 @@ const deleteBackwardExtremitySQL = "" + type backwardExtremitiesStatements struct { db *sql.DB - writer sqlutil.TransactionWriter insertBackwardExtremityStmt *sql.Stmt selectBackwardExtremitiesForRoomStmt *sql.Stmt deleteBackwardExtremityStmt *sql.Stmt @@ -57,8 +55,7 @@ type backwardExtremitiesStatements struct { func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) { s := &backwardExtremitiesStatements{ - db: db, - writer: sqlutil.NewTransactionWriter(), + db: db, } _, err := db.Exec(backwardExtremitiesSchema) if err != nil { @@ -79,10 +76,8 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string, ) (err error) { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - _, err := txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID) - return err - }) + _, err = txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID) + return err } func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom( @@ -110,8 +105,6 @@ func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom( func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( ctx context.Context, txn *sql.Tx, roomID, knownEventID string, ) (err error) { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - _, err := txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) - return err - }) + _, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) + return err } diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 2f0068ed..6f822c90 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -85,7 +85,6 @@ const selectEventsWithEventIDsSQL = "" + type currentRoomStateStatements struct { db *sql.DB - writer sqlutil.TransactionWriter streamIDStatements *streamIDStatements upsertRoomStateStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt @@ -98,7 +97,6 @@ type currentRoomStateStatements struct { func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) { s := ¤tRoomStateStatements{ db: db, - writer: sqlutil.NewTransactionWriter(), streamIDStatements: streamID, } _, err := db.Exec(currentRoomStateSchema) @@ -200,11 +198,9 @@ func (s *currentRoomStateStatements) SelectCurrentState( func (s *currentRoomStateStatements) DeleteRoomStateByEventID( ctx context.Context, txn *sql.Tx, eventID string, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt) - _, err := stmt.ExecContext(ctx, eventID) - return err - }) + stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt) + _, err := stmt.ExecContext(ctx, eventID) + return err } func (s *currentRoomStateStatements) UpsertRoomState( @@ -225,22 +221,20 @@ func (s *currentRoomStateStatements) UpsertRoomState( } // upsert state event - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt) - _, err := stmt.ExecContext( - ctx, - event.RoomID(), - event.EventID(), - event.Type(), - event.Sender(), - containsURL, - *event.StateKey(), - headeredJSON, - membership, - addedAt, - ) - return err - }) + stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt) + _, err = stmt.ExecContext( + ctx, + event.RoomID(), + event.EventID(), + event.Type(), + event.Sender(), + containsURL, + *event.StateKey(), + headeredJSON, + membership, + addedAt, + ) + return err } func minOfInts(a, b int) int { diff --git a/syncapi/storage/sqlite3/filter_table.go b/syncapi/storage/sqlite3/filter_table.go index 338b0b50..3092bcd7 100644 --- a/syncapi/storage/sqlite3/filter_table.go +++ b/syncapi/storage/sqlite3/filter_table.go @@ -20,7 +20,6 @@ import ( "encoding/json" "fmt" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/gomatrixserverlib" ) @@ -52,7 +51,6 @@ const insertFilterSQL = "" + type filterStatements struct { db *sql.DB - writer sqlutil.TransactionWriter selectFilterStmt *sql.Stmt selectFilterIDByContentStmt *sql.Stmt insertFilterStmt *sql.Stmt @@ -64,8 +62,7 @@ func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) { return nil, err } s := &filterStatements{ - db: db, - writer: sqlutil.NewTransactionWriter(), + db: db, } if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil { return nil, err @@ -114,33 +111,30 @@ func (s *filterStatements) InsertFilter( return "", err } - err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - // Check if filter already exists in the database using its localpart and content - // - // This can result in a race condition when two clients try to insert the - // same filter and localpart at the same time, however this is not a - // problem as both calls will result in the same filterID - err = s.selectFilterIDByContentStmt.QueryRowContext(ctx, - localpart, filterJSON).Scan(&existingFilterID) - if err != nil && err != sql.ErrNoRows { - return err - } - // If it does, return the existing ID - if existingFilterID != "" { - return nil - } - - // Otherwise insert the filter and return the new ID - res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart) - if err != nil { - return err - } - rowid, err := res.LastInsertId() - if err != nil { - return err - } - filterID = fmt.Sprintf("%d", rowid) - return nil - }) + // Check if filter already exists in the database using its localpart and content + // + // This can result in a race condition when two clients try to insert the + // same filter and localpart at the same time, however this is not a + // problem as both calls will result in the same filterID + err = s.selectFilterIDByContentStmt.QueryRowContext(ctx, + localpart, filterJSON).Scan(&existingFilterID) + if err != nil && err != sql.ErrNoRows { + return "", err + } + // If it does, return the existing ID + if existingFilterID != "" { + return existingFilterID, nil + } + + // Otherwise insert the filter and return the new ID + res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart) + if err != nil { + return "", err + } + rowid, err := res.LastInsertId() + if err != nil { + return "", err + } + filterID = fmt.Sprintf("%d", rowid) return } diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go index 0bbd79f7..45862efb 100644 --- a/syncapi/storage/sqlite3/invites_table.go +++ b/syncapi/storage/sqlite3/invites_table.go @@ -59,7 +59,6 @@ const selectMaxInviteIDSQL = "" + type inviteEventsStatements struct { db *sql.DB - writer sqlutil.TransactionWriter streamIDStatements *streamIDStatements insertInviteEventStmt *sql.Stmt selectInviteEventsInRangeStmt *sql.Stmt @@ -70,7 +69,6 @@ type inviteEventsStatements struct { func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) { s := &inviteEventsStatements{ db: db, - writer: sqlutil.NewTransactionWriter(), streamIDStatements: streamID, } _, err := db.Exec(inviteEventsSchema) @@ -95,45 +93,37 @@ func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Inv func (s *inviteEventsStatements) InsertInviteEvent( ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent, ) (streamPos types.StreamPosition, err error) { - err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - var err error - streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn) - if err != nil { - return err - } + streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn) + if err != nil { + return + } - var headeredJSON []byte - headeredJSON, err = json.Marshal(inviteEvent) - if err != nil { - return err - } + var headeredJSON []byte + headeredJSON, err = json.Marshal(inviteEvent) + if err != nil { + return + } - _, err = txn.Stmt(s.insertInviteEventStmt).ExecContext( - ctx, - streamPos, - inviteEvent.RoomID(), - inviteEvent.EventID(), - *inviteEvent.StateKey(), - headeredJSON, - ) - return err - }) + stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt) + _, err = stmt.ExecContext( + ctx, + streamPos, + inviteEvent.RoomID(), + inviteEvent.EventID(), + *inviteEvent.StateKey(), + headeredJSON, + ) return } func (s *inviteEventsStatements) DeleteInviteEvent( ctx context.Context, inviteEventID string, ) (types.StreamPosition, error) { - var streamPos types.StreamPosition - err := s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - var err error - streamPos, err = s.streamIDStatements.nextStreamID(ctx, nil) - if err != nil { - return err - } - _, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID) - return err - }) + streamPos, err := s.streamIDStatements.nextStreamID(ctx, nil) + if err != nil { + return streamPos, err + } + _, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID) return streamPos, err } diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 0d154650..f10d0106 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -105,7 +105,6 @@ const selectStateInRangeSQL = "" + type outputRoomEventsStatements struct { db *sql.DB - writer sqlutil.TransactionWriter streamIDStatements *streamIDStatements insertEventStmt *sql.Stmt selectEventsStmt *sql.Stmt @@ -120,7 +119,6 @@ type outputRoomEventsStatements struct { func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) { s := &outputRoomEventsStatements{ db: db, - writer: sqlutil.NewTransactionWriter(), streamIDStatements: streamID, } _, err := db.Exec(outputRoomEventsSchema) @@ -159,10 +157,8 @@ func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event if err != nil { return err } - return s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - _, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID()) - return err - }) + _, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID()) + return err } // selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos. @@ -304,32 +300,27 @@ func (s *outputRoomEventsStatements) InsertEvent( return 0, err } - var streamPos types.StreamPosition - err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn) - if err != nil { - return err - } - - insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) - _, ierr := insertStmt.ExecContext( - ctx, - streamPos, - event.RoomID(), - event.EventID(), - headeredJSON, - event.Type(), - event.Sender(), - containsURL, - string(addStateJSON), - string(removeStateJSON), - sessionID, - txnID, - excludeFromSync, - excludeFromSync, - ) - return ierr - }) + streamPos, err := s.streamIDStatements.nextStreamID(ctx, txn) + if err != nil { + return 0, err + } + insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt) + _, err = insertStmt.ExecContext( + ctx, + streamPos, + event.RoomID(), + event.EventID(), + headeredJSON, + event.Type(), + event.Sender(), + containsURL, + string(addStateJSON), + string(removeStateJSON), + sessionID, + txnID, + excludeFromSync, + excludeFromSync, + ) return streamPos, err } diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go index 5c4ab005..d8c97b7e 100644 --- a/syncapi/storage/sqlite3/output_room_events_topology_table.go +++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go @@ -67,7 +67,6 @@ const selectMaxPositionInTopologySQL = "" + type outputRoomEventsTopologyStatements struct { db *sql.DB - writer sqlutil.TransactionWriter insertEventInTopologyStmt *sql.Stmt selectEventIDsInRangeASCStmt *sql.Stmt selectEventIDsInRangeDESCStmt *sql.Stmt @@ -77,8 +76,7 @@ type outputRoomEventsTopologyStatements struct { func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { s := &outputRoomEventsTopologyStatements{ - db: db, - writer: sqlutil.NewTransactionWriter(), + db: db, } _, err := db.Exec(outputRoomEventsTopologySchema) if err != nil { @@ -107,13 +105,11 @@ func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { func (s *outputRoomEventsTopologyStatements) InsertEventInTopology( ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition, ) (err error) { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt) - _, err := stmt.ExecContext( - ctx, event.EventID(), event.Depth(), event.RoomID(), pos, - ) - return err - }) + stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt) + _, err = stmt.ExecContext( + ctx, event.EventID(), event.Depth(), event.RoomID(), pos, + ) + return } func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( diff --git a/syncapi/storage/sqlite3/send_to_device_table.go b/syncapi/storage/sqlite3/send_to_device_table.go index 53786589..fbc759b1 100644 --- a/syncapi/storage/sqlite3/send_to_device_table.go +++ b/syncapi/storage/sqlite3/send_to_device_table.go @@ -73,7 +73,6 @@ const deleteSendToDeviceMessagesSQL = ` type sendToDeviceStatements struct { db *sql.DB - writer sqlutil.TransactionWriter insertSendToDeviceMessageStmt *sql.Stmt selectSendToDeviceMessagesStmt *sql.Stmt countSendToDeviceMessagesStmt *sql.Stmt @@ -81,8 +80,7 @@ type sendToDeviceStatements struct { func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { s := &sendToDeviceStatements{ - db: db, - writer: sqlutil.NewTransactionWriter(), + db: db, } _, err := db.Exec(sendToDeviceSchema) if err != nil { @@ -103,10 +101,8 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { func (s *sendToDeviceStatements) InsertSendToDeviceMessage( ctx context.Context, txn *sql.Tx, userID, deviceID, content string, ) (err error) { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - _, err := sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content) - return err - }) + _, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content) + return } func (s *sendToDeviceStatements) CountSendToDeviceMessages( @@ -163,10 +159,8 @@ func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages( for k, v := range nids { params[k+1] = v } - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - _, err := txn.ExecContext(ctx, query, params...) - return err - }) + _, err = txn.ExecContext(ctx, query, params...) + return } func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( @@ -177,8 +171,6 @@ func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( for k, v := range nids { params[k] = v } - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - _, err := txn.ExecContext(ctx, query, params...) - return err - }) + _, err = txn.ExecContext(ctx, query, params...) + return } diff --git a/syncapi/storage/sqlite3/stream_id_table.go b/syncapi/storage/sqlite3/stream_id_table.go index 1971e7f3..e6bdc4fc 100644 --- a/syncapi/storage/sqlite3/stream_id_table.go +++ b/syncapi/storage/sqlite3/stream_id_table.go @@ -28,14 +28,12 @@ const selectStreamIDStmt = "" + type streamIDStatements struct { db *sql.DB - writer sqlutil.TransactionWriter increaseStreamIDStmt *sql.Stmt selectStreamIDStmt *sql.Stmt } func (s *streamIDStatements) prepare(db *sql.DB) (err error) { s.db = db - s.writer = sqlutil.NewTransactionWriter() _, err = db.Exec(streamIDTableSchema) if err != nil { return @@ -52,14 +50,9 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) { func (s *streamIDStatements) nextStreamID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) { increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt) selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt) - err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error { - if _, ierr := increaseStmt.ExecContext(ctx, "global"); err != nil { - return ierr - } - if serr := selectStmt.QueryRowContext(ctx, "global").Scan(&pos); err != nil { - return serr - } - return nil - }) + if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil { + return + } + err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos) return } diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go index 9564a23a..81197bb7 100644 --- a/syncapi/storage/sqlite3/syncserver.go +++ b/syncapi/storage/sqlite3/syncserver.go @@ -31,7 +31,8 @@ import ( // both the database for PDUs and caches for EDUs. type SyncServerDatasource struct { shared.Database - db *sql.DB + db *sql.DB + writer sqlutil.Writer sqlutil.PartitionOffsetStatements streamID streamIDStatements } @@ -44,6 +45,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e if d.db, err = sqlutil.Open(dbProperties); err != nil { return nil, err } + d.writer = sqlutil.NewExclusiveWriter() if err = d.prepare(); err != nil { return nil, err } @@ -51,7 +53,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e } func (d *SyncServerDatasource) prepare() (err error) { - if err = d.PartitionOffsetStatements.Prepare(d.db, "syncapi"); err != nil { + if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "syncapi"); err != nil { return err } if err = d.streamID.prepare(d.db); err != nil { @@ -91,6 +93,7 @@ func (d *SyncServerDatasource) prepare() (err error) { } d.Database = shared.Database{ DB: d.db, + Writer: sqlutil.NewExclusiveWriter(), Invites: invites, AccountData: accountData, OutputEvents: events, @@ -99,7 +102,6 @@ func (d *SyncServerDatasource) prepare() (err error) { Topology: topology, Filter: filter, SendToDevice: sendToDevice, - SendToDeviceWriter: sqlutil.NewTransactionWriter(), EDUCache: cache.New(), } return nil diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go index 9653c019..b36264dd 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -34,7 +34,8 @@ import ( // Database represents an account database type Database struct { - db *sql.DB + db *sql.DB + writer sqlutil.Writer sqlutil.PartitionOffsetStatements accounts accountsStatements profiles profilesStatements @@ -49,27 +50,27 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err != nil { return nil, err } - partitions := sqlutil.PartitionOffsetStatements{} - if err = partitions.Prepare(db, "account"); err != nil { + d := &Database{ + serverName: serverName, + db: db, + writer: sqlutil.NewDummyWriter(), + } + if err = d.PartitionOffsetStatements.Prepare(db, d.writer, "account"); err != nil { return nil, err } - a := accountsStatements{} - if err = a.prepare(db, serverName); err != nil { + if err = d.accounts.prepare(db, serverName); err != nil { return nil, err } - p := profilesStatements{} - if err = p.prepare(db); err != nil { + if err = d.profiles.prepare(db); err != nil { return nil, err } - ac := accountDataStatements{} - if err = ac.prepare(db); err != nil { + if err = d.accountDatas.prepare(db); err != nil { return nil, err } - t := threepidStatements{} - if err = t.prepare(db); err != nil { + if err = d.threepids.prepare(db); err != nil { return nil, err } - return &Database{db, partitions, a, p, ac, t, serverName}, nil + return d, nil } // GetAccountByPassword returns the account associated with the given localpart and password. diff --git a/userapi/storage/accounts/sqlite3/account_data_table.go b/userapi/storage/accounts/sqlite3/account_data_table.go index 9b40e657..aee8db6e 100644 --- a/userapi/storage/accounts/sqlite3/account_data_table.go +++ b/userapi/storage/accounts/sqlite3/account_data_table.go @@ -51,15 +51,15 @@ const selectAccountDataByTypeSQL = "" + type accountDataStatements struct { db *sql.DB - writer sqlutil.TransactionWriter + writer sqlutil.Writer insertAccountDataStmt *sql.Stmt selectAccountDataStmt *sql.Stmt selectAccountDataByTypeStmt *sql.Stmt } -func (s *accountDataStatements) prepare(db *sql.DB) (err error) { +func (s *accountDataStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { s.db = db - s.writer = sqlutil.NewTransactionWriter() + s.writer = writer _, err = db.Exec(accountDataSchema) if err != nil { return diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go index 586bcab9..83b90668 100644 --- a/userapi/storage/accounts/sqlite3/accounts_table.go +++ b/userapi/storage/accounts/sqlite3/accounts_table.go @@ -59,7 +59,7 @@ const selectNewNumericLocalpartSQL = "" + type accountsStatements struct { db *sql.DB - writer sqlutil.TransactionWriter + writer sqlutil.Writer insertAccountStmt *sql.Stmt selectAccountByLocalpartStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt @@ -67,9 +67,9 @@ type accountsStatements struct { serverName gomatrixserverlib.ServerName } -func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { +func (s *accountsStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) { s.db = db - s.writer = sqlutil.NewTransactionWriter() + s.writer = writer _, err = db.Exec(accountsSchema) if err != nil { return diff --git a/userapi/storage/accounts/sqlite3/profile_table.go b/userapi/storage/accounts/sqlite3/profile_table.go index cd35d298..1ec45e03 100644 --- a/userapi/storage/accounts/sqlite3/profile_table.go +++ b/userapi/storage/accounts/sqlite3/profile_table.go @@ -53,7 +53,7 @@ const selectProfilesBySearchSQL = "" + type profilesStatements struct { db *sql.DB - writer sqlutil.TransactionWriter + writer sqlutil.Writer insertProfileStmt *sql.Stmt selectProfileByLocalpartStmt *sql.Stmt setAvatarURLStmt *sql.Stmt @@ -61,9 +61,9 @@ type profilesStatements struct { selectProfilesBySearchStmt *sql.Stmt } -func (s *profilesStatements) prepare(db *sql.DB) (err error) { +func (s *profilesStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { s.db = db - s.writer = sqlutil.NewTransactionWriter() + s.writer = writer _, err = db.Exec(profilesSchema) if err != nil { return diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index 4d2c5e51..4f45f754 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -33,7 +33,9 @@ import ( // Database represents an account database type Database struct { - db *sql.DB + db *sql.DB + writer sqlutil.Writer + sqlutil.PartitionOffsetStatements accounts accountsStatements profiles profilesStatements @@ -53,35 +55,28 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err != nil { return nil, err } + d := &Database{ + serverName: serverName, + db: db, + writer: sqlutil.NewExclusiveWriter(), + } partitions := sqlutil.PartitionOffsetStatements{} - if err = partitions.Prepare(db, "account"); err != nil { + if err = partitions.Prepare(db, d.writer, "account"); err != nil { return nil, err } - a := accountsStatements{} - if err = a.prepare(db, serverName); err != nil { + if err = d.accounts.prepare(db, d.writer, serverName); err != nil { return nil, err } - p := profilesStatements{} - if err = p.prepare(db); err != nil { + if err = d.profiles.prepare(db, d.writer); err != nil { return nil, err } - ac := accountDataStatements{} - if err = ac.prepare(db); err != nil { + if err = d.accountDatas.prepare(db, d.writer); err != nil { return nil, err } - t := threepidStatements{} - if err = t.prepare(db); err != nil { + if err = d.threepids.prepare(db, d.writer); err != nil { return nil, err } - return &Database{ - db: db, - PartitionOffsetStatements: partitions, - accounts: a, - profiles: p, - accountDatas: ac, - threepids: t, - serverName: serverName, - }, nil + return d, nil } // GetAccountByPassword returns the account associated with the given localpart and password. diff --git a/userapi/storage/accounts/sqlite3/threepid_table.go b/userapi/storage/accounts/sqlite3/threepid_table.go index 3000d7c4..230978fe 100644 --- a/userapi/storage/accounts/sqlite3/threepid_table.go +++ b/userapi/storage/accounts/sqlite3/threepid_table.go @@ -54,16 +54,16 @@ const deleteThreePIDSQL = "" + type threepidStatements struct { db *sql.DB - writer sqlutil.TransactionWriter + writer sqlutil.Writer selectLocalpartForThreePIDStmt *sql.Stmt selectThreePIDsForLocalpartStmt *sql.Stmt insertThreePIDStmt *sql.Stmt deleteThreePIDStmt *sql.Stmt } -func (s *threepidStatements) prepare(db *sql.DB) (err error) { +func (s *threepidStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) { s.db = db - s.writer = sqlutil.NewTransactionWriter() + s.writer = writer _, err = db.Exec(threepidSchema) if err != nil { return diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go index 962e63b0..c93e8b77 100644 --- a/userapi/storage/devices/sqlite3/devices_table.go +++ b/userapi/storage/devices/sqlite3/devices_table.go @@ -78,7 +78,7 @@ const selectDevicesByIDSQL = "" + type devicesStatements struct { db *sql.DB - writer sqlutil.TransactionWriter + writer sqlutil.Writer insertDeviceStmt *sql.Stmt selectDevicesCountStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt @@ -91,9 +91,9 @@ type devicesStatements struct { serverName gomatrixserverlib.ServerName } -func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { +func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) { s.db = db - s.writer = sqlutil.NewTransactionWriter() + s.writer = writer _, err = db.Exec(devicesSchema) if err != nil { return @@ -138,19 +138,13 @@ func (s *devicesStatements) insertDevice( ) (*api.Device, error) { createdTimeMS := time.Now().UnixNano() / 1000000 var sessionID int64 - err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt) - insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt) - if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil { - return err - } - sessionID++ - if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil { - return err - } - return nil - }) - if err != nil { + countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt) + insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt) + if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil { + return nil, err + } + sessionID++ + if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil { return nil, err } return &api.Device{ @@ -164,11 +158,9 @@ func (s *devicesStatements) insertDevice( func (s *devicesStatements) deleteDevice( ctx context.Context, txn *sql.Tx, id, localpart string, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) - _, err := stmt.ExecContext(ctx, id, localpart) - return err - }) + stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) + _, err := stmt.ExecContext(ctx, id, localpart) + return err } func (s *devicesStatements) deleteDevices( @@ -179,36 +171,30 @@ func (s *devicesStatements) deleteDevices( if err != nil { return err } - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, prep) - params := make([]interface{}, len(devices)+1) - params[0] = localpart - for i, v := range devices { - params[i+1] = v - } - _, err = stmt.ExecContext(ctx, params...) - return err - }) + stmt := sqlutil.TxStmt(txn, prep) + params := make([]interface{}, len(devices)+1) + params[0] = localpart + for i, v := range devices { + params[i+1] = v + } + _, err = stmt.ExecContext(ctx, params...) + return err } func (s *devicesStatements) deleteDevicesByLocalpart( ctx context.Context, txn *sql.Tx, localpart string, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) - _, err := stmt.ExecContext(ctx, localpart) - return err - }) + stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) + _, err := stmt.ExecContext(ctx, localpart) + return err } func (s *devicesStatements) updateDeviceName( ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, ) error { - return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { - stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) - _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) - return err - }) + stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) + _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) + return err } func (s *devicesStatements) selectDeviceByToken( diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go index 1f2b59f3..4f426c6e 100644 --- a/userapi/storage/devices/sqlite3/storage.go +++ b/userapi/storage/devices/sqlite3/storage.go @@ -34,6 +34,7 @@ var deviceIDByteLength = 6 // Database represents a device database. type Database struct { db *sql.DB + writer sqlutil.Writer devices devicesStatements } @@ -43,11 +44,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err != nil { return nil, err } + writer := sqlutil.NewExclusiveWriter() d := devicesStatements{} - if err = d.prepare(db, serverName); err != nil { + if err = d.prepare(db, writer, serverName); err != nil { return nil, err } - return &Database{db, d}, nil + return &Database{db, writer, d}, nil } // GetDeviceByAccessToken returns the device matching the given access token. @@ -88,7 +90,7 @@ func (d *Database) CreateDevice( displayName *string, ) (dev *api.Device, returnErr error) { if deviceID != nil { - returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { var err error // Revoke existing tokens for this device if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil { @@ -108,7 +110,7 @@ func (d *Database) CreateDevice( return } - returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { var err error dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName) return err @@ -138,7 +140,7 @@ func generateDeviceID() (string, error) { func (d *Database) UpdateDevice( ctx context.Context, localpart, deviceID string, displayName *string, ) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName) }) } @@ -150,7 +152,7 @@ func (d *Database) UpdateDevice( func (d *Database) RemoveDevice( ctx context.Context, deviceID, localpart string, ) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows { return err } @@ -165,7 +167,7 @@ func (d *Database) RemoveDevice( func (d *Database) RemoveDevices( ctx context.Context, localpart string, devices []string, ) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { return err } @@ -179,7 +181,7 @@ func (d *Database) RemoveDevices( func (d *Database) RemoveAllDevices( ctx context.Context, localpart string, ) error { - return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows { return err } -- cgit v1.2.3