From e09d24e7323e73791e7bb31fa7fac1d3acf0c299 Mon Sep 17 00:00:00 2001
From: Kegsay <kegan@matrix.org>
Date: Wed, 17 Jun 2020 12:05:56 +0100
Subject: Move account/device DBs to userapi (#1141)

---
 .../storage/accounts/postgres/accounts_table.go    | 157 +++++++++++++++++++++
 1 file changed, 157 insertions(+)
 create mode 100644 userapi/storage/accounts/postgres/accounts_table.go

(limited to 'userapi/storage/accounts/postgres/accounts_table.go')

diff --git a/userapi/storage/accounts/postgres/accounts_table.go b/userapi/storage/accounts/postgres/accounts_table.go
new file mode 100644
index 00000000..931ffb73
--- /dev/null
+++ b/userapi/storage/accounts/postgres/accounts_table.go
@@ -0,0 +1,157 @@
+// Copyright 2017 Vector Creations Ltd
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+	"context"
+	"database/sql"
+	"time"
+
+	"github.com/matrix-org/dendrite/clientapi/userutil"
+	"github.com/matrix-org/dendrite/userapi/api"
+	"github.com/matrix-org/gomatrixserverlib"
+
+	log "github.com/sirupsen/logrus"
+)
+
+const accountsSchema = `
+-- Stores data about accounts.
+CREATE TABLE IF NOT EXISTS account_accounts (
+    -- The Matrix user ID localpart for this account
+    localpart TEXT NOT NULL PRIMARY KEY,
+    -- When this account was first created, as a unix timestamp (ms resolution).
+    created_ts BIGINT NOT NULL,
+    -- The password hash for this account. Can be NULL if this is a passwordless account.
+    password_hash TEXT,
+    -- Identifies which application service this account belongs to, if any.
+    appservice_id TEXT
+    -- TODO:
+    -- is_guest, is_admin, upgraded_ts, devices, any email reset stuff?
+);
+-- Create sequence for autogenerated numeric usernames
+CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
+`
+
+const insertAccountSQL = "" +
+	"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
+
+const selectAccountByLocalpartSQL = "" +
+	"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
+
+const selectPasswordHashSQL = "" +
+	"SELECT password_hash FROM account_accounts WHERE localpart = $1"
+
+const selectNewNumericLocalpartSQL = "" +
+	"SELECT nextval('numeric_username_seq')"
+
+// TODO: Update password
+
+type accountsStatements struct {
+	insertAccountStmt             *sql.Stmt
+	selectAccountByLocalpartStmt  *sql.Stmt
+	selectPasswordHashStmt        *sql.Stmt
+	selectNewNumericLocalpartStmt *sql.Stmt
+	serverName                    gomatrixserverlib.ServerName
+}
+
+func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
+	_, err = db.Exec(accountsSchema)
+	if err != nil {
+		return
+	}
+	if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil {
+		return
+	}
+	if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil {
+		return
+	}
+	if s.selectPasswordHashStmt, err = db.Prepare(selectPasswordHashSQL); err != nil {
+		return
+	}
+	if s.selectNewNumericLocalpartStmt, err = db.Prepare(selectNewNumericLocalpartSQL); err != nil {
+		return
+	}
+	s.serverName = server
+	return
+}
+
+// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
+// this account will be passwordless. Returns an error if this account already exists. Returns the account
+// on success.
+func (s *accountsStatements) insertAccount(
+	ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string,
+) (*api.Account, error) {
+	createdTimeMS := time.Now().UnixNano() / 1000000
+	stmt := txn.Stmt(s.insertAccountStmt)
+
+	var err error
+	if appserviceID == "" {
+		_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil)
+	} else {
+		_, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
+	}
+	if err != nil {
+		return nil, err
+	}
+
+	return &api.Account{
+		Localpart:    localpart,
+		UserID:       userutil.MakeUserID(localpart, s.serverName),
+		ServerName:   s.serverName,
+		AppServiceID: appserviceID,
+	}, nil
+}
+
+func (s *accountsStatements) selectPasswordHash(
+	ctx context.Context, localpart string,
+) (hash string, err error) {
+	err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
+	return
+}
+
+func (s *accountsStatements) selectAccountByLocalpart(
+	ctx context.Context, localpart string,
+) (*api.Account, error) {
+	var appserviceIDPtr sql.NullString
+	var acc api.Account
+
+	stmt := s.selectAccountByLocalpartStmt
+	err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr)
+	if err != nil {
+		if err != sql.ErrNoRows {
+			log.WithError(err).Error("Unable to retrieve user from the db")
+		}
+		return nil, err
+	}
+	if appserviceIDPtr.Valid {
+		acc.AppServiceID = appserviceIDPtr.String
+	}
+
+	acc.UserID = userutil.MakeUserID(localpart, s.serverName)
+	acc.ServerName = s.serverName
+
+	return &acc, nil
+}
+
+func (s *accountsStatements) selectNewNumericLocalpart(
+	ctx context.Context, txn *sql.Tx,
+) (id int64, err error) {
+	stmt := s.selectNewNumericLocalpartStmt
+	if txn != nil {
+		stmt = txn.Stmt(stmt)
+	}
+	err = stmt.QueryRowContext(ctx).Scan(&id)
+	return
+}
-- 
cgit v1.2.3