aboutsummaryrefslogtreecommitdiff
path: root/keyserver
diff options
context:
space:
mode:
authorKegsay <kegan@matrix.org>2020-08-12 22:43:02 +0100
committerGitHub <noreply@github.com>2020-08-12 22:43:02 +0100
commit820c56c165ec8f0409d23cd151a7ff89fbe09ffa (patch)
tree2266b281e13fc971d56625f416d1e03979062c43 /keyserver
parentd98ec12422c8498cf710bb34d2ed31f024aa1e15 (diff)
Fix more E2E sytests (#1265)
* WIP: Eagerly sync device lists on /user/keys/query requests Also notify servers when a user's device display name changes. Few caveats: - sytest `Device deletion propagates over federation` fails - `populateResponseWithDeviceKeysFromDatabase` is called from multiple goroutines and hence is unsafe. * Handle deleted devices correctly over federation
Diffstat (limited to 'keyserver')
-rw-r--r--keyserver/api/api.go5
-rw-r--r--keyserver/internal/device_list_update.go30
-rw-r--r--keyserver/internal/device_list_update_test.go2
-rw-r--r--keyserver/internal/internal.go126
-rw-r--r--keyserver/storage/interface.go5
-rw-r--r--keyserver/storage/postgres/device_keys_table.go12
-rw-r--r--keyserver/storage/shared/storage.go8
-rw-r--r--keyserver/storage/sqlite3/device_keys_table.go12
-rw-r--r--keyserver/storage/tables/interface.go1
9 files changed, 174 insertions, 27 deletions
diff --git a/keyserver/api/api.go b/keyserver/api/api.go
index c3481a38..442af871 100644
--- a/keyserver/api/api.go
+++ b/keyserver/api/api.go
@@ -110,6 +110,11 @@ type OneTimeKeysCount struct {
type PerformUploadKeysRequest struct {
DeviceKeys []DeviceKeys
OneTimeKeys []OneTimeKeys
+ // OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update
+ // the display name for their respective device, and NOT to modify the keys. The key
+ // itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths.
+ // Without this flag, requests to modify device display names would delete device keys.
+ OnlyDisplayNameUpdates bool
}
// PerformUploadKeysResponse is the response to PerformUploadKeys
diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go
index 1c4f0b97..573285e8 100644
--- a/keyserver/internal/device_list_update.go
+++ b/keyserver/internal/device_list_update.go
@@ -85,8 +85,9 @@ type DeviceListUpdaterDatabase interface {
MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
- // for this (user, device). Does not modify the stream ID for keys.
- StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
+ // for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior
+ // to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly.
+ StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
// PrevIDsExists returns true if all prev IDs exist for this user.
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)
@@ -144,6 +145,20 @@ func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex {
return u.userIDToMutex[userID]
}
+// ManualUpdate invalidates the device list for the given user and fetches the latest and tracks it.
+// Blocks until the device list is synced or the timeout is reached.
+func (u *DeviceListUpdater) ManualUpdate(ctx context.Context, serverName gomatrixserverlib.ServerName, userID string) error {
+ mu := u.mutex(userID)
+ mu.Lock()
+ err := u.db.MarkDeviceListStale(ctx, userID, true)
+ mu.Unlock()
+ if err != nil {
+ return fmt.Errorf("ManualUpdate: failed to mark device list for %s as stale: %w", userID, err)
+ }
+ u.notifyWorkers(userID)
+ return nil
+}
+
// Update blocks until the update has been stored in the database. It blocks primarily for satisfying sytest,
// which assumes when /send 200 OKs that the device lists have been updated.
func (u *DeviceListUpdater) Update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) error {
@@ -178,22 +193,27 @@ func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.
"stream_id": event.StreamID,
"prev_ids": event.PrevID,
"display_name": event.DeviceDisplayName,
+ "deleted": event.Deleted,
}).Info("DeviceListUpdater.Update")
// if we haven't missed anything update the database and notify users
if exists {
+ k := event.Keys
+ if event.Deleted {
+ k = nil
+ }
keys := []api.DeviceMessage{
{
DeviceKeys: api.DeviceKeys{
DeviceID: event.DeviceID,
DisplayName: event.DeviceDisplayName,
- KeyJSON: event.Keys,
+ KeyJSON: k,
UserID: event.UserID,
},
StreamID: event.StreamID,
},
}
- err = u.db.StoreRemoteDeviceKeys(ctx, keys)
+ err = u.db.StoreRemoteDeviceKeys(ctx, keys, nil)
if err != nil {
return false, fmt.Errorf("failed to store remote device keys for %s (%s): %w", event.UserID, event.DeviceID, err)
}
@@ -348,7 +368,7 @@ func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevi
},
}
}
- err := u.db.StoreRemoteDeviceKeys(ctx, keys)
+ err := u.db.StoreRemoteDeviceKeys(ctx, keys, []string{res.UserID})
if err != nil {
return fmt.Errorf("failed to store remote device keys: %w", err)
}
diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go
index dcb981c4..c42a7cdf 100644
--- a/keyserver/internal/device_list_update_test.go
+++ b/keyserver/internal/device_list_update_test.go
@@ -81,7 +81,7 @@ func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context,
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
// for this (user, device). Does not modify the stream ID for keys.
-func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
+func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clear []string) error {
d.storedKeys = append(d.storedKeys, keys...)
return nil
}
diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go
index 075622b7..ef52d014 100644
--- a/keyserver/internal/internal.go
+++ b/keyserver/internal/internal.go
@@ -28,6 +28,7 @@ import (
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
+ "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -205,7 +206,15 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query
maxStreamID = m.StreamID
}
}
- res.Devices = msgs
+ // remove deleted devices
+ var result []api.DeviceMessage
+ for _, m := range msgs {
+ if m.KeyJSON == nil {
+ continue
+ }
+ result = append(result, m)
+ }
+ res.Devices = result
res.StreamID = maxStreamID
}
@@ -282,27 +291,21 @@ func (a *KeyInternalAPI) remoteKeysFromDatabase(
fetchRemote := make(map[string]map[string][]string)
for domain, userToDeviceMap := range domainToDeviceKeys {
for userID, deviceIDs := range userToDeviceMap {
- keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs)
- // if we can't query the db or there are fewer keys than requested, fetch from remote.
- // Likewise, we can't safely return keys from the db when all devices are requested as we don't
+ // we can't safely return keys from the db when all devices are requested as we don't
// know if one has just been added.
- if len(deviceIDs) == 0 || err != nil || len(keys) < len(deviceIDs) {
- if _, ok := fetchRemote[domain]; !ok {
- fetchRemote[domain] = make(map[string][]string)
+ if len(deviceIDs) > 0 {
+ err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, deviceIDs)
+ if err == nil {
+ continue
}
- fetchRemote[domain][userID] = append(fetchRemote[domain][userID], deviceIDs...)
- continue
+ util.GetLogger(ctx).WithError(err).Error("populateResponseWithDeviceKeysFromDatabase")
}
- if res.DeviceKeys[userID] == nil {
- res.DeviceKeys[userID] = make(map[string]json.RawMessage)
- }
- for _, key := range keys {
- // inject the display name
- key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
- DisplayName string `json:"device_display_name,omitempty"`
- }{key.DisplayName})
- res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
+ // fetch device lists from remote
+ if _, ok := fetchRemote[domain]; !ok {
+ fetchRemote[domain] = make(map[string][]string)
}
+ fetchRemote[domain][userID] = append(fetchRemote[domain][userID], deviceIDs...)
+
}
}
return fetchRemote
@@ -324,6 +327,45 @@ func (a *KeyInternalAPI) queryRemoteKeys(
defer wg.Done()
fedCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
+ // for users who we do not have any knowledge about, try to start doing device list updates for them
+ // by hitting /users/devices - otherwise fallback to /keys/query which has nicer bulk properties but
+ // lack a stream ID.
+ var userIDsForAllDevices []string
+ for userID, deviceIDs := range devKeys {
+ if len(deviceIDs) == 0 {
+ userIDsForAllDevices = append(userIDsForAllDevices, userID)
+ delete(devKeys, userID)
+ }
+ }
+ for _, userID := range userIDsForAllDevices {
+ err := a.Updater.ManualUpdate(context.Background(), gomatrixserverlib.ServerName(serverName), userID)
+ if err != nil {
+ logrus.WithFields(logrus.Fields{
+ logrus.ErrorKey: err,
+ "user_id": userID,
+ "server": serverName,
+ }).Error("Failed to manually update device lists for user")
+ // try to do it via /keys/query
+ devKeys[userID] = []string{}
+ continue
+ }
+ // refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this
+ // user so the fact that we're populating all devices here isn't a problem so long as we have devices.
+ err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, userID, nil)
+ if err != nil {
+ logrus.WithFields(logrus.Fields{
+ logrus.ErrorKey: err,
+ "user_id": userID,
+ "server": serverName,
+ }).Error("Failed to manually update device lists for user")
+ // try to do it via /keys/query
+ devKeys[userID] = []string{}
+ continue
+ }
+ }
+ if len(devKeys) == 0 {
+ return
+ }
queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, gomatrixserverlib.ServerName(serverName), devKeys)
if err != nil {
failMu.Lock()
@@ -357,6 +399,37 @@ func (a *KeyInternalAPI) queryRemoteKeys(
}
}
+func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
+ ctx context.Context, res *api.QueryKeysResponse, userID string, deviceIDs []string,
+) error {
+ keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs)
+ // if we can't query the db or there are fewer keys than requested, fetch from remote.
+ if err != nil {
+ return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err)
+ }
+ if len(keys) < len(deviceIDs) {
+ return fmt.Errorf("DeviceKeysForUser %s returned fewer devices than requested, falling back to remote", userID)
+ }
+ if len(deviceIDs) == 0 && len(keys) == 0 {
+ return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID)
+ }
+ if res.DeviceKeys[userID] == nil {
+ res.DeviceKeys[userID] = make(map[string]json.RawMessage)
+ }
+
+ for _, key := range keys {
+ if len(key.KeyJSON) == 0 {
+ continue // ignore deleted keys
+ }
+ // inject the display name
+ key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
+ DisplayName string `json:"device_display_name,omitempty"`
+ }{key.DisplayName})
+ res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
+ }
+ return nil
+}
+
func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
var keysToStore []api.DeviceMessage
// assert that the user ID / device ID are not lying for each key
@@ -403,6 +476,10 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
}
return
}
+ if req.OnlyDisplayNameUpdates {
+ // add the display name field from keysToStore into existingKeys
+ keysToStore = appendDisplayNames(existingKeys, keysToStore)
+ }
// store the device keys and emit changes
err := a.DB.StoreLocalDeviceKeys(ctx, keysToStore)
if err != nil {
@@ -475,3 +552,16 @@ func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceMessage)
}
return a.Producer.ProduceKeyChanges(keysAdded)
}
+
+func appendDisplayNames(existing, new []api.DeviceMessage) []api.DeviceMessage {
+ for i, existingDevice := range existing {
+ for _, newDevice := range new {
+ if existingDevice.DeviceID != newDevice.DeviceID {
+ continue
+ }
+ existingDevice.DisplayName = newDevice.DisplayName
+ existing[i] = existingDevice
+ }
+ }
+ return existing
+}
diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go
index 2a60aacc..0ec62f56 100644
--- a/keyserver/storage/interface.go
+++ b/keyserver/storage/interface.go
@@ -43,8 +43,9 @@ type Database interface {
StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
- // for this (user, device). Does not modify the stream ID for keys.
- StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
+ // for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior
+ // to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly.
+ StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
// PrevIDsExists returns true if all prev IDs exist for this user.
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)
diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go
index b9d5d4c3..779d02c0 100644
--- a/keyserver/storage/postgres/device_keys_table.go
+++ b/keyserver/storage/postgres/device_keys_table.go
@@ -61,6 +61,9 @@ const selectMaxStreamForUserSQL = "" +
const countStreamIDsForUserSQL = "" +
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)"
+const deleteAllDeviceKeysSQL = "" +
+ "DELETE FROM keyserver_device_keys WHERE user_id=$1"
+
type deviceKeysStatements struct {
db *sql.DB
upsertDeviceKeysStmt *sql.Stmt
@@ -68,6 +71,7 @@ type deviceKeysStatements struct {
selectBatchDeviceKeysStmt *sql.Stmt
selectMaxStreamForUserStmt *sql.Stmt
countStreamIDsForUserStmt *sql.Stmt
+ deleteAllDeviceKeysStmt *sql.Stmt
}
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
@@ -93,6 +97,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.countStreamIDsForUserStmt, err = db.Prepare(countStreamIDsForUserSQL); err != nil {
return nil, err
}
+ if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil {
+ return nil, err
+ }
return s, nil
}
@@ -154,6 +161,11 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
return nil
}
+func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
+ _, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
+ return err
+}
+
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
if err != nil {
diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go
index 4279eae7..a4c35a4b 100644
--- a/keyserver/storage/shared/storage.go
+++ b/keyserver/storage/shared/storage.go
@@ -61,8 +61,14 @@ func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []i
return count == len(prevIDs), nil
}
-func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
+func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error {
return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
+ for _, userID := range clearUserIDs {
+ err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID)
+ if err != nil {
+ return err
+ }
+ }
return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
})
}
diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go
index abe6636a..a4d71fe1 100644
--- a/keyserver/storage/sqlite3/device_keys_table.go
+++ b/keyserver/storage/sqlite3/device_keys_table.go
@@ -58,6 +58,9 @@ const selectMaxStreamForUserSQL = "" +
const countStreamIDsForUserSQL = "" +
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)"
+const deleteAllDeviceKeysSQL = "" +
+ "DELETE FROM keyserver_device_keys WHERE user_id=$1"
+
type deviceKeysStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
@@ -65,6 +68,7 @@ type deviceKeysStatements struct {
selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt
selectMaxStreamForUserStmt *sql.Stmt
+ deleteAllDeviceKeysStmt *sql.Stmt
}
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
@@ -88,9 +92,17 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
return nil, err
}
+ if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil {
+ return nil, err
+ }
return s, nil
}
+func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
+ _, err := txn.Stmt(s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
+ return err
+}
+
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
deviceIDMap := make(map[string]bool)
for _, d := range deviceIDs {
diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go
index a4d5dede..f97e871f 100644
--- a/keyserver/storage/tables/interface.go
+++ b/keyserver/storage/tables/interface.go
@@ -38,6 +38,7 @@ type DeviceKeys interface {
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
+ DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
}
type KeyChanges interface {