aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNeil Alexander <neilalexander@users.noreply.github.com>2020-08-21 10:42:08 +0100
committerGitHub <noreply@github.com>2020-08-21 10:42:08 +0100
commit9d53351dc20283103bf2eec6b92831033d06c5a8 (patch)
tree653cf0ddca3f777bcdba188187fb78fe39ae2b02
parent5aaf32bbed4d704d5a22ad7dff79f7a68002a213 (diff)
Component-wide TransactionWriters (#1290)
* Offset updates take place using TransactionWriter * Refactor TransactionWriter in current state server * Refactor TransactionWriter in federation sender * Refactor TransactionWriter in key server * Refactor TransactionWriter in media API * Refactor TransactionWriter in server key API * Refactor TransactionWriter in sync API * Refactor TransactionWriter in user API * Fix deadlocking Sync API tests * Un-deadlock device database * Fix appservice API * Rename TransactionWriters to Writers * Move writers up a layer in sync API * Document sqlutil.Writer interface * Add note to Writer documentation
-rw-r--r--appservice/storage/postgres/storage.go4
-rw-r--r--appservice/storage/sqlite3/appservice_events_table.go6
-rw-r--r--appservice/storage/sqlite3/storage.go8
-rw-r--r--appservice/storage/sqlite3/txn_id_counter_table.go6
-rw-r--r--currentstateserver/storage/postgres/storage.go7
-rw-r--r--currentstateserver/storage/shared/storage.go3
-rw-r--r--currentstateserver/storage/sqlite3/current_room_state_table.go4
-rw-r--r--currentstateserver/storage/sqlite3/storage.go7
-rw-r--r--federationsender/storage/postgres/storage.go7
-rw-r--r--federationsender/storage/shared/storage.go18
-rw-r--r--federationsender/storage/sqlite3/blacklist_table.go20
-rw-r--r--federationsender/storage/sqlite3/joined_hosts_table.go26
-rw-r--r--federationsender/storage/sqlite3/queue_edus_table.go30
-rw-r--r--federationsender/storage/sqlite3/queue_json_table.go33
-rw-r--r--federationsender/storage/sqlite3/queue_pdus_table.go30
-rw-r--r--federationsender/storage/sqlite3/room_table.go18
-rw-r--r--federationsender/storage/sqlite3/storage.go7
-rw-r--r--internal/sqlutil/partition_offset_table.go13
-rw-r--r--internal/sqlutil/sql.go4
-rw-r--r--internal/sqlutil/writer.go46
-rw-r--r--internal/sqlutil/writer_dummy.go16
-rw-r--r--internal/sqlutil/writer_exclusive.go21
-rw-r--r--keyserver/storage/sqlite3/device_keys_table.go6
-rw-r--r--keyserver/storage/sqlite3/key_changes_table.go6
-rw-r--r--keyserver/storage/sqlite3/one_time_keys_table.go6
-rw-r--r--keyserver/storage/sqlite3/stale_device_lists.go17
-rw-r--r--keyserver/storage/sqlite3/storage.go9
-rw-r--r--mediaapi/storage/sqlite3/media_repository_table.go6
-rw-r--r--mediaapi/storage/sqlite3/sql.go8
-rw-r--r--mediaapi/storage/sqlite3/storage.go7
-rw-r--r--mediaapi/storage/sqlite3/thumbnail_table.go34
-rw-r--r--roomserver/storage/postgres/storage.go2
-rw-r--r--roomserver/storage/shared/storage.go2
-rw-r--r--roomserver/storage/sqlite3/storage.go6
-rw-r--r--serverkeyapi/storage/sqlite3/keydb.go7
-rw-r--r--serverkeyapi/storage/sqlite3/server_key_table.go6
-rw-r--r--syncapi/storage/postgres/syncserver.go8
-rw-r--r--syncapi/storage/shared/syncserver.go31
-rw-r--r--syncapi/storage/sqlite3/account_data_table.go18
-rw-r--r--syncapi/storage/sqlite3/backwards_extremities_table.go17
-rw-r--r--syncapi/storage/sqlite3/current_room_state_table.go40
-rw-r--r--syncapi/storage/sqlite3/filter_table.go58
-rw-r--r--syncapi/storage/sqlite3/invites_table.go56
-rw-r--r--syncapi/storage/sqlite3/output_room_events_table.go55
-rw-r--r--syncapi/storage/sqlite3/output_room_events_topology_table.go16
-rw-r--r--syncapi/storage/sqlite3/send_to_device_table.go22
-rw-r--r--syncapi/storage/sqlite3/stream_id_table.go15
-rw-r--r--syncapi/storage/sqlite3/syncserver.go8
-rw-r--r--userapi/storage/accounts/postgres/storage.go25
-rw-r--r--userapi/storage/accounts/sqlite3/account_data_table.go6
-rw-r--r--userapi/storage/accounts/sqlite3/accounts_table.go6
-rw-r--r--userapi/storage/accounts/sqlite3/profile_table.go6
-rw-r--r--userapi/storage/accounts/sqlite3/storage.go33
-rw-r--r--userapi/storage/accounts/sqlite3/threepid_table.go6
-rw-r--r--userapi/storage/devices/sqlite3/devices_table.go68
-rw-r--r--userapi/storage/devices/sqlite3/storage.go18
56 files changed, 484 insertions, 484 deletions
diff --git a/appservice/storage/postgres/storage.go b/appservice/storage/postgres/storage.go
index 9fda87ae..95215816 100644
--- a/appservice/storage/postgres/storage.go
+++ b/appservice/storage/postgres/storage.go
@@ -32,6 +32,7 @@ type Database struct {
events eventsStatements
txnID txnStatements
db *sql.DB
+ writer sqlutil.Writer
}
// NewDatabase opens a new database
@@ -41,10 +42,11 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
if result.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err
}
+ result.writer = sqlutil.NewDummyWriter()
if err = result.prepare(); err != nil {
return nil, err
}
- if err = result.PartitionOffsetStatements.Prepare(result.db, "appservice"); err != nil {
+ if err = result.PartitionOffsetStatements.Prepare(result.db, result.writer, "appservice"); err != nil {
return nil, err
}
return &result, nil
diff --git a/appservice/storage/sqlite3/appservice_events_table.go b/appservice/storage/sqlite3/appservice_events_table.go
index 5cc07ed3..5dfb72f6 100644
--- a/appservice/storage/sqlite3/appservice_events_table.go
+++ b/appservice/storage/sqlite3/appservice_events_table.go
@@ -67,7 +67,7 @@ const (
type eventsStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
+ writer sqlutil.Writer
selectEventsByApplicationServiceIDStmt *sql.Stmt
countEventsByApplicationServiceIDStmt *sql.Stmt
insertEventStmt *sql.Stmt
@@ -75,9 +75,9 @@ type eventsStatements struct {
deleteEventsBeforeAndIncludingIDStmt *sql.Stmt
}
-func (s *eventsStatements) prepare(db *sql.DB) (err error) {
+func (s *eventsStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
s.db = db
- s.writer = sqlutil.NewTransactionWriter()
+ s.writer = writer
_, err = db.Exec(appserviceEventsSchema)
if err != nil {
return
diff --git a/appservice/storage/sqlite3/storage.go b/appservice/storage/sqlite3/storage.go
index 59af9016..916845ab 100644
--- a/appservice/storage/sqlite3/storage.go
+++ b/appservice/storage/sqlite3/storage.go
@@ -32,6 +32,7 @@ type Database struct {
events eventsStatements
txnID txnStatements
db *sql.DB
+ writer sqlutil.Writer
}
// NewDatabase opens a new database
@@ -41,21 +42,22 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
if result.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err
}
+ result.writer = sqlutil.NewExclusiveWriter()
if err = result.prepare(); err != nil {
return nil, err
}
- if err = result.PartitionOffsetStatements.Prepare(result.db, "appservice"); err != nil {
+ if err = result.PartitionOffsetStatements.Prepare(result.db, result.writer, "appservice"); err != nil {
return nil, err
}
return &result, nil
}
func (d *Database) prepare() error {
- if err := d.events.prepare(d.db); err != nil {
+ if err := d.events.prepare(d.db, d.writer); err != nil {
return err
}
- return d.txnID.prepare(d.db)
+ return d.txnID.prepare(d.db, d.writer)
}
// StoreEvent takes in a gomatrixserverlib.HeaderedEvent and stores it in the database
diff --git a/appservice/storage/sqlite3/txn_id_counter_table.go b/appservice/storage/sqlite3/txn_id_counter_table.go
index 0ae0feee..b2940e35 100644
--- a/appservice/storage/sqlite3/txn_id_counter_table.go
+++ b/appservice/storage/sqlite3/txn_id_counter_table.go
@@ -38,13 +38,13 @@ const selectTxnIDSQL = `
type txnStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
+ writer sqlutil.Writer
selectTxnIDStmt *sql.Stmt
}
-func (s *txnStatements) prepare(db *sql.DB) (err error) {
+func (s *txnStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
s.db = db
- s.writer = sqlutil.NewTransactionWriter()
+ s.writer = writer
_, err = db.Exec(txnIDSchema)
if err != nil {
return
diff --git a/currentstateserver/storage/postgres/storage.go b/currentstateserver/storage/postgres/storage.go
index 0cd7e555..cb5ebff0 100644
--- a/currentstateserver/storage/postgres/storage.go
+++ b/currentstateserver/storage/postgres/storage.go
@@ -10,7 +10,8 @@ import (
type Database struct {
shared.Database
- db *sql.DB
+ db *sql.DB
+ writer sqlutil.Writer
sqlutil.PartitionOffsetStatements
}
@@ -21,7 +22,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err
}
- if err = d.PartitionOffsetStatements.Prepare(d.db, "currentstate"); err != nil {
+ d.writer = sqlutil.NewDummyWriter()
+ if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "currentstate"); err != nil {
return nil, err
}
currRoomState, err := NewPostgresCurrentRoomStateTable(d.db)
@@ -30,6 +32,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
}
d.Database = shared.Database{
DB: d.db,
+ Writer: d.writer,
CurrentRoomState: currRoomState,
}
return &d, nil
diff --git a/currentstateserver/storage/shared/storage.go b/currentstateserver/storage/shared/storage.go
index 46ef9e6c..2cf40ccc 100644
--- a/currentstateserver/storage/shared/storage.go
+++ b/currentstateserver/storage/shared/storage.go
@@ -27,6 +27,7 @@ import (
type Database struct {
DB *sql.DB
+ Writer sqlutil.Writer
CurrentRoomState tables.CurrentRoomState
}
@@ -59,7 +60,7 @@ func (d *Database) RedactEvent(ctx context.Context, redactedEventID string, reda
func (d *Database) StoreStateEvents(ctx context.Context, addStateEvents []gomatrixserverlib.HeaderedEvent,
removeStateEventIDs []string) error {
- return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
+ return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add.
for _, eventID := range removeStateEventIDs {
if err := d.CurrentRoomState.DeleteRoomStateByEventID(ctx, txn, eventID); err != nil {
diff --git a/currentstateserver/storage/sqlite3/current_room_state_table.go b/currentstateserver/storage/sqlite3/current_room_state_table.go
index 9d2fe6e0..c6cf40ed 100644
--- a/currentstateserver/storage/sqlite3/current_room_state_table.go
+++ b/currentstateserver/storage/sqlite3/current_room_state_table.go
@@ -83,7 +83,7 @@ const selectKnownUsersSQL = "" +
type currentRoomStateStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
+ writer sqlutil.Writer
upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt
selectRoomIDsWithMembershipStmt *sql.Stmt
@@ -96,7 +96,7 @@ type currentRoomStateStatements struct {
func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
s := &currentRoomStateStatements{
db: db,
- writer: sqlutil.NewTransactionWriter(),
+ writer: sqlutil.NewExclusiveWriter(),
}
_, err := db.Exec(currentRoomStateSchema)
if err != nil {
diff --git a/currentstateserver/storage/sqlite3/storage.go b/currentstateserver/storage/sqlite3/storage.go
index 4454c9ed..e79afd70 100644
--- a/currentstateserver/storage/sqlite3/storage.go
+++ b/currentstateserver/storage/sqlite3/storage.go
@@ -10,7 +10,8 @@ import (
type Database struct {
shared.Database
- db *sql.DB
+ db *sql.DB
+ writer sqlutil.Writer
sqlutil.PartitionOffsetStatements
}
@@ -22,7 +23,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err
}
- if err = d.PartitionOffsetStatements.Prepare(d.db, "currentstate"); err != nil {
+ d.writer = sqlutil.NewExclusiveWriter()
+ if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "currentstate"); err != nil {
return nil, err
}
currRoomState, err := NewSqliteCurrentRoomStateTable(d.db)
@@ -31,6 +33,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
}
d.Database = shared.Database{
DB: d.db,
+ Writer: d.writer,
CurrentRoomState: currRoomState,
}
return &d, nil
diff --git a/federationsender/storage/postgres/storage.go b/federationsender/storage/postgres/storage.go
index b65ff0b6..b3b4da39 100644
--- a/federationsender/storage/postgres/storage.go
+++ b/federationsender/storage/postgres/storage.go
@@ -27,7 +27,8 @@ import (
type Database struct {
shared.Database
sqlutil.PartitionOffsetStatements
- db *sql.DB
+ db *sql.DB
+ writer sqlutil.Writer
}
// NewDatabase opens a new database
@@ -37,6 +38,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err
}
+ d.writer = sqlutil.NewDummyWriter()
joinedHosts, err := NewPostgresJoinedHostsTable(d.db)
if err != nil {
return nil, err
@@ -63,6 +65,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
}
d.Database = shared.Database{
DB: d.db,
+ Writer: d.writer,
FederationSenderJoinedHosts: joinedHosts,
FederationSenderQueuePDUs: queuePDUs,
FederationSenderQueueEDUs: queueEDUs,
@@ -70,7 +73,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
FederationSenderRooms: rooms,
FederationSenderBlacklist: blacklist,
}
- if err = d.PartitionOffsetStatements.Prepare(d.db, "federationsender"); err != nil {
+ if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "federationsender"); err != nil {
return nil, err
}
return &d, nil
diff --git a/federationsender/storage/shared/storage.go b/federationsender/storage/shared/storage.go
index 4a681de6..4e347259 100644
--- a/federationsender/storage/shared/storage.go
+++ b/federationsender/storage/shared/storage.go
@@ -28,6 +28,7 @@ import (
type Database struct {
DB *sql.DB
+ Writer sqlutil.Writer
FederationSenderQueuePDUs tables.FederationSenderQueuePDUs
FederationSenderQueueEDUs tables.FederationSenderQueueEDUs
FederationSenderQueueJSON tables.FederationSenderQueueJSON
@@ -64,7 +65,7 @@ func (d *Database) UpdateRoom(
addHosts []types.JoinedHost,
removeHosts []string,
) (joinedHosts []types.JoinedHost, err error) {
- err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
err = d.FederationSenderRooms.InsertRoom(ctx, txn, roomID)
if err != nil {
return err
@@ -133,7 +134,12 @@ func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string)
func (d *Database) StoreJSON(
ctx context.Context, js string,
) (*Receipt, error) {
- nid, err := d.FederationSenderQueueJSON.InsertQueueJSON(ctx, nil, js)
+ var nid int64
+ var err error
+ _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ nid, err = d.FederationSenderQueueJSON.InsertQueueJSON(ctx, txn, js)
+ return nil
+ })
if err != nil {
return nil, fmt.Errorf("d.insertQueueJSON: %w", err)
}
@@ -143,11 +149,15 @@ func (d *Database) StoreJSON(
}
func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error {
- return d.FederationSenderBlacklist.InsertBlacklist(context.TODO(), nil, serverName)
+ return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ return d.FederationSenderBlacklist.InsertBlacklist(context.TODO(), txn, serverName)
+ })
}
func (d *Database) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error {
- return d.FederationSenderBlacklist.DeleteBlacklist(context.TODO(), nil, serverName)
+ return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ return d.FederationSenderBlacklist.DeleteBlacklist(context.TODO(), txn, serverName)
+ })
}
func (d *Database) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) {
diff --git a/federationsender/storage/sqlite3/blacklist_table.go b/federationsender/storage/sqlite3/blacklist_table.go
index b23bfcba..90b44ac9 100644
--- a/federationsender/storage/sqlite3/blacklist_table.go
+++ b/federationsender/storage/sqlite3/blacklist_table.go
@@ -42,7 +42,6 @@ const deleteBlacklistSQL = "" +
type blacklistStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
insertBlacklistStmt *sql.Stmt
selectBlacklistStmt *sql.Stmt
deleteBlacklistStmt *sql.Stmt
@@ -50,8 +49,7 @@ type blacklistStatements struct {
func NewSQLiteBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) {
s = &blacklistStatements{
- db: db,
- writer: sqlutil.NewTransactionWriter(),
+ db: db,
}
_, err = db.Exec(blacklistSchema)
if err != nil {
@@ -75,11 +73,9 @@ func NewSQLiteBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) {
func (s *blacklistStatements) InsertBlacklist(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) error {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt)
- _, err := stmt.ExecContext(ctx, serverName)
- return err
- })
+ stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt)
+ _, err := stmt.ExecContext(ctx, serverName)
+ return err
}
// selectRoomForUpdate locks the row for the room and returns the last_event_id.
@@ -105,9 +101,7 @@ func (s *blacklistStatements) SelectBlacklist(
func (s *blacklistStatements) DeleteBlacklist(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) error {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt)
- _, err := stmt.ExecContext(ctx, serverName)
- return err
- })
+ stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt)
+ _, err := stmt.ExecContext(ctx, serverName)
+ return err
}
diff --git a/federationsender/storage/sqlite3/joined_hosts_table.go b/federationsender/storage/sqlite3/joined_hosts_table.go
index 5dc18f4e..3bc45e7d 100644
--- a/federationsender/storage/sqlite3/joined_hosts_table.go
+++ b/federationsender/storage/sqlite3/joined_hosts_table.go
@@ -65,7 +65,6 @@ const selectJoinedHostsForRoomsSQL = "" +
type joinedHostsStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
insertJoinedHostsStmt *sql.Stmt
deleteJoinedHostsStmt *sql.Stmt
selectJoinedHostsStmt *sql.Stmt
@@ -75,8 +74,7 @@ type joinedHostsStatements struct {
func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
s = &joinedHostsStatements{
- db: db,
- writer: sqlutil.NewTransactionWriter(),
+ db: db,
}
_, err = db.Exec(joinedHostsSchema)
if err != nil {
@@ -103,25 +101,21 @@ func (s *joinedHostsStatements) InsertJoinedHosts(
roomID, eventID string,
serverName gomatrixserverlib.ServerName,
) error {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt)
- _, err := stmt.ExecContext(ctx, roomID, eventID, serverName)
- return err
- })
+ stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt)
+ _, err := stmt.ExecContext(ctx, roomID, eventID, serverName)
+ return err
}
func (s *joinedHostsStatements) DeleteJoinedHosts(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) error {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- for _, eventID := range eventIDs {
- stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt)
- if _, err := stmt.ExecContext(ctx, eventID); err != nil {
- return err
- }
+ for _, eventID := range eventIDs {
+ stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt)
+ if _, err := stmt.ExecContext(ctx, eventID); err != nil {
+ return err
}
- return nil
- })
+ }
+ return nil
}
func (s *joinedHostsStatements) SelectJoinedHostsWithTx(
diff --git a/federationsender/storage/sqlite3/queue_edus_table.go b/federationsender/storage/sqlite3/queue_edus_table.go
index 2abcc105..a6d60950 100644
--- a/federationsender/storage/sqlite3/queue_edus_table.go
+++ b/federationsender/storage/sqlite3/queue_edus_table.go
@@ -64,7 +64,6 @@ const selectQueueServerNamesSQL = "" +
type queueEDUsStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
insertQueueEDUStmt *sql.Stmt
selectQueueEDUStmt *sql.Stmt
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
@@ -74,8 +73,7 @@ type queueEDUsStatements struct {
func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) {
s = &queueEDUsStatements{
- db: db,
- writer: sqlutil.NewTransactionWriter(),
+ db: db,
}
_, err = db.Exec(queueEDUsSchema)
if err != nil {
@@ -106,16 +104,14 @@ func (s *queueEDUsStatements) InsertQueueEDU(
serverName gomatrixserverlib.ServerName,
nid int64,
) error {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt)
- _, err := stmt.ExecContext(
- ctx,
- eduType, // the EDU type
- serverName, // destination server name
- nid, // JSON blob NID
- )
- return err
- })
+ stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt)
+ _, err := stmt.ExecContext(
+ ctx,
+ eduType, // the EDU type
+ serverName, // destination server name
+ nid, // JSON blob NID
+ )
+ return err
}
func (s *queueEDUsStatements) DeleteQueueEDUs(
@@ -135,11 +131,9 @@ func (s *queueEDUsStatements) DeleteQueueEDUs(
params[k+1] = v
}
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, deleteStmt)
- _, err := stmt.ExecContext(ctx, params...)
- return err
- })
+ stmt := sqlutil.TxStmt(txn, deleteStmt)
+ _, err = stmt.ExecContext(ctx, params...)
+ return err
}
func (s *queueEDUsStatements) SelectQueueEDUs(
diff --git a/federationsender/storage/sqlite3/queue_json_table.go b/federationsender/storage/sqlite3/queue_json_table.go
index 867ffd44..3e3f60f6 100644
--- a/federationsender/storage/sqlite3/queue_json_table.go
+++ b/federationsender/storage/sqlite3/queue_json_table.go
@@ -50,7 +50,6 @@ const selectJSONSQL = "" +
type queueJSONStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
insertJSONStmt *sql.Stmt
//deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic
//selectJSONStmt *sql.Stmt - prepared at runtime due to variadic
@@ -58,8 +57,7 @@ type queueJSONStatements struct {
func NewSQLiteQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) {
s = &queueJSONStatements{
- db: db,
- writer: sqlutil.NewTransactionWriter(),
+ db: db,
}
_, err = db.Exec(queueJSONSchema)
if err != nil {
@@ -74,18 +72,15 @@ func NewSQLiteQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) {
func (s *queueJSONStatements) InsertQueueJSON(
ctx context.Context, txn *sql.Tx, json string,
) (lastid int64, err error) {
- err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, s.insertJSONStmt)
- res, err := stmt.ExecContext(ctx, json)
- if err != nil {
- return fmt.Errorf("stmt.QueryContext: %w", err)
- }
- lastid, err = res.LastInsertId()
- if err != nil {
- return fmt.Errorf("res.LastInsertId: %w", err)
- }
- return nil
- })
+ stmt := sqlutil.TxStmt(txn, s.insertJSONStmt)
+ res, err := stmt.ExecContext(ctx, json)
+ if err != nil {
+ return 0, fmt.Errorf("stmt.QueryContext: %w", err)
+ }
+ lastid, err = res.LastInsertId()
+ if err != nil {
+ return 0, fmt.Errorf("res.LastInsertId: %w", err)
+ }
return
}
@@ -103,11 +98,9 @@ func (s *queueJSONStatements) DeleteQueueJSON(
iNIDs[k] = v
}
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, deleteStmt)
- _, err = stmt.ExecContext(ctx, iNIDs...)
- return err
- })
+ stmt := sqlutil.TxStmt(txn, deleteStmt)
+ _, err = stmt.ExecContext(ctx, iNIDs...)
+ return err
}
func (s *queueJSONStatements) SelectQueueJSON(
diff --git a/federationsender/storage/sqlite3/queue_pdus_table.go b/federationsender/storage/sqlite3/queue_pdus_table.go
index 538ba3db..70519c9e 100644
--- a/federationsender/storage/sqlite3/queue_pdus_table.go
+++ b/federationsender/storage/sqlite3/queue_pdus_table.go
@@ -71,7 +71,6 @@ const selectQueuePDUsServerNamesSQL = "" +
type queuePDUsStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
insertQueuePDUStmt *sql.Stmt
selectQueueNextTransactionIDStmt *sql.Stmt
selectQueuePDUsByTransactionStmt *sql.Stmt
@@ -83,8 +82,7 @@ type queuePDUsStatements struct {
func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) {
s = &queuePDUsStatements{
- db: db,
- writer: sqlutil.NewTransactionWriter(),
+ db: db,
}
_, err = db.Exec(queuePDUsSchema)
if err != nil {
@@ -121,16 +119,14 @@ func (s *queuePDUsStatements) InsertQueuePDU(
serverName gomatrixserverlib.ServerName,
nid int64,
) error {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, s.insertQueuePDUStmt)
- _, err := stmt.ExecContext(
- ctx,
- transactionID, // the transaction ID that we initially attempted
- serverName, // destination server name
- nid, // JSON blob NID
- )
- return err
- })
+ stmt := sqlutil.TxStmt(txn, s.insertQueuePDUStmt)
+ _, err := stmt.ExecContext(
+ ctx,
+ transactionID, // the transaction ID that we initially attempted
+ serverName, // destination server name
+ nid, // JSON blob NID
+ )
+ return err
}
func (s *queuePDUsStatements) DeleteQueuePDUs(
@@ -150,11 +146,9 @@ func (s *queuePDUsStatements) DeleteQueuePDUs(
params[k+1] = v
}
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, deleteStmt)
- _, err := stmt.ExecContext(ctx, params...)
- return err
- })
+ stmt := sqlutil.TxStmt(txn, deleteStmt)
+ _, err = stmt.ExecContext(ctx, params...)
+ return err
}
func (s *queuePDUsStatements) SelectQueuePDUNextTransactionID(
diff --git a/federationsender/storage/sqlite3/room_table.go b/federationsender/storage/sqlite3/room_table.go
index 9a439fad..0710ccca 100644
--- a/federationsender/storage/sqlite3/room_table.go
+++ b/federationsender/storage/sqlite3/room_table.go
@@ -44,7 +44,6 @@ const updateRoomSQL = "" +
type roomStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
insertRoomStmt *sql.Stmt
selectRoomForUpdateStmt *sql.Stmt
updateRoomStmt *sql.Stmt
@@ -52,8 +51,7 @@ type roomStatements struct {
func NewSQLiteRoomsTable(db *sql.DB) (s *roomStatements, err error) {
s = &roomStatements{
- db: db,
- writer: sqlutil.NewTransactionWriter(),
+ db: db,
}
_, err = db.Exec(roomSchema)
if err != nil {
@@ -77,10 +75,8 @@ func NewSQLiteRoomsTable(db *sql.DB) (s *roomStatements, err error) {
func (s *roomStatements) InsertRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) error {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- _, err := sqlutil.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID)
- return err
- })
+ _, err := sqlutil.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID)
+ return err
}
// selectRoomForUpdate locks the row for the room and returns the last_event_id.
@@ -103,9 +99,7 @@ func (s *roomStatements) SelectRoomForUpdate(
func (s *roomStatements) UpdateRoom(
ctx context.Context, txn *sql.Tx, roomID, lastEventID string,
) error {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, s.updateRoomStmt)
- _, err := stmt.ExecContext(ctx, roomID, lastEventID)
- return err
- })
+ stmt := sqlutil.TxStmt(txn, s.updateRoomStmt)
+ _, err := stmt.ExecContext(ctx, roomID, lastEventID)
+ return err
}
diff --git a/federationsender/storage/sqlite3/storage.go b/federationsender/storage/sqlite3/storage.go
index 41b91871..ba467f02 100644
--- a/federationsender/storage/sqlite3/storage.go
+++ b/federationsender/storage/sqlite3/storage.go
@@ -29,7 +29,8 @@ import (
type Database struct {
shared.Database
sqlutil.PartitionOffsetStatements
- db *sql.DB
+ db *sql.DB
+ writer sqlutil.Writer
}
// NewDatabase opens a new database
@@ -39,6 +40,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err
}
+ d.writer = sqlutil.NewExclusiveWriter()
joinedHosts, err := NewSQLiteJoinedHostsTable(d.db)
if err != nil {
return nil, err
@@ -65,6 +67,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
}
d.Database = shared.Database{
DB: d.db,
+ Writer: d.writer,
FederationSenderJoinedHosts: joinedHosts,
FederationSenderQueuePDUs: queuePDUs,
FederationSenderQueueEDUs: queueEDUs,
@@ -72,7 +75,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*Database, error) {
FederationSenderRooms: rooms,
FederationSenderBlacklist: blacklist,
}
- if err = d.PartitionOffsetStatements.Prepare(d.db, "federationsender"); err != nil {
+ if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "federationsender"); err != nil {
return nil, err
}
return &d, nil
diff --git a/internal/sqlutil/partition_offset_table.go b/internal/sqlutil/partition_offset_table.go
index 34882902..be079442 100644
--- a/internal/sqlutil/partition_offset_table.go
+++ b/internal/sqlutil/partition_offset_table.go
@@ -53,6 +53,8 @@ const upsertPartitionOffsetsSQL = "" +
// PartitionOffsetStatements represents a set of statements that can be run on a partition_offsets table.
type PartitionOffsetStatements struct {
+ db *sql.DB
+ writer Writer
selectPartitionOffsetsStmt *sql.Stmt
upsertPartitionOffsetStmt *sql.Stmt
}
@@ -60,7 +62,9 @@ type PartitionOffsetStatements struct {
// Prepare converts the raw SQL statements into prepared statements.
// Takes a prefix to prepend to the table name used to store the partition offsets.
// This allows multiple components to share the same database schema.
-func (s *PartitionOffsetStatements) Prepare(db *sql.DB, prefix string) (err error) {
+func (s *PartitionOffsetStatements) Prepare(db *sql.DB, writer Writer, prefix string) (err error) {
+ s.db = db
+ s.writer = writer
_, err = db.Exec(strings.Replace(partitionOffsetsSchema, "${prefix}", prefix, -1))
if err != nil {
return
@@ -121,6 +125,9 @@ func (s *PartitionOffsetStatements) selectPartitionOffsets(
func (s *PartitionOffsetStatements) upsertPartitionOffset(
ctx context.Context, topic string, partition int32, offset int64,
) error {
- _, err := s.upsertPartitionOffsetStmt.ExecContext(ctx, topic, partition, offset)
- return err
+ return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
+ stmt := TxStmt(txn, s.upsertPartitionOffsetStmt)
+ _, err := stmt.ExecContext(ctx, topic, partition, offset)
+ return err
+ })
}
diff --git a/internal/sqlutil/sql.go b/internal/sqlutil/sql.go
index 002d7718..d296c418 100644
--- a/internal/sqlutil/sql.go
+++ b/internal/sqlutil/sql.go
@@ -103,7 +103,3 @@ func SQLiteDriverName() string {
}
return "sqlite3"
}
-
-type TransactionWriter interface {
- Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error
-}
diff --git a/internal/sqlutil/writer.go b/internal/sqlutil/writer.go
new file mode 100644
index 00000000..5d93fef4
--- /dev/null
+++ b/internal/sqlutil/writer.go
@@ -0,0 +1,46 @@
+package sqlutil
+
+import "database/sql"
+
+// The Writer interface is designed to solve the problem of how
+// to handle database writes for database engines that don't allow
+// concurrent writes, e.g. SQLite.
+//
+// The interface has a single Do function which takes an optional
+// database parameter, an optional transaction parameter and a
+// required function parameter. The Writer will call the function
+// provided when it is safe to do so, optionally providing a
+// transaction to use.
+//
+// Depending on the combination of parameters provided, the Writer
+// will behave in one of three ways:
+//
+// 1. `db` provided, `txn` provided:
+//
+// The Writer will call f() when it is safe to do so. The supplied
+// "txn" will ALWAYS be passed through to f(). Use this when you
+// already have a transaction open.
+//
+// 2. `db` provided, `txn` not provided (nil):
+//
+// The Writer will open a new transaction on the provided database
+// and then will call f() when it is safe to do so. The new
+// transaction will ALWAYS be passed through to f(). Use this if
+// you plan to perform more than one SQL query within f().
+//
+// 3. `db` not provided (nil), `txn` not provided (nil):
+//
+// The Writer will call f() when it is safe to do so, but will
+// not make any attempt to open a new database transaction or to
+// pass through an existing one. The "txn" parameter within f()
+// will ALWAYS be nil in this mode. This is useful if you just
+// want to perform a single query on an already-prepared statement
+// without the overhead of opening a new transaction to do it in.
+//
+// You MUST take particular care not to call Do() from within f()
+// on the same Writer, or it will likely result in a deadlock.
+type Writer interface {
+ // Queue up one or more database write operations within the
+ // provided function to be executed when it is safe to do so.
+ Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error
+}
diff --git a/internal/sqlutil/writer_dummy.go b/internal/sqlutil/writer_dummy.go
index e6ab81f6..f426c2bc 100644
--- a/internal/sqlutil/writer_dummy.go
+++ b/internal/sqlutil/writer_dummy.go
@@ -4,15 +4,21 @@ import (
"database/sql"
)
-type DummyTransactionWriter struct {
+// DummyWriter implements sqlutil.Writer.
+// The DummyWriter is designed to allow reuse of the sqlutil.Writer
+// interface but, unlike ExclusiveWriter, it will not guarantee
+// writer exclusivity. This is fine in PostgreSQL where overlapping
+// transactions and writes are acceptable.
+type DummyWriter struct {
}
-func NewDummyTransactionWriter() TransactionWriter {
- return &DummyTransactionWriter{}
+// NewDummyWriter returns a new dummy writer.
+func NewDummyWriter() Writer {
+ return &DummyWriter{}
}
-func (w *DummyTransactionWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error {
- if txn == nil {
+func (w *DummyWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error {
+ if db != nil && txn == nil {
return WithTransaction(db, func(txn *sql.Tx) error {
return f(txn)
})
diff --git a/internal/sqlutil/writer_exclusive.go b/internal/sqlutil/writer_exclusive.go
index 2e3666ae..002bc32c 100644
--- a/internal/sqlutil/writer_exclusive.go
+++ b/internal/sqlutil/writer_exclusive.go
@@ -7,16 +7,17 @@ import (
"go.uber.org/atomic"
)
-// ExclusiveTransactionWriter allows queuing database writes so that you don't
+// ExclusiveWriter implements sqlutil.Writer.
+// ExclusiveWriter allows queuing database writes so that you don't
// contend on database locks in, e.g. SQLite. Only one task will run
-// at a time on a given ExclusiveTransactionWriter.
-type ExclusiveTransactionWriter struct {
+// at a time on a given ExclusiveWriter.
+type ExclusiveWriter struct {
running atomic.Bool
todo chan transactionWriterTask
}
-func NewTransactionWriter() TransactionWriter {
- return &ExclusiveTransactionWriter{
+func NewExclusiveWriter() Writer {
+ return &ExclusiveWriter{
todo: make(chan transactionWriterTask),
}
}
@@ -34,7 +35,7 @@ type transactionWriterTask struct {
// txn parameter if one is supplied, and if not, will take out a
// new transaction from the database supplied in the database
// parameter. Either way, this will block until the task is done.
-func (w *ExclusiveTransactionWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error {
+func (w *ExclusiveWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error {
if w.todo == nil {
return errors.New("not initialised")
}
@@ -55,20 +56,20 @@ func (w *ExclusiveTransactionWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql
// of these goroutines will run at a time. A transaction will be
// opened using the database object from the task and then this will
// be passed as a parameter to the task function.
-func (w *ExclusiveTransactionWriter) run() {
+func (w *ExclusiveWriter) run() {
if !w.running.CAS(false, true) {
return
}
defer w.running.Store(false)
for task := range w.todo {
- if task.txn != nil {
+ if task.db != nil && task.txn != nil {
task.wait <- task.f(task.txn)
- } else if task.db != nil {
+ } else if task.db != nil && task.txn == nil {
task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error {
return task.f(txn)
})
} else {
- panic("expected database or transaction but got neither")
+ task.wait <- task.f(nil)
}
close(task.wait)
}
diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go
index c95790be..2af33761 100644
--- a/keyserver/storage/sqlite3/device_keys_table.go
+++ b/keyserver/storage/sqlite3/device_keys_table.go
@@ -63,7 +63,7 @@ const deleteAllDeviceKeysSQL = "" +
type deviceKeysStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
+ writer sqlutil.Writer
upsertDeviceKeysStmt *sql.Stmt
selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt
@@ -71,10 +71,10 @@ type deviceKeysStatements struct {
deleteAllDeviceKeysStmt *sql.Stmt
}
-func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
+func NewSqliteDeviceKeysTable(db *sql.DB, writer sqlutil.Writer) (tables.DeviceKeys, error) {
s := &deviceKeysStatements{
db: db,
- writer: sqlutil.NewTransactionWriter(),
+ writer: writer,
}
_, err := db.Exec(deviceKeysSchema)
if err != nil {
diff --git a/keyserver/storage/sqlite3/key_changes_table.go b/keyserver/storage/sqlite3/key_changes_table.go
index f451d657..cd178413 100644
--- a/keyserver/storage/sqlite3/key_changes_table.go
+++ b/keyserver/storage/sqlite3/key_changes_table.go
@@ -52,15 +52,15 @@ const selectKeyChangesSQL = "" +
type keyChangesStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
+ writer sqlutil.Writer
upsertKeyChangeStmt *sql.Stmt
selectKeyChangesStmt *sql.Stmt
}
-func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
+func NewSqliteKeyChangesTable(db *sql.DB, writer sqlutil.Writer) (tables.KeyChanges, error) {
s := &keyChangesStatements{
db: db,
- writer: sqlutil.NewTransactionWriter(),
+ writer: writer,
}
_, err := db.Exec(keyChangesSchema)
if err != nil {
diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/keyserver/storage/sqlite3/one_time_keys_table.go
index c71cc47d..d788f676 100644
--- a/keyserver/storage/sqlite3/one_time_keys_table.go
+++ b/keyserver/storage/sqlite3/one_time_keys_table.go
@@ -60,7 +60,7 @@ const selectKeyByAlgorithmSQL = "" +
type oneTimeKeysStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
+ writer sqlutil.Writer
upsertKeysStmt *sql.Stmt
selectKeysStmt *sql.Stmt
selectKeysCountStmt *sql.Stmt
@@ -68,10 +68,10 @@ type oneTimeKeysStatements struct {
deleteOneTimeKeyStmt *sql.Stmt
}
-func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
+func NewSqliteOneTimeKeysTable(db *sql.DB, writer sqlutil.Writer) (tables.OneTimeKeys, error) {
s := &oneTimeKeysStatements{
db: db,
- writer: sqlutil.NewTransactionWriter(),
+ writer: writer,
}
_, err := db.Exec(oneTimeKeysSchema)
if err != nil {
diff --git a/keyserver/storage/sqlite3/stale_device_lists.go b/keyserver/storage/sqlite3/stale_device_lists.go
index a989476d..8b6f8813 100644
--- a/keyserver/storage/sqlite3/stale_device_lists.go
+++ b/keyserver/storage/sqlite3/stale_device_lists.go
@@ -20,6 +20,7 @@ import (
"time"
"github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
@@ -49,13 +50,18 @@ const selectStaleDeviceListsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1"
type staleDeviceListsStatements struct {
+ db *sql.DB
+ writer sqlutil.Writer
upsertStaleDeviceListStmt *sql.Stmt
selectStaleDeviceListsWithDomainsStmt *sql.Stmt
selectStaleDeviceListsStmt *sql.Stmt
}
-func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
- s := &staleDeviceListsStatements{}
+func NewSqliteStaleDeviceListsTable(db *sql.DB, writer sqlutil.Writer) (tables.StaleDeviceLists, error) {
+ s := &staleDeviceListsStatements{
+ db: db,
+ writer: writer,
+ }
_, err := db.Exec(staleDeviceListsSchema)
if err != nil {
return nil, err
@@ -77,8 +83,11 @@ func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context,
if err != nil {
return err
}
- _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix())
- return err
+ return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
+ stmt := sqlutil.TxStmt(txn, s.upsertStaleDeviceListStmt)
+ _, err = stmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix())
+ return err
+ })
}
func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
diff --git a/keyserver/storage/sqlite3/storage.go b/keyserver/storage/sqlite3/storage.go
index bb293558..1a2a237f 100644
--- a/keyserver/storage/sqlite3/storage.go
+++ b/keyserver/storage/sqlite3/storage.go
@@ -25,19 +25,20 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error)
if err != nil {
return nil, err
}
- otk, err := NewSqliteOneTimeKeysTable(db)
+ writer := sqlutil.NewExclusiveWriter()
+ otk, err := NewSqliteOneTimeKeysTable(db, writer)
if err != nil {
return nil, err
}
- dk, err := NewSqliteDeviceKeysTable(db)
+ dk, err := NewSqliteDeviceKeysTable(db, writer)
if err != nil {
return nil, err
}
- kc, err := NewSqliteKeyChangesTable(db)
+ kc, err := NewSqliteKeyChangesTable(db, writer)
if err != nil {
return nil, err
}
- sdl, err := NewSqliteStaleDeviceListsTable(db)
+ sdl, err := NewSqliteStaleDeviceListsTable(db, writer)
if err != nil {
return nil, err
}
diff --git a/mediaapi/storage/sqlite3/media_repository_table.go b/mediaapi/storage/sqlite3/media_repository_table.go
index ff6ddf3d..dcc1b41e 100644
--- a/mediaapi/storage/sqlite3/media_repository_table.go
+++ b/mediaapi/storage/sqlite3/media_repository_table.go
@@ -62,14 +62,14 @@ SELECT content_type, file_size_bytes, creation_ts, upload_name, base64hash, user
type mediaStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
+ writer sqlutil.Writer
insertMediaStmt *sql.Stmt
selectMediaStmt *sql.Stmt
}
-func (s *mediaStatements) prepare(db *sql.DB) (err error) {
+func (s *mediaStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
s.db = db
- s.writer = sqlutil.NewTransactionWriter()
+ s.writer = writer
_, err = db.Exec(mediaSchema)
if err != nil {
diff --git a/mediaapi/storage/sqlite3/sql.go b/mediaapi/storage/sqlite3/sql.go
index 9cd78b8e..245bd40c 100644
--- a/mediaapi/storage/sqlite3/sql.go
+++ b/mediaapi/storage/sqlite3/sql.go
@@ -17,6 +17,8 @@ package sqlite3
import (
"database/sql"
+
+ "github.com/matrix-org/dendrite/internal/sqlutil"
)
type statements struct {
@@ -24,11 +26,11 @@ type statements struct {
thumbnail thumbnailStatements
}
-func (s *statements) prepare(db *sql.DB) (err error) {
- if err = s.media.prepare(db); err != nil {
+func (s *statements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
+ if err = s.media.prepare(db, writer); err != nil {
return
}
- if err = s.thumbnail.prepare(db); err != nil {
+ if err = s.thumbnail.prepare(db, writer); err != nil {
return
}
diff --git a/mediaapi/storage/sqlite3/storage.go b/mediaapi/storage/sqlite3/storage.go
index a1e7fec7..d5c3031e 100644
--- a/mediaapi/storage/sqlite3/storage.go
+++ b/mediaapi/storage/sqlite3/storage.go
@@ -31,16 +31,19 @@ import (
type Database struct {
statements statements
db *sql.DB
+ writer sqlutil.Writer
}
// Open opens a postgres database.
func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
- var d Database
+ d := Database{
+ writer: sqlutil.NewExclusiveWriter(),
+ }
var err error
if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err
}
- if err = d.statements.prepare(d.db); err != nil {
+ if err = d.statements.prepare(d.db, d.writer); err != nil {
return nil, err
}
return &d, nil
diff --git a/mediaapi/storage/sqlite3/thumbnail_table.go b/mediaapi/storage/sqlite3/thumbnail_table.go
index 432a1590..06b056b6 100644
--- a/mediaapi/storage/sqlite3/thumbnail_table.go
+++ b/mediaapi/storage/sqlite3/thumbnail_table.go
@@ -21,6 +21,7 @@ import (
"time"
"github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
@@ -57,16 +58,20 @@ SELECT content_type, file_size_bytes, creation_ts, width, height, resize_method
`
type thumbnailStatements struct {
+ db *sql.DB
+ writer sqlutil.Writer
insertThumbnailStmt *sql.Stmt
selectThumbnailStmt *sql.Stmt
selectThumbnailsStmt *sql.Stmt
}
-func (s *thumbnailStatements) prepare(db *sql.DB) (err error) {
+func (s *thumbnailStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
_, err = db.Exec(thumbnailSchema)
if err != nil {
return
}
+ s.db = db
+ s.writer = writer
return statementList{
{&s.insertThumbnailStmt, insertThumbnailSQL},
@@ -79,18 +84,21 @@ func (s *thumbnailStatements) insertThumbnail(
ctx context.Context, thumbnailMetadata *types.ThumbnailMetadata,
) error {
thumbnailMetadata.MediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
- _, err := s.insertThumbnailStmt.ExecContext(
- ctx,
- thumbnailMetadata.MediaMetadata.MediaID,
- thumbnailMetadata.MediaMetadata.Origin,
- thumbnailMetadata.MediaMetadata.ContentType,
- thumbnailMetadata.MediaMetadata.FileSizeBytes,
- thumbnailMetadata.MediaMetadata.CreationTimestamp,
- thumbnailMetadata.ThumbnailSize.Width,
- thumbnailMetadata.ThumbnailSize.Height,
- thumbnailMetadata.ThumbnailSize.ResizeMethod,
- )
- return err
+ return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
+ stmt := sqlutil.TxStmt(txn, s.insertThumbnailStmt)
+ _, err := stmt.ExecContext(
+ ctx,
+ thumbnailMetadata.MediaMetadata.MediaID,
+ thumbnailMetadata.MediaMetadata.Origin,
+ thumbnailMetadata.MediaMetadata.ContentType,
+ thumbnailMetadata.MediaMetadata.FileSizeBytes,
+ thumbnailMetadata.MediaMetadata.CreationTimestamp,
+ thumbnailMetadata.ThumbnailSize.Width,
+ thumbnailMetadata.ThumbnailSize.Height,
+ thumbnailMetadata.ThumbnailSize.ResizeMethod,
+ )
+ return err
+ })
}
func (s *thumbnailStatements) selectThumbnail(
diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go
index 0b7ed225..d217b5d2 100644
--- a/roomserver/storage/postgres/storage.go
+++ b/roomserver/storage/postgres/storage.go
@@ -98,7 +98,7 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
}
d.Database = shared.Database{
DB: db,
- Writer: sqlutil.NewDummyTransactionWriter(),
+ Writer: sqlutil.NewDummyWriter(),
EventTypesTable: eventTypes,
EventStateKeysTable: eventStateKeys,
EventJSONTable: eventJSON,
diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go
index 56c2b029..7101376a 100644
--- a/roomserver/storage/shared/storage.go
+++ b/roomserver/storage/shared/storage.go
@@ -27,7 +27,7 @@ const redactionsArePermanent = false
type Database struct {
DB *sql.DB
- Writer sqlutil.TransactionWriter
+ Writer sqlutil.Writer
EventsTable tables.Events
EventJSONTable tables.EventJSON
EventTypesTable tables.EventTypes
diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go
index 874bbbc7..d1738966 100644
--- a/roomserver/storage/sqlite3/storage.go
+++ b/roomserver/storage/sqlite3/storage.go
@@ -41,7 +41,7 @@ type Database struct {
invites tables.Invites
membership tables.Membership
db *sql.DB
- writer sqlutil.TransactionWriter
+ writer sqlutil.Writer
}
// Open a sqlite database.
@@ -52,7 +52,7 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err
}
- d.writer = sqlutil.NewTransactionWriter()
+ d.writer = sqlutil.NewExclusiveWriter()
//d.db.Exec("PRAGMA journal_mode=WAL;")
//d.db.Exec("PRAGMA read_uncommitted = true;")
@@ -120,7 +120,7 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
}
d.Database = shared.Database{
DB: d.db,
- Writer: sqlutil.NewTransactionWriter(),
+ Writer: sqlutil.NewExclusiveWriter(),
EventsTable: d.events,
EventTypesTable: d.eventTypes,
EventStateKeysTable: d.eventStateKeys,
diff --git a/serverkeyapi/storage/sqlite3/keydb.go b/serverkeyapi/storage/sqlite3/keydb.go
index 5174ece1..0ee74bc1 100644
--- a/serverkeyapi/storage/sqlite3/keydb.go
+++ b/serverkeyapi/storage/sqlite3/keydb.go
@@ -30,6 +30,7 @@ import (
// A Database implements gomatrixserverlib.KeyDatabase and is used to store
// the public keys for other matrix servers.
type Database struct {
+ writer sqlutil.Writer
statements serverKeyStatements
}
@@ -47,8 +48,10 @@ func NewDatabase(
if err != nil {
return nil, err
}
- d := &Database{}
- err = d.statements.prepare(db)
+ d := &Database{
+ writer: sqlutil.NewExclusiveWriter(),
+ }
+ err = d.statements.prepare(db, d.writer)
if err != nil {
return nil, err
}
diff --git a/serverkeyapi/storage/sqlite3/server_key_table.go b/serverkeyapi/storage/sqlite3/server_key_table.go
index b829eae7..f756ef5e 100644
--- a/serverkeyapi/storage/sqlite3/server_key_table.go
+++ b/serverkeyapi/storage/sqlite3/server_key_table.go
@@ -63,14 +63,14 @@ const upsertServerKeysSQL = "" +
type serverKeyStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
+ writer sqlutil.Writer
bulkSelectServerKeysStmt *sql.Stmt
upsertServerKeysStmt *sql.Stmt
}
-func (s *serverKeyStatements) prepare(db *sql.DB) (err error) {
+func (s *serverKeyStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
s.db = db
- s.writer = sqlutil.NewTransactionWriter()
+ s.writer = writer
_, err = db.Exec(serverKeysSchema)
if err != nil {
return
diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go
index 26ef082f..36e8de67 100644
--- a/syncapi/storage/postgres/syncserver.go
+++ b/syncapi/storage/postgres/syncserver.go
@@ -30,7 +30,8 @@ import (
// both the database for PDUs and caches for EDUs.
type SyncServerDatasource struct {
shared.Database
- db *sql.DB
+ db *sql.DB
+ writer sqlutil.Writer
sqlutil.PartitionOffsetStatements
}
@@ -41,7 +42,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err
}
- if err = d.PartitionOffsetStatements.Prepare(d.db, "syncapi"); err != nil {
+ d.writer = sqlutil.NewDummyWriter()
+ if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "syncapi"); err != nil {
return nil, err
}
accountData, err := NewPostgresAccountDataTable(d.db)
@@ -78,6 +80,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
}
d.Database = shared.Database{
DB: d.db,
+ Writer: sqlutil.NewDummyWriter(),
Invites: invites,
AccountData: accountData,
OutputEvents: events,
@@ -86,7 +89,6 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
BackwardExtremities: backwardExtremities,
Filter: filter,
SendToDevice: sendToDevice,
- SendToDeviceWriter: sqlutil.NewTransactionWriter(),
EDUCache: cache.New(),
}
return &d, nil
diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go
index fdbf6758..699a6647 100644
--- a/syncapi/storage/shared/syncserver.go
+++ b/syncapi/storage/shared/syncserver.go
@@ -37,6 +37,7 @@ import (
// For now this contains the shared functions
type Database struct {
DB *sql.DB
+ Writer sqlutil.Writer
Invites tables.Invites
AccountData tables.AccountData
OutputEvents tables.Events
@@ -45,7 +46,6 @@ type Database struct {
BackwardExtremities tables.BackwardsExtremities
SendToDevice tables.SendToDevice
Filter tables.Filter
- SendToDeviceWriter sqlutil.TransactionWriter
EDUCache *cache.EDUCache
}
@@ -129,10 +129,7 @@ func (d *Database) GetStateEvent(
func (d *Database) GetStateEventsForRoom(
ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter,
) (stateEvents []gomatrixserverlib.HeaderedEvent, err error) {
- err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
- stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, txn, roomID, stateFilter)
- return err
- })
+ stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, nil, roomID, stateFilter)
return
}
@@ -171,9 +168,9 @@ func (d *Database) SyncStreamPosition(ctx context.Context) (types.StreamPosition
func (d *Database) AddInviteEvent(
ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent,
) (sp types.StreamPosition, err error) {
- err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
- sp, err = d.Invites.InsertInviteEvent(ctx, txn, inviteEvent)
- return err
+ _ = d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
+ sp, err = d.Invites.InsertInviteEvent(ctx, nil, inviteEvent)
+ return nil
})
return
}
@@ -182,8 +179,12 @@ func (d *Database) AddInviteEvent(
// Returns an error if there was a problem communicating with the database.
func (d *Database) RetireInviteEvent(
ctx context.Context, inviteEventID string,
-) (types.StreamPosition, error) {
- return d.Invites.DeleteInviteEvent(ctx, inviteEventID)
+) (sp types.StreamPosition, err error) {
+ _ = d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
+ sp, err = d.Invites.DeleteInviteEvent(ctx, inviteEventID)
+ return nil
+ })
+ return
}
// GetAccountDataInRange returns all account data for a given user inserted or
@@ -207,7 +208,7 @@ func (d *Database) GetAccountDataInRange(
func (d *Database) UpsertAccountData(
ctx context.Context, userID, roomID, dataType string,
) (sp types.StreamPosition, err error) {
- err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
sp, err = d.AccountData.InsertAccountData(ctx, txn, userID, roomID, dataType)
return err
})
@@ -237,6 +238,7 @@ func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.Strea
// handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of
// the events listed in the event's 'prev_events'. This function also updates the backwards extremities table
// to account for the fact that the given event is no longer a backwards extremity, but may be marked as such.
+// This function should always be called within a sqlutil.Writer for safety in SQLite.
func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, ev *gomatrixserverlib.HeaderedEvent) error {
if err := d.BackwardExtremities.DeleteBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID()); err != nil {
return err
@@ -275,7 +277,7 @@ func (d *Database) WriteEvent(
addStateEventIDs, removeStateEventIDs []string,
transactionID *api.TransactionID, excludeFromSync bool,
) (pduPosition types.StreamPosition, returnErr error) {
- returnErr = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
+ returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var err error
pos, err := d.OutputEvents.InsertEvent(
ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync,
@@ -304,6 +306,7 @@ func (d *Database) WriteEvent(
return pduPosition, returnErr
}
+// This function should always be called within a sqlutil.Writer for safety in SQLite.
func (d *Database) updateRoomState(
ctx context.Context, txn *sql.Tx,
removedEventIDs []string,
@@ -1114,7 +1117,7 @@ func (d *Database) StoreNewSendForDeviceMessage(
}
// Delegate the database write task to the SendToDeviceWriter. It'll guarantee
// that we don't lock the table for writes in more than one place.
- err = d.SendToDeviceWriter.Do(d.DB, nil, func(txn *sql.Tx) error {
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.AddSendToDeviceEvent(
ctx, txn, userID, deviceID, string(j),
)
@@ -1179,7 +1182,7 @@ func (d *Database) CleanSendToDeviceUpdates(
// If we need to write to the database then we'll ask the SendToDeviceWriter to
// do that for us. It'll guarantee that we don't lock the table for writes in
// more than one place.
- err = d.SendToDeviceWriter.Do(d.DB, nil, func(txn *sql.Tx) error {
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// Delete any send-to-device messages marked for deletion.
if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil {
return fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e)
diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go
index 248ec926..72c46e48 100644
--- a/syncapi/storage/sqlite3/account_data_table.go
+++ b/syncapi/storage/sqlite3/account_data_table.go
@@ -20,7 +20,6 @@ import (
"database/sql"
"github.com/matrix-org/dendrite/internal"
- "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
@@ -51,7 +50,6 @@ const selectMaxAccountDataIDSQL = "" +
type accountDataStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
streamIDStatements *streamIDStatements
insertAccountDataStmt *sql.Stmt
selectMaxAccountDataIDStmt *sql.Stmt
@@ -61,7 +59,6 @@ type accountDataStatements struct {
func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) {
s := &accountDataStatements{
db: db,
- writer: sqlutil.NewTransactionWriter(),
streamIDStatements: streamID,
}
_, err := db.Exec(accountDataSchema)
@@ -84,15 +81,12 @@ func (s *accountDataStatements) InsertAccountData(
ctx context.Context, txn *sql.Tx,
userID, roomID, dataType string,
) (pos types.StreamPosition, err error) {
- return pos, s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
- var err error
- pos, err = s.streamIDStatements.nextStreamID(ctx, txn)
- if err != nil {
- return err
- }
- _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType)
- return err
- })
+ pos, err = s.streamIDStatements.nextStreamID(ctx, txn)
+ if err != nil {
+ return
+ }
+ _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType)
+ return
}
func (s *accountDataStatements) SelectAccountDataInRange(
diff --git a/syncapi/storage/sqlite3/backwards_extremities_table.go b/syncapi/storage/sqlite3/backwards_extremities_table.go
index d96f2fe5..116c33dc 100644
--- a/syncapi/storage/sqlite3/backwards_extremities_table.go
+++ b/syncapi/storage/sqlite3/backwards_extremities_table.go
@@ -19,7 +19,6 @@ import (
"database/sql"
"github.com/matrix-org/dendrite/internal"
- "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
)
@@ -49,7 +48,6 @@ const deleteBackwardExtremitySQL = "" +
type backwardExtremitiesStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
insertBackwardExtremityStmt *sql.Stmt
selectBackwardExtremitiesForRoomStmt *sql.Stmt
deleteBackwardExtremityStmt *sql.Stmt
@@ -57,8 +55,7 @@ type backwardExtremitiesStatements struct {
func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
s := &backwardExtremitiesStatements{
- db: db,
- writer: sqlutil.NewTransactionWriter(),
+ db: db,
}
_, err := db.Exec(backwardExtremitiesSchema)
if err != nil {
@@ -79,10 +76,8 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities
func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string,
) (err error) {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- _, err := txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
- return err
- })
+ _, err = txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
+ return err
}
func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
@@ -110,8 +105,6 @@ func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
ctx context.Context, txn *sql.Tx, roomID, knownEventID string,
) (err error) {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- _, err := txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
- return err
- })
+ _, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
+ return err
}
diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go
index 2f0068ed..6f822c90 100644
--- a/syncapi/storage/sqlite3/current_room_state_table.go
+++ b/syncapi/storage/sqlite3/current_room_state_table.go
@@ -85,7 +85,6 @@ const selectEventsWithEventIDsSQL = "" +
type currentRoomStateStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
streamIDStatements *streamIDStatements
upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt
@@ -98,7 +97,6 @@ type currentRoomStateStatements struct {
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) {
s := &currentRoomStateStatements{
db: db,
- writer: sqlutil.NewTransactionWriter(),
streamIDStatements: streamID,
}
_, err := db.Exec(currentRoomStateSchema)
@@ -200,11 +198,9 @@ func (s *currentRoomStateStatements) SelectCurrentState(
func (s *currentRoomStateStatements) DeleteRoomStateByEventID(
ctx context.Context, txn *sql.Tx, eventID string,
) error {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt)
- _, err := stmt.ExecContext(ctx, eventID)
- return err
- })
+ stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt)
+ _, err := stmt.ExecContext(ctx, eventID)
+ return err
}
func (s *currentRoomStateStatements) UpsertRoomState(
@@ -225,22 +221,20 @@ func (s *currentRoomStateStatements) UpsertRoomState(
}
// upsert state event
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt)
- _, err := stmt.ExecContext(
- ctx,
- event.RoomID(),
- event.EventID(),
- event.Type(),
- event.Sender(),
- containsURL,
- *event.StateKey(),
- headeredJSON,
- membership,
- addedAt,
- )
- return err
- })
+ stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt)
+ _, err = stmt.ExecContext(
+ ctx,
+ event.RoomID(),
+ event.EventID(),
+ event.Type(),
+ event.Sender(),
+ containsURL,
+ *event.StateKey(),
+ headeredJSON,
+ membership,
+ addedAt,
+ )
+ return err
}
func minOfInts(a, b int) int {
diff --git a/syncapi/storage/sqlite3/filter_table.go b/syncapi/storage/sqlite3/filter_table.go
index 338b0b50..3092bcd7 100644
--- a/syncapi/storage/sqlite3/filter_table.go
+++ b/syncapi/storage/sqlite3/filter_table.go
@@ -20,7 +20,6 @@ import (
"encoding/json"
"fmt"
- "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
@@ -52,7 +51,6 @@ const insertFilterSQL = "" +
type filterStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
selectFilterStmt *sql.Stmt
selectFilterIDByContentStmt *sql.Stmt
insertFilterStmt *sql.Stmt
@@ -64,8 +62,7 @@ func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) {
return nil, err
}
s := &filterStatements{
- db: db,
- writer: sqlutil.NewTransactionWriter(),
+ db: db,
}
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
return nil, err
@@ -114,33 +111,30 @@ func (s *filterStatements) InsertFilter(
return "", err
}
- err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
- // Check if filter already exists in the database using its localpart and content
- //
- // This can result in a race condition when two clients try to insert the
- // same filter and localpart at the same time, however this is not a
- // problem as both calls will result in the same filterID
- err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
- localpart, filterJSON).Scan(&existingFilterID)
- if err != nil && err != sql.ErrNoRows {
- return err
- }
- // If it does, return the existing ID
- if existingFilterID != "" {
- return nil
- }
-
- // Otherwise insert the filter and return the new ID
- res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart)
- if err != nil {
- return err
- }
- rowid, err := res.LastInsertId()
- if err != nil {
- return err
- }
- filterID = fmt.Sprintf("%d", rowid)
- return nil
- })
+ // Check if filter already exists in the database using its localpart and content
+ //
+ // This can result in a race condition when two clients try to insert the
+ // same filter and localpart at the same time, however this is not a
+ // problem as both calls will result in the same filterID
+ err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
+ localpart, filterJSON).Scan(&existingFilterID)
+ if err != nil && err != sql.ErrNoRows {
+ return "", err
+ }
+ // If it does, return the existing ID
+ if existingFilterID != "" {
+ return existingFilterID, nil
+ }
+
+ // Otherwise insert the filter and return the new ID
+ res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart)
+ if err != nil {
+ return "", err
+ }
+ rowid, err := res.LastInsertId()
+ if err != nil {
+ return "", err
+ }
+ filterID = fmt.Sprintf("%d", rowid)
return
}
diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go
index 0bbd79f7..45862efb 100644
--- a/syncapi/storage/sqlite3/invites_table.go
+++ b/syncapi/storage/sqlite3/invites_table.go
@@ -59,7 +59,6 @@ const selectMaxInviteIDSQL = "" +
type inviteEventsStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
streamIDStatements *streamIDStatements
insertInviteEventStmt *sql.Stmt
selectInviteEventsInRangeStmt *sql.Stmt
@@ -70,7 +69,6 @@ type inviteEventsStatements struct {
func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) {
s := &inviteEventsStatements{
db: db,
- writer: sqlutil.NewTransactionWriter(),
streamIDStatements: streamID,
}
_, err := db.Exec(inviteEventsSchema)
@@ -95,45 +93,37 @@ func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Inv
func (s *inviteEventsStatements) InsertInviteEvent(
ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent,
) (streamPos types.StreamPosition, err error) {
- err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- var err error
- streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
- if err != nil {
- return err
- }
+ streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
+ if err != nil {
+ return
+ }
- var headeredJSON []byte
- headeredJSON, err = json.Marshal(inviteEvent)
- if err != nil {
- return err
- }
+ var headeredJSON []byte
+ headeredJSON, err = json.Marshal(inviteEvent)
+ if err != nil {
+ return
+ }
- _, err = txn.Stmt(s.insertInviteEventStmt).ExecContext(
- ctx,
- streamPos,
- inviteEvent.RoomID(),
- inviteEvent.EventID(),
- *inviteEvent.StateKey(),
- headeredJSON,
- )
- return err
- })
+ stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt)
+ _, err = stmt.ExecContext(
+ ctx,
+ streamPos,
+ inviteEvent.RoomID(),
+ inviteEvent.EventID(),
+ *inviteEvent.StateKey(),
+ headeredJSON,
+ )
return
}
func (s *inviteEventsStatements) DeleteInviteEvent(
ctx context.Context, inviteEventID string,
) (types.StreamPosition, error) {
- var streamPos types.StreamPosition
- err := s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
- var err error
- streamPos, err = s.streamIDStatements.nextStreamID(ctx, nil)
- if err != nil {
- return err
- }
- _, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID)
- return err
- })
+ streamPos, err := s.streamIDStatements.nextStreamID(ctx, nil)
+ if err != nil {
+ return streamPos, err
+ }
+ _, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID)
return streamPos, err
}
diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go
index 0d154650..f10d0106 100644
--- a/syncapi/storage/sqlite3/output_room_events_table.go
+++ b/syncapi/storage/sqlite3/output_room_events_table.go
@@ -105,7 +105,6 @@ const selectStateInRangeSQL = "" +
type outputRoomEventsStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
streamIDStatements *streamIDStatements
insertEventStmt *sql.Stmt
selectEventsStmt *sql.Stmt
@@ -120,7 +119,6 @@ type outputRoomEventsStatements struct {
func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) {
s := &outputRoomEventsStatements{
db: db,
- writer: sqlutil.NewTransactionWriter(),
streamIDStatements: streamID,
}
_, err := db.Exec(outputRoomEventsSchema)
@@ -159,10 +157,8 @@ func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event
if err != nil {
return err
}
- return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
- _, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID())
- return err
- })
+ _, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID())
+ return err
}
// selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos.
@@ -304,32 +300,27 @@ func (s *outputRoomEventsStatements) InsertEvent(
return 0, err
}
- var streamPos types.StreamPosition
- err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
- if err != nil {
- return err
- }
-
- insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
- _, ierr := insertStmt.ExecContext(
- ctx,
- streamPos,
- event.RoomID(),
- event.EventID(),
- headeredJSON,
- event.Type(),
- event.Sender(),
- containsURL,
- string(addStateJSON),
- string(removeStateJSON),
- sessionID,
- txnID,
- excludeFromSync,
- excludeFromSync,
- )
- return ierr
- })
+ streamPos, err := s.streamIDStatements.nextStreamID(ctx, txn)
+ if err != nil {
+ return 0, err
+ }
+ insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
+ _, err = insertStmt.ExecContext(
+ ctx,
+ streamPos,
+ event.RoomID(),
+ event.EventID(),
+ headeredJSON,
+ event.Type(),
+ event.Sender(),
+ containsURL,
+ string(addStateJSON),
+ string(removeStateJSON),
+ sessionID,
+ txnID,
+ excludeFromSync,
+ excludeFromSync,
+ )
return streamPos, err
}
diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go
index 5c4ab005..d8c97b7e 100644
--- a/syncapi/storage/sqlite3/output_room_events_topology_table.go
+++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go
@@ -67,7 +67,6 @@ const selectMaxPositionInTopologySQL = "" +
type outputRoomEventsTopologyStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
insertEventInTopologyStmt *sql.Stmt
selectEventIDsInRangeASCStmt *sql.Stmt
selectEventIDsInRangeDESCStmt *sql.Stmt
@@ -77,8 +76,7 @@ type outputRoomEventsTopologyStatements struct {
func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
s := &outputRoomEventsTopologyStatements{
- db: db,
- writer: sqlutil.NewTransactionWriter(),
+ db: db,
}
_, err := db.Exec(outputRoomEventsTopologySchema)
if err != nil {
@@ -107,13 +105,11 @@ func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
func (s *outputRoomEventsTopologyStatements) InsertEventInTopology(
ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition,
) (err error) {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt)
- _, err := stmt.ExecContext(
- ctx, event.EventID(), event.Depth(), event.RoomID(), pos,
- )
- return err
- })
+ stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt)
+ _, err = stmt.ExecContext(
+ ctx, event.EventID(), event.Depth(), event.RoomID(), pos,
+ )
+ return
}
func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
diff --git a/syncapi/storage/sqlite3/send_to_device_table.go b/syncapi/storage/sqlite3/send_to_device_table.go
index 53786589..fbc759b1 100644
--- a/syncapi/storage/sqlite3/send_to_device_table.go
+++ b/syncapi/storage/sqlite3/send_to_device_table.go
@@ -73,7 +73,6 @@ const deleteSendToDeviceMessagesSQL = `
type sendToDeviceStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
insertSendToDeviceMessageStmt *sql.Stmt
selectSendToDeviceMessagesStmt *sql.Stmt
countSendToDeviceMessagesStmt *sql.Stmt
@@ -81,8 +80,7 @@ type sendToDeviceStatements struct {
func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
s := &sendToDeviceStatements{
- db: db,
- writer: sqlutil.NewTransactionWriter(),
+ db: db,
}
_, err := db.Exec(sendToDeviceSchema)
if err != nil {
@@ -103,10 +101,8 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
) (err error) {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- _, err := sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
- return err
- })
+ _, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
+ return
}
func (s *sendToDeviceStatements) CountSendToDeviceMessages(
@@ -163,10 +159,8 @@ func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(
for k, v := range nids {
params[k+1] = v
}
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- _, err := txn.ExecContext(ctx, query, params...)
- return err
- })
+ _, err = txn.ExecContext(ctx, query, params...)
+ return
}
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
@@ -177,8 +171,6 @@ func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
for k, v := range nids {
params[k] = v
}
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- _, err := txn.ExecContext(ctx, query, params...)
- return err
- })
+ _, err = txn.ExecContext(ctx, query, params...)
+ return
}
diff --git a/syncapi/storage/sqlite3/stream_id_table.go b/syncapi/storage/sqlite3/stream_id_table.go
index 1971e7f3..e6bdc4fc 100644
--- a/syncapi/storage/sqlite3/stream_id_table.go
+++ b/syncapi/storage/sqlite3/stream_id_table.go
@@ -28,14 +28,12 @@ const selectStreamIDStmt = "" +
type streamIDStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
increaseStreamIDStmt *sql.Stmt
selectStreamIDStmt *sql.Stmt
}
func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
s.db = db
- s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(streamIDTableSchema)
if err != nil {
return
@@ -52,14 +50,9 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
func (s *streamIDStatements) nextStreamID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
- err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
- if _, ierr := increaseStmt.ExecContext(ctx, "global"); err != nil {
- return ierr
- }
- if serr := selectStmt.QueryRowContext(ctx, "global").Scan(&pos); err != nil {
- return serr
- }
- return nil
- })
+ if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil {
+ return
+ }
+ err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos)
return
}
diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go
index 9564a23a..81197bb7 100644
--- a/syncapi/storage/sqlite3/syncserver.go
+++ b/syncapi/storage/sqlite3/syncserver.go
@@ -31,7 +31,8 @@ import (
// both the database for PDUs and caches for EDUs.
type SyncServerDatasource struct {
shared.Database
- db *sql.DB
+ db *sql.DB
+ writer sqlutil.Writer
sqlutil.PartitionOffsetStatements
streamID streamIDStatements
}
@@ -44,6 +45,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err
}
+ d.writer = sqlutil.NewExclusiveWriter()
if err = d.prepare(); err != nil {
return nil, err
}
@@ -51,7 +53,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
}
func (d *SyncServerDatasource) prepare() (err error) {
- if err = d.PartitionOffsetStatements.Prepare(d.db, "syncapi"); err != nil {
+ if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "syncapi"); err != nil {
return err
}
if err = d.streamID.prepare(d.db); err != nil {
@@ -91,6 +93,7 @@ func (d *SyncServerDatasource) prepare() (err error) {
}
d.Database = shared.Database{
DB: d.db,
+ Writer: sqlutil.NewExclusiveWriter(),
Invites: invites,
AccountData: accountData,
OutputEvents: events,
@@ -99,7 +102,6 @@ func (d *SyncServerDatasource) prepare() (err error) {
Topology: topology,
Filter: filter,
SendToDevice: sendToDevice,
- SendToDeviceWriter: sqlutil.NewTransactionWriter(),
EDUCache: cache.New(),
}
return nil
diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go
index 9653c019..b36264dd 100644
--- a/userapi/storage/accounts/postgres/storage.go
+++ b/userapi/storage/accounts/postgres/storage.go
@@ -34,7 +34,8 @@ import (
// Database represents an account database
type Database struct {
- db *sql.DB
+ db *sql.DB
+ writer sqlutil.Writer
sqlutil.PartitionOffsetStatements
accounts accountsStatements
profiles profilesStatements
@@ -49,27 +50,27 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err != nil {
return nil, err
}
- partitions := sqlutil.PartitionOffsetStatements{}
- if err = partitions.Prepare(db, "account"); err != nil {
+ d := &Database{
+ serverName: serverName,
+ db: db,
+ writer: sqlutil.NewDummyWriter(),
+ }
+ if err = d.PartitionOffsetStatements.Prepare(db, d.writer, "account"); err != nil {
return nil, err
}
- a := accountsStatements{}
- if err = a.prepare(db, serverName); err != nil {
+ if err = d.accounts.prepare(db, serverName); err != nil {
return nil, err
}
- p := profilesStatements{}
- if err = p.prepare(db); err != nil {
+ if err = d.profiles.prepare(db); err != nil {
return nil, err
}
- ac := accountDataStatements{}
- if err = ac.prepare(db); err != nil {
+ if err = d.accountDatas.prepare(db); err != nil {
return nil, err
}
- t := threepidStatements{}
- if err = t.prepare(db); err != nil {
+ if err = d.threepids.prepare(db); err != nil {
return nil, err
}
- return &Database{db, partitions, a, p, ac, t, serverName}, nil
+ return d, nil
}
// GetAccountByPassword returns the account associated with the given localpart and password.
diff --git a/userapi/storage/accounts/sqlite3/account_data_table.go b/userapi/storage/accounts/sqlite3/account_data_table.go
index 9b40e657..aee8db6e 100644
--- a/userapi/storage/accounts/sqlite3/account_data_table.go
+++ b/userapi/storage/accounts/sqlite3/account_data_table.go
@@ -51,15 +51,15 @@ const selectAccountDataByTypeSQL = "" +
type accountDataStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
+ writer sqlutil.Writer
insertAccountDataStmt *sql.Stmt
selectAccountDataStmt *sql.Stmt
selectAccountDataByTypeStmt *sql.Stmt
}
-func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
+func (s *accountDataStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
s.db = db
- s.writer = sqlutil.NewTransactionWriter()
+ s.writer = writer
_, err = db.Exec(accountDataSchema)
if err != nil {
return
diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go
index 586bcab9..83b90668 100644
--- a/userapi/storage/accounts/sqlite3/accounts_table.go
+++ b/userapi/storage/accounts/sqlite3/accounts_table.go
@@ -59,7 +59,7 @@ const selectNewNumericLocalpartSQL = "" +
type accountsStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
+ writer sqlutil.Writer
insertAccountStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
@@ -67,9 +67,9 @@ type accountsStatements struct {
serverName gomatrixserverlib.ServerName
}
-func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
+func (s *accountsStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) {
s.db = db
- s.writer = sqlutil.NewTransactionWriter()
+ s.writer = writer
_, err = db.Exec(accountsSchema)
if err != nil {
return
diff --git a/userapi/storage/accounts/sqlite3/profile_table.go b/userapi/storage/accounts/sqlite3/profile_table.go
index cd35d298..1ec45e03 100644
--- a/userapi/storage/accounts/sqlite3/profile_table.go
+++ b/userapi/storage/accounts/sqlite3/profile_table.go
@@ -53,7 +53,7 @@ const selectProfilesBySearchSQL = "" +
type profilesStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
+ writer sqlutil.Writer
insertProfileStmt *sql.Stmt
selectProfileByLocalpartStmt *sql.Stmt
setAvatarURLStmt *sql.Stmt
@@ -61,9 +61,9 @@ type profilesStatements struct {
selectProfilesBySearchStmt *sql.Stmt
}
-func (s *profilesStatements) prepare(db *sql.DB) (err error) {
+func (s *profilesStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
s.db = db
- s.writer = sqlutil.NewTransactionWriter()
+ s.writer = writer
_, err = db.Exec(profilesSchema)
if err != nil {
return
diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go
index 4d2c5e51..4f45f754 100644
--- a/userapi/storage/accounts/sqlite3/storage.go
+++ b/userapi/storage/accounts/sqlite3/storage.go
@@ -33,7 +33,9 @@ import (
// Database represents an account database
type Database struct {
- db *sql.DB
+ db *sql.DB
+ writer sqlutil.Writer
+
sqlutil.PartitionOffsetStatements
accounts accountsStatements
profiles profilesStatements
@@ -53,35 +55,28 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err != nil {
return nil, err
}
+ d := &Database{
+ serverName: serverName,
+ db: db,
+ writer: sqlutil.NewExclusiveWriter(),
+ }
partitions := sqlutil.PartitionOffsetStatements{}
- if err = partitions.Prepare(db, "account"); err != nil {
+ if err = partitions.Prepare(db, d.writer, "account"); err != nil {
return nil, err
}
- a := accountsStatements{}
- if err = a.prepare(db, serverName); err != nil {
+ if err = d.accounts.prepare(db, d.writer, serverName); err != nil {
return nil, err
}
- p := profilesStatements{}
- if err = p.prepare(db); err != nil {
+ if err = d.profiles.prepare(db, d.writer); err != nil {
return nil, err
}
- ac := accountDataStatements{}
- if err = ac.prepare(db); err != nil {
+ if err = d.accountDatas.prepare(db, d.writer); err != nil {
return nil, err
}
- t := threepidStatements{}
- if err = t.prepare(db); err != nil {
+ if err = d.threepids.prepare(db, d.writer); err != nil {
return nil, err
}
- return &Database{
- db: db,
- PartitionOffsetStatements: partitions,
- accounts: a,
- profiles: p,
- accountDatas: ac,
- threepids: t,
- serverName: serverName,
- }, nil
+ return d, nil
}
// GetAccountByPassword returns the account associated with the given localpart and password.
diff --git a/userapi/storage/accounts/sqlite3/threepid_table.go b/userapi/storage/accounts/sqlite3/threepid_table.go
index 3000d7c4..230978fe 100644
--- a/userapi/storage/accounts/sqlite3/threepid_table.go
+++ b/userapi/storage/accounts/sqlite3/threepid_table.go
@@ -54,16 +54,16 @@ const deleteThreePIDSQL = "" +
type threepidStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
+ writer sqlutil.Writer
selectLocalpartForThreePIDStmt *sql.Stmt
selectThreePIDsForLocalpartStmt *sql.Stmt
insertThreePIDStmt *sql.Stmt
deleteThreePIDStmt *sql.Stmt
}
-func (s *threepidStatements) prepare(db *sql.DB) (err error) {
+func (s *threepidStatements) prepare(db *sql.DB, writer sqlutil.Writer) (err error) {
s.db = db
- s.writer = sqlutil.NewTransactionWriter()
+ s.writer = writer
_, err = db.Exec(threepidSchema)
if err != nil {
return
diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go
index 962e63b0..c93e8b77 100644
--- a/userapi/storage/devices/sqlite3/devices_table.go
+++ b/userapi/storage/devices/sqlite3/devices_table.go
@@ -78,7 +78,7 @@ const selectDevicesByIDSQL = "" +
type devicesStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
+ writer sqlutil.Writer
insertDeviceStmt *sql.Stmt
selectDevicesCountStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt
@@ -91,9 +91,9 @@ type devicesStatements struct {
serverName gomatrixserverlib.ServerName
}
-func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
+func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) {
s.db = db
- s.writer = sqlutil.NewTransactionWriter()
+ s.writer = writer
_, err = db.Exec(devicesSchema)
if err != nil {
return
@@ -138,19 +138,13 @@ func (s *devicesStatements) insertDevice(
) (*api.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64
- err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt)
- insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
- if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil {
- return err
- }
- sessionID++
- if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil {
- return err
- }
- return nil
- })
- if err != nil {
+ countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt)
+ insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
+ if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil {
+ return nil, err
+ }
+ sessionID++
+ if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil {
return nil, err
}
return &api.Device{
@@ -164,11 +158,9 @@ func (s *devicesStatements) insertDevice(
func (s *devicesStatements) deleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string,
) error {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
- _, err := stmt.ExecContext(ctx, id, localpart)
- return err
- })
+ stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
+ _, err := stmt.ExecContext(ctx, id, localpart)
+ return err
}
func (s *devicesStatements) deleteDevices(
@@ -179,36 +171,30 @@ func (s *devicesStatements) deleteDevices(
if err != nil {
return err
}
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, prep)
- params := make([]interface{}, len(devices)+1)
- params[0] = localpart
- for i, v := range devices {
- params[i+1] = v
- }
- _, err = stmt.ExecContext(ctx, params...)
- return err
- })
+ stmt := sqlutil.TxStmt(txn, prep)
+ params := make([]interface{}, len(devices)+1)
+ params[0] = localpart
+ for i, v := range devices {
+ params[i+1] = v
+ }
+ _, err = stmt.ExecContext(ctx, params...)
+ return err
}
func (s *devicesStatements) deleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string,
) error {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
- _, err := stmt.ExecContext(ctx, localpart)
- return err
- })
+ stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
+ _, err := stmt.ExecContext(ctx, localpart)
+ return err
}
func (s *devicesStatements) updateDeviceName(
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
) error {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
- _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
- return err
- })
+ stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
+ _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
+ return err
}
func (s *devicesStatements) selectDeviceByToken(
diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go
index 1f2b59f3..4f426c6e 100644
--- a/userapi/storage/devices/sqlite3/storage.go
+++ b/userapi/storage/devices/sqlite3/storage.go
@@ -34,6 +34,7 @@ var deviceIDByteLength = 6
// Database represents a device database.
type Database struct {
db *sql.DB
+ writer sqlutil.Writer
devices devicesStatements
}
@@ -43,11 +44,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err != nil {
return nil, err
}
+ writer := sqlutil.NewExclusiveWriter()
d := devicesStatements{}
- if err = d.prepare(db, serverName); err != nil {
+ if err = d.prepare(db, writer, serverName); err != nil {
return nil, err
}
- return &Database{db, d}, nil
+ return &Database{db, writer, d}, nil
}
// GetDeviceByAccessToken returns the device matching the given access token.
@@ -88,7 +90,7 @@ func (d *Database) CreateDevice(
displayName *string,
) (dev *api.Device, returnErr error) {
if deviceID != nil {
- returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
@@ -108,7 +110,7 @@ func (d *Database) CreateDevice(
return
}
- returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName)
return err
@@ -138,7 +140,7 @@ func generateDeviceID() (string, error) {
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
- return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
}
@@ -150,7 +152,7 @@ func (d *Database) UpdateDevice(
func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
- return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err
}
@@ -165,7 +167,7 @@ func (d *Database) RemoveDevice(
func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string,
) error {
- return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
return err
}
@@ -179,7 +181,7 @@ func (d *Database) RemoveDevices(
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart string,
) error {
- return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
return err
}