diff options
Diffstat (limited to 'userapi/storage/accounts/sqlite3')
4 files changed, 89 insertions, 44 deletions
diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go index 8a7c8fba..5a918e03 100644 --- a/userapi/storage/accounts/sqlite3/accounts_table.go +++ b/userapi/storage/accounts/sqlite3/accounts_table.go @@ -19,10 +19,11 @@ import ( "database/sql" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" ) @@ -39,14 +40,16 @@ CREATE TABLE IF NOT EXISTS account_accounts ( -- Identifies which application service this account belongs to, if any. appservice_id TEXT, -- If the account is currently active - is_deactivated BOOLEAN DEFAULT 0 + is_deactivated BOOLEAN DEFAULT 0, + -- The account_type (user = 1, guest = 2, admin = 3, appservice = 4) + account_type INTEGER NOT NULL -- TODO: - -- is_guest, is_admin, upgraded_ts, devices, any email reset stuff? + -- upgraded_ts, devices, any email reset stuff? ); ` const insertAccountSQL = "" + - "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)" + "INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" const updatePasswordSQL = "" + "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2" @@ -55,7 +58,7 @@ const deactivateAccountSQL = "" + "UPDATE account_accounts SET is_deactivated = 1 WHERE localpart = $1" const selectAccountByLocalpartSQL = "" + - "SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1" + "SELECT localpart, appservice_id, account_type FROM account_accounts WHERE localpart = $1" const selectPasswordHashSQL = "" + "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0" @@ -96,16 +99,16 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server // this account will be passwordless. Returns an error if this account already exists. Returns the account // on success. func (s *accountsStatements) insertAccount( - ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, + ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 stmt := s.insertAccountStmt var err error - if appserviceID == "" { - _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil) + if accountType != api.AccountTypeAppService { + _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType) } else { - _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) + _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType) } if err != nil { return nil, err @@ -147,7 +150,7 @@ func (s *accountsStatements) selectAccountByLocalpart( var acc api.Account stmt := s.selectAccountByLocalpartStmt - err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr) + err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve user from the db") diff --git a/userapi/storage/accounts/sqlite3/deltas/20200929203058_is_active.go b/userapi/storage/accounts/sqlite3/deltas/20200929203058_is_active.go index 9fddb05a..c69614e8 100644 --- a/userapi/storage/accounts/sqlite3/deltas/20200929203058_is_active.go +++ b/userapi/storage/accounts/sqlite3/deltas/20200929203058_is_active.go @@ -4,12 +4,14 @@ import ( "database/sql" "fmt" - "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/pressly/goose" + + "github.com/matrix-org/dendrite/internal/sqlutil" ) func LoadFromGoose() { goose.AddMigration(UpIsActive, DownIsActive) + goose.AddMigration(UpAddAccountType, DownAddAccountType) } func LoadIsActive(m *sqlutil.Migrations) { diff --git a/userapi/storage/accounts/sqlite3/deltas/2022021012490600_add_account_type.go b/userapi/storage/accounts/sqlite3/deltas/2022021012490600_add_account_type.go new file mode 100644 index 00000000..9b058ded --- /dev/null +++ b/userapi/storage/accounts/sqlite3/deltas/2022021012490600_add_account_type.go @@ -0,0 +1,54 @@ +package deltas + +import ( + "database/sql" + "fmt" + + "github.com/pressly/goose" + + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +func init() { + goose.AddMigration(UpAddAccountType, DownAddAccountType) +} + +func LoadAddAccountType(m *sqlutil.Migrations) { + m.AddMigration(UpAddAccountType, DownAddAccountType) +} + +func UpAddAccountType(tx *sql.Tx) error { + // initially set every account to useraccount, change appservice and guest accounts afterwards + // (user = 1, guest = 2, admin = 3, appservice = 4) + _, err := tx.Exec(`ALTER TABLE account_accounts RENAME TO account_accounts_tmp; +CREATE TABLE account_accounts ( + localpart TEXT NOT NULL PRIMARY KEY, + created_ts BIGINT NOT NULL, + password_hash TEXT, + appservice_id TEXT, + is_deactivated BOOLEAN DEFAULT 0, + account_type INTEGER NOT NULL +); +INSERT + INTO account_accounts ( + localpart, created_ts, password_hash, appservice_id, account_type + ) SELECT + localpart, created_ts, password_hash, appservice_id, 1 + FROM account_accounts_tmp +; +UPDATE account_accounts SET account_type = 4 WHERE appservice_id <> ''; +UPDATE account_accounts SET account_type = 2 WHERE localpart GLOB '[0-9]*'; +DROP TABLE account_accounts_tmp;`) + if err != nil { + return fmt.Errorf("failed to add column: %w", err) + } + return nil +} + +func DownAddAccountType(tx *sql.Tx) error { + _, err := tx.Exec(`ALTER TABLE account_accounts DROP COLUMN account_type;`) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go index 2b731b75..0bab16ca 100644 --- a/userapi/storage/accounts/sqlite3/storage.go +++ b/userapi/storage/accounts/sqlite3/storage.go @@ -24,13 +24,14 @@ import ( "sync" "time" + "github.com/matrix-org/gomatrixserverlib" + "golang.org/x/crypto/bcrypt" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas" - "github.com/matrix-org/gomatrixserverlib" - "golang.org/x/crypto/bcrypt" ) // Database represents an account database @@ -77,6 +78,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver } m := sqlutil.NewMigrations() deltas.LoadIsActive(m) + deltas.LoadAddAccountType(m) if err = m.RunDeltas(db, dbProperties); err != nil { return nil, err } @@ -170,38 +172,11 @@ func (d *Database) SetPassword( }) } -// CreateGuestAccount makes a new guest account and creates an empty profile -// for this account. -func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) { - // We need to lock so we sequentially create numeric localparts. If we don't, two calls to - // this function will cause the same number to be selected and one will fail with 'database is locked' - // when the first txn upgrades to a write txn. We also need to lock the account creation else we can - // race with CreateAccount - // We know we'll be the only process since this is sqlite ;) so a lock here will be all that is needed. - d.profilesMu.Lock() - d.accountDatasMu.Lock() - d.accountsMu.Lock() - defer d.profilesMu.Unlock() - defer d.accountDatasMu.Unlock() - defer d.accountsMu.Unlock() - err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - var numLocalpart int64 - numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) - if err != nil { - return err - } - localpart := strconv.FormatInt(numLocalpart, 10) - acc, err = d.createAccount(ctx, txn, localpart, "", "") - return err - }) - return acc, err -} - // CreateAccount makes a new account with the given login name and password, and creates an empty profile // for this account. If no password is supplied, the account will be a passwordless account. If the // account already exists, it will return nil, ErrUserExists. func (d *Database) CreateAccount( - ctx context.Context, localpart, plaintextPassword, appserviceID string, + ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, ) (acc *api.Account, err error) { // Create one account at a time else we can get 'database is locked'. d.profilesMu.Lock() @@ -211,7 +186,18 @@ func (d *Database) CreateAccount( defer d.accountDatasMu.Unlock() defer d.accountsMu.Unlock() err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { - acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) + // For guest accounts, we create a new numeric local part + if accountType == api.AccountTypeGuest { + var numLocalpart int64 + numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) + if err != nil { + return err + } + localpart = strconv.FormatInt(numLocalpart, 10) + plaintextPassword = "" + appserviceID = "" + } + acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType) return err }) return @@ -220,7 +206,7 @@ func (d *Database) CreateAccount( // WARNING! This function assumes that the relevant mutexes have already // been taken out by the caller (e.g. CreateAccount or CreateGuestAccount). func (d *Database) createAccount( - ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, + ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { var err error var account *api.Account @@ -232,7 +218,7 @@ func (d *Database) createAccount( return nil, err } } - if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID); err != nil { + if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil { return nil, sqlutil.ErrUserExists } if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil { |