aboutsummaryrefslogtreecommitdiff
path: root/syncapi/storage/sqlite3/account_data_table.go
diff options
context:
space:
mode:
Diffstat (limited to 'syncapi/storage/sqlite3/account_data_table.go')
-rw-r--r--syncapi/storage/sqlite3/account_data_table.go20
1 files changed, 14 insertions, 6 deletions
diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go
index ae5caa4e..609cef14 100644
--- a/syncapi/storage/sqlite3/account_data_table.go
+++ b/syncapi/storage/sqlite3/account_data_table.go
@@ -20,6 +20,7 @@ import (
"database/sql"
"github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
@@ -49,6 +50,8 @@ const selectMaxAccountDataIDSQL = "" +
"SELECT MAX(id) FROM syncapi_account_data_type"
type accountDataStatements struct {
+ db *sql.DB
+ writer *sqlutil.TransactionWriter
streamIDStatements *streamIDStatements
insertAccountDataStmt *sql.Stmt
selectMaxAccountDataIDStmt *sql.Stmt
@@ -57,6 +60,8 @@ type accountDataStatements struct {
func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) {
s := &accountDataStatements{
+ db: db,
+ writer: sqlutil.NewTransactionWriter(),
streamIDStatements: streamID,
}
_, err := db.Exec(accountDataSchema)
@@ -79,12 +84,15 @@ func (s *accountDataStatements) InsertAccountData(
ctx context.Context, txn *sql.Tx,
userID, roomID, dataType string,
) (pos types.StreamPosition, err error) {
- pos, err = s.streamIDStatements.nextStreamID(ctx, txn)
- if err != nil {
- return
- }
- _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType)
- return
+ return pos, s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
+ var err error
+ pos, err = s.streamIDStatements.nextStreamID(ctx, txn)
+ if err != nil {
+ return err
+ }
+ _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType)
+ return err
+ })
}
func (s *accountDataStatements) SelectAccountDataInRange(