aboutsummaryrefslogtreecommitdiff
path: root/userapi/storage/accounts/postgres/key_backup_version_table.go
diff options
context:
space:
mode:
Diffstat (limited to 'userapi/storage/accounts/postgres/key_backup_version_table.go')
-rw-r--r--userapi/storage/accounts/postgres/key_backup_version_table.go144
1 files changed, 144 insertions, 0 deletions
diff --git a/userapi/storage/accounts/postgres/key_backup_version_table.go b/userapi/storage/accounts/postgres/key_backup_version_table.go
new file mode 100644
index 00000000..1b693e56
--- /dev/null
+++ b/userapi/storage/accounts/postgres/key_backup_version_table.go
@@ -0,0 +1,144 @@
+// Copyright 2021 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "fmt"
+ "strconv"
+)
+
+const keyBackupVersionTableSchema = `
+CREATE SEQUENCE IF NOT EXISTS account_e2e_room_keys_versions_seq;
+
+-- the metadata for each generation of encrypted e2e session backups
+CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions (
+ user_id TEXT NOT NULL,
+ -- this means no 2 users will ever have the same version of e2e session backups which strictly
+ -- isn't necessary, but this is easy to do rather than SELECT MAX(version)+1.
+ version BIGINT DEFAULT nextval('account_e2e_room_keys_versions_seq'),
+ algorithm TEXT NOT NULL,
+ auth_data TEXT NOT NULL,
+ deleted SMALLINT DEFAULT 0 NOT NULL
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version);
+`
+
+const insertKeyBackupSQL = "" +
+ "INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data) VALUES ($1, $2, $3) RETURNING version"
+
+const updateKeyBackupAuthDataSQL = "" + // TODO: do we need to WHERE algorithm = $3 as well?
+ "UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3"
+
+const deleteKeyBackupSQL = "" +
+ "UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2"
+
+const selectKeyBackupSQL = "" +
+ "SELECT algorithm, auth_data, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2"
+
+const selectLatestVersionSQL = "" +
+ "SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1"
+
+type keyBackupVersionStatements struct {
+ insertKeyBackupStmt *sql.Stmt
+ updateKeyBackupAuthDataStmt *sql.Stmt
+ deleteKeyBackupStmt *sql.Stmt
+ selectKeyBackupStmt *sql.Stmt
+ selectLatestVersionStmt *sql.Stmt
+}
+
+func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
+ _, err = db.Exec(keyBackupVersionTableSchema)
+ if err != nil {
+ return
+ }
+ if s.insertKeyBackupStmt, err = db.Prepare(insertKeyBackupSQL); err != nil {
+ return
+ }
+ if s.updateKeyBackupAuthDataStmt, err = db.Prepare(updateKeyBackupAuthDataSQL); err != nil {
+ return
+ }
+ if s.deleteKeyBackupStmt, err = db.Prepare(deleteKeyBackupSQL); err != nil {
+ return
+ }
+ if s.selectKeyBackupStmt, err = db.Prepare(selectKeyBackupSQL); err != nil {
+ return
+ }
+ if s.selectLatestVersionStmt, err = db.Prepare(selectLatestVersionSQL); err != nil {
+ return
+ }
+ return
+}
+
+func (s *keyBackupVersionStatements) insertKeyBackup(
+ ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage,
+) (version string, err error) {
+ var versionInt int64
+ err = txn.Stmt(s.insertKeyBackupStmt).QueryRowContext(ctx, userID, algorithm, string(authData)).Scan(&versionInt)
+ return strconv.FormatInt(versionInt, 10), err
+}
+
+func (s *keyBackupVersionStatements) updateKeyBackupAuthData(
+ ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage,
+) error {
+ versionInt, err := strconv.ParseInt(version, 10, 64)
+ if err != nil {
+ return fmt.Errorf("invalid version")
+ }
+ _, err = txn.Stmt(s.updateKeyBackupAuthDataStmt).ExecContext(ctx, string(authData), userID, versionInt)
+ return err
+}
+
+func (s *keyBackupVersionStatements) deleteKeyBackup(
+ ctx context.Context, txn *sql.Tx, userID, version string,
+) (bool, error) {
+ versionInt, err := strconv.ParseInt(version, 10, 64)
+ if err != nil {
+ return false, fmt.Errorf("invalid version")
+ }
+ result, err := txn.Stmt(s.deleteKeyBackupStmt).ExecContext(ctx, userID, versionInt)
+ if err != nil {
+ return false, err
+ }
+ ra, err := result.RowsAffected()
+ if err != nil {
+ return false, err
+ }
+ return ra == 1, nil
+}
+
+func (s *keyBackupVersionStatements) selectKeyBackup(
+ ctx context.Context, txn *sql.Tx, userID, version string,
+) (versionResult, algorithm string, authData json.RawMessage, deleted bool, err error) {
+ var versionInt int64
+ if version == "" {
+ err = txn.Stmt(s.selectLatestVersionStmt).QueryRowContext(ctx, userID).Scan(&versionInt)
+ } else {
+ versionInt, err = strconv.ParseInt(version, 10, 64)
+ }
+ if err != nil {
+ return
+ }
+ versionResult = strconv.FormatInt(versionInt, 10)
+ var deletedInt int
+ var authDataStr string
+ err = txn.Stmt(s.selectKeyBackupStmt).QueryRowContext(ctx, userID, versionInt).Scan(&algorithm, &authDataStr, &deletedInt)
+ deleted = deletedInt == 1
+ authData = json.RawMessage(authDataStr)
+ return
+}