diff options
10 files changed, 108 insertions, 40 deletions
diff --git a/clientapi/auth/storage/accounts/interface.go b/clientapi/auth/storage/accounts/interface.go index 83d3ee72..9f6e3e1e 100644 --- a/clientapi/auth/storage/accounts/interface.go +++ b/clientapi/auth/storage/accounts/interface.go @@ -30,6 +30,7 @@ type Database interface { SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error SetDisplayName(ctx context.Context, localpart string, displayName string) error CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID string) (*authtypes.Account, error) + CreateGuestAccount(ctx context.Context) (*authtypes.Account, error) UpdateMemberships(ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error) GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error) diff --git a/clientapi/auth/storage/accounts/postgres/account_data_table.go b/clientapi/auth/storage/accounts/postgres/account_data_table.go index d0cfcc0c..4573999b 100644 --- a/clientapi/auth/storage/accounts/postgres/account_data_table.go +++ b/clientapi/auth/storage/accounts/postgres/account_data_table.go @@ -72,9 +72,9 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) { } func (s *accountDataStatements) insertAccountData( - ctx context.Context, localpart, roomID, dataType, content string, + ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string, ) (err error) { - stmt := s.insertAccountDataStmt + stmt := txn.Stmt(s.insertAccountDataStmt) _, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content) return } diff --git a/clientapi/auth/storage/accounts/postgres/accounts_table.go b/clientapi/auth/storage/accounts/postgres/accounts_table.go index 6b8ed372..85c1938a 100644 --- a/clientapi/auth/storage/accounts/postgres/accounts_table.go +++ b/clientapi/auth/storage/accounts/postgres/accounts_table.go @@ -91,10 +91,10 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server // this account will be passwordless. Returns an error if this account already exists. Returns the account // on success. func (s *accountsStatements) insertAccount( - ctx context.Context, localpart, hash, appserviceID string, + ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, ) (*authtypes.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 - stmt := s.insertAccountStmt + stmt := txn.Stmt(s.insertAccountStmt) var err error if appserviceID == "" { @@ -146,8 +146,12 @@ func (s *accountsStatements) selectAccountByLocalpart( } func (s *accountsStatements) selectNewNumericLocalpart( - ctx context.Context, + ctx context.Context, txn *sql.Tx, ) (id int64, err error) { - err = s.selectNewNumericLocalpartStmt.QueryRowContext(ctx).Scan(&id) + stmt := s.selectNewNumericLocalpartStmt + if txn != nil { + stmt = txn.Stmt(stmt) + } + err = stmt.QueryRowContext(ctx).Scan(&id) return } diff --git a/clientapi/auth/storage/accounts/postgres/profile_table.go b/clientapi/auth/storage/accounts/postgres/profile_table.go index 38c76c40..d2cbeb8e 100644 --- a/clientapi/auth/storage/accounts/postgres/profile_table.go +++ b/clientapi/auth/storage/accounts/postgres/profile_table.go @@ -73,9 +73,9 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) { } func (s *profilesStatements) insertProfile( - ctx context.Context, localpart string, + ctx context.Context, txn *sql.Tx, localpart string, ) (err error) { - _, err = s.insertProfileStmt.ExecContext(ctx, localpart, "", "") + _, err = txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "") return } diff --git a/clientapi/auth/storage/accounts/postgres/storage.go b/clientapi/auth/storage/accounts/postgres/storage.go index cb74d131..8115dca4 100644 --- a/clientapi/auth/storage/accounts/postgres/storage.go +++ b/clientapi/auth/storage/accounts/postgres/storage.go @@ -18,6 +18,7 @@ import ( "context" "database/sql" "errors" + "strconv" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/common" @@ -118,11 +119,37 @@ func (d *Database) SetDisplayName( return d.profiles.setDisplayName(ctx, localpart, displayName) } +// CreateGuestAccount makes a new guest account and creates an empty profile +// for this account. +func (d *Database) CreateGuestAccount(ctx context.Context) (acc *authtypes.Account, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + var numLocalpart int64 + numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) + if err != nil { + return err + } + localpart := strconv.FormatInt(numLocalpart, 10) + acc, err = d.createAccount(ctx, txn, localpart, "", "") + return err + }) + return acc, err +} + // CreateAccount makes a new account with the given login name and password, and creates an empty profile // for this account. If no password is supplied, the account will be a passwordless account. If the // account already exists, it will return nil, nil. func (d *Database) CreateAccount( ctx context.Context, localpart, plaintextPassword, appserviceID string, +) (acc *authtypes.Account, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) + return err + }) + return +} + +func (d *Database) createAccount( + ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, ) (*authtypes.Account, error) { var err error @@ -134,13 +161,14 @@ func (d *Database) CreateAccount( return nil, err } } - if err := d.profiles.insertProfile(ctx, localpart); err != nil { + if err := d.profiles.insertProfile(ctx, txn, localpart); err != nil { if common.IsUniqueConstraintViolationErr(err) { return nil, nil } return nil, err } - if err := d.SaveAccountData(ctx, localpart, "", "m.push_rules", `{ + + if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{ "global": { "content": [], "override": [], @@ -151,7 +179,7 @@ func (d *Database) CreateAccount( }`); err != nil { return nil, err } - return d.accounts.insertAccount(ctx, localpart, hash, appserviceID) + return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID) } // SaveMembership saves the user matching a given localpart as a member of a given @@ -258,7 +286,9 @@ func (d *Database) newMembership( func (d *Database) SaveAccountData( ctx context.Context, localpart, roomID, dataType, content string, ) error { - return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content) + return common.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) + }) } // GetAccountData returns account data related to a given localpart @@ -288,7 +318,7 @@ func (d *Database) GetAccountDataByType( func (d *Database) GetNewNumericLocalpart( ctx context.Context, ) (int64, error) { - return d.accounts.selectNewNumericLocalpart(ctx) + return d.accounts.selectNewNumericLocalpart(ctx, nil) } func hashPassword(plaintext string) (hash string, err error) { diff --git a/clientapi/auth/storage/accounts/sqlite3/account_data_table.go b/clientapi/auth/storage/accounts/sqlite3/account_data_table.go index c2143881..b6bb6361 100644 --- a/clientapi/auth/storage/accounts/sqlite3/account_data_table.go +++ b/clientapi/auth/storage/accounts/sqlite3/account_data_table.go @@ -72,10 +72,9 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) { } func (s *accountDataStatements) insertAccountData( - ctx context.Context, localpart, roomID, dataType, content string, + ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string, ) (err error) { - stmt := s.insertAccountDataStmt - _, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content) + _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content) return } diff --git a/clientapi/auth/storage/accounts/sqlite3/accounts_table.go b/clientapi/auth/storage/accounts/sqlite3/accounts_table.go index b029951f..fd6a09cd 100644 --- a/clientapi/auth/storage/accounts/sqlite3/accounts_table.go +++ b/clientapi/auth/storage/accounts/sqlite3/accounts_table.go @@ -89,16 +89,16 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server // this account will be passwordless. Returns an error if this account already exists. Returns the account // on success. func (s *accountsStatements) insertAccount( - ctx context.Context, localpart, hash, appserviceID string, + ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, ) (*authtypes.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 stmt := s.insertAccountStmt var err error if appserviceID == "" { - _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil) + _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil) } else { - _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) + _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID) } if err != nil { return nil, err @@ -144,8 +144,12 @@ func (s *accountsStatements) selectAccountByLocalpart( } func (s *accountsStatements) selectNewNumericLocalpart( - ctx context.Context, + ctx context.Context, txn *sql.Tx, ) (id int64, err error) { - err = s.selectNewNumericLocalpartStmt.QueryRowContext(ctx).Scan(&id) + stmt := s.selectNewNumericLocalpartStmt + if txn != nil { + stmt = txn.Stmt(stmt) + } + err = stmt.QueryRowContext(ctx).Scan(&id) return } diff --git a/clientapi/auth/storage/accounts/sqlite3/profile_table.go b/clientapi/auth/storage/accounts/sqlite3/profile_table.go index 7af8307e..9b5192a0 100644 --- a/clientapi/auth/storage/accounts/sqlite3/profile_table.go +++ b/clientapi/auth/storage/accounts/sqlite3/profile_table.go @@ -73,9 +73,9 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) { } func (s *profilesStatements) insertProfile( - ctx context.Context, localpart string, + ctx context.Context, txn *sql.Tx, localpart string, ) (err error) { - _, err = s.insertProfileStmt.ExecContext(ctx, localpart, "", "") + _, err = txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "") return } diff --git a/clientapi/auth/storage/accounts/sqlite3/storage.go b/clientapi/auth/storage/accounts/sqlite3/storage.go index 3e62d10d..9124640c 100644 --- a/clientapi/auth/storage/accounts/sqlite3/storage.go +++ b/clientapi/auth/storage/accounts/sqlite3/storage.go @@ -18,6 +18,8 @@ import ( "context" "database/sql" "errors" + "strconv" + "sync" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/common" @@ -39,6 +41,8 @@ type Database struct { threepids threepidStatements filter filterStatements serverName gomatrixserverlib.ServerName + + createGuestAccountMu sync.Mutex } // NewDatabase creates a new accounts and profiles database @@ -76,7 +80,7 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) if err = f.prepare(db); err != nil { return nil, err } - return &Database{db, partitions, a, p, m, ac, t, f, serverName}, nil + return &Database{db, partitions, a, p, m, ac, t, f, serverName, sync.Mutex{}}, nil } // GetAccountByPassword returns the account associated with the given localpart and password. @@ -118,14 +122,46 @@ func (d *Database) SetDisplayName( return d.profiles.setDisplayName(ctx, localpart, displayName) } +// CreateGuestAccount makes a new guest account and creates an empty profile +// for this account. +func (d *Database) CreateGuestAccount(ctx context.Context) (acc *authtypes.Account, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + // We need to lock so we sequentially create numeric localparts. If we don't, two calls to + // this function will cause the same number to be selected and one will fail with 'database is locked' + // when the first txn upgrades to a write txn. + // We know we'll be the only process since this is sqlite ;) so a lock here will be all that is needed. + d.createGuestAccountMu.Lock() + defer d.createGuestAccountMu.Unlock() + + var numLocalpart int64 + numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn) + if err != nil { + return err + } + localpart := strconv.FormatInt(numLocalpart, 10) + acc, err = d.createAccount(ctx, txn, localpart, "", "") + return err + }) + return acc, err +} + // CreateAccount makes a new account with the given login name and password, and creates an empty profile // for this account. If no password is supplied, the account will be a passwordless account. If the // account already exists, it will return nil, nil. func (d *Database) CreateAccount( ctx context.Context, localpart, plaintextPassword, appserviceID string, +) (acc *authtypes.Account, err error) { + err = common.WithTransaction(d.db, func(txn *sql.Tx) error { + acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID) + return err + }) + return +} + +func (d *Database) createAccount( + ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, ) (*authtypes.Account, error) { var err error - // Generate a password hash if this is not a password-less user hash := "" if plaintextPassword != "" { @@ -134,13 +170,14 @@ func (d *Database) CreateAccount( return nil, err } } - if err := d.profiles.insertProfile(ctx, localpart); err != nil { + if err := d.profiles.insertProfile(ctx, txn, localpart); err != nil { if common.IsUniqueConstraintViolationErr(err) { return nil, nil } return nil, err } - if err := d.SaveAccountData(ctx, localpart, "", "m.push_rules", `{ + + if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{ "global": { "content": [], "override": [], @@ -151,7 +188,7 @@ func (d *Database) CreateAccount( }`); err != nil { return nil, err } - return d.accounts.insertAccount(ctx, localpart, hash, appserviceID) + return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID) } // SaveMembership saves the user matching a given localpart as a member of a given @@ -258,7 +295,9 @@ func (d *Database) newMembership( func (d *Database) SaveAccountData( ctx context.Context, localpart, roomID, dataType, content string, ) error { - return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content) + return common.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content) + }) } // GetAccountData returns account data related to a given localpart @@ -288,7 +327,7 @@ func (d *Database) GetAccountDataByType( func (d *Database) GetNewNumericLocalpart( ctx context.Context, ) (int64, error) { - return d.accounts.selectNewNumericLocalpart(ctx) + return d.accounts.selectNewNumericLocalpart(ctx, nil) } func hashPassword(plaintext string) (hash string, err error) { diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index ba24e527..2de7b273 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -516,16 +516,7 @@ func handleGuestRegistration( accountDB accounts.Database, deviceDB devices.Database, ) util.JSONResponse { - - //Generate numeric local part for guest user - id, err := accountDB.GetNewNumericLocalpart(req.Context()) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("accountDB.GetNewNumericLocalpart failed") - return jsonerror.InternalServerError() - } - - localpart := strconv.FormatInt(id, 10) - acc, err := accountDB.CreateAccount(req.Context(), localpart, "", "") + acc, err := accountDB.CreateGuestAccount(req.Context()) if err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, |