aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNeil Alexander <neilalexander@users.noreply.github.com>2020-08-19 15:38:27 +0100
committerGitHub <noreply@github.com>2020-08-19 15:38:27 +0100
commitb24747b305a0770fdd746655e702aa1c1c049765 (patch)
tree88d94b762fafb4852421eb243313edbfc96ccfa9
parent775b04d776ddc06fdee5ece6a407008f00edb7f2 (diff)
Transaction writer changes, move roomserver writers (#1285)
* Updated TransactionWriters, moved locks in roomserver, various other tweaks * Fix redaction deadlocks * Fix lint issue * Rename SQLiteTransactionWriter to ExclusiveTransactionWriter * Fix us not sending transactions through in latest events updater
-rw-r--r--appservice/storage/sqlite3/appservice_events_table.go2
-rw-r--r--appservice/storage/sqlite3/txn_id_counter_table.go2
-rw-r--r--currentstateserver/storage/sqlite3/current_room_state_table.go2
-rw-r--r--federationsender/storage/postgres/blacklist_table.go20
-rw-r--r--federationsender/storage/sqlite3/blacklist_table.go2
-rw-r--r--federationsender/storage/sqlite3/joined_hosts_table.go2
-rw-r--r--federationsender/storage/sqlite3/queue_edus_table.go2
-rw-r--r--federationsender/storage/sqlite3/queue_json_table.go2
-rw-r--r--federationsender/storage/sqlite3/queue_pdus_table.go2
-rw-r--r--federationsender/storage/sqlite3/room_table.go2
-rw-r--r--internal/sqlutil/sql.go71
-rw-r--r--internal/sqlutil/writer_dummy.go22
-rw-r--r--internal/sqlutil/writer_exclusive.go75
-rw-r--r--keyserver/storage/sqlite3/device_keys_table.go2
-rw-r--r--keyserver/storage/sqlite3/key_changes_table.go2
-rw-r--r--keyserver/storage/sqlite3/one_time_keys_table.go2
-rw-r--r--mediaapi/storage/sqlite3/media_repository_table.go2
-rw-r--r--roomserver/internal/input_latest_events.go36
-rw-r--r--roomserver/state/state.go22
-rw-r--r--roomserver/storage/postgres/storage.go1
-rw-r--r--roomserver/storage/shared/latest_events_updater.go26
-rw-r--r--roomserver/storage/shared/membership_updater.go34
-rw-r--r--roomserver/storage/shared/storage.go43
-rw-r--r--roomserver/storage/sqlite3/event_json_table.go12
-rw-r--r--roomserver/storage/sqlite3/event_state_keys_table.go28
-rw-r--r--roomserver/storage/sqlite3/event_types_table.go27
-rw-r--r--roomserver/storage/sqlite3/events_table.go54
-rw-r--r--roomserver/storage/sqlite3/invite_table.go66
-rw-r--r--roomserver/storage/sqlite3/membership_table.go26
-rw-r--r--roomserver/storage/sqlite3/previous_events_table.go18
-rw-r--r--roomserver/storage/sqlite3/published_table.go16
-rw-r--r--roomserver/storage/sqlite3/redactions_table.go22
-rw-r--r--roomserver/storage/sqlite3/room_aliases_table.go25
-rw-r--r--roomserver/storage/sqlite3/rooms_table.go46
-rw-r--r--roomserver/storage/sqlite3/state_block_table.go37
-rw-r--r--roomserver/storage/sqlite3/state_snapshot_table.go29
-rw-r--r--roomserver/storage/sqlite3/storage.go32
-rw-r--r--roomserver/storage/sqlite3/transactions_table.go20
-rw-r--r--serverkeyapi/storage/sqlite3/server_key_table.go2
-rw-r--r--syncapi/storage/shared/syncserver.go2
-rw-r--r--syncapi/storage/sqlite3/account_data_table.go2
-rw-r--r--syncapi/storage/sqlite3/backwards_extremities_table.go2
-rw-r--r--syncapi/storage/sqlite3/current_room_state_table.go2
-rw-r--r--syncapi/storage/sqlite3/filter_table.go2
-rw-r--r--syncapi/storage/sqlite3/invites_table.go2
-rw-r--r--syncapi/storage/sqlite3/output_room_events_table.go2
-rw-r--r--syncapi/storage/sqlite3/output_room_events_topology_table.go2
-rw-r--r--syncapi/storage/sqlite3/send_to_device_table.go2
-rw-r--r--syncapi/storage/sqlite3/stream_id_table.go2
-rw-r--r--userapi/storage/accounts/sqlite3/account_data_table.go2
-rw-r--r--userapi/storage/accounts/sqlite3/accounts_table.go2
-rw-r--r--userapi/storage/accounts/sqlite3/profile_table.go2
-rw-r--r--userapi/storage/accounts/sqlite3/threepid_table.go2
-rw-r--r--userapi/storage/devices/sqlite3/devices_table.go2
54 files changed, 432 insertions, 434 deletions
diff --git a/appservice/storage/sqlite3/appservice_events_table.go b/appservice/storage/sqlite3/appservice_events_table.go
index da31f235..5cc07ed3 100644
--- a/appservice/storage/sqlite3/appservice_events_table.go
+++ b/appservice/storage/sqlite3/appservice_events_table.go
@@ -67,7 +67,7 @@ const (
type eventsStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
+ writer sqlutil.TransactionWriter
selectEventsByApplicationServiceIDStmt *sql.Stmt
countEventsByApplicationServiceIDStmt *sql.Stmt
insertEventStmt *sql.Stmt
diff --git a/appservice/storage/sqlite3/txn_id_counter_table.go b/appservice/storage/sqlite3/txn_id_counter_table.go
index 501ab5aa..0ae0feee 100644
--- a/appservice/storage/sqlite3/txn_id_counter_table.go
+++ b/appservice/storage/sqlite3/txn_id_counter_table.go
@@ -38,7 +38,7 @@ const selectTxnIDSQL = `
type txnStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
+ writer sqlutil.TransactionWriter
selectTxnIDStmt *sql.Stmt
}
diff --git a/currentstateserver/storage/sqlite3/current_room_state_table.go b/currentstateserver/storage/sqlite3/current_room_state_table.go
index 5c7e8b0a..9d2fe6e0 100644
--- a/currentstateserver/storage/sqlite3/current_room_state_table.go
+++ b/currentstateserver/storage/sqlite3/current_room_state_table.go
@@ -83,7 +83,7 @@ const selectKnownUsersSQL = "" +
type currentRoomStateStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
+ writer sqlutil.TransactionWriter
upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt
selectRoomIDsWithMembershipStmt *sql.Stmt
diff --git a/federationsender/storage/postgres/blacklist_table.go b/federationsender/storage/postgres/blacklist_table.go
index 8de6feec..f92c59e5 100644
--- a/federationsender/storage/postgres/blacklist_table.go
+++ b/federationsender/storage/postgres/blacklist_table.go
@@ -42,7 +42,6 @@ const deleteBlacklistSQL = "" +
type blacklistStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
insertBlacklistStmt *sql.Stmt
selectBlacklistStmt *sql.Stmt
deleteBlacklistStmt *sql.Stmt
@@ -50,8 +49,7 @@ type blacklistStatements struct {
func NewPostgresBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) {
s = &blacklistStatements{
- db: db,
- writer: sqlutil.NewTransactionWriter(),
+ db: db,
}
_, err = db.Exec(blacklistSchema)
if err != nil {
@@ -75,11 +73,9 @@ func NewPostgresBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) {
func (s *blacklistStatements) InsertBlacklist(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) error {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt)
- _, err := stmt.ExecContext(ctx, serverName)
- return err
- })
+ stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt)
+ _, err := stmt.ExecContext(ctx, serverName)
+ return err
}
// selectRoomForUpdate locks the row for the room and returns the last_event_id.
@@ -105,9 +101,7 @@ func (s *blacklistStatements) SelectBlacklist(
func (s *blacklistStatements) DeleteBlacklist(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) error {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt)
- _, err := stmt.ExecContext(ctx, serverName)
- return err
- })
+ stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt)
+ _, err := stmt.ExecContext(ctx, serverName)
+ return err
}
diff --git a/federationsender/storage/sqlite3/blacklist_table.go b/federationsender/storage/sqlite3/blacklist_table.go
index a14fe0c4..b23bfcba 100644
--- a/federationsender/storage/sqlite3/blacklist_table.go
+++ b/federationsender/storage/sqlite3/blacklist_table.go
@@ -42,7 +42,7 @@ const deleteBlacklistSQL = "" +
type blacklistStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
+ writer sqlutil.TransactionWriter
insertBlacklistStmt *sql.Stmt
selectBlacklistStmt *sql.Stmt
deleteBlacklistStmt *sql.Stmt
diff --git a/federationsender/storage/sqlite3/joined_hosts_table.go b/federationsender/storage/sqlite3/joined_hosts_table.go
index 53736fa1..5dc18f4e 100644
--- a/federationsender/storage/sqlite3/joined_hosts_table.go
+++ b/federationsender/storage/sqlite3/joined_hosts_table.go
@@ -65,7 +65,7 @@ const selectJoinedHostsForRoomsSQL = "" +
type joinedHostsStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
+ writer sqlutil.TransactionWriter
insertJoinedHostsStmt *sql.Stmt
deleteJoinedHostsStmt *sql.Stmt
selectJoinedHostsStmt *sql.Stmt
diff --git a/federationsender/storage/sqlite3/queue_edus_table.go b/federationsender/storage/sqlite3/queue_edus_table.go
index cd11a0ea..2abcc105 100644
--- a/federationsender/storage/sqlite3/queue_edus_table.go
+++ b/federationsender/storage/sqlite3/queue_edus_table.go
@@ -64,7 +64,7 @@ const selectQueueServerNamesSQL = "" +
type queueEDUsStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
+ writer sqlutil.TransactionWriter
insertQueueEDUStmt *sql.Stmt
selectQueueEDUStmt *sql.Stmt
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
diff --git a/federationsender/storage/sqlite3/queue_json_table.go b/federationsender/storage/sqlite3/queue_json_table.go
index 46dfd9ab..867ffd44 100644
--- a/federationsender/storage/sqlite3/queue_json_table.go
+++ b/federationsender/storage/sqlite3/queue_json_table.go
@@ -50,7 +50,7 @@ const selectJSONSQL = "" +
type queueJSONStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
+ writer sqlutil.TransactionWriter
insertJSONStmt *sql.Stmt
//deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic
//selectJSONStmt *sql.Stmt - prepared at runtime due to variadic
diff --git a/federationsender/storage/sqlite3/queue_pdus_table.go b/federationsender/storage/sqlite3/queue_pdus_table.go
index 1474bfc0..538ba3db 100644
--- a/federationsender/storage/sqlite3/queue_pdus_table.go
+++ b/federationsender/storage/sqlite3/queue_pdus_table.go
@@ -71,7 +71,7 @@ const selectQueuePDUsServerNamesSQL = "" +
type queuePDUsStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
+ writer sqlutil.TransactionWriter
insertQueuePDUStmt *sql.Stmt
selectQueueNextTransactionIDStmt *sql.Stmt
selectQueuePDUsByTransactionStmt *sql.Stmt
diff --git a/federationsender/storage/sqlite3/room_table.go b/federationsender/storage/sqlite3/room_table.go
index 51793874..9a439fad 100644
--- a/federationsender/storage/sqlite3/room_table.go
+++ b/federationsender/storage/sqlite3/room_table.go
@@ -44,7 +44,7 @@ const updateRoomSQL = "" +
type roomStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
+ writer sqlutil.TransactionWriter
insertRoomStmt *sql.Stmt
selectRoomForUpdateStmt *sql.Stmt
updateRoomStmt *sql.Stmt
diff --git a/internal/sqlutil/sql.go b/internal/sqlutil/sql.go
index 95467c63..002d7718 100644
--- a/internal/sqlutil/sql.go
+++ b/internal/sqlutil/sql.go
@@ -19,8 +19,6 @@ import (
"errors"
"fmt"
"runtime"
-
- "go.uber.org/atomic"
)
// ErrUserExists is returned if a username already exists in the database.
@@ -52,7 +50,7 @@ func EndTransaction(txn Transaction, succeeded *bool) error {
func WithTransaction(db *sql.DB, fn func(txn *sql.Tx) error) (err error) {
txn, err := db.Begin()
if err != nil {
- return
+ return fmt.Errorf("sqlutil.WithTransaction.Begin: %w", err)
}
succeeded := false
defer func() {
@@ -106,69 +104,6 @@ func SQLiteDriverName() string {
return "sqlite3"
}
-// TransactionWriter allows queuing database writes so that you don't
-// contend on database locks in, e.g. SQLite. Only one task will run
-// at a time on a given TransactionWriter.
-type TransactionWriter struct {
- running atomic.Bool
- todo chan transactionWriterTask
-}
-
-func NewTransactionWriter() *TransactionWriter {
- return &TransactionWriter{
- todo: make(chan transactionWriterTask),
- }
-}
-
-// transactionWriterTask represents a specific task.
-type transactionWriterTask struct {
- db *sql.DB
- txn *sql.Tx
- f func(txn *sql.Tx) error
- wait chan error
-}
-
-// Do queues a task to be run by a TransactionWriter. The function
-// provided will be ran within a transaction as supplied by the
-// txn parameter if one is supplied, and if not, will take out a
-// new transaction from the database supplied in the database
-// parameter. Either way, this will block until the task is done.
-func (w *TransactionWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error {
- if w.todo == nil {
- return errors.New("not initialised")
- }
- if !w.running.Load() {
- go w.run()
- }
- task := transactionWriterTask{
- db: db,
- txn: txn,
- f: f,
- wait: make(chan error, 1),
- }
- w.todo <- task
- return <-task.wait
-}
-
-// run processes the tasks for a given transaction writer. Only one
-// of these goroutines will run at a time. A transaction will be
-// opened using the database object from the task and then this will
-// be passed as a parameter to the task function.
-func (w *TransactionWriter) run() {
- if !w.running.CAS(false, true) {
- return
- }
- defer w.running.Store(false)
- for task := range w.todo {
- if task.txn != nil {
- task.wait <- task.f(task.txn)
- } else if task.db != nil {
- task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error {
- return task.f(txn)
- })
- } else {
- panic("expected database or transaction but got neither")
- }
- close(task.wait)
- }
+type TransactionWriter interface {
+ Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error
}
diff --git a/internal/sqlutil/writer_dummy.go b/internal/sqlutil/writer_dummy.go
new file mode 100644
index 00000000..e6ab81f6
--- /dev/null
+++ b/internal/sqlutil/writer_dummy.go
@@ -0,0 +1,22 @@
+package sqlutil
+
+import (
+ "database/sql"
+)
+
+type DummyTransactionWriter struct {
+}
+
+func NewDummyTransactionWriter() TransactionWriter {
+ return &DummyTransactionWriter{}
+}
+
+func (w *DummyTransactionWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error {
+ if txn == nil {
+ return WithTransaction(db, func(txn *sql.Tx) error {
+ return f(txn)
+ })
+ } else {
+ return f(txn)
+ }
+}
diff --git a/internal/sqlutil/writer_exclusive.go b/internal/sqlutil/writer_exclusive.go
new file mode 100644
index 00000000..2e3666ae
--- /dev/null
+++ b/internal/sqlutil/writer_exclusive.go
@@ -0,0 +1,75 @@
+package sqlutil
+
+import (
+ "database/sql"
+ "errors"
+
+ "go.uber.org/atomic"
+)
+
+// ExclusiveTransactionWriter allows queuing database writes so that you don't
+// contend on database locks in, e.g. SQLite. Only one task will run
+// at a time on a given ExclusiveTransactionWriter.
+type ExclusiveTransactionWriter struct {
+ running atomic.Bool
+ todo chan transactionWriterTask
+}
+
+func NewTransactionWriter() TransactionWriter {
+ return &ExclusiveTransactionWriter{
+ todo: make(chan transactionWriterTask),
+ }
+}
+
+// transactionWriterTask represents a specific task.
+type transactionWriterTask struct {
+ db *sql.DB
+ txn *sql.Tx
+ f func(txn *sql.Tx) error
+ wait chan error
+}
+
+// Do queues a task to be run by a TransactionWriter. The function
+// provided will be ran within a transaction as supplied by the
+// txn parameter if one is supplied, and if not, will take out a
+// new transaction from the database supplied in the database
+// parameter. Either way, this will block until the task is done.
+func (w *ExclusiveTransactionWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error {
+ if w.todo == nil {
+ return errors.New("not initialised")
+ }
+ if !w.running.Load() {
+ go w.run()
+ }
+ task := transactionWriterTask{
+ db: db,
+ txn: txn,
+ f: f,
+ wait: make(chan error, 1),
+ }
+ w.todo <- task
+ return <-task.wait
+}
+
+// run processes the tasks for a given transaction writer. Only one
+// of these goroutines will run at a time. A transaction will be
+// opened using the database object from the task and then this will
+// be passed as a parameter to the task function.
+func (w *ExclusiveTransactionWriter) run() {
+ if !w.running.CAS(false, true) {
+ return
+ }
+ defer w.running.Store(false)
+ for task := range w.todo {
+ if task.txn != nil {
+ task.wait <- task.f(task.txn)
+ } else if task.db != nil {
+ task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error {
+ return task.f(txn)
+ })
+ } else {
+ panic("expected database or transaction but got neither")
+ }
+ close(task.wait)
+ }
+}
diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go
index a4d71fe1..c95790be 100644
--- a/keyserver/storage/sqlite3/device_keys_table.go
+++ b/keyserver/storage/sqlite3/device_keys_table.go
@@ -63,7 +63,7 @@ const deleteAllDeviceKeysSQL = "" +
type deviceKeysStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
+ writer sqlutil.TransactionWriter
upsertDeviceKeysStmt *sql.Stmt
selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt
diff --git a/keyserver/storage/sqlite3/key_changes_table.go b/keyserver/storage/sqlite3/key_changes_table.go
index 02b9d193..f451d657 100644
--- a/keyserver/storage/sqlite3/key_changes_table.go
+++ b/keyserver/storage/sqlite3/key_changes_table.go
@@ -52,7 +52,7 @@ const selectKeyChangesSQL = "" +
type keyChangesStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
+ writer sqlutil.TransactionWriter
upsertKeyChangeStmt *sql.Stmt
selectKeyChangesStmt *sql.Stmt
}
diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/keyserver/storage/sqlite3/one_time_keys_table.go
index 907966a7..c71cc47d 100644
--- a/keyserver/storage/sqlite3/one_time_keys_table.go
+++ b/keyserver/storage/sqlite3/one_time_keys_table.go
@@ -60,7 +60,7 @@ const selectKeyByAlgorithmSQL = "" +
type oneTimeKeysStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
+ writer sqlutil.TransactionWriter
upsertKeysStmt *sql.Stmt
selectKeysStmt *sql.Stmt
selectKeysCountStmt *sql.Stmt
diff --git a/mediaapi/storage/sqlite3/media_repository_table.go b/mediaapi/storage/sqlite3/media_repository_table.go
index f53f164d..ff6ddf3d 100644
--- a/mediaapi/storage/sqlite3/media_repository_table.go
+++ b/mediaapi/storage/sqlite3/media_repository_table.go
@@ -62,7 +62,7 @@ SELECT content_type, file_size_bytes, creation_ts, upload_name, base64hash, user
type mediaStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
+ writer sqlutil.TransactionWriter
insertMediaStmt *sql.Stmt
selectMediaStmt *sql.Stmt
}
diff --git a/roomserver/internal/input_latest_events.go b/roomserver/internal/input_latest_events.go
index 0158c8f7..3be5218d 100644
--- a/roomserver/internal/input_latest_events.go
+++ b/roomserver/internal/input_latest_events.go
@@ -57,7 +57,7 @@ func (r *RoomserverInternalAPI) updateLatestEvents(
) (err error) {
updater, err := r.DB.GetLatestEventsForUpdate(ctx, roomNID)
if err != nil {
- return
+ return fmt.Errorf("r.DB.GetLatestEventsForUpdate: %w", err)
}
succeeded := false
defer func() {
@@ -79,7 +79,7 @@ func (r *RoomserverInternalAPI) updateLatestEvents(
}
if err = u.doUpdateLatestEvents(); err != nil {
- return err
+ return fmt.Errorf("u.doUpdateLatestEvents: %w", err)
}
succeeded = true
@@ -137,7 +137,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
// don't need to do anything, as we've handled it already.
hasBeenSent, err := u.updater.HasEventBeenSent(u.stateAtEvent.EventNID)
if err != nil {
- return err
+ return fmt.Errorf("u.updater.HasEventBeenSent: %w", err)
} else if hasBeenSent {
return nil
}
@@ -145,7 +145,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
// Update the roomserver_previous_events table with references. This
// is effectively tracking the structure of the DAG.
if err = u.updater.StorePreviousEvents(u.stateAtEvent.EventNID, prevEvents); err != nil {
- return err
+ return fmt.Errorf("u.updater.StorePreviousEvents: %w", err)
}
// Get the event reference for our new event. This will be used when
@@ -156,7 +156,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
// in the room. If it is then it isn't a latest event.
alreadyReferenced, err := u.updater.IsReferenced(eventReference)
if err != nil {
- return err
+ return fmt.Errorf("u.updater.IsReferenced: %w", err)
}
// Work out what the latest events are.
@@ -173,19 +173,19 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
// Now that we know what the latest events are, it's time to get the
// latest state.
if err = u.latestState(); err != nil {
- return err
+ return fmt.Errorf("u.latestState: %w", err)
}
// If we need to generate any output events then here's where we do it.
// TODO: Move this!
updates, err := u.api.updateMemberships(u.ctx, u.updater, u.removed, u.added)
if err != nil {
- return err
+ return fmt.Errorf("u.api.updateMemberships: %w", err)
}
update, err := u.makeOutputNewRoomEvent()
if err != nil {
- return err
+ return fmt.Errorf("u.makeOutputNewRoomEvent: %w", err)
}
updates = append(updates, *update)
@@ -198,14 +198,18 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
// the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the
// necessary bookkeeping we'll keep the event sending synchronous for now.
if err = u.api.WriteOutputEvents(u.event.RoomID(), updates); err != nil {
- return err
+ return fmt.Errorf("u.api.WriteOutputEvents: %w", err)
}
if err = u.updater.SetLatestEvents(u.roomNID, u.latest, u.stateAtEvent.EventNID, u.newStateNID); err != nil {
- return err
+ return fmt.Errorf("u.updater.SetLatestEvents: %w", err)
}
- return u.updater.MarkEventAsSent(u.stateAtEvent.EventNID)
+ if err = u.updater.MarkEventAsSent(u.stateAtEvent.EventNID); err != nil {
+ return fmt.Errorf("u.updater.MarkEventAsSent: %w", err)
+ }
+
+ return nil
}
func (u *latestEventsUpdater) latestState() error {
@@ -225,7 +229,7 @@ func (u *latestEventsUpdater) latestState() error {
u.ctx, u.roomNID, latestStateAtEvents,
)
if err != nil {
- return err
+ return fmt.Errorf("roomState.CalculateAndStoreStateAfterEvents: %w", err)
}
// If we are overwriting the state then we should make sure that we
@@ -244,7 +248,7 @@ func (u *latestEventsUpdater) latestState() error {
u.ctx, u.oldStateNID, u.newStateNID,
)
if err != nil {
- return err
+ return fmt.Errorf("roomState.DifferenceBetweenStateSnapshots: %w", err)
}
// Also work out the state before the event removes and the event
@@ -252,7 +256,11 @@ func (u *latestEventsUpdater) latestState() error {
u.stateBeforeEventRemoves, u.stateBeforeEventAdds, err = roomState.DifferenceBetweeenStateSnapshots(
u.ctx, u.newStateNID, u.stateAtEvent.BeforeStateSnapshotNID,
)
- return err
+ if err != nil {
+ return fmt.Errorf("roomState.DifferenceBetweeenStateSnapshots: %w", err)
+ }
+
+ return nil
}
func calculateLatest(
diff --git a/roomserver/state/state.go b/roomserver/state/state.go
index d5be4a90..b9ad4a50 100644
--- a/roomserver/state/state.go
+++ b/roomserver/state/state.go
@@ -558,7 +558,11 @@ func (v StateResolution) CalculateAndStoreStateAfterEvents(
// 2) There weren't any prev_events for this event so the state is
// empty.
metrics.algorithm = "empty_state"
- return metrics.stop(v.db.AddState(ctx, roomNID, nil, nil))
+ stateNID, err := v.db.AddState(ctx, roomNID, nil, nil)
+ if err != nil {
+ err = fmt.Errorf("v.db.AddState: %w", err)
+ }
+ return metrics.stop(stateNID, err)
}
if len(prevStates) == 1 {
@@ -578,22 +582,30 @@ func (v StateResolution) CalculateAndStoreStateAfterEvents(
)
if err != nil {
metrics.algorithm = "_load_state_blocks"
- return metrics.stop(0, err)
+ return metrics.stop(0, fmt.Errorf("v.db.StateBlockNIDs: %w", err))
}
stateBlockNIDs := stateBlockNIDLists[0].StateBlockNIDs
if len(stateBlockNIDs) < maxStateBlockNIDs {
// 4) The number of state data blocks is small enough that we can just
// add the state event as a block of size one to the end of the blocks.
metrics.algorithm = "single_delta"
- return metrics.stop(v.db.AddState(
+ stateNID, err := v.db.AddState(
ctx, roomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry},
- ))
+ )
+ if err != nil {
+ err = fmt.Errorf("v.db.AddState: %w", err)
+ }
+ return metrics.stop(stateNID, err)
}
// If there are too many deltas then we need to calculate the full state
// So fall through to calculateAndStoreStateAfterManyEvents
}
- return v.calculateAndStoreStateAfterManyEvents(ctx, roomNID, prevStates, metrics)
+ stateNID, err := v.calculateAndStoreStateAfterManyEvents(ctx, roomNID, prevStates, metrics)
+ if err != nil {
+ return 0, fmt.Errorf("v.calculateAndStoreStateAfterManyEvents: %w", err)
+ }
+ return stateNID, nil
}
// maxStateBlockNIDs is the maximum number of state data blocks to use to encode a snapshot of room state.
diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go
index 52ff479b..0b7ed225 100644
--- a/roomserver/storage/postgres/storage.go
+++ b/roomserver/storage/postgres/storage.go
@@ -98,6 +98,7 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
}
d.Database = shared.Database{
DB: db,
+ Writer: sqlutil.NewDummyTransactionWriter(),
EventTypesTable: eventTypes,
EventStateKeysTable: eventStateKeys,
EventJSONTable: eventJSON,
diff --git a/roomserver/storage/shared/latest_events_updater.go b/roomserver/storage/shared/latest_events_updater.go
index 21b168a4..e9a0f698 100644
--- a/roomserver/storage/shared/latest_events_updater.go
+++ b/roomserver/storage/shared/latest_events_updater.go
@@ -3,6 +3,7 @@ package shared
import (
"context"
"database/sql"
+ "fmt"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
@@ -65,12 +66,14 @@ func (u *LatestEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID {
// StorePreviousEvents implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
- for _, ref := range previousEventReferences {
- if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
- return err
+ return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
+ for _, ref := range previousEventReferences {
+ if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
+ return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err)
+ }
}
- }
- return nil
+ return nil
+ })
}
// IsReferenced implements types.RoomRecentEventsUpdater
@@ -82,7 +85,7 @@ func (u *LatestEventsUpdater) IsReferenced(eventReference gomatrixserverlib.Even
if err == sql.ErrNoRows {
return false, nil
}
- return false, err
+ return false, fmt.Errorf("u.d.PrevEventsTable.SelectPreviousEventExists: %w", err)
}
// SetLatestEvents implements types.RoomRecentEventsUpdater
@@ -94,7 +97,12 @@ func (u *LatestEventsUpdater) SetLatestEvents(
for i := range latest {
eventNIDs[i] = latest[i].EventNID
}
- return u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, u.txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID)
+ return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
+ if err := u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID); err != nil {
+ return fmt.Errorf("u.d.RoomsTable.updateLatestEventNIDs: %w", err)
+ }
+ return nil
+ })
}
// HasEventBeenSent implements types.RoomRecentEventsUpdater
@@ -104,7 +112,9 @@ func (u *LatestEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, e
// MarkEventAsSent implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error {
- return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, u.txn, eventNID)
+ return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
+ return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, txn, eventNID)
+ })
}
func (u *LatestEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) {
diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go
index 5955844f..329813bf 100644
--- a/roomserver/storage/shared/membership_updater.go
+++ b/roomserver/storage/shared/membership_updater.go
@@ -3,6 +3,7 @@ package shared
import (
"context"
"database/sql"
+ "fmt"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
@@ -41,9 +42,14 @@ func (d *Database) membershipUpdaterTxn(
targetUserNID types.EventStateKeyNID,
targetLocal bool,
) (*MembershipUpdater, error) {
-
- if err := d.MembershipTable.InsertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil {
- return nil, err
+ err := d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
+ if err := d.MembershipTable.InsertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil {
+ return fmt.Errorf("d.MembershipTable.InsertMembership: %w", err)
+ }
+ return nil
+ })
+ if err != nil {
+ return nil, fmt.Errorf("u.d.Writer.Do: %w", err)
}
membership, err := d.MembershipTable.SelectMembershipForUpdate(ctx, txn, roomNID, targetUserNID)
@@ -75,19 +81,19 @@ func (u *MembershipUpdater) IsLeave() bool {
func (u *MembershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) {
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender())
if err != nil {
- return false, err
+ return false, fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
}
inserted, err := u.d.InvitesTable.InsertInviteEvent(
u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(),
)
if err != nil {
- return false, err
+ return false, fmt.Errorf("u.d.InvitesTable.InsertInviteEvent: %w", err)
}
if u.membership != tables.MembershipStateInvite {
if err = u.d.MembershipTable.UpdateMembership(
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, 0,
); err != nil {
- return false, err
+ return false, fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
}
}
return inserted, nil
@@ -99,7 +105,7 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
}
// If this is a join event update, there is no invite to update
@@ -108,14 +114,14 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
u.ctx, u.txn, u.roomNID, u.targetUserNID,
)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("u.d.InvitesTables.UpdateInviteRetired: %w", err)
}
}
// Look up the NID of the new join event
nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("u.d.EventNIDs: %w", err)
}
if u.membership != tables.MembershipStateJoin || isUpdate {
@@ -123,7 +129,7 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
tables.MembershipStateJoin, nIDs[eventID],
); err != nil {
- return nil, err
+ return nil, fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
}
}
@@ -134,19 +140,19 @@ func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpd
func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) {
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, senderUserID)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
}
inviteEventIDs, err := u.d.InvitesTable.UpdateInviteRetired(
u.ctx, u.txn, u.roomNID, u.targetUserNID,
)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("u.d.InvitesTable.updateInviteRetired: %w", err)
}
// Look up the NID of the new leave event
nIDs, err := u.d.EventNIDs(u.ctx, []string{eventID})
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("u.d.EventNIDs: %w", err)
}
if u.membership != tables.MembershipStateLeaveOrBan {
@@ -154,7 +160,7 @@ func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s
u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID,
tables.MembershipStateLeaveOrBan, nIDs[eventID],
); err != nil {
- return nil, err
+ return nil, fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
}
}
return inviteEventIDs, nil
diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go
index 00179e33..45020d55 100644
--- a/roomserver/storage/shared/storage.go
+++ b/roomserver/storage/shared/storage.go
@@ -27,6 +27,7 @@ const redactionsArePermanent = false
type Database struct {
DB *sql.DB
+ Writer sqlutil.TransactionWriter
EventsTable tables.Events
EventJSONTable tables.EventJSON
EventTypesTable tables.EventTypes
@@ -83,20 +84,23 @@ func (d *Database) AddState(
stateBlockNIDs []types.StateBlockNID,
state []types.StateEntry,
) (stateNID types.StateSnapshotNID, err error) {
- err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if len(state) > 0 {
var stateBlockNID types.StateBlockNID
stateBlockNID, err = d.StateBlockTable.BulkInsertStateData(ctx, txn, state)
if err != nil {
- return err
+ return fmt.Errorf("d.StateBlockTable.BulkInsertStateData: %w", err)
}
stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID)
}
stateNID, err = d.StateSnapshotTable.InsertState(ctx, txn, roomNID, stateBlockNIDs)
- return err
+ if err != nil {
+ return fmt.Errorf("d.StateSnapshotTable.InsertState: %w", err)
+ }
+ return nil
})
if err != nil {
- return 0, err
+ return 0, fmt.Errorf("d.Writer.Do: %w", err)
}
return
}
@@ -110,7 +114,9 @@ func (d *Database) EventNIDs(
func (d *Database) SetState(
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
) error {
- return d.EventsTable.UpdateEventState(ctx, eventNID, stateNID)
+ return d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error {
+ return d.EventsTable.UpdateEventState(ctx, eventNID, stateNID)
+ })
}
func (d *Database) StateAtEventIDs(
@@ -221,7 +227,9 @@ func (d *Database) GetRoomVersionForRoomNID(
}
func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error {
- return d.RoomAliasesTable.InsertRoomAlias(ctx, alias, roomID, creatorUserID)
+ return d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error {
+ return d.RoomAliasesTable.InsertRoomAlias(ctx, alias, roomID, creatorUserID)
+ })
}
func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) {
@@ -239,15 +247,21 @@ func (d *Database) GetCreatorIDForAlias(
}
func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
- return d.RoomAliasesTable.DeleteRoomAlias(ctx, alias)
+ return d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error {
+ return d.RoomAliasesTable.DeleteRoomAlias(ctx, alias)
+ })
}
func (d *Database) GetMembership(
ctx context.Context, roomNID types.RoomNID, requestSenderUserID string,
) (membershipEventNID types.EventNID, stillInRoom bool, err error) {
- requestSenderUserNID, err := d.assignStateKeyNID(ctx, nil, requestSenderUserID)
+ var requestSenderUserNID types.EventStateKeyNID
+ err = d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error {
+ requestSenderUserNID, err = d.assignStateKeyNID(ctx, nil, requestSenderUserID)
+ return err
+ })
if err != nil {
- return
+ return 0, false, fmt.Errorf("d.assignStateKeyNID: %w", err)
}
senderMembershipEventNID, senderMembership, err :=
@@ -350,6 +364,7 @@ func (d *Database) GetLatestEventsForUpdate(
return NewLatestEventsUpdater(ctx, d, txn, roomNID)
}
+// nolint:gocyclo
func (d *Database) StoreEvent(
ctx context.Context, event gomatrixserverlib.Event,
txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID,
@@ -365,10 +380,10 @@ func (d *Database) StoreEvent(
err error
)
- err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if txnAndSessionID != nil {
if err = d.TransactionsTable.InsertTransaction(
- ctx, txn, txnAndSessionID.TransactionID,
+ ctx, nil, txnAndSessionID.TransactionID,
txnAndSessionID.SessionID, event.Sender(), event.EventID(),
); err != nil {
return fmt.Errorf("d.TransactionsTable.InsertTransaction: %w", err)
@@ -433,7 +448,7 @@ func (d *Database) StoreEvent(
return nil
})
if err != nil {
- return 0, types.StateAtEvent{}, nil, "", err
+ return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.Writer.Do: %w", err)
}
return roomNID, types.StateAtEvent{
@@ -449,7 +464,9 @@ func (d *Database) StoreEvent(
}
func (d *Database) PublishRoom(ctx context.Context, roomID string, publish bool) error {
- return d.PublishedTable.UpsertRoomPublished(ctx, roomID, publish)
+ return d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error {
+ return d.PublishedTable.UpsertRoomPublished(ctx, roomID, publish)
+ })
}
func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) {
diff --git a/roomserver/storage/sqlite3/event_json_table.go b/roomserver/storage/sqlite3/event_json_table.go
index e8118ad7..3cd44b1d 100644
--- a/roomserver/storage/sqlite3/event_json_table.go
+++ b/roomserver/storage/sqlite3/event_json_table.go
@@ -49,15 +49,13 @@ const bulkSelectEventJSONSQL = `
type eventJSONStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
insertEventJSONStmt *sql.Stmt
bulkSelectEventJSONStmt *sql.Stmt
}
-func NewSqliteEventJSONTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.EventJSON, error) {
+func NewSqliteEventJSONTable(db *sql.DB) (tables.EventJSON, error) {
s := &eventJSONStatements{
- db: db,
- writer: writer,
+ db: db,
}
_, err := db.Exec(eventJSONSchema)
if err != nil {
@@ -72,10 +70,8 @@ func NewSqliteEventJSONTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tab
func (s *eventJSONStatements) InsertEventJSON(
ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte,
) error {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- _, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON)
- return err
- })
+ _, err := sqlutil.TxStmt(txn, s.insertEventJSONStmt).ExecContext(ctx, int64(eventNID), eventJSON)
+ return err
}
func (s *eventJSONStatements) BulkSelectEventJSON(
diff --git a/roomserver/storage/sqlite3/event_state_keys_table.go b/roomserver/storage/sqlite3/event_state_keys_table.go
index c8ad052b..345df8c6 100644
--- a/roomserver/storage/sqlite3/event_state_keys_table.go
+++ b/roomserver/storage/sqlite3/event_state_keys_table.go
@@ -64,17 +64,15 @@ const bulkSelectEventStateKeyNIDSQL = `
type eventStateKeyStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
insertEventStateKeyNIDStmt *sql.Stmt
selectEventStateKeyNIDStmt *sql.Stmt
bulkSelectEventStateKeyNIDStmt *sql.Stmt
bulkSelectEventStateKeyStmt *sql.Stmt
}
-func NewSqliteEventStateKeysTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.EventStateKeys, error) {
+func NewSqliteEventStateKeysTable(db *sql.DB) (tables.EventStateKeys, error) {
s := &eventStateKeyStatements{
- db: db,
- writer: writer,
+ db: db,
}
_, err := db.Exec(eventStateKeysSchema)
if err != nil {
@@ -91,19 +89,15 @@ func NewSqliteEventStateKeysTable(db *sql.DB, writer *sqlutil.TransactionWriter)
func (s *eventStateKeyStatements) InsertEventStateKeyNID(
ctx context.Context, txn *sql.Tx, eventStateKey string,
) (types.EventStateKeyNID, error) {
- var eventStateKeyNID int64
- err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- insertStmt := sqlutil.TxStmt(txn, s.insertEventStateKeyNIDStmt)
- res, err := insertStmt.ExecContext(ctx, eventStateKey)
- if err != nil {
- return err
- }
- eventStateKeyNID, err = res.LastInsertId()
- if err != nil {
- return err
- }
- return nil
- })
+ insertStmt := sqlutil.TxStmt(txn, s.insertEventStateKeyNIDStmt)
+ res, err := insertStmt.ExecContext(ctx, eventStateKey)
+ if err != nil {
+ return 0, err
+ }
+ eventStateKeyNID, err := res.LastInsertId()
+ if err != nil {
+ return 0, err
+ }
return types.EventStateKeyNID(eventStateKeyNID), err
}
diff --git a/roomserver/storage/sqlite3/event_types_table.go b/roomserver/storage/sqlite3/event_types_table.go
index 4a645789..26e2bf84 100644
--- a/roomserver/storage/sqlite3/event_types_table.go
+++ b/roomserver/storage/sqlite3/event_types_table.go
@@ -18,6 +18,7 @@ package sqlite3
import (
"context"
"database/sql"
+ "fmt"
"strings"
"github.com/matrix-org/dendrite/internal"
@@ -78,17 +79,15 @@ const bulkSelectEventTypeNIDSQL = `
type eventTypeStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
insertEventTypeNIDStmt *sql.Stmt
insertEventTypeNIDResultStmt *sql.Stmt
selectEventTypeNIDStmt *sql.Stmt
bulkSelectEventTypeNIDStmt *sql.Stmt
}
-func NewSqliteEventTypesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.EventTypes, error) {
+func NewSqliteEventTypesTable(db *sql.DB) (tables.EventTypes, error) {
s := &eventTypeStatements{
- db: db,
- writer: writer,
+ db: db,
}
_, err := db.Exec(eventTypesSchema)
if err != nil {
@@ -104,18 +103,18 @@ func NewSqliteEventTypesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (ta
}
func (s *eventTypeStatements) InsertEventTypeNID(
- ctx context.Context, tx *sql.Tx, eventType string,
+ ctx context.Context, txn *sql.Tx, eventType string,
) (types.EventTypeNID, error) {
var eventTypeNID int64
- err := s.writer.Do(s.db, tx, func(tx *sql.Tx) error {
- insertStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDStmt)
- resultStmt := sqlutil.TxStmt(tx, s.insertEventTypeNIDResultStmt)
- _, err := insertStmt.ExecContext(ctx, eventType)
- if err != nil {
- return err
- }
- return resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID)
- })
+ insertStmt := sqlutil.TxStmt(txn, s.insertEventTypeNIDStmt)
+ resultStmt := sqlutil.TxStmt(txn, s.insertEventTypeNIDResultStmt)
+ _, err := insertStmt.ExecContext(ctx, eventType)
+ if err != nil {
+ return 0, fmt.Errorf("insertStmt.ExecContext: %w", err)
+ }
+ if err = resultStmt.QueryRowContext(ctx).Scan(&eventTypeNID); err != nil {
+ return 0, fmt.Errorf("resultStmt.QueryRowContext.Scan: %w", err)
+ }
return types.EventTypeNID(eventTypeNID), err
}
diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go
index 0e39755c..26ea1d41 100644
--- a/roomserver/storage/sqlite3/events_table.go
+++ b/roomserver/storage/sqlite3/events_table.go
@@ -99,7 +99,6 @@ const selectRoomNIDForEventNIDSQL = "" +
type eventStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
insertEventStmt *sql.Stmt
selectEventStmt *sql.Stmt
bulkSelectStateEventByIDStmt *sql.Stmt
@@ -115,10 +114,9 @@ type eventStatements struct {
selectRoomNIDForEventNIDStmt *sql.Stmt
}
-func NewSqliteEventsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Events, error) {
+func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) {
s := &eventStatements{
- db: db,
- writer: writer,
+ db: db,
}
_, err := db.Exec(eventsSchema)
if err != nil {
@@ -155,22 +153,19 @@ func (s *eventStatements) InsertEvent(
) (types.EventNID, types.StateSnapshotNID, error) {
// attempt to insert: the last_row_id is the event NID
var eventNID int64
- err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
- result, err := insertStmt.ExecContext(
- ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID),
- eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth,
- )
- if err != nil {
- return err
- }
- modified, err := result.RowsAffected()
- if modified == 0 && err == nil {
- return sql.ErrNoRows
- }
- eventNID, err = result.LastInsertId()
- return err
- })
+ insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
+ result, err := insertStmt.ExecContext(
+ ctx, int64(roomNID), int64(eventTypeNID), int64(eventStateKeyNID),
+ eventID, referenceSHA256, eventNIDsAsArray(authEventNIDs), depth,
+ )
+ if err != nil {
+ return 0, 0, err
+ }
+ modified, err := result.RowsAffected()
+ if modified == 0 && err == nil {
+ return 0, 0, sql.ErrNoRows
+ }
+ eventNID, err = result.LastInsertId()
return types.EventNID(eventNID), 0, err
}
@@ -286,11 +281,8 @@ func (s *eventStatements) BulkSelectStateAtEventByID(
func (s *eventStatements) UpdateEventState(
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
) error {
- return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, s.updateEventStateStmt)
- _, err := stmt.ExecContext(ctx, int64(stateNID), int64(eventNID))
- return err
- })
+ _, err := s.updateEventStateStmt.ExecContext(ctx, int64(stateNID), int64(eventNID))
+ return err
}
func (s *eventStatements) SelectEventSentToOutput(
@@ -302,11 +294,9 @@ func (s *eventStatements) SelectEventSentToOutput(
}
func (s *eventStatements) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- updateStmt := sqlutil.TxStmt(txn, s.updateEventSentToOutputStmt)
- _, err := updateStmt.ExecContext(ctx, int64(eventNID))
- return err
- })
+ updateStmt := sqlutil.TxStmt(txn, s.updateEventSentToOutputStmt)
+ _, err := updateStmt.ExecContext(ctx, int64(eventNID))
+ return err
}
func (s *eventStatements) SelectEventID(
@@ -334,7 +324,7 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference(
rows, err := sqlutil.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("sqlutil.TxStmt.QueryContext: %w", err)
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateAtEventAndReference: rows.close() failed")
results := make([]types.StateAtEventAndReference, len(eventNIDs))
@@ -481,7 +471,7 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx,
}
err = sqlutil.TxStmt(txn, sqlPrep).QueryRowContext(ctx, iEventIDs...).Scan(&result)
if err != nil {
- return 0, err
+ return 0, fmt.Errorf("sqlutil.TxStmt.QueryRowContext: %w", err)
}
return result, nil
}
diff --git a/roomserver/storage/sqlite3/invite_table.go b/roomserver/storage/sqlite3/invite_table.go
index 1305f4a8..327be6a0 100644
--- a/roomserver/storage/sqlite3/invite_table.go
+++ b/roomserver/storage/sqlite3/invite_table.go
@@ -64,17 +64,15 @@ SELECT invite_event_id FROM roomserver_invites WHERE room_nid = $1 AND target_ni
type inviteStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
insertInviteEventStmt *sql.Stmt
selectInviteActiveForUserInRoomStmt *sql.Stmt
updateInviteRetiredStmt *sql.Stmt
selectInvitesAboutToRetireStmt *sql.Stmt
}
-func NewSqliteInvitesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Invites, error) {
+func NewSqliteInvitesTable(db *sql.DB) (tables.Invites, error) {
s := &inviteStatements{
- db: db,
- writer: writer,
+ db: db,
}
_, err := db.Exec(inviteSchema)
if err != nil {
@@ -96,20 +94,17 @@ func (s *inviteStatements) InsertInviteEvent(
inviteEventJSON []byte,
) (bool, error) {
var count int64
- err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt)
- result, err := stmt.ExecContext(
- ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON,
- )
- if err != nil {
- return err
- }
- count, err = result.RowsAffected()
- if err != nil {
- return err
- }
- return nil
- })
+ stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt)
+ result, err := stmt.ExecContext(
+ ctx, inviteEventID, roomNID, targetUserNID, senderUserNID, inviteEventJSON,
+ )
+ if err != nil {
+ return false, err
+ }
+ count, err = result.RowsAffected()
+ if err != nil {
+ return false, err
+ }
return count != 0, err
}
@@ -117,26 +112,23 @@ func (s *inviteStatements) UpdateInviteRetired(
ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventIDs []string, err error) {
- err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- // gather all the event IDs we will retire
- stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt)
- rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
- if err != nil {
- return err
- }
- defer internal.CloseAndLogIfError(ctx, rows, "UpdateInviteRetired: rows.close() failed")
- for rows.Next() {
- var inviteEventID string
- if err = rows.Scan(&inviteEventID); err != nil {
- return err
- }
- eventIDs = append(eventIDs, inviteEventID)
+ // gather all the event IDs we will retire
+ stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt)
+ rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
+ if err != nil {
+ return
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "UpdateInviteRetired: rows.close() failed")
+ for rows.Next() {
+ var inviteEventID string
+ if err = rows.Scan(&inviteEventID); err != nil {
+ return
}
- // now retire the invites
- stmt = sqlutil.TxStmt(txn, s.updateInviteRetiredStmt)
- _, err = stmt.ExecContext(ctx, roomNID, targetUserNID)
- return err
- })
+ eventIDs = append(eventIDs, inviteEventID)
+ }
+ // now retire the invites
+ stmt = sqlutil.TxStmt(txn, s.updateInviteRetiredStmt)
+ _, err = stmt.ExecContext(ctx, roomNID, targetUserNID)
return
}
diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go
index 7b69cee3..b3ee69c0 100644
--- a/roomserver/storage/sqlite3/membership_table.go
+++ b/roomserver/storage/sqlite3/membership_table.go
@@ -77,7 +77,6 @@ const updateMembershipSQL = "" +
type membershipStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
insertMembershipStmt *sql.Stmt
selectMembershipForUpdateStmt *sql.Stmt
selectMembershipFromRoomAndTargetStmt *sql.Stmt
@@ -88,10 +87,9 @@ type membershipStatements struct {
updateMembershipStmt *sql.Stmt
}
-func NewSqliteMembershipTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Membership, error) {
+func NewSqliteMembershipTable(db *sql.DB) (tables.Membership, error) {
s := &membershipStatements{
- db: db,
- writer: writer,
+ db: db,
}
_, err := db.Exec(membershipSchema)
if err != nil {
@@ -115,11 +113,9 @@ func (s *membershipStatements) InsertMembership(
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
localTarget bool,
) error {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt)
- _, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget)
- return err
- })
+ stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt)
+ _, err := stmt.ExecContext(ctx, roomNID, targetUserNID, localTarget)
+ return err
}
func (s *membershipStatements) SelectMembershipForUpdate(
@@ -201,11 +197,9 @@ func (s *membershipStatements) UpdateMembership(
senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
eventNID types.EventNID,
) error {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt)
- _, err := stmt.ExecContext(
- ctx, senderUserNID, membership, eventNID, roomNID, targetUserNID,
- )
- return err
- })
+ stmt := sqlutil.TxStmt(txn, s.updateMembershipStmt)
+ _, err := stmt.ExecContext(
+ ctx, senderUserNID, membership, eventNID, roomNID, targetUserNID,
+ )
+ return err
}
diff --git a/roomserver/storage/sqlite3/previous_events_table.go b/roomserver/storage/sqlite3/previous_events_table.go
index ff804861..d28a42c6 100644
--- a/roomserver/storage/sqlite3/previous_events_table.go
+++ b/roomserver/storage/sqlite3/previous_events_table.go
@@ -54,15 +54,13 @@ const selectPreviousEventExistsSQL = `
type previousEventStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
insertPreviousEventStmt *sql.Stmt
selectPreviousEventExistsStmt *sql.Stmt
}
-func NewSqlitePrevEventsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.PreviousEvents, error) {
+func NewSqlitePrevEventsTable(db *sql.DB) (tables.PreviousEvents, error) {
s := &previousEventStatements{
- db: db,
- writer: writer,
+ db: db,
}
_, err := db.Exec(previousEventSchema)
if err != nil {
@@ -82,13 +80,11 @@ func (s *previousEventStatements) InsertPreviousEvent(
previousEventReferenceSHA256 []byte,
eventNID types.EventNID,
) error {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt)
- _, err := stmt.ExecContext(
- ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID),
- )
- return err
- })
+ stmt := sqlutil.TxStmt(txn, s.insertPreviousEventStmt)
+ _, err := stmt.ExecContext(
+ ctx, previousEventID, previousEventReferenceSHA256, int64(eventNID),
+ )
+ return err
}
// Check if the event reference exists
diff --git a/roomserver/storage/sqlite3/published_table.go b/roomserver/storage/sqlite3/published_table.go
index a4a47aec..1d6ccd56 100644
--- a/roomserver/storage/sqlite3/published_table.go
+++ b/roomserver/storage/sqlite3/published_table.go
@@ -19,7 +19,6 @@ import (
"database/sql"
"github.com/matrix-org/dendrite/internal"
- "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
)
@@ -45,16 +44,14 @@ const selectPublishedSQL = "" +
type publishedStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
upsertPublishedStmt *sql.Stmt
selectAllPublishedStmt *sql.Stmt
selectPublishedStmt *sql.Stmt
}
-func NewSqlitePublishedTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Published, error) {
+func NewSqlitePublishedTable(db *sql.DB) (tables.Published, error) {
s := &publishedStatements{
- db: db,
- writer: writer,
+ db: db,
}
_, err := db.Exec(publishedSchema)
if err != nil {
@@ -69,12 +66,9 @@ func NewSqlitePublishedTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tab
func (s *publishedStatements) UpsertRoomPublished(
ctx context.Context, roomID string, published bool,
-) (err error) {
- return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, s.upsertPublishedStmt)
- _, err := stmt.ExecContext(ctx, roomID, published)
- return err
- })
+) error {
+ _, err := s.upsertPublishedStmt.ExecContext(ctx, roomID, published)
+ return err
}
func (s *publishedStatements) SelectPublishedFromRoomID(
diff --git a/roomserver/storage/sqlite3/redactions_table.go b/roomserver/storage/sqlite3/redactions_table.go
index ad900a4e..a2179357 100644
--- a/roomserver/storage/sqlite3/redactions_table.go
+++ b/roomserver/storage/sqlite3/redactions_table.go
@@ -53,17 +53,15 @@ const markRedactionValidatedSQL = "" +
type redactionStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
insertRedactionStmt *sql.Stmt
selectRedactionInfoByRedactionEventIDStmt *sql.Stmt
selectRedactionInfoByEventBeingRedactedStmt *sql.Stmt
markRedactionValidatedStmt *sql.Stmt
}
-func NewSqliteRedactionsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Redactions, error) {
+func NewSqliteRedactionsTable(db *sql.DB) (tables.Redactions, error) {
s := &redactionStatements{
- db: db,
- writer: writer,
+ db: db,
}
_, err := db.Exec(redactionsSchema)
if err != nil {
@@ -81,11 +79,9 @@ func NewSqliteRedactionsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (ta
func (s *redactionStatements) InsertRedaction(
ctx context.Context, txn *sql.Tx, info tables.RedactionInfo,
) error {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, s.insertRedactionStmt)
- _, err := stmt.ExecContext(ctx, info.RedactionEventID, info.RedactsEventID, info.Validated)
- return err
- })
+ stmt := sqlutil.TxStmt(txn, s.insertRedactionStmt)
+ _, err := stmt.ExecContext(ctx, info.RedactionEventID, info.RedactsEventID, info.Validated)
+ return err
}
func (s *redactionStatements) SelectRedactionInfoByRedactionEventID(
@@ -121,9 +117,7 @@ func (s *redactionStatements) SelectRedactionInfoByEventBeingRedacted(
func (s *redactionStatements) MarkRedactionValidated(
ctx context.Context, txn *sql.Tx, redactionEventID string, validated bool,
) error {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, s.markRedactionValidatedStmt)
- _, err := stmt.ExecContext(ctx, redactionEventID, validated)
- return err
- })
+ stmt := sqlutil.TxStmt(txn, s.markRedactionValidatedStmt)
+ _, err := stmt.ExecContext(ctx, redactionEventID, validated)
+ return err
}
diff --git a/roomserver/storage/sqlite3/room_aliases_table.go b/roomserver/storage/sqlite3/room_aliases_table.go
index deba3ff5..a16e97aa 100644
--- a/roomserver/storage/sqlite3/room_aliases_table.go
+++ b/roomserver/storage/sqlite3/room_aliases_table.go
@@ -20,7 +20,6 @@ import (
"database/sql"
"github.com/matrix-org/dendrite/internal"
- "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
)
@@ -57,7 +56,6 @@ const deleteRoomAliasSQL = `
type roomAliasesStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
insertRoomAliasStmt *sql.Stmt
selectRoomIDFromAliasStmt *sql.Stmt
selectAliasesFromRoomIDStmt *sql.Stmt
@@ -65,10 +63,9 @@ type roomAliasesStatements struct {
deleteRoomAliasStmt *sql.Stmt
}
-func NewSqliteRoomAliasesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.RoomAliases, error) {
+func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
s := &roomAliasesStatements{
- db: db,
- writer: writer,
+ db: db,
}
_, err := db.Exec(roomAliasesSchema)
if err != nil {
@@ -85,12 +82,9 @@ func NewSqliteRoomAliasesTable(db *sql.DB, writer *sqlutil.TransactionWriter) (t
func (s *roomAliasesStatements) InsertRoomAlias(
ctx context.Context, alias string, roomID string, creatorUserID string,
-) (err error) {
- return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, s.insertRoomAliasStmt)
- _, err := stmt.ExecContext(ctx, alias, roomID, creatorUserID)
- return err
- })
+) error {
+ _, err := s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID)
+ return err
}
func (s *roomAliasesStatements) SelectRoomIDFromAlias(
@@ -138,10 +132,7 @@ func (s *roomAliasesStatements) SelectCreatorIDFromAlias(
func (s *roomAliasesStatements) DeleteRoomAlias(
ctx context.Context, alias string,
-) (err error) {
- return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, s.deleteRoomAliasStmt)
- _, err := stmt.ExecContext(ctx, alias)
- return err
- })
+) error {
+ _, err := s.deleteRoomAliasStmt.ExecContext(ctx, alias)
+ return err
}
diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go
index 8bbec508..6541cc0c 100644
--- a/roomserver/storage/sqlite3/rooms_table.go
+++ b/roomserver/storage/sqlite3/rooms_table.go
@@ -66,7 +66,6 @@ const selectRoomVersionForRoomNIDSQL = "" +
type roomStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
insertRoomNIDStmt *sql.Stmt
selectRoomNIDStmt *sql.Stmt
selectLatestEventNIDsStmt *sql.Stmt
@@ -76,10 +75,9 @@ type roomStatements struct {
selectRoomVersionForRoomNIDStmt *sql.Stmt
}
-func NewSqliteRoomsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Rooms, error) {
+func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
s := &roomStatements{
- db: db,
- writer: writer,
+ db: db,
}
_, err := db.Exec(roomsSchema)
if err != nil {
@@ -100,20 +98,14 @@ func (s *roomStatements) InsertRoomNID(
ctx context.Context, txn *sql.Tx,
roomID string, roomVersion gomatrixserverlib.RoomVersion,
) (roomNID types.RoomNID, err error) {
- err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt)
- _, err = insertStmt.ExecContext(ctx, roomID, roomVersion)
- if err != nil {
- return fmt.Errorf("insertStmt.ExecContext: %w", err)
- }
- roomNID, err = s.SelectRoomNID(ctx, txn, roomID)
- if err != nil {
- return fmt.Errorf("s.SelectRoomNID: %w", err)
- }
- return nil
- })
+ insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt)
+ _, err = insertStmt.ExecContext(ctx, roomID, roomVersion)
if err != nil {
- return types.RoomNID(0), err
+ return 0, fmt.Errorf("insertStmt.ExecContext: %w", err)
+ }
+ roomNID, err = s.SelectRoomNID(ctx, txn, roomID)
+ if err != nil {
+ return 0, fmt.Errorf("s.SelectRoomNID: %w", err)
}
return
}
@@ -170,17 +162,15 @@ func (s *roomStatements) UpdateLatestEventNIDs(
lastEventSentNID types.EventNID,
stateSnapshotNID types.StateSnapshotNID,
) error {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, s.updateLatestEventNIDsStmt)
- _, err := stmt.ExecContext(
- ctx,
- eventNIDsAsArray(eventNIDs),
- int64(lastEventSentNID),
- int64(stateSnapshotNID),
- roomNID,
- )
- return err
- })
+ stmt := sqlutil.TxStmt(txn, s.updateLatestEventNIDsStmt)
+ _, err := stmt.ExecContext(
+ ctx,
+ eventNIDsAsArray(eventNIDs),
+ int64(lastEventSentNID),
+ int64(stateSnapshotNID),
+ roomNID,
+ )
+ return err
}
func (s *roomStatements) SelectRoomVersionForRoomID(
diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go
index 3e28e450..8033903f 100644
--- a/roomserver/storage/sqlite3/state_block_table.go
+++ b/roomserver/storage/sqlite3/state_block_table.go
@@ -74,17 +74,15 @@ const bulkSelectFilteredStateBlockEntriesSQL = "" +
type stateBlockStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
insertStateDataStmt *sql.Stmt
selectNextStateBlockNIDStmt *sql.Stmt
bulkSelectStateBlockEntriesStmt *sql.Stmt
bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt
}
-func NewSqliteStateBlockTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.StateBlock, error) {
+func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
s := &stateBlockStatements{
- db: db,
- writer: writer,
+ db: db,
}
_, err := db.Exec(stateDataSchema)
if err != nil {
@@ -107,25 +105,22 @@ func (s *stateBlockStatements) BulkInsertStateData(
return 0, nil
}
var stateBlockNID types.StateBlockNID
- err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- err := txn.Stmt(s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID)
+ err := txn.Stmt(s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID)
+ if err != nil {
+ return 0, err
+ }
+ for _, entry := range entries {
+ _, err = txn.Stmt(s.insertStateDataStmt).ExecContext(
+ ctx,
+ int64(stateBlockNID),
+ int64(entry.EventTypeNID),
+ int64(entry.EventStateKeyNID),
+ int64(entry.EventNID),
+ )
if err != nil {
- return err
+ return 0, err
}
- for _, entry := range entries {
- _, err := txn.Stmt(s.insertStateDataStmt).ExecContext(
- ctx,
- int64(stateBlockNID),
- int64(entry.EventTypeNID),
- int64(entry.EventStateKeyNID),
- int64(entry.EventNID),
- )
- if err != nil {
- return err
- }
- }
- return nil
- })
+ }
return stateBlockNID, err
}
diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go
index 799904ff..392c2a67 100644
--- a/roomserver/storage/sqlite3/state_snapshot_table.go
+++ b/roomserver/storage/sqlite3/state_snapshot_table.go
@@ -50,15 +50,13 @@ const bulkSelectStateBlockNIDsSQL = "" +
type stateSnapshotStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
insertStateStmt *sql.Stmt
bulkSelectStateBlockNIDsStmt *sql.Stmt
}
-func NewSqliteStateSnapshotTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.StateSnapshot, error) {
+func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
s := &stateSnapshotStatements{
- db: db,
- writer: writer,
+ db: db,
}
_, err := db.Exec(stateSnapshotSchema)
if err != nil {
@@ -78,19 +76,16 @@ func (s *stateSnapshotStatements) InsertState(
if err != nil {
return
}
- err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- insertStmt := txn.Stmt(s.insertStateStmt)
- res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON))
- if err != nil {
- return err
- }
- lastRowID, err := res.LastInsertId()
- if err != nil {
- return err
- }
- stateNID = types.StateSnapshotNID(lastRowID)
- return nil
- })
+ insertStmt := txn.Stmt(s.insertStateStmt)
+ res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON))
+ if err != nil {
+ return 0, err
+ }
+ lastRowID, err := res.LastInsertId()
+ if err != nil {
+ return 0, err
+ }
+ stateNID = types.StateSnapshotNID(lastRowID)
return
}
diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go
index 72431637..8e3af6b7 100644
--- a/roomserver/storage/sqlite3/storage.go
+++ b/roomserver/storage/sqlite3/storage.go
@@ -41,6 +41,7 @@ type Database struct {
invites tables.Invites
membership tables.Membership
db *sql.DB
+ writer sqlutil.TransactionWriter
}
// Open a sqlite database.
@@ -51,7 +52,7 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err
}
- writer := sqlutil.NewTransactionWriter()
+ d.writer = sqlutil.NewTransactionWriter()
//d.db.Exec("PRAGMA journal_mode=WAL;")
//d.db.Exec("PRAGMA read_uncommitted = true;")
@@ -61,64 +62,65 @@ func Open(dbProperties *config.DatabaseOptions) (*Database, error) {
// which it will never obtain.
d.db.SetMaxOpenConns(20)
- d.eventStateKeys, err = NewSqliteEventStateKeysTable(d.db, writer)
+ d.eventStateKeys, err = NewSqliteEventStateKeysTable(d.db)
if err != nil {
return nil, err
}
- d.eventTypes, err = NewSqliteEventTypesTable(d.db, writer)
+ d.eventTypes, err = NewSqliteEventTypesTable(d.db)
if err != nil {
return nil, err
}
- d.eventJSON, err = NewSqliteEventJSONTable(d.db, writer)
+ d.eventJSON, err = NewSqliteEventJSONTable(d.db)
if err != nil {
return nil, err
}
- d.events, err = NewSqliteEventsTable(d.db, writer)
+ d.events, err = NewSqliteEventsTable(d.db)
if err != nil {
return nil, err
}
- d.rooms, err = NewSqliteRoomsTable(d.db, writer)
+ d.rooms, err = NewSqliteRoomsTable(d.db)
if err != nil {
return nil, err
}
- d.transactions, err = NewSqliteTransactionsTable(d.db, writer)
+ d.transactions, err = NewSqliteTransactionsTable(d.db)
if err != nil {
return nil, err
}
- stateBlock, err := NewSqliteStateBlockTable(d.db, writer)
+ stateBlock, err := NewSqliteStateBlockTable(d.db)
if err != nil {
return nil, err
}
- stateSnapshot, err := NewSqliteStateSnapshotTable(d.db, writer)
+ stateSnapshot, err := NewSqliteStateSnapshotTable(d.db)
if err != nil {
return nil, err
}
- d.prevEvents, err = NewSqlitePrevEventsTable(d.db, writer)
+ d.prevEvents, err = NewSqlitePrevEventsTable(d.db)
if err != nil {
return nil, err
}
- roomAliases, err := NewSqliteRoomAliasesTable(d.db, writer)
+ roomAliases, err := NewSqliteRoomAliasesTable(d.db)
if err != nil {
return nil, err
}
- d.invites, err = NewSqliteInvitesTable(d.db, writer)
+ d.invites, err = NewSqliteInvitesTable(d.db)
if err != nil {
return nil, err
}
- d.membership, err = NewSqliteMembershipTable(d.db, writer)
+ d.membership, err = NewSqliteMembershipTable(d.db)
if err != nil {
return nil, err
}
- published, err := NewSqlitePublishedTable(d.db, writer)
+ published, err := NewSqlitePublishedTable(d.db)
if err != nil {
return nil, err
}
- redactions, err := NewSqliteRedactionsTable(d.db, writer)
+ redactions, err := NewSqliteRedactionsTable(d.db)
if err != nil {
return nil, err
}
d.Database = shared.Database{
DB: d.db,
+ Writer: sqlutil.NewTransactionWriter(),
EventsTable: d.events,
EventTypesTable: d.eventTypes,
EventStateKeysTable: d.eventStateKeys,
diff --git a/roomserver/storage/sqlite3/transactions_table.go b/roomserver/storage/sqlite3/transactions_table.go
index 65c18a8a..029122c5 100644
--- a/roomserver/storage/sqlite3/transactions_table.go
+++ b/roomserver/storage/sqlite3/transactions_table.go
@@ -45,15 +45,13 @@ const selectTransactionEventIDSQL = `
type transactionStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
insertTransactionStmt *sql.Stmt
selectTransactionEventIDStmt *sql.Stmt
}
-func NewSqliteTransactionsTable(db *sql.DB, writer *sqlutil.TransactionWriter) (tables.Transactions, error) {
+func NewSqliteTransactionsTable(db *sql.DB) (tables.Transactions, error) {
s := &transactionStatements{
- db: db,
- writer: writer,
+ db: db,
}
_, err := db.Exec(transactionsSchema)
if err != nil {
@@ -72,14 +70,12 @@ func (s *transactionStatements) InsertTransaction(
sessionID int64,
userID string,
eventID string,
-) (err error) {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, s.insertTransactionStmt)
- _, err := stmt.ExecContext(
- ctx, transactionID, sessionID, userID, eventID,
- )
- return err
- })
+) error {
+ stmt := sqlutil.TxStmt(txn, s.insertTransactionStmt)
+ _, err := stmt.ExecContext(
+ ctx, transactionID, sessionID, userID, eventID,
+ )
+ return err
}
func (s *transactionStatements) SelectTransactionEventID(
diff --git a/serverkeyapi/storage/sqlite3/server_key_table.go b/serverkeyapi/storage/sqlite3/server_key_table.go
index 423292a5..b829eae7 100644
--- a/serverkeyapi/storage/sqlite3/server_key_table.go
+++ b/serverkeyapi/storage/sqlite3/server_key_table.go
@@ -63,7 +63,7 @@ const upsertServerKeysSQL = "" +
type serverKeyStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
+ writer sqlutil.TransactionWriter
bulkSelectServerKeysStmt *sql.Stmt
upsertServerKeysStmt *sql.Stmt
}
diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go
index dd5b838c..fdbf6758 100644
--- a/syncapi/storage/shared/syncserver.go
+++ b/syncapi/storage/shared/syncserver.go
@@ -45,7 +45,7 @@ type Database struct {
BackwardExtremities tables.BackwardsExtremities
SendToDevice tables.SendToDevice
Filter tables.Filter
- SendToDeviceWriter *sqlutil.TransactionWriter
+ SendToDeviceWriter sqlutil.TransactionWriter
EDUCache *cache.EDUCache
}
diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go
index 609cef14..248ec926 100644
--- a/syncapi/storage/sqlite3/account_data_table.go
+++ b/syncapi/storage/sqlite3/account_data_table.go
@@ -51,7 +51,7 @@ const selectMaxAccountDataIDSQL = "" +
type accountDataStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
+ writer sqlutil.TransactionWriter
streamIDStatements *streamIDStatements
insertAccountDataStmt *sql.Stmt
selectMaxAccountDataIDStmt *sql.Stmt
diff --git a/syncapi/storage/sqlite3/backwards_extremities_table.go b/syncapi/storage/sqlite3/backwards_extremities_table.go
index 1aeb041f..d96f2fe5 100644
--- a/syncapi/storage/sqlite3/backwards_extremities_table.go
+++ b/syncapi/storage/sqlite3/backwards_extremities_table.go
@@ -49,7 +49,7 @@ const deleteBackwardExtremitySQL = "" +
type backwardExtremitiesStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
+ writer sqlutil.TransactionWriter
insertBackwardExtremityStmt *sql.Stmt
selectBackwardExtremitiesForRoomStmt *sql.Stmt
deleteBackwardExtremityStmt *sql.Stmt
diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go
index 6edc99aa..77a21543 100644
--- a/syncapi/storage/sqlite3/current_room_state_table.go
+++ b/syncapi/storage/sqlite3/current_room_state_table.go
@@ -85,7 +85,7 @@ const selectEventsWithEventIDsSQL = "" +
type currentRoomStateStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
+ writer sqlutil.TransactionWriter
streamIDStatements *streamIDStatements
upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt
diff --git a/syncapi/storage/sqlite3/filter_table.go b/syncapi/storage/sqlite3/filter_table.go
index 3e8a4655..338b0b50 100644
--- a/syncapi/storage/sqlite3/filter_table.go
+++ b/syncapi/storage/sqlite3/filter_table.go
@@ -52,7 +52,7 @@ const insertFilterSQL = "" +
type filterStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
+ writer sqlutil.TransactionWriter
selectFilterStmt *sql.Stmt
selectFilterIDByContentStmt *sql.Stmt
insertFilterStmt *sql.Stmt
diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go
index 19e7a7c6..0bbd79f7 100644
--- a/syncapi/storage/sqlite3/invites_table.go
+++ b/syncapi/storage/sqlite3/invites_table.go
@@ -59,7 +59,7 @@ const selectMaxInviteIDSQL = "" +
type inviteEventsStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
+ writer sqlutil.TransactionWriter
streamIDStatements *streamIDStatements
insertInviteEventStmt *sql.Stmt
selectInviteEventsInRangeStmt *sql.Stmt
diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go
index 12b4dbab..0d154650 100644
--- a/syncapi/storage/sqlite3/output_room_events_table.go
+++ b/syncapi/storage/sqlite3/output_room_events_table.go
@@ -105,7 +105,7 @@ const selectStateInRangeSQL = "" +
type outputRoomEventsStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
+ writer sqlutil.TransactionWriter
streamIDStatements *streamIDStatements
insertEventStmt *sql.Stmt
selectEventsStmt *sql.Stmt
diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go
index 2e71e8f3..5c4ab005 100644
--- a/syncapi/storage/sqlite3/output_room_events_topology_table.go
+++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go
@@ -67,7 +67,7 @@ const selectMaxPositionInTopologySQL = "" +
type outputRoomEventsTopologyStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
+ writer sqlutil.TransactionWriter
insertEventInTopologyStmt *sql.Stmt
selectEventIDsInRangeASCStmt *sql.Stmt
selectEventIDsInRangeDESCStmt *sql.Stmt
diff --git a/syncapi/storage/sqlite3/send_to_device_table.go b/syncapi/storage/sqlite3/send_to_device_table.go
index 88b319fb..53786589 100644
--- a/syncapi/storage/sqlite3/send_to_device_table.go
+++ b/syncapi/storage/sqlite3/send_to_device_table.go
@@ -73,7 +73,7 @@ const deleteSendToDeviceMessagesSQL = `
type sendToDeviceStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
+ writer sqlutil.TransactionWriter
insertSendToDeviceMessageStmt *sql.Stmt
selectSendToDeviceMessagesStmt *sql.Stmt
countSendToDeviceMessagesStmt *sql.Stmt
diff --git a/syncapi/storage/sqlite3/stream_id_table.go b/syncapi/storage/sqlite3/stream_id_table.go
index cf3eed5b..1971e7f3 100644
--- a/syncapi/storage/sqlite3/stream_id_table.go
+++ b/syncapi/storage/sqlite3/stream_id_table.go
@@ -28,7 +28,7 @@ const selectStreamIDStmt = "" +
type streamIDStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
+ writer sqlutil.TransactionWriter
increaseStreamIDStmt *sql.Stmt
selectStreamIDStmt *sql.Stmt
}
diff --git a/userapi/storage/accounts/sqlite3/account_data_table.go b/userapi/storage/accounts/sqlite3/account_data_table.go
index cb54412a..9b40e657 100644
--- a/userapi/storage/accounts/sqlite3/account_data_table.go
+++ b/userapi/storage/accounts/sqlite3/account_data_table.go
@@ -51,7 +51,7 @@ const selectAccountDataByTypeSQL = "" +
type accountDataStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
+ writer sqlutil.TransactionWriter
insertAccountDataStmt *sql.Stmt
selectAccountDataStmt *sql.Stmt
selectAccountDataByTypeStmt *sql.Stmt
diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go
index 27c3d845..586bcab9 100644
--- a/userapi/storage/accounts/sqlite3/accounts_table.go
+++ b/userapi/storage/accounts/sqlite3/accounts_table.go
@@ -59,7 +59,7 @@ const selectNewNumericLocalpartSQL = "" +
type accountsStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
+ writer sqlutil.TransactionWriter
insertAccountStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
diff --git a/userapi/storage/accounts/sqlite3/profile_table.go b/userapi/storage/accounts/sqlite3/profile_table.go
index d4c404ca..cd35d298 100644
--- a/userapi/storage/accounts/sqlite3/profile_table.go
+++ b/userapi/storage/accounts/sqlite3/profile_table.go
@@ -53,7 +53,7 @@ const selectProfilesBySearchSQL = "" +
type profilesStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
+ writer sqlutil.TransactionWriter
insertProfileStmt *sql.Stmt
selectProfileByLocalpartStmt *sql.Stmt
setAvatarURLStmt *sql.Stmt
diff --git a/userapi/storage/accounts/sqlite3/threepid_table.go b/userapi/storage/accounts/sqlite3/threepid_table.go
index 0104e834..3000d7c4 100644
--- a/userapi/storage/accounts/sqlite3/threepid_table.go
+++ b/userapi/storage/accounts/sqlite3/threepid_table.go
@@ -54,7 +54,7 @@ const deleteThreePIDSQL = "" +
type threepidStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
+ writer sqlutil.TransactionWriter
selectLocalpartForThreePIDStmt *sql.Stmt
selectThreePIDsForLocalpartStmt *sql.Stmt
insertThreePIDStmt *sql.Stmt
diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go
index 9b535aab..962e63b0 100644
--- a/userapi/storage/devices/sqlite3/devices_table.go
+++ b/userapi/storage/devices/sqlite3/devices_table.go
@@ -78,7 +78,7 @@ const selectDevicesByIDSQL = "" +
type devicesStatements struct {
db *sql.DB
- writer *sqlutil.TransactionWriter
+ writer sqlutil.TransactionWriter
insertDeviceStmt *sql.Stmt
selectDevicesCountStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt