aboutsummaryrefslogtreecommitdiff
path: root/userapi/storage/accounts/postgres/accounts_table.go
diff options
context:
space:
mode:
Diffstat (limited to 'userapi/storage/accounts/postgres/accounts_table.go')
-rw-r--r--userapi/storage/accounts/postgres/accounts_table.go16
1 files changed, 14 insertions, 2 deletions
diff --git a/userapi/storage/accounts/postgres/accounts_table.go b/userapi/storage/accounts/postgres/accounts_table.go
index 931ffb73..8c8d32cf 100644
--- a/userapi/storage/accounts/postgres/accounts_table.go
+++ b/userapi/storage/accounts/postgres/accounts_table.go
@@ -47,6 +47,9 @@ CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
+const updatePasswordSQL = "" +
+ "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
+
const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
@@ -56,10 +59,9 @@ const selectPasswordHashSQL = "" +
const selectNewNumericLocalpartSQL = "" +
"SELECT nextval('numeric_username_seq')"
-// TODO: Update password
-
type accountsStatements struct {
insertAccountStmt *sql.Stmt
+ updatePasswordStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt
@@ -74,6 +76,9 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil {
return
}
+ if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil {
+ return
+ }
if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil {
return
}
@@ -114,6 +119,13 @@ func (s *accountsStatements) insertAccount(
}, nil
}
+func (s *accountsStatements) updatePassword(
+ ctx context.Context, localpart, passwordHash string,
+) (err error) {
+ _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
+ return
+}
+
func (s *accountsStatements) selectPasswordHash(
ctx context.Context, localpart string,
) (hash string, err error) {