diff options
author | Kegsay <kegan@matrix.org> | 2020-07-21 14:47:53 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-07-21 14:47:53 +0100 |
commit | 1d72ce8b7ab759555503df37af666529749b489c (patch) | |
tree | 06ac331afec50a9a92f05062b80db0870f95ac25 /keyserver/storage/postgres | |
parent | d76eb1b99491f644be035a631a08b5874065e4d7 (diff) |
Implement claiming one-time keys locally (#1210)
* Add API shape for claiming keys
* Implement claiming one-time keys locally
Fairly boring, nothing too special going on.
Diffstat (limited to 'keyserver/storage/postgres')
-rw-r--r-- | keyserver/storage/postgres/one_time_keys_table.go | 40 |
1 files changed, 36 insertions, 4 deletions
diff --git a/keyserver/storage/postgres/one_time_keys_table.go b/keyserver/storage/postgres/one_time_keys_table.go index b8aee72b..a9d05548 100644 --- a/keyserver/storage/postgres/one_time_keys_table.go +++ b/keyserver/storage/postgres/one_time_keys_table.go @@ -52,11 +52,19 @@ const selectKeysSQL = "" + const selectKeysCountSQL = "" + "SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm" +const deleteOneTimeKeySQL = "" + + "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4" + +const selectKeyByAlgorithmSQL = "" + + "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1" + type oneTimeKeysStatements struct { - db *sql.DB - upsertKeysStmt *sql.Stmt - selectKeysStmt *sql.Stmt - selectKeysCountStmt *sql.Stmt + db *sql.DB + upsertKeysStmt *sql.Stmt + selectKeysStmt *sql.Stmt + selectKeysCountStmt *sql.Stmt + selectKeyByAlgorithmStmt *sql.Stmt + deleteOneTimeKeyStmt *sql.Stmt } func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { @@ -76,6 +84,12 @@ func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { if s.selectKeysCountStmt, err = db.Prepare(selectKeysCountSQL); err != nil { return nil, err } + if s.selectKeyByAlgorithmStmt, err = db.Prepare(selectKeyByAlgorithmSQL); err != nil { + return nil, err + } + if s.deleteOneTimeKeyStmt, err = db.Prepare(deleteOneTimeKeySQL); err != nil { + return nil, err + } return s, nil } @@ -141,3 +155,21 @@ func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, keys api. return rows.Err() }) } + +func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( + ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string, +) (map[string]json.RawMessage, error) { + var keyID string + var keyJSON string + err := txn.StmtContext(ctx, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + _, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) + return map[string]json.RawMessage{ + algorithm + ":" + keyID: json.RawMessage(keyJSON), + }, err +} |