aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--clientapi/auth/storage/accounts/interface.go1
-rw-r--r--clientapi/auth/storage/accounts/postgres/account_data_table.go4
-rw-r--r--clientapi/auth/storage/accounts/postgres/accounts_table.go12
-rw-r--r--clientapi/auth/storage/accounts/postgres/profile_table.go4
-rw-r--r--clientapi/auth/storage/accounts/postgres/storage.go40
-rw-r--r--clientapi/auth/storage/accounts/sqlite3/account_data_table.go5
-rw-r--r--clientapi/auth/storage/accounts/sqlite3/accounts_table.go14
-rw-r--r--clientapi/auth/storage/accounts/sqlite3/profile_table.go4
-rw-r--r--clientapi/auth/storage/accounts/sqlite3/storage.go53
-rw-r--r--clientapi/routing/register.go11
10 files changed, 108 insertions, 40 deletions
diff --git a/clientapi/auth/storage/accounts/interface.go b/clientapi/auth/storage/accounts/interface.go
index 83d3ee72..9f6e3e1e 100644
--- a/clientapi/auth/storage/accounts/interface.go
+++ b/clientapi/auth/storage/accounts/interface.go
@@ -30,6 +30,7 @@ type Database interface {
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
SetDisplayName(ctx context.Context, localpart string, displayName string) error
CreateAccount(ctx context.Context, localpart, plaintextPassword, appserviceID string) (*authtypes.Account, error)
+ CreateGuestAccount(ctx context.Context) (*authtypes.Account, error)
UpdateMemberships(ctx context.Context, eventsToAdd []gomatrixserverlib.Event, idsToRemove []string) error
GetMembershipInRoomByLocalpart(ctx context.Context, localpart, roomID string) (authtypes.Membership, error)
GetMembershipsByLocalpart(ctx context.Context, localpart string) (memberships []authtypes.Membership, err error)
diff --git a/clientapi/auth/storage/accounts/postgres/account_data_table.go b/clientapi/auth/storage/accounts/postgres/account_data_table.go
index d0cfcc0c..4573999b 100644
--- a/clientapi/auth/storage/accounts/postgres/account_data_table.go
+++ b/clientapi/auth/storage/accounts/postgres/account_data_table.go
@@ -72,9 +72,9 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
}
func (s *accountDataStatements) insertAccountData(
- ctx context.Context, localpart, roomID, dataType, content string,
+ ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string,
) (err error) {
- stmt := s.insertAccountDataStmt
+ stmt := txn.Stmt(s.insertAccountDataStmt)
_, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
return
}
diff --git a/clientapi/auth/storage/accounts/postgres/accounts_table.go b/clientapi/auth/storage/accounts/postgres/accounts_table.go
index 6b8ed372..85c1938a 100644
--- a/clientapi/auth/storage/accounts/postgres/accounts_table.go
+++ b/clientapi/auth/storage/accounts/postgres/accounts_table.go
@@ -91,10 +91,10 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
// this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success.
func (s *accountsStatements) insertAccount(
- ctx context.Context, localpart, hash, appserviceID string,
+ ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string,
) (*authtypes.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
- stmt := s.insertAccountStmt
+ stmt := txn.Stmt(s.insertAccountStmt)
var err error
if appserviceID == "" {
@@ -146,8 +146,12 @@ func (s *accountsStatements) selectAccountByLocalpart(
}
func (s *accountsStatements) selectNewNumericLocalpart(
- ctx context.Context,
+ ctx context.Context, txn *sql.Tx,
) (id int64, err error) {
- err = s.selectNewNumericLocalpartStmt.QueryRowContext(ctx).Scan(&id)
+ stmt := s.selectNewNumericLocalpartStmt
+ if txn != nil {
+ stmt = txn.Stmt(stmt)
+ }
+ err = stmt.QueryRowContext(ctx).Scan(&id)
return
}
diff --git a/clientapi/auth/storage/accounts/postgres/profile_table.go b/clientapi/auth/storage/accounts/postgres/profile_table.go
index 38c76c40..d2cbeb8e 100644
--- a/clientapi/auth/storage/accounts/postgres/profile_table.go
+++ b/clientapi/auth/storage/accounts/postgres/profile_table.go
@@ -73,9 +73,9 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
}
func (s *profilesStatements) insertProfile(
- ctx context.Context, localpart string,
+ ctx context.Context, txn *sql.Tx, localpart string,
) (err error) {
- _, err = s.insertProfileStmt.ExecContext(ctx, localpart, "", "")
+ _, err = txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
return
}
diff --git a/clientapi/auth/storage/accounts/postgres/storage.go b/clientapi/auth/storage/accounts/postgres/storage.go
index cb74d131..8115dca4 100644
--- a/clientapi/auth/storage/accounts/postgres/storage.go
+++ b/clientapi/auth/storage/accounts/postgres/storage.go
@@ -18,6 +18,7 @@ import (
"context"
"database/sql"
"errors"
+ "strconv"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/common"
@@ -118,11 +119,37 @@ func (d *Database) SetDisplayName(
return d.profiles.setDisplayName(ctx, localpart, displayName)
}
+// CreateGuestAccount makes a new guest account and creates an empty profile
+// for this account.
+func (d *Database) CreateGuestAccount(ctx context.Context) (acc *authtypes.Account, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ var numLocalpart int64
+ numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn)
+ if err != nil {
+ return err
+ }
+ localpart := strconv.FormatInt(numLocalpart, 10)
+ acc, err = d.createAccount(ctx, txn, localpart, "", "")
+ return err
+ })
+ return acc, err
+}
+
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, nil.
func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string,
+) (acc *authtypes.Account, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID)
+ return err
+ })
+ return
+}
+
+func (d *Database) createAccount(
+ ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string,
) (*authtypes.Account, error) {
var err error
@@ -134,13 +161,14 @@ func (d *Database) CreateAccount(
return nil, err
}
}
- if err := d.profiles.insertProfile(ctx, localpart); err != nil {
+ if err := d.profiles.insertProfile(ctx, txn, localpart); err != nil {
if common.IsUniqueConstraintViolationErr(err) {
return nil, nil
}
return nil, err
}
- if err := d.SaveAccountData(ctx, localpart, "", "m.push_rules", `{
+
+ if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{
"global": {
"content": [],
"override": [],
@@ -151,7 +179,7 @@ func (d *Database) CreateAccount(
}`); err != nil {
return nil, err
}
- return d.accounts.insertAccount(ctx, localpart, hash, appserviceID)
+ return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID)
}
// SaveMembership saves the user matching a given localpart as a member of a given
@@ -258,7 +286,9 @@ func (d *Database) newMembership(
func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType, content string,
) error {
- return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content)
+ return common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
+ })
}
// GetAccountData returns account data related to a given localpart
@@ -288,7 +318,7 @@ func (d *Database) GetAccountDataByType(
func (d *Database) GetNewNumericLocalpart(
ctx context.Context,
) (int64, error) {
- return d.accounts.selectNewNumericLocalpart(ctx)
+ return d.accounts.selectNewNumericLocalpart(ctx, nil)
}
func hashPassword(plaintext string) (hash string, err error) {
diff --git a/clientapi/auth/storage/accounts/sqlite3/account_data_table.go b/clientapi/auth/storage/accounts/sqlite3/account_data_table.go
index c2143881..b6bb6361 100644
--- a/clientapi/auth/storage/accounts/sqlite3/account_data_table.go
+++ b/clientapi/auth/storage/accounts/sqlite3/account_data_table.go
@@ -72,10 +72,9 @@ func (s *accountDataStatements) prepare(db *sql.DB) (err error) {
}
func (s *accountDataStatements) insertAccountData(
- ctx context.Context, localpart, roomID, dataType, content string,
+ ctx context.Context, txn *sql.Tx, localpart, roomID, dataType, content string,
) (err error) {
- stmt := s.insertAccountDataStmt
- _, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content)
+ _, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
return
}
diff --git a/clientapi/auth/storage/accounts/sqlite3/accounts_table.go b/clientapi/auth/storage/accounts/sqlite3/accounts_table.go
index b029951f..fd6a09cd 100644
--- a/clientapi/auth/storage/accounts/sqlite3/accounts_table.go
+++ b/clientapi/auth/storage/accounts/sqlite3/accounts_table.go
@@ -89,16 +89,16 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
// this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success.
func (s *accountsStatements) insertAccount(
- ctx context.Context, localpart, hash, appserviceID string,
+ ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string,
) (*authtypes.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
stmt := s.insertAccountStmt
var err error
if appserviceID == "" {
- _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil)
+ _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil)
} else {
- _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
+ _, err = txn.Stmt(stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID)
}
if err != nil {
return nil, err
@@ -144,8 +144,12 @@ func (s *accountsStatements) selectAccountByLocalpart(
}
func (s *accountsStatements) selectNewNumericLocalpart(
- ctx context.Context,
+ ctx context.Context, txn *sql.Tx,
) (id int64, err error) {
- err = s.selectNewNumericLocalpartStmt.QueryRowContext(ctx).Scan(&id)
+ stmt := s.selectNewNumericLocalpartStmt
+ if txn != nil {
+ stmt = txn.Stmt(stmt)
+ }
+ err = stmt.QueryRowContext(ctx).Scan(&id)
return
}
diff --git a/clientapi/auth/storage/accounts/sqlite3/profile_table.go b/clientapi/auth/storage/accounts/sqlite3/profile_table.go
index 7af8307e..9b5192a0 100644
--- a/clientapi/auth/storage/accounts/sqlite3/profile_table.go
+++ b/clientapi/auth/storage/accounts/sqlite3/profile_table.go
@@ -73,9 +73,9 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
}
func (s *profilesStatements) insertProfile(
- ctx context.Context, localpart string,
+ ctx context.Context, txn *sql.Tx, localpart string,
) (err error) {
- _, err = s.insertProfileStmt.ExecContext(ctx, localpart, "", "")
+ _, err = txn.Stmt(s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
return
}
diff --git a/clientapi/auth/storage/accounts/sqlite3/storage.go b/clientapi/auth/storage/accounts/sqlite3/storage.go
index 3e62d10d..9124640c 100644
--- a/clientapi/auth/storage/accounts/sqlite3/storage.go
+++ b/clientapi/auth/storage/accounts/sqlite3/storage.go
@@ -18,6 +18,8 @@ import (
"context"
"database/sql"
"errors"
+ "strconv"
+ "sync"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/common"
@@ -39,6 +41,8 @@ type Database struct {
threepids threepidStatements
filter filterStatements
serverName gomatrixserverlib.ServerName
+
+ createGuestAccountMu sync.Mutex
}
// NewDatabase creates a new accounts and profiles database
@@ -76,7 +80,7 @@ func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName)
if err = f.prepare(db); err != nil {
return nil, err
}
- return &Database{db, partitions, a, p, m, ac, t, f, serverName}, nil
+ return &Database{db, partitions, a, p, m, ac, t, f, serverName, sync.Mutex{}}, nil
}
// GetAccountByPassword returns the account associated with the given localpart and password.
@@ -118,14 +122,46 @@ func (d *Database) SetDisplayName(
return d.profiles.setDisplayName(ctx, localpart, displayName)
}
+// CreateGuestAccount makes a new guest account and creates an empty profile
+// for this account.
+func (d *Database) CreateGuestAccount(ctx context.Context) (acc *authtypes.Account, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ // We need to lock so we sequentially create numeric localparts. If we don't, two calls to
+ // this function will cause the same number to be selected and one will fail with 'database is locked'
+ // when the first txn upgrades to a write txn.
+ // We know we'll be the only process since this is sqlite ;) so a lock here will be all that is needed.
+ d.createGuestAccountMu.Lock()
+ defer d.createGuestAccountMu.Unlock()
+
+ var numLocalpart int64
+ numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn)
+ if err != nil {
+ return err
+ }
+ localpart := strconv.FormatInt(numLocalpart, 10)
+ acc, err = d.createAccount(ctx, txn, localpart, "", "")
+ return err
+ })
+ return acc, err
+}
+
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, nil.
func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string,
+) (acc *authtypes.Account, err error) {
+ err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID)
+ return err
+ })
+ return
+}
+
+func (d *Database) createAccount(
+ ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string,
) (*authtypes.Account, error) {
var err error
-
// Generate a password hash if this is not a password-less user
hash := ""
if plaintextPassword != "" {
@@ -134,13 +170,14 @@ func (d *Database) CreateAccount(
return nil, err
}
}
- if err := d.profiles.insertProfile(ctx, localpart); err != nil {
+ if err := d.profiles.insertProfile(ctx, txn, localpart); err != nil {
if common.IsUniqueConstraintViolationErr(err) {
return nil, nil
}
return nil, err
}
- if err := d.SaveAccountData(ctx, localpart, "", "m.push_rules", `{
+
+ if err := d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", `{
"global": {
"content": [],
"override": [],
@@ -151,7 +188,7 @@ func (d *Database) CreateAccount(
}`); err != nil {
return nil, err
}
- return d.accounts.insertAccount(ctx, localpart, hash, appserviceID)
+ return d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID)
}
// SaveMembership saves the user matching a given localpart as a member of a given
@@ -258,7 +295,9 @@ func (d *Database) newMembership(
func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType, content string,
) error {
- return d.accountDatas.insertAccountData(ctx, localpart, roomID, dataType, content)
+ return common.WithTransaction(d.db, func(txn *sql.Tx) error {
+ return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
+ })
}
// GetAccountData returns account data related to a given localpart
@@ -288,7 +327,7 @@ func (d *Database) GetAccountDataByType(
func (d *Database) GetNewNumericLocalpart(
ctx context.Context,
) (int64, error) {
- return d.accounts.selectNewNumericLocalpart(ctx)
+ return d.accounts.selectNewNumericLocalpart(ctx, nil)
}
func hashPassword(plaintext string) (hash string, err error) {
diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go
index ba24e527..2de7b273 100644
--- a/clientapi/routing/register.go
+++ b/clientapi/routing/register.go
@@ -516,16 +516,7 @@ func handleGuestRegistration(
accountDB accounts.Database,
deviceDB devices.Database,
) util.JSONResponse {
-
- //Generate numeric local part for guest user
- id, err := accountDB.GetNewNumericLocalpart(req.Context())
- if err != nil {
- util.GetLogger(req.Context()).WithError(err).Error("accountDB.GetNewNumericLocalpart failed")
- return jsonerror.InternalServerError()
- }
-
- localpart := strconv.FormatInt(id, 10)
- acc, err := accountDB.CreateAccount(req.Context(), localpart, "", "")
+ acc, err := accountDB.CreateGuestAccount(req.Context())
if err != nil {
return util.JSONResponse{
Code: http.StatusInternalServerError,