aboutsummaryrefslogtreecommitdiff
path: root/userapi/storage/accounts/postgres/accounts_table.go
diff options
context:
space:
mode:
authorS7evinK <2353100+S7evinK@users.noreply.github.com>2022-02-16 18:55:38 +0100
committerGitHub <noreply@github.com>2022-02-16 18:55:38 +0100
commit5a39512f5f35b13adea3afc2e366e01ec73924de (patch)
treeac0e5cd6de8798e45cf0b5b37440ae08f4c7ba90 /userapi/storage/accounts/postgres/accounts_table.go
parente9b672a34e08bce9d12b2a2454c19fde6e52036e (diff)
Add account type (#2171)
* Add account_type for sqlite3 * Add account_type for postgres * Remove CreateGuestAccount from interface * Add new AccountTypes & update test * Use newly added AccountType for account creation * Add migrations * Reuse type * Add AccounnType to Device, so it can be verified on requests * Rename migration, add missing update for appservices * Rename sqlite3 migration * Add missing AccountType to return value * Update sqlite migration Change allowance check on /admin/whois * Fix migration, add IS NULL * Move accountType to completeRegistration * Fix migrations * Add passing test
Diffstat (limited to 'userapi/storage/accounts/postgres/accounts_table.go')
-rw-r--r--userapi/storage/accounts/postgres/accounts_table.go24
1 files changed, 14 insertions, 10 deletions
diff --git a/userapi/storage/accounts/postgres/accounts_table.go b/userapi/storage/accounts/postgres/accounts_table.go
index b57aa901..9e3e456a 100644
--- a/userapi/storage/accounts/postgres/accounts_table.go
+++ b/userapi/storage/accounts/postgres/accounts_table.go
@@ -19,10 +19,11 @@ import (
"database/sql"
"time"
+ "github.com/matrix-org/gomatrixserverlib"
+
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
@@ -39,16 +40,18 @@ CREATE TABLE IF NOT EXISTS account_accounts (
-- Identifies which application service this account belongs to, if any.
appservice_id TEXT,
-- If the account is currently active
- is_deactivated BOOLEAN DEFAULT FALSE
+ is_deactivated BOOLEAN DEFAULT FALSE,
+ -- The account_type (user = 1, guest = 2, admin = 3, appservice = 4)
+ account_type SMALLINT NOT NULL
-- TODO:
- -- is_guest, is_admin, upgraded_ts, devices, any email reset stuff?
+ -- upgraded_ts, devices, any email reset stuff?
);
-- Create sequence for autogenerated numeric usernames
CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
`
const insertAccountSQL = "" +
- "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
+ "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
const updatePasswordSQL = "" +
"UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
@@ -57,7 +60,7 @@ const deactivateAccountSQL = "" +
"UPDATE account_accounts SET is_deactivated = TRUE WHERE localpart = $1"
const selectAccountByLocalpartSQL = "" +
- "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
+ "SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1"
const selectPasswordHashSQL = "" +
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
@@ -96,16 +99,16 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
// this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success.
func (s *accountsStatements) insertAccount(
- ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string,
+ ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.insertAccountStmt)
var err error
- if appserviceID == "" {
- _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil)
+ if accountType != api.AccountTypeAppService {
+ _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
} else {
- _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
+ _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
}
if err != nil {
return nil, err
@@ -116,6 +119,7 @@ func (s *accountsStatements) insertAccount(
UserID: userutil.MakeUserID(localpart, s.serverName),
ServerName: s.serverName,
AppServiceID: appserviceID,
+ AccountType: accountType,
}, nil
}
@@ -147,7 +151,7 @@ func (s *accountsStatements) selectAccountByLocalpart(
var acc api.Account
stmt := s.selectAccountByLocalpartStmt
- err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr)
+ err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType)
if err != nil {
if err != sql.ErrNoRows {
log.WithError(err).Error("Unable to retrieve user from the db")