aboutsummaryrefslogtreecommitdiff
path: root/userapi/storage/accounts/sqlite3/accounts_table.go
diff options
context:
space:
mode:
Diffstat (limited to 'userapi/storage/accounts/sqlite3/accounts_table.go')
-rw-r--r--userapi/storage/accounts/sqlite3/accounts_table.go21
1 files changed, 8 insertions, 13 deletions
diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go
index 83b90668..798a6de9 100644
--- a/userapi/storage/accounts/sqlite3/accounts_table.go
+++ b/userapi/storage/accounts/sqlite3/accounts_table.go
@@ -20,7 +20,6 @@ import (
"time"
"github.com/matrix-org/dendrite/clientapi/userutil"
- "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
@@ -59,7 +58,6 @@ const selectNewNumericLocalpartSQL = "" +
type accountsStatements struct {
db *sql.DB
- writer sqlutil.Writer
insertAccountStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
@@ -67,9 +65,9 @@ type accountsStatements struct {
serverName gomatrixserverlib.ServerName
}
-func (s *accountsStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) {
+func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
s.db = db
- s.writer = writer
+
_, err = db.Exec(accountsSchema)
if err != nil {
return
@@ -99,15 +97,12 @@ func (s *accountsStatements) insertAccount(
createdTimeMS := time.Now().UnixNano() / 1000000
stmt := s.insertAccountStmt
- err := s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
- var err error
- if appserviceID == "" {
- _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil)
- } else {
- _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
- }
- return err
- })
+ var err error
+ if appserviceID == "" {
+ _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil)
+ } else {
+ _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
+ }
if err != nil {
return nil, err
}