aboutsummaryrefslogtreecommitdiff
path: root/userapi/storage
diff options
context:
space:
mode:
authorTill <2353100+S7evinK@users.noreply.github.com>2023-02-20 14:58:03 +0100
committerGitHub <noreply@github.com>2023-02-20 14:58:03 +0100
commit4594233f89f8531fca8f696ab0ece36909130c2a (patch)
tree18d3c451041423022e15ba5fcc4a778806ff94dc /userapi/storage
parentbd6f0c14e56af71d83d703b7c91b8cf829ca560f (diff)
Merge keyserver & userapi (#2972)
As discussed yesterday, a first draft of merging the keyserver and the userapi.
Diffstat (limited to 'userapi/storage')
-rw-r--r--userapi/storage/interface.go76
-rw-r--r--userapi/storage/postgres/account_data_table.go8
-rw-r--r--userapi/storage/postgres/cross_signing_keys_table.go102
-rw-r--r--userapi/storage/postgres/cross_signing_sigs_table.go131
-rw-r--r--userapi/storage/postgres/deltas/2022012016470000_key_changes.go69
-rw-r--r--userapi/storage/postgres/deltas/2022042612000000_xsigning_idx.go47
-rw-r--r--userapi/storage/postgres/device_keys_table.go213
-rw-r--r--userapi/storage/postgres/devices_table.go8
-rw-r--r--userapi/storage/postgres/key_backup_table.go4
-rw-r--r--userapi/storage/postgres/key_changes_table.go127
-rw-r--r--userapi/storage/postgres/one_time_keys_table.go194
-rw-r--r--userapi/storage/postgres/stale_device_lists.go131
-rw-r--r--userapi/storage/postgres/storage.go41
-rw-r--r--userapi/storage/shared/storage.go235
-rw-r--r--userapi/storage/sqlite3/cross_signing_keys_table.go101
-rw-r--r--userapi/storage/sqlite3/cross_signing_sigs_table.go129
-rw-r--r--userapi/storage/sqlite3/deltas/2022012016470000_key_changes.go66
-rw-r--r--userapi/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go71
-rw-r--r--userapi/storage/sqlite3/device_keys_table.go213
-rw-r--r--userapi/storage/sqlite3/devices_table.go17
-rw-r--r--userapi/storage/sqlite3/key_backup_table.go4
-rw-r--r--userapi/storage/sqlite3/key_changes_table.go125
-rw-r--r--userapi/storage/sqlite3/one_time_keys_table.go208
-rw-r--r--userapi/storage/sqlite3/stale_device_lists.go145
-rw-r--r--userapi/storage/sqlite3/stats_table.go3
-rw-r--r--userapi/storage/sqlite3/storage.go45
-rw-r--r--userapi/storage/storage.go27
-rw-r--r--userapi/storage/storage_test.go210
-rw-r--r--userapi/storage/storage_wasm.go4
-rw-r--r--userapi/storage/tables/interface.go46
-rw-r--r--userapi/storage/tables/stale_device_lists_test.go94
31 files changed, 2860 insertions, 34 deletions
diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go
index c22b7658..27837886 100644
--- a/userapi/storage/interface.go
+++ b/userapi/storage/interface.go
@@ -90,7 +90,7 @@ type KeyBackup interface {
type LoginToken interface {
// CreateLoginToken generates a token, stores and returns it. The lifetime is
- // determined by the loginTokenLifetime given to the Database constructor.
+ // determined by the loginTokenLifetime given to the UserDatabase constructor.
CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error)
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
@@ -130,7 +130,7 @@ type Notification interface {
DeleteOldNotifications(ctx context.Context) error
}
-type Database interface {
+type UserDatabase interface {
Account
AccountData
Device
@@ -144,6 +144,78 @@ type Database interface {
ThreePID
}
+type KeyChangeDatabase interface {
+ // StoreKeyChange stores key change metadata and returns the device change ID which represents the position in the /sync stream for this device change.
+ // `userID` is the the user who has changed their keys in some way.
+ StoreKeyChange(ctx context.Context, userID string) (int64, error)
+}
+
+type KeyDatabase interface {
+ KeyChangeDatabase
+ // ExistingOneTimeKeys returns a map of keyIDWithAlgorithm to key JSON for the given parameters. If no keys exist with this combination
+ // of user/device/key/algorithm 4-uple then it is omitted from the map. Returns an error when failing to communicate with the database.
+ ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
+
+ // StoreOneTimeKeys persists the given one-time keys.
+ StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
+
+ // OneTimeKeysCount returns a count of all OTKs for this device.
+ OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
+
+ // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
+ DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
+
+ // StoreLocalDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
+ // for this (user, device).
+ // The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set.
+ // Returns an error if there was a problem storing the keys.
+ StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
+
+ // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
+ // for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior
+ // to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly.
+ StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
+
+ // PrevIDsExists returns true if all prev IDs exist for this user.
+ PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error)
+
+ // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
+ // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
+ DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
+
+ // DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
+ // cross-signing signatures relating to that device.
+ DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error
+
+ // ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key
+ // cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
+ ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error)
+
+ // KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive).
+ // A to offset of types.OffsetNewest means no upper limit.
+ // Returns the offset of the latest key change.
+ KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
+
+ // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
+ // If no domains are given, all user IDs with stale device lists are returned.
+ StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
+
+ // MarkDeviceListStale sets the stale bit for this user to isStale.
+ MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error
+
+ CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error)
+ CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error)
+ CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error)
+
+ StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error
+ StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error
+
+ DeleteStaleDeviceLists(
+ ctx context.Context,
+ userIDs []string,
+ ) error
+}
+
type Statistics interface {
UserStatistics(ctx context.Context) (*types.UserStatistics, *types.DatabaseEngine, error)
DailyRoomsMessages(ctx context.Context, serverName gomatrixserverlib.ServerName) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error)
diff --git a/userapi/storage/postgres/account_data_table.go b/userapi/storage/postgres/account_data_table.go
index 2a4777d7..05716037 100644
--- a/userapi/storage/postgres/account_data_table.go
+++ b/userapi/storage/postgres/account_data_table.go
@@ -78,7 +78,13 @@ func (s *accountDataStatements) InsertAccountData(
roomID, dataType string, content json.RawMessage,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt)
- _, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, content)
+ // Empty/nil json.RawMessage is not interpreted as "nil", so use *json.RawMessage
+ // when passing the data to trigger "NOT NULL" constraint
+ var data *json.RawMessage
+ if len(content) > 0 {
+ data = &content
+ }
+ _, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, data)
return
}
diff --git a/userapi/storage/postgres/cross_signing_keys_table.go b/userapi/storage/postgres/cross_signing_keys_table.go
new file mode 100644
index 00000000..c0ecbd30
--- /dev/null
+++ b/userapi/storage/postgres/cross_signing_keys_table.go
@@ -0,0 +1,102 @@
+// Copyright 2021 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/dendrite/userapi/types"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+var crossSigningKeysSchema = `
+CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys (
+ user_id TEXT NOT NULL,
+ key_type SMALLINT NOT NULL,
+ key_data TEXT NOT NULL,
+ PRIMARY KEY (user_id, key_type)
+);
+`
+
+const selectCrossSigningKeysForUserSQL = "" +
+ "SELECT key_type, key_data FROM keyserver_cross_signing_keys" +
+ " WHERE user_id = $1"
+
+const upsertCrossSigningKeysForUserSQL = "" +
+ "INSERT INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" +
+ " VALUES($1, $2, $3)" +
+ " ON CONFLICT (user_id, key_type) DO UPDATE SET key_data = $3"
+
+type crossSigningKeysStatements struct {
+ db *sql.DB
+ selectCrossSigningKeysForUserStmt *sql.Stmt
+ upsertCrossSigningKeysForUserStmt *sql.Stmt
+}
+
+func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) {
+ s := &crossSigningKeysStatements{
+ db: db,
+ }
+ _, err := db.Exec(crossSigningKeysSchema)
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL},
+ {&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL},
+ }.Prepare(db)
+}
+
+func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser(
+ ctx context.Context, txn *sql.Tx, userID string,
+) (r types.CrossSigningKeyMap, err error) {
+ rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningKeysForUserStmt: rows.close() failed")
+ r = types.CrossSigningKeyMap{}
+ for rows.Next() {
+ var keyTypeInt int16
+ var keyData gomatrixserverlib.Base64Bytes
+ if err := rows.Scan(&keyTypeInt, &keyData); err != nil {
+ return nil, err
+ }
+ keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt]
+ if !ok {
+ return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt)
+ }
+ r[keyType] = keyData
+ }
+ return
+}
+
+func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser(
+ ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes,
+) error {
+ keyTypeInt, ok := types.KeyTypePurposeToInt[keyType]
+ if !ok {
+ return fmt.Errorf("unknown key purpose %q", keyType)
+ }
+ if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData); err != nil {
+ return fmt.Errorf("s.upsertCrossSigningKeysForUserStmt: %w", err)
+ }
+ return nil
+}
diff --git a/userapi/storage/postgres/cross_signing_sigs_table.go b/userapi/storage/postgres/cross_signing_sigs_table.go
new file mode 100644
index 00000000..b0117145
--- /dev/null
+++ b/userapi/storage/postgres/cross_signing_sigs_table.go
@@ -0,0 +1,131 @@
+// Copyright 2021 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/dendrite/userapi/types"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+var crossSigningSigsSchema = `
+CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs (
+ origin_user_id TEXT NOT NULL,
+ origin_key_id TEXT NOT NULL,
+ target_user_id TEXT NOT NULL,
+ target_key_id TEXT NOT NULL,
+ signature TEXT NOT NULL,
+ PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id)
+);
+
+CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
+`
+
+const selectCrossSigningSigsForTargetSQL = "" +
+ "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" +
+ " WHERE (origin_user_id = $1 OR origin_user_id = $2) AND target_user_id = $2 AND target_key_id = $3"
+
+const upsertCrossSigningSigsForTargetSQL = "" +
+ "INSERT INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" +
+ " VALUES($1, $2, $3, $4, $5)" +
+ " ON CONFLICT (origin_user_id, origin_key_id, target_user_id, target_key_id) DO UPDATE SET signature = $5"
+
+const deleteCrossSigningSigsForTargetSQL = "" +
+ "DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2"
+
+type crossSigningSigsStatements struct {
+ db *sql.DB
+ selectCrossSigningSigsForTargetStmt *sql.Stmt
+ upsertCrossSigningSigsForTargetStmt *sql.Stmt
+ deleteCrossSigningSigsForTargetStmt *sql.Stmt
+}
+
+func NewPostgresCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) {
+ s := &crossSigningSigsStatements{
+ db: db,
+ }
+ _, err := db.Exec(crossSigningSigsSchema)
+ if err != nil {
+ return nil, err
+ }
+
+ m := sqlutil.NewMigrator(db)
+ m.AddMigrations(sqlutil.Migration{
+ Version: "keyserver: cross signing signature indexes",
+ Up: deltas.UpFixCrossSigningSignatureIndexes,
+ })
+ if err = m.Up(context.Background()); err != nil {
+ return nil, err
+ }
+
+ return s, sqlutil.StatementList{
+ {&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL},
+ {&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL},
+ {&s.deleteCrossSigningSigsForTargetStmt, deleteCrossSigningSigsForTargetSQL},
+ }.Prepare(db)
+}
+
+func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget(
+ ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID,
+) (r types.CrossSigningSigMap, err error) {
+ rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetKeyID)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForTargetStmt: rows.close() failed")
+ r = types.CrossSigningSigMap{}
+ for rows.Next() {
+ var userID string
+ var keyID gomatrixserverlib.KeyID
+ var signature gomatrixserverlib.Base64Bytes
+ if err := rows.Scan(&userID, &keyID, &signature); err != nil {
+ return nil, err
+ }
+ if _, ok := r[userID]; !ok {
+ r[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
+ }
+ r[userID][keyID] = signature
+ }
+ return
+}
+
+func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget(
+ ctx context.Context, txn *sql.Tx,
+ originUserID string, originKeyID gomatrixserverlib.KeyID,
+ targetUserID string, targetKeyID gomatrixserverlib.KeyID,
+ signature gomatrixserverlib.Base64Bytes,
+) error {
+ if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil {
+ return fmt.Errorf("s.upsertCrossSigningSigsForTargetStmt: %w", err)
+ }
+ return nil
+}
+
+func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget(
+ ctx context.Context, txn *sql.Tx,
+ targetUserID string, targetKeyID gomatrixserverlib.KeyID,
+) error {
+ if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil {
+ return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err)
+ }
+ return nil
+}
diff --git a/userapi/storage/postgres/deltas/2022012016470000_key_changes.go b/userapi/storage/postgres/deltas/2022012016470000_key_changes.go
new file mode 100644
index 00000000..0cfe9e79
--- /dev/null
+++ b/userapi/storage/postgres/deltas/2022012016470000_key_changes.go
@@ -0,0 +1,69 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package deltas
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+)
+
+func UpRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
+ // start counting from the last max offset, else 0. We need to do a count(*) first to see if there
+ // even are entries in this table to know if we can query for log_offset. Without the count then
+ // the query to SELECT the max log offset fails on new Dendrite instances as log_offset doesn't
+ // exist on that table. Even though we discard the error, the txn is tainted and gets aborted :/
+ var count int
+ _ = tx.QueryRowContext(ctx, `SELECT count(*) FROM keyserver_key_changes`).Scan(&count)
+ if count > 0 {
+ var maxOffset int64
+ _ = tx.QueryRowContext(ctx, `SELECT coalesce(MAX(log_offset), 0) AS offset FROM keyserver_key_changes`).Scan(&maxOffset)
+ if _, err := tx.ExecContext(ctx, fmt.Sprintf(`CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq START %d`, maxOffset)); err != nil {
+ return fmt.Errorf("failed to CREATE SEQUENCE for key changes, starting at %d: %s", maxOffset, err)
+ }
+ }
+
+ _, err := tx.ExecContext(ctx, `
+ -- make the new table
+ DROP TABLE IF EXISTS keyserver_key_changes;
+ CREATE TABLE IF NOT EXISTS keyserver_key_changes (
+ change_id BIGINT PRIMARY KEY DEFAULT nextval('keyserver_key_changes_seq'),
+ user_id TEXT NOT NULL,
+ CONSTRAINT keyserver_key_changes_unique_per_user UNIQUE (user_id)
+ );
+ `)
+ if err != nil {
+ return fmt.Errorf("failed to execute upgrade: %w", err)
+ }
+ return nil
+}
+
+func DownRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
+ -- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers
+ DROP SEQUENCE IF EXISTS keyserver_key_changes_seq;
+ DROP TABLE IF EXISTS keyserver_key_changes;
+ CREATE TABLE IF NOT EXISTS keyserver_key_changes (
+ partition BIGINT NOT NULL,
+ log_offset BIGINT NOT NULL,
+ user_id TEXT NOT NULL,
+ CONSTRAINT keyserver_key_changes_unique UNIQUE (partition, log_offset)
+ );
+ `)
+ if err != nil {
+ return fmt.Errorf("failed to execute downgrade: %w", err)
+ }
+ return nil
+}
diff --git a/userapi/storage/postgres/deltas/2022042612000000_xsigning_idx.go b/userapi/storage/postgres/deltas/2022042612000000_xsigning_idx.go
new file mode 100644
index 00000000..1a3d4fee
--- /dev/null
+++ b/userapi/storage/postgres/deltas/2022042612000000_xsigning_idx.go
@@ -0,0 +1,47 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package deltas
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+)
+
+func UpFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
+ ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey;
+ ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id);
+
+ CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
+ `)
+ if err != nil {
+ return fmt.Errorf("failed to execute upgrade: %w", err)
+ }
+ return nil
+}
+
+func DownFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
+ ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey;
+ ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, target_user_id, target_key_id);
+
+ DROP INDEX IF EXISTS keyserver_cross_signing_sigs_idx;
+ `)
+ if err != nil {
+ return fmt.Errorf("failed to execute downgrade: %w", err)
+ }
+ return nil
+}
diff --git a/userapi/storage/postgres/device_keys_table.go b/userapi/storage/postgres/device_keys_table.go
new file mode 100644
index 00000000..a9203857
--- /dev/null
+++ b/userapi/storage/postgres/device_keys_table.go
@@ -0,0 +1,213 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "time"
+
+ "github.com/lib/pq"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+)
+
+var deviceKeysSchema = `
+-- Stores device keys for users
+CREATE TABLE IF NOT EXISTS keyserver_device_keys (
+ user_id TEXT NOT NULL,
+ device_id TEXT NOT NULL,
+ ts_added_secs BIGINT NOT NULL,
+ key_json TEXT NOT NULL,
+ -- the stream ID of this key, scoped per-user. This gets updated when the device key changes.
+ -- This means we do not store an unbounded append-only log of device keys, which is not actually
+ -- required in the spec because in the event of a missed update the server fetches the entire
+ -- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs.
+ stream_id BIGINT NOT NULL,
+ display_name TEXT,
+ -- Clobber based on tuple of user/device.
+ CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id)
+);
+`
+
+const upsertDeviceKeysSQL = "" +
+ "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
+ " VALUES ($1, $2, $3, $4, $5, $6)" +
+ " ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" +
+ " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
+
+const selectDeviceKeysSQL = "" +
+ "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
+
+const selectBatchDeviceKeysSQL = "" +
+ "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
+
+const selectBatchDeviceKeysWithEmptiesSQL = "" +
+ "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
+
+const selectMaxStreamForUserSQL = "" +
+ "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
+
+const countStreamIDsForUserSQL = "" +
+ "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)"
+
+const deleteDeviceKeysSQL = "" +
+ "DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
+
+const deleteAllDeviceKeysSQL = "" +
+ "DELETE FROM keyserver_device_keys WHERE user_id=$1"
+
+type deviceKeysStatements struct {
+ db *sql.DB
+ upsertDeviceKeysStmt *sql.Stmt
+ selectDeviceKeysStmt *sql.Stmt
+ selectBatchDeviceKeysStmt *sql.Stmt
+ selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
+ selectMaxStreamForUserStmt *sql.Stmt
+ countStreamIDsForUserStmt *sql.Stmt
+ deleteDeviceKeysStmt *sql.Stmt
+ deleteAllDeviceKeysStmt *sql.Stmt
+}
+
+func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
+ s := &deviceKeysStatements{
+ db: db,
+ }
+ _, err := db.Exec(deviceKeysSchema)
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.upsertDeviceKeysStmt, upsertDeviceKeysSQL},
+ {&s.selectDeviceKeysStmt, selectDeviceKeysSQL},
+ {&s.selectBatchDeviceKeysStmt, selectBatchDeviceKeysSQL},
+ {&s.selectBatchDeviceKeysWithEmptiesStmt, selectBatchDeviceKeysWithEmptiesSQL},
+ {&s.selectMaxStreamForUserStmt, selectMaxStreamForUserSQL},
+ {&s.countStreamIDsForUserStmt, countStreamIDsForUserSQL},
+ {&s.deleteDeviceKeysStmt, deleteDeviceKeysSQL},
+ {&s.deleteAllDeviceKeysStmt, deleteAllDeviceKeysSQL},
+ }.Prepare(db)
+}
+
+func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
+ for i, key := range keys {
+ var keyJSONStr string
+ var streamID int64
+ var displayName sql.NullString
+ err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
+ if err != nil && err != sql.ErrNoRows {
+ return err
+ }
+ // this will be '' when there is no device
+ keys[i].Type = api.TypeDeviceKeyUpdate
+ keys[i].KeyJSON = []byte(keyJSONStr)
+ keys[i].StreamID = streamID
+ if displayName.Valid {
+ keys[i].DisplayName = displayName.String
+ }
+ }
+ return nil
+}
+
+func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) {
+ // nullable if there are no results
+ var nullStream sql.NullInt64
+ err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
+ if err == sql.ErrNoRows {
+ err = nil
+ }
+ if nullStream.Valid {
+ streamID = nullStream.Int64
+ }
+ return
+}
+
+func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) {
+ // nullable if there are no results
+ var count sql.NullInt32
+ err := s.countStreamIDsForUserStmt.QueryRowContext(ctx, userID, pq.Int64Array(streamIDs)).Scan(&count)
+ if err != nil {
+ return 0, err
+ }
+ if count.Valid {
+ return int(count.Int32), nil
+ }
+ return 0, nil
+}
+
+func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
+ for _, key := range keys {
+ now := time.Now().Unix()
+ _, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
+ ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
+ )
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
+ _, err := sqlutil.TxStmt(txn, s.deleteDeviceKeysStmt).ExecContext(ctx, userID, deviceID)
+ return err
+}
+
+func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
+ _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
+ return err
+}
+
+func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
+ var stmt *sql.Stmt
+ if includeEmpty {
+ stmt = s.selectBatchDeviceKeysWithEmptiesStmt
+ } else {
+ stmt = s.selectBatchDeviceKeysStmt
+ }
+ rows, err := stmt.QueryContext(ctx, userID)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
+ deviceIDMap := make(map[string]bool)
+ for _, d := range deviceIDs {
+ deviceIDMap[d] = true
+ }
+ var result []api.DeviceMessage
+ var displayName sql.NullString
+ for rows.Next() {
+ dk := api.DeviceMessage{
+ Type: api.TypeDeviceKeyUpdate,
+ DeviceKeys: &api.DeviceKeys{
+ UserID: userID,
+ },
+ }
+ if err := rows.Scan(&dk.DeviceID, &dk.KeyJSON, &dk.StreamID, &displayName); err != nil {
+ return nil, err
+ }
+ if displayName.Valid {
+ dk.DisplayName = displayName.String
+ }
+ // include the key if we want all keys (no device) or it was asked
+ if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
+ result = append(result, dk)
+ }
+ }
+ return result, rows.Err()
+}
diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go
index 7481ac5b..88f8839c 100644
--- a/userapi/storage/postgres/devices_table.go
+++ b/userapi/storage/postgres/devices_table.go
@@ -160,7 +160,7 @@ func (s *devicesStatements) InsertDevice(
if err := stmt.QueryRowContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil {
return nil, fmt.Errorf("insertDeviceStmt: %w", err)
}
- return &api.Device{
+ dev := &api.Device{
ID: id,
UserID: userutil.MakeUserID(localpart, serverName),
AccessToken: accessToken,
@@ -168,7 +168,11 @@ func (s *devicesStatements) InsertDevice(
LastSeenTS: createdTimeMS,
LastSeenIP: ipAddr,
UserAgent: userAgent,
- }, nil
+ }
+ if displayName != nil {
+ dev.DisplayName = *displayName
+ }
+ return dev, nil
}
func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id,
diff --git a/userapi/storage/postgres/key_backup_table.go b/userapi/storage/postgres/key_backup_table.go
index 7b58f7ba..91a34c35 100644
--- a/userapi/storage/postgres/key_backup_table.go
+++ b/userapi/storage/postgres/key_backup_table.go
@@ -52,7 +52,7 @@ const updateBackupKeySQL = "" +
const countKeysSQL = "" +
"SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2"
-const selectKeysSQL = "" +
+const selectBackupKeysSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
"WHERE user_id = $1 AND version = $2"
@@ -83,7 +83,7 @@ func NewPostgresKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) {
{&s.insertBackupKeyStmt, insertBackupKeySQL},
{&s.updateBackupKeyStmt, updateBackupKeySQL},
{&s.countKeysStmt, countKeysSQL},
- {&s.selectKeysStmt, selectKeysSQL},
+ {&s.selectKeysStmt, selectBackupKeysSQL},
{&s.selectKeysByRoomIDStmt, selectKeysByRoomIDSQL},
{&s.selectKeysByRoomIDAndSessionIDStmt, selectKeysByRoomIDAndSessionIDSQL},
}.Prepare(db)
diff --git a/userapi/storage/postgres/key_changes_table.go b/userapi/storage/postgres/key_changes_table.go
new file mode 100644
index 00000000..a0049414
--- /dev/null
+++ b/userapi/storage/postgres/key_changes_table.go
@@ -0,0 +1,127 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+)
+
+var keyChangesSchema = `
+-- Stores key change information about users. Used to determine when to send updated device lists to clients.
+CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq;
+CREATE TABLE IF NOT EXISTS keyserver_key_changes (
+ change_id BIGINT PRIMARY KEY DEFAULT nextval('keyserver_key_changes_seq'),
+ user_id TEXT NOT NULL,
+ CONSTRAINT keyserver_key_changes_unique_per_user UNIQUE (user_id)
+);
+`
+
+// Replace based on user ID. We don't care how many times the user's keys have changed, only that they
+// have changed, hence we can just keep bumping the change ID for this user.
+const upsertKeyChangeSQL = "" +
+ "INSERT INTO keyserver_key_changes (user_id)" +
+ " VALUES ($1)" +
+ " ON CONFLICT ON CONSTRAINT keyserver_key_changes_unique_per_user" +
+ " DO UPDATE SET change_id = nextval('keyserver_key_changes_seq')" +
+ " RETURNING change_id"
+
+const selectKeyChangesSQL = "" +
+ "SELECT user_id, change_id FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2"
+
+type keyChangesStatements struct {
+ db *sql.DB
+ upsertKeyChangeStmt *sql.Stmt
+ selectKeyChangesStmt *sql.Stmt
+}
+
+func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
+ s := &keyChangesStatements{
+ db: db,
+ }
+ _, err := db.Exec(keyChangesSchema)
+ if err != nil {
+ return s, err
+ }
+
+ if err = executeMigration(context.Background(), db); err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.upsertKeyChangeStmt, upsertKeyChangeSQL},
+ {&s.selectKeyChangesStmt, selectKeyChangesSQL},
+ }.Prepare(db)
+}
+
+func executeMigration(ctx context.Context, db *sql.DB) error {
+ // TODO: Remove when we are sure we are not having goose artefacts in the db
+ // This forces an error, which indicates the migration is already applied, since the
+ // column partition was removed from the table
+ migrationName := "keyserver: refactor key changes"
+
+ var cName string
+ err := db.QueryRowContext(ctx, "select column_name from information_schema.columns where table_name = 'keyserver_key_changes' AND column_name = 'partition'").Scan(&cName)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) { // migration was already executed, as the column was removed
+ if err = sqlutil.InsertMigration(ctx, db, migrationName); err != nil {
+ return fmt.Errorf("unable to manually insert migration '%s': %w", migrationName, err)
+ }
+ return nil
+ }
+ return err
+ }
+ m := sqlutil.NewMigrator(db)
+ m.AddMigrations(sqlutil.Migration{
+ Version: migrationName,
+ Up: deltas.UpRefactorKeyChanges,
+ })
+
+ return m.Up(ctx)
+}
+
+func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) {
+ err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID)
+ return
+}
+
+func (s *keyChangesStatements) SelectKeyChanges(
+ ctx context.Context, fromOffset, toOffset int64,
+) (userIDs []string, latestOffset int64, err error) {
+ latestOffset = fromOffset
+ rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset)
+ if err != nil {
+ return nil, 0, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed")
+ for rows.Next() {
+ var userID string
+ var offset int64
+ if err := rows.Scan(&userID, &offset); err != nil {
+ return nil, 0, err
+ }
+ if offset > latestOffset {
+ latestOffset = offset
+ }
+ userIDs = append(userIDs, userID)
+ }
+ return
+}
diff --git a/userapi/storage/postgres/one_time_keys_table.go b/userapi/storage/postgres/one_time_keys_table.go
new file mode 100644
index 00000000..972a5914
--- /dev/null
+++ b/userapi/storage/postgres/one_time_keys_table.go
@@ -0,0 +1,194 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "time"
+
+ "github.com/lib/pq"
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+)
+
+var oneTimeKeysSchema = `
+-- Stores one-time public keys for users
+CREATE TABLE IF NOT EXISTS keyserver_one_time_keys (
+ user_id TEXT NOT NULL,
+ device_id TEXT NOT NULL,
+ key_id TEXT NOT NULL,
+ algorithm TEXT NOT NULL,
+ ts_added_secs BIGINT NOT NULL,
+ key_json TEXT NOT NULL,
+ -- Clobber based on 4-uple of user/device/key/algorithm.
+ CONSTRAINT keyserver_one_time_keys_unique UNIQUE (user_id, device_id, key_id, algorithm)
+);
+
+CREATE INDEX IF NOT EXISTS keyserver_one_time_keys_idx ON keyserver_one_time_keys (user_id, device_id);
+`
+
+const upsertKeysSQL = "" +
+ "INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" +
+ " VALUES ($1, $2, $3, $4, $5, $6)" +
+ " ON CONFLICT ON CONSTRAINT keyserver_one_time_keys_unique" +
+ " DO UPDATE SET key_json = $6"
+
+const selectOneTimeKeysSQL = "" +
+ "SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 AND concat(algorithm, ':', key_id) = ANY($3);"
+
+const selectKeysCountSQL = "" +
+ "SELECT algorithm, COUNT(key_id) FROM " +
+ " (SELECT algorithm, key_id FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 LIMIT 100)" +
+ " x GROUP BY algorithm"
+
+const deleteOneTimeKeySQL = "" +
+ "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4"
+
+const selectKeyByAlgorithmSQL = "" +
+ "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
+
+const deleteOneTimeKeysSQL = "" +
+ "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2"
+
+type oneTimeKeysStatements struct {
+ db *sql.DB
+ upsertKeysStmt *sql.Stmt
+ selectKeysStmt *sql.Stmt
+ selectKeysCountStmt *sql.Stmt
+ selectKeyByAlgorithmStmt *sql.Stmt
+ deleteOneTimeKeyStmt *sql.Stmt
+ deleteOneTimeKeysStmt *sql.Stmt
+}
+
+func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
+ s := &oneTimeKeysStatements{
+ db: db,
+ }
+ _, err := db.Exec(oneTimeKeysSchema)
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.upsertKeysStmt, upsertKeysSQL},
+ {&s.selectKeysStmt, selectOneTimeKeysSQL},
+ {&s.selectKeysCountStmt, selectKeysCountSQL},
+ {&s.selectKeyByAlgorithmStmt, selectKeyByAlgorithmSQL},
+ {&s.deleteOneTimeKeyStmt, deleteOneTimeKeySQL},
+ {&s.deleteOneTimeKeysStmt, deleteOneTimeKeysSQL},
+ }.Prepare(db)
+}
+
+func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
+ rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID, pq.Array(keyIDsWithAlgorithms))
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed")
+
+ result := make(map[string]json.RawMessage)
+ var (
+ algorithmWithID string
+ keyJSONStr string
+ )
+ for rows.Next() {
+ if err := rows.Scan(&algorithmWithID, &keyJSONStr); err != nil {
+ return nil, err
+ }
+ result[algorithmWithID] = json.RawMessage(keyJSONStr)
+ }
+ return result, rows.Err()
+}
+
+func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
+ counts := &api.OneTimeKeysCount{
+ DeviceID: deviceID,
+ UserID: userID,
+ KeyCount: make(map[string]int),
+ }
+ rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
+ for rows.Next() {
+ var algorithm string
+ var count int
+ if err = rows.Scan(&algorithm, &count); err != nil {
+ return nil, err
+ }
+ counts.KeyCount[algorithm] = count
+ }
+ return counts, nil
+}
+
+func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) {
+ now := time.Now().Unix()
+ counts := &api.OneTimeKeysCount{
+ DeviceID: keys.DeviceID,
+ UserID: keys.UserID,
+ KeyCount: make(map[string]int),
+ }
+ for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
+ algo, keyID := keys.Split(keyIDWithAlgo)
+ _, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
+ ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
+ )
+ if err != nil {
+ return nil, err
+ }
+ }
+ rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
+ for rows.Next() {
+ var algorithm string
+ var count int
+ if err = rows.Scan(&algorithm, &count); err != nil {
+ return nil, err
+ }
+ counts.KeyCount[algorithm] = count
+ }
+
+ return counts, rows.Err()
+}
+
+func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
+ ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
+) (map[string]json.RawMessage, error) {
+ var keyID string
+ var keyJSON string
+ err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return nil, nil
+ }
+ return nil, err
+ }
+ _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
+ return map[string]json.RawMessage{
+ algorithm + ":" + keyID: json.RawMessage(keyJSON),
+ }, err
+}
+
+func (s *oneTimeKeysStatements) DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
+ _, err := sqlutil.TxStmt(txn, s.deleteOneTimeKeysStmt).ExecContext(ctx, userID, deviceID)
+ return err
+}
diff --git a/userapi/storage/postgres/stale_device_lists.go b/userapi/storage/postgres/stale_device_lists.go
new file mode 100644
index 00000000..c823b58c
--- /dev/null
+++ b/userapi/storage/postgres/stale_device_lists.go
@@ -0,0 +1,131 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package postgres
+
+import (
+ "context"
+ "database/sql"
+ "time"
+
+ "github.com/lib/pq"
+
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+var staleDeviceListsSchema = `
+-- Stores whether a user's device lists are stale or not.
+CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists (
+ user_id TEXT PRIMARY KEY NOT NULL,
+ domain TEXT NOT NULL,
+ is_stale BOOLEAN NOT NULL,
+ ts_added_secs BIGINT NOT NULL
+);
+
+CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale);
+`
+
+const upsertStaleDeviceListSQL = "" +
+ "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" +
+ " VALUES ($1, $2, $3, $4)" +
+ " ON CONFLICT (user_id)" +
+ " DO UPDATE SET is_stale = $3, ts_added_secs = $4"
+
+const selectStaleDeviceListsWithDomainsSQL = "" +
+ "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC"
+
+const selectStaleDeviceListsSQL = "" +
+ "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC"
+
+const deleteStaleDevicesSQL = "" +
+ "DELETE FROM keyserver_stale_device_lists WHERE user_id = ANY($1)"
+
+type staleDeviceListsStatements struct {
+ upsertStaleDeviceListStmt *sql.Stmt
+ selectStaleDeviceListsWithDomainsStmt *sql.Stmt
+ selectStaleDeviceListsStmt *sql.Stmt
+ deleteStaleDeviceListsStmt *sql.Stmt
+}
+
+func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
+ s := &staleDeviceListsStatements{}
+ _, err := db.Exec(staleDeviceListsSchema)
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL},
+ {&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL},
+ {&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL},
+ {&s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL},
+ }.Prepare(db)
+}
+
+func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
+ _, domain, err := gomatrixserverlib.SplitID('@', userID)
+ if err != nil {
+ return err
+ }
+ _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now()))
+ return err
+}
+
+func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
+ // we only query for 1 domain or all domains so optimise for those use cases
+ if len(domains) == 0 {
+ rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true)
+ if err != nil {
+ return nil, err
+ }
+ return rowsToUserIDs(ctx, rows)
+ }
+ var result []string
+ for _, domain := range domains {
+ rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain))
+ if err != nil {
+ return nil, err
+ }
+ userIDs, err := rowsToUserIDs(ctx, rows)
+ if err != nil {
+ return nil, err
+ }
+ result = append(result, userIDs...)
+ }
+ return result, nil
+}
+
+// DeleteStaleDeviceLists removes users from stale device lists
+func (s *staleDeviceListsStatements) DeleteStaleDeviceLists(
+ ctx context.Context, txn *sql.Tx, userIDs []string,
+) error {
+ stmt := sqlutil.TxStmt(txn, s.deleteStaleDeviceListsStmt)
+ _, err := stmt.ExecContext(ctx, pq.Array(userIDs))
+ return err
+}
+
+func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
+ defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
+ for rows.Next() {
+ var userID string
+ if err := rows.Scan(&userID); err != nil {
+ return nil, err
+ }
+ result = append(result, userID)
+ }
+ return result, rows.Err()
+}
diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go
index 92dc4808..673d123b 100644
--- a/userapi/storage/postgres/storage.go
+++ b/userapi/storage/postgres/storage.go
@@ -136,3 +136,44 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
}, nil
}
+
+func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) {
+ db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter())
+ if err != nil {
+ return nil, err
+ }
+ otk, err := NewPostgresOneTimeKeysTable(db)
+ if err != nil {
+ return nil, err
+ }
+ dk, err := NewPostgresDeviceKeysTable(db)
+ if err != nil {
+ return nil, err
+ }
+ kc, err := NewPostgresKeyChangesTable(db)
+ if err != nil {
+ return nil, err
+ }
+ sdl, err := NewPostgresStaleDeviceListsTable(db)
+ if err != nil {
+ return nil, err
+ }
+ csk, err := NewPostgresCrossSigningKeysTable(db)
+ if err != nil {
+ return nil, err
+ }
+ css, err := NewPostgresCrossSigningSigsTable(db)
+ if err != nil {
+ return nil, err
+ }
+
+ return &shared.KeyDatabase{
+ OneTimeKeysTable: otk,
+ DeviceKeysTable: dk,
+ KeyChangesTable: kc,
+ StaleDeviceListsTable: sdl,
+ CrossSigningKeysTable: csk,
+ CrossSigningSigsTable: css,
+ Writer: writer,
+ }, nil
+}
diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go
index bf94f14d..d3272a03 100644
--- a/userapi/storage/shared/storage.go
+++ b/userapi/storage/shared/storage.go
@@ -59,6 +59,17 @@ type Database struct {
OpenIDTokenLifetimeMS int64
}
+type KeyDatabase struct {
+ OneTimeKeysTable tables.OneTimeKeys
+ DeviceKeysTable tables.DeviceKeys
+ KeyChangesTable tables.KeyChanges
+ StaleDeviceListsTable tables.StaleDeviceLists
+ CrossSigningKeysTable tables.CrossSigningKeys
+ CrossSigningSigsTable tables.CrossSigningSigs
+ DB *sql.DB
+ Writer sqlutil.Writer
+}
+
const (
// The length of generated device IDs
deviceIDByteLength = 6
@@ -875,3 +886,227 @@ func (d *Database) DailyRoomsMessages(
) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error) {
return d.Stats.DailyRoomsMessages(ctx, nil, serverName)
}
+
+//
+
+func (d *KeyDatabase) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
+ return d.OneTimeKeysTable.SelectOneTimeKeys(ctx, userID, deviceID, keyIDsWithAlgorithms)
+}
+
+func (d *KeyDatabase) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (counts *api.OneTimeKeysCount, err error) {
+ _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ counts, err = d.OneTimeKeysTable.InsertOneTimeKeys(ctx, txn, keys)
+ return err
+ })
+ return
+}
+
+func (d *KeyDatabase) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
+ return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
+}
+
+func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
+ return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
+}
+
+func (d *KeyDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) {
+ count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, prevIDs)
+ if err != nil {
+ return false, err
+ }
+ return count == len(prevIDs), nil
+}
+
+func (d *KeyDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error {
+ return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ for _, userID := range clearUserIDs {
+ err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID)
+ if err != nil {
+ return err
+ }
+ }
+ return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
+ })
+}
+
+func (d *KeyDatabase) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
+ // work out the latest stream IDs for each user
+ userIDToStreamID := make(map[string]int64)
+ for _, k := range keys {
+ userIDToStreamID[k.UserID] = 0
+ }
+ return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ for userID := range userIDToStreamID {
+ streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID)
+ if err != nil {
+ return err
+ }
+ userIDToStreamID[userID] = streamID
+ }
+ // set the stream IDs for each key
+ for i := range keys {
+ k := keys[i]
+ userIDToStreamID[k.UserID]++ // start stream from 1
+ k.StreamID = userIDToStreamID[k.UserID]
+ keys[i] = k
+ }
+ return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
+ })
+}
+
+func (d *KeyDatabase) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
+ return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty)
+}
+
+func (d *KeyDatabase) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) {
+ var result []api.OneTimeKeys
+ err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ for userID, deviceToAlgo := range userToDeviceToAlgorithm {
+ for deviceID, algo := range deviceToAlgo {
+ keyJSON, err := d.OneTimeKeysTable.SelectAndDeleteOneTimeKey(ctx, txn, userID, deviceID, algo)
+ if err != nil {
+ return err
+ }
+ if keyJSON != nil {
+ result = append(result, api.OneTimeKeys{
+ UserID: userID,
+ DeviceID: deviceID,
+ KeyJSON: keyJSON,
+ })
+ }
+ }
+ }
+ return nil
+ })
+ return result, err
+}
+
+func (d *KeyDatabase) StoreKeyChange(ctx context.Context, userID string) (id int64, err error) {
+ err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
+ id, err = d.KeyChangesTable.InsertKeyChange(ctx, userID)
+ return err
+ })
+ return
+}
+
+func (d *KeyDatabase) KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) {
+ return d.KeyChangesTable.SelectKeyChanges(ctx, fromOffset, toOffset)
+}
+
+// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
+// If no domains are given, all user IDs with stale device lists are returned.
+func (d *KeyDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
+ return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains)
+}
+
+// MarkDeviceListStale sets the stale bit for this user to isStale.
+func (d *KeyDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
+ return d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
+ return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale)
+ })
+}
+
+// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
+// cross-signing signatures relating to that device.
+func (d *KeyDatabase) DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error {
+ return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ for _, deviceID := range deviceIDs {
+ if err := d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget(ctx, txn, userID, deviceID); err != nil && err != sql.ErrNoRows {
+ return fmt.Errorf("d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget: %w", err)
+ }
+ if err := d.DeviceKeysTable.DeleteDeviceKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows {
+ return fmt.Errorf("d.DeviceKeysTable.DeleteDeviceKeys: %w", err)
+ }
+ if err := d.OneTimeKeysTable.DeleteOneTimeKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows {
+ return fmt.Errorf("d.OneTimeKeysTable.DeleteOneTimeKeys: %w", err)
+ }
+ }
+ return nil
+ })
+}
+
+// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any.
+func (d *KeyDatabase) CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) {
+ keyMap, err := d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID)
+ if err != nil {
+ return nil, fmt.Errorf("d.CrossSigningKeysTable.SelectCrossSigningKeysForUser: %w", err)
+ }
+ results := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{}
+ for purpose, key := range keyMap {
+ keyID := gomatrixserverlib.KeyID("ed25519:" + key.Encode())
+ result := gomatrixserverlib.CrossSigningKey{
+ UserID: userID,
+ Usage: []gomatrixserverlib.CrossSigningKeyPurpose{purpose},
+ Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{
+ keyID: key,
+ },
+ }
+ sigMap, err := d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, userID, userID, keyID)
+ if err != nil {
+ continue
+ }
+ for sigUserID, forSigUserID := range sigMap {
+ if userID != sigUserID {
+ continue
+ }
+ if result.Signatures == nil {
+ result.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
+ }
+ if _, ok := result.Signatures[sigUserID]; !ok {
+ result.Signatures[sigUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
+ }
+ for sigKeyID, sigBytes := range forSigUserID {
+ result.Signatures[sigUserID][sigKeyID] = sigBytes
+ }
+ }
+ results[purpose] = result
+ }
+ return results, nil
+}
+
+// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any.
+func (d *KeyDatabase) CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) {
+ return d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID)
+}
+
+// CrossSigningSigsForTarget returns the signatures for a given user's key ID, if any.
+func (d *KeyDatabase) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) {
+ return d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, originUserID, targetUserID, targetKeyID)
+}
+
+// StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user.
+func (d *KeyDatabase) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error {
+ return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ for keyType, keyData := range keyMap {
+ if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, keyData); err != nil {
+ return fmt.Errorf("d.CrossSigningKeysTable.InsertCrossSigningKeysForUser: %w", err)
+ }
+ }
+ return nil
+ })
+}
+
+// StoreCrossSigningSigsForTarget stores a signature for a target user ID and key/dvice.
+func (d *KeyDatabase) StoreCrossSigningSigsForTarget(
+ ctx context.Context,
+ originUserID string, originKeyID gomatrixserverlib.KeyID,
+ targetUserID string, targetKeyID gomatrixserverlib.KeyID,
+ signature gomatrixserverlib.Base64Bytes,
+) error {
+ return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ if err := d.CrossSigningSigsTable.UpsertCrossSigningSigsForTarget(ctx, nil, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil {
+ return fmt.Errorf("d.CrossSigningSigsTable.InsertCrossSigningSigsForTarget: %w", err)
+ }
+ return nil
+ })
+}
+
+// DeleteStaleDeviceLists deletes stale device list entries for users we don't share a room with anymore.
+func (d *KeyDatabase) DeleteStaleDeviceLists(
+ ctx context.Context,
+ userIDs []string,
+) error {
+ return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ return d.StaleDeviceListsTable.DeleteStaleDeviceLists(ctx, txn, userIDs)
+ })
+}
diff --git a/userapi/storage/sqlite3/cross_signing_keys_table.go b/userapi/storage/sqlite3/cross_signing_keys_table.go
new file mode 100644
index 00000000..10721fcc
--- /dev/null
+++ b/userapi/storage/sqlite3/cross_signing_keys_table.go
@@ -0,0 +1,101 @@
+// Copyright 2021 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/dendrite/userapi/types"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+var crossSigningKeysSchema = `
+CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys (
+ user_id TEXT NOT NULL,
+ key_type INTEGER NOT NULL,
+ key_data TEXT NOT NULL,
+ PRIMARY KEY (user_id, key_type)
+);
+`
+
+const selectCrossSigningKeysForUserSQL = "" +
+ "SELECT key_type, key_data FROM keyserver_cross_signing_keys" +
+ " WHERE user_id = $1"
+
+const upsertCrossSigningKeysForUserSQL = "" +
+ "INSERT OR REPLACE INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" +
+ " VALUES($1, $2, $3)"
+
+type crossSigningKeysStatements struct {
+ db *sql.DB
+ selectCrossSigningKeysForUserStmt *sql.Stmt
+ upsertCrossSigningKeysForUserStmt *sql.Stmt
+}
+
+func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) {
+ s := &crossSigningKeysStatements{
+ db: db,
+ }
+ _, err := db.Exec(crossSigningKeysSchema)
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL},
+ {&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL},
+ }.Prepare(db)
+}
+
+func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser(
+ ctx context.Context, txn *sql.Tx, userID string,
+) (r types.CrossSigningKeyMap, err error) {
+ rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningKeysForUserStmt: rows.close() failed")
+ r = types.CrossSigningKeyMap{}
+ for rows.Next() {
+ var keyTypeInt int16
+ var keyData gomatrixserverlib.Base64Bytes
+ if err := rows.Scan(&keyTypeInt, &keyData); err != nil {
+ return nil, err
+ }
+ keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt]
+ if !ok {
+ return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt)
+ }
+ r[keyType] = keyData
+ }
+ return
+}
+
+func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser(
+ ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes,
+) error {
+ keyTypeInt, ok := types.KeyTypePurposeToInt[keyType]
+ if !ok {
+ return fmt.Errorf("unknown key purpose %q", keyType)
+ }
+ if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData); err != nil {
+ return fmt.Errorf("s.upsertCrossSigningKeysForUserStmt: %w", err)
+ }
+ return nil
+}
diff --git a/userapi/storage/sqlite3/cross_signing_sigs_table.go b/userapi/storage/sqlite3/cross_signing_sigs_table.go
new file mode 100644
index 00000000..2be00c9c
--- /dev/null
+++ b/userapi/storage/sqlite3/cross_signing_sigs_table.go
@@ -0,0 +1,129 @@
+// Copyright 2021 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/dendrite/userapi/types"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+var crossSigningSigsSchema = `
+CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs (
+ origin_user_id TEXT NOT NULL,
+ origin_key_id TEXT NOT NULL,
+ target_user_id TEXT NOT NULL,
+ target_key_id TEXT NOT NULL,
+ signature TEXT NOT NULL,
+ PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id)
+);
+
+CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
+`
+
+const selectCrossSigningSigsForTargetSQL = "" +
+ "SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" +
+ " WHERE (origin_user_id = $1 OR origin_user_id = $2) AND target_user_id = $3 AND target_key_id = $4"
+
+const upsertCrossSigningSigsForTargetSQL = "" +
+ "INSERT OR REPLACE INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" +
+ " VALUES($1, $2, $3, $4, $5)"
+
+const deleteCrossSigningSigsForTargetSQL = "" +
+ "DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2"
+
+type crossSigningSigsStatements struct {
+ db *sql.DB
+ selectCrossSigningSigsForTargetStmt *sql.Stmt
+ upsertCrossSigningSigsForTargetStmt *sql.Stmt
+ deleteCrossSigningSigsForTargetStmt *sql.Stmt
+}
+
+func NewSqliteCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) {
+ s := &crossSigningSigsStatements{
+ db: db,
+ }
+ _, err := db.Exec(crossSigningSigsSchema)
+ if err != nil {
+ return nil, err
+ }
+ m := sqlutil.NewMigrator(db)
+ m.AddMigrations(sqlutil.Migration{
+ Version: "keyserver: cross signing signature indexes",
+ Up: deltas.UpFixCrossSigningSignatureIndexes,
+ })
+ if err = m.Up(context.Background()); err != nil {
+ return nil, err
+ }
+
+ return s, sqlutil.StatementList{
+ {&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL},
+ {&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL},
+ {&s.deleteCrossSigningSigsForTargetStmt, deleteCrossSigningSigsForTargetSQL},
+ }.Prepare(db)
+}
+
+func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget(
+ ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID,
+) (r types.CrossSigningSigMap, err error) {
+ rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetUserID, targetKeyID)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForOriginTargetStmt: rows.close() failed")
+ r = types.CrossSigningSigMap{}
+ for rows.Next() {
+ var userID string
+ var keyID gomatrixserverlib.KeyID
+ var signature gomatrixserverlib.Base64Bytes
+ if err := rows.Scan(&userID, &keyID, &signature); err != nil {
+ return nil, err
+ }
+ if _, ok := r[userID]; !ok {
+ r[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
+ }
+ r[userID][keyID] = signature
+ }
+ return
+}
+
+func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget(
+ ctx context.Context, txn *sql.Tx,
+ originUserID string, originKeyID gomatrixserverlib.KeyID,
+ targetUserID string, targetKeyID gomatrixserverlib.KeyID,
+ signature gomatrixserverlib.Base64Bytes,
+) error {
+ if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil {
+ return fmt.Errorf("s.upsertCrossSigningSigsForTargetStmt: %w", err)
+ }
+ return nil
+}
+
+func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget(
+ ctx context.Context, txn *sql.Tx,
+ targetUserID string, targetKeyID gomatrixserverlib.KeyID,
+) error {
+ if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil {
+ return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err)
+ }
+ return nil
+}
diff --git a/userapi/storage/sqlite3/deltas/2022012016470000_key_changes.go b/userapi/storage/sqlite3/deltas/2022012016470000_key_changes.go
new file mode 100644
index 00000000..cd0f19df
--- /dev/null
+++ b/userapi/storage/sqlite3/deltas/2022012016470000_key_changes.go
@@ -0,0 +1,66 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package deltas
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+)
+
+func UpRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
+ // start counting from the last max offset, else 0.
+ var maxOffset int64
+ var userID string
+ _ = tx.QueryRowContext(ctx, `SELECT user_id, MAX(log_offset) FROM keyserver_key_changes GROUP BY user_id`).Scan(&userID, &maxOffset)
+
+ _, err := tx.ExecContext(ctx, `
+ -- make the new table
+ DROP TABLE IF EXISTS keyserver_key_changes;
+ CREATE TABLE IF NOT EXISTS keyserver_key_changes (
+ change_id INTEGER PRIMARY KEY AUTOINCREMENT,
+ -- The key owner
+ user_id TEXT NOT NULL,
+ UNIQUE (user_id)
+ );
+ `)
+ if err != nil {
+ return fmt.Errorf("failed to execute upgrade: %w", err)
+ }
+ // to start counting from maxOffset, insert a row with that value
+ if userID != "" {
+ _, err = tx.ExecContext(ctx, `INSERT INTO keyserver_key_changes(change_id, user_id) VALUES($1, $2)`, maxOffset, userID)
+ return err
+ }
+ return nil
+}
+
+func DownRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
+ -- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers
+ DROP TABLE IF EXISTS keyserver_key_changes;
+ CREATE TABLE IF NOT EXISTS keyserver_key_changes (
+ partition BIGINT NOT NULL,
+ offset BIGINT NOT NULL,
+ -- The key owner
+ user_id TEXT NOT NULL,
+ UNIQUE (partition, offset)
+ );
+ `)
+ if err != nil {
+ return fmt.Errorf("failed to execute downgrade: %w", err)
+ }
+ return nil
+}
diff --git a/userapi/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go b/userapi/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go
new file mode 100644
index 00000000..d4e38dea
--- /dev/null
+++ b/userapi/storage/sqlite3/deltas/2022042612000000_xsigning_idx.go
@@ -0,0 +1,71 @@
+// Copyright 2022 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package deltas
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+)
+
+func UpFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
+ CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp (
+ origin_user_id TEXT NOT NULL,
+ origin_key_id TEXT NOT NULL,
+ target_user_id TEXT NOT NULL,
+ target_key_id TEXT NOT NULL,
+ signature TEXT NOT NULL,
+ PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id)
+ );
+
+ INSERT INTO keyserver_cross_signing_sigs_tmp (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)
+ SELECT origin_user_id, origin_key_id, target_user_id, target_key_id, signature FROM keyserver_cross_signing_sigs;
+
+ DROP TABLE keyserver_cross_signing_sigs;
+ ALTER TABLE keyserver_cross_signing_sigs_tmp RENAME TO keyserver_cross_signing_sigs;
+
+ CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
+ `)
+ if err != nil {
+ return fmt.Errorf("failed to execute upgrade: %w", err)
+ }
+ return nil
+}
+
+func DownFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
+ _, err := tx.ExecContext(ctx, `
+ CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp (
+ origin_user_id TEXT NOT NULL,
+ origin_key_id TEXT NOT NULL,
+ target_user_id TEXT NOT NULL,
+ target_key_id TEXT NOT NULL,
+ signature TEXT NOT NULL,
+ PRIMARY KEY (origin_user_id, target_user_id, target_key_id)
+ );
+
+ INSERT INTO keyserver_cross_signing_sigs_tmp (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)
+ SELECT origin_user_id, origin_key_id, target_user_id, target_key_id, signature FROM keyserver_cross_signing_sigs;
+
+ DROP TABLE keyserver_cross_signing_sigs;
+ ALTER TABLE keyserver_cross_signing_sigs_tmp RENAME TO keyserver_cross_signing_sigs;
+
+ DELETE INDEX IF EXISTS keyserver_cross_signing_sigs_idx;
+ `)
+ if err != nil {
+ return fmt.Errorf("failed to execute downgrade: %w", err)
+ }
+ return nil
+}
diff --git a/userapi/storage/sqlite3/device_keys_table.go b/userapi/storage/sqlite3/device_keys_table.go
new file mode 100644
index 00000000..15e69cc4
--- /dev/null
+++ b/userapi/storage/sqlite3/device_keys_table.go
@@ -0,0 +1,213 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "strings"
+ "time"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+)
+
+var deviceKeysSchema = `
+-- Stores device keys for users
+CREATE TABLE IF NOT EXISTS keyserver_device_keys (
+ user_id TEXT NOT NULL,
+ device_id TEXT NOT NULL,
+ ts_added_secs BIGINT NOT NULL,
+ key_json TEXT NOT NULL,
+ stream_id BIGINT NOT NULL,
+ display_name TEXT,
+ -- Clobber based on tuple of user/device.
+ UNIQUE (user_id, device_id)
+);
+`
+
+const upsertDeviceKeysSQL = "" +
+ "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
+ " VALUES ($1, $2, $3, $4, $5, $6)" +
+ " ON CONFLICT (user_id, device_id)" +
+ " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
+
+const selectDeviceKeysSQL = "" +
+ "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
+
+const selectBatchDeviceKeysSQL = "" +
+ "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
+
+const selectBatchDeviceKeysWithEmptiesSQL = "" +
+ "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
+
+const selectMaxStreamForUserSQL = "" +
+ "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
+
+const countStreamIDsForUserSQL = "" +
+ "SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)"
+
+const deleteDeviceKeysSQL = "" +
+ "DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
+
+const deleteAllDeviceKeysSQL = "" +
+ "DELETE FROM keyserver_device_keys WHERE user_id=$1"
+
+type deviceKeysStatements struct {
+ db *sql.DB
+ upsertDeviceKeysStmt *sql.Stmt
+ selectDeviceKeysStmt *sql.Stmt
+ selectBatchDeviceKeysStmt *sql.Stmt
+ selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
+ selectMaxStreamForUserStmt *sql.Stmt
+ deleteDeviceKeysStmt *sql.Stmt
+ deleteAllDeviceKeysStmt *sql.Stmt
+}
+
+func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
+ s := &deviceKeysStatements{
+ db: db,
+ }
+ _, err := db.Exec(deviceKeysSchema)
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.upsertDeviceKeysStmt, upsertDeviceKeysSQL},
+ {&s.selectDeviceKeysStmt, selectDeviceKeysSQL},
+ {&s.selectBatchDeviceKeysStmt, selectBatchDeviceKeysSQL},
+ {&s.selectBatchDeviceKeysWithEmptiesStmt, selectBatchDeviceKeysWithEmptiesSQL},
+ {&s.selectMaxStreamForUserStmt, selectMaxStreamForUserSQL},
+ // {&s.countStreamIDsForUserStmt, countStreamIDsForUserSQL}, // prepared at runtime
+ {&s.deleteDeviceKeysStmt, deleteDeviceKeysSQL},
+ {&s.deleteAllDeviceKeysStmt, deleteAllDeviceKeysSQL},
+ }.Prepare(db)
+}
+
+func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
+ _, err := sqlutil.TxStmt(txn, s.deleteDeviceKeysStmt).ExecContext(ctx, userID, deviceID)
+ return err
+}
+
+func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
+ _, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
+ return err
+}
+
+func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
+ deviceIDMap := make(map[string]bool)
+ for _, d := range deviceIDs {
+ deviceIDMap[d] = true
+ }
+ var stmt *sql.Stmt
+ if includeEmpty {
+ stmt = s.selectBatchDeviceKeysWithEmptiesStmt
+ } else {
+ stmt = s.selectBatchDeviceKeysStmt
+ }
+ rows, err := stmt.QueryContext(ctx, userID)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
+ var result []api.DeviceMessage
+ var displayName sql.NullString
+ for rows.Next() {
+ dk := api.DeviceMessage{
+ Type: api.TypeDeviceKeyUpdate,
+ DeviceKeys: &api.DeviceKeys{
+ UserID: userID,
+ },
+ }
+ if err := rows.Scan(&dk.DeviceID, &dk.KeyJSON, &dk.StreamID, &displayName); err != nil {
+ return nil, err
+ }
+ if displayName.Valid {
+ dk.DisplayName = displayName.String
+ }
+ // include the key if we want all keys (no device) or it was asked
+ if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
+ result = append(result, dk)
+ }
+ }
+ return result, rows.Err()
+}
+
+func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
+ for i, key := range keys {
+ var keyJSONStr string
+ var streamID int64
+ var displayName sql.NullString
+ err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
+ if err != nil && err != sql.ErrNoRows {
+ return err
+ }
+ // this will be '' when there is no device
+ keys[i].Type = api.TypeDeviceKeyUpdate
+ keys[i].KeyJSON = []byte(keyJSONStr)
+ keys[i].StreamID = streamID
+ if displayName.Valid {
+ keys[i].DisplayName = displayName.String
+ }
+ }
+ return nil
+}
+
+func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) {
+ // nullable if there are no results
+ var nullStream sql.NullInt64
+ err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
+ if err == sql.ErrNoRows {
+ err = nil
+ }
+ if nullStream.Valid {
+ streamID = nullStream.Int64
+ }
+ return
+}
+
+func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) {
+ iStreamIDs := make([]interface{}, len(streamIDs)+1)
+ iStreamIDs[0] = userID
+ for i := range streamIDs {
+ iStreamIDs[i+1] = streamIDs[i]
+ }
+ query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1)
+ // nullable if there are no results
+ var count sql.NullInt64
+ err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count)
+ if err != nil {
+ return 0, err
+ }
+ if count.Valid {
+ return int(count.Int64), nil
+ }
+ return 0, nil
+}
+
+func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
+ for _, key := range keys {
+ now := time.Now().Unix()
+ _, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
+ ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
+ )
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go
index 449e4549..65e17527 100644
--- a/userapi/storage/sqlite3/devices_table.go
+++ b/userapi/storage/sqlite3/devices_table.go
@@ -151,7 +151,7 @@ func (s *devicesStatements) InsertDevice(
if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
return nil, err
}
- return &api.Device{
+ dev := &api.Device{
ID: id,
UserID: userutil.MakeUserID(localpart, serverName),
AccessToken: accessToken,
@@ -159,7 +159,11 @@ func (s *devicesStatements) InsertDevice(
LastSeenTS: createdTimeMS,
LastSeenIP: ipAddr,
UserAgent: userAgent,
- }, nil
+ }
+ if displayName != nil {
+ dev.DisplayName = *displayName
+ }
+ return dev, nil
}
func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id,
@@ -172,7 +176,7 @@ func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *
if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
return nil, err
}
- return &api.Device{
+ dev := &api.Device{
ID: id,
UserID: userutil.MakeUserID(localpart, serverName),
AccessToken: accessToken,
@@ -180,7 +184,11 @@ func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *
LastSeenTS: createdTimeMS,
LastSeenIP: ipAddr,
UserAgent: userAgent,
- }, nil
+ }
+ if displayName != nil {
+ dev.DisplayName = *displayName
+ }
+ return dev, nil
}
func (s *devicesStatements) DeleteDevice(
@@ -202,6 +210,7 @@ func (s *devicesStatements) DeleteDevices(
if err != nil {
return err
}
+ defer internal.CloseAndLogIfError(ctx, prep, "DeleteDevices.StmtClose() failed")
stmt := sqlutil.TxStmt(txn, prep)
params := make([]interface{}, len(devices)+2)
params[0] = localpart
diff --git a/userapi/storage/sqlite3/key_backup_table.go b/userapi/storage/sqlite3/key_backup_table.go
index 7883ffb1..ed274631 100644
--- a/userapi/storage/sqlite3/key_backup_table.go
+++ b/userapi/storage/sqlite3/key_backup_table.go
@@ -52,7 +52,7 @@ const updateBackupKeySQL = "" +
const countKeysSQL = "" +
"SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2"
-const selectKeysSQL = "" +
+const selectBackupKeysSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
"WHERE user_id = $1 AND version = $2"
@@ -83,7 +83,7 @@ func NewSQLiteKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) {
{&s.insertBackupKeyStmt, insertBackupKeySQL},
{&s.updateBackupKeyStmt, updateBackupKeySQL},
{&s.countKeysStmt, countKeysSQL},
- {&s.selectKeysStmt, selectKeysSQL},
+ {&s.selectKeysStmt, selectBackupKeysSQL},
{&s.selectKeysByRoomIDStmt, selectKeysByRoomIDSQL},
{&s.selectKeysByRoomIDAndSessionIDStmt, selectKeysByRoomIDAndSessionIDSQL},
}.Prepare(db)
diff --git a/userapi/storage/sqlite3/key_changes_table.go b/userapi/storage/sqlite3/key_changes_table.go
new file mode 100644
index 00000000..923bb57e
--- /dev/null
+++ b/userapi/storage/sqlite3/key_changes_table.go
@@ -0,0 +1,125 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+)
+
+var keyChangesSchema = `
+-- Stores key change information about users. Used to determine when to send updated device lists to clients.
+CREATE TABLE IF NOT EXISTS keyserver_key_changes (
+ change_id INTEGER PRIMARY KEY AUTOINCREMENT,
+ -- The key owner
+ user_id TEXT NOT NULL,
+ UNIQUE (user_id)
+);
+`
+
+// Replace based on user ID. We don't care how many times the user's keys have changed, only that they
+// have changed, hence we can just keep bumping the change ID for this user.
+const upsertKeyChangeSQL = "" +
+ "INSERT OR REPLACE INTO keyserver_key_changes (user_id)" +
+ " VALUES ($1)" +
+ " RETURNING change_id"
+
+const selectKeyChangesSQL = "" +
+ "SELECT user_id, change_id FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2"
+
+type keyChangesStatements struct {
+ db *sql.DB
+ upsertKeyChangeStmt *sql.Stmt
+ selectKeyChangesStmt *sql.Stmt
+}
+
+func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
+ s := &keyChangesStatements{
+ db: db,
+ }
+ _, err := db.Exec(keyChangesSchema)
+ if err != nil {
+ return s, err
+ }
+
+ if err = executeMigration(context.Background(), db); err != nil {
+ return nil, err
+ }
+
+ return s, sqlutil.StatementList{
+ {&s.upsertKeyChangeStmt, upsertKeyChangeSQL},
+ {&s.selectKeyChangesStmt, selectKeyChangesSQL},
+ }.Prepare(db)
+}
+
+func executeMigration(ctx context.Context, db *sql.DB) error {
+ // TODO: Remove when we are sure we are not having goose artefacts in the db
+ // This forces an error, which indicates the migration is already applied, since the
+ // column partition was removed from the table
+ migrationName := "keyserver: refactor key changes"
+
+ var cName string
+ err := db.QueryRowContext(ctx, `SELECT p.name FROM sqlite_master AS m JOIN pragma_table_info(m.name) AS p WHERE m.name = 'keyserver_key_changes' AND p.name = 'partition'`).Scan(&cName)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) { // migration was already executed, as the column was removed
+ if err = sqlutil.InsertMigration(ctx, db, migrationName); err != nil {
+ return fmt.Errorf("unable to manually insert migration '%s': %w", migrationName, err)
+ }
+ return nil
+ }
+ return err
+ }
+ m := sqlutil.NewMigrator(db)
+ m.AddMigrations(sqlutil.Migration{
+ Version: migrationName,
+ Up: deltas.UpRefactorKeyChanges,
+ })
+ return m.Up(ctx)
+}
+
+func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) {
+ err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID)
+ return
+}
+
+func (s *keyChangesStatements) SelectKeyChanges(
+ ctx context.Context, fromOffset, toOffset int64,
+) (userIDs []string, latestOffset int64, err error) {
+ latestOffset = fromOffset
+ rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset)
+ if err != nil {
+ return nil, 0, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed")
+ for rows.Next() {
+ var userID string
+ var offset int64
+ if err := rows.Scan(&userID, &offset); err != nil {
+ return nil, 0, err
+ }
+ if offset > latestOffset {
+ latestOffset = offset
+ }
+ userIDs = append(userIDs, userID)
+ }
+ return
+}
diff --git a/userapi/storage/sqlite3/one_time_keys_table.go b/userapi/storage/sqlite3/one_time_keys_table.go
new file mode 100644
index 00000000..a992d399
--- /dev/null
+++ b/userapi/storage/sqlite3/one_time_keys_table.go
@@ -0,0 +1,208 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "time"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+)
+
+var oneTimeKeysSchema = `
+-- Stores one-time public keys for users
+CREATE TABLE IF NOT EXISTS keyserver_one_time_keys (
+ user_id TEXT NOT NULL,
+ device_id TEXT NOT NULL,
+ key_id TEXT NOT NULL,
+ algorithm TEXT NOT NULL,
+ ts_added_secs BIGINT NOT NULL,
+ key_json TEXT NOT NULL,
+ -- Clobber based on 4-uple of user/device/key/algorithm.
+ UNIQUE (user_id, device_id, key_id, algorithm)
+);
+
+CREATE INDEX IF NOT EXISTS keyserver_one_time_keys_idx ON keyserver_one_time_keys (user_id, device_id);
+`
+
+const upsertKeysSQL = "" +
+ "INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" +
+ " VALUES ($1, $2, $3, $4, $5, $6)" +
+ " ON CONFLICT (user_id, device_id, key_id, algorithm)" +
+ " DO UPDATE SET key_json = $6"
+
+const selectOneTimeKeysSQL = "" +
+ "SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2"
+
+const selectKeysCountSQL = "" +
+ "SELECT algorithm, COUNT(key_id) FROM " +
+ " (SELECT algorithm, key_id FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 LIMIT 100)" +
+ " x GROUP BY algorithm"
+
+const deleteOneTimeKeySQL = "" +
+ "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4"
+
+const selectKeyByAlgorithmSQL = "" +
+ "SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
+
+const deleteOneTimeKeysSQL = "" +
+ "DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2"
+
+type oneTimeKeysStatements struct {
+ db *sql.DB
+ upsertKeysStmt *sql.Stmt
+ selectKeysStmt *sql.Stmt
+ selectKeysCountStmt *sql.Stmt
+ selectKeyByAlgorithmStmt *sql.Stmt
+ deleteOneTimeKeyStmt *sql.Stmt
+ deleteOneTimeKeysStmt *sql.Stmt
+}
+
+func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
+ s := &oneTimeKeysStatements{
+ db: db,
+ }
+ _, err := db.Exec(oneTimeKeysSchema)
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.upsertKeysStmt, upsertKeysSQL},
+ {&s.selectKeysStmt, selectOneTimeKeysSQL},
+ {&s.selectKeysCountStmt, selectKeysCountSQL},
+ {&s.selectKeyByAlgorithmStmt, selectKeyByAlgorithmSQL},
+ {&s.deleteOneTimeKeyStmt, deleteOneTimeKeySQL},
+ {&s.deleteOneTimeKeysStmt, deleteOneTimeKeysSQL},
+ }.Prepare(db)
+}
+
+func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
+ rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed")
+
+ wantSet := make(map[string]bool, len(keyIDsWithAlgorithms))
+ for _, ka := range keyIDsWithAlgorithms {
+ wantSet[ka] = true
+ }
+
+ result := make(map[string]json.RawMessage)
+ for rows.Next() {
+ var keyID string
+ var algorithm string
+ var keyJSONStr string
+ if err := rows.Scan(&keyID, &algorithm, &keyJSONStr); err != nil {
+ return nil, err
+ }
+ keyIDWithAlgo := algorithm + ":" + keyID
+ if wantSet[keyIDWithAlgo] {
+ result[keyIDWithAlgo] = json.RawMessage(keyJSONStr)
+ }
+ }
+ return result, rows.Err()
+}
+
+func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
+ counts := &api.OneTimeKeysCount{
+ DeviceID: deviceID,
+ UserID: userID,
+ KeyCount: make(map[string]int),
+ }
+ rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
+ for rows.Next() {
+ var algorithm string
+ var count int
+ if err = rows.Scan(&algorithm, &count); err != nil {
+ return nil, err
+ }
+ counts.KeyCount[algorithm] = count
+ }
+ return counts, nil
+}
+
+func (s *oneTimeKeysStatements) InsertOneTimeKeys(
+ ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys,
+) (*api.OneTimeKeysCount, error) {
+ now := time.Now().Unix()
+ counts := &api.OneTimeKeysCount{
+ DeviceID: keys.DeviceID,
+ UserID: keys.UserID,
+ KeyCount: make(map[string]int),
+ }
+ for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
+ algo, keyID := keys.Split(keyIDWithAlgo)
+ _, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
+ ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
+ )
+ if err != nil {
+ return nil, err
+ }
+ }
+ rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
+ for rows.Next() {
+ var algorithm string
+ var count int
+ if err = rows.Scan(&algorithm, &count); err != nil {
+ return nil, err
+ }
+ counts.KeyCount[algorithm] = count
+ }
+
+ return counts, rows.Err()
+}
+
+func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
+ ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
+) (map[string]json.RawMessage, error) {
+ var keyID string
+ var keyJSON string
+ err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return nil, nil
+ }
+ return nil, err
+ }
+ _, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
+ if err != nil {
+ return nil, err
+ }
+ if keyJSON == "" {
+ return nil, nil
+ }
+ return map[string]json.RawMessage{
+ algorithm + ":" + keyID: json.RawMessage(keyJSON),
+ }, err
+}
+
+func (s *oneTimeKeysStatements) DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
+ _, err := sqlutil.TxStmt(txn, s.deleteOneTimeKeysStmt).ExecContext(ctx, userID, deviceID)
+ return err
+}
diff --git a/userapi/storage/sqlite3/stale_device_lists.go b/userapi/storage/sqlite3/stale_device_lists.go
new file mode 100644
index 00000000..f078fc99
--- /dev/null
+++ b/userapi/storage/sqlite3/stale_device_lists.go
@@ -0,0 +1,145 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package sqlite3
+
+import (
+ "context"
+ "database/sql"
+ "strings"
+ "time"
+
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+
+ "github.com/matrix-org/dendrite/internal"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+ "github.com/matrix-org/gomatrixserverlib"
+)
+
+var staleDeviceListsSchema = `
+-- Stores whether a user's device lists are stale or not.
+CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists (
+ user_id TEXT PRIMARY KEY NOT NULL,
+ domain TEXT NOT NULL,
+ is_stale BOOLEAN NOT NULL,
+ ts_added_secs BIGINT NOT NULL
+);
+
+CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale);
+`
+
+const upsertStaleDeviceListSQL = "" +
+ "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" +
+ " VALUES ($1, $2, $3, $4)" +
+ " ON CONFLICT (user_id)" +
+ " DO UPDATE SET is_stale = $3, ts_added_secs = $4"
+
+const selectStaleDeviceListsWithDomainsSQL = "" +
+ "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC"
+
+const selectStaleDeviceListsSQL = "" +
+ "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC"
+
+const deleteStaleDevicesSQL = "" +
+ "DELETE FROM keyserver_stale_device_lists WHERE user_id IN ($1)"
+
+type staleDeviceListsStatements struct {
+ db *sql.DB
+ upsertStaleDeviceListStmt *sql.Stmt
+ selectStaleDeviceListsWithDomainsStmt *sql.Stmt
+ selectStaleDeviceListsStmt *sql.Stmt
+ // deleteStaleDeviceListsStmt *sql.Stmt // Prepared at runtime
+}
+
+func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
+ s := &staleDeviceListsStatements{
+ db: db,
+ }
+ _, err := db.Exec(staleDeviceListsSchema)
+ if err != nil {
+ return nil, err
+ }
+ return s, sqlutil.StatementList{
+ {&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL},
+ {&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL},
+ {&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL},
+ // { &s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL}, // Prepared at runtime
+ }.Prepare(db)
+}
+
+func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
+ _, domain, err := gomatrixserverlib.SplitID('@', userID)
+ if err != nil {
+ return err
+ }
+ _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now()))
+ return err
+}
+
+func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
+ // we only query for 1 domain or all domains so optimise for those use cases
+ if len(domains) == 0 {
+ rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true)
+ if err != nil {
+ return nil, err
+ }
+ return rowsToUserIDs(ctx, rows)
+ }
+ var result []string
+ for _, domain := range domains {
+ rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain))
+ if err != nil {
+ return nil, err
+ }
+ userIDs, err := rowsToUserIDs(ctx, rows)
+ if err != nil {
+ return nil, err
+ }
+ result = append(result, userIDs...)
+ }
+ return result, nil
+}
+
+// DeleteStaleDeviceLists removes users from stale device lists
+func (s *staleDeviceListsStatements) DeleteStaleDeviceLists(
+ ctx context.Context, txn *sql.Tx, userIDs []string,
+) error {
+ qry := strings.Replace(deleteStaleDevicesSQL, "($1)", sqlutil.QueryVariadic(len(userIDs)), 1)
+ stmt, err := s.db.Prepare(qry)
+ if err != nil {
+ return err
+ }
+ defer internal.CloseAndLogIfError(ctx, stmt, "DeleteStaleDeviceLists: stmt.Close failed")
+ stmt = sqlutil.TxStmt(txn, stmt)
+
+ params := make([]any, len(userIDs))
+ for i := range userIDs {
+ params[i] = userIDs[i]
+ }
+
+ _, err = stmt.ExecContext(ctx, params...)
+ return err
+}
+
+func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
+ defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
+ for rows.Next() {
+ var userID string
+ if err := rows.Scan(&userID); err != nil {
+ return nil, err
+ }
+ result = append(result, userID)
+ }
+ return result, rows.Err()
+}
diff --git a/userapi/storage/sqlite3/stats_table.go b/userapi/storage/sqlite3/stats_table.go
index a1365c94..72b3ba49 100644
--- a/userapi/storage/sqlite3/stats_table.go
+++ b/userapi/storage/sqlite3/stats_table.go
@@ -256,6 +256,7 @@ func (s *statsStatements) allUsers(ctx context.Context, txn *sql.Tx) (result int
if err != nil {
return 0, err
}
+ defer internal.CloseAndLogIfError(ctx, queryStmt, "allUsers.StmtClose() failed")
stmt := sqlutil.TxStmt(txn, queryStmt)
err = stmt.QueryRowContext(ctx,
1, 2, 3, 4,
@@ -269,6 +270,7 @@ func (s *statsStatements) nonBridgedUsers(ctx context.Context, txn *sql.Tx) (res
if err != nil {
return 0, err
}
+ defer internal.CloseAndLogIfError(ctx, queryStmt, "nonBridgedUsers.StmtClose() failed")
stmt := sqlutil.TxStmt(txn, queryStmt)
err = stmt.QueryRowContext(ctx,
1, 2, 3,
@@ -286,6 +288,7 @@ func (s *statsStatements) registeredUserByType(ctx context.Context, txn *sql.Tx)
if err != nil {
return nil, err
}
+ defer internal.CloseAndLogIfError(ctx, queryStmt, "registeredUserByType.StmtClose() failed")
stmt := sqlutil.TxStmt(txn, queryStmt)
registeredAfter := time.Now().AddDate(0, 0, -30)
diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go
index 85a1f706..0f3eeed1 100644
--- a/userapi/storage/sqlite3/storage.go
+++ b/userapi/storage/sqlite3/storage.go
@@ -30,8 +30,8 @@ import (
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
)
-// NewDatabase creates a new accounts and profiles database
-func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) {
+// NewUserDatabase creates a new accounts and profiles database
+func NewUserDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) {
db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter())
if err != nil {
return nil, err
@@ -134,3 +134,44 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
}, nil
}
+
+func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) {
+ db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter())
+ if err != nil {
+ return nil, err
+ }
+ otk, err := NewSqliteOneTimeKeysTable(db)
+ if err != nil {
+ return nil, err
+ }
+ dk, err := NewSqliteDeviceKeysTable(db)
+ if err != nil {
+ return nil, err
+ }
+ kc, err := NewSqliteKeyChangesTable(db)
+ if err != nil {
+ return nil, err
+ }
+ sdl, err := NewSqliteStaleDeviceListsTable(db)
+ if err != nil {
+ return nil, err
+ }
+ csk, err := NewSqliteCrossSigningKeysTable(db)
+ if err != nil {
+ return nil, err
+ }
+ css, err := NewSqliteCrossSigningSigsTable(db)
+ if err != nil {
+ return nil, err
+ }
+
+ return &shared.KeyDatabase{
+ OneTimeKeysTable: otk,
+ DeviceKeysTable: dk,
+ KeyChangesTable: kc,
+ StaleDeviceListsTable: sdl,
+ CrossSigningKeysTable: csk,
+ CrossSigningSigsTable: css,
+ Writer: writer,
+ }, nil
+}
diff --git a/userapi/storage/storage.go b/userapi/storage/storage.go
index 42221e75..0329fb46 100644
--- a/userapi/storage/storage.go
+++ b/userapi/storage/storage.go
@@ -29,15 +29,36 @@ import (
"github.com/matrix-org/dendrite/userapi/storage/sqlite3"
)
-// NewUserAPIDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
+// NewUserDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
// and sets postgres connection parameters
-func NewUserAPIDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (Database, error) {
+func NewUserDatabase(
+ base *base.BaseDendrite,
+ dbProperties *config.DatabaseOptions,
+ serverName gomatrixserverlib.ServerName,
+ bcryptCost int,
+ openIDTokenLifetimeMS int64,
+ loginTokenLifetime time.Duration,
+ serverNoticesLocalpart string,
+) (UserDatabase, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
- return sqlite3.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
+ return sqlite3.NewUserDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
case dbProperties.ConnectionString.IsPostgres():
return postgres.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
default:
return nil, fmt.Errorf("unexpected database type")
}
}
+
+// NewKeyDatabase opens a new Postgres or Sqlite database (base on dataSourceName) scheme)
+// and sets postgres connection parameters.
+func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (KeyDatabase, error) {
+ switch {
+ case dbProperties.ConnectionString.IsSQLite():
+ return sqlite3.NewKeyDatabase(base, dbProperties)
+ case dbProperties.ConnectionString.IsPostgres():
+ return postgres.NewKeyDatabase(base, dbProperties)
+ default:
+ return nil, fmt.Errorf("unexpected database type")
+ }
+}
diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go
index 23aafff0..f52e7e17 100644
--- a/userapi/storage/storage_test.go
+++ b/userapi/storage/storage_test.go
@@ -4,9 +4,12 @@ import (
"context"
"encoding/json"
"fmt"
+ "reflect"
+ "sync"
"testing"
"time"
+ "github.com/matrix-org/dendrite/userapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/stretchr/testify/assert"
@@ -29,14 +32,14 @@ var (
ctx = context.Background()
)
-func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
+func mustCreateUserDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) {
base, baseclose := testrig.CreateBaseDendrite(t, dbType)
connStr, close := test.PrepareDBConnectionString(t, dbType)
- db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{
+ db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server")
if err != nil {
- t.Fatalf("NewUserAPIDatabase returned %s", err)
+ t.Fatalf("NewUserDatabase returned %s", err)
}
return db, func() {
close()
@@ -47,7 +50,7 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, fun
// Tests storing and getting account data
func Test_AccountData(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, close := mustCreateDatabase(t, dbType)
+ db, close := mustCreateUserDatabase(t, dbType)
defer close()
alice := test.NewUser(t)
localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID)
@@ -78,7 +81,7 @@ func Test_AccountData(t *testing.T) {
// Tests the creation of accounts
func Test_Accounts(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, close := mustCreateDatabase(t, dbType)
+ db, close := mustCreateUserDatabase(t, dbType)
defer close()
alice := test.NewUser(t)
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
@@ -158,7 +161,7 @@ func Test_Devices(t *testing.T) {
accessToken := util.RandomString(16)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, close := mustCreateDatabase(t, dbType)
+ db, close := mustCreateUserDatabase(t, dbType)
defer close()
deviceWithID, err := db.CreateDevice(ctx, localpart, domain, &deviceID, accessToken, nil, "", "")
@@ -238,7 +241,7 @@ func Test_KeyBackup(t *testing.T) {
room := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, close := mustCreateDatabase(t, dbType)
+ db, close := mustCreateUserDatabase(t, dbType)
defer close()
wantAuthData := json.RawMessage("my auth data")
@@ -315,7 +318,7 @@ func Test_KeyBackup(t *testing.T) {
func Test_LoginToken(t *testing.T) {
alice := test.NewUser(t)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, close := mustCreateDatabase(t, dbType)
+ db, close := mustCreateUserDatabase(t, dbType)
defer close()
// create a new token
@@ -347,7 +350,7 @@ func Test_OpenID(t *testing.T) {
token := util.RandomString(24)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, close := mustCreateDatabase(t, dbType)
+ db, close := mustCreateUserDatabase(t, dbType)
defer close()
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + openIDLifetimeMS
@@ -368,7 +371,7 @@ func Test_Profile(t *testing.T) {
assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, close := mustCreateDatabase(t, dbType)
+ db, close := mustCreateUserDatabase(t, dbType)
defer close()
// create account, which also creates a profile
@@ -417,7 +420,7 @@ func Test_Pusher(t *testing.T) {
assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, close := mustCreateDatabase(t, dbType)
+ db, close := mustCreateUserDatabase(t, dbType)
defer close()
appID := util.RandomString(8)
@@ -468,7 +471,7 @@ func Test_ThreePID(t *testing.T) {
assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, close := mustCreateDatabase(t, dbType)
+ db, close := mustCreateUserDatabase(t, dbType)
defer close()
threePID := util.RandomString(8)
medium := util.RandomString(8)
@@ -507,7 +510,7 @@ func Test_Notification(t *testing.T) {
room := test.NewRoom(t, alice)
room2 := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, close := mustCreateDatabase(t, dbType)
+ db, close := mustCreateUserDatabase(t, dbType)
defer close()
// generate some dummy notifications
for i := 0; i < 10; i++ {
@@ -571,3 +574,184 @@ func Test_Notification(t *testing.T) {
assert.Equal(t, int64(0), total)
})
}
+
+func mustCreateKeyDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) {
+ base, close := testrig.CreateBaseDendrite(t, dbType)
+ db, err := storage.NewKeyDatabase(base, &base.Cfg.KeyServer.Database)
+ if err != nil {
+ t.Fatalf("failed to create new database: %v", err)
+ }
+ return db, close
+}
+
+func MustNotError(t *testing.T, err error) {
+ t.Helper()
+ if err == nil {
+ return
+ }
+ t.Fatalf("operation failed: %s", err)
+}
+
+func TestKeyChanges(t *testing.T) {
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, clean := mustCreateKeyDatabase(t, dbType)
+ defer clean()
+ _, err := db.StoreKeyChange(ctx, "@alice:localhost")
+ MustNotError(t, err)
+ deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
+ MustNotError(t, err)
+ deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost")
+ MustNotError(t, err)
+ userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest)
+ if err != nil {
+ t.Fatalf("Failed to KeyChanges: %s", err)
+ }
+ if latest != deviceChangeIDC {
+ t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC)
+ }
+ if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) {
+ t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
+ }
+ })
+}
+
+func TestKeyChangesNoDupes(t *testing.T) {
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, clean := mustCreateKeyDatabase(t, dbType)
+ defer clean()
+ deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
+ MustNotError(t, err)
+ deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost")
+ MustNotError(t, err)
+ if deviceChangeIDA == deviceChangeIDB {
+ t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA)
+ }
+ deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost")
+ MustNotError(t, err)
+ userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest)
+ if err != nil {
+ t.Fatalf("Failed to KeyChanges: %s", err)
+ }
+ if latest != deviceChangeID {
+ t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID)
+ }
+ if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) {
+ t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
+ }
+ })
+}
+
+func TestKeyChangesUpperLimit(t *testing.T) {
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, clean := mustCreateKeyDatabase(t, dbType)
+ defer clean()
+ deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
+ MustNotError(t, err)
+ deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
+ MustNotError(t, err)
+ _, err = db.StoreKeyChange(ctx, "@charlie:localhost")
+ MustNotError(t, err)
+ userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB)
+ if err != nil {
+ t.Fatalf("Failed to KeyChanges: %s", err)
+ }
+ if latest != deviceChangeIDB {
+ t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB)
+ }
+ if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) {
+ t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
+ }
+ })
+}
+
+var dbLock sync.Mutex
+var deviceArray = []string{"AAA", "another_device"}
+
+// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user,
+// and that they are returned correctly when querying for device keys.
+func TestDeviceKeysStreamIDGeneration(t *testing.T) {
+ var err error
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, clean := mustCreateKeyDatabase(t, dbType)
+ defer clean()
+ alice := "@alice:TestDeviceKeysStreamIDGeneration"
+ bob := "@bob:TestDeviceKeysStreamIDGeneration"
+ msgs := []api.DeviceMessage{
+ {
+ Type: api.TypeDeviceKeyUpdate,
+ DeviceKeys: &api.DeviceKeys{
+ DeviceID: "AAA",
+ UserID: alice,
+ KeyJSON: []byte(`{"key":"v1"}`),
+ },
+ // StreamID: 1
+ },
+ {
+ Type: api.TypeDeviceKeyUpdate,
+ DeviceKeys: &api.DeviceKeys{
+ DeviceID: "AAA",
+ UserID: bob,
+ KeyJSON: []byte(`{"key":"v1"}`),
+ },
+ // StreamID: 1 as this is a different user
+ },
+ {
+ Type: api.TypeDeviceKeyUpdate,
+ DeviceKeys: &api.DeviceKeys{
+ DeviceID: "another_device",
+ UserID: alice,
+ KeyJSON: []byte(`{"key":"v1"}`),
+ },
+ // StreamID: 2 as this is a 2nd device key
+ },
+ }
+ MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
+ if msgs[0].StreamID != 1 {
+ t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID)
+ }
+ if msgs[1].StreamID != 1 {
+ t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID)
+ }
+ if msgs[2].StreamID != 2 {
+ t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID)
+ }
+
+ // updating a device sets the next stream ID for that user
+ msgs = []api.DeviceMessage{
+ {
+ Type: api.TypeDeviceKeyUpdate,
+ DeviceKeys: &api.DeviceKeys{
+ DeviceID: "AAA",
+ UserID: alice,
+ KeyJSON: []byte(`{"key":"v2"}`),
+ },
+ // StreamID: 3
+ },
+ }
+ MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
+ if msgs[0].StreamID != 3 {
+ t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
+ }
+
+ dbLock.Lock()
+ defer dbLock.Unlock()
+ // Querying for device keys returns the latest stream IDs
+ msgs, err = db.DeviceKeysForUser(ctx, alice, deviceArray, false)
+
+ if err != nil {
+ t.Fatalf("DeviceKeysForUser returned error: %s", err)
+ }
+ wantStreamIDs := map[string]int64{
+ "AAA": 3,
+ "another_device": 2,
+ }
+ if len(msgs) != len(wantStreamIDs) {
+ t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs))
+ }
+ for _, m := range msgs {
+ if m.StreamID != wantStreamIDs[m.DeviceID] {
+ t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID])
+ }
+ }
+ })
+}
diff --git a/userapi/storage/storage_wasm.go b/userapi/storage/storage_wasm.go
index 5d5d292e..163e3e17 100644
--- a/userapi/storage/storage_wasm.go
+++ b/userapi/storage/storage_wasm.go
@@ -32,10 +32,10 @@ func NewUserAPIDatabase(
openIDTokenLifetimeMS int64,
loginTokenLifetime time.Duration,
serverNoticesLocalpart string,
-) (Database, error) {
+) (UserDatabase, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
- return sqlite3.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
+ return sqlite3.NewUserDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
case dbProperties.ConnectionString.IsPostgres():
return nil, fmt.Errorf("can't use Postgres implementation")
default:
diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go
index 9221e571..693e7303 100644
--- a/userapi/storage/tables/interface.go
+++ b/userapi/storage/tables/interface.go
@@ -20,10 +20,10 @@ import (
"encoding/json"
"time"
+ "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
- "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/types"
)
@@ -145,3 +145,47 @@ const (
// uint32.
AllNotifications NotificationFilter = (1 << 31) - 1
)
+
+type OneTimeKeys interface {
+ SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
+ CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
+ InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
+ // SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON.
+ // Returns an empty map if the key does not exist.
+ SelectAndDeleteOneTimeKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error)
+ DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
+}
+
+type DeviceKeys interface {
+ SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
+ InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
+ SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error)
+ CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
+ SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
+ DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
+ DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
+}
+
+type KeyChanges interface {
+ InsertKeyChange(ctx context.Context, userID string) (int64, error)
+ // SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets.
+ // Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of types.OffsetNewest means no upper offset.
+ SelectKeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
+}
+
+type StaleDeviceLists interface {
+ InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error
+ SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
+ DeleteStaleDeviceLists(ctx context.Context, txn *sql.Tx, userIDs []string) error
+}
+
+type CrossSigningKeys interface {
+ SelectCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string) (r types.CrossSigningKeyMap, err error)
+ UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes) error
+}
+
+type CrossSigningSigs interface {
+ SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r types.CrossSigningSigMap, err error)
+ UpsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error
+ DeleteCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) error
+}
diff --git a/userapi/storage/tables/stale_device_lists_test.go b/userapi/storage/tables/stale_device_lists_test.go
new file mode 100644
index 00000000..b9bdafda
--- /dev/null
+++ b/userapi/storage/tables/stale_device_lists_test.go
@@ -0,0 +1,94 @@
+package tables_test
+
+import (
+ "context"
+ "testing"
+
+ "github.com/matrix-org/dendrite/userapi/storage/postgres"
+ "github.com/matrix-org/dendrite/userapi/storage/sqlite3"
+ "github.com/matrix-org/gomatrixserverlib"
+
+ "github.com/matrix-org/dendrite/internal/sqlutil"
+ "github.com/matrix-org/dendrite/setup/config"
+
+ "github.com/matrix-org/dendrite/test"
+ "github.com/matrix-org/dendrite/userapi/storage/tables"
+)
+
+func mustCreateTable(t *testing.T, dbType test.DBType) (tab tables.StaleDeviceLists, close func()) {
+ connStr, close := test.PrepareDBConnectionString(t, dbType)
+ db, err := sqlutil.Open(&config.DatabaseOptions{
+ ConnectionString: config.DataSource(connStr),
+ }, nil)
+ if err != nil {
+ t.Fatalf("failed to open database: %s", err)
+ }
+ switch dbType {
+ case test.DBTypePostgres:
+ tab, err = postgres.NewPostgresStaleDeviceListsTable(db)
+ case test.DBTypeSQLite:
+ tab, err = sqlite3.NewSqliteStaleDeviceListsTable(db)
+ }
+ if err != nil {
+ t.Fatalf("failed to create new table: %s", err)
+ }
+ return tab, close
+}
+
+func TestStaleDeviceLists(t *testing.T) {
+ alice := test.NewUser(t)
+ bob := test.NewUser(t)
+ charlie := "@charlie:localhost"
+ ctx := context.Background()
+
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ tab, closeDB := mustCreateTable(t, dbType)
+ defer closeDB()
+
+ if err := tab.InsertStaleDeviceList(ctx, alice.ID, true); err != nil {
+ t.Fatalf("failed to insert stale device: %s", err)
+ }
+ if err := tab.InsertStaleDeviceList(ctx, bob.ID, true); err != nil {
+ t.Fatalf("failed to insert stale device: %s", err)
+ }
+ if err := tab.InsertStaleDeviceList(ctx, charlie, true); err != nil {
+ t.Fatalf("failed to insert stale device: %s", err)
+ }
+
+ // Query one server
+ wantStaleUsers := []string{alice.ID, bob.ID}
+ gotStaleUsers, err := tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"})
+ if err != nil {
+ t.Fatalf("failed to query stale device lists: %s", err)
+ }
+ if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) {
+ t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers)
+ }
+
+ // Query all servers
+ wantStaleUsers = []string{alice.ID, bob.ID, charlie}
+ gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{})
+ if err != nil {
+ t.Fatalf("failed to query stale device lists: %s", err)
+ }
+ if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) {
+ t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers)
+ }
+
+ // Delete stale devices
+ deleteUsers := []string{alice.ID, bob.ID}
+ if err = tab.DeleteStaleDeviceLists(ctx, nil, deleteUsers); err != nil {
+ t.Fatalf("failed to delete stale device lists: %s", err)
+ }
+
+ // Verify we don't get anything back after deleting
+ gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"})
+ if err != nil {
+ t.Fatalf("failed to query stale device lists: %s", err)
+ }
+
+ if gotCount := len(gotStaleUsers); gotCount > 0 {
+ t.Fatalf("expected no stale users, got %d", gotCount)
+ }
+ })
+}