aboutsummaryrefslogtreecommitdiff
path: root/keyserver
diff options
context:
space:
mode:
authorKegsay <kegan@matrix.org>2020-07-21 14:47:53 +0100
committerGitHub <noreply@github.com>2020-07-21 14:47:53 +0100
commit1d72ce8b7ab759555503df37af666529749b489c (patch)
tree06ac331afec50a9a92f05062b80db0870f95ac25 /keyserver
parentd76eb1b99491f644be035a631a08b5874065e4d7 (diff)
Implement claiming one-time keys locally (#1210)
* Add API shape for claiming keys * Implement claiming one-time keys locally Fairly boring, nothing too special going on.
Diffstat (limited to 'keyserver')
-rw-r--r--keyserver/api/api.go9
-rw-r--r--keyserver/internal/internal.go46
-rw-r--r--keyserver/storage/interface.go4
-rw-r--r--keyserver/storage/postgres/one_time_keys_table.go40
-rw-r--r--keyserver/storage/shared/storage.go24
-rw-r--r--keyserver/storage/sqlite3/one_time_keys_table.go40
-rw-r--r--keyserver/storage/tables/interface.go4
7 files changed, 159 insertions, 8 deletions
diff --git a/keyserver/api/api.go b/keyserver/api/api.go
index 0f6cb797..d42fb60c 100644
--- a/keyserver/api/api.go
+++ b/keyserver/api/api.go
@@ -23,6 +23,7 @@ import (
type KeyInternalAPI interface {
PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse)
+ // PerformClaimKeys claims one-time keys for use in pre-key messages
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse)
QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse)
}
@@ -102,9 +103,17 @@ func (r *PerformUploadKeysResponse) KeyError(userID, deviceID string, err *KeyEr
}
type PerformClaimKeysRequest struct {
+ // Map of user_id to device_id to algorithm name
+ OneTimeKeys map[string]map[string]string
+ Timeout time.Duration
}
type PerformClaimKeysResponse struct {
+ // Map of user_id to device_id to algorithm:key_id to key JSON
+ OneTimeKeys map[string]map[string]map[string]json.RawMessage
+ // Map of remote server domain to error JSON
+ Failures map[string]interface{}
+ // Set if there was a fatal error processing this action
Error *KeyError
}
diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go
index 5be87aa4..041732dc 100644
--- a/keyserver/internal/internal.go
+++ b/keyserver/internal/internal.go
@@ -37,9 +37,39 @@ func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perform
a.uploadDeviceKeys(ctx, req, res)
a.uploadOneTimeKeys(ctx, req, res)
}
+
func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) {
+ res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage)
+ res.Failures = make(map[string]interface{})
+ // wrap request map in a top-level by-domain map
+ domainToDeviceKeys := make(map[string]map[string]map[string]string)
+ for userID, val := range req.OneTimeKeys {
+ _, serverName, err := gomatrixserverlib.SplitID('@', userID)
+ if err != nil {
+ continue // ignore invalid users
+ }
+ nested, ok := domainToDeviceKeys[string(serverName)]
+ if !ok {
+ nested = make(map[string]map[string]string)
+ }
+ nested[userID] = val
+ domainToDeviceKeys[string(serverName)] = nested
+ }
+ // claim local keys
+ if local, ok := domainToDeviceKeys[string(a.ThisServer)]; ok {
+ keys, err := a.DB.ClaimKeys(ctx, local)
+ if err != nil {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("failed to ClaimKeys locally: %s", err),
+ }
+ }
+ mergeInto(res.OneTimeKeys, keys)
+ delete(domainToDeviceKeys, string(a.ThisServer))
+ }
+ // TODO: claim remote keys
}
+
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) {
res.DeviceKeys = make(map[string]map[string]json.RawMessage)
res.Failures = make(map[string]interface{})
@@ -166,3 +196,19 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform
func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceKeys) {
// TODO
}
+
+func mergeInto(dst map[string]map[string]map[string]json.RawMessage, src []api.OneTimeKeys) {
+ for _, key := range src {
+ _, ok := dst[key.UserID]
+ if !ok {
+ dst[key.UserID] = make(map[string]map[string]json.RawMessage)
+ }
+ _, ok = dst[key.UserID][key.DeviceID]
+ if !ok {
+ dst[key.UserID][key.DeviceID] = make(map[string]json.RawMessage)
+ }
+ for keyID, keyJSON := range key.KeyJSON {
+ dst[key.UserID][key.DeviceID][keyID] = keyJSON
+ }
+ }
+}
diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go
index a626c66a..7a0328bd 100644
--- a/keyserver/storage/interface.go
+++ b/keyserver/storage/interface.go
@@ -39,4 +39,8 @@ type Database interface {
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error)
+
+ // ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key
+ // cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
+ ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error)
}
diff --git a/keyserver/storage/postgres/one_time_keys_table.go b/keyserver/storage/postgres/one_time_keys_table.go
index b8aee72b..a9d05548 100644
--- a/keyserver/storage/postgres/one_time_keys_table.go
+++ b/keyserver/storage/postgres/one_time_keys_table.go
@@ -52,11 +52,19 @@ const selectKeysSQL = "" +
const selectKeysCountSQL = "" +
"SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
+const deleteOneTimeKeySQL = "" +
+ "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4"
+
+const selectKeyByAlgorithmSQL = "" +
+ "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
+
type oneTimeKeysStatements struct {
- db *sql.DB
- upsertKeysStmt *sql.Stmt
- selectKeysStmt *sql.Stmt
- selectKeysCountStmt *sql.Stmt
+ db *sql.DB
+ upsertKeysStmt *sql.Stmt
+ selectKeysStmt *sql.Stmt
+ selectKeysCountStmt *sql.Stmt
+ selectKeyByAlgorithmStmt *sql.Stmt
+ deleteOneTimeKeyStmt *sql.Stmt
}
func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
@@ -76,6 +84,12 @@ func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil {
return nil, err
}
+ if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil {
+ return nil, err
+ }
+ if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil {
+ return nil, err
+ }
return s, nil
}
@@ -141,3 +155,21 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.
return rows.Err()
})
}
+
+func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
+ ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
+) (map[string]json.RawMessage, error) {
+ var keyID string
+ var keyJSON string
+ err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return nil, nil
+ }
+ return nil, err
+ }
+ _, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
+ return map[string]json.RawMessage{
+ algorithm + ":" + keyID: json.RawMessage(keyJSON),
+ }, err
+}
diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go
index d5ac6458..156b5b41 100644
--- a/keyserver/storage/shared/storage.go
+++ b/keyserver/storage/shared/storage.go
@@ -19,6 +19,7 @@ import (
"database/sql"
"encoding/json"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
)
@@ -48,3 +49,26 @@ func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceKeys) e
func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) {
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs)
}
+
+func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) {
+ var result []api.OneTimeKeys
+ err := sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
+ for userID, deviceToAlgo := range userToDeviceToAlgorithm {
+ for deviceID, algo := range deviceToAlgo {
+ keyJSON, err := d.OneTimeKeysTable.SelectAndDeleteOneTimeKey(ctx, txn, userID, deviceID, algo)
+ if err != nil {
+ return err
+ }
+ if keyJSON != nil {
+ result = append(result, api.OneTimeKeys{
+ UserID: userID,
+ DeviceID: deviceID,
+ KeyJSON: keyJSON,
+ })
+ }
+ }
+ }
+ return nil
+ })
+ return result, err
+}
diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/keyserver/storage/sqlite3/one_time_keys_table.go
index 86e91268..fecf533e 100644
--- a/keyserver/storage/sqlite3/one_time_keys_table.go
+++ b/keyserver/storage/sqlite3/one_time_keys_table.go
@@ -52,11 +52,19 @@ const selectKeysSQL = "" +
const selectKeysCountSQL = "" +
"SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
+const deleteOneTimeKeySQL = "" +
+ "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4"
+
+const selectKeyByAlgorithmSQL = "" +
+ "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
+
type oneTimeKeysStatements struct {
- db *sql.DB
- upsertKeysStmt *sql.Stmt
- selectKeysStmt *sql.Stmt
- selectKeysCountStmt *sql.Stmt
+ db *sql.DB
+ upsertKeysStmt *sql.Stmt
+ selectKeysStmt *sql.Stmt
+ selectKeysCountStmt *sql.Stmt
+ selectKeyByAlgorithmStmt *sql.Stmt
+ deleteOneTimeKeyStmt *sql.Stmt
}
func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
@@ -76,6 +84,12 @@ func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil {
return nil, err
}
+ if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil {
+ return nil, err
+ }
+ if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil {
+ return nil, err
+ }
return s, nil
}
@@ -141,3 +155,21 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.
return rows.Err()
})
}
+
+func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
+ ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
+) (map[string]json.RawMessage, error) {
+ var keyID string
+ var keyJSON string
+ err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return nil, nil
+ }
+ return nil, err
+ }
+ _, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
+ return map[string]json.RawMessage{
+ algorithm + ":" + keyID: json.RawMessage(keyJSON),
+ }, err
+}
diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go
index 1f7f686b..216be773 100644
--- a/keyserver/storage/tables/interface.go
+++ b/keyserver/storage/tables/interface.go
@@ -16,6 +16,7 @@ package tables
import (
"context"
+ "database/sql"
"encoding/json"
"github.com/matrix-org/dendrite/keyserver/api"
@@ -24,6 +25,9 @@ import (
type OneTimeKeys interface {
SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
InsertOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
+ // SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON.
+ // Returns an empty map if the key does not exist.
+ SelectAndDeleteOneTimeKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error)
}
type DeviceKeys interface {