diff options
author | Kegsay <kegan@matrix.org> | 2020-06-17 12:05:56 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-06-17 12:05:56 +0100 |
commit | e09d24e7323e73791e7bb31fa7fac1d3acf0c299 (patch) | |
tree | 85600bfe84e938ffeb48d7a152f968f86b78fcdf /userapi/storage/accounts/postgres/accounts_table.go | |
parent | 5d5aa0a31d60941c7ece95b4b516044cb8a10cce (diff) |
Move account/device DBs to userapi (#1141)
Diffstat (limited to 'userapi/storage/accounts/postgres/accounts_table.go')
-rw-r--r-- | userapi/storage/accounts/postgres/accounts_table.go | 157 |
1 files changed, 157 insertions, 0 deletions
diff --git a/userapi/storage/accounts/postgres/accounts_table.go b/userapi/storage/accounts/postgres/accounts_table.go new file mode 100644 index 00000000..931ffb73 --- /dev/null +++ b/userapi/storage/accounts/postgres/accounts_table.go @@ -0,0 +1,157 @@ +// Copyright 2017 Vector Creations Ltd +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + + log "github.com/sirupsen/logrus" +) + +const accountsSchema = ` +-- Stores data about accounts. +CREATE TABLE IF NOT EXISTS account_accounts ( + -- The Matrix user ID localpart for this account + localpart TEXT NOT NULL PRIMARY KEY, + -- When this account was first created, as a unix timestamp (ms resolution). + created_ts BIGINT NOT NULL, + -- The password hash for this account. Can be NULL if this is a passwordless account. + password_hash TEXT, + -- Identifies which application service this account belongs to, if any. + appservice_id TEXT + -- TODO: + -- is_guest, is_admin, upgraded_ts, devices, any email reset stuff? +); +-- Create sequence for autogenerated numeric usernames +CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1; +` + +const insertAccountSQL = "" + + "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)" + +const selectAccountByLocalpartSQL = "" + + "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" + +const selectPasswordHashSQL = "" + + "SELECT password_hash FROM account_accounts WHERE localpart = $1" + +const selectNewNumericLocalpartSQL = "" + + "SELECT nextval('numeric_username_seq')" + +// TODO: Update password + +type accountsStatements struct { + insertAccountStmt *sql.Stmt + selectAccountByLocalpartStmt *sql.Stmt + selectPasswordHashStmt *sql.Stmt + selectNewNumericLocalpartStmt *sql.Stmt + serverName gomatrixserverlib.ServerName +} + +func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { + _, err = db.Exec(accountsSchema) + if err != nil { + return + } + if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil { + return + } + if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil { + return + } + if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil { + return + } + if s.selectNewNumericLocalpartStmt, err = db.Prepare(selectNewNumericLocalpartSQL); err != nil { + return + } + s.serverName = server + return +} + +// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing, +// this account will be passwordless. Returns an error if this account already exists. Returns the account +// on success. +func (s *accountsStatements) insertAccount( + ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, +) (*api.Account, error) { + createdTimeMS := time.Now().UnixNano() / 1000000 + stmt := txn.Stmt(s.insertAccountStmt) + + var err error + if appserviceID == "" { + _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil) + } else { + _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) + } + if err != nil { + return nil, err + } + + return &api.Account{ + Localpart: localpart, + UserID: userutil.MakeUserID(localpart, s.serverName), + ServerName: s.serverName, + AppServiceID: appserviceID, + }, nil +} + +func (s *accountsStatements) selectPasswordHash( + ctx context.Context, localpart string, +) (hash string, err error) { + err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) + return +} + +func (s *accountsStatements) selectAccountByLocalpart( + ctx context.Context, localpart string, +) (*api.Account, error) { + var appserviceIDPtr sql.NullString + var acc api.Account + + stmt := s.selectAccountByLocalpartStmt + err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr) + if err != nil { + if err != sql.ErrNoRows { + log.WithError(err).Error("Unable to retrieve user from the db") + } + return nil, err + } + if appserviceIDPtr.Valid { + acc.AppServiceID = appserviceIDPtr.String + } + + acc.UserID = userutil.MakeUserID(localpart, s.serverName) + acc.ServerName = s.serverName + + return &acc, nil +} + +func (s *accountsStatements) selectNewNumericLocalpart( + ctx context.Context, txn *sql.Tx, +) (id int64, err error) { + stmt := s.selectNewNumericLocalpartStmt + if txn != nil { + stmt = txn.Stmt(stmt) + } + err = stmt.QueryRowContext(ctx).Scan(&id) + return +} |