aboutsummaryrefslogtreecommitdiff
path: root/keyserver/storage/postgres
diff options
context:
space:
mode:
authorKegsay <kegan@matrix.org>2020-07-21 14:47:53 +0100
committerGitHub <noreply@github.com>2020-07-21 14:47:53 +0100
commit1d72ce8b7ab759555503df37af666529749b489c (patch)
tree06ac331afec50a9a92f05062b80db0870f95ac25 /keyserver/storage/postgres
parentd76eb1b99491f644be035a631a08b5874065e4d7 (diff)
Implement claiming one-time keys locally (#1210)
* Add API shape for claiming keys * Implement claiming one-time keys locally Fairly boring, nothing too special going on.
Diffstat (limited to 'keyserver/storage/postgres')
-rw-r--r--keyserver/storage/postgres/one_time_keys_table.go40
1 files changed, 36 insertions, 4 deletions
diff --git a/keyserver/storage/postgres/one_time_keys_table.go b/keyserver/storage/postgres/one_time_keys_table.go
index b8aee72b..a9d05548 100644
--- a/keyserver/storage/postgres/one_time_keys_table.go
+++ b/keyserver/storage/postgres/one_time_keys_table.go
@@ -52,11 +52,19 @@ const selectKeysSQL = "" +
const selectKeysCountSQL = "" +
"SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
+const deleteOneTimeKeySQL = "" +
+ "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4"
+
+const selectKeyByAlgorithmSQL = "" +
+ "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
+
type oneTimeKeysStatements struct {
- db *sql.DB
- upsertKeysStmt *sql.Stmt
- selectKeysStmt *sql.Stmt
- selectKeysCountStmt *sql.Stmt
+ db *sql.DB
+ upsertKeysStmt *sql.Stmt
+ selectKeysStmt *sql.Stmt
+ selectKeysCountStmt *sql.Stmt
+ selectKeyByAlgorithmStmt *sql.Stmt
+ deleteOneTimeKeyStmt *sql.Stmt
}
func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
@@ -76,6 +84,12 @@ func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil {
return nil, err
}
+ if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil {
+ return nil, err
+ }
+ if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil {
+ return nil, err
+ }
return s, nil
}
@@ -141,3 +155,21 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api.
return rows.Err()
})
}
+
+func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
+ ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
+) (map[string]json.RawMessage, error) {
+ var keyID string
+ var keyJSON string
+ err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return nil, nil
+ }
+ return nil, err
+ }
+ _, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
+ return map[string]json.RawMessage{
+ algorithm + ":" + keyID: json.RawMessage(keyJSON),
+ }, err
+}