diff options
author | santhoshivan23 <47689668+santhoshivan23@users.noreply.github.com> | 2023-06-22 22:07:21 +0530 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-06-22 16:37:21 +0000 |
commit | 45082d4dcefadceada1b4374f3876365887cfd4a (patch) | |
tree | 899fda990a71f16eb05073098196a2a1a1218bd3 /userapi | |
parent | a734b112c6577a23b87c6b54c50fb2e9a629cf2b (diff) |
feat: admin APIs for token authenticated registration (#3101)
### Pull Request Checklist
<!-- Please read
https://matrix-org.github.io/dendrite/development/contributing before
submitting your pull request -->
* [x] I have added Go unit tests or [Complement integration
tests](https://github.com/matrix-org/complement) for this PR _or_ I have
justified why this PR doesn't need tests
* [x] Pull request includes a [sign off below using a legally
identifiable
name](https://matrix-org.github.io/dendrite/development/contributing#sign-off)
_or_ I have already signed off privately
Signed-off-by: `Santhoshivan Amudhan santhoshivan23@gmail.com`
Diffstat (limited to 'userapi')
-rw-r--r-- | userapi/api/api.go | 6 | ||||
-rw-r--r-- | userapi/internal/user_api.go | 32 | ||||
-rw-r--r-- | userapi/storage/interface.go | 11 | ||||
-rw-r--r-- | userapi/storage/postgres/registration_tokens_table.go | 222 | ||||
-rw-r--r-- | userapi/storage/postgres/storage.go | 5 | ||||
-rw-r--r-- | userapi/storage/shared/storage.go | 38 | ||||
-rw-r--r-- | userapi/storage/sqlite3/registration_tokens_table.go | 222 | ||||
-rw-r--r-- | userapi/storage/sqlite3/storage.go | 6 | ||||
-rw-r--r-- | userapi/storage/tables/interface.go | 10 |
9 files changed, 551 insertions, 1 deletions
diff --git a/userapi/api/api.go b/userapi/api/api.go index 05040264..a0dce975 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -27,6 +27,7 @@ import ( "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" + clientapi "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/pushrules" ) @@ -94,6 +95,11 @@ type ClientUserAPI interface { QueryPushers(ctx context.Context, req *QueryPushersRequest, res *QueryPushersResponse) error QueryPushRules(ctx context.Context, userID string) (*pushrules.AccountRuleSets, error) QueryAccountAvailability(ctx context.Context, req *QueryAccountAvailabilityRequest, res *QueryAccountAvailabilityResponse) error + PerformAdminCreateRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (bool, error) + PerformAdminListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) + PerformAdminGetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) + PerformAdminDeleteRegistrationToken(ctx context.Context, tokenString string) error + PerformAdminUpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error) PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error diff --git a/userapi/internal/user_api.go b/userapi/internal/user_api.go index 32f3d84b..4305c13a 100644 --- a/userapi/internal/user_api.go +++ b/userapi/internal/user_api.go @@ -33,6 +33,7 @@ import ( "github.com/sirupsen/logrus" "golang.org/x/crypto/bcrypt" + clientapi "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/pushgateway" @@ -63,6 +64,37 @@ type UserInternalAPI struct { Updater *DeviceListUpdater } +func (a *UserInternalAPI) PerformAdminCreateRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (bool, error) { + exists, err := a.DB.RegistrationTokenExists(ctx, *registrationToken.Token) + if err != nil { + return false, err + } + if exists { + return false, fmt.Errorf("token: %s already exists", *registrationToken.Token) + } + _, err = a.DB.InsertRegistrationToken(ctx, registrationToken) + if err != nil { + return false, fmt.Errorf("Error creating token: %s"+err.Error(), *registrationToken.Token) + } + return true, nil +} + +func (a *UserInternalAPI) PerformAdminListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) { + return a.DB.ListRegistrationTokens(ctx, returnAll, valid) +} + +func (a *UserInternalAPI) PerformAdminGetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) { + return a.DB.GetRegistrationToken(ctx, tokenString) +} + +func (a *UserInternalAPI) PerformAdminDeleteRegistrationToken(ctx context.Context, tokenString string) error { + return a.DB.DeleteRegistrationToken(ctx, tokenString) +} + +func (a *UserInternalAPI) PerformAdminUpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error) { + return a.DB.UpdateRegistrationToken(ctx, tokenString, newAttributes) +} + func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error { local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 4f5e99a8..125b3158 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -23,6 +23,7 @@ import ( "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" + clientapi "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/userapi/api" @@ -30,6 +31,15 @@ import ( "github.com/matrix-org/dendrite/userapi/types" ) +type RegistrationTokens interface { + RegistrationTokenExists(ctx context.Context, token string) (bool, error) + InsertRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (bool, error) + ListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) + GetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) + DeleteRegistrationToken(ctx context.Context, tokenString string) error + UpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error) +} + type Profile interface { GetProfileByLocalpart(ctx context.Context, localpart string, serverName spec.ServerName) (*authtypes.Profile, error) SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) @@ -144,6 +154,7 @@ type UserDatabase interface { Pusher Statistics ThreePID + RegistrationTokens } type KeyChangeDatabase interface { diff --git a/userapi/storage/postgres/registration_tokens_table.go b/userapi/storage/postgres/registration_tokens_table.go new file mode 100644 index 00000000..3c3e3fdd --- /dev/null +++ b/userapi/storage/postgres/registration_tokens_table.go @@ -0,0 +1,222 @@ +package postgres + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/clientapi/api" + internal "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "golang.org/x/exp/constraints" +) + +const registrationTokensSchema = ` +CREATE TABLE IF NOT EXISTS userapi_registration_tokens ( + token TEXT PRIMARY KEY, + pending BIGINT, + completed BIGINT, + uses_allowed BIGINT, + expiry_time BIGINT +); +` + +const selectTokenSQL = "" + + "SELECT token FROM userapi_registration_tokens WHERE token = $1" + +const insertTokenSQL = "" + + "INSERT INTO userapi_registration_tokens (token, uses_allowed, expiry_time, pending, completed) VALUES ($1, $2, $3, $4, $5)" + +const listAllTokensSQL = "" + + "SELECT * FROM userapi_registration_tokens" + +const listValidTokensSQL = "" + + "SELECT * FROM userapi_registration_tokens WHERE" + + "(uses_allowed > pending + completed OR uses_allowed IS NULL) AND" + + "(expiry_time > $1 OR expiry_time IS NULL)" + +const listInvalidTokensSQL = "" + + "SELECT * FROM userapi_registration_tokens WHERE" + + "(uses_allowed <= pending + completed OR expiry_time <= $1)" + +const getTokenSQL = "" + + "SELECT pending, completed, uses_allowed, expiry_time FROM userapi_registration_tokens WHERE token = $1" + +const deleteTokenSQL = "" + + "DELETE FROM userapi_registration_tokens WHERE token = $1" + +const updateTokenUsesAllowedAndExpiryTimeSQL = "" + + "UPDATE userapi_registration_tokens SET uses_allowed = $2, expiry_time = $3 WHERE token = $1" + +const updateTokenUsesAllowedSQL = "" + + "UPDATE userapi_registration_tokens SET uses_allowed = $2 WHERE token = $1" + +const updateTokenExpiryTimeSQL = "" + + "UPDATE userapi_registration_tokens SET expiry_time = $2 WHERE token = $1" + +type registrationTokenStatements struct { + selectTokenStatement *sql.Stmt + insertTokenStatement *sql.Stmt + listAllTokensStatement *sql.Stmt + listValidTokensStatement *sql.Stmt + listInvalidTokenStatement *sql.Stmt + getTokenStatement *sql.Stmt + deleteTokenStatement *sql.Stmt + updateTokenUsesAllowedAndExpiryTimeStatement *sql.Stmt + updateTokenUsesAllowedStatement *sql.Stmt + updateTokenExpiryTimeStatement *sql.Stmt +} + +func NewPostgresRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) { + s := ®istrationTokenStatements{} + _, err := db.Exec(registrationTokensSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.selectTokenStatement, selectTokenSQL}, + {&s.insertTokenStatement, insertTokenSQL}, + {&s.listAllTokensStatement, listAllTokensSQL}, + {&s.listValidTokensStatement, listValidTokensSQL}, + {&s.listInvalidTokenStatement, listInvalidTokensSQL}, + {&s.getTokenStatement, getTokenSQL}, + {&s.deleteTokenStatement, deleteTokenSQL}, + {&s.updateTokenUsesAllowedAndExpiryTimeStatement, updateTokenUsesAllowedAndExpiryTimeSQL}, + {&s.updateTokenUsesAllowedStatement, updateTokenUsesAllowedSQL}, + {&s.updateTokenExpiryTimeStatement, updateTokenExpiryTimeSQL}, + }.Prepare(db) +} + +func (s *registrationTokenStatements) RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) { + var existingToken string + stmt := sqlutil.TxStmt(tx, s.selectTokenStatement) + err := stmt.QueryRowContext(ctx, token).Scan(&existingToken) + if err != nil { + if err == sql.ErrNoRows { + return false, nil + } + return false, err + } + return true, nil +} + +func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Context, tx *sql.Tx, registrationToken *api.RegistrationToken) (bool, error) { + stmt := sqlutil.TxStmt(tx, s.insertTokenStatement) + _, err := stmt.ExecContext( + ctx, + *registrationToken.Token, + getInsertValue(registrationToken.UsesAllowed), + getInsertValue(registrationToken.ExpiryTime), + *registrationToken.Pending, + *registrationToken.Completed) + if err != nil { + return false, err + } + return true, nil +} + +func getInsertValue[t constraints.Integer](in *t) any { + if in == nil { + return nil + } + return *in +} + +func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) { + var stmt *sql.Stmt + var tokens []api.RegistrationToken + var tokenString string + var pending, completed, usesAllowed *int32 + var expiryTime *int64 + var rows *sql.Rows + var err error + if returnAll { + stmt = sqlutil.TxStmt(tx, s.listAllTokensStatement) + rows, err = stmt.QueryContext(ctx) + } else if valid { + stmt = sqlutil.TxStmt(tx, s.listValidTokensStatement) + rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond)) + } else { + stmt = sqlutil.TxStmt(tx, s.listInvalidTokenStatement) + rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond)) + } + if err != nil { + return tokens, err + } + defer internal.CloseAndLogIfError(ctx, rows, "ListRegistrationTokens: rows.close() failed") + for rows.Next() { + err = rows.Scan(&tokenString, &pending, &completed, &usesAllowed, &expiryTime) + if err != nil { + return tokens, err + } + tokenString := tokenString + pending := pending + completed := completed + usesAllowed := usesAllowed + expiryTime := expiryTime + + tokenMap := api.RegistrationToken{ + Token: &tokenString, + Pending: pending, + Completed: completed, + UsesAllowed: usesAllowed, + ExpiryTime: expiryTime, + } + tokens = append(tokens, tokenMap) + } + return tokens, rows.Err() +} + +func (s *registrationTokenStatements) GetRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) (*api.RegistrationToken, error) { + stmt := sqlutil.TxStmt(tx, s.getTokenStatement) + var pending, completed, usesAllowed *int32 + var expiryTime *int64 + err := stmt.QueryRowContext(ctx, tokenString).Scan(&pending, &completed, &usesAllowed, &expiryTime) + if err != nil { + return nil, err + } + token := api.RegistrationToken{ + Token: &tokenString, + Pending: pending, + Completed: completed, + UsesAllowed: usesAllowed, + ExpiryTime: expiryTime, + } + return &token, nil +} + +func (s *registrationTokenStatements) DeleteRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) error { + stmt := sqlutil.TxStmt(tx, s.deleteTokenStatement) + _, err := stmt.ExecContext(ctx, tokenString) + if err != nil { + return err + } + return nil +} + +func (s *registrationTokenStatements) UpdateRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string, newAttributes map[string]interface{}) (*api.RegistrationToken, error) { + var stmt *sql.Stmt + usesAllowed, usesAllowedPresent := newAttributes["usesAllowed"] + expiryTime, expiryTimePresent := newAttributes["expiryTime"] + if usesAllowedPresent && expiryTimePresent { + stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedAndExpiryTimeStatement) + _, err := stmt.ExecContext(ctx, tokenString, usesAllowed, expiryTime) + if err != nil { + return nil, err + } + } else if usesAllowedPresent { + stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedStatement) + _, err := stmt.ExecContext(ctx, tokenString, usesAllowed) + if err != nil { + return nil, err + } + } else if expiryTimePresent { + stmt = sqlutil.TxStmt(tx, s.updateTokenExpiryTimeStatement) + _, err := stmt.ExecContext(ctx, tokenString, expiryTime) + if err != nil { + return nil, err + } + } + return s.GetRegistrationToken(ctx, tx, tokenString) +} diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index 72e7c9cd..d01ccc77 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -53,6 +53,10 @@ func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties * return nil, err } + registationTokensTable, err := NewPostgresRegistrationTokensTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresRegistrationsTokenTable: %w", err) + } accountsTable, err := NewPostgresAccountsTable(db, serverName) if err != nil { return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err) @@ -125,6 +129,7 @@ func NewDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperties * ThreePIDs: threePIDTable, Pushers: pusherTable, Notifications: notificationsTable, + RegistrationTokens: registationTokensTable, Stats: statsTable, ServerName: serverName, DB: db, diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index 537bbbf4..b7acb203 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -31,6 +31,7 @@ import ( "github.com/matrix-org/gomatrixserverlib/spec" "golang.org/x/crypto/bcrypt" + clientapi "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/internal/pushrules" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -43,6 +44,7 @@ import ( type Database struct { DB *sql.DB Writer sqlutil.Writer + RegistrationTokens tables.RegistrationTokensTable Accounts tables.AccountsTable Profiles tables.ProfileTable AccountDatas tables.AccountDataTable @@ -78,6 +80,42 @@ const ( loginTokenByteLength = 32 ) +func (d *Database) RegistrationTokenExists(ctx context.Context, token string) (bool, error) { + return d.RegistrationTokens.RegistrationTokenExists(ctx, nil, token) +} + +func (d *Database) InsertRegistrationToken(ctx context.Context, registrationToken *clientapi.RegistrationToken) (created bool, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + created, err = d.RegistrationTokens.InsertRegistrationToken(ctx, txn, registrationToken) + return err + }) + return +} + +func (d *Database) ListRegistrationTokens(ctx context.Context, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) { + return d.RegistrationTokens.ListRegistrationTokens(ctx, nil, returnAll, valid) +} + +func (d *Database) GetRegistrationToken(ctx context.Context, tokenString string) (*clientapi.RegistrationToken, error) { + return d.RegistrationTokens.GetRegistrationToken(ctx, nil, tokenString) +} + +func (d *Database) DeleteRegistrationToken(ctx context.Context, tokenString string) (err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + err = d.RegistrationTokens.DeleteRegistrationToken(ctx, nil, tokenString) + return err + }) + return +} + +func (d *Database) UpdateRegistrationToken(ctx context.Context, tokenString string, newAttributes map[string]interface{}) (updatedToken *clientapi.RegistrationToken, err error) { + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + updatedToken, err = d.RegistrationTokens.UpdateRegistrationToken(ctx, txn, tokenString, newAttributes) + return err + }) + return +} + // GetAccountByPassword returns the account associated with the given localpart and password. // Returns sql.ErrNoRows if no account exists which matches the given localpart. func (d *Database) GetAccountByPassword( diff --git a/userapi/storage/sqlite3/registration_tokens_table.go b/userapi/storage/sqlite3/registration_tokens_table.go new file mode 100644 index 00000000..89795473 --- /dev/null +++ b/userapi/storage/sqlite3/registration_tokens_table.go @@ -0,0 +1,222 @@ +package sqlite3 + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/clientapi/api" + internal "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "golang.org/x/exp/constraints" +) + +const registrationTokensSchema = ` +CREATE TABLE IF NOT EXISTS userapi_registration_tokens ( + token TEXT PRIMARY KEY, + pending BIGINT, + completed BIGINT, + uses_allowed BIGINT, + expiry_time BIGINT +); +` + +const selectTokenSQL = "" + + "SELECT token FROM userapi_registration_tokens WHERE token = $1" + +const insertTokenSQL = "" + + "INSERT INTO userapi_registration_tokens (token, uses_allowed, expiry_time, pending, completed) VALUES ($1, $2, $3, $4, $5)" + +const listAllTokensSQL = "" + + "SELECT * FROM userapi_registration_tokens" + +const listValidTokensSQL = "" + + "SELECT * FROM userapi_registration_tokens WHERE" + + "(uses_allowed > pending + completed OR uses_allowed IS NULL) AND" + + "(expiry_time > $1 OR expiry_time IS NULL)" + +const listInvalidTokensSQL = "" + + "SELECT * FROM userapi_registration_tokens WHERE" + + "(uses_allowed <= pending + completed OR expiry_time <= $1)" + +const getTokenSQL = "" + + "SELECT pending, completed, uses_allowed, expiry_time FROM userapi_registration_tokens WHERE token = $1" + +const deleteTokenSQL = "" + + "DELETE FROM userapi_registration_tokens WHERE token = $1" + +const updateTokenUsesAllowedAndExpiryTimeSQL = "" + + "UPDATE userapi_registration_tokens SET uses_allowed = $2, expiry_time = $3 WHERE token = $1" + +const updateTokenUsesAllowedSQL = "" + + "UPDATE userapi_registration_tokens SET uses_allowed = $2 WHERE token = $1" + +const updateTokenExpiryTimeSQL = "" + + "UPDATE userapi_registration_tokens SET expiry_time = $2 WHERE token = $1" + +type registrationTokenStatements struct { + selectTokenStatement *sql.Stmt + insertTokenStatement *sql.Stmt + listAllTokensStatement *sql.Stmt + listValidTokensStatement *sql.Stmt + listInvalidTokenStatement *sql.Stmt + getTokenStatement *sql.Stmt + deleteTokenStatement *sql.Stmt + updateTokenUsesAllowedAndExpiryTimeStatement *sql.Stmt + updateTokenUsesAllowedStatement *sql.Stmt + updateTokenExpiryTimeStatement *sql.Stmt +} + +func NewSQLiteRegistrationTokensTable(db *sql.DB) (tables.RegistrationTokensTable, error) { + s := ®istrationTokenStatements{} + _, err := db.Exec(registrationTokensSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.selectTokenStatement, selectTokenSQL}, + {&s.insertTokenStatement, insertTokenSQL}, + {&s.listAllTokensStatement, listAllTokensSQL}, + {&s.listValidTokensStatement, listValidTokensSQL}, + {&s.listInvalidTokenStatement, listInvalidTokensSQL}, + {&s.getTokenStatement, getTokenSQL}, + {&s.deleteTokenStatement, deleteTokenSQL}, + {&s.updateTokenUsesAllowedAndExpiryTimeStatement, updateTokenUsesAllowedAndExpiryTimeSQL}, + {&s.updateTokenUsesAllowedStatement, updateTokenUsesAllowedSQL}, + {&s.updateTokenExpiryTimeStatement, updateTokenExpiryTimeSQL}, + }.Prepare(db) +} + +func (s *registrationTokenStatements) RegistrationTokenExists(ctx context.Context, tx *sql.Tx, token string) (bool, error) { + var existingToken string + stmt := sqlutil.TxStmt(tx, s.selectTokenStatement) + err := stmt.QueryRowContext(ctx, token).Scan(&existingToken) + if err != nil { + if err == sql.ErrNoRows { + return false, nil + } + return false, err + } + return true, nil +} + +func (s *registrationTokenStatements) InsertRegistrationToken(ctx context.Context, tx *sql.Tx, registrationToken *api.RegistrationToken) (bool, error) { + stmt := sqlutil.TxStmt(tx, s.insertTokenStatement) + _, err := stmt.ExecContext( + ctx, + *registrationToken.Token, + getInsertValue(registrationToken.UsesAllowed), + getInsertValue(registrationToken.ExpiryTime), + *registrationToken.Pending, + *registrationToken.Completed) + if err != nil { + return false, err + } + return true, nil +} + +func getInsertValue[t constraints.Integer](in *t) any { + if in == nil { + return nil + } + return *in +} + +func (s *registrationTokenStatements) ListRegistrationTokens(ctx context.Context, tx *sql.Tx, returnAll bool, valid bool) ([]api.RegistrationToken, error) { + var stmt *sql.Stmt + var tokens []api.RegistrationToken + var tokenString string + var pending, completed, usesAllowed *int32 + var expiryTime *int64 + var rows *sql.Rows + var err error + if returnAll { + stmt = sqlutil.TxStmt(tx, s.listAllTokensStatement) + rows, err = stmt.QueryContext(ctx) + } else if valid { + stmt = sqlutil.TxStmt(tx, s.listValidTokensStatement) + rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond)) + } else { + stmt = sqlutil.TxStmt(tx, s.listInvalidTokenStatement) + rows, err = stmt.QueryContext(ctx, time.Now().UnixNano()/int64(time.Millisecond)) + } + if err != nil { + return tokens, err + } + defer internal.CloseAndLogIfError(ctx, rows, "ListRegistrationTokens: rows.close() failed") + for rows.Next() { + err = rows.Scan(&tokenString, &pending, &completed, &usesAllowed, &expiryTime) + if err != nil { + return tokens, err + } + tokenString := tokenString + pending := pending + completed := completed + usesAllowed := usesAllowed + expiryTime := expiryTime + + tokenMap := api.RegistrationToken{ + Token: &tokenString, + Pending: pending, + Completed: completed, + UsesAllowed: usesAllowed, + ExpiryTime: expiryTime, + } + tokens = append(tokens, tokenMap) + } + return tokens, rows.Err() +} + +func (s *registrationTokenStatements) GetRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) (*api.RegistrationToken, error) { + stmt := sqlutil.TxStmt(tx, s.getTokenStatement) + var pending, completed, usesAllowed *int32 + var expiryTime *int64 + err := stmt.QueryRowContext(ctx, tokenString).Scan(&pending, &completed, &usesAllowed, &expiryTime) + if err != nil { + return nil, err + } + token := api.RegistrationToken{ + Token: &tokenString, + Pending: pending, + Completed: completed, + UsesAllowed: usesAllowed, + ExpiryTime: expiryTime, + } + return &token, nil +} + +func (s *registrationTokenStatements) DeleteRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string) error { + stmt := sqlutil.TxStmt(tx, s.deleteTokenStatement) + _, err := stmt.ExecContext(ctx, tokenString) + if err != nil { + return err + } + return nil +} + +func (s *registrationTokenStatements) UpdateRegistrationToken(ctx context.Context, tx *sql.Tx, tokenString string, newAttributes map[string]interface{}) (*api.RegistrationToken, error) { + var stmt *sql.Stmt + usesAllowed, usesAllowedPresent := newAttributes["usesAllowed"] + expiryTime, expiryTimePresent := newAttributes["expiryTime"] + if usesAllowedPresent && expiryTimePresent { + stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedAndExpiryTimeStatement) + _, err := stmt.ExecContext(ctx, tokenString, usesAllowed, expiryTime) + if err != nil { + return nil, err + } + } else if usesAllowedPresent { + stmt = sqlutil.TxStmt(tx, s.updateTokenUsesAllowedStatement) + _, err := stmt.ExecContext(ctx, tokenString, usesAllowed) + if err != nil { + return nil, err + } + } else if expiryTimePresent { + stmt = sqlutil.TxStmt(tx, s.updateTokenExpiryTimeStatement) + _, err := stmt.ExecContext(ctx, tokenString, expiryTime) + if err != nil { + return nil, err + } + } + return s.GetRegistrationToken(ctx, tx, tokenString) +} diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index acd9678f..48f5c842 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -50,7 +50,10 @@ func NewUserDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperti if err = m.Up(ctx); err != nil { return nil, err } - + registationTokensTable, err := NewSQLiteRegistrationTokensTable(db) + if err != nil { + return nil, fmt.Errorf("NewSQLiteRegistrationsTokenTable: %w", err) + } accountsTable, err := NewSQLiteAccountsTable(db, serverName) if err != nil { return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err) @@ -130,6 +133,7 @@ func NewUserDatabase(ctx context.Context, conMan sqlutil.Connections, dbProperti LoginTokenLifetime: loginTokenLifetime, BcryptCost: bcryptCost, OpenIDTokenLifetimeMS: openIDTokenLifetimeMS, + RegistrationTokens: registationTokensTable, }, nil } diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 3c6214e7..3a0be73e 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -25,10 +25,20 @@ import ( "github.com/matrix-org/gomatrixserverlib/fclient" "github.com/matrix-org/gomatrixserverlib/spec" + clientapi "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/userapi/types" ) +type RegistrationTokensTable interface { + RegistrationTokenExists(ctx context.Context, txn *sql.Tx, token string) (bool, error) + InsertRegistrationToken(ctx context.Context, txn *sql.Tx, registrationToken *clientapi.RegistrationToken) (bool, error) + ListRegistrationTokens(ctx context.Context, txn *sql.Tx, returnAll bool, valid bool) ([]clientapi.RegistrationToken, error) + GetRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string) (*clientapi.RegistrationToken, error) + DeleteRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string) error + UpdateRegistrationToken(ctx context.Context, txn *sql.Tx, tokenString string, newAttributes map[string]interface{}) (*clientapi.RegistrationToken, error) +} + type AccountDataTable interface { InsertAccountData(ctx context.Context, txn *sql.Tx, localpart string, serverName spec.ServerName, roomID, dataType string, content json.RawMessage) error SelectAccountData(ctx context.Context, localpart string, serverName spec.ServerName) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error) |