aboutsummaryrefslogtreecommitdiff
path: root/userapi/storage/postgres/device_keys_table.go
diff options
context:
space:
mode:
authorTill <2353100+S7evinK@users.noreply.github.com>2023-02-20 14:58:03 +0100
committerGitHub <noreply@github.com>2023-02-20 14:58:03 +0100
commit4594233f89f8531fca8f696ab0ece36909130c2a (patch)
tree18d3c451041423022e15ba5fcc4a778806ff94dc /userapi/storage/postgres/device_keys_table.go
parentbd6f0c14e56af71d83d703b7c91b8cf829ca560f (diff)
Merge keyserver & userapi (#2972)
As discussed yesterday, a first draft of merging the keyserver and the userapi.
Diffstat (limited to 'userapi/storage/postgres/device_keys_table.go')
-rw-r--r--userapi/storage/postgres/device_keys_table.go213
1 files changed, 213 insertions, 0 deletions
diff --git a/userapi/storage/postgres/device_keys_table.go b/userapi/storage/postgres/device_keys_table.go
new file mode 100644
index 00000000..a9203857
--- /dev/null
+++ b/userapi/storage/postgres/device_keys_table.go
@@ -0,0 +1,213 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "time"
+
+ "github.com/lib/pq"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+)
+
+var deviceKeysSchema = `
+-- Stores device keys for users
+CREATE TABLE IF NOT EXISTS keyserver_device_keys (
+ user_id TEXT NOT NULL,
+ device_id TEXT NOT NULL,
+ ts_added_secs BIGINT NOT NULL,
+ key_json TEXT NOT NULL,
+ -- the stream ID of this key, scoped per-user. This gets updated when the device key changes.
+ -- This means we do not store an unbounded append-only log of device keys, which is not actually
+ -- required in the spec because in the event of a missed update the server fetches the entire
+ -- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs.
+ stream_id BIGINT NOT NULL,
+ display_name TEXT,
+ -- Clobber based on tuple of user/device.
+ CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id)
+);
+`
+
+const upsertDeviceKeysSQL = "" +
+ "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
+ " VALUES ($1, $2, $3, $4, $5, $6)" +
+ " ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" +
+ " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
+
+const selectDeviceKeysSQL = "" +
+ "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
+
+const selectBatchDeviceKeysSQL = "" +
+ "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
+
+const selectBatchDeviceKeysWithEmptiesSQL = "" +
+ "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
+
+const selectMaxStreamForUserSQL = "" +
+ "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
+
+const countStreamIDsForUserSQL = "" +
+ "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)"
+
+const deleteDeviceKeysSQL = "" +
+ "DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
+
+const deleteAllDeviceKeysSQL = "" +
+ "DELETE FROM keyserver_device_keys WHERE user_id=$1"
+
+type deviceKeysStatements struct {
+ db *sql.DB
+ upsertDeviceKeysStmt *sql.Stmt
+ selectDeviceKeysStmt *sql.Stmt
+ selectBatchDeviceKeysStmt *sql.Stmt
+ selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
+ selectMaxStreamForUserStmt *sql.Stmt
+ countStreamIDsForUserStmt *sql.Stmt
+ deleteDeviceKeysStmt *sql.Stmt
+ deleteAllDeviceKeysStmt *sql.Stmt
+}
+
+func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
+ s := &deviceKeysStatements{
+ db: db,
+ }
+ _, err := db.Exec(deviceKeysSchema)
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.upsertDeviceKeysStmt, upsertDeviceKeysSQL},
+ {&s.selectDeviceKeysStmt, selectDeviceKeysSQL},
+ {&s.selectBatchDeviceKeysStmt, selectBatchDeviceKeysSQL},
+ {&s.selectBatchDeviceKeysWithEmptiesStmt, selectBatchDeviceKeysWithEmptiesSQL},
+ {&s.selectMaxStreamForUserStmt, selectMaxStreamForUserSQL},
+ {&s.countStreamIDsForUserStmt, countStreamIDsForUserSQL},
+ {&s.deleteDeviceKeysStmt, deleteDeviceKeysSQL},
+ {&s.deleteAllDeviceKeysStmt, deleteAllDeviceKeysSQL},
+ }.Prepare(db)
+}
+
+func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
+ for i, key := range keys {
+ var keyJSONStr string
+ var streamID int64
+ var displayName sql.NullString
+ err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
+ if err != nil && err != sql.ErrNoRows {
+ return err
+ }
+ // this will be '' when there is no device
+ keys[i].Type = api.TypeDeviceKeyUpdate
+ keys[i].KeyJSON = []byte(keyJSONStr)
+ keys[i].StreamID = streamID
+ if displayName.Valid {
+ keys[i].DisplayName = displayName.String
+ }
+ }
+ return nil
+}
+
+func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) {
+ // nullable if there are no results
+ var nullStream sql.NullInt64
+ err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
+ if err == sql.ErrNoRows {
+ err = nil
+ }
+ if nullStream.Valid {
+ streamID = nullStream.Int64
+ }
+ return
+}
+
+func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) {
+ // nullable if there are no results
+ var count sql.NullInt32
+ err := s.countStreamIDsForUserStmt.QueryRowContext(ctx, userID, pq.Int64Array(streamIDs)).Scan(&count)
+ if err != nil {
+ return 0, err
+ }
+ if count.Valid {
+ return int(count.Int32), nil
+ }
+ return 0, nil
+}
+
+func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
+ for _, key := range keys {
+ now := time.Now().Unix()
+ _, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
+ ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
+ )
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
+ _, err := sqlutil.TxStmt(txn, s.deleteDeviceKeysStmt).ExecContext(ctx, userID, deviceID)
+ return err
+}
+
+func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
+ _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
+ return err
+}
+
+func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
+ var stmt *sql.Stmt
+ if includeEmpty {
+ stmt = s.selectBatchDeviceKeysWithEmptiesStmt
+ } else {
+ stmt = s.selectBatchDeviceKeysStmt
+ }
+ rows, err := stmt.QueryContext(ctx, userID)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
+ deviceIDMap := make(map[string]bool)
+ for _, d := range deviceIDs {
+ deviceIDMap[d] = true
+ }
+ var result []api.DeviceMessage
+ var displayName sql.NullString
+ for rows.Next() {
+ dk := api.DeviceMessage{
+ Type: api.TypeDeviceKeyUpdate,
+ DeviceKeys: &api.DeviceKeys{
+ UserID: userID,
+ },
+ }
+ if err := rows.Scan(&dk.DeviceID, &dk.KeyJSON, &dk.StreamID, &displayName); err != nil {
+ return nil, err
+ }
+ if displayName.Valid {
+ dk.DisplayName = displayName.String
+ }
+ // include the key if we want all keys (no device) or it was asked
+ if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
+ result = append(result, dk)
+ }
+ }
+ return result, rows.Err()
+}