diff options
Diffstat (limited to 'userapi/storage/postgres/accounts_table.go')
-rw-r--r-- | userapi/storage/postgres/accounts_table.go | 28 |
1 files changed, 15 insertions, 13 deletions
diff --git a/userapi/storage/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go index 9e3e456a..92311d56 100644 --- a/userapi/storage/postgres/accounts_table.go +++ b/userapi/storage/postgres/accounts_table.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" log "github.com/sirupsen/logrus" ) @@ -78,14 +79,15 @@ type accountsStatements struct { serverName gomatrixserverlib.ServerName } -func (s *accountsStatements) execSchema(db *sql.DB) error { +func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) { + s := &accountsStatements{ + serverName: serverName, + } _, err := db.Exec(accountsSchema) - return err -} - -func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { - s.serverName = server - return sqlutil.StatementList{ + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ {&s.insertAccountStmt, insertAccountSQL}, {&s.updatePasswordStmt, updatePasswordSQL}, {&s.deactivateAccountStmt, deactivateAccountSQL}, @@ -98,7 +100,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server // insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing, // this account will be passwordless. Returns an error if this account already exists. Returns the account // on success. -func (s *accountsStatements) insertAccount( +func (s *accountsStatements) InsertAccount( ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 @@ -123,28 +125,28 @@ func (s *accountsStatements) insertAccount( }, nil } -func (s *accountsStatements) updatePassword( +func (s *accountsStatements) UpdatePassword( ctx context.Context, localpart, passwordHash string, ) (err error) { _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) return } -func (s *accountsStatements) deactivateAccount( +func (s *accountsStatements) DeactivateAccount( ctx context.Context, localpart string, ) (err error) { _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart) return } -func (s *accountsStatements) selectPasswordHash( +func (s *accountsStatements) SelectPasswordHash( ctx context.Context, localpart string, ) (hash string, err error) { err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) return } -func (s *accountsStatements) selectAccountByLocalpart( +func (s *accountsStatements) SelectAccountByLocalpart( ctx context.Context, localpart string, ) (*api.Account, error) { var appserviceIDPtr sql.NullString @@ -168,7 +170,7 @@ func (s *accountsStatements) selectAccountByLocalpart( return &acc, nil } -func (s *accountsStatements) selectNewNumericLocalpart( +func (s *accountsStatements) SelectNewNumericLocalpart( ctx context.Context, txn *sql.Tx, ) (id int64, err error) { stmt := s.selectNewNumericLocalpartStmt |