diff options
Diffstat (limited to 'userapi/storage/accounts/postgres')
-rw-r--r-- | userapi/storage/accounts/postgres/key_backup_table.go | 54 | ||||
-rw-r--r-- | userapi/storage/accounts/postgres/key_backup_version_table.go | 1 | ||||
-rw-r--r-- | userapi/storage/accounts/postgres/storage.go | 65 |
3 files changed, 85 insertions, 35 deletions
diff --git a/userapi/storage/accounts/postgres/key_backup_table.go b/userapi/storage/accounts/postgres/key_backup_table.go index 0dc5879b..ec651826 100644 --- a/userapi/storage/accounts/postgres/key_backup_table.go +++ b/userapi/storage/accounts/postgres/key_backup_table.go @@ -35,7 +35,8 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys ( is_verified BOOLEAN NOT NULL, session_data TEXT NOT NULL ); -CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id); +CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version); +CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version); ` const insertBackupKeySQL = "" + @@ -53,14 +54,23 @@ const selectKeysSQL = "" + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + "WHERE user_id = $1 AND version = $2" +const selectKeysByRoomIDSQL = "" + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "WHERE user_id = $1 AND version = $2 AND room_id = $3" + +const selectKeysByRoomIDAndSessionIDSQL = "" + + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " + + "WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4" + type keyBackupStatements struct { - insertBackupKeyStmt *sql.Stmt - updateBackupKeyStmt *sql.Stmt - countKeysStmt *sql.Stmt - selectKeysStmt *sql.Stmt + insertBackupKeyStmt *sql.Stmt + updateBackupKeyStmt *sql.Stmt + countKeysStmt *sql.Stmt + selectKeysStmt *sql.Stmt + selectKeysByRoomIDStmt *sql.Stmt + selectKeysByRoomIDAndSessionIDStmt *sql.Stmt } -// nolint:unused func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { _, err = db.Exec(keyBackupTableSchema) if err != nil { @@ -78,6 +88,12 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil { return } + if s.selectKeysByRoomIDStmt, err = db.Prepare(selectKeysByRoomIDSQL); err != nil { + return + } + if s.selectKeysByRoomIDAndSessionIDStmt, err = db.Prepare(selectKeysByRoomIDAndSessionIDSQL); err != nil { + return + } return } @@ -109,11 +125,35 @@ func (s *keyBackupStatements) updateBackupKey( func (s *keyBackupStatements) selectKeys( ctx context.Context, txn *sql.Tx, userID, version string, ) (map[string]map[string]api.KeyBackupSession, error) { - result := make(map[string]map[string]api.KeyBackupSession) rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version) if err != nil { return nil, err } + return unpackKeys(ctx, rows) +} + +func (s *keyBackupStatements) selectKeysByRoomID( + ctx context.Context, txn *sql.Tx, userID, version, roomID string, +) (map[string]map[string]api.KeyBackupSession, error) { + rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID) + if err != nil { + return nil, err + } + return unpackKeys(ctx, rows) +} + +func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID( + ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string, +) (map[string]map[string]api.KeyBackupSession, error) { + rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID) + if err != nil { + return nil, err + } + return unpackKeys(ctx, rows) +} + +func unpackKeys(ctx context.Context, rows *sql.Rows) (map[string]map[string]api.KeyBackupSession, error) { + result := make(map[string]map[string]api.KeyBackupSession) defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt.Close failed") for rows.Next() { var key api.InternalKeyBackupSession diff --git a/userapi/storage/accounts/postgres/key_backup_version_table.go b/userapi/storage/accounts/postgres/key_backup_version_table.go index 323a842d..aca575df 100644 --- a/userapi/storage/accounts/postgres/key_backup_version_table.go +++ b/userapi/storage/accounts/postgres/key_backup_version_table.go @@ -67,7 +67,6 @@ type keyBackupVersionStatements struct { updateKeyBackupETagStmt *sql.Stmt } -// nolint:unused func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { _, err = db.Exec(keyBackupVersionTableSchema) if err != nil { diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go index b07218b2..9d6fd13a 100644 --- a/userapi/storage/accounts/postgres/storage.go +++ b/userapi/storage/accounts/postgres/storage.go @@ -96,13 +96,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.openIDTokens.prepare(db, serverName); err != nil { return nil, err } - /* - if err = d.keyBackupVersions.prepare(db); err != nil { - return nil, err - } - if err = d.keyBackups.prepare(db); err != nil { - return nil, err - } */ + if err = d.keyBackupVersions.prepare(db); err != nil { + return nil, err + } + if err = d.keyBackups.prepare(db); err != nil { + return nil, err + } return d, nil } @@ -418,6 +417,37 @@ func (d *Database) GetKeyBackup( return } +func (d *Database) GetBackupKeys( + ctx context.Context, version, userID, filterRoomID, filterSessionID string, +) (result map[string]map[string]api.KeyBackupSession, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + if filterSessionID != "" { + result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID) + return err + } + if filterRoomID != "" { + result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID) + return err + } + result, err = d.keyBackups.selectKeys(ctx, txn, userID, version) + return err + }) + return +} + +func (d *Database) CountBackupKeys( + ctx context.Context, version, userID string, +) (count int64, err error) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + count, err = d.keyBackups.countKeys(ctx, txn, userID, version) + if err != nil { + return err + } + return nil + }) + return +} + // nolint:nakedret func (d *Database) UpsertBackupKeys( ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession, @@ -445,7 +475,7 @@ func (d *Database) UpsertBackupKeys( if existingRoom != nil { existingSession, ok := existingRoom[newKey.SessionID] if ok { - if shouldReplaceRoomKey(existingSession, newKey.KeyBackupSession) { + if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) { err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey) changed = true if err != nil { @@ -489,22 +519,3 @@ func (d *Database) UpsertBackupKeys( }) return } - -// TODO FIXME XXX : This logic really shouldn't live in the storage layer, but I don't know where else is sensible which won't -// create circular import loops -func shouldReplaceRoomKey(existing, uploaded api.KeyBackupSession) bool { - // https://spec.matrix.org/unstable/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2 - // "if the keys have different values for is_verified, then it will keep the key that has is_verified set to true" - if uploaded.IsVerified && !existing.IsVerified { - return true - } - // "if they have the same values for is_verified, then it will keep the key with a lower first_message_index" - if uploaded.FirstMessageIndex < existing.FirstMessageIndex { - return true - } - // "and finally, is is_verified and first_message_index are equal, then it will keep the key with a lower forwarded_count" - if uploaded.ForwardedCount < existing.ForwardedCount { - return true - } - return false -} |