diff options
author | kegsay <kegan@matrix.org> | 2021-07-27 17:08:53 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-07-27 17:08:53 +0100 |
commit | b3754d68fcbe9022eb0bf4f8eda7102b7c27e62d (patch) | |
tree | fb289fc2baf56205292d0f0ce28c9f6924278b01 /userapi/storage/accounts/postgres/storage.go | |
parent | a060df91e206903e4e3cbf7b7d2dabddfa0bf788 (diff) |
Key Backups (2/3) : Add E2E backup key tables (#1945)
* Add PUT key backup endpoints and glue them to PerformKeyBackup
* Add tables for storing backup keys and glue them into the user API
* Don't create tables whilst still WIPing
* writer on sqlite please
* Linting
Diffstat (limited to 'userapi/storage/accounts/postgres/storage.go')
-rw-r--r-- | userapi/storage/accounts/postgres/storage.go | 115 |
1 files changed, 106 insertions, 9 deletions
diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go index 719e9878..b07218b2 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -19,6 +19,7 @@ import ( "database/sql" "encoding/json" "errors" + "fmt" "strconv" "time" @@ -45,7 +46,8 @@ type Database struct { accountDatas accountDataStatements threepids threepidStatements openIDTokens tokenStatements - keyBackups keyBackupVersionStatements + keyBackupVersions keyBackupVersionStatements + keyBackups keyBackupStatements serverName gomatrixserverlib.ServerName bcryptCost int openIDTokenLifetimeMS int64 @@ -94,9 +96,13 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.openIDTokens.prepare(db, serverName); err != nil { return nil, err } - if err = d.keyBackups.prepare(db); err != nil { - return nil, err - } + /* + if err = d.keyBackupVersions.prepare(db); err != nil { + return nil, err + } + if err = d.keyBackups.prepare(db); err != nil { + return nil, err + } */ return d, nil } @@ -377,7 +383,7 @@ func (d *Database) CreateKeyBackup( ctx context.Context, userID, algorithm string, authData json.RawMessage, ) (version string, err error) { err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - version, err = d.keyBackups.insertKeyBackup(ctx, txn, userID, algorithm, authData) + version, err = d.keyBackupVersions.insertKeyBackup(ctx, txn, userID, algorithm, authData, "") return err }) return @@ -387,7 +393,7 @@ func (d *Database) UpdateKeyBackupAuthData( ctx context.Context, userID, version string, authData json.RawMessage, ) (err error) { err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - return d.keyBackups.updateKeyBackupAuthData(ctx, txn, userID, version, authData) + return d.keyBackupVersions.updateKeyBackupAuthData(ctx, txn, userID, version, authData) }) return } @@ -396,7 +402,7 @@ func (d *Database) DeleteKeyBackup( ctx context.Context, userID, version string, ) (exists bool, err error) { err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - exists, err = d.keyBackups.deleteKeyBackup(ctx, txn, userID, version) + exists, err = d.keyBackupVersions.deleteKeyBackup(ctx, txn, userID, version) return err }) return @@ -404,10 +410,101 @@ func (d *Database) DeleteKeyBackup( func (d *Database) GetKeyBackup( ctx context.Context, userID, version string, -) (versionResult, algorithm string, authData json.RawMessage, deleted bool, err error) { +) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { - versionResult, algorithm, authData, deleted, err = d.keyBackups.selectKeyBackup(ctx, txn, userID, version) + versionResult, algorithm, authData, etag, deleted, err = d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version) return err }) return } + +// nolint:nakedret +func (d *Database) UpsertBackupKeys( + ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession, +) (count int64, etag string, err error) { + // wrap the following logic in a txn to ensure we atomically upload keys + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + _, _, _, oldETag, deleted, err := d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version) + if err != nil { + return err + } + if deleted { + return fmt.Errorf("backup was deleted") + } + // pull out all keys for this (user_id, version) + existingKeys, err := d.keyBackups.selectKeys(ctx, txn, userID, version) + if err != nil { + return err + } + + changed := false + // loop over all the new keys (which should be smaller than the set of backed up keys) + for _, newKey := range uploads { + // if we have a matching (room_id, session_id), we may need to update the key if it meets some rules, check them. + existingRoom := existingKeys[newKey.RoomID] + if existingRoom != nil { + existingSession, ok := existingRoom[newKey.SessionID] + if ok { + if shouldReplaceRoomKey(existingSession, newKey.KeyBackupSession) { + err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey) + changed = true + if err != nil { + return err + } + } + // if we shouldn't replace the key we do nothing with it + continue + } + } + // if we're here, either the room or session are new, either way, we insert + err = d.keyBackups.insertBackupKey(ctx, txn, userID, version, newKey) + changed = true + if err != nil { + return err + } + } + + count, err = d.keyBackups.countKeys(ctx, txn, userID, version) + if err != nil { + return err + } + if changed { + // update the etag + var newETag string + if oldETag == "" { + newETag = "1" + } else { + oldETagInt, err := strconv.ParseInt(oldETag, 10, 64) + if err != nil { + return fmt.Errorf("failed to parse old etag: %s", err) + } + newETag = strconv.FormatInt(oldETagInt+1, 10) + } + etag = newETag + return d.keyBackupVersions.updateKeyBackupETag(ctx, txn, userID, version, newETag) + } else { + etag = oldETag + } + return nil + }) + return +} + +// TODO FIXME XXX : This logic really shouldn't live in the storage layer, but I don't know where else is sensible which won't +// create circular import loops +func shouldReplaceRoomKey(existing, uploaded api.KeyBackupSession) bool { + // https://spec.matrix.org/unstable/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2 + // "if the keys have different values for is_verified, then it will keep the key that has is_verified set to true" + if uploaded.IsVerified && !existing.IsVerified { + return true + } + // "if they have the same values for is_verified, then it will keep the key with a lower first_message_index" + if uploaded.FirstMessageIndex < existing.FirstMessageIndex { + return true + } + // "and finally, is is_verified and first_message_index are equal, then it will keep the key with a lower forwarded_count" + if uploaded.ForwardedCount < existing.ForwardedCount { + return true + } + return false +} |