aboutsummaryrefslogtreecommitdiff
path: root/userapi
diff options
context:
space:
mode:
authorNeil Alexander <neilalexander@users.noreply.github.com>2022-02-18 11:31:05 +0000
committerGitHub <noreply@github.com>2022-02-18 11:31:05 +0000
commit153bfbbea579dfa10e8e804036f17c1a33b6fe80 (patch)
treee135dcefc59618d7b86cd8687c1a2a304385ce45 /userapi
parent0a7dea44505f703af1e7e069602ca95aa5a83700 (diff)
Merge both user API databases into one (#2186)
* Merge user API databases into one * Remove DeviceDatabase from config * Fix tests * Try that again * Clean up keyserver device keys when the devices no longer exist in the user API * Tweak ordering * Fix UserExists flag, device check * Allow including empty entries so we can clean them up * Remove logging
Diffstat (limited to 'userapi')
-rw-r--r--userapi/api/api_logintoken.go7
-rw-r--r--userapi/internal/api.go67
-rw-r--r--userapi/internal/api_logintoken.go8
-rw-r--r--userapi/storage/devices/interface.go52
-rw-r--r--userapi/storage/devices/postgres/storage.go270
-rw-r--r--userapi/storage/devices/sqlite3/storage.go271
-rw-r--r--userapi/storage/devices/storage.go42
-rw-r--r--userapi/storage/devices/storage_wasm.go39
-rw-r--r--userapi/storage/interface.go (renamed from userapi/storage/accounts/interface.go)31
-rw-r--r--userapi/storage/postgres/account_data_table.go (renamed from userapi/storage/accounts/postgres/account_data_table.go)0
-rw-r--r--userapi/storage/postgres/accounts_table.go (renamed from userapi/storage/accounts/postgres/accounts_table.go)0
-rw-r--r--userapi/storage/postgres/deltas/20200929203058_is_active.go (renamed from userapi/storage/accounts/postgres/deltas/20200929203058_is_active.go)0
-rw-r--r--userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go (renamed from userapi/storage/devices/postgres/deltas/20201001204705_last_seen_ts_ip.go)5
-rw-r--r--userapi/storage/postgres/deltas/2022021013023800_add_account_type.go (renamed from userapi/storage/accounts/postgres/deltas/2022021013023800_add_account_type.go)0
-rw-r--r--userapi/storage/postgres/devices_table.go (renamed from userapi/storage/devices/postgres/devices_table.go)3
-rw-r--r--userapi/storage/postgres/key_backup_table.go (renamed from userapi/storage/accounts/postgres/key_backup_table.go)0
-rw-r--r--userapi/storage/postgres/key_backup_version_table.go (renamed from userapi/storage/accounts/postgres/key_backup_version_table.go)0
-rw-r--r--userapi/storage/postgres/logintoken_table.go (renamed from userapi/storage/devices/postgres/logintoken_table.go)3
-rw-r--r--userapi/storage/postgres/openid_table.go (renamed from userapi/storage/accounts/postgres/openid_table.go)0
-rw-r--r--userapi/storage/postgres/profile_table.go (renamed from userapi/storage/accounts/postgres/profile_table.go)0
-rw-r--r--userapi/storage/postgres/storage.go (renamed from userapi/storage/accounts/postgres/storage.go)216
-rw-r--r--userapi/storage/postgres/threepid_table.go (renamed from userapi/storage/accounts/postgres/threepid_table.go)0
-rw-r--r--userapi/storage/sqlite3/account_data_table.go (renamed from userapi/storage/accounts/sqlite3/account_data_table.go)0
-rw-r--r--userapi/storage/sqlite3/accounts_table.go (renamed from userapi/storage/accounts/sqlite3/accounts_table.go)0
-rw-r--r--userapi/storage/sqlite3/constraint_wasm.go (renamed from userapi/storage/accounts/sqlite3/constraint_wasm.go)0
-rw-r--r--userapi/storage/sqlite3/deltas/20200929203058_is_active.go (renamed from userapi/storage/accounts/sqlite3/deltas/20200929203058_is_active.go)0
-rw-r--r--userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go (renamed from userapi/storage/devices/sqlite3/deltas/20201001204705_last_seen_ts_ip.go)5
-rw-r--r--userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go (renamed from userapi/storage/accounts/sqlite3/deltas/2022021012490600_add_account_type.go)0
-rw-r--r--userapi/storage/sqlite3/devices_table.go (renamed from userapi/storage/devices/sqlite3/devices_table.go)3
-rw-r--r--userapi/storage/sqlite3/key_backup_table.go (renamed from userapi/storage/accounts/sqlite3/key_backup_table.go)0
-rw-r--r--userapi/storage/sqlite3/key_backup_version_table.go (renamed from userapi/storage/accounts/sqlite3/key_backup_version_table.go)0
-rw-r--r--userapi/storage/sqlite3/logintoken_table.go (renamed from userapi/storage/devices/sqlite3/logintoken_table.go)3
-rw-r--r--userapi/storage/sqlite3/openid_table.go (renamed from userapi/storage/accounts/sqlite3/openid_table.go)0
-rw-r--r--userapi/storage/sqlite3/profile_table.go (renamed from userapi/storage/accounts/sqlite3/profile_table.go)0
-rw-r--r--userapi/storage/sqlite3/storage.go (renamed from userapi/storage/accounts/sqlite3/storage.go)216
-rw-r--r--userapi/storage/sqlite3/threepid_table.go (renamed from userapi/storage/accounts/sqlite3/threepid_table.go)0
-rw-r--r--userapi/storage/storage.go (renamed from userapi/storage/accounts/storage.go)13
-rw-r--r--userapi/storage/storage_wasm.go (renamed from userapi/storage/accounts/storage_wasm.go)8
-rw-r--r--userapi/userapi.go22
-rw-r--r--userapi/userapi_test.go15
40 files changed, 537 insertions, 762 deletions
diff --git a/userapi/api/api_logintoken.go b/userapi/api/api_logintoken.go
index f3aa037e..e2207bb5 100644
--- a/userapi/api/api_logintoken.go
+++ b/userapi/api/api_logintoken.go
@@ -19,6 +19,13 @@ import (
"time"
)
+// DefaultLoginTokenLifetime determines how old a valid token may be.
+//
+// NOTSPEC: The current spec says "SHOULD be limited to around five
+// seconds". Since TCP retries are on the order of 3 s, 5 s sounds very low.
+// Synapse uses 2 min (https://github.com/matrix-org/synapse/blob/78d5f91de1a9baf4dbb0a794cb49a799f29f7a38/synapse/handlers/auth.py#L1323-L1325).
+const DefaultLoginTokenLifetime = 2 * time.Minute
+
type LoginTokenInternalAPI interface {
// PerformLoginTokenCreation creates a new login token and associates it with the provided data.
PerformLoginTokenCreation(ctx context.Context, req *PerformLoginTokenCreationRequest, res *PerformLoginTokenCreationResponse) error
diff --git a/userapi/internal/api.go b/userapi/internal/api.go
index f96d4804..f54cc613 100644
--- a/userapi/internal/api.go
+++ b/userapi/internal/api.go
@@ -31,13 +31,11 @@ import (
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/dendrite/userapi/storage/accounts"
- "github.com/matrix-org/dendrite/userapi/storage/devices"
+ "github.com/matrix-org/dendrite/userapi/storage"
)
type UserInternalAPI struct {
- AccountDB accounts.Database
- DeviceDB devices.Database
+ DB storage.Database
ServerName gomatrixserverlib.ServerName
// AppServices is the list of all registered AS
AppServices []config.ApplicationService
@@ -55,11 +53,11 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
if req.DataType == "" {
return fmt.Errorf("data type must not be empty")
}
- return a.AccountDB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData)
+ return a.DB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData)
}
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
- acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
+ acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
if err != nil {
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
switch req.OnConflict {
@@ -89,7 +87,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
return nil
}
- if err = a.AccountDB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
+ if err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
return err
}
@@ -99,7 +97,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
}
func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
- if err := a.AccountDB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
+ if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
return err
}
res.PasswordUpdated = true
@@ -112,7 +110,7 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
"device_id": req.DeviceID,
"display_name": req.DeviceDisplayName,
}).Info("PerformDeviceCreation")
- dev, err := a.DeviceDB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
+ dev, err := a.DB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
if err != nil {
return err
}
@@ -137,12 +135,12 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
deletedDeviceIDs := req.DeviceIDs
if len(req.DeviceIDs) == 0 {
var devices []api.Device
- devices, err = a.DeviceDB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
+ devices, err = a.DB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
for _, d := range devices {
deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
}
} else {
- err = a.DeviceDB.RemoveDevices(ctx, local, req.DeviceIDs)
+ err = a.DB.RemoveDevices(ctx, local, req.DeviceIDs)
}
if err != nil {
return err
@@ -196,7 +194,7 @@ func (a *UserInternalAPI) PerformLastSeenUpdate(
if err != nil {
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
}
- if err := a.DeviceDB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr); err != nil {
+ if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr); err != nil {
return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err)
}
return nil
@@ -208,7 +206,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
return err
}
- dev, err := a.DeviceDB.GetDeviceByID(ctx, localpart, req.DeviceID)
+ dev, err := a.DB.GetDeviceByID(ctx, localpart, req.DeviceID)
if err == sql.ErrNoRows {
res.DeviceExists = false
return nil
@@ -223,7 +221,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
return nil
}
- err = a.DeviceDB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName)
+ err = a.DB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
return err
@@ -261,7 +259,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
if domain != a.ServerName {
return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName)
}
- prof, err := a.AccountDB.GetProfileByLocalpart(ctx, local)
+ prof, err := a.DB.GetProfileByLocalpart(ctx, local)
if err != nil {
if err == sql.ErrNoRows {
return nil
@@ -275,7 +273,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
}
func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error {
- profiles, err := a.AccountDB.SearchProfiles(ctx, req.SearchString, req.Limit)
+ profiles, err := a.DB.SearchProfiles(ctx, req.SearchString, req.Limit)
if err != nil {
return err
}
@@ -284,7 +282,7 @@ func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.Quer
}
func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error {
- devices, err := a.DeviceDB.GetDevicesByID(ctx, req.DeviceIDs)
+ devices, err := a.DB.GetDevicesByID(ctx, req.DeviceIDs)
if err != nil {
return err
}
@@ -312,10 +310,11 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice
if domain != a.ServerName {
return fmt.Errorf("cannot query devices of remote users: got %s want %s", domain, a.ServerName)
}
- devs, err := a.DeviceDB.GetDevicesByLocalpart(ctx, local)
+ devs, err := a.DB.GetDevicesByLocalpart(ctx, local)
if err != nil {
return err
}
+ res.UserExists = true
res.Devices = devs
return nil
}
@@ -330,7 +329,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
}
if req.DataType != "" {
var data json.RawMessage
- data, err = a.AccountDB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
+ data, err = a.DB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
if err != nil {
return err
}
@@ -348,7 +347,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
}
return nil
}
- global, rooms, err := a.AccountDB.GetAccountData(ctx, local)
+ global, rooms, err := a.DB.GetAccountData(ctx, local)
if err != nil {
return err
}
@@ -367,7 +366,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
return nil
}
- device, err := a.DeviceDB.GetDeviceByAccessToken(ctx, req.AccessToken)
+ device, err := a.DB.GetDeviceByAccessToken(ctx, req.AccessToken)
if err != nil {
if err == sql.ErrNoRows {
return nil
@@ -378,7 +377,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
if err != nil {
return err
}
- acc, err := a.AccountDB.GetAccountByLocalpart(ctx, localPart)
+ acc, err := a.DB.GetAccountByLocalpart(ctx, localPart)
if err != nil {
return err
}
@@ -419,7 +418,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
if localpart != "" { // AS is masquerading as another user
// Verify that the user is registered
- account, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart)
+ account, err := a.DB.GetAccountByLocalpart(ctx, localpart)
// Verify that the account exists and either appServiceID matches or
// it belongs to the appservice user namespaces
if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) {
@@ -437,7 +436,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
// PerformAccountDeactivation deactivates the user's account, removing all ability for the user to login again.
func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error {
- err := a.AccountDB.DeactivateAccount(ctx, req.Localpart)
+ err := a.DB.DeactivateAccount(ctx, req.Localpart)
res.AccountDeactivated = err == nil
return err
}
@@ -446,7 +445,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error {
token := util.RandomString(24)
- exp, err := a.AccountDB.CreateOpenIDToken(ctx, token, req.UserID)
+ exp, err := a.DB.CreateOpenIDToken(ctx, token, req.UserID)
res.Token = api.OpenIDToken{
Token: token,
@@ -459,7 +458,7 @@ func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *a
// QueryOpenIDToken validates that the OpenID token was issued for the user, the replying party uses this for validation
func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error {
- openIDTokenAttrs, err := a.AccountDB.GetOpenIDTokenAttributes(ctx, req.Token)
+ openIDTokenAttrs, err := a.DB.GetOpenIDTokenAttributes(ctx, req.Token)
if err != nil {
return err
}
@@ -481,7 +480,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
}
return nil
}
- exists, err := a.AccountDB.DeleteKeyBackup(ctx, req.UserID, req.Version)
+ exists, err := a.DB.DeleteKeyBackup(ctx, req.UserID, req.Version)
if err != nil {
res.Error = fmt.Sprintf("failed to delete backup: %s", err)
}
@@ -494,7 +493,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
}
// Create metadata
if req.Version == "" {
- version, err := a.AccountDB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData)
+ version, err := a.DB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData)
if err != nil {
res.Error = fmt.Sprintf("failed to create backup: %s", err)
}
@@ -507,7 +506,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
}
// Update metadata
if len(req.Keys.Rooms) == 0 {
- err := a.AccountDB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData)
+ err := a.DB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData)
if err != nil {
res.Error = fmt.Sprintf("failed to update backup: %s", err)
}
@@ -528,7 +527,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) {
// you can only upload keys for the CURRENT version
- version, _, _, _, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, "")
+ version, _, _, _, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, "")
if err != nil {
res.Error = fmt.Sprintf("failed to query version: %s", err)
return
@@ -556,7 +555,7 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform
})
}
}
- count, etag, err := a.AccountDB.UpsertBackupKeys(ctx, version, req.UserID, uploads)
+ count, etag, err := a.DB.UpsertBackupKeys(ctx, version, req.UserID, uploads)
if err != nil {
res.Error = fmt.Sprintf("failed to upsert keys: %s", err)
return
@@ -566,7 +565,7 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform
}
func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) {
- version, algorithm, authData, etag, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, req.Version)
+ version, algorithm, authData, etag, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, req.Version)
res.Version = version
if err != nil {
if err == sql.ErrNoRows {
@@ -582,14 +581,14 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB
res.Exists = !deleted
if !req.ReturnKeys {
- res.Count, err = a.AccountDB.CountBackupKeys(ctx, version, req.UserID)
+ res.Count, err = a.DB.CountBackupKeys(ctx, version, req.UserID)
if err != nil {
res.Error = fmt.Sprintf("failed to count keys: %s", err)
}
return
}
- result, err := a.AccountDB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID)
+ result, err := a.DB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID)
if err != nil {
res.Error = fmt.Sprintf("failed to query keys: %s", err)
return
diff --git a/userapi/internal/api_logintoken.go b/userapi/internal/api_logintoken.go
index 86ffc58f..f1bf391e 100644
--- a/userapi/internal/api_logintoken.go
+++ b/userapi/internal/api_logintoken.go
@@ -34,7 +34,7 @@ func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *ap
if domain != a.ServerName {
return fmt.Errorf("cannot create a login token for a remote user: got %s want %s", domain, a.ServerName)
}
- tokenMeta, err := a.DeviceDB.CreateLoginToken(ctx, &req.Data)
+ tokenMeta, err := a.DB.CreateLoginToken(ctx, &req.Data)
if err != nil {
return err
}
@@ -45,13 +45,13 @@ func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *ap
// PerformLoginTokenDeletion ensures the token doesn't exist.
func (a *UserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *api.PerformLoginTokenDeletionRequest, res *api.PerformLoginTokenDeletionResponse) error {
util.GetLogger(ctx).Info("PerformLoginTokenDeletion")
- return a.DeviceDB.RemoveLoginToken(ctx, req.Token)
+ return a.DB.RemoveLoginToken(ctx, req.Token)
}
// QueryLoginToken returns the data associated with a login token. If
// the token is not valid, success is returned, but res.Data == nil.
func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLoginTokenRequest, res *api.QueryLoginTokenResponse) error {
- tokenData, err := a.DeviceDB.GetLoginTokenDataByToken(ctx, req.Token)
+ tokenData, err := a.DB.GetLoginTokenDataByToken(ctx, req.Token)
if err != nil {
res.Data = nil
if err == sql.ErrNoRows {
@@ -66,7 +66,7 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog
if domain != a.ServerName {
return fmt.Errorf("cannot return a login token for a remote user: got %s want %s", domain, a.ServerName)
}
- if _, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart); err != nil {
+ if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil {
res.Data = nil
if err == sql.ErrNoRows {
return nil
diff --git a/userapi/storage/devices/interface.go b/userapi/storage/devices/interface.go
deleted file mode 100644
index 8ff91cf1..00000000
--- a/userapi/storage/devices/interface.go
+++ /dev/null
@@ -1,52 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package devices
-
-import (
- "context"
-
- "github.com/matrix-org/dendrite/userapi/api"
-)
-
-type Database interface {
- GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
- GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
- GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
- GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
- // CreateDevice makes a new device associated with the given user ID localpart.
- // If there is already a device with the same device ID for this user, that access token will be revoked
- // and replaced with the given accessToken. If the given accessToken is already in use for another device,
- // an error will be returned.
- // If no device ID is given one is generated.
- // Returns the device on success.
- CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
- UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
- UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error
- RemoveDevice(ctx context.Context, deviceID, localpart string) error
- RemoveDevices(ctx context.Context, localpart string, devices []string) error
- // RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
- RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
-
- // CreateLoginToken generates a token, stores and returns it. The lifetime is
- // determined by the loginTokenLifetime given to the Database constructor.
- CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error)
-
- // RemoveLoginToken removes the named token (and may clean up other expired tokens).
- RemoveLoginToken(ctx context.Context, token string) error
-
- // GetLoginTokenDataByToken returns the data associated with the given token.
- // May return sql.ErrNoRows.
- GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error)
-}
diff --git a/userapi/storage/devices/postgres/storage.go b/userapi/storage/devices/postgres/storage.go
deleted file mode 100644
index fd9d513f..00000000
--- a/userapi/storage/devices/postgres/storage.go
+++ /dev/null
@@ -1,270 +0,0 @@
-// Copyright 2017 Vector Creations Ltd
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package postgres
-
-import (
- "context"
- "crypto/rand"
- "database/sql"
- "encoding/base64"
- "time"
-
- "github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/setup/config"
- "github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/dendrite/userapi/storage/devices/postgres/deltas"
- "github.com/matrix-org/gomatrixserverlib"
-)
-
-const (
- // The length of generated device IDs
- deviceIDByteLength = 6
- loginTokenByteLength = 32
-)
-
-// Database represents a device database.
-type Database struct {
- db *sql.DB
- devices devicesStatements
- loginTokens loginTokenStatements
- loginTokenLifetime time.Duration
-}
-
-// NewDatabase creates a new device database
-func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) {
- db, err := sqlutil.Open(dbProperties)
- if err != nil {
- return nil, err
- }
- var d devicesStatements
- var lt loginTokenStatements
-
- // Create tables before executing migrations so we don't fail if the table is missing,
- // and THEN prepare statements so we don't fail due to referencing new columns
- if err = d.execSchema(db); err != nil {
- return nil, err
- }
- if err = lt.execSchema(db); err != nil {
- return nil, err
- }
-
- m := sqlutil.NewMigrations()
- deltas.LoadLastSeenTSIP(m)
- if err = m.RunDeltas(db, dbProperties); err != nil {
- return nil, err
- }
-
- if err = d.prepare(db, serverName); err != nil {
- return nil, err
- }
- if err = lt.prepare(db); err != nil {
- return nil, err
- }
-
- return &Database{db, d, lt, loginTokenLifetime}, nil
-}
-
-// GetDeviceByAccessToken returns the device matching the given access token.
-// Returns sql.ErrNoRows if no matching device was found.
-func (d *Database) GetDeviceByAccessToken(
- ctx context.Context, token string,
-) (*api.Device, error) {
- return d.devices.selectDeviceByToken(ctx, token)
-}
-
-// GetDeviceByID returns the device matching the given ID.
-// Returns sql.ErrNoRows if no matching device was found.
-func (d *Database) GetDeviceByID(
- ctx context.Context, localpart, deviceID string,
-) (*api.Device, error) {
- return d.devices.selectDeviceByID(ctx, localpart, deviceID)
-}
-
-// GetDevicesByLocalpart returns the devices matching the given localpart.
-func (d *Database) GetDevicesByLocalpart(
- ctx context.Context, localpart string,
-) ([]api.Device, error) {
- return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
-}
-
-func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
- return d.devices.selectDevicesByID(ctx, deviceIDs)
-}
-
-// CreateDevice makes a new device associated with the given user ID localpart.
-// If there is already a device with the same device ID for this user, that access token will be revoked
-// and replaced with the given accessToken. If the given accessToken is already in use for another device,
-// an error will be returned.
-// If no device ID is given one is generated.
-// Returns the device on success.
-func (d *Database) CreateDevice(
- ctx context.Context, localpart string, deviceID *string, accessToken string,
- displayName *string, ipAddr, userAgent string,
-) (dev *api.Device, returnErr error) {
- if deviceID != nil {
- returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
- var err error
- // Revoke existing tokens for this device
- if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
- return err
- }
-
- dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
- return err
- })
- } else {
- // We generate device IDs in a loop in case its already taken.
- // We cap this at going round 5 times to ensure we don't spin forever
- var newDeviceID string
- for i := 1; i <= 5; i++ {
- newDeviceID, returnErr = generateDeviceID()
- if returnErr != nil {
- return
- }
-
- returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
- var err error
- dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
- return err
- })
- if returnErr == nil {
- return
- }
- }
- }
- return
-}
-
-// generateDeviceID creates a new device id. Returns an error if failed to generate
-// random bytes.
-func generateDeviceID() (string, error) {
- b := make([]byte, deviceIDByteLength)
- _, err := rand.Read(b)
- if err != nil {
- return "", err
- }
- // url-safe no padding
- return base64.RawURLEncoding.EncodeToString(b), nil
-}
-
-// UpdateDevice updates the given device with the display name.
-// Returns SQL error if there are problems and nil on success.
-func (d *Database) UpdateDevice(
- ctx context.Context, localpart, deviceID string, displayName *string,
-) error {
- return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
- return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
- })
-}
-
-// RemoveDevice revokes a device by deleting the entry in the database
-// matching with the given device ID and user ID localpart.
-// If the device doesn't exist, it will not return an error
-// If something went wrong during the deletion, it will return the SQL error.
-func (d *Database) RemoveDevice(
- ctx context.Context, deviceID, localpart string,
-) error {
- return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
- if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
- return err
- }
- return nil
- })
-}
-
-// RemoveDevices revokes one or more devices by deleting the entry in the database
-// matching with the given device IDs and user ID localpart.
-// If the devices don't exist, it will not return an error
-// If something went wrong during the deletion, it will return the SQL error.
-func (d *Database) RemoveDevices(
- ctx context.Context, localpart string, devices []string,
-) error {
- return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
- if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
- return err
- }
- return nil
- })
-}
-
-// RemoveAllDevices revokes devices by deleting the entry in the
-// database matching the given user ID localpart.
-// If something went wrong during the deletion, it will return the SQL error.
-func (d *Database) RemoveAllDevices(
- ctx context.Context, localpart, exceptDeviceID string,
-) (devices []api.Device, err error) {
- err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
- devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
- if err != nil {
- return err
- }
- if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
- return err
- }
- return nil
- })
- return
-}
-
-// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
-func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
- return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
- return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
- })
-}
-
-// CreateLoginToken generates a token, stores and returns it. The lifetime is
-// determined by the loginTokenLifetime given to the Database constructor.
-func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
- tok, err := generateLoginToken()
- if err != nil {
- return nil, err
- }
- meta := &api.LoginTokenMetadata{
- Token: tok,
- Expiration: time.Now().Add(d.loginTokenLifetime),
- }
-
- err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
- return d.loginTokens.insert(ctx, txn, meta, data)
- })
- if err != nil {
- return nil, err
- }
-
- return meta, nil
-}
-
-func generateLoginToken() (string, error) {
- b := make([]byte, loginTokenByteLength)
- _, err := rand.Read(b)
- if err != nil {
- return "", err
- }
- return base64.RawURLEncoding.EncodeToString(b), nil
-}
-
-// RemoveLoginToken removes the named token (and may clean up other expired tokens).
-func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
- return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
- return d.loginTokens.deleteByToken(ctx, txn, token)
- })
-}
-
-// GetLoginTokenDataByToken returns the data associated with the given token.
-// May return sql.ErrNoRows.
-func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
- return d.loginTokens.selectByToken(ctx, token)
-}
diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go
deleted file mode 100644
index 6e90413b..00000000
--- a/userapi/storage/devices/sqlite3/storage.go
+++ /dev/null
@@ -1,271 +0,0 @@
-// Copyright 2017 Vector Creations Ltd
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package sqlite3
-
-import (
- "context"
- "crypto/rand"
- "database/sql"
- "encoding/base64"
- "time"
-
- "github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/matrix-org/dendrite/setup/config"
- "github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3/deltas"
- "github.com/matrix-org/gomatrixserverlib"
-)
-
-const (
- // The length of generated device IDs
- deviceIDByteLength = 6
-
- loginTokenByteLength = 32
-)
-
-// Database represents a device database.
-type Database struct {
- db *sql.DB
- writer sqlutil.Writer
- devices devicesStatements
- loginTokens loginTokenStatements
- loginTokenLifetime time.Duration
-}
-
-// NewDatabase creates a new device database
-func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) {
- db, err := sqlutil.Open(dbProperties)
- if err != nil {
- return nil, err
- }
- writer := sqlutil.NewExclusiveWriter()
- var d devicesStatements
- var lt loginTokenStatements
-
- // Create tables before executing migrations so we don't fail if the table is missing,
- // and THEN prepare statements so we don't fail due to referencing new columns
- if err = d.execSchema(db); err != nil {
- return nil, err
- }
- if err = lt.execSchema(db); err != nil {
- return nil, err
- }
-
- m := sqlutil.NewMigrations()
- deltas.LoadLastSeenTSIP(m)
- if err = m.RunDeltas(db, dbProperties); err != nil {
- return nil, err
- }
- if err = d.prepare(db, writer, serverName); err != nil {
- return nil, err
- }
- if err = lt.prepare(db); err != nil {
- return nil, err
- }
- return &Database{db, writer, d, lt, loginTokenLifetime}, nil
-}
-
-// GetDeviceByAccessToken returns the device matching the given access token.
-// Returns sql.ErrNoRows if no matching device was found.
-func (d *Database) GetDeviceByAccessToken(
- ctx context.Context, token string,
-) (*api.Device, error) {
- return d.devices.selectDeviceByToken(ctx, token)
-}
-
-// GetDeviceByID returns the device matching the given ID.
-// Returns sql.ErrNoRows if no matching device was found.
-func (d *Database) GetDeviceByID(
- ctx context.Context, localpart, deviceID string,
-) (*api.Device, error) {
- return d.devices.selectDeviceByID(ctx, localpart, deviceID)
-}
-
-// GetDevicesByLocalpart returns the devices matching the given localpart.
-func (d *Database) GetDevicesByLocalpart(
- ctx context.Context, localpart string,
-) ([]api.Device, error) {
- return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
-}
-
-func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
- return d.devices.selectDevicesByID(ctx, deviceIDs)
-}
-
-// CreateDevice makes a new device associated with the given user ID localpart.
-// If there is already a device with the same device ID for this user, that access token will be revoked
-// and replaced with the given accessToken. If the given accessToken is already in use for another device,
-// an error will be returned.
-// If no device ID is given one is generated.
-// Returns the device on success.
-func (d *Database) CreateDevice(
- ctx context.Context, localpart string, deviceID *string, accessToken string,
- displayName *string, ipAddr, userAgent string,
-) (dev *api.Device, returnErr error) {
- if deviceID != nil {
- returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
- var err error
- // Revoke existing tokens for this device
- if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
- return err
- }
-
- dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
- return err
- })
- } else {
- // We generate device IDs in a loop in case its already taken.
- // We cap this at going round 5 times to ensure we don't spin forever
- var newDeviceID string
- for i := 1; i <= 5; i++ {
- newDeviceID, returnErr = generateDeviceID()
- if returnErr != nil {
- return
- }
-
- returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
- var err error
- dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
- return err
- })
- if returnErr == nil {
- return
- }
- }
- }
- return
-}
-
-// generateDeviceID creates a new device id. Returns an error if failed to generate
-// random bytes.
-func generateDeviceID() (string, error) {
- b := make([]byte, deviceIDByteLength)
- _, err := rand.Read(b)
- if err != nil {
- return "", err
- }
- // url-safe no padding
- return base64.RawURLEncoding.EncodeToString(b), nil
-}
-
-// UpdateDevice updates the given device with the display name.
-// Returns SQL error if there are problems and nil on success.
-func (d *Database) UpdateDevice(
- ctx context.Context, localpart, deviceID string, displayName *string,
-) error {
- return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
- return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
- })
-}
-
-// RemoveDevice revokes a device by deleting the entry in the database
-// matching with the given device ID and user ID localpart.
-// If the device doesn't exist, it will not return an error
-// If something went wrong during the deletion, it will return the SQL error.
-func (d *Database) RemoveDevice(
- ctx context.Context, deviceID, localpart string,
-) error {
- return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
- if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
- return err
- }
- return nil
- })
-}
-
-// RemoveDevices revokes one or more devices by deleting the entry in the database
-// matching with the given device IDs and user ID localpart.
-// If the devices don't exist, it will not return an error
-// If something went wrong during the deletion, it will return the SQL error.
-func (d *Database) RemoveDevices(
- ctx context.Context, localpart string, devices []string,
-) error {
- return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
- if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
- return err
- }
- return nil
- })
-}
-
-// RemoveAllDevices revokes devices by deleting the entry in the
-// database matching the given user ID localpart.
-// If something went wrong during the deletion, it will return the SQL error.
-func (d *Database) RemoveAllDevices(
- ctx context.Context, localpart, exceptDeviceID string,
-) (devices []api.Device, err error) {
- err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
- devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
- if err != nil {
- return err
- }
- if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
- return err
- }
- return nil
- })
- return
-}
-
-// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
-func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
- return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
- return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
- })
-}
-
-// CreateLoginToken generates a token, stores and returns it. The lifetime is
-// determined by the loginTokenLifetime given to the Database constructor.
-func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
- tok, err := generateLoginToken()
- if err != nil {
- return nil, err
- }
- meta := &api.LoginTokenMetadata{
- Token: tok,
- Expiration: time.Now().Add(d.loginTokenLifetime),
- }
-
- err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
- return d.loginTokens.insert(ctx, txn, meta, data)
- })
- if err != nil {
- return nil, err
- }
-
- return meta, nil
-}
-
-func generateLoginToken() (string, error) {
- b := make([]byte, loginTokenByteLength)
- _, err := rand.Read(b)
- if err != nil {
- return "", err
- }
- return base64.RawURLEncoding.EncodeToString(b), nil
-}
-
-// RemoveLoginToken removes the named token (and may clean up other expired tokens).
-func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
- return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
- return d.loginTokens.deleteByToken(ctx, txn, token)
- })
-}
-
-// GetLoginTokenDataByToken returns the data associated with the given token.
-// May return sql.ErrNoRows.
-func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
- return d.loginTokens.selectByToken(ctx, token)
-}
diff --git a/userapi/storage/devices/storage.go b/userapi/storage/devices/storage.go
deleted file mode 100644
index 15cf8150..00000000
--- a/userapi/storage/devices/storage.go
+++ /dev/null
@@ -1,42 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-//go:build !wasm
-// +build !wasm
-
-package devices
-
-import (
- "fmt"
- "time"
-
- "github.com/matrix-org/dendrite/setup/config"
- "github.com/matrix-org/dendrite/userapi/storage/devices/postgres"
- "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3"
- "github.com/matrix-org/gomatrixserverlib"
-)
-
-// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
-// and sets postgres connection parameters. loginTokenLifetime determines how long a
-// login token from CreateLoginToken is valid.
-func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (Database, error) {
- switch {
- case dbProperties.ConnectionString.IsSQLite():
- return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime)
- case dbProperties.ConnectionString.IsPostgres():
- return postgres.NewDatabase(dbProperties, serverName, loginTokenLifetime)
- default:
- return nil, fmt.Errorf("unexpected database type")
- }
-}
diff --git a/userapi/storage/devices/storage_wasm.go b/userapi/storage/devices/storage_wasm.go
deleted file mode 100644
index 3de7880b..00000000
--- a/userapi/storage/devices/storage_wasm.go
+++ /dev/null
@@ -1,39 +0,0 @@
-// Copyright 2020 The Matrix.org Foundation C.I.C.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package devices
-
-import (
- "fmt"
- "time"
-
- "github.com/matrix-org/dendrite/setup/config"
- "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3"
- "github.com/matrix-org/gomatrixserverlib"
-)
-
-func NewDatabase(
- dbProperties *config.DatabaseOptions,
- serverName gomatrixserverlib.ServerName,
- loginTokenLifetime time.Duration,
-) (Database, error) {
- switch {
- case dbProperties.ConnectionString.IsSQLite():
- return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime)
- case dbProperties.ConnectionString.IsPostgres():
- return nil, fmt.Errorf("can't use Postgres implementation")
- default:
- return nil, fmt.Errorf("unexpected database type")
- }
-}
diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/interface.go
index a2185774..a131dac4 100644
--- a/userapi/storage/accounts/interface.go
+++ b/userapi/storage/interface.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package accounts
+package storage
import (
"context"
@@ -60,6 +60,35 @@ type Database interface {
UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error)
GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error)
CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error)
+
+ GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
+ GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
+ GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
+ GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
+ // CreateDevice makes a new device associated with the given user ID localpart.
+ // If there is already a device with the same device ID for this user, that access token will be revoked
+ // and replaced with the given accessToken. If the given accessToken is already in use for another device,
+ // an error will be returned.
+ // If no device ID is given one is generated.
+ // Returns the device on success.
+ CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
+ UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
+ UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error
+ RemoveDevice(ctx context.Context, deviceID, localpart string) error
+ RemoveDevices(ctx context.Context, localpart string, devices []string) error
+ // RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
+ RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
+
+ // CreateLoginToken generates a token, stores and returns it. The lifetime is
+ // determined by the loginTokenLifetime given to the Database constructor.
+ CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error)
+
+ // RemoveLoginToken removes the named token (and may clean up other expired tokens).
+ RemoveLoginToken(ctx context.Context, token string) error
+
+ // GetLoginTokenDataByToken returns the data associated with the given token.
+ // May return sql.ErrNoRows.
+ GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error)
}
// Err3PIDInUse is the error returned when trying to save an association involving
diff --git a/userapi/storage/accounts/postgres/account_data_table.go b/userapi/storage/postgres/account_data_table.go
index 8ba890e7..8ba890e7 100644
--- a/userapi/storage/accounts/postgres/account_data_table.go
+++ b/userapi/storage/postgres/account_data_table.go
diff --git a/userapi/storage/accounts/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go
index 9e3e456a..9e3e456a 100644
--- a/userapi/storage/accounts/postgres/accounts_table.go
+++ b/userapi/storage/postgres/accounts_table.go
diff --git a/userapi/storage/accounts/postgres/deltas/20200929203058_is_active.go b/userapi/storage/postgres/deltas/20200929203058_is_active.go
index 32d3235b..32d3235b 100644
--- a/userapi/storage/accounts/postgres/deltas/20200929203058_is_active.go
+++ b/userapi/storage/postgres/deltas/20200929203058_is_active.go
diff --git a/userapi/storage/devices/postgres/deltas/20201001204705_last_seen_ts_ip.go b/userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go
index 290f854c..1bbb0a9d 100644
--- a/userapi/storage/devices/postgres/deltas/20201001204705_last_seen_ts_ip.go
+++ b/userapi/storage/postgres/deltas/20201001204705_last_seen_ts_ip.go
@@ -5,13 +5,8 @@ import (
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/pressly/goose"
)
-func LoadFromGoose() {
- goose.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
-}
-
func LoadLastSeenTSIP(m *sqlutil.Migrations) {
m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
}
diff --git a/userapi/storage/accounts/postgres/deltas/2022021013023800_add_account_type.go b/userapi/storage/postgres/deltas/2022021013023800_add_account_type.go
index 2fae00cb..2fae00cb 100644
--- a/userapi/storage/accounts/postgres/deltas/2022021013023800_add_account_type.go
+++ b/userapi/storage/postgres/deltas/2022021013023800_add_account_type.go
diff --git a/userapi/storage/devices/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go
index 7de9f5f9..64cc0b71 100644
--- a/userapi/storage/devices/postgres/devices_table.go
+++ b/userapi/storage/postgres/devices_table.go
@@ -117,6 +117,9 @@ func (s *devicesStatements) execSchema(db *sql.DB) error {
}
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
+ if err = s.execSchema(db); err != nil {
+ return
+ }
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
return
}
diff --git a/userapi/storage/accounts/postgres/key_backup_table.go b/userapi/storage/postgres/key_backup_table.go
index c1402d4d..c1402d4d 100644
--- a/userapi/storage/accounts/postgres/key_backup_table.go
+++ b/userapi/storage/postgres/key_backup_table.go
diff --git a/userapi/storage/accounts/postgres/key_backup_version_table.go b/userapi/storage/postgres/key_backup_version_table.go
index d73447b4..d73447b4 100644
--- a/userapi/storage/accounts/postgres/key_backup_version_table.go
+++ b/userapi/storage/postgres/key_backup_version_table.go
diff --git a/userapi/storage/devices/postgres/logintoken_table.go b/userapi/storage/postgres/logintoken_table.go
index f601fc7d..508a6898 100644
--- a/userapi/storage/devices/postgres/logintoken_table.go
+++ b/userapi/storage/postgres/logintoken_table.go
@@ -51,6 +51,9 @@ CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_exp
// prepare runs statement preparation.
func (s *loginTokenStatements) prepare(db *sql.DB) error {
+ if err := s.execSchema(db); err != nil {
+ return err
+ }
return sqlutil.StatementList{
{&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"},
{&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"},
diff --git a/userapi/storage/accounts/postgres/openid_table.go b/userapi/storage/postgres/openid_table.go
index 190d141b..190d141b 100644
--- a/userapi/storage/accounts/postgres/openid_table.go
+++ b/userapi/storage/postgres/openid_table.go
diff --git a/userapi/storage/accounts/postgres/profile_table.go b/userapi/storage/postgres/profile_table.go
index 9313864b..9313864b 100644
--- a/userapi/storage/accounts/postgres/profile_table.go
+++ b/userapi/storage/postgres/profile_table.go
diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/postgres/storage.go
index d31efd25..73419279 100644
--- a/userapi/storage/accounts/postgres/storage.go
+++ b/userapi/storage/postgres/storage.go
@@ -16,7 +16,9 @@ package postgres
import (
"context"
+ "crypto/rand"
"database/sql"
+ "encoding/base64"
"encoding/json"
"errors"
"fmt"
@@ -30,7 +32,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres/deltas"
+ "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
// Import the postgres database driver.
_ "github.com/lib/pq"
@@ -47,14 +49,23 @@ type Database struct {
threepids threepidStatements
openIDTokens tokenStatements
keyBackupVersions keyBackupVersionStatements
+ devices devicesStatements
+ loginTokens loginTokenStatements
+ loginTokenLifetime time.Duration
keyBackups keyBackupStatements
serverName gomatrixserverlib.ServerName
bcryptCost int
openIDTokenLifetimeMS int64
}
+const (
+ // The length of generated device IDs
+ deviceIDByteLength = 6
+ loginTokenByteLength = 32
+)
+
// NewDatabase creates a new accounts and profiles database
-func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
+func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
@@ -63,6 +74,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
serverName: serverName,
db: db,
writer: sqlutil.NewDummyWriter(),
+ loginTokenLifetime: loginTokenLifetime,
bcryptCost: bcryptCost,
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
}
@@ -74,6 +86,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
}
m := sqlutil.NewMigrations()
deltas.LoadIsActive(m)
+ //deltas.LoadLastSeenTSIP(m)
deltas.LoadAddAccountType(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
@@ -103,6 +116,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = d.keyBackups.prepare(db); err != nil {
return nil, err
}
+ if err = d.devices.prepare(db, serverName); err != nil {
+ return nil, err
+ }
+ if err = d.loginTokens.prepare(db); err != nil {
+ return nil, err
+ }
return d, nil
}
@@ -515,3 +534,196 @@ func (d *Database) UpsertBackupKeys(
})
return
}
+
+// GetDeviceByAccessToken returns the device matching the given access token.
+// Returns sql.ErrNoRows if no matching device was found.
+func (d *Database) GetDeviceByAccessToken(
+ ctx context.Context, token string,
+) (*api.Device, error) {
+ return d.devices.selectDeviceByToken(ctx, token)
+}
+
+// GetDeviceByID returns the device matching the given ID.
+// Returns sql.ErrNoRows if no matching device was found.
+func (d *Database) GetDeviceByID(
+ ctx context.Context, localpart, deviceID string,
+) (*api.Device, error) {
+ return d.devices.selectDeviceByID(ctx, localpart, deviceID)
+}
+
+// GetDevicesByLocalpart returns the devices matching the given localpart.
+func (d *Database) GetDevicesByLocalpart(
+ ctx context.Context, localpart string,
+) ([]api.Device, error) {
+ return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
+}
+
+func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
+ return d.devices.selectDevicesByID(ctx, deviceIDs)
+}
+
+// CreateDevice makes a new device associated with the given user ID localpart.
+// If there is already a device with the same device ID for this user, that access token will be revoked
+// and replaced with the given accessToken. If the given accessToken is already in use for another device,
+// an error will be returned.
+// If no device ID is given one is generated.
+// Returns the device on success.
+func (d *Database) CreateDevice(
+ ctx context.Context, localpart string, deviceID *string, accessToken string,
+ displayName *string, ipAddr, userAgent string,
+) (dev *api.Device, returnErr error) {
+ if deviceID != nil {
+ returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ var err error
+ // Revoke existing tokens for this device
+ if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
+ return err
+ }
+
+ dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
+ return err
+ })
+ } else {
+ // We generate device IDs in a loop in case its already taken.
+ // We cap this at going round 5 times to ensure we don't spin forever
+ var newDeviceID string
+ for i := 1; i <= 5; i++ {
+ newDeviceID, returnErr = generateDeviceID()
+ if returnErr != nil {
+ return
+ }
+
+ returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ var err error
+ dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
+ return err
+ })
+ if returnErr == nil {
+ return
+ }
+ }
+ }
+ return
+}
+
+// generateDeviceID creates a new device id. Returns an error if failed to generate
+// random bytes.
+func generateDeviceID() (string, error) {
+ b := make([]byte, deviceIDByteLength)
+ _, err := rand.Read(b)
+ if err != nil {
+ return "", err
+ }
+ // url-safe no padding
+ return base64.RawURLEncoding.EncodeToString(b), nil
+}
+
+// UpdateDevice updates the given device with the display name.
+// Returns SQL error if there are problems and nil on success.
+func (d *Database) UpdateDevice(
+ ctx context.Context, localpart, deviceID string, displayName *string,
+) error {
+ return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
+ })
+}
+
+// RemoveDevice revokes a device by deleting the entry in the database
+// matching with the given device ID and user ID localpart.
+// If the device doesn't exist, it will not return an error
+// If something went wrong during the deletion, it will return the SQL error.
+func (d *Database) RemoveDevice(
+ ctx context.Context, deviceID, localpart string,
+) error {
+ return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
+ return err
+ }
+ return nil
+ })
+}
+
+// RemoveDevices revokes one or more devices by deleting the entry in the database
+// matching with the given device IDs and user ID localpart.
+// If the devices don't exist, it will not return an error
+// If something went wrong during the deletion, it will return the SQL error.
+func (d *Database) RemoveDevices(
+ ctx context.Context, localpart string, devices []string,
+) error {
+ return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
+ return err
+ }
+ return nil
+ })
+}
+
+// RemoveAllDevices revokes devices by deleting the entry in the
+// database matching the given user ID localpart.
+// If something went wrong during the deletion, it will return the SQL error.
+func (d *Database) RemoveAllDevices(
+ ctx context.Context, localpart, exceptDeviceID string,
+) (devices []api.Device, err error) {
+ err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
+ if err != nil {
+ return err
+ }
+ if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
+ return err
+ }
+ return nil
+ })
+ return
+}
+
+// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
+func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
+ return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
+ })
+}
+
+// CreateLoginToken generates a token, stores and returns it. The lifetime is
+// determined by the loginTokenLifetime given to the Database constructor.
+func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
+ tok, err := generateLoginToken()
+ if err != nil {
+ return nil, err
+ }
+ meta := &api.LoginTokenMetadata{
+ Token: tok,
+ Expiration: time.Now().Add(d.loginTokenLifetime),
+ }
+
+ err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ return d.loginTokens.insert(ctx, txn, meta, data)
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return meta, nil
+}
+
+func generateLoginToken() (string, error) {
+ b := make([]byte, loginTokenByteLength)
+ _, err := rand.Read(b)
+ if err != nil {
+ return "", err
+ }
+ return base64.RawURLEncoding.EncodeToString(b), nil
+}
+
+// RemoveLoginToken removes the named token (and may clean up other expired tokens).
+func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
+ return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ return d.loginTokens.deleteByToken(ctx, txn, token)
+ })
+}
+
+// GetLoginTokenDataByToken returns the data associated with the given token.
+// May return sql.ErrNoRows.
+func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
+ return d.loginTokens.selectByToken(ctx, token)
+}
diff --git a/userapi/storage/accounts/postgres/threepid_table.go b/userapi/storage/postgres/threepid_table.go
index 9280fc87..9280fc87 100644
--- a/userapi/storage/accounts/postgres/threepid_table.go
+++ b/userapi/storage/postgres/threepid_table.go
diff --git a/userapi/storage/accounts/sqlite3/account_data_table.go b/userapi/storage/sqlite3/account_data_table.go
index 871f996e..871f996e 100644
--- a/userapi/storage/accounts/sqlite3/account_data_table.go
+++ b/userapi/storage/sqlite3/account_data_table.go
diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/sqlite3/accounts_table.go
index 5a918e03..5a918e03 100644
--- a/userapi/storage/accounts/sqlite3/accounts_table.go
+++ b/userapi/storage/sqlite3/accounts_table.go
diff --git a/userapi/storage/accounts/sqlite3/constraint_wasm.go b/userapi/storage/sqlite3/constraint_wasm.go
index 6c4ee762..6c4ee762 100644
--- a/userapi/storage/accounts/sqlite3/constraint_wasm.go
+++ b/userapi/storage/sqlite3/constraint_wasm.go
diff --git a/userapi/storage/accounts/sqlite3/deltas/20200929203058_is_active.go b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go
index c69614e8..c69614e8 100644
--- a/userapi/storage/accounts/sqlite3/deltas/20200929203058_is_active.go
+++ b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go
diff --git a/userapi/storage/devices/sqlite3/deltas/20201001204705_last_seen_ts_ip.go b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go
index 26209826..ebf90800 100644
--- a/userapi/storage/devices/sqlite3/deltas/20201001204705_last_seen_ts_ip.go
+++ b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go
@@ -5,13 +5,8 @@ import (
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
- "github.com/pressly/goose"
)
-func LoadFromGoose() {
- goose.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
-}
-
func LoadLastSeenTSIP(m *sqlutil.Migrations) {
m.AddMigration(UpLastSeenTSIP, DownLastSeenTSIP)
}
diff --git a/userapi/storage/accounts/sqlite3/deltas/2022021012490600_add_account_type.go b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go
index 9b058ded..9b058ded 100644
--- a/userapi/storage/accounts/sqlite3/deltas/2022021012490600_add_account_type.go
+++ b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go
diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go
index 955d8ac7..119ecdf9 100644
--- a/userapi/storage/devices/sqlite3/devices_table.go
+++ b/userapi/storage/sqlite3/devices_table.go
@@ -106,6 +106,9 @@ func (s *devicesStatements) execSchema(db *sql.DB) error {
func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) {
s.db = db
s.writer = writer
+ if err = s.execSchema(db); err != nil {
+ return
+ }
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
return
}
diff --git a/userapi/storage/accounts/sqlite3/key_backup_table.go b/userapi/storage/sqlite3/key_backup_table.go
index 837d38cf..837d38cf 100644
--- a/userapi/storage/accounts/sqlite3/key_backup_table.go
+++ b/userapi/storage/sqlite3/key_backup_table.go
diff --git a/userapi/storage/accounts/sqlite3/key_backup_version_table.go b/userapi/storage/sqlite3/key_backup_version_table.go
index 4211ed0f..4211ed0f 100644
--- a/userapi/storage/accounts/sqlite3/key_backup_version_table.go
+++ b/userapi/storage/sqlite3/key_backup_version_table.go
diff --git a/userapi/storage/devices/sqlite3/logintoken_table.go b/userapi/storage/sqlite3/logintoken_table.go
index 75ef272f..52322b46 100644
--- a/userapi/storage/devices/sqlite3/logintoken_table.go
+++ b/userapi/storage/sqlite3/logintoken_table.go
@@ -51,6 +51,9 @@ CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_exp
// prepare runs statement preparation.
func (s *loginTokenStatements) prepare(db *sql.DB) error {
+ if err := s.execSchema(db); err != nil {
+ return err
+ }
return sqlutil.StatementList{
{&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"},
{&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"},
diff --git a/userapi/storage/accounts/sqlite3/openid_table.go b/userapi/storage/sqlite3/openid_table.go
index 98c0488b..98c0488b 100644
--- a/userapi/storage/accounts/sqlite3/openid_table.go
+++ b/userapi/storage/sqlite3/openid_table.go
diff --git a/userapi/storage/accounts/sqlite3/profile_table.go b/userapi/storage/sqlite3/profile_table.go
index a92e9566..a92e9566 100644
--- a/userapi/storage/accounts/sqlite3/profile_table.go
+++ b/userapi/storage/sqlite3/profile_table.go
diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go
index 0bab16ca..56ec1b6a 100644
--- a/userapi/storage/accounts/sqlite3/storage.go
+++ b/userapi/storage/sqlite3/storage.go
@@ -16,7 +16,9 @@ package sqlite3
import (
"context"
+ "crypto/rand"
"database/sql"
+ "encoding/base64"
"encoding/json"
"errors"
"fmt"
@@ -31,7 +33,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3/deltas"
+ "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
)
// Database represents an account database
@@ -47,6 +49,9 @@ type Database struct {
openIDTokens tokenStatements
keyBackupVersions keyBackupVersionStatements
keyBackups keyBackupStatements
+ devices devicesStatements
+ loginTokens loginTokenStatements
+ loginTokenLifetime time.Duration
serverName gomatrixserverlib.ServerName
bcryptCost int
openIDTokenLifetimeMS int64
@@ -57,8 +62,14 @@ type Database struct {
threepidsMu sync.Mutex
}
+const (
+ // The length of generated device IDs
+ deviceIDByteLength = 6
+ loginTokenByteLength = 32
+)
+
// NewDatabase creates a new accounts and profiles database
-func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (*Database, error) {
+func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*Database, error) {
db, err := sqlutil.Open(dbProperties)
if err != nil {
return nil, err
@@ -67,6 +78,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
serverName: serverName,
db: db,
writer: sqlutil.NewExclusiveWriter(),
+ loginTokenLifetime: loginTokenLifetime,
bcryptCost: bcryptCost,
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
}
@@ -78,6 +90,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
}
m := sqlutil.NewMigrations()
deltas.LoadIsActive(m)
+ //deltas.LoadLastSeenTSIP(m)
deltas.LoadAddAccountType(m)
if err = m.RunDeltas(db, dbProperties); err != nil {
return nil, err
@@ -108,6 +121,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = d.keyBackups.prepare(db); err != nil {
return nil, err
}
+ if err = d.devices.prepare(db, d.writer, serverName); err != nil {
+ return nil, err
+ }
+ if err = d.loginTokens.prepare(db); err != nil {
+ return nil, err
+ }
return d, nil
}
@@ -547,3 +566,196 @@ func (d *Database) UpsertBackupKeys(
})
return
}
+
+// GetDeviceByAccessToken returns the device matching the given access token.
+// Returns sql.ErrNoRows if no matching device was found.
+func (d *Database) GetDeviceByAccessToken(
+ ctx context.Context, token string,
+) (*api.Device, error) {
+ return d.devices.selectDeviceByToken(ctx, token)
+}
+
+// GetDeviceByID returns the device matching the given ID.
+// Returns sql.ErrNoRows if no matching device was found.
+func (d *Database) GetDeviceByID(
+ ctx context.Context, localpart, deviceID string,
+) (*api.Device, error) {
+ return d.devices.selectDeviceByID(ctx, localpart, deviceID)
+}
+
+// GetDevicesByLocalpart returns the devices matching the given localpart.
+func (d *Database) GetDevicesByLocalpart(
+ ctx context.Context, localpart string,
+) ([]api.Device, error) {
+ return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
+}
+
+func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
+ return d.devices.selectDevicesByID(ctx, deviceIDs)
+}
+
+// CreateDevice makes a new device associated with the given user ID localpart.
+// If there is already a device with the same device ID for this user, that access token will be revoked
+// and replaced with the given accessToken. If the given accessToken is already in use for another device,
+// an error will be returned.
+// If no device ID is given one is generated.
+// Returns the device on success.
+func (d *Database) CreateDevice(
+ ctx context.Context, localpart string, deviceID *string, accessToken string,
+ displayName *string, ipAddr, userAgent string,
+) (dev *api.Device, returnErr error) {
+ if deviceID != nil {
+ returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
+ var err error
+ // Revoke existing tokens for this device
+ if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
+ return err
+ }
+
+ dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
+ return err
+ })
+ } else {
+ // We generate device IDs in a loop in case its already taken.
+ // We cap this at going round 5 times to ensure we don't spin forever
+ var newDeviceID string
+ for i := 1; i <= 5; i++ {
+ newDeviceID, returnErr = generateDeviceID()
+ if returnErr != nil {
+ return
+ }
+
+ returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
+ var err error
+ dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
+ return err
+ })
+ if returnErr == nil {
+ return
+ }
+ }
+ }
+ return
+}
+
+// generateDeviceID creates a new device id. Returns an error if failed to generate
+// random bytes.
+func generateDeviceID() (string, error) {
+ b := make([]byte, deviceIDByteLength)
+ _, err := rand.Read(b)
+ if err != nil {
+ return "", err
+ }
+ // url-safe no padding
+ return base64.RawURLEncoding.EncodeToString(b), nil
+}
+
+// UpdateDevice updates the given device with the display name.
+// Returns SQL error if there are problems and nil on success.
+func (d *Database) UpdateDevice(
+ ctx context.Context, localpart, deviceID string, displayName *string,
+) error {
+ return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
+ return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
+ })
+}
+
+// RemoveDevice revokes a device by deleting the entry in the database
+// matching with the given device ID and user ID localpart.
+// If the device doesn't exist, it will not return an error
+// If something went wrong during the deletion, it will return the SQL error.
+func (d *Database) RemoveDevice(
+ ctx context.Context, deviceID, localpart string,
+) error {
+ return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
+ if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
+ return err
+ }
+ return nil
+ })
+}
+
+// RemoveDevices revokes one or more devices by deleting the entry in the database
+// matching with the given device IDs and user ID localpart.
+// If the devices don't exist, it will not return an error
+// If something went wrong during the deletion, it will return the SQL error.
+func (d *Database) RemoveDevices(
+ ctx context.Context, localpart string, devices []string,
+) error {
+ return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
+ if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
+ return err
+ }
+ return nil
+ })
+}
+
+// RemoveAllDevices revokes devices by deleting the entry in the
+// database matching the given user ID localpart.
+// If something went wrong during the deletion, it will return the SQL error.
+func (d *Database) RemoveAllDevices(
+ ctx context.Context, localpart, exceptDeviceID string,
+) (devices []api.Device, err error) {
+ err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
+ devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
+ if err != nil {
+ return err
+ }
+ if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
+ return err
+ }
+ return nil
+ })
+ return
+}
+
+// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
+func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
+ return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
+ return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
+ })
+}
+
+// CreateLoginToken generates a token, stores and returns it. The lifetime is
+// determined by the loginTokenLifetime given to the Database constructor.
+func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
+ tok, err := generateLoginToken()
+ if err != nil {
+ return nil, err
+ }
+ meta := &api.LoginTokenMetadata{
+ Token: tok,
+ Expiration: time.Now().Add(d.loginTokenLifetime),
+ }
+
+ err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
+ return d.loginTokens.insert(ctx, txn, meta, data)
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return meta, nil
+}
+
+func generateLoginToken() (string, error) {
+ b := make([]byte, loginTokenByteLength)
+ _, err := rand.Read(b)
+ if err != nil {
+ return "", err
+ }
+ return base64.RawURLEncoding.EncodeToString(b), nil
+}
+
+// RemoveLoginToken removes the named token (and may clean up other expired tokens).
+func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
+ return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
+ return d.loginTokens.deleteByToken(ctx, txn, token)
+ })
+}
+
+// GetLoginTokenDataByToken returns the data associated with the given token.
+// May return sql.ErrNoRows.
+func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
+ return d.loginTokens.selectByToken(ctx, token)
+}
diff --git a/userapi/storage/accounts/sqlite3/threepid_table.go b/userapi/storage/sqlite3/threepid_table.go
index 9dc0e2d2..9dc0e2d2 100644
--- a/userapi/storage/accounts/sqlite3/threepid_table.go
+++ b/userapi/storage/sqlite3/threepid_table.go
diff --git a/userapi/storage/accounts/storage.go b/userapi/storage/storage.go
index f43f7efd..4711439a 100644
--- a/userapi/storage/accounts/storage.go
+++ b/userapi/storage/storage.go
@@ -15,26 +15,27 @@
//go:build !wasm
// +build !wasm
-package accounts
+package storage
import (
"fmt"
+ "time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/setup/config"
- "github.com/matrix-org/dendrite/userapi/storage/accounts/postgres"
- "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3"
+ "github.com/matrix-org/dendrite/userapi/storage/postgres"
+ "github.com/matrix-org/dendrite/userapi/storage/sqlite3"
)
// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
// and sets postgres connection parameters
-func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64) (Database, error) {
+func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
- return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS)
+ return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime)
case dbProperties.ConnectionString.IsPostgres():
- return postgres.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS)
+ return postgres.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime)
default:
return nil, fmt.Errorf("unexpected database type")
}
diff --git a/userapi/storage/accounts/storage_wasm.go b/userapi/storage/storage_wasm.go
index 11a88a20..701dcd83 100644
--- a/userapi/storage/accounts/storage_wasm.go
+++ b/userapi/storage/storage_wasm.go
@@ -12,13 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package accounts
+package storage
import (
"fmt"
+ "time"
"github.com/matrix-org/dendrite/setup/config"
- "github.com/matrix-org/dendrite/userapi/storage/accounts/sqlite3"
+ "github.com/matrix-org/dendrite/userapi/storage/sqlite3"
"github.com/matrix-org/gomatrixserverlib"
)
@@ -27,10 +28,11 @@ func NewDatabase(
serverName gomatrixserverlib.ServerName,
bcryptCost int,
openIDTokenLifetimeMS int64,
+ loginTokenLifetime time.Duration,
) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
- return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS)
+ return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime)
case dbProperties.ConnectionString.IsPostgres():
return nil, fmt.Errorf("can't use Postgres implementation")
default:
diff --git a/userapi/userapi.go b/userapi/userapi.go
index c7e1f667..4a5793ab 100644
--- a/userapi/userapi.go
+++ b/userapi/userapi.go
@@ -23,18 +23,10 @@ import (
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/internal"
"github.com/matrix-org/dendrite/userapi/inthttp"
- "github.com/matrix-org/dendrite/userapi/storage/accounts"
- "github.com/matrix-org/dendrite/userapi/storage/devices"
+ "github.com/matrix-org/dendrite/userapi/storage"
"github.com/sirupsen/logrus"
)
-// defaultLoginTokenLifetime determines how old a valid token may be.
-//
-// NOTSPEC: The current spec says "SHOULD be limited to around five
-// seconds". Since TCP retries are on the order of 3 s, 5 s sounds very low.
-// Synapse uses 2 min (https://github.com/matrix-org/synapse/blob/78d5f91de1a9baf4dbb0a794cb49a799f29f7a38/synapse/handlers/auth.py#L1323-L1325).
-const defaultLoginTokenLifetime = 2 * time.Minute
-
// AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions
// on the given input API.
func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) {
@@ -44,26 +36,24 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) {
// NewInternalAPI returns a concerete implementation of the internal API. Callers
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI(
- accountDB accounts.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI,
+ accountDB storage.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI,
) api.UserInternalAPI {
- deviceDB, err := devices.NewDatabase(&cfg.DeviceDatabase, cfg.Matrix.ServerName, defaultLoginTokenLifetime)
+ db, err := storage.NewDatabase(&cfg.AccountDatabase, cfg.Matrix.ServerName, cfg.BCryptCost, int64(api.DefaultLoginTokenLifetime*time.Millisecond), api.DefaultLoginTokenLifetime)
if err != nil {
logrus.WithError(err).Panicf("failed to connect to device db")
}
- return newInternalAPI(accountDB, deviceDB, cfg, appServices, keyAPI)
+ return newInternalAPI(db, cfg, appServices, keyAPI)
}
func newInternalAPI(
- accountDB accounts.Database,
- deviceDB devices.Database,
+ db storage.Database,
cfg *config.UserAPI,
appServices []config.ApplicationService,
keyAPI keyapi.KeyInternalAPI,
) api.UserInternalAPI {
return &internal.UserInternalAPI{
- AccountDB: accountDB,
- DeviceDB: deviceDB,
+ DB: db,
ServerName: cfg.Matrix.ServerName,
AppServices: appServices,
KeyAPI: keyAPI,
diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go
index 141dd96d..4214c07f 100644
--- a/userapi/userapi_test.go
+++ b/userapi/userapi_test.go
@@ -31,8 +31,7 @@ import (
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/inthttp"
- "github.com/matrix-org/dendrite/userapi/storage/accounts"
- "github.com/matrix-org/dendrite/userapi/storage/devices"
+ "github.com/matrix-org/dendrite/userapi/storage"
)
const (
@@ -43,23 +42,19 @@ type apiTestOpts struct {
loginTokenLifetime time.Duration
}
-func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, accounts.Database) {
+func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, storage.Database) {
if opts.loginTokenLifetime == 0 {
- opts.loginTokenLifetime = defaultLoginTokenLifetime
+ opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond
}
dbopts := &config.DatabaseOptions{
ConnectionString: "file::memory:",
MaxOpenConnections: 1,
MaxIdleConnections: 1,
}
- accountDB, err := accounts.NewDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS)
+ accountDB, err := storage.NewDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime)
if err != nil {
t.Fatalf("failed to create account DB: %s", err)
}
- deviceDB, err := devices.NewDatabase(dbopts, serverName, opts.loginTokenLifetime)
- if err != nil {
- t.Fatalf("failed to create device DB: %s", err)
- }
cfg := &config.UserAPI{
Matrix: &config.Global{
@@ -67,7 +62,7 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, a
},
}
- return newInternalAPI(accountDB, deviceDB, cfg, nil, nil), accountDB
+ return newInternalAPI(accountDB, cfg, nil, nil), accountDB
}
func TestQueryProfile(t *testing.T) {