aboutsummaryrefslogtreecommitdiff
path: root/userapi/storage/postgres/accounts_table.go
diff options
context:
space:
mode:
Diffstat (limited to 'userapi/storage/postgres/accounts_table.go')
-rw-r--r--userapi/storage/postgres/accounts_table.go28
1 files changed, 15 insertions, 13 deletions
diff --git a/userapi/storage/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go
index 9e3e456a..92311d56 100644
--- a/userapi/storage/postgres/accounts_table.go
+++ b/userapi/storage/postgres/accounts_table.go
@@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
log "github.com/sirupsen/logrus"
)
@@ -78,14 +79,15 @@ type accountsStatements struct {
serverName gomatrixserverlib.ServerName
}
-func (s *accountsStatements) execSchema(db *sql.DB) error {
+func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) {
+ s := &accountsStatements{
+ serverName: serverName,
+ }
_, err := db.Exec(accountsSchema)
- return err
-}
-
-func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
- s.serverName = server
- return sqlutil.StatementList{
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
{&s.insertAccountStmt, insertAccountSQL},
{&s.updatePasswordStmt, updatePasswordSQL},
{&s.deactivateAccountStmt, deactivateAccountSQL},
@@ -98,7 +100,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
// this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success.
-func (s *accountsStatements) insertAccount(
+func (s *accountsStatements) InsertAccount(
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
@@ -123,28 +125,28 @@ func (s *accountsStatements) insertAccount(
}, nil
}
-func (s *accountsStatements) updatePassword(
+func (s *accountsStatements) UpdatePassword(
ctx context.Context, localpart, passwordHash string,
) (err error) {
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
return
}
-func (s *accountsStatements) deactivateAccount(
+func (s *accountsStatements) DeactivateAccount(
ctx context.Context, localpart string,
) (err error) {
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
return
}
-func (s *accountsStatements) selectPasswordHash(
+func (s *accountsStatements) SelectPasswordHash(
ctx context.Context, localpart string,
) (hash string, err error) {
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
return
}
-func (s *accountsStatements) selectAccountByLocalpart(
+func (s *accountsStatements) SelectAccountByLocalpart(
ctx context.Context, localpart string,
) (*api.Account, error) {
var appserviceIDPtr sql.NullString
@@ -168,7 +170,7 @@ func (s *accountsStatements) selectAccountByLocalpart(
return &acc, nil
}
-func (s *accountsStatements) selectNewNumericLocalpart(
+func (s *accountsStatements) SelectNewNumericLocalpart(
ctx context.Context, txn *sql.Tx,
) (id int64, err error) {
stmt := s.selectNewNumericLocalpartStmt