aboutsummaryrefslogtreecommitdiff
path: root/userapi/storage/accounts/postgres/storage.go
diff options
context:
space:
mode:
Diffstat (limited to 'userapi/storage/accounts/postgres/storage.go')
-rw-r--r--userapi/storage/accounts/postgres/storage.go115
1 files changed, 106 insertions, 9 deletions
diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go
index 719e9878..b07218b2 100644
--- a/userapi/storage/accounts/postgres/storage.go
+++ b/userapi/storage/accounts/postgres/storage.go
@@ -19,6 +19,7 @@ import (
"database/sql"
"encoding/json"
"errors"
+ "fmt"
"strconv"
"time"
@@ -45,7 +46,8 @@ type Database struct {
accountDatas accountDataStatements
threepids threepidStatements
openIDTokens tokenStatements
- keyBackups keyBackupVersionStatements
+ keyBackupVersions keyBackupVersionStatements
+ keyBackups keyBackupStatements
serverName gomatrixserverlib.ServerName
bcryptCost int
openIDTokenLifetimeMS int64
@@ -94,9 +96,13 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = d.openIDTokens.prepare(db, serverName); err != nil {
return nil, err
}
- if err = d.keyBackups.prepare(db); err != nil {
- return nil, err
- }
+ /*
+ if err = d.keyBackupVersions.prepare(db); err != nil {
+ return nil, err
+ }
+ if err = d.keyBackups.prepare(db); err != nil {
+ return nil, err
+ } */
return d, nil
}
@@ -377,7 +383,7 @@ func (d *Database) CreateKeyBackup(
ctx context.Context, userID, algorithm string, authData json.RawMessage,
) (version string, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
- version, err = d.keyBackups.insertKeyBackup(ctx, txn, userID, algorithm, authData)
+ version, err = d.keyBackupVersions.insertKeyBackup(ctx, txn, userID, algorithm, authData, "")
return err
})
return
@@ -387,7 +393,7 @@ func (d *Database) UpdateKeyBackupAuthData(
ctx context.Context, userID, version string, authData json.RawMessage,
) (err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
- return d.keyBackups.updateKeyBackupAuthData(ctx, txn, userID, version, authData)
+ return d.keyBackupVersions.updateKeyBackupAuthData(ctx, txn, userID, version, authData)
})
return
}
@@ -396,7 +402,7 @@ func (d *Database) DeleteKeyBackup(
ctx context.Context, userID, version string,
) (exists bool, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
- exists, err = d.keyBackups.deleteKeyBackup(ctx, txn, userID, version)
+ exists, err = d.keyBackupVersions.deleteKeyBackup(ctx, txn, userID, version)
return err
})
return
@@ -404,10 +410,101 @@ func (d *Database) DeleteKeyBackup(
func (d *Database) GetKeyBackup(
ctx context.Context, userID, version string,
-) (versionResult, algorithm string, authData json.RawMessage, deleted bool, err error) {
+) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
- versionResult, algorithm, authData, deleted, err = d.keyBackups.selectKeyBackup(ctx, txn, userID, version)
+ versionResult, algorithm, authData, etag, deleted, err = d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version)
return err
})
return
}
+
+// nolint:nakedret
+func (d *Database) UpsertBackupKeys(
+ ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession,
+) (count int64, etag string, err error) {
+ // wrap the following logic in a txn to ensure we atomically upload keys
+ err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ _, _, _, oldETag, deleted, err := d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version)
+ if err != nil {
+ return err
+ }
+ if deleted {
+ return fmt.Errorf("backup was deleted")
+ }
+ // pull out all keys for this (user_id, version)
+ existingKeys, err := d.keyBackups.selectKeys(ctx, txn, userID, version)
+ if err != nil {
+ return err
+ }
+
+ changed := false
+ // loop over all the new keys (which should be smaller than the set of backed up keys)
+ for _, newKey := range uploads {
+ // if we have a matching (room_id, session_id), we may need to update the key if it meets some rules, check them.
+ existingRoom := existingKeys[newKey.RoomID]
+ if existingRoom != nil {
+ existingSession, ok := existingRoom[newKey.SessionID]
+ if ok {
+ if shouldReplaceRoomKey(existingSession, newKey.KeyBackupSession) {
+ err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey)
+ changed = true
+ if err != nil {
+ return err
+ }
+ }
+ // if we shouldn't replace the key we do nothing with it
+ continue
+ }
+ }
+ // if we're here, either the room or session are new, either way, we insert
+ err = d.keyBackups.insertBackupKey(ctx, txn, userID, version, newKey)
+ changed = true
+ if err != nil {
+ return err
+ }
+ }
+
+ count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
+ if err != nil {
+ return err
+ }
+ if changed {
+ // update the etag
+ var newETag string
+ if oldETag == "" {
+ newETag = "1"
+ } else {
+ oldETagInt, err := strconv.ParseInt(oldETag, 10, 64)
+ if err != nil {
+ return fmt.Errorf("failed to parse old etag: %s", err)
+ }
+ newETag = strconv.FormatInt(oldETagInt+1, 10)
+ }
+ etag = newETag
+ return d.keyBackupVersions.updateKeyBackupETag(ctx, txn, userID, version, newETag)
+ } else {
+ etag = oldETag
+ }
+ return nil
+ })
+ return
+}
+
+// TODO FIXME XXX : This logic really shouldn't live in the storage layer, but I don't know where else is sensible which won't
+// create circular import loops
+func shouldReplaceRoomKey(existing, uploaded api.KeyBackupSession) bool {
+ // https://spec.matrix.org/unstable/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2
+ // "if the keys have different values for is_verified, then it will keep the key that has is_verified set to true"
+ if uploaded.IsVerified && !existing.IsVerified {
+ return true
+ }
+ // "if they have the same values for is_verified, then it will keep the key with a lower first_message_index"
+ if uploaded.FirstMessageIndex < existing.FirstMessageIndex {
+ return true
+ }
+ // "and finally, is is_verified and first_message_index are equal, then it will keep the key with a lower forwarded_count"
+ if uploaded.ForwardedCount < existing.ForwardedCount {
+ return true
+ }
+ return false
+}