aboutsummaryrefslogtreecommitdiff
path: root/userapi
diff options
context:
space:
mode:
Diffstat (limited to 'userapi')
-rw-r--r--userapi/api/api.go25
-rw-r--r--userapi/internal/api.go28
-rw-r--r--userapi/storage/accounts/interface.go2
-rw-r--r--userapi/storage/accounts/postgres/key_backup_table.go54
-rw-r--r--userapi/storage/accounts/postgres/key_backup_version_table.go1
-rw-r--r--userapi/storage/accounts/postgres/storage.go65
-rw-r--r--userapi/storage/accounts/sqlite3/key_backup_table.go54
-rw-r--r--userapi/storage/accounts/sqlite3/key_backup_version_table.go1
-rw-r--r--userapi/storage/accounts/sqlite3/storage.go65
9 files changed, 219 insertions, 76 deletions
diff --git a/userapi/api/api.go b/userapi/api/api.go
index 7e18d72f..b0d91856 100644
--- a/userapi/api/api.go
+++ b/userapi/api/api.go
@@ -67,6 +67,23 @@ type KeyBackupSession struct {
SessionData json.RawMessage `json:"session_data"`
}
+func (a *KeyBackupSession) ShouldReplaceRoomKey(newKey *KeyBackupSession) bool {
+ // https://spec.matrix.org/unstable/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2
+ // "if the keys have different values for is_verified, then it will keep the key that has is_verified set to true"
+ if newKey.IsVerified && !a.IsVerified {
+ return true
+ }
+ // "if they have the same values for is_verified, then it will keep the key with a lower first_message_index"
+ if newKey.FirstMessageIndex < a.FirstMessageIndex {
+ return true
+ }
+ // "and finally, is is_verified and first_message_index are equal, then it will keep the key with a lower forwarded_count"
+ if newKey.ForwardedCount < a.ForwardedCount {
+ return true
+ }
+ return false
+}
+
// Internal KeyBackupData for passing to/from the storage layer
type InternalKeyBackupSession struct {
KeyBackupSession
@@ -88,6 +105,10 @@ type PerformKeyBackupResponse struct {
type QueryKeyBackupRequest struct {
UserID string
Version string // the version to query, if blank it means the latest
+
+ ReturnKeys bool // whether to return keys in the backup response or just the metadata
+ KeysForRoomID string // optional string to return keys which belong to this room
+ KeysForSessionID string // optional string to return keys which belong to this (room, session)
}
type QueryKeyBackupResponse struct {
@@ -96,9 +117,11 @@ type QueryKeyBackupResponse struct {
Algorithm string `json:"algorithm"`
AuthData json.RawMessage `json:"auth_data"`
- Count int `json:"count"`
+ Count int64 `json:"count"`
ETag string `json:"etag"`
Version string `json:"version"`
+
+ Keys map[string]map[string]KeyBackupSession // the keys if ReturnKeys=true
}
// InputAccountDataRequest is the request for InputAccountData
diff --git a/userapi/internal/api.go b/userapi/internal/api.go
index 27e17963..a2bc8ecf 100644
--- a/userapi/internal/api.go
+++ b/userapi/internal/api.go
@@ -475,6 +475,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
if err != nil {
res.Error = fmt.Sprintf("failed to update backup: %s", err)
}
+ res.Exists = err == nil
res.Version = req.Version
return
}
@@ -483,8 +484,8 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
}
func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) {
- // ensure the version metadata exists
- version, _, _, _, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, req.Version)
+ // you can only upload keys for the CURRENT version
+ version, _, _, _, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, "")
if err != nil {
res.Error = fmt.Sprintf("failed to query version: %s", err)
return
@@ -493,6 +494,11 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform
res.Error = "backup was deleted"
return
}
+ if version != req.Version {
+ res.BadInput = true
+ res.Error = fmt.Sprintf("%s isn't the current version, %s is.", req.Version, version)
+ return
+ }
res.Exists = true
res.Version = version
@@ -529,9 +535,21 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB
}
res.Algorithm = algorithm
res.AuthData = authData
+ res.ETag = etag
res.Exists = !deleted
- // TODO:
- res.Count = 0
- res.ETag = etag
+ if !req.ReturnKeys {
+ res.Count, err = a.AccountDB.CountBackupKeys(ctx, version, req.UserID)
+ if err != nil {
+ res.Error = fmt.Sprintf("failed to count keys: %s", err)
+ }
+ return
+ }
+
+ result, err := a.AccountDB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID)
+ if err != nil {
+ res.Error = fmt.Sprintf("failed to query keys: %s", err)
+ return
+ }
+ res.Keys = result
}
diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/accounts/interface.go
index 4fd9c177..887f7193 100644
--- a/userapi/storage/accounts/interface.go
+++ b/userapi/storage/accounts/interface.go
@@ -61,6 +61,8 @@ type Database interface {
DeleteKeyBackup(ctx context.Context, userID, version string) (exists bool, err error)
GetKeyBackup(ctx context.Context, userID, version string) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error)
UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error)
+ GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error)
+ CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error)
}
// Err3PIDInUse is the error returned when trying to save an association involving
diff --git a/userapi/storage/accounts/postgres/key_backup_table.go b/userapi/storage/accounts/postgres/key_backup_table.go
index 0dc5879b..ec651826 100644
--- a/userapi/storage/accounts/postgres/key_backup_table.go
+++ b/userapi/storage/accounts/postgres/key_backup_table.go
@@ -35,7 +35,8 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys (
is_verified BOOLEAN NOT NULL,
session_data TEXT NOT NULL
);
-CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id);
+CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version);
+CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version);
`
const insertBackupKeySQL = "" +
@@ -53,14 +54,23 @@ const selectKeysSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
"WHERE user_id = $1 AND version = $2"
+const selectKeysByRoomIDSQL = "" +
+ "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
+ "WHERE user_id = $1 AND version = $2 AND room_id = $3"
+
+const selectKeysByRoomIDAndSessionIDSQL = "" +
+ "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
+ "WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4"
+
type keyBackupStatements struct {
- insertBackupKeyStmt *sql.Stmt
- updateBackupKeyStmt *sql.Stmt
- countKeysStmt *sql.Stmt
- selectKeysStmt *sql.Stmt
+ insertBackupKeyStmt *sql.Stmt
+ updateBackupKeyStmt *sql.Stmt
+ countKeysStmt *sql.Stmt
+ selectKeysStmt *sql.Stmt
+ selectKeysByRoomIDStmt *sql.Stmt
+ selectKeysByRoomIDAndSessionIDStmt *sql.Stmt
}
-// nolint:unused
func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(keyBackupTableSchema)
if err != nil {
@@ -78,6 +88,12 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil {
return
}
+ if s.selectKeysByRoomIDStmt, err = db.Prepare(selectKeysByRoomIDSQL); err != nil {
+ return
+ }
+ if s.selectKeysByRoomIDAndSessionIDStmt, err = db.Prepare(selectKeysByRoomIDAndSessionIDSQL); err != nil {
+ return
+ }
return
}
@@ -109,11 +125,35 @@ func (s *keyBackupStatements) updateBackupKey(
func (s *keyBackupStatements) selectKeys(
ctx context.Context, txn *sql.Tx, userID, version string,
) (map[string]map[string]api.KeyBackupSession, error) {
- result := make(map[string]map[string]api.KeyBackupSession)
rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version)
if err != nil {
return nil, err
}
+ return unpackKeys(ctx, rows)
+}
+
+func (s *keyBackupStatements) selectKeysByRoomID(
+ ctx context.Context, txn *sql.Tx, userID, version, roomID string,
+) (map[string]map[string]api.KeyBackupSession, error) {
+ rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID)
+ if err != nil {
+ return nil, err
+ }
+ return unpackKeys(ctx, rows)
+}
+
+func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID(
+ ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string,
+) (map[string]map[string]api.KeyBackupSession, error) {
+ rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID)
+ if err != nil {
+ return nil, err
+ }
+ return unpackKeys(ctx, rows)
+}
+
+func unpackKeys(ctx context.Context, rows *sql.Rows) (map[string]map[string]api.KeyBackupSession, error) {
+ result := make(map[string]map[string]api.KeyBackupSession)
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt.Close failed")
for rows.Next() {
var key api.InternalKeyBackupSession
diff --git a/userapi/storage/accounts/postgres/key_backup_version_table.go b/userapi/storage/accounts/postgres/key_backup_version_table.go
index 323a842d..aca575df 100644
--- a/userapi/storage/accounts/postgres/key_backup_version_table.go
+++ b/userapi/storage/accounts/postgres/key_backup_version_table.go
@@ -67,7 +67,6 @@ type keyBackupVersionStatements struct {
updateKeyBackupETagStmt *sql.Stmt
}
-// nolint:unused
func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(keyBackupVersionTableSchema)
if err != nil {
diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go
index b07218b2..9d6fd13a 100644
--- a/userapi/storage/accounts/postgres/storage.go
+++ b/userapi/storage/accounts/postgres/storage.go
@@ -96,13 +96,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = d.openIDTokens.prepare(db, serverName); err != nil {
return nil, err
}
- /*
- if err = d.keyBackupVersions.prepare(db); err != nil {
- return nil, err
- }
- if err = d.keyBackups.prepare(db); err != nil {
- return nil, err
- } */
+ if err = d.keyBackupVersions.prepare(db); err != nil {
+ return nil, err
+ }
+ if err = d.keyBackups.prepare(db); err != nil {
+ return nil, err
+ }
return d, nil
}
@@ -418,6 +417,37 @@ func (d *Database) GetKeyBackup(
return
}
+func (d *Database) GetBackupKeys(
+ ctx context.Context, version, userID, filterRoomID, filterSessionID string,
+) (result map[string]map[string]api.KeyBackupSession, err error) {
+ err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ if filterSessionID != "" {
+ result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID)
+ return err
+ }
+ if filterRoomID != "" {
+ result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID)
+ return err
+ }
+ result, err = d.keyBackups.selectKeys(ctx, txn, userID, version)
+ return err
+ })
+ return
+}
+
+func (d *Database) CountBackupKeys(
+ ctx context.Context, version, userID string,
+) (count int64, err error) {
+ err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
+ count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
+ if err != nil {
+ return err
+ }
+ return nil
+ })
+ return
+}
+
// nolint:nakedret
func (d *Database) UpsertBackupKeys(
ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession,
@@ -445,7 +475,7 @@ func (d *Database) UpsertBackupKeys(
if existingRoom != nil {
existingSession, ok := existingRoom[newKey.SessionID]
if ok {
- if shouldReplaceRoomKey(existingSession, newKey.KeyBackupSession) {
+ if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) {
err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey)
changed = true
if err != nil {
@@ -489,22 +519,3 @@ func (d *Database) UpsertBackupKeys(
})
return
}
-
-// TODO FIXME XXX : This logic really shouldn't live in the storage layer, but I don't know where else is sensible which won't
-// create circular import loops
-func shouldReplaceRoomKey(existing, uploaded api.KeyBackupSession) bool {
- // https://spec.matrix.org/unstable/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2
- // "if the keys have different values for is_verified, then it will keep the key that has is_verified set to true"
- if uploaded.IsVerified && !existing.IsVerified {
- return true
- }
- // "if they have the same values for is_verified, then it will keep the key with a lower first_message_index"
- if uploaded.FirstMessageIndex < existing.FirstMessageIndex {
- return true
- }
- // "and finally, is is_verified and first_message_index are equal, then it will keep the key with a lower forwarded_count"
- if uploaded.ForwardedCount < existing.ForwardedCount {
- return true
- }
- return false
-}
diff --git a/userapi/storage/accounts/sqlite3/key_backup_table.go b/userapi/storage/accounts/sqlite3/key_backup_table.go
index 268bda93..c1a698e6 100644
--- a/userapi/storage/accounts/sqlite3/key_backup_table.go
+++ b/userapi/storage/accounts/sqlite3/key_backup_table.go
@@ -35,7 +35,8 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys (
is_verified BOOLEAN NOT NULL,
session_data TEXT NOT NULL
);
-CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id);
+CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version);
+CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version);
`
const insertBackupKeySQL = "" +
@@ -53,14 +54,23 @@ const selectKeysSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
"WHERE user_id = $1 AND version = $2"
+const selectKeysByRoomIDSQL = "" +
+ "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
+ "WHERE user_id = $1 AND version = $2 AND room_id = $3"
+
+const selectKeysByRoomIDAndSessionIDSQL = "" +
+ "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
+ "WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4"
+
type keyBackupStatements struct {
- insertBackupKeyStmt *sql.Stmt
- updateBackupKeyStmt *sql.Stmt
- countKeysStmt *sql.Stmt
- selectKeysStmt *sql.Stmt
+ insertBackupKeyStmt *sql.Stmt
+ updateBackupKeyStmt *sql.Stmt
+ countKeysStmt *sql.Stmt
+ selectKeysStmt *sql.Stmt
+ selectKeysByRoomIDStmt *sql.Stmt
+ selectKeysByRoomIDAndSessionIDStmt *sql.Stmt
}
-// nolint:unused
func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(keyBackupTableSchema)
if err != nil {
@@ -78,6 +88,12 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil {
return
}
+ if s.selectKeysByRoomIDStmt, err = db.Prepare(selectKeysByRoomIDSQL); err != nil {
+ return
+ }
+ if s.selectKeysByRoomIDAndSessionIDStmt, err = db.Prepare(selectKeysByRoomIDAndSessionIDSQL); err != nil {
+ return
+ }
return
}
@@ -109,11 +125,35 @@ func (s *keyBackupStatements) updateBackupKey(
func (s *keyBackupStatements) selectKeys(
ctx context.Context, txn *sql.Tx, userID, version string,
) (map[string]map[string]api.KeyBackupSession, error) {
- result := make(map[string]map[string]api.KeyBackupSession)
rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version)
if err != nil {
return nil, err
}
+ return unpackKeys(ctx, rows)
+}
+
+func (s *keyBackupStatements) selectKeysByRoomID(
+ ctx context.Context, txn *sql.Tx, userID, version, roomID string,
+) (map[string]map[string]api.KeyBackupSession, error) {
+ rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID)
+ if err != nil {
+ return nil, err
+ }
+ return unpackKeys(ctx, rows)
+}
+
+func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID(
+ ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string,
+) (map[string]map[string]api.KeyBackupSession, error) {
+ rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID)
+ if err != nil {
+ return nil, err
+ }
+ return unpackKeys(ctx, rows)
+}
+
+func unpackKeys(ctx context.Context, rows *sql.Rows) (map[string]map[string]api.KeyBackupSession, error) {
+ result := make(map[string]map[string]api.KeyBackupSession)
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt.Close failed")
for rows.Next() {
var key api.InternalKeyBackupSession
diff --git a/userapi/storage/accounts/sqlite3/key_backup_version_table.go b/userapi/storage/accounts/sqlite3/key_backup_version_table.go
index 72e9b132..9a58fee7 100644
--- a/userapi/storage/accounts/sqlite3/key_backup_version_table.go
+++ b/userapi/storage/accounts/sqlite3/key_backup_version_table.go
@@ -65,7 +65,6 @@ type keyBackupVersionStatements struct {
updateKeyBackupETagStmt *sql.Stmt
}
-// nolint:unused
func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(keyBackupVersionTableSchema)
if err != nil {
diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go
index 4fae621f..728ae901 100644
--- a/userapi/storage/accounts/sqlite3/storage.go
+++ b/userapi/storage/accounts/sqlite3/storage.go
@@ -100,13 +100,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = d.openIDTokens.prepare(db, serverName); err != nil {
return nil, err
}
- /*
- if err = d.keyBackupVersions.prepare(db); err != nil {
- return nil, err
- }
- if err = d.keyBackups.prepare(db); err != nil {
- return nil, err
- } */
+ if err = d.keyBackupVersions.prepare(db); err != nil {
+ return nil, err
+ }
+ if err = d.keyBackups.prepare(db); err != nil {
+ return nil, err
+ }
return d, nil
}
@@ -459,6 +458,37 @@ func (d *Database) GetKeyBackup(
return
}
+func (d *Database) GetBackupKeys(
+ ctx context.Context, version, userID, filterRoomID, filterSessionID string,
+) (result map[string]map[string]api.KeyBackupSession, err error) {
+ err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
+ if filterSessionID != "" {
+ result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID)
+ return err
+ }
+ if filterRoomID != "" {
+ result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID)
+ return err
+ }
+ result, err = d.keyBackups.selectKeys(ctx, txn, userID, version)
+ return err
+ })
+ return
+}
+
+func (d *Database) CountBackupKeys(
+ ctx context.Context, version, userID string,
+) (count int64, err error) {
+ err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
+ count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
+ if err != nil {
+ return err
+ }
+ return nil
+ })
+ return
+}
+
// nolint:nakedret
func (d *Database) UpsertBackupKeys(
ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession,
@@ -486,7 +516,7 @@ func (d *Database) UpsertBackupKeys(
if existingRoom != nil {
existingSession, ok := existingRoom[newKey.SessionID]
if ok {
- if shouldReplaceRoomKey(existingSession, newKey.KeyBackupSession) {
+ if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) {
err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey)
changed = true
if err != nil {
@@ -531,22 +561,3 @@ func (d *Database) UpsertBackupKeys(
})
return
}
-
-// TODO FIXME XXX : This logic really shouldn't live in the storage layer, but I don't know where else is sensible which won't
-// create circular import loops
-func shouldReplaceRoomKey(existing, uploaded api.KeyBackupSession) bool {
- // https://spec.matrix.org/unstable/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2
- // "if the keys have different values for is_verified, then it will keep the key that has is_verified set to true"
- if uploaded.IsVerified && !existing.IsVerified {
- return true
- }
- // "if they have the same values for is_verified, then it will keep the key with a lower first_message_index"
- if uploaded.FirstMessageIndex < existing.FirstMessageIndex {
- return true
- }
- // "and finally, is is_verified and first_message_index are equal, then it will keep the key with a lower forwarded_count"
- if uploaded.ForwardedCount < existing.ForwardedCount {
- return true
- }
- return false
-}