diff options
author | Kegsay <kegan@matrix.org> | 2020-08-07 17:32:13 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-08-07 17:32:13 +0100 |
commit | f371783da765f96fc3764091e95fb8cb8004e208 (patch) | |
tree | dd1748eca0719054be508f2f026cb0e314d2565f /keyserver/storage | |
parent | 30c2325eaf85f28f438f9a3c7b703978eee66cf7 (diff) |
Finish inbound E2E device lists (#1243)
* Add tests for device list updates
* Add stale_device_lists table and use db before asking remote for device keys
* Fetch remote keys if all devices are requested
* Add display_name col to store remote device names
Few other tweaks to make `Server correctly handles incoming m.device_list_update`
pass.
* Fix sqlite otk bug
* Unbuffered channel to block /send causing sytest to not race anymore
* Linting and fix bug whereby we didn't send updated dl tokens to the client causing a tightloop on /sync sometimes
* No longer assert staleness as Update blocks on workers now
* Back out tweaks
* Bugfixes
Diffstat (limited to 'keyserver/storage')
-rw-r--r-- | keyserver/storage/postgres/device_keys_table.go | 25 | ||||
-rw-r--r-- | keyserver/storage/postgres/stale_device_lists.go | 118 | ||||
-rw-r--r-- | keyserver/storage/postgres/storage.go | 13 | ||||
-rw-r--r-- | keyserver/storage/shared/storage.go | 13 | ||||
-rw-r--r-- | keyserver/storage/sqlite3/device_keys_table.go | 25 | ||||
-rw-r--r-- | keyserver/storage/sqlite3/one_time_keys_table.go | 3 | ||||
-rw-r--r-- | keyserver/storage/sqlite3/stale_device_lists.go | 118 | ||||
-rw-r--r-- | keyserver/storage/sqlite3/storage.go | 13 | ||||
-rw-r--r-- | keyserver/storage/tables/interface.go | 6 |
9 files changed, 304 insertions, 30 deletions
diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go index d321860d..b9d5d4c3 100644 --- a/keyserver/storage/postgres/device_keys_table.go +++ b/keyserver/storage/postgres/device_keys_table.go @@ -37,22 +37,23 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys ( -- required in the spec because in the event of a missed update the server fetches the entire -- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs. stream_id BIGINT NOT NULL, + display_name TEXT, -- Clobber based on tuple of user/device. CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id) ); ` const upsertDeviceKeysSQL = "" + - "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" + - " VALUES ($1, $2, $3, $4, $5)" + + "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" + + " VALUES ($1, $2, $3, $4, $5, $6)" + " ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" + - " DO UPDATE SET key_json = $4, stream_id = $5" + " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6" const selectDeviceKeysSQL = "" + - "SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" const selectBatchDeviceKeysSQL = "" + - "SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1" + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1" const selectMaxStreamForUserSQL = "" + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" @@ -99,13 +100,17 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys [] for i, key := range keys { var keyJSONStr string var streamID int - err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID) + var displayName sql.NullString + err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName) if err != nil && err != sql.ErrNoRows { return err } // this will be '' when there is no device keys[i].KeyJSON = []byte(keyJSONStr) keys[i].StreamID = streamID + if displayName.Valid { + keys[i].DisplayName = displayName.String + } } return nil } @@ -140,7 +145,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx for _, key := range keys { now := time.Now().Unix() _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( - ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, + ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName, ) if err != nil { return err @@ -165,11 +170,15 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID dk.UserID = userID var keyJSON string var streamID int - if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil { + var displayName sql.NullString + if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil { return nil, err } dk.KeyJSON = []byte(keyJSON) dk.StreamID = streamID + if displayName.Valid { + dk.DisplayName = displayName.String + } // include the key if we want all keys (no device) or it was asked if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { result = append(result, dk) diff --git a/keyserver/storage/postgres/stale_device_lists.go b/keyserver/storage/postgres/stale_device_lists.go new file mode 100644 index 00000000..63281adf --- /dev/null +++ b/keyserver/storage/postgres/stale_device_lists.go @@ -0,0 +1,118 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/gomatrixserverlib" +) + +var staleDeviceListsSchema = ` +-- Stores whether a user's device lists are stale or not. +CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists ( + user_id TEXT PRIMARY KEY NOT NULL, + domain TEXT NOT NULL, + is_stale BOOLEAN NOT NULL, + ts_added_secs BIGINT NOT NULL +); + +CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale); +` + +const upsertStaleDeviceListSQL = "" + + "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" + + " VALUES ($1, $2, $3, $4)" + + " ON CONFLICT (user_id)" + + " DO UPDATE SET is_stale = $3, ts_added_secs = $4" + +const selectStaleDeviceListsWithDomainsSQL = "" + + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2" + +const selectStaleDeviceListsSQL = "" + + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" + +type staleDeviceListsStatements struct { + upsertStaleDeviceListStmt *sql.Stmt + selectStaleDeviceListsWithDomainsStmt *sql.Stmt + selectStaleDeviceListsStmt *sql.Stmt +} + +func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { + s := &staleDeviceListsStatements{} + _, err := db.Exec(staleDeviceListsSchema) + if err != nil { + return nil, err + } + if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil { + return nil, err + } + if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil { + return nil, err + } + if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error { + _, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return err + } + _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix()) + return err +} + +func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { + // we only query for 1 domain or all domains so optimise for those use cases + if len(domains) == 0 { + rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true) + if err != nil { + return nil, err + } + return rowsToUserIDs(ctx, rows) + } + var result []string + for _, domain := range domains { + rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain)) + if err != nil { + return nil, err + } + userIDs, err := rowsToUserIDs(ctx, rows) + if err != nil { + return nil, err + } + result = append(result, userIDs...) + } + return result, nil +} + +func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) { + defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed") + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return nil, err + } + result = append(result, userID) + } + return result, rows.Err() +} diff --git a/keyserver/storage/postgres/storage.go b/keyserver/storage/postgres/storage.go index a1d1c0fe..de2fabfd 100644 --- a/keyserver/storage/postgres/storage.go +++ b/keyserver/storage/postgres/storage.go @@ -38,10 +38,15 @@ func NewDatabase(dbDataSourceName string, dbProperties sqlutil.DbProperties) (*s if err != nil { return nil, err } + sdl, err := NewPostgresStaleDeviceListsTable(db) + if err != nil { + return nil, err + } return &shared.Database{ - DB: db, - OneTimeKeysTable: otk, - DeviceKeysTable: dk, - KeyChangesTable: kc, + DB: db, + OneTimeKeysTable: otk, + DeviceKeysTable: dk, + KeyChangesTable: kc, + StaleDeviceListsTable: sdl, }, nil } diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index 68964be6..4279eae7 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -26,10 +26,11 @@ import ( ) type Database struct { - DB *sql.DB - OneTimeKeysTable tables.OneTimeKeys - DeviceKeysTable tables.DeviceKeys - KeyChangesTable tables.KeyChanges + DB *sql.DB + OneTimeKeysTable tables.OneTimeKeys + DeviceKeysTable tables.DeviceKeys + KeyChangesTable tables.KeyChanges + StaleDeviceListsTable tables.StaleDeviceLists } func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { @@ -129,10 +130,10 @@ func (d *Database) KeyChanges(ctx context.Context, partition int32, fromOffset, // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. // If no domains are given, all user IDs with stale device lists are returned. func (d *Database) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { - return nil, nil // TODO + return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains) } // MarkDeviceListStale sets the stale bit for this user to isStale. func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error { - return nil // TODO + return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale) } diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index 15d9c775..abe6636a 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -34,22 +34,23 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys ( ts_added_secs BIGINT NOT NULL, key_json TEXT NOT NULL, stream_id BIGINT NOT NULL, + display_name TEXT, -- Clobber based on tuple of user/device. UNIQUE (user_id, device_id) ); ` const upsertDeviceKeysSQL = "" + - "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" + - " VALUES ($1, $2, $3, $4, $5)" + + "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" + + " VALUES ($1, $2, $3, $4, $5, $6)" + " ON CONFLICT (user_id, device_id)" + - " DO UPDATE SET key_json = $4, stream_id = $5" + " DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6" const selectDeviceKeysSQL = "" + - "SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" + "SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" const selectBatchDeviceKeysSQL = "" + - "SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1" + "SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1" const selectMaxStreamForUserSQL = "" + "SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1" @@ -106,11 +107,15 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID dk.UserID = userID var keyJSON string var streamID int - if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil { + var displayName sql.NullString + if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil { return nil, err } dk.KeyJSON = []byte(keyJSON) dk.StreamID = streamID + if displayName.Valid { + dk.DisplayName = displayName.String + } // include the key if we want all keys (no device) or it was asked if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { result = append(result, dk) @@ -123,13 +128,17 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys [] for i, key := range keys { var keyJSONStr string var streamID int - err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID) + var displayName sql.NullString + err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName) if err != nil && err != sql.ErrNoRows { return err } // this will be '' when there is no device keys[i].KeyJSON = []byte(keyJSONStr) keys[i].StreamID = streamID + if displayName.Valid { + keys[i].DisplayName = displayName.String + } } return nil } @@ -171,7 +180,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx for _, key := range keys { now := time.Now().Unix() _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( - ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, + ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName, ) if err != nil { return err diff --git a/keyserver/storage/sqlite3/one_time_keys_table.go b/keyserver/storage/sqlite3/one_time_keys_table.go index f910479f..907966a7 100644 --- a/keyserver/storage/sqlite3/one_time_keys_table.go +++ b/keyserver/storage/sqlite3/one_time_keys_table.go @@ -196,6 +196,9 @@ func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey( _, err = txn.StmtContext(ctx, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID) return err }) + if keyJSON == "" { + return nil, nil + } return map[string]json.RawMessage{ algorithm + ":" + keyID: json.RawMessage(keyJSON), }, err diff --git a/keyserver/storage/sqlite3/stale_device_lists.go b/keyserver/storage/sqlite3/stale_device_lists.go new file mode 100644 index 00000000..a989476d --- /dev/null +++ b/keyserver/storage/sqlite3/stale_device_lists.go @@ -0,0 +1,118 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "time" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/gomatrixserverlib" +) + +var staleDeviceListsSchema = ` +-- Stores whether a user's device lists are stale or not. +CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists ( + user_id TEXT PRIMARY KEY NOT NULL, + domain TEXT NOT NULL, + is_stale BOOLEAN NOT NULL, + ts_added_secs BIGINT NOT NULL +); + +CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale); +` + +const upsertStaleDeviceListSQL = "" + + "INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" + + " VALUES ($1, $2, $3, $4)" + + " ON CONFLICT (user_id)" + + " DO UPDATE SET is_stale = $3, ts_added_secs = $4" + +const selectStaleDeviceListsWithDomainsSQL = "" + + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2" + +const selectStaleDeviceListsSQL = "" + + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" + +type staleDeviceListsStatements struct { + upsertStaleDeviceListStmt *sql.Stmt + selectStaleDeviceListsWithDomainsStmt *sql.Stmt + selectStaleDeviceListsStmt *sql.Stmt +} + +func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { + s := &staleDeviceListsStatements{} + _, err := db.Exec(staleDeviceListsSchema) + if err != nil { + return nil, err + } + if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil { + return nil, err + } + if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil { + return nil, err + } + if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil { + return nil, err + } + return s, nil +} + +func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error { + _, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return err + } + _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix()) + return err +} + +func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { + // we only query for 1 domain or all domains so optimise for those use cases + if len(domains) == 0 { + rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true) + if err != nil { + return nil, err + } + return rowsToUserIDs(ctx, rows) + } + var result []string + for _, domain := range domains { + rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain)) + if err != nil { + return nil, err + } + userIDs, err := rowsToUserIDs(ctx, rows) + if err != nil { + return nil, err + } + result = append(result, userIDs...) + } + return result, nil +} + +func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) { + defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed") + for rows.Next() { + var userID string + if err := rows.Scan(&userID); err != nil { + return nil, err + } + result = append(result, userID) + } + return result, rows.Err() +} diff --git a/keyserver/storage/sqlite3/storage.go b/keyserver/storage/sqlite3/storage.go index f9771cf1..bbfd1e79 100644 --- a/keyserver/storage/sqlite3/storage.go +++ b/keyserver/storage/sqlite3/storage.go @@ -41,10 +41,15 @@ func NewDatabase(dataSourceName string) (*shared.Database, error) { if err != nil { return nil, err } + sdl, err := NewSqliteStaleDeviceListsTable(db) + if err != nil { + return nil, err + } return &shared.Database{ - DB: db, - OneTimeKeysTable: otk, - DeviceKeysTable: dk, - KeyChangesTable: kc, + DB: db, + OneTimeKeysTable: otk, + DeviceKeysTable: dk, + KeyChangesTable: kc, + StaleDeviceListsTable: sdl, }, nil } diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index ac932d56..a4d5dede 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -20,6 +20,7 @@ import ( "encoding/json" "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/gomatrixserverlib" ) type OneTimeKeys interface { @@ -45,3 +46,8 @@ type KeyChanges interface { // Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of sarama.OffsetNewest means no upper offset. SelectKeyChanges(ctx context.Context, partition int32, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) } + +type StaleDeviceLists interface { + InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error + SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) +} |