aboutsummaryrefslogtreecommitdiff
path: root/syncapi/storage
diff options
context:
space:
mode:
authorNeil Alexander <neilalexander@users.noreply.github.com>2020-08-21 10:42:08 +0100
committerGitHub <noreply@github.com>2020-08-21 10:42:08 +0100
commit9d53351dc20283103bf2eec6b92831033d06c5a8 (patch)
tree653cf0ddca3f777bcdba188187fb78fe39ae2b02 /syncapi/storage
parent5aaf32bbed4d704d5a22ad7dff79f7a68002a213 (diff)
Component-wide TransactionWriters (#1290)
* Offset updates take place using TransactionWriter * Refactor TransactionWriter in current state server * Refactor TransactionWriter in federation sender * Refactor TransactionWriter in key server * Refactor TransactionWriter in media API * Refactor TransactionWriter in server key API * Refactor TransactionWriter in sync API * Refactor TransactionWriter in user API * Fix deadlocking Sync API tests * Un-deadlock device database * Fix appservice API * Rename TransactionWriters to Writers * Move writers up a layer in sync API * Document sqlutil.Writer interface * Add note to Writer documentation
Diffstat (limited to 'syncapi/storage')
-rw-r--r--syncapi/storage/postgres/syncserver.go8
-rw-r--r--syncapi/storage/shared/syncserver.go31
-rw-r--r--syncapi/storage/sqlite3/account_data_table.go18
-rw-r--r--syncapi/storage/sqlite3/backwards_extremities_table.go17
-rw-r--r--syncapi/storage/sqlite3/current_room_state_table.go40
-rw-r--r--syncapi/storage/sqlite3/filter_table.go58
-rw-r--r--syncapi/storage/sqlite3/invites_table.go56
-rw-r--r--syncapi/storage/sqlite3/output_room_events_table.go55
-rw-r--r--syncapi/storage/sqlite3/output_room_events_topology_table.go16
-rw-r--r--syncapi/storage/sqlite3/send_to_device_table.go22
-rw-r--r--syncapi/storage/sqlite3/stream_id_table.go15
-rw-r--r--syncapi/storage/sqlite3/syncserver.go8
12 files changed, 144 insertions, 200 deletions
diff --git a/syncapi/storage/postgres/syncserver.go b/syncapi/storage/postgres/syncserver.go
index 26ef082f..36e8de67 100644
--- a/syncapi/storage/postgres/syncserver.go
+++ b/syncapi/storage/postgres/syncserver.go
@@ -30,7 +30,8 @@ import (
// both the database for PDUs and caches for EDUs.
type SyncServerDatasource struct {
shared.Database
- db *sql.DB
+ db *sql.DB
+ writer sqlutil.Writer
sqlutil.PartitionOffsetStatements
}
@@ -41,7 +42,8 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err
}
- if err = d.PartitionOffsetStatements.Prepare(d.db, "syncapi"); err != nil {
+ d.writer = sqlutil.NewDummyWriter()
+ if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "syncapi"); err != nil {
return nil, err
}
accountData, err := NewPostgresAccountDataTable(d.db)
@@ -78,6 +80,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
}
d.Database = shared.Database{
DB: d.db,
+ Writer: sqlutil.NewDummyWriter(),
Invites: invites,
AccountData: accountData,
OutputEvents: events,
@@ -86,7 +89,6 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
BackwardExtremities: backwardExtremities,
Filter: filter,
SendToDevice: sendToDevice,
- SendToDeviceWriter: sqlutil.NewTransactionWriter(),
EDUCache: cache.New(),
}
return &d, nil
diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go
index fdbf6758..699a6647 100644
--- a/syncapi/storage/shared/syncserver.go
+++ b/syncapi/storage/shared/syncserver.go
@@ -37,6 +37,7 @@ import (
// For now this contains the shared functions
type Database struct {
DB *sql.DB
+ Writer sqlutil.Writer
Invites tables.Invites
AccountData tables.AccountData
OutputEvents tables.Events
@@ -45,7 +46,6 @@ type Database struct {
BackwardExtremities tables.BackwardsExtremities
SendToDevice tables.SendToDevice
Filter tables.Filter
- SendToDeviceWriter sqlutil.TransactionWriter
EDUCache *cache.EDUCache
}
@@ -129,10 +129,7 @@ func (d *Database) GetStateEvent(
func (d *Database) GetStateEventsForRoom(
ctx context.Context, roomID string, stateFilter *gomatrixserverlib.StateFilter,
) (stateEvents []gomatrixserverlib.HeaderedEvent, err error) {
- err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
- stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, txn, roomID, stateFilter)
- return err
- })
+ stateEvents, err = d.CurrentRoomState.SelectCurrentState(ctx, nil, roomID, stateFilter)
return
}
@@ -171,9 +168,9 @@ func (d *Database) SyncStreamPosition(ctx context.Context) (types.StreamPosition
func (d *Database) AddInviteEvent(
ctx context.Context, inviteEvent gomatrixserverlib.HeaderedEvent,
) (sp types.StreamPosition, err error) {
- err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
- sp, err = d.Invites.InsertInviteEvent(ctx, txn, inviteEvent)
- return err
+ _ = d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
+ sp, err = d.Invites.InsertInviteEvent(ctx, nil, inviteEvent)
+ return nil
})
return
}
@@ -182,8 +179,12 @@ func (d *Database) AddInviteEvent(
// Returns an error if there was a problem communicating with the database.
func (d *Database) RetireInviteEvent(
ctx context.Context, inviteEventID string,
-) (types.StreamPosition, error) {
- return d.Invites.DeleteInviteEvent(ctx, inviteEventID)
+) (sp types.StreamPosition, err error) {
+ _ = d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
+ sp, err = d.Invites.DeleteInviteEvent(ctx, inviteEventID)
+ return nil
+ })
+ return
}
// GetAccountDataInRange returns all account data for a given user inserted or
@@ -207,7 +208,7 @@ func (d *Database) GetAccountDataInRange(
func (d *Database) UpsertAccountData(
ctx context.Context, userID, roomID, dataType string,
) (sp types.StreamPosition, err error) {
- err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
sp, err = d.AccountData.InsertAccountData(ctx, txn, userID, roomID, dataType)
return err
})
@@ -237,6 +238,7 @@ func (d *Database) StreamEventsToEvents(device *userapi.Device, in []types.Strea
// handleBackwardExtremities adds this event as a backwards extremity if and only if we do not have all of
// the events listed in the event's 'prev_events'. This function also updates the backwards extremities table
// to account for the fact that the given event is no longer a backwards extremity, but may be marked as such.
+// This function should always be called within a sqlutil.Writer for safety in SQLite.
func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, ev *gomatrixserverlib.HeaderedEvent) error {
if err := d.BackwardExtremities.DeleteBackwardExtremity(ctx, txn, ev.RoomID(), ev.EventID()); err != nil {
return err
@@ -275,7 +277,7 @@ func (d *Database) WriteEvent(
addStateEventIDs, removeStateEventIDs []string,
transactionID *api.TransactionID, excludeFromSync bool,
) (pduPosition types.StreamPosition, returnErr error) {
- returnErr = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
+ returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var err error
pos, err := d.OutputEvents.InsertEvent(
ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync,
@@ -304,6 +306,7 @@ func (d *Database) WriteEvent(
return pduPosition, returnErr
}
+// This function should always be called within a sqlutil.Writer for safety in SQLite.
func (d *Database) updateRoomState(
ctx context.Context, txn *sql.Tx,
removedEventIDs []string,
@@ -1114,7 +1117,7 @@ func (d *Database) StoreNewSendForDeviceMessage(
}
// Delegate the database write task to the SendToDeviceWriter. It'll guarantee
// that we don't lock the table for writes in more than one place.
- err = d.SendToDeviceWriter.Do(d.DB, nil, func(txn *sql.Tx) error {
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.AddSendToDeviceEvent(
ctx, txn, userID, deviceID, string(j),
)
@@ -1179,7 +1182,7 @@ func (d *Database) CleanSendToDeviceUpdates(
// If we need to write to the database then we'll ask the SendToDeviceWriter to
// do that for us. It'll guarantee that we don't lock the table for writes in
// more than one place.
- err = d.SendToDeviceWriter.Do(d.DB, nil, func(txn *sql.Tx) error {
+ err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// Delete any send-to-device messages marked for deletion.
if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil {
return fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e)
diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go
index 248ec926..72c46e48 100644
--- a/syncapi/storage/sqlite3/account_data_table.go
+++ b/syncapi/storage/sqlite3/account_data_table.go
@@ -20,7 +20,6 @@ import (
"database/sql"
"github.com/matrix-org/dendrite/internal"
- "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
@@ -51,7 +50,6 @@ const selectMaxAccountDataIDSQL = "" +
type accountDataStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
streamIDStatements *streamIDStatements
insertAccountDataStmt *sql.Stmt
selectMaxAccountDataIDStmt *sql.Stmt
@@ -61,7 +59,6 @@ type accountDataStatements struct {
func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) {
s := &accountDataStatements{
db: db,
- writer: sqlutil.NewTransactionWriter(),
streamIDStatements: streamID,
}
_, err := db.Exec(accountDataSchema)
@@ -84,15 +81,12 @@ func (s *accountDataStatements) InsertAccountData(
ctx context.Context, txn *sql.Tx,
userID, roomID, dataType string,
) (pos types.StreamPosition, err error) {
- return pos, s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
- var err error
- pos, err = s.streamIDStatements.nextStreamID(ctx, txn)
- if err != nil {
- return err
- }
- _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType)
- return err
- })
+ pos, err = s.streamIDStatements.nextStreamID(ctx, txn)
+ if err != nil {
+ return
+ }
+ _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType)
+ return
}
func (s *accountDataStatements) SelectAccountDataInRange(
diff --git a/syncapi/storage/sqlite3/backwards_extremities_table.go b/syncapi/storage/sqlite3/backwards_extremities_table.go
index d96f2fe5..116c33dc 100644
--- a/syncapi/storage/sqlite3/backwards_extremities_table.go
+++ b/syncapi/storage/sqlite3/backwards_extremities_table.go
@@ -19,7 +19,6 @@ import (
"database/sql"
"github.com/matrix-org/dendrite/internal"
- "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
)
@@ -49,7 +48,6 @@ const deleteBackwardExtremitySQL = "" +
type backwardExtremitiesStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
insertBackwardExtremityStmt *sql.Stmt
selectBackwardExtremitiesForRoomStmt *sql.Stmt
deleteBackwardExtremityStmt *sql.Stmt
@@ -57,8 +55,7 @@ type backwardExtremitiesStatements struct {
func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
s := &backwardExtremitiesStatements{
- db: db,
- writer: sqlutil.NewTransactionWriter(),
+ db: db,
}
_, err := db.Exec(backwardExtremitiesSchema)
if err != nil {
@@ -79,10 +76,8 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities
func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string,
) (err error) {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- _, err := txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
- return err
- })
+ _, err = txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
+ return err
}
func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
@@ -110,8 +105,6 @@ func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
ctx context.Context, txn *sql.Tx, roomID, knownEventID string,
) (err error) {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- _, err := txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
- return err
- })
+ _, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
+ return err
}
diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go
index 2f0068ed..6f822c90 100644
--- a/syncapi/storage/sqlite3/current_room_state_table.go
+++ b/syncapi/storage/sqlite3/current_room_state_table.go
@@ -85,7 +85,6 @@ const selectEventsWithEventIDsSQL = "" +
type currentRoomStateStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
streamIDStatements *streamIDStatements
upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt
@@ -98,7 +97,6 @@ type currentRoomStateStatements struct {
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) {
s := &currentRoomStateStatements{
db: db,
- writer: sqlutil.NewTransactionWriter(),
streamIDStatements: streamID,
}
_, err := db.Exec(currentRoomStateSchema)
@@ -200,11 +198,9 @@ func (s *currentRoomStateStatements) SelectCurrentState(
func (s *currentRoomStateStatements) DeleteRoomStateByEventID(
ctx context.Context, txn *sql.Tx, eventID string,
) error {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt)
- _, err := stmt.ExecContext(ctx, eventID)
- return err
- })
+ stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt)
+ _, err := stmt.ExecContext(ctx, eventID)
+ return err
}
func (s *currentRoomStateStatements) UpsertRoomState(
@@ -225,22 +221,20 @@ func (s *currentRoomStateStatements) UpsertRoomState(
}
// upsert state event
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt)
- _, err := stmt.ExecContext(
- ctx,
- event.RoomID(),
- event.EventID(),
- event.Type(),
- event.Sender(),
- containsURL,
- *event.StateKey(),
- headeredJSON,
- membership,
- addedAt,
- )
- return err
- })
+ stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt)
+ _, err = stmt.ExecContext(
+ ctx,
+ event.RoomID(),
+ event.EventID(),
+ event.Type(),
+ event.Sender(),
+ containsURL,
+ *event.StateKey(),
+ headeredJSON,
+ membership,
+ addedAt,
+ )
+ return err
}
func minOfInts(a, b int) int {
diff --git a/syncapi/storage/sqlite3/filter_table.go b/syncapi/storage/sqlite3/filter_table.go
index 338b0b50..3092bcd7 100644
--- a/syncapi/storage/sqlite3/filter_table.go
+++ b/syncapi/storage/sqlite3/filter_table.go
@@ -20,7 +20,6 @@ import (
"encoding/json"
"fmt"
- "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
@@ -52,7 +51,6 @@ const insertFilterSQL = "" +
type filterStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
selectFilterStmt *sql.Stmt
selectFilterIDByContentStmt *sql.Stmt
insertFilterStmt *sql.Stmt
@@ -64,8 +62,7 @@ func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) {
return nil, err
}
s := &filterStatements{
- db: db,
- writer: sqlutil.NewTransactionWriter(),
+ db: db,
}
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
return nil, err
@@ -114,33 +111,30 @@ func (s *filterStatements) InsertFilter(
return "", err
}
- err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
- // Check if filter already exists in the database using its localpart and content
- //
- // This can result in a race condition when two clients try to insert the
- // same filter and localpart at the same time, however this is not a
- // problem as both calls will result in the same filterID
- err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
- localpart, filterJSON).Scan(&existingFilterID)
- if err != nil && err != sql.ErrNoRows {
- return err
- }
- // If it does, return the existing ID
- if existingFilterID != "" {
- return nil
- }
-
- // Otherwise insert the filter and return the new ID
- res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart)
- if err != nil {
- return err
- }
- rowid, err := res.LastInsertId()
- if err != nil {
- return err
- }
- filterID = fmt.Sprintf("%d", rowid)
- return nil
- })
+ // Check if filter already exists in the database using its localpart and content
+ //
+ // This can result in a race condition when two clients try to insert the
+ // same filter and localpart at the same time, however this is not a
+ // problem as both calls will result in the same filterID
+ err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
+ localpart, filterJSON).Scan(&existingFilterID)
+ if err != nil && err != sql.ErrNoRows {
+ return "", err
+ }
+ // If it does, return the existing ID
+ if existingFilterID != "" {
+ return existingFilterID, nil
+ }
+
+ // Otherwise insert the filter and return the new ID
+ res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart)
+ if err != nil {
+ return "", err
+ }
+ rowid, err := res.LastInsertId()
+ if err != nil {
+ return "", err
+ }
+ filterID = fmt.Sprintf("%d", rowid)
return
}
diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go
index 0bbd79f7..45862efb 100644
--- a/syncapi/storage/sqlite3/invites_table.go
+++ b/syncapi/storage/sqlite3/invites_table.go
@@ -59,7 +59,6 @@ const selectMaxInviteIDSQL = "" +
type inviteEventsStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
streamIDStatements *streamIDStatements
insertInviteEventStmt *sql.Stmt
selectInviteEventsInRangeStmt *sql.Stmt
@@ -70,7 +69,6 @@ type inviteEventsStatements struct {
func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) {
s := &inviteEventsStatements{
db: db,
- writer: sqlutil.NewTransactionWriter(),
streamIDStatements: streamID,
}
_, err := db.Exec(inviteEventsSchema)
@@ -95,45 +93,37 @@ func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Inv
func (s *inviteEventsStatements) InsertInviteEvent(
ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent,
) (streamPos types.StreamPosition, err error) {
- err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- var err error
- streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
- if err != nil {
- return err
- }
+ streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
+ if err != nil {
+ return
+ }
- var headeredJSON []byte
- headeredJSON, err = json.Marshal(inviteEvent)
- if err != nil {
- return err
- }
+ var headeredJSON []byte
+ headeredJSON, err = json.Marshal(inviteEvent)
+ if err != nil {
+ return
+ }
- _, err = txn.Stmt(s.insertInviteEventStmt).ExecContext(
- ctx,
- streamPos,
- inviteEvent.RoomID(),
- inviteEvent.EventID(),
- *inviteEvent.StateKey(),
- headeredJSON,
- )
- return err
- })
+ stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt)
+ _, err = stmt.ExecContext(
+ ctx,
+ streamPos,
+ inviteEvent.RoomID(),
+ inviteEvent.EventID(),
+ *inviteEvent.StateKey(),
+ headeredJSON,
+ )
return
}
func (s *inviteEventsStatements) DeleteInviteEvent(
ctx context.Context, inviteEventID string,
) (types.StreamPosition, error) {
- var streamPos types.StreamPosition
- err := s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
- var err error
- streamPos, err = s.streamIDStatements.nextStreamID(ctx, nil)
- if err != nil {
- return err
- }
- _, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID)
- return err
- })
+ streamPos, err := s.streamIDStatements.nextStreamID(ctx, nil)
+ if err != nil {
+ return streamPos, err
+ }
+ _, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID)
return streamPos, err
}
diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go
index 0d154650..f10d0106 100644
--- a/syncapi/storage/sqlite3/output_room_events_table.go
+++ b/syncapi/storage/sqlite3/output_room_events_table.go
@@ -105,7 +105,6 @@ const selectStateInRangeSQL = "" +
type outputRoomEventsStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
streamIDStatements *streamIDStatements
insertEventStmt *sql.Stmt
selectEventsStmt *sql.Stmt
@@ -120,7 +119,6 @@ type outputRoomEventsStatements struct {
func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) {
s := &outputRoomEventsStatements{
db: db,
- writer: sqlutil.NewTransactionWriter(),
streamIDStatements: streamID,
}
_, err := db.Exec(outputRoomEventsSchema)
@@ -159,10 +157,8 @@ func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event
if err != nil {
return err
}
- return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
- _, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID())
- return err
- })
+ _, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID())
+ return err
}
// selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos.
@@ -304,32 +300,27 @@ func (s *outputRoomEventsStatements) InsertEvent(
return 0, err
}
- var streamPos types.StreamPosition
- err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
- if err != nil {
- return err
- }
-
- insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
- _, ierr := insertStmt.ExecContext(
- ctx,
- streamPos,
- event.RoomID(),
- event.EventID(),
- headeredJSON,
- event.Type(),
- event.Sender(),
- containsURL,
- string(addStateJSON),
- string(removeStateJSON),
- sessionID,
- txnID,
- excludeFromSync,
- excludeFromSync,
- )
- return ierr
- })
+ streamPos, err := s.streamIDStatements.nextStreamID(ctx, txn)
+ if err != nil {
+ return 0, err
+ }
+ insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
+ _, err = insertStmt.ExecContext(
+ ctx,
+ streamPos,
+ event.RoomID(),
+ event.EventID(),
+ headeredJSON,
+ event.Type(),
+ event.Sender(),
+ containsURL,
+ string(addStateJSON),
+ string(removeStateJSON),
+ sessionID,
+ txnID,
+ excludeFromSync,
+ excludeFromSync,
+ )
return streamPos, err
}
diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go
index 5c4ab005..d8c97b7e 100644
--- a/syncapi/storage/sqlite3/output_room_events_topology_table.go
+++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go
@@ -67,7 +67,6 @@ const selectMaxPositionInTopologySQL = "" +
type outputRoomEventsTopologyStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
insertEventInTopologyStmt *sql.Stmt
selectEventIDsInRangeASCStmt *sql.Stmt
selectEventIDsInRangeDESCStmt *sql.Stmt
@@ -77,8 +76,7 @@ type outputRoomEventsTopologyStatements struct {
func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
s := &outputRoomEventsTopologyStatements{
- db: db,
- writer: sqlutil.NewTransactionWriter(),
+ db: db,
}
_, err := db.Exec(outputRoomEventsTopologySchema)
if err != nil {
@@ -107,13 +105,11 @@ func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
func (s *outputRoomEventsTopologyStatements) InsertEventInTopology(
ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition,
) (err error) {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt)
- _, err := stmt.ExecContext(
- ctx, event.EventID(), event.Depth(), event.RoomID(), pos,
- )
- return err
- })
+ stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt)
+ _, err = stmt.ExecContext(
+ ctx, event.EventID(), event.Depth(), event.RoomID(), pos,
+ )
+ return
}
func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
diff --git a/syncapi/storage/sqlite3/send_to_device_table.go b/syncapi/storage/sqlite3/send_to_device_table.go
index 53786589..fbc759b1 100644
--- a/syncapi/storage/sqlite3/send_to_device_table.go
+++ b/syncapi/storage/sqlite3/send_to_device_table.go
@@ -73,7 +73,6 @@ const deleteSendToDeviceMessagesSQL = `
type sendToDeviceStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
insertSendToDeviceMessageStmt *sql.Stmt
selectSendToDeviceMessagesStmt *sql.Stmt
countSendToDeviceMessagesStmt *sql.Stmt
@@ -81,8 +80,7 @@ type sendToDeviceStatements struct {
func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
s := &sendToDeviceStatements{
- db: db,
- writer: sqlutil.NewTransactionWriter(),
+ db: db,
}
_, err := db.Exec(sendToDeviceSchema)
if err != nil {
@@ -103,10 +101,8 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
) (err error) {
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- _, err := sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
- return err
- })
+ _, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
+ return
}
func (s *sendToDeviceStatements) CountSendToDeviceMessages(
@@ -163,10 +159,8 @@ func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(
for k, v := range nids {
params[k+1] = v
}
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- _, err := txn.ExecContext(ctx, query, params...)
- return err
- })
+ _, err = txn.ExecContext(ctx, query, params...)
+ return
}
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
@@ -177,8 +171,6 @@ func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
for k, v := range nids {
params[k] = v
}
- return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- _, err := txn.ExecContext(ctx, query, params...)
- return err
- })
+ _, err = txn.ExecContext(ctx, query, params...)
+ return
}
diff --git a/syncapi/storage/sqlite3/stream_id_table.go b/syncapi/storage/sqlite3/stream_id_table.go
index 1971e7f3..e6bdc4fc 100644
--- a/syncapi/storage/sqlite3/stream_id_table.go
+++ b/syncapi/storage/sqlite3/stream_id_table.go
@@ -28,14 +28,12 @@ const selectStreamIDStmt = "" +
type streamIDStatements struct {
db *sql.DB
- writer sqlutil.TransactionWriter
increaseStreamIDStmt *sql.Stmt
selectStreamIDStmt *sql.Stmt
}
func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
s.db = db
- s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(streamIDTableSchema)
if err != nil {
return
@@ -52,14 +50,9 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
func (s *streamIDStatements) nextStreamID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
- err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
- if _, ierr := increaseStmt.ExecContext(ctx, "global"); err != nil {
- return ierr
- }
- if serr := selectStmt.QueryRowContext(ctx, "global").Scan(&pos); err != nil {
- return serr
- }
- return nil
- })
+ if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil {
+ return
+ }
+ err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos)
return
}
diff --git a/syncapi/storage/sqlite3/syncserver.go b/syncapi/storage/sqlite3/syncserver.go
index 9564a23a..81197bb7 100644
--- a/syncapi/storage/sqlite3/syncserver.go
+++ b/syncapi/storage/sqlite3/syncserver.go
@@ -31,7 +31,8 @@ import (
// both the database for PDUs and caches for EDUs.
type SyncServerDatasource struct {
shared.Database
- db *sql.DB
+ db *sql.DB
+ writer sqlutil.Writer
sqlutil.PartitionOffsetStatements
streamID streamIDStatements
}
@@ -44,6 +45,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err
}
+ d.writer = sqlutil.NewExclusiveWriter()
if err = d.prepare(); err != nil {
return nil, err
}
@@ -51,7 +53,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
}
func (d *SyncServerDatasource) prepare() (err error) {
- if err = d.PartitionOffsetStatements.Prepare(d.db, "syncapi"); err != nil {
+ if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "syncapi"); err != nil {
return err
}
if err = d.streamID.prepare(d.db); err != nil {
@@ -91,6 +93,7 @@ func (d *SyncServerDatasource) prepare() (err error) {
}
d.Database = shared.Database{
DB: d.db,
+ Writer: sqlutil.NewExclusiveWriter(),
Invites: invites,
AccountData: accountData,
OutputEvents: events,
@@ -99,7 +102,6 @@ func (d *SyncServerDatasource) prepare() (err error) {
Topology: topology,
Filter: filter,
SendToDevice: sendToDevice,
- SendToDeviceWriter: sqlutil.NewTransactionWriter(),
EDUCache: cache.New(),
}
return nil