diff options
Diffstat (limited to 'userapi/storage/accounts/sqlite3/accounts_table.go')
-rw-r--r-- | userapi/storage/accounts/sqlite3/accounts_table.go | 20 |
1 files changed, 14 insertions, 6 deletions
diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go index 768f536d..27c3d845 100644 --- a/userapi/storage/accounts/sqlite3/accounts_table.go +++ b/userapi/storage/accounts/sqlite3/accounts_table.go @@ -20,6 +20,7 @@ import ( "time" "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -57,6 +58,8 @@ const selectNewNumericLocalpartSQL = "" + // TODO: Update password type accountsStatements struct { + db *sql.DB + writer *sqlutil.TransactionWriter insertAccountStmt *sql.Stmt selectAccountByLocalpartStmt *sql.Stmt selectPasswordHashStmt *sql.Stmt @@ -65,6 +68,8 @@ type accountsStatements struct { } func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { + s.db = db + s.writer = sqlutil.NewTransactionWriter() _, err = db.Exec(accountsSchema) if err != nil { return @@ -94,12 +99,15 @@ func (s *accountsStatements) insertAccount( createdTimeMS := time.Now().UnixNano() / 1000000 stmt := s.insertAccountStmt - var err error - if appserviceID == "" { - _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil) - } else { - _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) - } + err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + var err error + if appserviceID == "" { + _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil) + } else { + _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) + } + return err + }) if err != nil { return nil, err } |