aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--appservice/storage/sqlite3/appservice_events_table.go33
-rw-r--r--appservice/storage/sqlite3/txn_id_counter_table.go11
-rw-r--r--currentstateserver/storage/sqlite3/current_room_state_table.go38
-rw-r--r--mediaapi/storage/sqlite3/media_repository_table.go33
-rw-r--r--roomserver/internal/input_events.go11
-rw-r--r--roomserver/storage/shared/storage.go20
-rw-r--r--roomserver/storage/sqlite3/events_table.go3
-rw-r--r--roomserver/storage/sqlite3/published_table.go3
-rw-r--r--roomserver/storage/sqlite3/room_aliases_table.go6
-rw-r--r--roomserver/storage/sqlite3/rooms_table.go21
-rw-r--r--serverkeyapi/storage/sqlite3/server_key_table.go25
-rw-r--r--syncapi/storage/shared/syncserver.go12
-rw-r--r--syncapi/storage/sqlite3/account_data_table.go20
-rw-r--r--syncapi/storage/sqlite3/backwards_extremities_table.go20
-rw-r--r--syncapi/storage/sqlite3/current_room_state_table.go42
-rw-r--r--syncapi/storage/sqlite3/filter_table.go61
-rw-r--r--syncapi/storage/sqlite3/invites_table.go57
-rw-r--r--syncapi/storage/sqlite3/output_room_events_table.go70
-rw-r--r--syncapi/storage/sqlite3/output_room_events_topology_table.go19
-rw-r--r--syncapi/storage/sqlite3/send_to_device_table.go25
-rw-r--r--syncapi/storage/sqlite3/stream_id_table.go19
-rw-r--r--syncapi/storage/storage_test.go9
-rw-r--r--userapi/storage/accounts/sqlite3/account_data_table.go12
-rw-r--r--userapi/storage/accounts/sqlite3/accounts_table.go20
-rw-r--r--userapi/storage/accounts/sqlite3/profile_table.go11
-rw-r--r--userapi/storage/accounts/sqlite3/threepid_table.go19
-rw-r--r--userapi/storage/devices/sqlite3/devices_table.go66
27 files changed, 440 insertions, 246 deletions
diff --git a/appservice/storage/sqlite3/appservice_events_table.go b/appservice/storage/sqlite3/appservice_events_table.go
index 479f2213..da31f235 100644
--- a/appservice/storage/sqlite3/appservice_events_table.go
+++ b/appservice/storage/sqlite3/appservice_events_table.go
@@ -21,6 +21,7 @@ import (
"encoding/json"
"time"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
@@ -65,6 +66,8 @@ const (
)
type eventsStatements struct {
+ db *sql.DB
+ writer *sqlutil.TransactionWriter
selectEventsByApplicationServiceIDStmt *sql.Stmt
countEventsByApplicationServiceIDStmt *sql.Stmt
insertEventStmt *sql.Stmt
@@ -73,6 +76,8 @@ type eventsStatements struct {
}
func (s *eventsStatements) prepare(db *sql.DB) (err error) {
+ s.db = db
+ s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(appserviceEventsSchema)
if err != nil {
return
@@ -217,13 +222,15 @@ func (s *eventsStatements) insertEvent(
return err
}
- _, err = s.insertEventStmt.ExecContext(
- ctx,
- appServiceID,
- eventJSON,
- -1, // No transaction ID yet
- )
- return
+ return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
+ _, err := s.insertEventStmt.ExecContext(
+ ctx,
+ appServiceID,
+ eventJSON,
+ -1, // No transaction ID yet
+ )
+ return err
+ })
}
// updateTxnIDForEvents sets the transactionID for a collection of events. Done
@@ -234,8 +241,10 @@ func (s *eventsStatements) updateTxnIDForEvents(
appserviceID string,
maxID, txnID int,
) (err error) {
- _, err = s.updateTxnIDForEventsStmt.ExecContext(ctx, txnID, appserviceID, maxID)
- return
+ return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
+ _, err := s.updateTxnIDForEventsStmt.ExecContext(ctx, txnID, appserviceID, maxID)
+ return err
+ })
}
// deleteEventsBeforeAndIncludingID removes events matching given IDs from the database.
@@ -244,6 +253,8 @@ func (s *eventsStatements) deleteEventsBeforeAndIncludingID(
appserviceID string,
eventTableID int,
) (err error) {
- _, err = s.deleteEventsBeforeAndIncludingIDStmt.ExecContext(ctx, appserviceID, eventTableID)
- return
+ return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
+ _, err := s.deleteEventsBeforeAndIncludingIDStmt.ExecContext(ctx, appserviceID, eventTableID)
+ return err
+ })
}
diff --git a/appservice/storage/sqlite3/txn_id_counter_table.go b/appservice/storage/sqlite3/txn_id_counter_table.go
index b1ee6076..501ab5aa 100644
--- a/appservice/storage/sqlite3/txn_id_counter_table.go
+++ b/appservice/storage/sqlite3/txn_id_counter_table.go
@@ -18,6 +18,8 @@ package sqlite3
import (
"context"
"database/sql"
+
+ "github.com/matrix-org/dendrite/internal/sqlutil"
)
const txnIDSchema = `
@@ -35,10 +37,14 @@ const selectTxnIDSQL = `
`
type txnStatements struct {
+ db *sql.DB
+ writer *sqlutil.TransactionWriter
selectTxnIDStmt *sql.Stmt
}
func (s *txnStatements) prepare(db *sql.DB) (err error) {
+ s.db = db
+ s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(txnIDSchema)
if err != nil {
return
@@ -55,6 +61,9 @@ func (s *txnStatements) prepare(db *sql.DB) (err error) {
func (s *txnStatements) selectTxnID(
ctx context.Context,
) (txnID int, err error) {
- err = s.selectTxnIDStmt.QueryRowContext(ctx).Scan(&txnID)
+ err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
+ err := s.selectTxnIDStmt.QueryRowContext(ctx).Scan(&txnID)
+ return err
+ })
return
}
diff --git a/currentstateserver/storage/sqlite3/current_room_state_table.go b/currentstateserver/storage/sqlite3/current_room_state_table.go
index 8fac4f35..b95fb435 100644
--- a/currentstateserver/storage/sqlite3/current_room_state_table.go
+++ b/currentstateserver/storage/sqlite3/current_room_state_table.go
@@ -68,6 +68,7 @@ const selectBulkStateContentWildSQL = "" +
type currentRoomStateStatements struct {
db *sql.DB
+ writer *sqlutil.TransactionWriter
upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt
selectRoomIDsWithMembershipStmt *sql.Stmt
@@ -76,7 +77,8 @@ type currentRoomStateStatements struct {
func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) {
s := &currentRoomStateStatements{
- db: db,
+ db: db,
+ writer: sqlutil.NewTransactionWriter(),
}
_, err := db.Exec(currentRoomStateSchema)
if err != nil {
@@ -125,9 +127,11 @@ func (s *currentRoomStateStatements) SelectRoomIDsWithMembership(
func (s *currentRoomStateStatements) DeleteRoomStateByEventID(
ctx context.Context, txn *sql.Tx, eventID string,
) error {
- stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt)
- _, err := stmt.ExecContext(ctx, eventID)
- return err
+ return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt)
+ _, err := stmt.ExecContext(ctx, eventID)
+ return err
+ })
}
func (s *currentRoomStateStatements) UpsertRoomState(
@@ -140,18 +144,20 @@ func (s *currentRoomStateStatements) UpsertRoomState(
}
// upsert state event
- stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt)
- _, err = stmt.ExecContext(
- ctx,
- event.RoomID(),
- event.EventID(),
- event.Type(),
- event.Sender(),
- *event.StateKey(),
- headeredJSON,
- contentVal,
- )
- return err
+ return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt)
+ _, err = stmt.ExecContext(
+ ctx,
+ event.RoomID(),
+ event.EventID(),
+ event.Type(),
+ event.Sender(),
+ *event.StateKey(),
+ headeredJSON,
+ contentVal,
+ )
+ return err
+ })
}
func (s *currentRoomStateStatements) SelectEventsWithEventIDs(
diff --git a/mediaapi/storage/sqlite3/media_repository_table.go b/mediaapi/storage/sqlite3/media_repository_table.go
index 8e2e6236..f53f164d 100644
--- a/mediaapi/storage/sqlite3/media_repository_table.go
+++ b/mediaapi/storage/sqlite3/media_repository_table.go
@@ -20,6 +20,7 @@ import (
"database/sql"
"time"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
@@ -60,11 +61,16 @@ SELECT content_type, file_size_bytes, creation_ts, upload_name, base64hash, user
`
type mediaStatements struct {
+ db *sql.DB
+ writer *sqlutil.TransactionWriter
insertMediaStmt *sql.Stmt
selectMediaStmt *sql.Stmt
}
func (s *mediaStatements) prepare(db *sql.DB) (err error) {
+ s.db = db
+ s.writer = sqlutil.NewTransactionWriter()
+
_, err = db.Exec(mediaSchema)
if err != nil {
return
@@ -80,18 +86,21 @@ func (s *mediaStatements) insertMedia(
ctx context.Context, mediaMetadata *types.MediaMetadata,
) error {
mediaMetadata.CreationTimestamp = types.UnixMs(time.Now().UnixNano() / 1000000)
- _, err := s.insertMediaStmt.ExecContext(
- ctx,
- mediaMetadata.MediaID,
- mediaMetadata.Origin,
- mediaMetadata.ContentType,
- mediaMetadata.FileSizeBytes,
- mediaMetadata.CreationTimestamp,
- mediaMetadata.UploadName,
- mediaMetadata.Base64Hash,
- mediaMetadata.UserID,
- )
- return err
+ return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
+ stmt := sqlutil.TxStmt(txn, s.insertMediaStmt)
+ _, err := stmt.ExecContext(
+ ctx,
+ mediaMetadata.MediaID,
+ mediaMetadata.Origin,
+ mediaMetadata.ContentType,
+ mediaMetadata.FileSizeBytes,
+ mediaMetadata.CreationTimestamp,
+ mediaMetadata.UploadName,
+ mediaMetadata.Base64Hash,
+ mediaMetadata.UserID,
+ )
+ return err
+ })
}
func (s *mediaStatements) selectMedia(
diff --git a/roomserver/internal/input_events.go b/roomserver/internal/input_events.go
index 04538cf6..a6308299 100644
--- a/roomserver/internal/input_events.go
+++ b/roomserver/internal/input_events.go
@@ -18,6 +18,7 @@ package internal
import (
"context"
+ "fmt"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api"
@@ -65,13 +66,13 @@ func (r *RoomserverInternalAPI) processRoomEvent(
// Store the event.
roomNID, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs)
if err != nil {
- return
+ return "", fmt.Errorf("r.DB.StoreEvent: %w", err)
}
// if storing this event results in it being redacted then do so.
if redactedEventID == event.EventID() {
r, rerr := eventutil.RedactEvent(redactionEvent, &event)
if rerr != nil {
- return "", rerr
+ return "", fmt.Errorf("eventutil.RedactEvent: %w", rerr)
}
event = *r
}
@@ -93,7 +94,7 @@ func (r *RoomserverInternalAPI) processRoomEvent(
// Lets calculate one.
err = r.calculateAndSetState(ctx, input, roomNID, &stateAtEvent, event)
if err != nil {
- return
+ return "", fmt.Errorf("r.calculateAndSetState: %w", err)
}
}
@@ -105,7 +106,7 @@ func (r *RoomserverInternalAPI) processRoomEvent(
input.SendAsServer, // send as server
input.TransactionID, // transaction ID
); err != nil {
- return
+ return "", fmt.Errorf("r.updateLatestEvents: %w", err)
}
// processing this event resulted in an event (which may not be the one we're processing)
@@ -123,7 +124,7 @@ func (r *RoomserverInternalAPI) processRoomEvent(
},
})
if err != nil {
- return
+ return "", fmt.Errorf("r.WriteOutputEvents: %w", err)
}
}
diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go
index e2e5daf9..e858a9b0 100644
--- a/roomserver/storage/shared/storage.go
+++ b/roomserver/storage/shared/storage.go
@@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"encoding/json"
+ "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/api"
@@ -362,7 +363,7 @@ func (d *Database) StoreEvent(
ctx, txn, txnAndSessionID.TransactionID,
txnAndSessionID.SessionID, event.Sender(), event.EventID(),
); err != nil {
- return err
+ return fmt.Errorf("d.TransactionsTable.InsertTransaction: %w", err)
}
}
@@ -377,15 +378,15 @@ func (d *Database) StoreEvent(
// room.
var roomVersion gomatrixserverlib.RoomVersion
if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil {
- return err
+ return fmt.Errorf("extractRoomVersionFromCreateEvent: %w", err)
}
if roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion); err != nil {
- return err
+ return fmt.Errorf("d.assignRoomNID: %w", err)
}
if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, event.Type()); err != nil {
- return err
+ return fmt.Errorf("d.assignEventTypeNID: %w", err)
}
eventStateKey := event.StateKey()
@@ -393,7 +394,7 @@ func (d *Database) StoreEvent(
// Otherwise set the numeric ID for the state_key to 0.
if eventStateKey != nil {
if eventStateKeyNID, err = d.assignStateKeyNID(ctx, txn, *eventStateKey); err != nil {
- return err
+ return fmt.Errorf("d.assignStateKeyNID: %w", err)
}
}
@@ -411,17 +412,20 @@ func (d *Database) StoreEvent(
if err == sql.ErrNoRows {
// We've already inserted the event so select the numeric event ID
eventNID, stateNID, err = d.EventsTable.SelectEvent(ctx, txn, event.EventID())
+ if err != nil {
+ return fmt.Errorf("d.EventsTable.SelectEvent: %w", err)
+ }
}
if err != nil {
- return err
+ return fmt.Errorf("d.EventsTable.InsertEvent: %w", err)
}
}
if err = d.EventJSONTable.InsertEventJSON(ctx, txn, eventNID, event.JSON()); err != nil {
- return err
+ return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err)
}
redactionEvent, redactedEventID, err = d.handleRedactions(ctx, txn, eventNID, event)
- return err
+ return nil
})
if err != nil {
return 0, types.StateAtEvent{}, nil, "", err
diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go
index 378441c3..b3cfee07 100644
--- a/roomserver/storage/sqlite3/events_table.go
+++ b/roomserver/storage/sqlite3/events_table.go
@@ -287,7 +287,8 @@ func (s *eventStatements) UpdateEventState(
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
) error {
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
- _, err := s.updateEventStateStmt.ExecContext(ctx, int64(stateNID), int64(eventNID))
+ stmt := sqlutil.TxStmt(txn, s.updateEventStateStmt)
+ _, err := stmt.ExecContext(ctx, int64(stateNID), int64(eventNID))
return err
})
}
diff --git a/roomserver/storage/sqlite3/published_table.go b/roomserver/storage/sqlite3/published_table.go
index 96575241..85f1e0a4 100644
--- a/roomserver/storage/sqlite3/published_table.go
+++ b/roomserver/storage/sqlite3/published_table.go
@@ -71,7 +71,8 @@ func (s *publishedStatements) UpsertRoomPublished(
ctx context.Context, roomID string, published bool,
) (err error) {
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
- _, err := s.upsertPublishedStmt.ExecContext(ctx, roomID, published)
+ stmt := sqlutil.TxStmt(txn, s.upsertPublishedStmt)
+ _, err := stmt.ExecContext(ctx, roomID, published)
return err
})
}
diff --git a/roomserver/storage/sqlite3/room_aliases_table.go b/roomserver/storage/sqlite3/room_aliases_table.go
index 096b73f9..4a535777 100644
--- a/roomserver/storage/sqlite3/room_aliases_table.go
+++ b/roomserver/storage/sqlite3/room_aliases_table.go
@@ -87,7 +87,8 @@ func (s *roomAliasesStatements) InsertRoomAlias(
ctx context.Context, alias string, roomID string, creatorUserID string,
) (err error) {
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
- _, err := s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID)
+ stmt := sqlutil.TxStmt(txn, s.insertRoomAliasStmt)
+ _, err := stmt.ExecContext(ctx, alias, roomID, creatorUserID)
return err
})
}
@@ -139,7 +140,8 @@ func (s *roomAliasesStatements) DeleteRoomAlias(
ctx context.Context, alias string,
) (err error) {
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
- _, err := s.deleteRoomAliasStmt.ExecContext(ctx, alias)
+ stmt := sqlutil.TxStmt(txn, s.deleteRoomAliasStmt)
+ _, err := stmt.ExecContext(ctx, alias)
return err
})
}
diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go
index 9eeadea9..bb30a63b 100644
--- a/roomserver/storage/sqlite3/rooms_table.go
+++ b/roomserver/storage/sqlite3/rooms_table.go
@@ -20,6 +20,7 @@ import (
"database/sql"
"encoding/json"
"errors"
+ "fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
@@ -98,17 +99,23 @@ func NewSqliteRoomsTable(db *sql.DB) (tables.Rooms, error) {
func (s *roomStatements) InsertRoomNID(
ctx context.Context, txn *sql.Tx,
roomID string, roomVersion gomatrixserverlib.RoomVersion,
-) (types.RoomNID, error) {
- err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+) (roomNID types.RoomNID, err error) {
+ err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
insertStmt := sqlutil.TxStmt(txn, s.insertRoomNIDStmt)
- _, err := insertStmt.ExecContext(ctx, roomID, roomVersion)
- return err
+ _, err = insertStmt.ExecContext(ctx, roomID, roomVersion)
+ if err != nil {
+ return fmt.Errorf("insertStmt.ExecContext: %w", err)
+ }
+ roomNID, err = s.SelectRoomNID(ctx, txn, roomID)
+ if err != nil {
+ return fmt.Errorf("s.SelectRoomNID: %w", err)
+ }
+ return nil
})
- if err == nil {
- return s.SelectRoomNID(ctx, txn, roomID)
- } else {
+ if err != nil {
return types.RoomNID(0), err
}
+ return
}
func (s *roomStatements) SelectRoomNID(
diff --git a/serverkeyapi/storage/sqlite3/server_key_table.go b/serverkeyapi/storage/sqlite3/server_key_table.go
index 4f03dccb..423292a5 100644
--- a/serverkeyapi/storage/sqlite3/server_key_table.go
+++ b/serverkeyapi/storage/sqlite3/server_key_table.go
@@ -63,12 +63,14 @@ const upsertServerKeysSQL = "" +
type serverKeyStatements struct {
db *sql.DB
+ writer *sqlutil.TransactionWriter
bulkSelectServerKeysStmt *sql.Stmt
upsertServerKeysStmt *sql.Stmt
}
func (s *serverKeyStatements) prepare(db *sql.DB) (err error) {
s.db = db
+ s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(serverKeysSchema)
if err != nil {
return
@@ -136,16 +138,19 @@ func (s *serverKeyStatements) upsertServerKeys(
request gomatrixserverlib.PublicKeyLookupRequest,
key gomatrixserverlib.PublicKeyLookupResult,
) error {
- _, err := s.upsertServerKeysStmt.ExecContext(
- ctx,
- string(request.ServerName),
- string(request.KeyID),
- nameAndKeyID(request),
- key.ValidUntilTS,
- key.ExpiredTS,
- key.Key.Encode(),
- )
- return err
+ return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
+ stmt := sqlutil.TxStmt(txn, s.upsertServerKeysStmt)
+ _, err := stmt.ExecContext(
+ ctx,
+ string(request.ServerName),
+ string(request.KeyID),
+ nameAndKeyID(request),
+ key.ValidUntilTS,
+ key.ExpiredTS,
+ key.Key.Encode(),
+ )
+ return err
+ })
}
func nameAndKeyID(request gomatrixserverlib.PublicKeyLookupRequest) string {
diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go
index 32079291..e1312671 100644
--- a/syncapi/storage/shared/syncserver.go
+++ b/syncapi/storage/shared/syncserver.go
@@ -281,16 +281,16 @@ func (d *Database) WriteEvent(
ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync,
)
if err != nil {
- return err
+ return fmt.Errorf("d.OutputEvents.InsertEvent: %w", err)
}
pduPosition = pos
if err = d.Topology.InsertEventInTopology(ctx, txn, ev, pos); err != nil {
- return err
+ return fmt.Errorf("d.Topology.InsertEventInTopology: %w", err)
}
if err = d.handleBackwardExtremities(ctx, txn, ev); err != nil {
- return err
+ return fmt.Errorf("d.handleBackwardExtremities: %w", err)
}
if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 {
@@ -313,7 +313,7 @@ func (d *Database) updateRoomState(
// remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add.
for _, eventID := range removedEventIDs {
if err := d.CurrentRoomState.DeleteRoomStateByEventID(ctx, txn, eventID); err != nil {
- return err
+ return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateByEventID: %w", err)
}
}
@@ -326,13 +326,13 @@ func (d *Database) updateRoomState(
if event.Type() == "m.room.member" {
value, err := event.Membership()
if err != nil {
- return err
+ return fmt.Errorf("event.Membership: %w", err)
}
membership = &value
}
if err := d.CurrentRoomState.UpsertRoomState(ctx, txn, event, membership, pduPosition); err != nil {
- return err
+ return fmt.Errorf("d.CurrentRoomState.UpsertRoomState: %w", err)
}
}
diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go
index ae5caa4e..609cef14 100644
--- a/syncapi/storage/sqlite3/account_data_table.go
+++ b/syncapi/storage/sqlite3/account_data_table.go
@@ -20,6 +20,7 @@ import (
"database/sql"
"github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
@@ -49,6 +50,8 @@ const selectMaxAccountDataIDSQL = "" +
"SELECT MAX(id) FROM syncapi_account_data_type"
type accountDataStatements struct {
+ db *sql.DB
+ writer *sqlutil.TransactionWriter
streamIDStatements *streamIDStatements
insertAccountDataStmt *sql.Stmt
selectMaxAccountDataIDStmt *sql.Stmt
@@ -57,6 +60,8 @@ type accountDataStatements struct {
func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) {
s := &accountDataStatements{
+ db: db,
+ writer: sqlutil.NewTransactionWriter(),
streamIDStatements: streamID,
}
_, err := db.Exec(accountDataSchema)
@@ -79,12 +84,15 @@ func (s *accountDataStatements) InsertAccountData(
ctx context.Context, txn *sql.Tx,
userID, roomID, dataType string,
) (pos types.StreamPosition, err error) {
- pos, err = s.streamIDStatements.nextStreamID(ctx, txn)
- if err != nil {
- return
- }
- _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType)
- return
+ return pos, s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
+ var err error
+ pos, err = s.streamIDStatements.nextStreamID(ctx, txn)
+ if err != nil {
+ return err
+ }
+ _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType)
+ return err
+ })
}
func (s *accountDataStatements) SelectAccountDataInRange(
diff --git a/syncapi/storage/sqlite3/backwards_extremities_table.go b/syncapi/storage/sqlite3/backwards_extremities_table.go
index e16e54a6..1aeb041f 100644
--- a/syncapi/storage/sqlite3/backwards_extremities_table.go
+++ b/syncapi/storage/sqlite3/backwards_extremities_table.go
@@ -19,6 +19,7 @@ import (
"database/sql"
"github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
)
@@ -47,13 +48,18 @@ const deleteBackwardExtremitySQL = "" +
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2"
type backwardExtremitiesStatements struct {
+ db *sql.DB
+ writer *sqlutil.TransactionWriter
insertBackwardExtremityStmt *sql.Stmt
selectBackwardExtremitiesForRoomStmt *sql.Stmt
deleteBackwardExtremityStmt *sql.Stmt
}
func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
- s := &backwardExtremitiesStatements{}
+ s := &backwardExtremitiesStatements{
+ db: db,
+ writer: sqlutil.NewTransactionWriter(),
+ }
_, err := db.Exec(backwardExtremitiesSchema)
if err != nil {
return nil, err
@@ -73,8 +79,10 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities
func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string,
) (err error) {
- _, err = txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
- return
+ return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ _, err := txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
+ return err
+ })
}
func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
@@ -102,6 +110,8 @@ func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
ctx context.Context, txn *sql.Tx, roomID, knownEventID string,
) (err error) {
- _, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
- return
+ return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ _, err := txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
+ return err
+ })
}
diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go
index 85f212ad..08b42f5b 100644
--- a/syncapi/storage/sqlite3/current_room_state_table.go
+++ b/syncapi/storage/sqlite3/current_room_state_table.go
@@ -84,6 +84,8 @@ const selectEventsWithEventIDsSQL = "" +
" FROM syncapi_current_room_state WHERE event_id IN ($1)"
type currentRoomStateStatements struct {
+ db *sql.DB
+ writer *sqlutil.TransactionWriter
streamIDStatements *streamIDStatements
upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt
@@ -95,6 +97,8 @@ type currentRoomStateStatements struct {
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) {
s := &currentRoomStateStatements{
+ db: db,
+ writer: sqlutil.NewTransactionWriter(),
streamIDStatements: streamID,
}
_, err := db.Exec(currentRoomStateSchema)
@@ -196,9 +200,11 @@ func (s *currentRoomStateStatements) SelectCurrentState(
func (s *currentRoomStateStatements) DeleteRoomStateByEventID(
ctx context.Context, txn *sql.Tx, eventID string,
) error {
- stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt)
- _, err := stmt.ExecContext(ctx, eventID)
- return err
+ return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt)
+ _, err := stmt.ExecContext(ctx, eventID)
+ return err
+ })
}
func (s *currentRoomStateStatements) UpsertRoomState(
@@ -219,20 +225,22 @@ func (s *currentRoomStateStatements) UpsertRoomState(
}
// upsert state event
- stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt)
- _, err = stmt.ExecContext(
- ctx,
- event.RoomID(),
- event.EventID(),
- event.Type(),
- event.Sender(),
- containsURL,
- *event.StateKey(),
- headeredJSON,
- membership,
- addedAt,
- )
- return err
+ return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt)
+ _, err := stmt.ExecContext(
+ ctx,
+ event.RoomID(),
+ event.EventID(),
+ event.Type(),
+ event.Sender(),
+ containsURL,
+ *event.StateKey(),
+ headeredJSON,
+ membership,
+ addedAt,
+ )
+ return err
+ })
}
func (s *currentRoomStateStatements) SelectEventsWithEventIDs(
diff --git a/syncapi/storage/sqlite3/filter_table.go b/syncapi/storage/sqlite3/filter_table.go
index 8b26759d..3e8a4655 100644
--- a/syncapi/storage/sqlite3/filter_table.go
+++ b/syncapi/storage/sqlite3/filter_table.go
@@ -20,6 +20,7 @@ import (
"encoding/json"
"fmt"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
@@ -50,6 +51,8 @@ const insertFilterSQL = "" +
"INSERT INTO syncapi_filter (filter, localpart) VALUES ($1, $2)"
type filterStatements struct {
+ db *sql.DB
+ writer *sqlutil.TransactionWriter
selectFilterStmt *sql.Stmt
selectFilterIDByContentStmt *sql.Stmt
insertFilterStmt *sql.Stmt
@@ -60,7 +63,10 @@ func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) {
if err != nil {
return nil, err
}
- s := &filterStatements{}
+ s := &filterStatements{
+ db: db,
+ writer: sqlutil.NewTransactionWriter(),
+ }
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
return nil, err
}
@@ -108,30 +114,33 @@ func (s *filterStatements) InsertFilter(
return "", err
}
- // Check if filter already exists in the database using its localpart and content
- //
- // This can result in a race condition when two clients try to insert the
- // same filter and localpart at the same time, however this is not a
- // problem as both calls will result in the same filterID
- err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
- localpart, filterJSON).Scan(&existingFilterID)
- if err != nil && err != sql.ErrNoRows {
- return "", err
- }
- // If it does, return the existing ID
- if existingFilterID != "" {
- return existingFilterID, err
- }
-
- // Otherwise insert the filter and return the new ID
- res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart)
- if err != nil {
- return "", err
- }
- rowid, err := res.LastInsertId()
- if err != nil {
- return "", err
- }
- filterID = fmt.Sprintf("%d", rowid)
+ err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
+ // Check if filter already exists in the database using its localpart and content
+ //
+ // This can result in a race condition when two clients try to insert the
+ // same filter and localpart at the same time, however this is not a
+ // problem as both calls will result in the same filterID
+ err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
+ localpart, filterJSON).Scan(&existingFilterID)
+ if err != nil && err != sql.ErrNoRows {
+ return err
+ }
+ // If it does, return the existing ID
+ if existingFilterID != "" {
+ return nil
+ }
+
+ // Otherwise insert the filter and return the new ID
+ res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart)
+ if err != nil {
+ return err
+ }
+ rowid, err := res.LastInsertId()
+ if err != nil {
+ return err
+ }
+ filterID = fmt.Sprintf("%d", rowid)
+ return nil
+ })
return
}
diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go
index aa051388..19e7a7c6 100644
--- a/syncapi/storage/sqlite3/invites_table.go
+++ b/syncapi/storage/sqlite3/invites_table.go
@@ -58,6 +58,8 @@ const selectMaxInviteIDSQL = "" +
"SELECT MAX(id) FROM syncapi_invite_events"
type inviteEventsStatements struct {
+ db *sql.DB
+ writer *sqlutil.TransactionWriter
streamIDStatements *streamIDStatements
insertInviteEventStmt *sql.Stmt
selectInviteEventsInRangeStmt *sql.Stmt
@@ -67,6 +69,8 @@ type inviteEventsStatements struct {
func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) {
s := &inviteEventsStatements{
+ db: db,
+ writer: sqlutil.NewTransactionWriter(),
streamIDStatements: streamID,
}
_, err := db.Exec(inviteEventsSchema)
@@ -91,36 +95,45 @@ func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Inv
func (s *inviteEventsStatements) InsertInviteEvent(
ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent,
) (streamPos types.StreamPosition, err error) {
- streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
- if err != nil {
- return
- }
+ err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ var err error
+ streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
+ if err != nil {
+ return err
+ }
- var headeredJSON []byte
- headeredJSON, err = json.Marshal(inviteEvent)
- if err != nil {
- return
- }
+ var headeredJSON []byte
+ headeredJSON, err = json.Marshal(inviteEvent)
+ if err != nil {
+ return err
+ }
- _, err = txn.Stmt(s.insertInviteEventStmt).ExecContext(
- ctx,
- streamPos,
- inviteEvent.RoomID(),
- inviteEvent.EventID(),
- *inviteEvent.StateKey(),
- headeredJSON,
- )
+ _, err = txn.Stmt(s.insertInviteEventStmt).ExecContext(
+ ctx,
+ streamPos,
+ inviteEvent.RoomID(),
+ inviteEvent.EventID(),
+ *inviteEvent.StateKey(),
+ headeredJSON,
+ )
+ return err
+ })
return
}
func (s *inviteEventsStatements) DeleteInviteEvent(
ctx context.Context, inviteEventID string,
) (types.StreamPosition, error) {
- streamPos, err := s.streamIDStatements.nextStreamID(ctx, nil)
- if err != nil {
- return streamPos, err
- }
- _, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID)
+ var streamPos types.StreamPosition
+ err := s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
+ var err error
+ streamPos, err = s.streamIDStatements.nextStreamID(ctx, nil)
+ if err != nil {
+ return err
+ }
+ _, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID)
+ return err
+ })
return streamPos, err
}
diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go
index da2ea3f6..12b4dbab 100644
--- a/syncapi/storage/sqlite3/output_room_events_table.go
+++ b/syncapi/storage/sqlite3/output_room_events_table.go
@@ -104,6 +104,8 @@ const selectStateInRangeSQL = "" +
" LIMIT $8" // limit
type outputRoomEventsStatements struct {
+ db *sql.DB
+ writer *sqlutil.TransactionWriter
streamIDStatements *streamIDStatements
insertEventStmt *sql.Stmt
selectEventsStmt *sql.Stmt
@@ -117,6 +119,8 @@ type outputRoomEventsStatements struct {
func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) {
s := &outputRoomEventsStatements{
+ db: db,
+ writer: sqlutil.NewTransactionWriter(),
streamIDStatements: streamID,
}
_, err := db.Exec(outputRoomEventsSchema)
@@ -155,8 +159,10 @@ func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event
if err != nil {
return err
}
- _, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID())
- return err
+ return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
+ _, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID())
+ return err
+ })
}
// selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos.
@@ -267,7 +273,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
ctx context.Context, txn *sql.Tx,
event *gomatrixserverlib.HeaderedEvent, addState, removeState []string,
transactionID *api.TransactionID, excludeFromSync bool,
-) (streamPos types.StreamPosition, err error) {
+) (types.StreamPosition, error) {
var txnID *string
var sessionID *int64
if transactionID != nil {
@@ -284,43 +290,47 @@ func (s *outputRoomEventsStatements) InsertEvent(
}
var headeredJSON []byte
- headeredJSON, err = json.Marshal(event)
- if err != nil {
- return
- }
-
- streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
+ headeredJSON, err := json.Marshal(event)
if err != nil {
- return
+ return 0, err
}
addStateJSON, err := json.Marshal(addState)
if err != nil {
- return
+ return 0, err
}
removeStateJSON, err := json.Marshal(removeState)
if err != nil {
- return
+ return 0, err
}
- insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
- _, err = insertStmt.ExecContext(
- ctx,
- streamPos,
- event.RoomID(),
- event.EventID(),
- headeredJSON,
- event.Type(),
- event.Sender(),
- containsURL,
- string(addStateJSON),
- string(removeStateJSON),
- sessionID,
- txnID,
- excludeFromSync,
- excludeFromSync,
- )
- return
+ var streamPos types.StreamPosition
+ err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
+ if err != nil {
+ return err
+ }
+
+ insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
+ _, ierr := insertStmt.ExecContext(
+ ctx,
+ streamPos,
+ event.RoomID(),
+ event.EventID(),
+ headeredJSON,
+ event.Type(),
+ event.Sender(),
+ containsURL,
+ string(addStateJSON),
+ string(removeStateJSON),
+ sessionID,
+ txnID,
+ excludeFromSync,
+ excludeFromSync,
+ )
+ return ierr
+ })
+ return streamPos, err
}
func (s *outputRoomEventsStatements) SelectRecentEvents(
diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go
index 811dfa4f..2e71e8f3 100644
--- a/syncapi/storage/sqlite3/output_room_events_topology_table.go
+++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go
@@ -66,6 +66,8 @@ const selectMaxPositionInTopologySQL = "" +
" WHERE room_id = $1 ORDER BY stream_position DESC"
type outputRoomEventsTopologyStatements struct {
+ db *sql.DB
+ writer *sqlutil.TransactionWriter
insertEventInTopologyStmt *sql.Stmt
selectEventIDsInRangeASCStmt *sql.Stmt
selectEventIDsInRangeDESCStmt *sql.Stmt
@@ -74,7 +76,10 @@ type outputRoomEventsTopologyStatements struct {
}
func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
- s := &outputRoomEventsTopologyStatements{}
+ s := &outputRoomEventsTopologyStatements{
+ db: db,
+ writer: sqlutil.NewTransactionWriter(),
+ }
_, err := db.Exec(outputRoomEventsTopologySchema)
if err != nil {
return nil, err
@@ -102,11 +107,13 @@ func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
func (s *outputRoomEventsTopologyStatements) InsertEventInTopology(
ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition,
) (err error) {
- stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt)
- _, err = stmt.ExecContext(
- ctx, event.EventID(), event.Depth(), event.RoomID(), pos,
- )
- return
+ return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt)
+ _, err := stmt.ExecContext(
+ ctx, event.EventID(), event.Depth(), event.RoomID(), pos,
+ )
+ return err
+ })
}
func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
diff --git a/syncapi/storage/sqlite3/send_to_device_table.go b/syncapi/storage/sqlite3/send_to_device_table.go
index 42bd3c19..88b319fb 100644
--- a/syncapi/storage/sqlite3/send_to_device_table.go
+++ b/syncapi/storage/sqlite3/send_to_device_table.go
@@ -72,13 +72,18 @@ const deleteSendToDeviceMessagesSQL = `
`
type sendToDeviceStatements struct {
+ db *sql.DB
+ writer *sqlutil.TransactionWriter
insertSendToDeviceMessageStmt *sql.Stmt
selectSendToDeviceMessagesStmt *sql.Stmt
countSendToDeviceMessagesStmt *sql.Stmt
}
func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
- s := &sendToDeviceStatements{}
+ s := &sendToDeviceStatements{
+ db: db,
+ writer: sqlutil.NewTransactionWriter(),
+ }
_, err := db.Exec(sendToDeviceSchema)
if err != nil {
return nil, err
@@ -98,8 +103,10 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
) (err error) {
- _, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
- return
+ return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ _, err := sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
+ return err
+ })
}
func (s *sendToDeviceStatements) CountSendToDeviceMessages(
@@ -156,8 +163,10 @@ func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(
for k, v := range nids {
params[k+1] = v
}
- _, err = txn.ExecContext(ctx, query, params...)
- return
+ return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ _, err := txn.ExecContext(ctx, query, params...)
+ return err
+ })
}
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
@@ -168,6 +177,8 @@ func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
for k, v := range nids {
params[k] = v
}
- _, err = txn.ExecContext(ctx, query, params...)
- return
+ return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ _, err := txn.ExecContext(ctx, query, params...)
+ return err
+ })
}
diff --git a/syncapi/storage/sqlite3/stream_id_table.go b/syncapi/storage/sqlite3/stream_id_table.go
index 57abd9c4..cf3eed5b 100644
--- a/syncapi/storage/sqlite3/stream_id_table.go
+++ b/syncapi/storage/sqlite3/stream_id_table.go
@@ -27,11 +27,15 @@ const selectStreamIDStmt = "" +
"SELECT stream_id FROM syncapi_stream_id WHERE stream_name = $1"
type streamIDStatements struct {
+ db *sql.DB
+ writer *sqlutil.TransactionWriter
increaseStreamIDStmt *sql.Stmt
selectStreamIDStmt *sql.Stmt
}
func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
+ s.db = db
+ s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(streamIDTableSchema)
if err != nil {
return
@@ -48,11 +52,14 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
func (s *streamIDStatements) nextStreamID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
- if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil {
- return
- }
- if err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos); err != nil {
- return
- }
+ err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
+ if _, ierr := increaseStmt.ExecContext(ctx, "global"); err != nil {
+ return ierr
+ }
+ if serr := selectStmt.QueryRowContext(ctx, "global").Scan(&pos); err != nil {
+ return serr
+ }
+ return nil
+ })
return
}
diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go
index feacbc18..474d3222 100644
--- a/syncapi/storage/storage_test.go
+++ b/syncapi/storage/storage_test.go
@@ -5,6 +5,7 @@ import (
"crypto/ed25519"
"encoding/json"
"fmt"
+ "os"
"testing"
"time"
@@ -52,7 +53,13 @@ func MustCreateEvent(t *testing.T, roomID string, prevs []gomatrixserverlib.Head
}
func MustCreateDatabase(t *testing.T) storage.Database {
- db, err := sqlite3.NewDatabase("file::memory:")
+ dbname := fmt.Sprintf("test_%s.db", t.Name())
+ if _, err := os.Stat(dbname); err == nil {
+ if err = os.Remove(dbname); err != nil {
+ t.Fatalf("tried to delete stale test database but failed: %s", err)
+ }
+ }
+ db, err := sqlite3.NewDatabase(fmt.Sprintf("file:%s", dbname))
if err != nil {
t.Fatalf("NewSyncServerDatasource returned %s", err)
}
diff --git a/userapi/storage/accounts/sqlite3/account_data_table.go b/userapi/storage/accounts/sqlite3/account_data_table.go
index d048dbd1..cb54412a 100644
--- a/userapi/storage/accounts/sqlite3/account_data_table.go
+++ b/userapi/storage/accounts/sqlite3/account_data_table.go
@@ -18,6 +18,8 @@ import (
"context"
"database/sql"
"encoding/json"
+
+ "github.com/matrix-org/dendrite/internal/sqlutil"
)
const accountDataSchema = `
@@ -48,12 +50,16 @@ const selectAccountDataByTypeSQL = "" +
"SELECT content FROM account_data WHERE localpart = $1 AND room_id = $2 AND type = $3"
type accountDataStatements struct {
+ db *sql.DB
+ writer *sqlutil.TransactionWriter
insertAccountDataStmt *sql.Stmt
selectAccountDataStmt *sql.Stmt
selectAccountDataByTypeStmt *sql.Stmt
}
func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
+ s.db = db
+ s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(accountDataSchema)
if err != nil {
return
@@ -73,8 +79,10 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
func (s *accountDataStatements) insertAccountData(
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
) (err error) {
- _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
- return
+ return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ _, err := txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
+ return err
+ })
}
func (s *accountDataStatements) selectAccountData(
diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go
index 768f536d..27c3d845 100644
--- a/userapi/storage/accounts/sqlite3/accounts_table.go
+++ b/userapi/storage/accounts/sqlite3/accounts_table.go
@@ -20,6 +20,7 @@ import (
"time"
"github.com/matrix-org/dendrite/clientapi/userutil"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
@@ -57,6 +58,8 @@ const selectNewNumericLocalpartSQL = "" +
// TODO: Update password
type accountsStatements struct {
+ db *sql.DB
+ writer *sqlutil.TransactionWriter
insertAccountStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
@@ -65,6 +68,8 @@ type accountsStatements struct {
}
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
+ s.db = db
+ s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(accountsSchema)
if err != nil {
return
@@ -94,12 +99,15 @@ func (s *accountsStatements) insertAccount(
createdTimeMS := time.Now().UnixNano() / 1000000
stmt := s.insertAccountStmt
- var err error
- if appserviceID == "" {
- _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil)
- } else {
- _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
- }
+ err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ var err error
+ if appserviceID == "" {
+ _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil)
+ } else {
+ _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
+ }
+ return err
+ })
if err != nil {
return nil, err
}
diff --git a/userapi/storage/accounts/sqlite3/profile_table.go b/userapi/storage/accounts/sqlite3/profile_table.go
index 9b5192a0..68cea516 100644
--- a/userapi/storage/accounts/sqlite3/profile_table.go
+++ b/userapi/storage/accounts/sqlite3/profile_table.go
@@ -19,6 +19,7 @@ import (
"database/sql"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
)
const profilesSchema = `
@@ -46,6 +47,8 @@ const setDisplayNameSQL = "" +
"UPDATE account_profiles SET display_name = $1 WHERE localpart = $2"
type profilesStatements struct {
+ db *sql.DB
+ writer *sqlutil.TransactionWriter
insertProfileStmt *sql.Stmt
selectProfileByLocalpartStmt *sql.Stmt
setAvatarURLStmt *sql.Stmt
@@ -53,6 +56,8 @@ type profilesStatements struct {
}
func (s *profilesStatements) prepare(db *sql.DB) (err error) {
+ s.db = db
+ s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(profilesSchema)
if err != nil {
return
@@ -75,8 +80,10 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
func (s *profilesStatements) insertProfile(
ctx context.Context, txn *sql.Tx, localpart string,
) (err error) {
- _, err = txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
- return
+ return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ _, err := txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
+ return err
+ })
}
func (s *profilesStatements) selectProfileByLocalpart(
diff --git a/userapi/storage/accounts/sqlite3/threepid_table.go b/userapi/storage/accounts/sqlite3/threepid_table.go
index 0200dee7..0104e834 100644
--- a/userapi/storage/accounts/sqlite3/threepid_table.go
+++ b/userapi/storage/accounts/sqlite3/threepid_table.go
@@ -53,6 +53,8 @@ const deleteThreePIDSQL = "" +
"DELETE FROM account_threepid WHERE threepid = $1 AND medium = $2"
type threepidStatements struct {
+ db *sql.DB
+ writer *sqlutil.TransactionWriter
selectLocalpartForThreePIDStmt *sql.Stmt
selectThreePIDsForLocalpartStmt *sql.Stmt
insertThreePIDStmt *sql.Stmt
@@ -60,6 +62,8 @@ type threepidStatements struct {
}
func (s *threepidStatements) prepare(db *sql.DB) (err error) {
+ s.db = db
+ s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(threepidSchema)
if err != nil {
return
@@ -118,13 +122,18 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(
func (s *threepidStatements) insertThreePID(
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
) (err error) {
- stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
- _, err = stmt.ExecContext(ctx, threepid, medium, localpart)
- return
+ return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
+ _, err := stmt.ExecContext(ctx, threepid, medium, localpart)
+ return err
+ })
}
func (s *threepidStatements) deleteThreePID(
ctx context.Context, threepid string, medium string) (err error) {
- _, err = s.deleteThreePIDStmt.ExecContext(ctx, threepid, medium)
- return
+ return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
+ stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt)
+ _, err := stmt.ExecContext(ctx, threepid, medium)
+ return err
+ })
}
diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go
index 07ea5dca..ec52c64b 100644
--- a/userapi/storage/devices/sqlite3/devices_table.go
+++ b/userapi/storage/devices/sqlite3/devices_table.go
@@ -74,6 +74,7 @@ const deleteDevicesSQL = "" +
type devicesStatements struct {
db *sql.DB
+ writer *sqlutil.TransactionWriter
insertDeviceStmt *sql.Stmt
selectDevicesCountStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt
@@ -87,6 +88,7 @@ type devicesStatements struct {
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
s.db = db
+ s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(devicesSchema)
if err != nil {
return
@@ -128,13 +130,19 @@ func (s *devicesStatements) insertDevice(
) (*api.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64
- countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt)
- insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
- if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil {
- return nil, err
- }
- sessionID++
- if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil {
+ err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt)
+ insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
+ if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil {
+ return err
+ }
+ sessionID++
+ if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil {
+ return err
+ }
+ return nil
+ })
+ if err != nil {
return nil, err
}
return &api.Device{
@@ -148,9 +156,11 @@ func (s *devicesStatements) insertDevice(
func (s *devicesStatements) deleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string,
) error {
- stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
- _, err := stmt.ExecContext(ctx, id, localpart)
- return err
+ return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
+ _, err := stmt.ExecContext(ctx, id, localpart)
+ return err
+ })
}
func (s *devicesStatements) deleteDevices(
@@ -161,31 +171,37 @@ func (s *devicesStatements) deleteDevices(
if err != nil {
return err
}
- stmt := sqlutil.TxStmt(txn, prep)
- params := make([]interface{}, len(devices)+1)
- params[0] = localpart
- for i, v := range devices {
- params[i+1] = v
- }
- params = append(params, params...)
- _, err = stmt.ExecContext(ctx, params...)
- return err
+ return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ stmt := sqlutil.TxStmt(txn, prep)
+ params := make([]interface{}, len(devices)+1)
+ params[0] = localpart
+ for i, v := range devices {
+ params[i+1] = v
+ }
+ params = append(params, params...)
+ _, err = stmt.ExecContext(ctx, params...)
+ return err
+ })
}
func (s *devicesStatements) deleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string,
) error {
- stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
- _, err := stmt.ExecContext(ctx, localpart)
- return err
+ return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
+ _, err := stmt.ExecContext(ctx, localpart)
+ return err
+ })
}
func (s *devicesStatements) updateDeviceName(
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
) error {
- stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
- _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
- return err
+ return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
+ stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
+ _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
+ return err
+ })
}
func (s *devicesStatements) selectDeviceByToken(