diff options
author | Till <2353100+S7evinK@users.noreply.github.com> | 2023-02-20 14:58:03 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-02-20 14:58:03 +0100 |
commit | 4594233f89f8531fca8f696ab0ece36909130c2a (patch) | |
tree | 18d3c451041423022e15ba5fcc4a778806ff94dc /userapi/storage | |
parent | bd6f0c14e56af71d83d703b7c91b8cf829ca560f (diff) |
Merge keyserver & userapi (#2972)
As discussed yesterday, a first draft of merging the keyserver and the
userapi.
Diffstat (limited to 'userapi/storage')
31 files changed, 2860 insertions, 34 deletions
diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index c22b7658..27837886 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -90,7 +90,7 @@ type KeyBackup interface { type LoginToken interface { // CreateLoginToken generates a token, stores and returns it. The lifetime is - // determined by the loginTokenLifetime given to the Database constructor. + // determined by the loginTokenLifetime given to the UserDatabase constructor. CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) // RemoveLoginToken removes the named token (and may clean up other expired tokens). @@ -130,7 +130,7 @@ type Notification interface { DeleteOldNotifications(ctx context.Context) error } -type Database interface { +type UserDatabase interface { Account AccountData Device @@ -144,6 +144,78 @@ type Database interface { ThreePID } +type KeyChangeDatabase interface { + // StoreKeyChange stores key change metadata and returns the device change ID which represents the position in the /sync stream for this device change. + // `userID` is the the user who has changed their keys in some way. + StoreKeyChange(ctx context.Context, userID string) (int64, error) +} + +type KeyDatabase interface { + KeyChangeDatabase + // ExistingOneTimeKeys returns a map of keyIDWithAlgorithm to key JSON for the given parameters. If no keys exist with this combination + // of user/device/key/algorithm 4-uple then it is omitted from the map. Returns an error when failing to communicate with the database. + ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) + + // StoreOneTimeKeys persists the given one-time keys. + StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) + + // OneTimeKeysCount returns a count of all OTKs for this device. + OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) + + // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. + DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error + + // StoreLocalDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key + // for this (user, device). + // The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set. + // Returns an error if there was a problem storing the keys. + StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error + + // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key + // for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior + // to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly. + StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error + + // PrevIDsExists returns true if all prev IDs exist for this user. + PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) + + // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected. + // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice. + DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) + + // DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying + // cross-signing signatures relating to that device. + DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error + + // ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key + // cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice. + ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) + + // KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive). + // A to offset of types.OffsetNewest means no upper limit. + // Returns the offset of the latest key change. + KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) + + // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. + // If no domains are given, all user IDs with stale device lists are returned. + StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) + + // MarkDeviceListStale sets the stale bit for this user to isStale. + MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error + + CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) + CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) + CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) + + StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error + StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error + + DeleteStaleDeviceLists( + ctx context.Context, + userIDs []string, + ) error +} + type Statistics interface { UserStatistics(ctx context.Context) (*types.UserStatistics, *types.DatabaseEngine, error) DailyRoomsMessages(ctx context.Context, serverName gomatrixserverlib.ServerName) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error) diff --git a/userapi/storage/postgres/account_data_table.go b/userapi/storage/postgres/account_data_table.go index 2a4777d7..05716037 100644 --- a/userapi/storage/postgres/account_data_table.go +++ b/userapi/storage/postgres/account_data_table.go @@ -78,7 +78,13 @@ func (s *accountDataStatements) InsertAccountData( roomID, dataType string, content json.RawMessage, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt) - _, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, content) + // Empty/nil json.RawMessage is not interpreted as "nil", so use *json.RawMessage + // when passing the data to trigger "NOT NULL" constraint + var data *json.RawMessage + if len(content) > 0 { + data = &content + } + _, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, data) return } diff --git a/userapi/storage/postgres/cross_signing_keys_table.go b/userapi/storage/postgres/cross_signing_keys_table.go new file mode 100644 index 00000000..c0ecbd30 --- /dev/null +++ b/userapi/storage/postgres/cross_signing_keys_table.go @@ -0,0 +1,102 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +var crossSigningKeysSchema = ` +CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys ( + user_id TEXT NOT NULL, + key_type SMALLINT NOT NULL, + key_data TEXT NOT NULL, + PRIMARY KEY (user_id, key_type) +); +` + +const selectCrossSigningKeysForUserSQL = "" + + "SELECT key_type, key_data FROM keyserver_cross_signing_keys" + + " WHERE user_id = $1" + +const upsertCrossSigningKeysForUserSQL = "" + + "INSERT INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" + + " VALUES($1, $2, $3)" + + " ON CONFLICT (user_id, key_type) DO UPDATE SET key_data = $3" + +type crossSigningKeysStatements struct { + db *sql.DB + selectCrossSigningKeysForUserStmt *sql.Stmt + upsertCrossSigningKeysForUserStmt *sql.Stmt +} + +func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) { + s := &crossSigningKeysStatements{ + db: db, + } + _, err := db.Exec(crossSigningKeysSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL}, + {&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL}, + }.Prepare(db) +} + +func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( + ctx context.Context, txn *sql.Tx, userID string, +) (r types.CrossSigningKeyMap, err error) { + rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningKeysForUserStmt: rows.close() failed") + r = types.CrossSigningKeyMap{} + for rows.Next() { + var keyTypeInt int16 + var keyData gomatrixserverlib.Base64Bytes + if err := rows.Scan(&keyTypeInt, &keyData); err != nil { + return nil, err + } + keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt] + if !ok { + return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt) + } + r[keyType] = keyData + } + return +} + +func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( + ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes, +) error { + keyTypeInt, ok := types.KeyTypePurposeToInt[keyType] + if !ok { + return fmt.Errorf("unknown key purpose %q", keyType) + } + if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData); err != nil { + return fmt.Errorf("s.upsertCrossSigningKeysForUserStmt: %w", err) + } + return nil +} diff --git a/userapi/storage/postgres/cross_signing_sigs_table.go b/userapi/storage/postgres/cross_signing_sigs_table.go new file mode 100644 index 00000000..b0117145 --- /dev/null +++ b/userapi/storage/postgres/cross_signing_sigs_table.go @@ -0,0 +1,131 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +var crossSigningSigsSchema = ` +CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs ( + origin_user_id TEXT NOT NULL, + origin_key_id TEXT NOT NULL, + target_user_id TEXT NOT NULL, + target_key_id TEXT NOT NULL, + signature TEXT NOT NULL, + PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id) +); + +CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id); +` + +const selectCrossSigningSigsForTargetSQL = "" + + "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" + + " WHERE (origin_user_id = $1 OR origin_user_id = $2) AND target_user_id = $2 AND target_key_id = $3" + +const upsertCrossSigningSigsForTargetSQL = "" + + "INSERT INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" + + " VALUES($1, $2, $3, $4, $5)" + + " ON CONFLICT (origin_user_id, origin_key_id, target_user_id, target_key_id) DO UPDATE SET signature = $5" + +const deleteCrossSigningSigsForTargetSQL = "" + + "DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2" + +type crossSigningSigsStatements struct { + db *sql.DB + selectCrossSigningSigsForTargetStmt *sql.Stmt + upsertCrossSigningSigsForTargetStmt *sql.Stmt + deleteCrossSigningSigsForTargetStmt *sql.Stmt +} + +func NewPostgresCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) { + s := &crossSigningSigsStatements{ + db: db, + } + _, err := db.Exec(crossSigningSigsSchema) + if err != nil { + return nil, err + } + + m := sqlutil.NewMigrator(db) + m.AddMigrations(sqlutil.Migration{ + Version: "keyserver: cross signing signature indexes", + Up: deltas.UpFixCrossSigningSignatureIndexes, + }) + if err = m.Up(context.Background()); err != nil { + return nil, err + } + + return s, sqlutil.StatementList{ + {&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL}, + {&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL}, + {&s.deleteCrossSigningSigsForTargetStmt, deleteCrossSigningSigsForTargetSQL}, + }.Prepare(db) +} + +func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( + ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, +) (r types.CrossSigningSigMap, err error) { + rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetKeyID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForTargetStmt: rows.close() failed") + r = types.CrossSigningSigMap{} + for rows.Next() { + var userID string + var keyID gomatrixserverlib.KeyID + var signature gomatrixserverlib.Base64Bytes + if err := rows.Scan(&userID, &keyID, &signature); err != nil { + return nil, err + } + if _, ok := r[userID]; !ok { + r[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } + r[userID][keyID] = signature + } + return +} + +func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget( + ctx context.Context, txn *sql.Tx, + originUserID string, originKeyID gomatrixserverlib.KeyID, + targetUserID string, targetKeyID gomatrixserverlib.KeyID, + signature gomatrixserverlib.Base64Bytes, +) error { + if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { + return fmt.Errorf("s.upsertCrossSigningSigsForTargetStmt: %w", err) + } + return nil +} + +func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget( + ctx context.Context, txn *sql.Tx, + targetUserID string, targetKeyID gomatrixserverlib.KeyID, +) error { + if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil { + return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err) + } + return nil +} diff --git a/userapi/storage/postgres/deltas/2022012016470000_key_changes.go b/userapi/storage/postgres/deltas/2022012016470000_key_changes.go new file mode 100644 index 00000000..0cfe9e79 --- /dev/null +++ b/userapi/storage/postgres/deltas/2022012016470000_key_changes.go @@ -0,0 +1,69 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "context" + "database/sql" + "fmt" +) + +func UpRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error { + // start counting from the last max offset, else 0. We need to do a count(*) first to see if there + // even are entries in this table to know if we can query for log_offset. Without the count then + // the query to SELECT the max log offset fails on new Dendrite instances as log_offset doesn't + // exist on that table. Even though we discard the error, the txn is tainted and gets aborted :/ + var count int + _ = tx.QueryRowContext(ctx, `SELECT count(*) FROM keyserver_key_changes`).Scan(&count) + if count > 0 { + var maxOffset int64 + _ = tx.QueryRowContext(ctx, `SELECT coalesce(MAX(log_offset), 0) AS offset FROM keyserver_key_changes`).Scan(&maxOffset) + if _, err := tx.ExecContext(ctx, fmt.Sprintf(`CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq START %d`, maxOffset)); err != nil { + return fmt.Errorf("failed to CREATE SEQUENCE for key changes, starting at %d: %s", maxOffset, err) + } + } + + _, err := tx.ExecContext(ctx, ` + -- make the new table + DROP TABLE IF EXISTS keyserver_key_changes; + CREATE TABLE IF NOT EXISTS keyserver_key_changes ( + change_id BIGINT PRIMARY KEY DEFAULT nextval('keyserver_key_changes_seq'), + user_id TEXT NOT NULL, + CONSTRAINT keyserver_key_changes_unique_per_user UNIQUE (user_id) + ); + `) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, ` + -- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers + DROP SEQUENCE IF EXISTS keyserver_key_changes_seq; + DROP TABLE IF EXISTS keyserver_key_changes; + CREATE TABLE IF NOT EXISTS keyserver_key_changes ( + partition BIGINT NOT NULL, + log_offset BIGINT NOT NULL, + user_id TEXT NOT NULL, + CONSTRAINT keyserver_key_changes_unique UNIQUE (partition, log_offset) + ); + `) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/userapi/storage/postgres/deltas/2022042612000000_xsigning_idx.go b/userapi/storage/postgres/deltas/2022042612000000_xsigning_idx.go new file mode 100644 index 00000000..1a3d4fee --- /dev/null +++ b/userapi/storage/postgres/deltas/2022042612000000_xsigning_idx.go @@ -0,0 +1,47 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "context" + "database/sql" + "fmt" +) + +func UpFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, ` + ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey; + ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id); + + CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id); + `) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, ` + ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey; + ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, target_user_id, target_key_id); + + DROP INDEX IF EXISTS keyserver_cross_signing_sigs_idx; + `) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/userapi/storage/postgres/device_keys_table.go b/userapi/storage/postgres/device_keys_table.go new file mode 100644 index 00000000..a9203857 --- /dev/null +++ b/userapi/storage/postgres/device_keys_table.go @@ -0,0 +1,213 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "time" + + "github.com/lib/pq" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" +) + +var deviceKeysSchema = ` +-- Stores device keys for users +CREATE TABLE IF NOT EXISTS keyserver_device_keys ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + ts_added_secs BIGINT NOT NULL, + key_json TEXT NOT NULL, + -- the stream ID of this key, scoped per-user. This gets updated when the device key changes. + -- This means we do not store an unbounded append-only log of device keys, which is not actually + -- required in the spec because in the event of a missed update the server fetches the entire + -- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs. + stream_id BIGINT NOT NULL, + display_name TEXT, + -- Clobber based on tuple of user/device. + CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id) +); +` + +const upsertDeviceKeysSQL = "" + + "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" + + " VALUES ($1, $2, $3, $4, $5, $6)" + + " ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" + + " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6" + +const selectDeviceKeysSQL = "" + + "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + +const selectBatchDeviceKeysSQL = "" + + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''" + +const selectBatchDeviceKeysWithEmptiesSQL = "" + + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1" + +const selectMaxStreamForUserSQL = "" + + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" + +const countStreamIDsForUserSQL = "" + + "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)" + +const deleteDeviceKeysSQL = "" + + "DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + +const deleteAllDeviceKeysSQL = "" + + "DELETE FROM keyserver_device_keys WHERE user_id=$1" + +type deviceKeysStatements struct { + db *sql.DB + upsertDeviceKeysStmt *sql.Stmt + selectDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt + selectMaxStreamForUserStmt *sql.Stmt + countStreamIDsForUserStmt *sql.Stmt + deleteDeviceKeysStmt *sql.Stmt + deleteAllDeviceKeysStmt *sql.Stmt +} + +func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { + s := &deviceKeysStatements{ + db: db, + } + _, err := db.Exec(deviceKeysSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.upsertDeviceKeysStmt, upsertDeviceKeysSQL}, + {&s.selectDeviceKeysStmt, selectDeviceKeysSQL}, + {&s.selectBatchDeviceKeysStmt, selectBatchDeviceKeysSQL}, + {&s.selectBatchDeviceKeysWithEmptiesStmt, selectBatchDeviceKeysWithEmptiesSQL}, + {&s.selectMaxStreamForUserStmt, selectMaxStreamForUserSQL}, + {&s.countStreamIDsForUserStmt, countStreamIDsForUserSQL}, + {&s.deleteDeviceKeysStmt, deleteDeviceKeysSQL}, + {&s.deleteAllDeviceKeysStmt, deleteAllDeviceKeysSQL}, + }.Prepare(db) +} + +func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { + for i, key := range keys { + var keyJSONStr string + var streamID int64 + var displayName sql.NullString + err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName) + if err != nil && err != sql.ErrNoRows { + return err + } + // this will be '' when there is no device + keys[i].Type = api.TypeDeviceKeyUpdate + keys[i].KeyJSON = []byte(keyJSONStr) + keys[i].StreamID = streamID + if displayName.Valid { + keys[i].DisplayName = displayName.String + } + } + return nil +} + +func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) { + // nullable if there are no results + var nullStream sql.NullInt64 + err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) + if err == sql.ErrNoRows { + err = nil + } + if nullStream.Valid { + streamID = nullStream.Int64 + } + return +} + +func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) { + // nullable if there are no results + var count sql.NullInt32 + err := s.countStreamIDsForUserStmt.QueryRowContext(ctx, userID, pq.Int64Array(streamIDs)).Scan(&count) + if err != nil { + return 0, err + } + if count.Valid { + return int(count.Int32), nil + } + return 0, nil +} + +func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { + for _, key := range keys { + now := time.Now().Unix() + _, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext( + ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName, + ) + if err != nil { + return err + } + } + return nil +} + +func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error { + _, err := sqlutil.TxStmt(txn, s.deleteDeviceKeysStmt).ExecContext(ctx, userID, deviceID) + return err +} + +func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error { + _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) + return err +} + +func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { + var stmt *sql.Stmt + if includeEmpty { + stmt = s.selectBatchDeviceKeysWithEmptiesStmt + } else { + stmt = s.selectBatchDeviceKeysStmt + } + rows, err := stmt.QueryContext(ctx, userID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed") + deviceIDMap := make(map[string]bool) + for _, d := range deviceIDs { + deviceIDMap[d] = true + } + var result []api.DeviceMessage + var displayName sql.NullString + for rows.Next() { + dk := api.DeviceMessage{ + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + UserID: userID, + }, + } + if err := rows.Scan(&dk.DeviceID, &dk.KeyJSON, &dk.StreamID, &displayName); err != nil { + return nil, err + } + if displayName.Valid { + dk.DisplayName = displayName.String + } + // include the key if we want all keys (no device) or it was asked + if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { + result = append(result, dk) + } + } + return result, rows.Err() +} diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go index 7481ac5b..88f8839c 100644 --- a/userapi/storage/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -160,7 +160,7 @@ func (s *devicesStatements) InsertDevice( if err := stmt.QueryRowContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil { return nil, fmt.Errorf("insertDeviceStmt: %w", err) } - return &api.Device{ + dev := &api.Device{ ID: id, UserID: userutil.MakeUserID(localpart, serverName), AccessToken: accessToken, @@ -168,7 +168,11 @@ func (s *devicesStatements) InsertDevice( LastSeenTS: createdTimeMS, LastSeenIP: ipAddr, UserAgent: userAgent, - }, nil + } + if displayName != nil { + dev.DisplayName = *displayName + } + return dev, nil } func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id, diff --git a/userapi/storage/postgres/key_backup_table.go b/userapi/storage/postgres/key_backup_table.go index 7b58f7ba..91a34c35 100644 --- a/userapi/storage/postgres/key_backup_table.go +++ b/userapi/storage/postgres/key_backup_table.go @@ -52,7 +52,7 @@ const updateBackupKeySQL = "" + const countKeysSQL = "" + "SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2" -const selectKeysSQL = "" + +const selectBackupKeysSQL = "" + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " + "WHERE user_id = $1 AND version = $2" @@ -83,7 +83,7 @@ func NewPostgresKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) { {&s.insertBackupKeyStmt, insertBackupKeySQL}, {&s.updateBackupKeyStmt, updateBackupKeySQL}, {&s.countKeysStmt, countKeysSQL}, - {&s.selectKeysStmt, selectKeysSQL}, + {&s.selectKeysStmt, selectBackupKeysSQL}, {&s.selectKeysByRoomIDStmt, selectKeysByRoomIDSQL}, {&s.selectKeysByRoomIDAndSessionIDStmt, selectKeysByRoomIDAndSessionIDSQL}, }.Prepare(db) diff --git a/userapi/storage/postgres/key_changes_table.go b/userapi/storage/postgres/key_changes_table.go new file mode 100644 index 00000000..a0049414 --- /dev/null +++ b/userapi/storage/postgres/key_changes_table.go @@ -0,0 +1,127 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" + "github.com/matrix-org/dendrite/userapi/storage/tables" +) + +var keyChangesSchema = ` +-- Stores key change information about users. Used to determine when to send updated device lists to clients. +CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq; +CREATE TABLE IF NOT EXISTS keyserver_key_changes ( + change_id BIGINT PRIMARY KEY DEFAULT nextval('keyserver_key_changes_seq'), + user_id TEXT NOT NULL, + CONSTRAINT keyserver_key_changes_unique_per_user UNIQUE (user_id) +); +` + +// Replace based on user ID. We don't care how many times the user's keys have changed, only that they +// have changed, hence we can just keep bumping the change ID for this user. +const upsertKeyChangeSQL = "" + + "INSERT INTO keyserver_key_changes (user_id)" + + " VALUES ($1)" + + " ON CONFLICT ON CONSTRAINT keyserver_key_changes_unique_per_user" + + " DO UPDATE SET change_id = nextval('keyserver_key_changes_seq')" + + " RETURNING change_id" + +const selectKeyChangesSQL = "" + + "SELECT user_id, change_id FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2" + +type keyChangesStatements struct { + db *sql.DB + upsertKeyChangeStmt *sql.Stmt + selectKeyChangesStmt *sql.Stmt +} + +func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { + s := &keyChangesStatements{ + db: db, + } + _, err := db.Exec(keyChangesSchema) + if err != nil { + return s, err + } + + if err = executeMigration(context.Background(), db); err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.upsertKeyChangeStmt, upsertKeyChangeSQL}, + {&s.selectKeyChangesStmt, selectKeyChangesSQL}, + }.Prepare(db) +} + +func executeMigration(ctx context.Context, db *sql.DB) error { + // TODO: Remove when we are sure we are not having goose artefacts in the db + // This forces an error, which indicates the migration is already applied, since the + // column partition was removed from the table + migrationName := "keyserver: refactor key changes" + + var cName string + err := db.QueryRowContext(ctx, "select column_name from information_schema.columns where table_name = 'keyserver_key_changes' AND column_name = 'partition'").Scan(&cName) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { // migration was already executed, as the column was removed + if err = sqlutil.InsertMigration(ctx, db, migrationName); err != nil { + return fmt.Errorf("unable to manually insert migration '%s': %w", migrationName, err) + } + return nil + } + return err + } + m := sqlutil.NewMigrator(db) + m.AddMigrations(sqlutil.Migration{ + Version: migrationName, + Up: deltas.UpRefactorKeyChanges, + }) + + return m.Up(ctx) +} + +func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) { + err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID) + return +} + +func (s *keyChangesStatements) SelectKeyChanges( + ctx context.Context, fromOffset, toOffset int64, +) (userIDs []string, latestOffset int64, err error) { + latestOffset = fromOffset + rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset) + if err != nil { + return nil, 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed") + for rows.Next() { + var userID string + var offset int64 + if err := rows.Scan(&userID, &offset); err != nil { + return nil, 0, err + } + if offset > latestOffset { + latestOffset = offset + } + userIDs = append(userIDs, userID) + } + return +} diff --git a/userapi/storage/postgres/one_time_keys_table.go b/userapi/storage/postgres/one_time_keys_table.go new file mode 100644 index 00000000..972a5914 --- /dev/null +++ b/userapi/storage/postgres/one_time_keys_table.go @@ -0,0 +1,194 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + "time" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" +) + +var oneTimeKeysSchema = ` +-- Stores one-time public keys for users +CREATE TABLE IF NOT EXISTS keyserver_one_time_keys ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + key_id TEXT NOT NULL, + algorithm TEXT NOT NULL, + ts_added_secs BIGINT NOT NULL, + key_json TEXT NOT NULL, + -- Clobber based on 4-uple of user/device/key/algorithm. + CONSTRAINT keyserver_one_time_keys_unique UNIQUE (user_id, device_id, key_id, algorithm) +); + +CREATE INDEX IF NOT EXISTS keyserver_one_time_keys_idx ON keyserver_one_time_keys (user_id, device_id); +` + +const upsertKeysSQL = "" + + "INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" + + " VALUES ($1, $2, $3, $4, $5, $6)" + + " ON CONFLICT ON CONSTRAINT keyserver_one_time_keys_unique" + + " DO UPDATE SET key_json = $6" + +const selectOneTimeKeysSQL = "" + + "SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 AND concat(algorithm, ':', key_id) = ANY($3);" + +const selectKeysCountSQL = "" + + "SELECT algorithm, COUNT(key_id) FROM " + + " (SELECT algorithm, key_id FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 LIMIT 100)" + + " x GROUP BY algorithm" + +const deleteOneTimeKeySQL = "" + + "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4" + +const selectKeyByAlgorithmSQL = "" + + "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1" + +const deleteOneTimeKeysSQL = "" + + "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2" + +type oneTimeKeysStatements struct { + db *sql.DB + upsertKeysStmt *sql.Stmt + selectKeysStmt *sql.Stmt + selectKeysCountStmt *sql.Stmt + selectKeyByAlgorithmStmt *sql.Stmt + deleteOneTimeKeyStmt *sql.Stmt + deleteOneTimeKeysStmt *sql.Stmt +} + +func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { + s := &oneTimeKeysStatements{ + db: db, + } + _, err := db.Exec(oneTimeKeysSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.upsertKeysStmt, upsertKeysSQL}, + {&s.selectKeysStmt, selectOneTimeKeysSQL}, + {&s.selectKeysCountStmt, selectKeysCountSQL}, + {&s.selectKeyByAlgorithmStmt, selectKeyByAlgorithmSQL}, + {&s.deleteOneTimeKeyStmt, deleteOneTimeKeySQL}, + {&s.deleteOneTimeKeysStmt, deleteOneTimeKeysSQL}, + }.Prepare(db) +} + +func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { + rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID, pq.Array(keyIDsWithAlgorithms)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed") + + result := make(map[string]json.RawMessage) + var ( + algorithmWithID string + keyJSONStr string + ) + for rows.Next() { + if err := rows.Scan(&algorithmWithID, &keyJSONStr); err != nil { + return nil, err + } + result[algorithmWithID] = json.RawMessage(keyJSONStr) + } + return result, rows.Err() +} + +func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { + counts := &api.OneTimeKeysCount{ + DeviceID: deviceID, + UserID: userID, + KeyCount: make(map[string]int), + } + rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") + for rows.Next() { + var algorithm string + var count int + if err = rows.Scan(&algorithm, &count); err != nil { + return nil, err + } + counts.KeyCount[algorithm] = count + } + return counts, nil +} + +func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) { + now := time.Now().Unix() + counts := &api.OneTimeKeysCount{ + DeviceID: keys.DeviceID, + UserID: keys.UserID, + KeyCount: make(map[string]int), + } + for keyIDWithAlgo, keyJSON := range keys.KeyJSON { + algo, keyID := keys.Split(keyIDWithAlgo) + _, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext( + ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON), + ) + if err != nil { + return nil, err + } + } + rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") + for rows.Next() { + var algorithm string + var count int + if err = rows.Scan(&algorithm, &count); err != nil { + return nil, err + } + counts.KeyCount[algorithm] = count + } + + return counts, rows.Err() +} + +func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( + ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string, +) (map[string]json.RawMessage, error) { + var keyID string + var keyJSON string + err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) + return map[string]json.RawMessage{ + algorithm + ":" + keyID: json.RawMessage(keyJSON), + }, err +} + +func (s *oneTimeKeysStatements) DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error { + _, err := sqlutil.TxStmt(txn, s.deleteOneTimeKeysStmt).ExecContext(ctx, userID, deviceID) + return err +} diff --git a/userapi/storage/postgres/stale_device_lists.go b/userapi/storage/postgres/stale_device_lists.go new file mode 100644 index 00000000..c823b58c --- /dev/null +++ b/userapi/storage/postgres/stale_device_lists.go @@ -0,0 +1,131 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "time" + + "github.com/lib/pq" + + "github.com/matrix-org/dendrite/internal/sqlutil" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" +) + +var staleDeviceListsSchema = ` +-- Stores whether a user's device lists are stale or not. +CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists ( + user_id TEXT PRIMARY KEY NOT NULL, + domain TEXT NOT NULL, + is_stale BOOLEAN NOT NULL, + ts_added_secs BIGINT NOT NULL +); + +CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale); +` + +const upsertStaleDeviceListSQL = "" + + "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" + + " VALUES ($1, $2, $3, $4)" + + " ON CONFLICT (user_id)" + + " DO UPDATE SET is_stale = $3, ts_added_secs = $4" + +const selectStaleDeviceListsWithDomainsSQL = "" + + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC" + +const selectStaleDeviceListsSQL = "" + + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC" + +const deleteStaleDevicesSQL = "" + + "DELETE FROM keyserver_stale_device_lists WHERE user_id = ANY($1)" + +type staleDeviceListsStatements struct { + upsertStaleDeviceListStmt *sql.Stmt + selectStaleDeviceListsWithDomainsStmt *sql.Stmt + selectStaleDeviceListsStmt *sql.Stmt + deleteStaleDeviceListsStmt *sql.Stmt +} + +func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { + s := &staleDeviceListsStatements{} + _, err := db.Exec(staleDeviceListsSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL}, + {&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL}, + {&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL}, + {&s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL}, + }.Prepare(db) +} + +func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error { + _, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return err + } + _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now())) + return err +} + +func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { + // we only query for 1 domain or all domains so optimise for those use cases + if len(domains) == 0 { + rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true) + if err != nil { + return nil, err + } + return rowsToUserIDs(ctx, rows) + } + var result []string + for _, domain := range domains { + rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain)) + if err != nil { + return nil, err + } + userIDs, err := rowsToUserIDs(ctx, rows) + if err != nil { + return nil, err + } + result = append(result, userIDs...) + } + return result, nil +} + +// DeleteStaleDeviceLists removes users from stale device lists +func (s *staleDeviceListsStatements) DeleteStaleDeviceLists( + ctx context.Context, txn *sql.Tx, userIDs []string, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteStaleDeviceListsStmt) + _, err := stmt.ExecContext(ctx, pq.Array(userIDs)) + return err +} + +func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) { + defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed") + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return nil, err + } + result = append(result, userID) + } + return result, rows.Err() +} diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index 92dc4808..673d123b 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -136,3 +136,44 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, OpenIDTokenLifetimeMS: openIDTokenLifetimeMS, }, nil } + +func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) { + db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()) + if err != nil { + return nil, err + } + otk, err := NewPostgresOneTimeKeysTable(db) + if err != nil { + return nil, err + } + dk, err := NewPostgresDeviceKeysTable(db) + if err != nil { + return nil, err + } + kc, err := NewPostgresKeyChangesTable(db) + if err != nil { + return nil, err + } + sdl, err := NewPostgresStaleDeviceListsTable(db) + if err != nil { + return nil, err + } + csk, err := NewPostgresCrossSigningKeysTable(db) + if err != nil { + return nil, err + } + css, err := NewPostgresCrossSigningSigsTable(db) + if err != nil { + return nil, err + } + + return &shared.KeyDatabase{ + OneTimeKeysTable: otk, + DeviceKeysTable: dk, + KeyChangesTable: kc, + StaleDeviceListsTable: sdl, + CrossSigningKeysTable: csk, + CrossSigningSigsTable: css, + Writer: writer, + }, nil +} diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index bf94f14d..d3272a03 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -59,6 +59,17 @@ type Database struct { OpenIDTokenLifetimeMS int64 } +type KeyDatabase struct { + OneTimeKeysTable tables.OneTimeKeys + DeviceKeysTable tables.DeviceKeys + KeyChangesTable tables.KeyChanges + StaleDeviceListsTable tables.StaleDeviceLists + CrossSigningKeysTable tables.CrossSigningKeys + CrossSigningSigsTable tables.CrossSigningSigs + DB *sql.DB + Writer sqlutil.Writer +} + const ( // The length of generated device IDs deviceIDByteLength = 6 @@ -875,3 +886,227 @@ func (d *Database) DailyRoomsMessages( ) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error) { return d.Stats.DailyRoomsMessages(ctx, nil, serverName) } + +// + +func (d *KeyDatabase) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { + return d.OneTimeKeysTable.SelectOneTimeKeys(ctx, userID, deviceID, keyIDsWithAlgorithms) +} + +func (d *KeyDatabase) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (counts *api.OneTimeKeysCount, err error) { + _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + counts, err = d.OneTimeKeysTable.InsertOneTimeKeys(ctx, txn, keys) + return err + }) + return +} + +func (d *KeyDatabase) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { + return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID) +} + +func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { + return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys) +} + +func (d *KeyDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) { + count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, prevIDs) + if err != nil { + return false, err + } + return count == len(prevIDs), nil +} + +func (d *KeyDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + for _, userID := range clearUserIDs { + err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID) + if err != nil { + return err + } + } + return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys) + }) +} + +func (d *KeyDatabase) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error { + // work out the latest stream IDs for each user + userIDToStreamID := make(map[string]int64) + for _, k := range keys { + userIDToStreamID[k.UserID] = 0 + } + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + for userID := range userIDToStreamID { + streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID) + if err != nil { + return err + } + userIDToStreamID[userID] = streamID + } + // set the stream IDs for each key + for i := range keys { + k := keys[i] + userIDToStreamID[k.UserID]++ // start stream from 1 + k.StreamID = userIDToStreamID[k.UserID] + keys[i] = k + } + return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys) + }) +} + +func (d *KeyDatabase) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { + return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty) +} + +func (d *KeyDatabase) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) { + var result []api.OneTimeKeys + err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + for userID, deviceToAlgo := range userToDeviceToAlgorithm { + for deviceID, algo := range deviceToAlgo { + keyJSON, err := d.OneTimeKeysTable.SelectAndDeleteOneTimeKey(ctx, txn, userID, deviceID, algo) + if err != nil { + return err + } + if keyJSON != nil { + result = append(result, api.OneTimeKeys{ + UserID: userID, + DeviceID: deviceID, + KeyJSON: keyJSON, + }) + } + } + } + return nil + }) + return result, err +} + +func (d *KeyDatabase) StoreKeyChange(ctx context.Context, userID string) (id int64, err error) { + err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error { + id, err = d.KeyChangesTable.InsertKeyChange(ctx, userID) + return err + }) + return +} + +func (d *KeyDatabase) KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) { + return d.KeyChangesTable.SelectKeyChanges(ctx, fromOffset, toOffset) +} + +// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. +// If no domains are given, all user IDs with stale device lists are returned. +func (d *KeyDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { + return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains) +} + +// MarkDeviceListStale sets the stale bit for this user to isStale. +func (d *KeyDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error { + return d.Writer.Do(nil, nil, func(_ *sql.Tx) error { + return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale) + }) +} + +// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying +// cross-signing signatures relating to that device. +func (d *KeyDatabase) DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + for _, deviceID := range deviceIDs { + if err := d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget(ctx, txn, userID, deviceID); err != nil && err != sql.ErrNoRows { + return fmt.Errorf("d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget: %w", err) + } + if err := d.DeviceKeysTable.DeleteDeviceKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows { + return fmt.Errorf("d.DeviceKeysTable.DeleteDeviceKeys: %w", err) + } + if err := d.OneTimeKeysTable.DeleteOneTimeKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows { + return fmt.Errorf("d.OneTimeKeysTable.DeleteOneTimeKeys: %w", err) + } + } + return nil + }) +} + +// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any. +func (d *KeyDatabase) CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) { + keyMap, err := d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID) + if err != nil { + return nil, fmt.Errorf("d.CrossSigningKeysTable.SelectCrossSigningKeysForUser: %w", err) + } + results := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{} + for purpose, key := range keyMap { + keyID := gomatrixserverlib.KeyID("ed25519:" + key.Encode()) + result := gomatrixserverlib.CrossSigningKey{ + UserID: userID, + Usage: []gomatrixserverlib.CrossSigningKeyPurpose{purpose}, + Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{ + keyID: key, + }, + } + sigMap, err := d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, userID, userID, keyID) + if err != nil { + continue + } + for sigUserID, forSigUserID := range sigMap { + if userID != sigUserID { + continue + } + if result.Signatures == nil { + result.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } + if _, ok := result.Signatures[sigUserID]; !ok { + result.Signatures[sigUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } + for sigKeyID, sigBytes := range forSigUserID { + result.Signatures[sigUserID][sigKeyID] = sigBytes + } + } + results[purpose] = result + } + return results, nil +} + +// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any. +func (d *KeyDatabase) CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) { + return d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID) +} + +// CrossSigningSigsForTarget returns the signatures for a given user's key ID, if any. +func (d *KeyDatabase) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) { + return d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, originUserID, targetUserID, targetKeyID) +} + +// StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user. +func (d *KeyDatabase) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + for keyType, keyData := range keyMap { + if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, keyData); err != nil { + return fmt.Errorf("d.CrossSigningKeysTable.InsertCrossSigningKeysForUser: %w", err) + } + } + return nil + }) +} + +// StoreCrossSigningSigsForTarget stores a signature for a target user ID and key/dvice. +func (d *KeyDatabase) StoreCrossSigningSigsForTarget( + ctx context.Context, + originUserID string, originKeyID gomatrixserverlib.KeyID, + targetUserID string, targetKeyID gomatrixserverlib.KeyID, + signature gomatrixserverlib.Base64Bytes, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + if err := d.CrossSigningSigsTable.UpsertCrossSigningSigsForTarget(ctx, nil, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { + return fmt.Errorf("d.CrossSigningSigsTable.InsertCrossSigningSigsForTarget: %w", err) + } + return nil + }) +} + +// DeleteStaleDeviceLists deletes stale device list entries for users we don't share a room with anymore. +func (d *KeyDatabase) DeleteStaleDeviceLists( + ctx context.Context, + userIDs []string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.StaleDeviceListsTable.DeleteStaleDeviceLists(ctx, txn, userIDs) + }) +} diff --git a/userapi/storage/sqlite3/cross_signing_keys_table.go b/userapi/storage/sqlite3/cross_signing_keys_table.go new file mode 100644 index 00000000..10721fcc --- /dev/null +++ b/userapi/storage/sqlite3/cross_signing_keys_table.go @@ -0,0 +1,101 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +var crossSigningKeysSchema = ` +CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys ( + user_id TEXT NOT NULL, + key_type INTEGER NOT NULL, + key_data TEXT NOT NULL, + PRIMARY KEY (user_id, key_type) +); +` + +const selectCrossSigningKeysForUserSQL = "" + + "SELECT key_type, key_data FROM keyserver_cross_signing_keys" + + " WHERE user_id = $1" + +const upsertCrossSigningKeysForUserSQL = "" + + "INSERT OR REPLACE INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" + + " VALUES($1, $2, $3)" + +type crossSigningKeysStatements struct { + db *sql.DB + selectCrossSigningKeysForUserStmt *sql.Stmt + upsertCrossSigningKeysForUserStmt *sql.Stmt +} + +func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) { + s := &crossSigningKeysStatements{ + db: db, + } + _, err := db.Exec(crossSigningKeysSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL}, + {&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL}, + }.Prepare(db) +} + +func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser( + ctx context.Context, txn *sql.Tx, userID string, +) (r types.CrossSigningKeyMap, err error) { + rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningKeysForUserStmt: rows.close() failed") + r = types.CrossSigningKeyMap{} + for rows.Next() { + var keyTypeInt int16 + var keyData gomatrixserverlib.Base64Bytes + if err := rows.Scan(&keyTypeInt, &keyData); err != nil { + return nil, err + } + keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt] + if !ok { + return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt) + } + r[keyType] = keyData + } + return +} + +func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser( + ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes, +) error { + keyTypeInt, ok := types.KeyTypePurposeToInt[keyType] + if !ok { + return fmt.Errorf("unknown key purpose %q", keyType) + } + if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData); err != nil { + return fmt.Errorf("s.upsertCrossSigningKeysForUserStmt: %w", err) + } + return nil +} diff --git a/userapi/storage/sqlite3/cross_signing_sigs_table.go b/userapi/storage/sqlite3/cross_signing_sigs_table.go new file mode 100644 index 00000000..2be00c9c --- /dev/null +++ b/userapi/storage/sqlite3/cross_signing_sigs_table.go @@ -0,0 +1,129 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/dendrite/userapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +var crossSigningSigsSchema = ` +CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs ( + origin_user_id TEXT NOT NULL, + origin_key_id TEXT NOT NULL, + target_user_id TEXT NOT NULL, + target_key_id TEXT NOT NULL, + signature TEXT NOT NULL, + PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id) +); + +CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id); +` + +const selectCrossSigningSigsForTargetSQL = "" + + "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" + + " WHERE (origin_user_id = $1 OR origin_user_id = $2) AND target_user_id = $3 AND target_key_id = $4" + +const upsertCrossSigningSigsForTargetSQL = "" + + "INSERT OR REPLACE INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" + + " VALUES($1, $2, $3, $4, $5)" + +const deleteCrossSigningSigsForTargetSQL = "" + + "DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2" + +type crossSigningSigsStatements struct { + db *sql.DB + selectCrossSigningSigsForTargetStmt *sql.Stmt + upsertCrossSigningSigsForTargetStmt *sql.Stmt + deleteCrossSigningSigsForTargetStmt *sql.Stmt +} + +func NewSqliteCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) { + s := &crossSigningSigsStatements{ + db: db, + } + _, err := db.Exec(crossSigningSigsSchema) + if err != nil { + return nil, err + } + m := sqlutil.NewMigrator(db) + m.AddMigrations(sqlutil.Migration{ + Version: "keyserver: cross signing signature indexes", + Up: deltas.UpFixCrossSigningSignatureIndexes, + }) + if err = m.Up(context.Background()); err != nil { + return nil, err + } + + return s, sqlutil.StatementList{ + {&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL}, + {&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL}, + {&s.deleteCrossSigningSigsForTargetStmt, deleteCrossSigningSigsForTargetSQL}, + }.Prepare(db) +} + +func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget( + ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, +) (r types.CrossSigningSigMap, err error) { + rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetUserID, targetKeyID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForOriginTargetStmt: rows.close() failed") + r = types.CrossSigningSigMap{} + for rows.Next() { + var userID string + var keyID gomatrixserverlib.KeyID + var signature gomatrixserverlib.Base64Bytes + if err := rows.Scan(&userID, &keyID, &signature); err != nil { + return nil, err + } + if _, ok := r[userID]; !ok { + r[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{} + } + r[userID][keyID] = signature + } + return +} + +func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget( + ctx context.Context, txn *sql.Tx, + originUserID string, originKeyID gomatrixserverlib.KeyID, + targetUserID string, targetKeyID gomatrixserverlib.KeyID, + signature gomatrixserverlib.Base64Bytes, +) error { + if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil { + return fmt.Errorf("s.upsertCrossSigningSigsForTargetStmt: %w", err) + } + return nil +} + +func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget( + ctx context.Context, txn *sql.Tx, + targetUserID string, targetKeyID gomatrixserverlib.KeyID, +) error { + if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil { + return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err) + } + return nil +} diff --git a/userapi/storage/sqlite3/deltas/2022012016470000_key_changes.go b/userapi/storage/sqlite3/deltas/2022012016470000_key_changes.go new file mode 100644 index 00000000..cd0f19df --- /dev/null +++ b/userapi/storage/sqlite3/deltas/2022012016470000_key_changes.go @@ -0,0 +1,66 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "context" + "database/sql" + "fmt" +) + +func UpRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error { + // start counting from the last max offset, else 0. + var maxOffset int64 + var userID string + _ = tx.QueryRowContext(ctx, `SELECT user_id, MAX(log_offset) FROM keyserver_key_changes GROUP BY user_id`).Scan(&userID, &maxOffset) + + _, err := tx.ExecContext(ctx, ` + -- make the new table + DROP TABLE IF EXISTS keyserver_key_changes; + CREATE TABLE IF NOT EXISTS keyserver_key_changes ( + change_id INTEGER PRIMARY KEY AUTOINCREMENT, + -- The key owner + user_id TEXT NOT NULL, + UNIQUE (user_id) + ); + `) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + // to start counting from maxOffset, insert a row with that value + if userID != "" { + _, err = tx.ExecContext(ctx, `INSERT INTO keyserver_key_changes(change_id, user_id) VALUES($1, $2)`, maxOffset, userID) + return err + } + return nil +} + +func DownRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, ` + -- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers + DROP TABLE IF EXISTS keyserver_key_changes; + CREATE TABLE IF NOT EXISTS keyserver_key_changes ( + partition BIGINT NOT NULL, + offset BIGINT NOT NULL, + -- The key owner + user_id TEXT NOT NULL, + UNIQUE (partition, offset) + ); + `) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/userapi/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go b/userapi/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go new file mode 100644 index 00000000..d4e38dea --- /dev/null +++ b/userapi/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go @@ -0,0 +1,71 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "context" + "database/sql" + "fmt" +) + +func UpFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, ` + CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp ( + origin_user_id TEXT NOT NULL, + origin_key_id TEXT NOT NULL, + target_user_id TEXT NOT NULL, + target_key_id TEXT NOT NULL, + signature TEXT NOT NULL, + PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id) + ); + + INSERT INTO keyserver_cross_signing_sigs_tmp (origin_user_id, origin_key_id, target_user_id, target_key_id, signature) + SELECT origin_user_id, origin_key_id, target_user_id, target_key_id, signature FROM keyserver_cross_signing_sigs; + + DROP TABLE keyserver_cross_signing_sigs; + ALTER TABLE keyserver_cross_signing_sigs_tmp RENAME TO keyserver_cross_signing_sigs; + + CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id); + `) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, ` + CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp ( + origin_user_id TEXT NOT NULL, + origin_key_id TEXT NOT NULL, + target_user_id TEXT NOT NULL, + target_key_id TEXT NOT NULL, + signature TEXT NOT NULL, + PRIMARY KEY (origin_user_id, target_user_id, target_key_id) + ); + + INSERT INTO keyserver_cross_signing_sigs_tmp (origin_user_id, origin_key_id, target_user_id, target_key_id, signature) + SELECT origin_user_id, origin_key_id, target_user_id, target_key_id, signature FROM keyserver_cross_signing_sigs; + + DROP TABLE keyserver_cross_signing_sigs; + ALTER TABLE keyserver_cross_signing_sigs_tmp RENAME TO keyserver_cross_signing_sigs; + + DELETE INDEX IF EXISTS keyserver_cross_signing_sigs_idx; + `) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/userapi/storage/sqlite3/device_keys_table.go b/userapi/storage/sqlite3/device_keys_table.go new file mode 100644 index 00000000..15e69cc4 --- /dev/null +++ b/userapi/storage/sqlite3/device_keys_table.go @@ -0,0 +1,213 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "strings" + "time" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" +) + +var deviceKeysSchema = ` +-- Stores device keys for users +CREATE TABLE IF NOT EXISTS keyserver_device_keys ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + ts_added_secs BIGINT NOT NULL, + key_json TEXT NOT NULL, + stream_id BIGINT NOT NULL, + display_name TEXT, + -- Clobber based on tuple of user/device. + UNIQUE (user_id, device_id) +); +` + +const upsertDeviceKeysSQL = "" + + "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" + + " VALUES ($1, $2, $3, $4, $5, $6)" + + " ON CONFLICT (user_id, device_id)" + + " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6" + +const selectDeviceKeysSQL = "" + + "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + +const selectBatchDeviceKeysSQL = "" + + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''" + +const selectBatchDeviceKeysWithEmptiesSQL = "" + + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1" + +const selectMaxStreamForUserSQL = "" + + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" + +const countStreamIDsForUserSQL = "" + + "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)" + +const deleteDeviceKeysSQL = "" + + "DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + +const deleteAllDeviceKeysSQL = "" + + "DELETE FROM keyserver_device_keys WHERE user_id=$1" + +type deviceKeysStatements struct { + db *sql.DB + upsertDeviceKeysStmt *sql.Stmt + selectDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysStmt *sql.Stmt + selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt + selectMaxStreamForUserStmt *sql.Stmt + deleteDeviceKeysStmt *sql.Stmt + deleteAllDeviceKeysStmt *sql.Stmt +} + +func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { + s := &deviceKeysStatements{ + db: db, + } + _, err := db.Exec(deviceKeysSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.upsertDeviceKeysStmt, upsertDeviceKeysSQL}, + {&s.selectDeviceKeysStmt, selectDeviceKeysSQL}, + {&s.selectBatchDeviceKeysStmt, selectBatchDeviceKeysSQL}, + {&s.selectBatchDeviceKeysWithEmptiesStmt, selectBatchDeviceKeysWithEmptiesSQL}, + {&s.selectMaxStreamForUserStmt, selectMaxStreamForUserSQL}, + // {&s.countStreamIDsForUserStmt, countStreamIDsForUserSQL}, // prepared at runtime + {&s.deleteDeviceKeysStmt, deleteDeviceKeysSQL}, + {&s.deleteAllDeviceKeysStmt, deleteAllDeviceKeysSQL}, + }.Prepare(db) +} + +func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error { + _, err := sqlutil.TxStmt(txn, s.deleteDeviceKeysStmt).ExecContext(ctx, userID, deviceID) + return err +} + +func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error { + _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID) + return err +} + +func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) { + deviceIDMap := make(map[string]bool) + for _, d := range deviceIDs { + deviceIDMap[d] = true + } + var stmt *sql.Stmt + if includeEmpty { + stmt = s.selectBatchDeviceKeysWithEmptiesStmt + } else { + stmt = s.selectBatchDeviceKeysStmt + } + rows, err := stmt.QueryContext(ctx, userID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed") + var result []api.DeviceMessage + var displayName sql.NullString + for rows.Next() { + dk := api.DeviceMessage{ + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + UserID: userID, + }, + } + if err := rows.Scan(&dk.DeviceID, &dk.KeyJSON, &dk.StreamID, &displayName); err != nil { + return nil, err + } + if displayName.Valid { + dk.DisplayName = displayName.String + } + // include the key if we want all keys (no device) or it was asked + if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { + result = append(result, dk) + } + } + return result, rows.Err() +} + +func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { + for i, key := range keys { + var keyJSONStr string + var streamID int64 + var displayName sql.NullString + err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName) + if err != nil && err != sql.ErrNoRows { + return err + } + // this will be '' when there is no device + keys[i].Type = api.TypeDeviceKeyUpdate + keys[i].KeyJSON = []byte(keyJSONStr) + keys[i].StreamID = streamID + if displayName.Valid { + keys[i].DisplayName = displayName.String + } + } + return nil +} + +func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) { + // nullable if there are no results + var nullStream sql.NullInt64 + err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream) + if err == sql.ErrNoRows { + err = nil + } + if nullStream.Valid { + streamID = nullStream.Int64 + } + return +} + +func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) { + iStreamIDs := make([]interface{}, len(streamIDs)+1) + iStreamIDs[0] = userID + for i := range streamIDs { + iStreamIDs[i+1] = streamIDs[i] + } + query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1) + // nullable if there are no results + var count sql.NullInt64 + err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count) + if err != nil { + return 0, err + } + if count.Valid { + return int(count.Int64), nil + } + return 0, nil +} + +func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error { + for _, key := range keys { + now := time.Now().Unix() + _, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext( + ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName, + ) + if err != nil { + return err + } + } + return nil +} diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go index 449e4549..65e17527 100644 --- a/userapi/storage/sqlite3/devices_table.go +++ b/userapi/storage/sqlite3/devices_table.go @@ -151,7 +151,7 @@ func (s *devicesStatements) InsertDevice( if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil { return nil, err } - return &api.Device{ + dev := &api.Device{ ID: id, UserID: userutil.MakeUserID(localpart, serverName), AccessToken: accessToken, @@ -159,7 +159,11 @@ func (s *devicesStatements) InsertDevice( LastSeenTS: createdTimeMS, LastSeenIP: ipAddr, UserAgent: userAgent, - }, nil + } + if displayName != nil { + dev.DisplayName = *displayName + } + return dev, nil } func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id, @@ -172,7 +176,7 @@ func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn * if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil { return nil, err } - return &api.Device{ + dev := &api.Device{ ID: id, UserID: userutil.MakeUserID(localpart, serverName), AccessToken: accessToken, @@ -180,7 +184,11 @@ func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn * LastSeenTS: createdTimeMS, LastSeenIP: ipAddr, UserAgent: userAgent, - }, nil + } + if displayName != nil { + dev.DisplayName = *displayName + } + return dev, nil } func (s *devicesStatements) DeleteDevice( @@ -202,6 +210,7 @@ func (s *devicesStatements) DeleteDevices( if err != nil { return err } + defer internal.CloseAndLogIfError(ctx, prep, "DeleteDevices.StmtClose() failed") stmt := sqlutil.TxStmt(txn, prep) params := make([]interface{}, len(devices)+2) params[0] = localpart diff --git a/userapi/storage/sqlite3/key_backup_table.go b/userapi/storage/sqlite3/key_backup_table.go index 7883ffb1..ed274631 100644 --- a/userapi/storage/sqlite3/key_backup_table.go +++ b/userapi/storage/sqlite3/key_backup_table.go @@ -52,7 +52,7 @@ const updateBackupKeySQL = "" + const countKeysSQL = "" + "SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2" -const selectKeysSQL = "" + +const selectBackupKeysSQL = "" + "SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " + "WHERE user_id = $1 AND version = $2" @@ -83,7 +83,7 @@ func NewSQLiteKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) { {&s.insertBackupKeyStmt, insertBackupKeySQL}, {&s.updateBackupKeyStmt, updateBackupKeySQL}, {&s.countKeysStmt, countKeysSQL}, - {&s.selectKeysStmt, selectKeysSQL}, + {&s.selectKeysStmt, selectBackupKeysSQL}, {&s.selectKeysByRoomIDStmt, selectKeysByRoomIDSQL}, {&s.selectKeysByRoomIDAndSessionIDStmt, selectKeysByRoomIDAndSessionIDSQL}, }.Prepare(db) diff --git a/userapi/storage/sqlite3/key_changes_table.go b/userapi/storage/sqlite3/key_changes_table.go new file mode 100644 index 00000000..923bb57e --- /dev/null +++ b/userapi/storage/sqlite3/key_changes_table.go @@ -0,0 +1,125 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" + "github.com/matrix-org/dendrite/userapi/storage/tables" +) + +var keyChangesSchema = ` +-- Stores key change information about users. Used to determine when to send updated device lists to clients. +CREATE TABLE IF NOT EXISTS keyserver_key_changes ( + change_id INTEGER PRIMARY KEY AUTOINCREMENT, + -- The key owner + user_id TEXT NOT NULL, + UNIQUE (user_id) +); +` + +// Replace based on user ID. We don't care how many times the user's keys have changed, only that they +// have changed, hence we can just keep bumping the change ID for this user. +const upsertKeyChangeSQL = "" + + "INSERT OR REPLACE INTO keyserver_key_changes (user_id)" + + " VALUES ($1)" + + " RETURNING change_id" + +const selectKeyChangesSQL = "" + + "SELECT user_id, change_id FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2" + +type keyChangesStatements struct { + db *sql.DB + upsertKeyChangeStmt *sql.Stmt + selectKeyChangesStmt *sql.Stmt +} + +func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { + s := &keyChangesStatements{ + db: db, + } + _, err := db.Exec(keyChangesSchema) + if err != nil { + return s, err + } + + if err = executeMigration(context.Background(), db); err != nil { + return nil, err + } + + return s, sqlutil.StatementList{ + {&s.upsertKeyChangeStmt, upsertKeyChangeSQL}, + {&s.selectKeyChangesStmt, selectKeyChangesSQL}, + }.Prepare(db) +} + +func executeMigration(ctx context.Context, db *sql.DB) error { + // TODO: Remove when we are sure we are not having goose artefacts in the db + // This forces an error, which indicates the migration is already applied, since the + // column partition was removed from the table + migrationName := "keyserver: refactor key changes" + + var cName string + err := db.QueryRowContext(ctx, `SELECT p.name FROM sqlite_master AS m JOIN pragma_table_info(m.name) AS p WHERE m.name = 'keyserver_key_changes' AND p.name = 'partition'`).Scan(&cName) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { // migration was already executed, as the column was removed + if err = sqlutil.InsertMigration(ctx, db, migrationName); err != nil { + return fmt.Errorf("unable to manually insert migration '%s': %w", migrationName, err) + } + return nil + } + return err + } + m := sqlutil.NewMigrator(db) + m.AddMigrations(sqlutil.Migration{ + Version: migrationName, + Up: deltas.UpRefactorKeyChanges, + }) + return m.Up(ctx) +} + +func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) { + err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID) + return +} + +func (s *keyChangesStatements) SelectKeyChanges( + ctx context.Context, fromOffset, toOffset int64, +) (userIDs []string, latestOffset int64, err error) { + latestOffset = fromOffset + rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset) + if err != nil { + return nil, 0, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed") + for rows.Next() { + var userID string + var offset int64 + if err := rows.Scan(&userID, &offset); err != nil { + return nil, 0, err + } + if offset > latestOffset { + latestOffset = offset + } + userIDs = append(userIDs, userID) + } + return +} diff --git a/userapi/storage/sqlite3/one_time_keys_table.go b/userapi/storage/sqlite3/one_time_keys_table.go new file mode 100644 index 00000000..a992d399 --- /dev/null +++ b/userapi/storage/sqlite3/one_time_keys_table.go @@ -0,0 +1,208 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "encoding/json" + "time" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage/tables" +) + +var oneTimeKeysSchema = ` +-- Stores one-time public keys for users +CREATE TABLE IF NOT EXISTS keyserver_one_time_keys ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + key_id TEXT NOT NULL, + algorithm TEXT NOT NULL, + ts_added_secs BIGINT NOT NULL, + key_json TEXT NOT NULL, + -- Clobber based on 4-uple of user/device/key/algorithm. + UNIQUE (user_id, device_id, key_id, algorithm) +); + +CREATE INDEX IF NOT EXISTS keyserver_one_time_keys_idx ON keyserver_one_time_keys (user_id, device_id); +` + +const upsertKeysSQL = "" + + "INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" + + " VALUES ($1, $2, $3, $4, $5, $6)" + + " ON CONFLICT (user_id, device_id, key_id, algorithm)" + + " DO UPDATE SET key_json = $6" + +const selectOneTimeKeysSQL = "" + + "SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2" + +const selectKeysCountSQL = "" + + "SELECT algorithm, COUNT(key_id) FROM " + + " (SELECT algorithm, key_id FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 LIMIT 100)" + + " x GROUP BY algorithm" + +const deleteOneTimeKeySQL = "" + + "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4" + +const selectKeyByAlgorithmSQL = "" + + "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1" + +const deleteOneTimeKeysSQL = "" + + "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2" + +type oneTimeKeysStatements struct { + db *sql.DB + upsertKeysStmt *sql.Stmt + selectKeysStmt *sql.Stmt + selectKeysCountStmt *sql.Stmt + selectKeyByAlgorithmStmt *sql.Stmt + deleteOneTimeKeyStmt *sql.Stmt + deleteOneTimeKeysStmt *sql.Stmt +} + +func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) { + s := &oneTimeKeysStatements{ + db: db, + } + _, err := db.Exec(oneTimeKeysSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.upsertKeysStmt, upsertKeysSQL}, + {&s.selectKeysStmt, selectOneTimeKeysSQL}, + {&s.selectKeysCountStmt, selectKeysCountSQL}, + {&s.selectKeyByAlgorithmStmt, selectKeyByAlgorithmSQL}, + {&s.deleteOneTimeKeyStmt, deleteOneTimeKeySQL}, + {&s.deleteOneTimeKeysStmt, deleteOneTimeKeysSQL}, + }.Prepare(db) +} + +func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { + rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed") + + wantSet := make(map[string]bool, len(keyIDsWithAlgorithms)) + for _, ka := range keyIDsWithAlgorithms { + wantSet[ka] = true + } + + result := make(map[string]json.RawMessage) + for rows.Next() { + var keyID string + var algorithm string + var keyJSONStr string + if err := rows.Scan(&keyID, &algorithm, &keyJSONStr); err != nil { + return nil, err + } + keyIDWithAlgo := algorithm + ":" + keyID + if wantSet[keyIDWithAlgo] { + result[keyIDWithAlgo] = json.RawMessage(keyJSONStr) + } + } + return result, rows.Err() +} + +func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) { + counts := &api.OneTimeKeysCount{ + DeviceID: deviceID, + UserID: userID, + KeyCount: make(map[string]int), + } + rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") + for rows.Next() { + var algorithm string + var count int + if err = rows.Scan(&algorithm, &count); err != nil { + return nil, err + } + counts.KeyCount[algorithm] = count + } + return counts, nil +} + +func (s *oneTimeKeysStatements) InsertOneTimeKeys( + ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys, +) (*api.OneTimeKeysCount, error) { + now := time.Now().Unix() + counts := &api.OneTimeKeysCount{ + DeviceID: keys.DeviceID, + UserID: keys.UserID, + KeyCount: make(map[string]int), + } + for keyIDWithAlgo, keyJSON := range keys.KeyJSON { + algo, keyID := keys.Split(keyIDWithAlgo) + _, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext( + ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON), + ) + if err != nil { + return nil, err + } + } + rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed") + for rows.Next() { + var algorithm string + var count int + if err = rows.Scan(&algorithm, &count); err != nil { + return nil, err + } + counts.KeyCount[algorithm] = count + } + + return counts, rows.Err() +} + +func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( + ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string, +) (map[string]json.RawMessage, error) { + var keyID string + var keyJSON string + err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil + } + return nil, err + } + _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) + if err != nil { + return nil, err + } + if keyJSON == "" { + return nil, nil + } + return map[string]json.RawMessage{ + algorithm + ":" + keyID: json.RawMessage(keyJSON), + }, err +} + +func (s *oneTimeKeysStatements) DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error { + _, err := sqlutil.TxStmt(txn, s.deleteOneTimeKeysStmt).ExecContext(ctx, userID, deviceID) + return err +} diff --git a/userapi/storage/sqlite3/stale_device_lists.go b/userapi/storage/sqlite3/stale_device_lists.go new file mode 100644 index 00000000..f078fc99 --- /dev/null +++ b/userapi/storage/sqlite3/stale_device_lists.go @@ -0,0 +1,145 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "strings" + "time" + + "github.com/matrix-org/dendrite/internal/sqlutil" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" +) + +var staleDeviceListsSchema = ` +-- Stores whether a user's device lists are stale or not. +CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists ( + user_id TEXT PRIMARY KEY NOT NULL, + domain TEXT NOT NULL, + is_stale BOOLEAN NOT NULL, + ts_added_secs BIGINT NOT NULL +); + +CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale); +` + +const upsertStaleDeviceListSQL = "" + + "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" + + " VALUES ($1, $2, $3, $4)" + + " ON CONFLICT (user_id)" + + " DO UPDATE SET is_stale = $3, ts_added_secs = $4" + +const selectStaleDeviceListsWithDomainsSQL = "" + + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC" + +const selectStaleDeviceListsSQL = "" + + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC" + +const deleteStaleDevicesSQL = "" + + "DELETE FROM keyserver_stale_device_lists WHERE user_id IN ($1)" + +type staleDeviceListsStatements struct { + db *sql.DB + upsertStaleDeviceListStmt *sql.Stmt + selectStaleDeviceListsWithDomainsStmt *sql.Stmt + selectStaleDeviceListsStmt *sql.Stmt + // deleteStaleDeviceListsStmt *sql.Stmt // Prepared at runtime +} + +func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { + s := &staleDeviceListsStatements{ + db: db, + } + _, err := db.Exec(staleDeviceListsSchema) + if err != nil { + return nil, err + } + return s, sqlutil.StatementList{ + {&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL}, + {&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL}, + {&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL}, + // { &s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL}, // Prepared at runtime + }.Prepare(db) +} + +func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error { + _, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return err + } + _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now())) + return err +} + +func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { + // we only query for 1 domain or all domains so optimise for those use cases + if len(domains) == 0 { + rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true) + if err != nil { + return nil, err + } + return rowsToUserIDs(ctx, rows) + } + var result []string + for _, domain := range domains { + rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain)) + if err != nil { + return nil, err + } + userIDs, err := rowsToUserIDs(ctx, rows) + if err != nil { + return nil, err + } + result = append(result, userIDs...) + } + return result, nil +} + +// DeleteStaleDeviceLists removes users from stale device lists +func (s *staleDeviceListsStatements) DeleteStaleDeviceLists( + ctx context.Context, txn *sql.Tx, userIDs []string, +) error { + qry := strings.Replace(deleteStaleDevicesSQL, "($1)", sqlutil.QueryVariadic(len(userIDs)), 1) + stmt, err := s.db.Prepare(qry) + if err != nil { + return err + } + defer internal.CloseAndLogIfError(ctx, stmt, "DeleteStaleDeviceLists: stmt.Close failed") + stmt = sqlutil.TxStmt(txn, stmt) + + params := make([]any, len(userIDs)) + for i := range userIDs { + params[i] = userIDs[i] + } + + _, err = stmt.ExecContext(ctx, params...) + return err +} + +func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) { + defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed") + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return nil, err + } + result = append(result, userID) + } + return result, rows.Err() +} diff --git a/userapi/storage/sqlite3/stats_table.go b/userapi/storage/sqlite3/stats_table.go index a1365c94..72b3ba49 100644 --- a/userapi/storage/sqlite3/stats_table.go +++ b/userapi/storage/sqlite3/stats_table.go @@ -256,6 +256,7 @@ func (s *statsStatements) allUsers(ctx context.Context, txn *sql.Tx) (result int if err != nil { return 0, err } + defer internal.CloseAndLogIfError(ctx, queryStmt, "allUsers.StmtClose() failed") stmt := sqlutil.TxStmt(txn, queryStmt) err = stmt.QueryRowContext(ctx, 1, 2, 3, 4, @@ -269,6 +270,7 @@ func (s *statsStatements) nonBridgedUsers(ctx context.Context, txn *sql.Tx) (res if err != nil { return 0, err } + defer internal.CloseAndLogIfError(ctx, queryStmt, "nonBridgedUsers.StmtClose() failed") stmt := sqlutil.TxStmt(txn, queryStmt) err = stmt.QueryRowContext(ctx, 1, 2, 3, @@ -286,6 +288,7 @@ func (s *statsStatements) registeredUserByType(ctx context.Context, txn *sql.Tx) if err != nil { return nil, err } + defer internal.CloseAndLogIfError(ctx, queryStmt, "registeredUserByType.StmtClose() failed") stmt := sqlutil.TxStmt(txn, queryStmt) registeredAfter := time.Now().AddDate(0, 0, -30) diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index 85a1f706..0f3eeed1 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -30,8 +30,8 @@ import ( "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" ) -// NewDatabase creates a new accounts and profiles database -func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) { +// NewUserDatabase creates a new accounts and profiles database +func NewUserDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) { db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()) if err != nil { return nil, err @@ -134,3 +134,44 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, OpenIDTokenLifetimeMS: openIDTokenLifetimeMS, }, nil } + +func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) { + db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()) + if err != nil { + return nil, err + } + otk, err := NewSqliteOneTimeKeysTable(db) + if err != nil { + return nil, err + } + dk, err := NewSqliteDeviceKeysTable(db) + if err != nil { + return nil, err + } + kc, err := NewSqliteKeyChangesTable(db) + if err != nil { + return nil, err + } + sdl, err := NewSqliteStaleDeviceListsTable(db) + if err != nil { + return nil, err + } + csk, err := NewSqliteCrossSigningKeysTable(db) + if err != nil { + return nil, err + } + css, err := NewSqliteCrossSigningSigsTable(db) + if err != nil { + return nil, err + } + + return &shared.KeyDatabase{ + OneTimeKeysTable: otk, + DeviceKeysTable: dk, + KeyChangesTable: kc, + StaleDeviceListsTable: sdl, + CrossSigningKeysTable: csk, + CrossSigningSigsTable: css, + Writer: writer, + }, nil +} diff --git a/userapi/storage/storage.go b/userapi/storage/storage.go index 42221e75..0329fb46 100644 --- a/userapi/storage/storage.go +++ b/userapi/storage/storage.go @@ -29,15 +29,36 @@ import ( "github.com/matrix-org/dendrite/userapi/storage/sqlite3" ) -// NewUserAPIDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) +// NewUserDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) // and sets postgres connection parameters -func NewUserAPIDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (Database, error) { +func NewUserDatabase( + base *base.BaseDendrite, + dbProperties *config.DatabaseOptions, + serverName gomatrixserverlib.ServerName, + bcryptCost int, + openIDTokenLifetimeMS int64, + loginTokenLifetime time.Duration, + serverNoticesLocalpart string, +) (UserDatabase, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) + return sqlite3.NewUserDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) case dbProperties.ConnectionString.IsPostgres(): return postgres.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) default: return nil, fmt.Errorf("unexpected database type") } } + +// NewKeyDatabase opens a new Postgres or Sqlite database (base on dataSourceName) scheme) +// and sets postgres connection parameters. +func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (KeyDatabase, error) { + switch { + case dbProperties.ConnectionString.IsSQLite(): + return sqlite3.NewKeyDatabase(base, dbProperties) + case dbProperties.ConnectionString.IsPostgres(): + return postgres.NewKeyDatabase(base, dbProperties) + default: + return nil, fmt.Errorf("unexpected database type") + } +} diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index 23aafff0..f52e7e17 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -4,9 +4,12 @@ import ( "context" "encoding/json" "fmt" + "reflect" + "sync" "testing" "time" + "github.com/matrix-org/dendrite/userapi/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/stretchr/testify/assert" @@ -29,14 +32,14 @@ var ( ctx = context.Background() ) -func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { +func mustCreateUserDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) { base, baseclose := testrig.CreateBaseDendrite(t, dbType) connStr, close := test.PrepareDBConnectionString(t, dbType) - db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{ + db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server") if err != nil { - t.Fatalf("NewUserAPIDatabase returned %s", err) + t.Fatalf("NewUserDatabase returned %s", err) } return db, func() { close() @@ -47,7 +50,7 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, fun // Tests storing and getting account data func Test_AccountData(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() alice := test.NewUser(t) localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID) @@ -78,7 +81,7 @@ func Test_AccountData(t *testing.T) { // Tests the creation of accounts func Test_Accounts(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() alice := test.NewUser(t) aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID) @@ -158,7 +161,7 @@ func Test_Devices(t *testing.T) { accessToken := util.RandomString(16) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() deviceWithID, err := db.CreateDevice(ctx, localpart, domain, &deviceID, accessToken, nil, "", "") @@ -238,7 +241,7 @@ func Test_KeyBackup(t *testing.T) { room := test.NewRoom(t, alice) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() wantAuthData := json.RawMessage("my auth data") @@ -315,7 +318,7 @@ func Test_KeyBackup(t *testing.T) { func Test_LoginToken(t *testing.T) { alice := test.NewUser(t) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() // create a new token @@ -347,7 +350,7 @@ func Test_OpenID(t *testing.T) { token := util.RandomString(24) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + openIDLifetimeMS @@ -368,7 +371,7 @@ func Test_Profile(t *testing.T) { assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() // create account, which also creates a profile @@ -417,7 +420,7 @@ func Test_Pusher(t *testing.T) { assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() appID := util.RandomString(8) @@ -468,7 +471,7 @@ func Test_ThreePID(t *testing.T) { assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() threePID := util.RandomString(8) medium := util.RandomString(8) @@ -507,7 +510,7 @@ func Test_Notification(t *testing.T) { room := test.NewRoom(t, alice) room2 := test.NewRoom(t, alice) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, close := mustCreateDatabase(t, dbType) + db, close := mustCreateUserDatabase(t, dbType) defer close() // generate some dummy notifications for i := 0; i < 10; i++ { @@ -571,3 +574,184 @@ func Test_Notification(t *testing.T) { assert.Equal(t, int64(0), total) }) } + +func mustCreateKeyDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) { + base, close := testrig.CreateBaseDendrite(t, dbType) + db, err := storage.NewKeyDatabase(base, &base.Cfg.KeyServer.Database) + if err != nil { + t.Fatalf("failed to create new database: %v", err) + } + return db, close +} + +func MustNotError(t *testing.T, err error) { + t.Helper() + if err == nil { + return + } + t.Fatalf("operation failed: %s", err) +} + +func TestKeyChanges(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clean := mustCreateKeyDatabase(t, dbType) + defer clean() + _, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") + MustNotError(t, err) + deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost") + MustNotError(t, err) + userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest) + if err != nil { + t.Fatalf("Failed to KeyChanges: %s", err) + } + if latest != deviceChangeIDC { + t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC) + } + if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) { + t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) + } + }) +} + +func TestKeyChangesNoDupes(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clean := mustCreateKeyDatabase(t, dbType) + defer clean() + deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + if deviceChangeIDA == deviceChangeIDB { + t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA) + } + deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest) + if err != nil { + t.Fatalf("Failed to KeyChanges: %s", err) + } + if latest != deviceChangeID { + t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID) + } + if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) { + t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) + } + }) +} + +func TestKeyChangesUpperLimit(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clean := mustCreateKeyDatabase(t, dbType) + defer clean() + deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") + MustNotError(t, err) + _, err = db.StoreKeyChange(ctx, "@charlie:localhost") + MustNotError(t, err) + userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB) + if err != nil { + t.Fatalf("Failed to KeyChanges: %s", err) + } + if latest != deviceChangeIDB { + t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB) + } + if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) { + t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) + } + }) +} + +var dbLock sync.Mutex +var deviceArray = []string{"AAA", "another_device"} + +// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user, +// and that they are returned correctly when querying for device keys. +func TestDeviceKeysStreamIDGeneration(t *testing.T) { + var err error + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clean := mustCreateKeyDatabase(t, dbType) + defer clean() + alice := "@alice:TestDeviceKeysStreamIDGeneration" + bob := "@bob:TestDeviceKeysStreamIDGeneration" + msgs := []api.DeviceMessage{ + { + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + DeviceID: "AAA", + UserID: alice, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 1 + }, + { + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + DeviceID: "AAA", + UserID: bob, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 1 as this is a different user + }, + { + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + DeviceID: "another_device", + UserID: alice, + KeyJSON: []byte(`{"key":"v1"}`), + }, + // StreamID: 2 as this is a 2nd device key + }, + } + MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) + if msgs[0].StreamID != 1 { + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID) + } + if msgs[1].StreamID != 1 { + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID) + } + if msgs[2].StreamID != 2 { + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID) + } + + // updating a device sets the next stream ID for that user + msgs = []api.DeviceMessage{ + { + Type: api.TypeDeviceKeyUpdate, + DeviceKeys: &api.DeviceKeys{ + DeviceID: "AAA", + UserID: alice, + KeyJSON: []byte(`{"key":"v2"}`), + }, + // StreamID: 3 + }, + } + MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) + if msgs[0].StreamID != 3 { + t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID) + } + + dbLock.Lock() + defer dbLock.Unlock() + // Querying for device keys returns the latest stream IDs + msgs, err = db.DeviceKeysForUser(ctx, alice, deviceArray, false) + + if err != nil { + t.Fatalf("DeviceKeysForUser returned error: %s", err) + } + wantStreamIDs := map[string]int64{ + "AAA": 3, + "another_device": 2, + } + if len(msgs) != len(wantStreamIDs) { + t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs)) + } + for _, m := range msgs { + if m.StreamID != wantStreamIDs[m.DeviceID] { + t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID]) + } + } + }) +} diff --git a/userapi/storage/storage_wasm.go b/userapi/storage/storage_wasm.go index 5d5d292e..163e3e17 100644 --- a/userapi/storage/storage_wasm.go +++ b/userapi/storage/storage_wasm.go @@ -32,10 +32,10 @@ func NewUserAPIDatabase( openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string, -) (Database, error) { +) (UserDatabase, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) + return sqlite3.NewUserDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) case dbProperties.ConnectionString.IsPostgres(): return nil, fmt.Errorf("can't use Postgres implementation") default: diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 9221e571..693e7303 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -20,10 +20,10 @@ import ( "encoding/json" "time" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/types" ) @@ -145,3 +145,47 @@ const ( // uint32. AllNotifications NotificationFilter = (1 << 31) - 1 ) + +type OneTimeKeys interface { + SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) + CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) + InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) + // SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON. + // Returns an empty map if the key does not exist. + SelectAndDeleteOneTimeKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error) + DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error +} + +type DeviceKeys interface { + SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error + InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error + SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) + CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) + SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) + DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error + DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error +} + +type KeyChanges interface { + InsertKeyChange(ctx context.Context, userID string) (int64, error) + // SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets. + // Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of types.OffsetNewest means no upper offset. + SelectKeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) +} + +type StaleDeviceLists interface { + InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error + SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) + DeleteStaleDeviceLists(ctx context.Context, txn *sql.Tx, userIDs []string) error +} + +type CrossSigningKeys interface { + SelectCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string) (r types.CrossSigningKeyMap, err error) + UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes) error +} + +type CrossSigningSigs interface { + SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r types.CrossSigningSigMap, err error) + UpsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error + DeleteCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) error +} diff --git a/userapi/storage/tables/stale_device_lists_test.go b/userapi/storage/tables/stale_device_lists_test.go new file mode 100644 index 00000000..b9bdafda --- /dev/null +++ b/userapi/storage/tables/stale_device_lists_test.go @@ -0,0 +1,94 @@ +package tables_test + +import ( + "context" + "testing" + + "github.com/matrix-org/dendrite/userapi/storage/postgres" + "github.com/matrix-org/dendrite/userapi/storage/sqlite3" + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/userapi/storage/tables" +) + +func mustCreateTable(t *testing.T, dbType test.DBType) (tab tables.StaleDeviceLists, close func()) { + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, nil) + if err != nil { + t.Fatalf("failed to open database: %s", err) + } + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresStaleDeviceListsTable(db) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSqliteStaleDeviceListsTable(db) + } + if err != nil { + t.Fatalf("failed to create new table: %s", err) + } + return tab, close +} + +func TestStaleDeviceLists(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + charlie := "@charlie:localhost" + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, closeDB := mustCreateTable(t, dbType) + defer closeDB() + + if err := tab.InsertStaleDeviceList(ctx, alice.ID, true); err != nil { + t.Fatalf("failed to insert stale device: %s", err) + } + if err := tab.InsertStaleDeviceList(ctx, bob.ID, true); err != nil { + t.Fatalf("failed to insert stale device: %s", err) + } + if err := tab.InsertStaleDeviceList(ctx, charlie, true); err != nil { + t.Fatalf("failed to insert stale device: %s", err) + } + + // Query one server + wantStaleUsers := []string{alice.ID, bob.ID} + gotStaleUsers, err := tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) + if err != nil { + t.Fatalf("failed to query stale device lists: %s", err) + } + if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) { + t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers) + } + + // Query all servers + wantStaleUsers = []string{alice.ID, bob.ID, charlie} + gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{}) + if err != nil { + t.Fatalf("failed to query stale device lists: %s", err) + } + if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) { + t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers) + } + + // Delete stale devices + deleteUsers := []string{alice.ID, bob.ID} + if err = tab.DeleteStaleDeviceLists(ctx, nil, deleteUsers); err != nil { + t.Fatalf("failed to delete stale device lists: %s", err) + } + + // Verify we don't get anything back after deleting + gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) + if err != nil { + t.Fatalf("failed to query stale device lists: %s", err) + } + + if gotCount := len(gotStaleUsers); gotCount > 0 { + t.Fatalf("expected no stale users, got %d", gotCount) + } + }) +} |