diff options
author | Neil Alexander <neilalexander@users.noreply.github.com> | 2022-02-18 11:31:05 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-02-18 11:31:05 +0000 |
commit | 153bfbbea579dfa10e8e804036f17c1a33b6fe80 (patch) | |
tree | e135dcefc59618d7b86cd8687c1a2a304385ce45 /userapi/storage/postgres/accounts_table.go | |
parent | 0a7dea44505f703af1e7e069602ca95aa5a83700 (diff) |
Merge both user API databases into one (#2186)
* Merge user API databases into one
* Remove DeviceDatabase from config
* Fix tests
* Try that again
* Clean up keyserver device keys when the devices no longer exist in the user API
* Tweak ordering
* Fix UserExists flag, device check
* Allow including empty entries so we can clean them up
* Remove logging
Diffstat (limited to 'userapi/storage/postgres/accounts_table.go')
-rw-r--r-- | userapi/storage/postgres/accounts_table.go | 180 |
1 files changed, 180 insertions, 0 deletions
diff --git a/userapi/storage/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go new file mode 100644 index 00000000..9e3e456a --- /dev/null +++ b/userapi/storage/postgres/accounts_table.go @@ -0,0 +1,180 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + + log "github.com/sirupsen/logrus" +) + +const accountsSchema = ` +-- Stores data about accounts. +CREATE TABLE IF NOT EXISTS account_accounts ( + -- The Matrix user ID localpart for this account + localpart TEXT NOT NULL PRIMARY KEY, + -- When this account was first created, as a unix timestamp (ms resolution). + created_ts BIGINT NOT NULL, + -- The password hash for this account. Can be NULL if this is a passwordless account. + password_hash TEXT, + -- Identifies which application service this account belongs to, if any. + appservice_id TEXT, + -- If the account is currently active + is_deactivated BOOLEAN DEFAULT FALSE, + -- The account_type (user = 1, guest = 2, admin = 3, appservice = 4) + account_type SMALLINT NOT NULL + -- TODO: + -- upgraded_ts, devices, any email reset stuff? +); +-- Create sequence for autogenerated numeric usernames +CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1; +` + +const insertAccountSQL = "" + + "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" + +const updatePasswordSQL = "" + + "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" + +const deactivateAccountSQL = "" + + "UPDATE account_accounts SET is_deactivated = TRUE WHERE localpart = $1" + +const selectAccountByLocalpartSQL = "" + + "SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1" + +const selectPasswordHashSQL = "" + + "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE" + +const selectNewNumericLocalpartSQL = "" + + "SELECT nextval('numeric_username_seq')" + +type accountsStatements struct { + insertAccountStmt *sql.Stmt + updatePasswordStmt *sql.Stmt + deactivateAccountStmt *sql.Stmt + selectAccountByLocalpartStmt *sql.Stmt + selectPasswordHashStmt *sql.Stmt + selectNewNumericLocalpartStmt *sql.Stmt + serverName gomatrixserverlib.ServerName +} + +func (s *accountsStatements) execSchema(db *sql.DB) error { + _, err := db.Exec(accountsSchema) + return err +} + +func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { + s.serverName = server + return sqlutil.StatementList{ + {&s.insertAccountStmt, insertAccountSQL}, + {&s.updatePasswordStmt, updatePasswordSQL}, + {&s.deactivateAccountStmt, deactivateAccountSQL}, + {&s.selectAccountByLocalpartStmt, selectAccountByLocalpartSQL}, + {&s.selectPasswordHashStmt, selectPasswordHashSQL}, + {&s.selectNewNumericLocalpartStmt, selectNewNumericLocalpartSQL}, + }.Prepare(db) +} + +// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing, +// this account will be passwordless. Returns an error if this account already exists. Returns the account +// on success. +func (s *accountsStatements) insertAccount( + ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, +) (*api.Account, error) { + createdTimeMS := time.Now().UnixNano() / 1000000 + stmt := sqlutil.TxStmt(txn, s.insertAccountStmt) + + var err error + if accountType != api.AccountTypeAppService { + _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType) + } else { + _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType) + } + if err != nil { + return nil, err + } + + return &api.Account{ + Localpart: localpart, + UserID: userutil.MakeUserID(localpart, s.serverName), + ServerName: s.serverName, + AppServiceID: appserviceID, + AccountType: accountType, + }, nil +} + +func (s *accountsStatements) updatePassword( + ctx context.Context, localpart, passwordHash string, +) (err error) { + _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) + return +} + +func (s *accountsStatements) deactivateAccount( + ctx context.Context, localpart string, +) (err error) { + _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart) + return +} + +func (s *accountsStatements) selectPasswordHash( + ctx context.Context, localpart string, +) (hash string, err error) { + err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) + return +} + +func (s *accountsStatements) selectAccountByLocalpart( + ctx context.Context, localpart string, +) (*api.Account, error) { + var appserviceIDPtr sql.NullString + var acc api.Account + + stmt := s.selectAccountByLocalpartStmt + err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType) + if err != nil { + if err != sql.ErrNoRows { + log.WithError(err).Error("Unable to retrieve user from the db") + } + return nil, err + } + if appserviceIDPtr.Valid { + acc.AppServiceID = appserviceIDPtr.String + } + + acc.UserID = userutil.MakeUserID(localpart, s.serverName) + acc.ServerName = s.serverName + + return &acc, nil +} + +func (s *accountsStatements) selectNewNumericLocalpart( + ctx context.Context, txn *sql.Tx, +) (id int64, err error) { + stmt := s.selectNewNumericLocalpartStmt + if txn != nil { + stmt = sqlutil.TxStmt(txn, stmt) + } + err = stmt.QueryRowContext(ctx).Scan(&id) + return +} |