aboutsummaryrefslogtreecommitdiff
path: root/userapi/storage/accounts/postgres
diff options
context:
space:
mode:
Diffstat (limited to 'userapi/storage/accounts/postgres')
-rw-r--r--userapi/storage/accounts/postgres/key_backup_table.go54
-rw-r--r--userapi/storage/accounts/postgres/key_backup_version_table.go1
-rw-r--r--userapi/storage/accounts/postgres/storage.go65
3 files changed, 85 insertions, 35 deletions
diff --git a/userapi/storage/accounts/postgres/key_backup_table.go b/userapi/storage/accounts/postgres/key_backup_table.go
index 0dc5879b..ec651826 100644
--- a/userapi/storage/accounts/postgres/key_backup_table.go
+++ b/userapi/storage/accounts/postgres/key_backup_table.go
@@ -35,7 +35,8 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys (
is_verified BOOLEAN NOT NULL,
session_data TEXT NOT NULL
);
-CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id);
+CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version);
+CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version);
`
const insertBackupKeySQL = "" +
@@ -53,14 +54,23 @@ const selectKeysSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
"WHERE user_id = $1 AND version = $2"
+const selectKeysByRoomIDSQL = "" +
+ "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
+ "WHERE user_id = $1 AND version = $2 AND room_id = $3"
+
+const selectKeysByRoomIDAndSessionIDSQL = "" +
+ "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
+ "WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4"
+
type keyBackupStatements struct {
- insertBackupKeyStmt *sql.Stmt
- updateBackupKeyStmt *sql.Stmt
- countKeysStmt *sql.Stmt
- selectKeysStmt *sql.Stmt
+ insertBackupKeyStmt *sql.Stmt
+ updateBackupKeyStmt *sql.Stmt
+ countKeysStmt *sql.Stmt
+ selectKeysStmt *sql.Stmt
+ selectKeysByRoomIDStmt *sql.Stmt
+ selectKeysByRoomIDAndSessionIDStmt *sql.Stmt
}
-// nolint:unused
func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(keyBackupTableSchema)
if err != nil {
@@ -78,6 +88,12 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil {
return
}
+ if s.selectKeysByRoomIDStmt, err = db.Prepare(selectKeysByRoomIDSQL); err != nil {
+ return
+ }
+ if s.selectKeysByRoomIDAndSessionIDStmt, err = db.Prepare(selectKeysByRoomIDAndSessionIDSQL); err != nil {
+ return
+ }
return
}
@@ -109,11 +125,35 @@ func (s *keyBackupStatements) updateBackupKey(
func (s *keyBackupStatements) selectKeys(
ctx context.Context, txn *sql.Tx, userID, version string,
) (map[string]map[string]api.KeyBackupSession, error) {
- result := make(map[string]map[string]api.KeyBackupSession)
rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version)
if err != nil {
return nil, err
}
+ return unpackKeys(ctx, rows)
+}
+
+func (s *keyBackupStatements) selectKeysByRoomID(
+ ctx context.Context, txn *sql.Tx, userID, version, roomID string,
+) (map[string]map[string]api.KeyBackupSession, error) {
+ rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID)
+ if err != nil {
+ return nil, err
+ }
+ return unpackKeys(ctx, rows)
+}
+
+func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID(
+ ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string,
+) (map[string]map[string]api.KeyBackupSession, error) {
+ rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID)
+ if err != nil {
+ return nil, err
+ }
+ return unpackKeys(ctx, rows)
+}
+
+func unpackKeys(ctx context.Context, rows *sql.Rows) (map[string]map[string]api.KeyBackupSession, error) {
+ result := make(map[string]map[string]api.KeyBackupSession)
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt.Close failed")
for rows.Next() {
var key api.InternalKeyBackupSession
diff --git a/userapi/storage/accounts/postgres/key_backup_version_table.go b/userapi/storage/accounts/postgres/key_backup_version_table.go
index 323a842d..aca575df 100644
--- a/userapi/storage/accounts/postgres/key_backup_version_table.go
+++ b/userapi/storage/accounts/postgres/key_backup_version_table.go
@@ -67,7 +67,6 @@ type keyBackupVersionStatements struct {
updateKeyBackupETagStmt *sql.Stmt
}
-// nolint:unused
func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(keyBackupVersionTableSchema)
if err != nil {
diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go
index b07218b2..9d6fd13a 100644
--- a/userapi/storage/accounts/postgres/storage.go
+++ b/userapi/storage/accounts/postgres/storage.go
@@ -96,13 +96,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = d.openIDTokens.prepare(db, serverName); err != nil {
return nil, err
}
- /*
- if err = d.keyBackupVersions.prepare(db); err != nil {
- return nil, err
- }
- if err = d.keyBackups.prepare(db); err != nil {
- return nil, err
- } */
+ if err = d.keyBackupVersions.prepare(db); err != nil {
+ return nil, err
+ }
+ if err = d.keyBackups.prepare(db); err != nil {
+ return nil, err
+ }
return d, nil
}
@@ -418,6 +417,37 @@ func (d *Database) GetKeyBackup(
return
}
+func (d *Database) GetBackupKeys(
+ ctx context.Context, version, userID, filterRoomID, filterSessionID string,
+) (result map[string]map[string]api.KeyBackupSession, err error) {
+ err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ if filterSessionID != "" {
+ result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID)
+ return err
+ }
+ if filterRoomID != "" {
+ result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID)
+ return err
+ }
+ result, err = d.keyBackups.selectKeys(ctx, txn, userID, version)
+ return err
+ })
+ return
+}
+
+func (d *Database) CountBackupKeys(
+ ctx context.Context, version, userID string,
+) (count int64, err error) {
+ err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
+ if err != nil {
+ return err
+ }
+ return nil
+ })
+ return
+}
+
// nolint:nakedret
func (d *Database) UpsertBackupKeys(
ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession,
@@ -445,7 +475,7 @@ func (d *Database) UpsertBackupKeys(
if existingRoom != nil {
existingSession, ok := existingRoom[newKey.SessionID]
if ok {
- if shouldReplaceRoomKey(existingSession, newKey.KeyBackupSession) {
+ if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) {
err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey)
changed = true
if err != nil {
@@ -489,22 +519,3 @@ func (d *Database) UpsertBackupKeys(
})
return
}
-
-// TODO FIXME XXX : This logic really shouldn't live in the storage layer, but I don't know where else is sensible which won't
-// create circular import loops
-func shouldReplaceRoomKey(existing, uploaded api.KeyBackupSession) bool {
- // https://spec.matrix.org/unstable/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2
- // "if the keys have different values for is_verified, then it will keep the key that has is_verified set to true"
- if uploaded.IsVerified && !existing.IsVerified {
- return true
- }
- // "if they have the same values for is_verified, then it will keep the key with a lower first_message_index"
- if uploaded.FirstMessageIndex < existing.FirstMessageIndex {
- return true
- }
- // "and finally, is is_verified and first_message_index are equal, then it will keep the key with a lower forwarded_count"
- if uploaded.ForwardedCount < existing.ForwardedCount {
- return true
- }
- return false
-}