aboutsummaryrefslogtreecommitdiff
path: root/userapi/internal/api.go
diff options
context:
space:
mode:
Diffstat (limited to 'userapi/internal/api.go')
-rw-r--r--userapi/internal/api.go67
1 files changed, 33 insertions, 34 deletions
diff --git a/userapi/internal/api.go b/userapi/internal/api.go
index f96d4804..f54cc613 100644
--- a/userapi/internal/api.go
+++ b/userapi/internal/api.go
@@ -31,13 +31,11 @@ import (
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
- "github.com/matrix-org/dendrite/userapi/storage/accounts"
- "github.com/matrix-org/dendrite/userapi/storage/devices"
+ "github.com/matrix-org/dendrite/userapi/storage"
)
type UserInternalAPI struct {
- AccountDB accounts.Database
- DeviceDB devices.Database
+ DB storage.Database
ServerName gomatrixserverlib.ServerName
// AppServices is the list of all registered AS
AppServices []config.ApplicationService
@@ -55,11 +53,11 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
if req.DataType == "" {
return fmt.Errorf("data type must not be empty")
}
- return a.AccountDB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData)
+ return a.DB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData)
}
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
- acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
+ acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
if err != nil {
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
switch req.OnConflict {
@@ -89,7 +87,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
return nil
}
- if err = a.AccountDB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
+ if err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
return err
}
@@ -99,7 +97,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
}
func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
- if err := a.AccountDB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
+ if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
return err
}
res.PasswordUpdated = true
@@ -112,7 +110,7 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
"device_id": req.DeviceID,
"display_name": req.DeviceDisplayName,
}).Info("PerformDeviceCreation")
- dev, err := a.DeviceDB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
+ dev, err := a.DB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
if err != nil {
return err
}
@@ -137,12 +135,12 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
deletedDeviceIDs := req.DeviceIDs
if len(req.DeviceIDs) == 0 {
var devices []api.Device
- devices, err = a.DeviceDB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
+ devices, err = a.DB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
for _, d := range devices {
deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
}
} else {
- err = a.DeviceDB.RemoveDevices(ctx, local, req.DeviceIDs)
+ err = a.DB.RemoveDevices(ctx, local, req.DeviceIDs)
}
if err != nil {
return err
@@ -196,7 +194,7 @@ func (a *UserInternalAPI) PerformLastSeenUpdate(
if err != nil {
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
}
- if err := a.DeviceDB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr); err != nil {
+ if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr); err != nil {
return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err)
}
return nil
@@ -208,7 +206,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
return err
}
- dev, err := a.DeviceDB.GetDeviceByID(ctx, localpart, req.DeviceID)
+ dev, err := a.DB.GetDeviceByID(ctx, localpart, req.DeviceID)
if err == sql.ErrNoRows {
res.DeviceExists = false
return nil
@@ -223,7 +221,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
return nil
}
- err = a.DeviceDB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName)
+ err = a.DB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
return err
@@ -261,7 +259,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
if domain != a.ServerName {
return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName)
}
- prof, err := a.AccountDB.GetProfileByLocalpart(ctx, local)
+ prof, err := a.DB.GetProfileByLocalpart(ctx, local)
if err != nil {
if err == sql.ErrNoRows {
return nil
@@ -275,7 +273,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
}
func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error {
- profiles, err := a.AccountDB.SearchProfiles(ctx, req.SearchString, req.Limit)
+ profiles, err := a.DB.SearchProfiles(ctx, req.SearchString, req.Limit)
if err != nil {
return err
}
@@ -284,7 +282,7 @@ func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.Quer
}
func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error {
- devices, err := a.DeviceDB.GetDevicesByID(ctx, req.DeviceIDs)
+ devices, err := a.DB.GetDevicesByID(ctx, req.DeviceIDs)
if err != nil {
return err
}
@@ -312,10 +310,11 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice
if domain != a.ServerName {
return fmt.Errorf("cannot query devices of remote users: got %s want %s", domain, a.ServerName)
}
- devs, err := a.DeviceDB.GetDevicesByLocalpart(ctx, local)
+ devs, err := a.DB.GetDevicesByLocalpart(ctx, local)
if err != nil {
return err
}
+ res.UserExists = true
res.Devices = devs
return nil
}
@@ -330,7 +329,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
}
if req.DataType != "" {
var data json.RawMessage
- data, err = a.AccountDB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
+ data, err = a.DB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
if err != nil {
return err
}
@@ -348,7 +347,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
}
return nil
}
- global, rooms, err := a.AccountDB.GetAccountData(ctx, local)
+ global, rooms, err := a.DB.GetAccountData(ctx, local)
if err != nil {
return err
}
@@ -367,7 +366,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
return nil
}
- device, err := a.DeviceDB.GetDeviceByAccessToken(ctx, req.AccessToken)
+ device, err := a.DB.GetDeviceByAccessToken(ctx, req.AccessToken)
if err != nil {
if err == sql.ErrNoRows {
return nil
@@ -378,7 +377,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
if err != nil {
return err
}
- acc, err := a.AccountDB.GetAccountByLocalpart(ctx, localPart)
+ acc, err := a.DB.GetAccountByLocalpart(ctx, localPart)
if err != nil {
return err
}
@@ -419,7 +418,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
if localpart != "" { // AS is masquerading as another user
// Verify that the user is registered
- account, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart)
+ account, err := a.DB.GetAccountByLocalpart(ctx, localpart)
// Verify that the account exists and either appServiceID matches or
// it belongs to the appservice user namespaces
if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) {
@@ -437,7 +436,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
// PerformAccountDeactivation deactivates the user's account, removing all ability for the user to login again.
func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error {
- err := a.AccountDB.DeactivateAccount(ctx, req.Localpart)
+ err := a.DB.DeactivateAccount(ctx, req.Localpart)
res.AccountDeactivated = err == nil
return err
}
@@ -446,7 +445,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error {
token := util.RandomString(24)
- exp, err := a.AccountDB.CreateOpenIDToken(ctx, token, req.UserID)
+ exp, err := a.DB.CreateOpenIDToken(ctx, token, req.UserID)
res.Token = api.OpenIDToken{
Token: token,
@@ -459,7 +458,7 @@ func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *a
// QueryOpenIDToken validates that the OpenID token was issued for the user, the replying party uses this for validation
func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error {
- openIDTokenAttrs, err := a.AccountDB.GetOpenIDTokenAttributes(ctx, req.Token)
+ openIDTokenAttrs, err := a.DB.GetOpenIDTokenAttributes(ctx, req.Token)
if err != nil {
return err
}
@@ -481,7 +480,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
}
return nil
}
- exists, err := a.AccountDB.DeleteKeyBackup(ctx, req.UserID, req.Version)
+ exists, err := a.DB.DeleteKeyBackup(ctx, req.UserID, req.Version)
if err != nil {
res.Error = fmt.Sprintf("failed to delete backup: %s", err)
}
@@ -494,7 +493,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
}
// Create metadata
if req.Version == "" {
- version, err := a.AccountDB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData)
+ version, err := a.DB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData)
if err != nil {
res.Error = fmt.Sprintf("failed to create backup: %s", err)
}
@@ -507,7 +506,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
}
// Update metadata
if len(req.Keys.Rooms) == 0 {
- err := a.AccountDB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData)
+ err := a.DB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData)
if err != nil {
res.Error = fmt.Sprintf("failed to update backup: %s", err)
}
@@ -528,7 +527,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) {
// you can only upload keys for the CURRENT version
- version, _, _, _, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, "")
+ version, _, _, _, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, "")
if err != nil {
res.Error = fmt.Sprintf("failed to query version: %s", err)
return
@@ -556,7 +555,7 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform
})
}
}
- count, etag, err := a.AccountDB.UpsertBackupKeys(ctx, version, req.UserID, uploads)
+ count, etag, err := a.DB.UpsertBackupKeys(ctx, version, req.UserID, uploads)
if err != nil {
res.Error = fmt.Sprintf("failed to upsert keys: %s", err)
return
@@ -566,7 +565,7 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform
}
func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) {
- version, algorithm, authData, etag, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, req.Version)
+ version, algorithm, authData, etag, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, req.Version)
res.Version = version
if err != nil {
if err == sql.ErrNoRows {
@@ -582,14 +581,14 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB
res.Exists = !deleted
if !req.ReturnKeys {
- res.Count, err = a.AccountDB.CountBackupKeys(ctx, version, req.UserID)
+ res.Count, err = a.DB.CountBackupKeys(ctx, version, req.UserID)
if err != nil {
res.Error = fmt.Sprintf("failed to count keys: %s", err)
}
return
}
- result, err := a.AccountDB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID)
+ result, err := a.DB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID)
if err != nil {
res.Error = fmt.Sprintf("failed to query keys: %s", err)
return