diff options
author | S7evinK <2353100+S7evinK@users.noreply.github.com> | 2022-02-16 18:55:38 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-02-16 18:55:38 +0100 |
commit | 5a39512f5f35b13adea3afc2e366e01ec73924de (patch) | |
tree | ac0e5cd6de8798e45cf0b5b37440ae08f4c7ba90 /userapi/storage/accounts/postgres/accounts_table.go | |
parent | e9b672a34e08bce9d12b2a2454c19fde6e52036e (diff) |
Add account type (#2171)
* Add account_type for sqlite3
* Add account_type for postgres
* Remove CreateGuestAccount from interface
* Add new AccountTypes & update test
* Use newly added AccountType for account creation
* Add migrations
* Reuse type
* Add AccounnType to Device, so it can be verified on requests
* Rename migration, add missing update for appservices
* Rename sqlite3 migration
* Add missing AccountType to return value
* Update sqlite migration
Change allowance check on /admin/whois
* Fix migration, add IS NULL
* Move accountType to completeRegistration
* Fix migrations
* Add passing test
Diffstat (limited to 'userapi/storage/accounts/postgres/accounts_table.go')
-rw-r--r-- | userapi/storage/accounts/postgres/accounts_table.go | 24 |
1 files changed, 14 insertions, 10 deletions
diff --git a/userapi/storage/accounts/postgres/accounts_table.go b/userapi/storage/accounts/postgres/accounts_table.go index b57aa901..9e3e456a 100644 --- a/userapi/storage/accounts/postgres/accounts_table.go +++ b/userapi/storage/accounts/postgres/accounts_table.go @@ -19,10 +19,11 @@ import ( "database/sql" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" ) @@ -39,16 +40,18 @@ CREATE TABLE IF NOT EXISTS account_accounts ( -- Identifies which application service this account belongs to, if any. appservice_id TEXT, -- If the account is currently active - is_deactivated BOOLEAN DEFAULT FALSE + is_deactivated BOOLEAN DEFAULT FALSE, + -- The account_type (user = 1, guest = 2, admin = 3, appservice = 4) + account_type SMALLINT NOT NULL -- TODO: - -- is_guest, is_admin, upgraded_ts, devices, any email reset stuff? + -- upgraded_ts, devices, any email reset stuff? ); -- Create sequence for autogenerated numeric usernames CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1; ` const insertAccountSQL = "" + - "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)" + "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" const updatePasswordSQL = "" + "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" @@ -57,7 +60,7 @@ const deactivateAccountSQL = "" + "UPDATE account_accounts SET is_deactivated = TRUE WHERE localpart = $1" const selectAccountByLocalpartSQL = "" + - "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" + "SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1" const selectPasswordHashSQL = "" + "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE" @@ -96,16 +99,16 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server // this account will be passwordless. Returns an error if this account already exists. Returns the account // on success. func (s *accountsStatements) insertAccount( - ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, + ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.insertAccountStmt) var err error - if appserviceID == "" { - _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil) + if accountType != api.AccountTypeAppService { + _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType) } else { - _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) + _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType) } if err != nil { return nil, err @@ -116,6 +119,7 @@ func (s *accountsStatements) insertAccount( UserID: userutil.MakeUserID(localpart, s.serverName), ServerName: s.serverName, AppServiceID: appserviceID, + AccountType: accountType, }, nil } @@ -147,7 +151,7 @@ func (s *accountsStatements) selectAccountByLocalpart( var acc api.Account stmt := s.selectAccountByLocalpartStmt - err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr) + err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve user from the db") |