aboutsummaryrefslogtreecommitdiff
path: root/userapi/internal
diff options
context:
space:
mode:
authorTill <2353100+S7evinK@users.noreply.github.com>2023-04-28 17:49:38 +0200
committerGitHub <noreply@github.com>2023-04-28 17:49:38 +0200
commit9e9617ff84c3310fa3db6f32cdfc0207ec546963 (patch)
tree13fc5313e782f05c4569d7e49bcaea4bf8229dab /userapi/internal
parent6b47cf0f6ac9176b7e5a5bd6f357722ee0f5e384 (diff)
Add key backup tests (#3071)
Also slightly refactors the functions and methods to rely less on the req/res pattern we had for polylith. Returns `M_WRONG_ROOM_KEYS_VERSION` for some endpoints as per the spec
Diffstat (limited to 'userapi/internal')
-rw-r--r--userapi/internal/user_api.go87
1 files changed, 35 insertions, 52 deletions
diff --git a/userapi/internal/user_api.go b/userapi/internal/user_api.go
index cdd08344..ea97fd35 100644
--- a/userapi/internal/user_api.go
+++ b/userapi/internal/user_api.go
@@ -25,6 +25,7 @@ import (
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
+ "github.com/matrix-org/dendrite/clientapi/jsonerror"
fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/gomatrixserverlib"
@@ -678,62 +679,43 @@ func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOp
return nil
}
-func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) error {
- // Delete metadata
- if req.DeleteBackup {
- if req.Version == "" {
- res.BadInput = true
- res.Error = "must specify a version to delete"
- return nil
- }
- exists, err := a.DB.DeleteKeyBackup(ctx, req.UserID, req.Version)
- if err != nil {
- res.Error = fmt.Sprintf("failed to delete backup: %s", err)
- }
- res.Exists = exists
- res.Version = req.Version
- return nil
- }
+func (a *UserInternalAPI) DeleteKeyBackup(ctx context.Context, userID, version string) (bool, error) {
+ return a.DB.DeleteKeyBackup(ctx, userID, version)
+}
+
+func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.PerformKeyBackupRequest) (string, error) {
// Create metadata
- if req.Version == "" {
- version, err := a.DB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData)
- if err != nil {
- res.Error = fmt.Sprintf("failed to create backup: %s", err)
- }
- res.Exists = err == nil
- res.Version = version
- return nil
- }
+ return a.DB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData)
+}
+
+func (a *UserInternalAPI) UpdateBackupKeyAuthData(ctx context.Context, req *api.PerformKeyBackupRequest) (*api.PerformKeyBackupResponse, error) {
+ res := &api.PerformKeyBackupResponse{}
// Update metadata
if len(req.Keys.Rooms) == 0 {
err := a.DB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData)
- if err != nil {
- res.Error = fmt.Sprintf("failed to update backup: %s", err)
- }
res.Exists = err == nil
res.Version = req.Version
- return nil
+ if err != nil {
+ return res, fmt.Errorf("failed to update backup: %w", err)
+ }
+ return res, nil
}
// Upload Keys for a specific version metadata
- a.uploadBackupKeys(ctx, req, res)
- return nil
+ return a.uploadBackupKeys(ctx, req)
}
-func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) {
+func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest) (*api.PerformKeyBackupResponse, error) {
+ res := &api.PerformKeyBackupResponse{}
// you can only upload keys for the CURRENT version
version, _, _, _, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, "")
if err != nil {
- res.Error = fmt.Sprintf("failed to query version: %s", err)
- return
+ return res, fmt.Errorf("failed to query version: %w", err)
}
if deleted {
- res.Error = "backup was deleted"
- return
+ return res, fmt.Errorf("backup was deleted")
}
if version != req.Version {
- res.BadInput = true
- res.Error = fmt.Sprintf("%s isn't the current version, %s is.", req.Version, version)
- return
+ return res, jsonerror.WrongBackupVersionError(version)
}
res.Exists = true
res.Version = version
@@ -751,23 +733,25 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform
}
count, etag, err := a.DB.UpsertBackupKeys(ctx, version, req.UserID, uploads)
if err != nil {
- res.Error = fmt.Sprintf("failed to upsert keys: %s", err)
- return
+ return res, fmt.Errorf("failed to upsert keys: %w", err)
}
res.KeyCount = count
res.KeyETag = etag
+ return res, nil
}
-func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) error {
+func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest) (*api.QueryKeyBackupResponse, error) {
+ res := &api.QueryKeyBackupResponse{}
version, algorithm, authData, etag, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, req.Version)
res.Version = version
if err != nil {
- if err == sql.ErrNoRows {
- res.Exists = false
- return nil
+ if errors.Is(err, sql.ErrNoRows) {
+ return res, nil
}
- res.Error = fmt.Sprintf("failed to query key backup: %s", err)
- return nil
+ if errors.Is(err, strconv.ErrSyntax) {
+ return res, nil
+ }
+ return res, fmt.Errorf("failed to query key backup: %s", err)
}
res.Algorithm = algorithm
res.AuthData = authData
@@ -777,18 +761,17 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB
if !req.ReturnKeys {
res.Count, err = a.DB.CountBackupKeys(ctx, version, req.UserID)
if err != nil {
- res.Error = fmt.Sprintf("failed to count keys: %s", err)
+ return res, fmt.Errorf("failed to count keys: %w", err)
}
- return nil
+ return res, nil
}
result, err := a.DB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID)
if err != nil {
- res.Error = fmt.Sprintf("failed to query keys: %s", err)
- return nil
+ return res, fmt.Errorf("failed to query keys: %s", err)
}
res.Keys = result
- return nil
+ return res, nil
}
func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.QueryNotificationsRequest, res *api.QueryNotificationsResponse) error {