aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--keyserver/storage/postgres/one_time_keys_table.go24
1 files changed, 9 insertions, 15 deletions
diff --git a/keyserver/storage/postgres/one_time_keys_table.go b/keyserver/storage/postgres/one_time_keys_table.go
index 6e32838b..cc397ba8 100644
--- a/keyserver/storage/postgres/one_time_keys_table.go
+++ b/keyserver/storage/postgres/one_time_keys_table.go
@@ -20,6 +20,7 @@ import (
"encoding/json"
"time"
+ "github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api"
@@ -47,7 +48,7 @@ const upsertKeysSQL = "" +
" DO UPDATE SET key_json = $6"
const selectKeysSQL = "" +
- "SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2"
+ "SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 AND concat(algorithm, ':', key_id) = ANY($3);"
const selectKeysCountSQL = "" +
"SELECT algorithm, COUNT(key_id) FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 GROUP BY algorithm"
@@ -94,29 +95,22 @@ func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
}
func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
- rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID)
+ rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID, pq.Array(keyIDsWithAlgorithms))
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed")
- wantSet := make(map[string]bool, len(keyIDsWithAlgorithms))
- for _, ka := range keyIDsWithAlgorithms {
- wantSet[ka] = true
- }
-
result := make(map[string]json.RawMessage)
+ var (
+ algorithmWithID string
+ keyJSONStr string
+ )
for rows.Next() {
- var keyID string
- var algorithm string
- var keyJSONStr string
- if err := rows.Scan(&keyID, &algorithm, &keyJSONStr); err != nil {
+ if err := rows.Scan(&algorithmWithID, &keyJSONStr); err != nil {
return nil, err
}
- keyIDWithAlgo := algorithm + ":" + keyID
- if wantSet[keyIDWithAlgo] {
- result[keyIDWithAlgo] = json.RawMessage(keyJSONStr)
- }
+ result[algorithmWithID] = json.RawMessage(keyJSONStr)
}
return result, rows.Err()
}