aboutsummaryrefslogtreecommitdiff
path: root/userapi
diff options
context:
space:
mode:
authorNeil <neil@nats.io>2024-12-17 19:19:15 +0100
committerGitHub <noreply@github.com>2024-12-17 19:19:15 +0100
commit78dbf21c5f92bf8245d456a18c21f0771f994618 (patch)
tree5ee55e1506db9821f0b0e107794b037008792c20 /userapi
parentc3d7a34c155b8f1987001ea0b8b33528c77d6839 (diff)
Support for fallback keys (#3451)
Backports support for fallback keys from Harmony, which should make E2EE more reliable in the face of OTK exhaustion. Signed-off-by: Neil Alexander <git@neilalexander.dev> Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com> [skip ci]
Diffstat (limited to 'userapi')
-rw-r--r--userapi/api/api.go36
-rw-r--r--userapi/internal/key_api.go55
-rw-r--r--userapi/storage/interface.go9
-rw-r--r--userapi/storage/postgres/fallback_keys_table.go134
-rw-r--r--userapi/storage/postgres/storage.go5
-rw-r--r--userapi/storage/shared/storage.go23
-rw-r--r--userapi/storage/sqlite3/fallback_keys_table.go132
-rw-r--r--userapi/storage/sqlite3/storage.go5
-rw-r--r--userapi/storage/storage_test.go39
-rw-r--r--userapi/storage/tables/interface.go7
10 files changed, 433 insertions, 12 deletions
diff --git a/userapi/api/api.go b/userapi/api/api.go
index 6da12fc9..26482129 100644
--- a/userapi/api/api.go
+++ b/userapi/api/api.go
@@ -788,12 +788,30 @@ type OneTimeKeysCount struct {
KeyCount map[string]int
}
+// FallbackKeys represents a set of fallback keys for a single device
+// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
+type FallbackKeys struct {
+ // The user who owns this device
+ UserID string
+ // The device ID of this device
+ DeviceID string
+ // A map of algorithm:key_id => key JSON
+ KeyJSON map[string]json.RawMessage
+}
+
+// Split a key in KeyJSON into algorithm and key ID
+func (k *FallbackKeys) Split(keyIDWithAlgo string) (algo string, keyID string) {
+ segments := strings.Split(keyIDWithAlgo, ":")
+ return segments[0], segments[1]
+}
+
// PerformUploadKeysRequest is the request to PerformUploadKeys
type PerformUploadKeysRequest struct {
- UserID string // Required - User performing the request
- DeviceID string // Optional - Device performing the request, for fetching OTK count
- DeviceKeys []DeviceKeys
- OneTimeKeys []OneTimeKeys
+ UserID string // Required - User performing the request
+ DeviceID string // Optional - Device performing the request, for fetching OTK count
+ DeviceKeys []DeviceKeys
+ OneTimeKeys []OneTimeKeys
+ FallbackKeys []FallbackKeys
// OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update
// the display name for their respective device, and NOT to modify the keys. The key
// itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths.
@@ -810,8 +828,9 @@ type PerformUploadKeysResponse struct {
// A fatal error when processing e.g database failures
Error *KeyError
// A map of user_id -> device_id -> Error for tracking failures.
- KeyErrors map[string]map[string]*KeyError
- OneTimeKeyCounts []OneTimeKeysCount
+ KeyErrors map[string]map[string]*KeyError
+ OneTimeKeyCounts []OneTimeKeysCount
+ FallbackKeysUnusedAlgorithms []string
}
// PerformDeleteKeysRequest asks the keyserver to forget about certain
@@ -917,8 +936,9 @@ type QueryOneTimeKeysRequest struct {
type QueryOneTimeKeysResponse struct {
// OTK key counts, in the extended /sync form described by https://matrix.org/docs/spec/client_server/r0.6.1#id84
- Count OneTimeKeysCount
- Error *KeyError
+ Count OneTimeKeysCount
+ UnusedFallbackAlgorithms []string
+ Error *KeyError
}
type QueryDeviceMessagesRequest struct {
diff --git a/userapi/internal/key_api.go b/userapi/internal/key_api.go
index 09ead2c5..6cb11bcd 100644
--- a/userapi/internal/key_api.go
+++ b/userapi/internal/key_api.go
@@ -44,14 +44,22 @@ func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perfor
if len(req.DeviceKeys) > 0 {
a.uploadLocalDeviceKeys(ctx, req, res)
}
- if len(req.OneTimeKeys) > 0 {
- a.uploadOneTimeKeys(ctx, req, res)
+ if len(req.OneTimeKeys) > 0 || len(req.FallbackKeys) > 0 {
+ a.uploadOneTimeAndFallbackKeys(ctx, req, res)
}
otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
if err != nil {
return err
}
+ algos, err := a.KeyDatabase.UnusedFallbackKeyAlgorithms(ctx, req.UserID, req.DeviceID)
+ if err != nil {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("Failed to query unused fallback algorithms: %s", err),
+ }
+ return nil
+ }
res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks}
+ res.FallbackKeysUnusedAlgorithms = algos
return nil
}
@@ -169,7 +177,15 @@ func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOn
}
return nil
}
+ algos, err := a.KeyDatabase.UnusedFallbackKeyAlgorithms(ctx, req.UserID, req.DeviceID)
+ if err != nil {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("Failed to query unused fallback algorithms: %s", err),
+ }
+ return nil
+ }
res.Count = *count
+ res.UnusedFallbackAlgorithms = algos
return nil
}
@@ -507,6 +523,9 @@ func (a *UserInternalAPI) queryRemoteKeysOnServer(
for userID := range userIDsForAllDevices {
err := a.Updater.ManualUpdate(context.Background(), spec.ServerName(serverName), userID)
if err != nil {
+ if errors.Is(err, context.Canceled) {
+ return
+ }
logrus.WithFields(logrus.Fields{
logrus.ErrorKey: err,
"user_id": userID,
@@ -520,6 +539,9 @@ func (a *UserInternalAPI) queryRemoteKeysOnServer(
// user so the fact that we're populating all devices here isn't a problem so long as we have devices.
err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil)
if err != nil {
+ if errors.Is(err, context.Canceled) {
+ return
+ }
logrus.WithFields(logrus.Fields{
logrus.ErrorKey: err,
"user_id": userID,
@@ -715,7 +737,7 @@ func (a *UserInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Pe
}
}
-func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
+func (a *UserInternalAPI) uploadOneTimeAndFallbackKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
if req.UserID == "" {
res.Error = &api.KeyError{
Err: "user ID missing",
@@ -768,7 +790,32 @@ func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perfor
// collect counts
res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts)
}
-
+ if len(req.FallbackKeys) > 0 {
+ if err := a.KeyDatabase.DeleteFallbackKeys(ctx, req.UserID, req.DeviceID); err != nil {
+ res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
+ Err: fmt.Sprintf("%s device %s : failed to clear fallback keys: %s", req.UserID, req.DeviceID, err.Error()),
+ })
+ return
+ }
+ for _, key := range req.FallbackKeys {
+ // grab existing keys based on (user/device/algorithm/key ID)
+ keyIDsWithAlgorithms := make([]string, len(key.KeyJSON))
+ i := 0
+ for keyIDWithAlgo := range key.KeyJSON {
+ keyIDsWithAlgorithms[i] = keyIDWithAlgo
+ i++
+ }
+ unused, err := a.KeyDatabase.StoreFallbackKeys(ctx, key)
+ if err != nil {
+ res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
+ Err: fmt.Sprintf("%s device %s : failed to store fallback keys: %s", req.UserID, req.DeviceID, err.Error()),
+ })
+ continue
+ }
+ // collect counts
+ res.FallbackKeysUnusedAlgorithms = unused
+ }
+ }
}
func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error {
diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go
index 7767f6cd..2a46a7fd 100644
--- a/userapi/storage/interface.go
+++ b/userapi/storage/interface.go
@@ -167,6 +167,15 @@ type KeyDatabase interface {
// OneTimeKeysCount returns a count of all OTKs for this device.
OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
+ // StoreFallbackKeys persists the given fallback keys.
+ StoreFallbackKeys(ctx context.Context, keys api.FallbackKeys) ([]string, error)
+
+ // UnusedFallbackKeyAlgorithms returns unused fallback algorithms for this user/device.
+ UnusedFallbackKeyAlgorithms(ctx context.Context, userID, deviceID string) ([]string, error)
+
+ // DeleteFallbackKeys deletes all fallback keys for the user.
+ DeleteFallbackKeys(ctx context.Context, userID, deviceID string) error
+
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
diff --git a/userapi/storage/postgres/fallback_keys_table.go b/userapi/storage/postgres/fallback_keys_table.go
new file mode 100644
index 00000000..acae7ed6
--- /dev/null
+++ b/userapi/storage/postgres/fallback_keys_table.go
@@ -0,0 +1,134 @@
+// Copyright 2024 New Vector Ltd.
+// Copyright 2017 Vector Creations Ltd
+//
+// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
+// Please see LICENSE files in the repository root for full details.
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "time"
+
+ "github.com/element-hq/dendrite/internal"
+ "github.com/element-hq/dendrite/internal/sqlutil"
+ "github.com/element-hq/dendrite/userapi/api"
+ "github.com/element-hq/dendrite/userapi/storage/tables"
+)
+
+var fallbackKeysSchema = `
+-- Stores one-time public keys for users
+CREATE TABLE IF NOT EXISTS keyserver_fallback_keys (
+ user_id TEXT NOT NULL,
+ device_id TEXT NOT NULL,
+ key_id TEXT NOT NULL,
+ algorithm TEXT NOT NULL,
+ ts_added_secs BIGINT NOT NULL,
+ key_json TEXT NOT NULL,
+ used BOOLEAN NOT NULL,
+ -- Clobber based on tuple of user/device/algorithm.
+ CONSTRAINT keyserver_fallback_keys_unique UNIQUE (user_id, device_id, algorithm)
+);
+
+CREATE INDEX IF NOT EXISTS keyserver_fallback_keys_idx ON keyserver_fallback_keys (user_id, device_id);
+`
+
+const upsertFallbackKeysSQL = "" +
+ "INSERT INTO keyserver_fallback_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json, used)" +
+ " VALUES ($1, $2, $3, $4, $5, $6, false)" +
+ " ON CONFLICT ON CONSTRAINT keyserver_fallback_keys_unique" +
+ " DO UPDATE SET key_id = $3, key_json = $6, used = false"
+
+const selectFallbackUnusedAlgorithmsSQL = "" +
+ "SELECT algorithm FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2 AND used = false"
+
+const selectFallbackKeysByAlgorithmSQL = "" +
+ "SELECT key_id, key_json FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 ORDER BY used ASC LIMIT 1"
+
+const deleteFallbackKeysSQL = "" +
+ "DELETE FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2"
+
+const updateFallbackKeyUsedSQL = "" +
+ "UPDATE keyserver_fallback_keys SET used=true WHERE user_id = $1 AND device_id = $2 AND key_id = $3 AND algorithm = $4"
+
+type fallbackKeysStatements struct {
+ db *sql.DB
+ upsertKeysStmt *sql.Stmt
+ selectUnusedAlgorithmsStmt *sql.Stmt
+ selectKeyByAlgorithmStmt *sql.Stmt
+ deleteFallbackKeysStmt *sql.Stmt
+ updateFallbackKeyUsedStmt *sql.Stmt
+}
+
+func NewPostgresFallbackKeysTable(db *sql.DB) (tables.FallbackKeys, error) {
+ s := &fallbackKeysStatements{
+ db: db,
+ }
+ _, err := db.Exec(fallbackKeysSchema)
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.upsertKeysStmt, upsertFallbackKeysSQL},
+ {&s.selectUnusedAlgorithmsStmt, selectFallbackUnusedAlgorithmsSQL},
+ {&s.selectKeyByAlgorithmStmt, selectFallbackKeysByAlgorithmSQL},
+ {&s.deleteFallbackKeysStmt, deleteFallbackKeysSQL},
+ {&s.updateFallbackKeyUsedStmt, updateFallbackKeyUsedSQL},
+ }.Prepare(db)
+}
+
+func (s *fallbackKeysStatements) SelectUnusedFallbackKeyAlgorithms(ctx context.Context, userID, deviceID string) ([]string, error) {
+ rows, err := s.selectUnusedAlgorithmsStmt.QueryContext(ctx, userID, deviceID)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
+ algos := []string{}
+ for rows.Next() {
+ var algorithm string
+ if err = rows.Scan(&algorithm); err != nil {
+ return nil, err
+ }
+ algos = append(algos, algorithm)
+ }
+ return algos, rows.Err()
+}
+
+func (s *fallbackKeysStatements) InsertFallbackKeys(ctx context.Context, txn *sql.Tx, keys api.FallbackKeys) ([]string, error) {
+ now := time.Now().Unix()
+ for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
+ algo, keyID := keys.Split(keyIDWithAlgo)
+ _, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
+ ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
+ )
+ if err != nil {
+ return nil, err
+ }
+ }
+ return s.SelectUnusedFallbackKeyAlgorithms(ctx, keys.UserID, keys.DeviceID)
+}
+
+func (s *fallbackKeysStatements) DeleteFallbackKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
+ _, err := sqlutil.TxStmt(txn, s.deleteFallbackKeysStmt).ExecContext(ctx, userID, deviceID)
+ return err
+}
+
+func (s *fallbackKeysStatements) SelectAndUpdateFallbackKey(
+ ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
+) (map[string]json.RawMessage, error) {
+ var keyID string
+ var keyJSON string
+ err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return nil, nil
+ }
+ return nil, err
+ }
+ _, err = sqlutil.TxStmtContext(ctx, txn, s.updateFallbackKeyUsedStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
+ return map[string]json.RawMessage{
+ algorithm + ":" + keyID: json.RawMessage(keyJSON),
+ }, err
+}
diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go
index 696e1aa6..c7fb9d29 100644
--- a/userapi/storage/postgres/storage.go
+++ b/userapi/storage/postgres/storage.go
@@ -141,6 +141,10 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
if err != nil {
return nil, err
}
+ fk, err := NewPostgresFallbackKeysTable(db)
+ if err != nil {
+ return nil, err
+ }
dk, err := NewPostgresDeviceKeysTable(db)
if err != nil {
return nil, err
@@ -164,6 +168,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
return &shared.KeyDatabase{
OneTimeKeysTable: otk,
+ FallbackKeysTable: fk,
DeviceKeysTable: dk,
KeyChangesTable: kc,
StaleDeviceListsTable: sdl,
diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go
index 2b1885cd..44ace733 100644
--- a/userapi/storage/shared/storage.go
+++ b/userapi/storage/shared/storage.go
@@ -57,6 +57,7 @@ type Database struct {
type KeyDatabase struct {
OneTimeKeysTable tables.OneTimeKeys
+ FallbackKeysTable tables.FallbackKeys
DeviceKeysTable tables.DeviceKeys
KeyChangesTable tables.KeyChanges
StaleDeviceListsTable tables.StaleDeviceLists
@@ -937,6 +938,22 @@ func (d *KeyDatabase) OneTimeKeysCount(ctx context.Context, userID, deviceID str
return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
}
+func (d *KeyDatabase) StoreFallbackKeys(ctx context.Context, keys api.FallbackKeys) (unused []string, err error) {
+ _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ unused, err = d.FallbackKeysTable.InsertFallbackKeys(ctx, txn, keys)
+ return err
+ })
+ return
+}
+
+func (d *KeyDatabase) DeleteFallbackKeys(ctx context.Context, userID, deviceID string) error {
+ return d.FallbackKeysTable.DeleteFallbackKeys(ctx, nil, userID, deviceID)
+}
+
+func (d *KeyDatabase) UnusedFallbackKeyAlgorithms(ctx context.Context, userID, deviceID string) ([]string, error) {
+ return d.FallbackKeysTable.SelectUnusedFallbackKeyAlgorithms(ctx, userID, deviceID)
+}
+
func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
}
@@ -999,6 +1016,12 @@ func (d *KeyDatabase) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map
if err != nil {
return err
}
+ if len(keyJSON) == 0 {
+ keyJSON, err = d.FallbackKeysTable.SelectAndUpdateFallbackKey(ctx, txn, userID, deviceID, algo)
+ if err != nil {
+ return err
+ }
+ }
if keyJSON != nil {
result = append(result, api.OneTimeKeys{
UserID: userID,
diff --git a/userapi/storage/sqlite3/fallback_keys_table.go b/userapi/storage/sqlite3/fallback_keys_table.go
new file mode 100644
index 00000000..2eb99813
--- /dev/null
+++ b/userapi/storage/sqlite3/fallback_keys_table.go
@@ -0,0 +1,132 @@
+// Copyright 2024 New Vector Ltd.
+// Copyright 2017 Vector Creations Ltd
+//
+// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
+// Please see LICENSE files in the repository root for full details.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "time"
+
+ "github.com/element-hq/dendrite/internal"
+ "github.com/element-hq/dendrite/internal/sqlutil"
+ "github.com/element-hq/dendrite/userapi/api"
+ "github.com/element-hq/dendrite/userapi/storage/tables"
+)
+
+var fallbackKeysSchema = `
+-- Stores one-time public keys for users
+CREATE TABLE IF NOT EXISTS keyserver_fallback_keys (
+ user_id TEXT NOT NULL,
+ device_id TEXT NOT NULL,
+ key_id TEXT NOT NULL,
+ algorithm TEXT NOT NULL,
+ ts_added_secs BIGINT NOT NULL,
+ key_json TEXT NOT NULL,
+ used BOOLEAN NOT NULL
+);
+CREATE UNIQUE INDEX IF NOT EXISTS keyserver_fallback_keys_unique_idx ON keyserver_fallback_keys(user_id, device_id, algorithm);
+CREATE INDEX IF NOT EXISTS keyserver_fallback_keys_idx ON keyserver_fallback_keys (user_id, device_id);
+`
+
+const upsertFallbackKeysSQL = "" +
+ "INSERT INTO keyserver_fallback_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json, used)" +
+ " VALUES ($1, $2, $3, $4, $5, $6, false)" +
+ " ON CONFLICT (user_id, device_id, algorithm)" +
+ " DO UPDATE SET key_id = $3, key_json = $6, used = false"
+
+const selectFallbackUnusedAlgorithmsSQL = "" +
+ "SELECT algorithm FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2 AND used = false"
+
+const selectFallbackKeysByAlgorithmSQL = "" +
+ "SELECT key_id, key_json FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 ORDER BY used ASC LIMIT 1"
+
+const deleteFallbackKeysSQL = "" +
+ "DELETE FROM keyserver_fallback_keys WHERE user_id = $1 AND device_id = $2"
+
+const updateFallbackKeyUsedSQL = "" +
+ "UPDATE keyserver_fallback_keys SET used=true WHERE user_id = $1 AND device_id = $2 AND key_id = $3 AND algorithm = $4"
+
+type fallbackKeysStatements struct {
+ db *sql.DB
+ upsertKeysStmt *sql.Stmt
+ selectUnusedAlgorithmsStmt *sql.Stmt
+ selectKeyByAlgorithmStmt *sql.Stmt
+ deleteFallbackKeysStmt *sql.Stmt
+ updateFallbackKeyUsedStmt *sql.Stmt
+}
+
+func NewSqliteFallbackKeysTable(db *sql.DB) (tables.FallbackKeys, error) {
+ s := &fallbackKeysStatements{
+ db: db,
+ }
+ _, err := db.Exec(fallbackKeysSchema)
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.upsertKeysStmt, upsertFallbackKeysSQL},
+ {&s.selectUnusedAlgorithmsStmt, selectFallbackUnusedAlgorithmsSQL},
+ {&s.selectKeyByAlgorithmStmt, selectFallbackKeysByAlgorithmSQL},
+ {&s.deleteFallbackKeysStmt, deleteFallbackKeysSQL},
+ {&s.updateFallbackKeyUsedStmt, updateFallbackKeyUsedSQL},
+ }.Prepare(db)
+}
+
+func (s *fallbackKeysStatements) SelectUnusedFallbackKeyAlgorithms(ctx context.Context, userID, deviceID string) ([]string, error) {
+ rows, err := s.selectUnusedAlgorithmsStmt.QueryContext(ctx, userID, deviceID)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
+ algos := []string{}
+ for rows.Next() {
+ var algorithm string
+ if err = rows.Scan(&algorithm); err != nil {
+ return nil, err
+ }
+ algos = append(algos, algorithm)
+ }
+ return algos, rows.Err()
+}
+
+func (s *fallbackKeysStatements) InsertFallbackKeys(ctx context.Context, txn *sql.Tx, keys api.FallbackKeys) ([]string, error) {
+ now := time.Now().Unix()
+ for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
+ algo, keyID := keys.Split(keyIDWithAlgo)
+ _, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
+ ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
+ )
+ if err != nil {
+ return nil, err
+ }
+ }
+ return s.SelectUnusedFallbackKeyAlgorithms(ctx, keys.UserID, keys.DeviceID)
+}
+
+func (s *fallbackKeysStatements) DeleteFallbackKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
+ _, err := sqlutil.TxStmt(txn, s.deleteFallbackKeysStmt).ExecContext(ctx, userID, deviceID)
+ return err
+}
+
+func (s *fallbackKeysStatements) SelectAndUpdateFallbackKey(
+ ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
+) (map[string]json.RawMessage, error) {
+ var keyID string
+ var keyJSON string
+ err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return nil, nil
+ }
+ return nil, err
+ }
+ _, err = sqlutil.TxStmtContext(ctx, txn, s.updateFallbackKeyUsedStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
+ return map[string]json.RawMessage{
+ algorithm + ":" + keyID: json.RawMessage(keyJSON),
+ }, err
+}
diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go
index c57cc153..6d906191 100644
--- a/userapi/storage/sqlite3/storage.go
+++ b/userapi/storage/sqlite3/storage.go
@@ -138,6 +138,10 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
if err != nil {
return nil, err
}
+ fk, err := NewSqliteFallbackKeysTable(db)
+ if err != nil {
+ return nil, err
+ }
dk, err := NewSqliteDeviceKeysTable(db)
if err != nil {
return nil, err
@@ -161,6 +165,7 @@ func NewKeyDatabase(conMan *sqlutil.Connections, dbProperties *config.DatabaseOp
return &shared.KeyDatabase{
OneTimeKeysTable: otk,
+ FallbackKeysTable: fk,
DeviceKeysTable: dk,
KeyChangesTable: kc,
StaleDeviceListsTable: sdl,
diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go
index 68198b37..189c1dd8 100644
--- a/userapi/storage/storage_test.go
+++ b/userapi/storage/storage_test.go
@@ -809,3 +809,42 @@ func TestOneTimeKeys(t *testing.T) {
}
})
}
+
+func TestFallbackKeys(t *testing.T) {
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, clean := mustCreateKeyDatabase(t, dbType)
+ defer clean()
+ userID := "@alice:localhost"
+ deviceID := "alice_device"
+ fk := api.FallbackKeys{
+ UserID: userID,
+ DeviceID: deviceID,
+ KeyJSON: map[string]json.RawMessage{"curve25519:KEY1": []byte(`{"key":"v1"}`)},
+ }
+
+ _, err := db.StoreFallbackKeys(ctx, fk)
+ MustNotError(t, err)
+
+ unused, err := db.UnusedFallbackKeyAlgorithms(ctx, userID, deviceID)
+ MustNotError(t, err)
+ if c := len(unused); c != 1 {
+ t.Fatalf("Expected 1 unused key algorithm, got %d", c)
+ }
+ if unused[0] != "curve25519" {
+ t.Fatalf("Expected unused key algorithm to be 'curve25519', got '%s'", unused[0])
+ }
+
+ // No other one-time keys have been uploaded so we expect to get the fallback key instead.
+ claimed, err := db.ClaimKeys(ctx, map[string]map[string]string{userID: {deviceID: "curve25519"}})
+ MustNotError(t, err)
+
+ switch {
+ case claimed[0].UserID != fk.UserID:
+ t.Fatalf("Claimed user ID ID doesn't match, got %q, want %q", claimed[0].UserID, fk.DeviceID)
+ case claimed[0].DeviceID != fk.DeviceID:
+ t.Fatalf("Claimed device ID doesn't match, got %q, want %q", claimed[0].DeviceID, fk.DeviceID)
+ case claimed[0].KeyJSON["curve25519:KEY1"] == nil:
+ t.Fatalf("Claimed key JSON for curve25519:KEY1 not found")
+ }
+ })
+}
diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go
index 7d4cfbae..44f31a5c 100644
--- a/userapi/storage/tables/interface.go
+++ b/userapi/storage/tables/interface.go
@@ -170,6 +170,13 @@ type DeviceKeys interface {
DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
}
+type FallbackKeys interface {
+ SelectUnusedFallbackKeyAlgorithms(ctx context.Context, userID, deviceID string) ([]string, error)
+ InsertFallbackKeys(ctx context.Context, txn *sql.Tx, keys api.FallbackKeys) ([]string, error)
+ DeleteFallbackKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
+ SelectAndUpdateFallbackKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error)
+}
+
type KeyChanges interface {
InsertKeyChange(ctx context.Context, userID string) (int64, error)
// SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets.