aboutsummaryrefslogtreecommitdiff
path: root/keyserver/internal/device_list_update.go
diff options
context:
space:
mode:
Diffstat (limited to 'keyserver/internal/device_list_update.go')
-rw-r--r--keyserver/internal/device_list_update.go298
1 files changed, 298 insertions, 0 deletions
diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go
new file mode 100644
index 00000000..19d8463d
--- /dev/null
+++ b/keyserver/internal/device_list_update.go
@@ -0,0 +1,298 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package internal
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "hash/fnv"
+ "sync"
+ "time"
+
+ "github.com/matrix-org/dendrite/keyserver/api"
+ "github.com/matrix-org/dendrite/keyserver/producers"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
+ "github.com/sirupsen/logrus"
+)
+
+// DeviceListUpdater handles device list updates from remote servers.
+//
+// In the case where we have the prev_id for an update, the updater just stores the update (after acquiring a per-user lock).
+// In the case where we do not have the prev_id for an update, the updater marks the user_id as stale and notifies
+// a worker to get the latest device list for this user. Note: stream IDs are scoped per user so missing a prev_id
+// for a (user, device) does not mean that DEVICE is outdated as the previous ID could be for a different device:
+// we have to invalidate all devices for that user. Once the list has been fetched, the per-user lock is acquired and the
+// updater stores the latest list along with the latest stream ID.
+//
+// On startup, the updater spins up N workers which are responsible for querying device keys from remote servers.
+// Workers are scoped by homeserver domain, with one worker responsible for many domains, determined by hashing
+// mod N the server name. Work is sent via a channel which just serves to "poke" the worker as the data is retrieved
+// from the database (which allows us to batch requests to the same server). This has a number of desirable properties:
+// - We guarantee only 1 in-flight /keys/query request per server at any time as there is exactly 1 worker responsible
+// for that domain.
+// - We don't have unbounded growth in proportion to the number of servers (this is more important in a P2P world where
+// we have many many servers)
+// - We can adjust concurrency (at the cost of memory usage) by tuning N, to accommodate mobile devices vs servers.
+// The downsides are that:
+// - Query requests can get queued behind other servers if they hash to the same worker, even if there are other free
+// workers elsewhere. Whilst suboptimal, provided we cap how long a single request can last (e.g using context timeouts)
+// we guarantee we will get around to it. Also, more users on a given server does not increase the number of requests
+// (as /keys/query allows multiple users to be specified) so being stuck behind matrix.org won't materially be any worse
+// than being stuck behind foo.bar
+// In the event that the query fails, the worker spins up a short-lived goroutine whose sole purpose is to inject the server
+// name back into the channel after a certain amount of time. If in the interim the device lists have been updated, then
+// the database query will return no stale lists. Reinjection into the channel continues until success or the server terminates,
+// when it will be reloaded on startup.
+type DeviceListUpdater struct {
+ // A map from user_id to a mutex. Used when we are missing prev IDs so we don't make more than 1
+ // request to the remote server and race.
+ // TODO: Put in an LRU cache to bound growth
+ userIDToMutex map[string]*sync.Mutex
+ mu *sync.Mutex // protects UserIDToMutex
+
+ db DeviceListUpdaterDatabase
+ producer *producers.KeyChange
+ fedClient *gomatrixserverlib.FederationClient
+ workerChans []chan gomatrixserverlib.ServerName
+}
+
+// DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater.
+// Useful for testing.
+type DeviceListUpdaterDatabase interface {
+ // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
+ // If no domains are given, all user IDs with stale device lists are returned.
+ StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
+
+ // MarkDeviceListStale sets the stale bit for this user to isStale.
+ MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error
+
+ // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
+ // for this (user, device). Does not modify the stream ID for keys.
+ StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
+
+ // PrevIDsExists returns true if all prev IDs exist for this user.
+ PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)
+}
+
+// NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale.
+func NewDeviceListUpdater(
+ db DeviceListUpdaterDatabase, producer *producers.KeyChange, fedClient *gomatrixserverlib.FederationClient,
+ numWorkers int,
+) *DeviceListUpdater {
+ return &DeviceListUpdater{
+ userIDToMutex: make(map[string]*sync.Mutex),
+ mu: &sync.Mutex{},
+ db: db,
+ producer: producer,
+ fedClient: fedClient,
+ workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers),
+ }
+}
+
+// Start the device list updater, which will try to refresh any stale device lists.
+func (u *DeviceListUpdater) Start() error {
+ for i := 0; i < len(u.workerChans); i++ {
+ // Allocate a small buffer per channel.
+ // If the buffer limit is reached, backpressure will cause the processing of EDUs
+ // to stop (in this transaction) until key requests can be made.
+ ch := make(chan gomatrixserverlib.ServerName, 10)
+ u.workerChans[i] = ch
+ go u.worker(ch)
+ }
+
+ staleLists, err := u.db.StaleDeviceLists(context.Background(), []gomatrixserverlib.ServerName{})
+ if err != nil {
+ return err
+ }
+ for _, userID := range staleLists {
+ u.notifyWorkers(userID)
+ }
+ return nil
+}
+
+func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex {
+ u.mu.Lock()
+ defer u.mu.Unlock()
+ if u.userIDToMutex[userID] == nil {
+ u.userIDToMutex[userID] = &sync.Mutex{}
+ }
+ return u.userIDToMutex[userID]
+}
+
+func (u *DeviceListUpdater) Update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) error {
+ isDeviceListStale, err := u.update(ctx, event)
+ if err != nil {
+ return err
+ }
+ if isDeviceListStale {
+ // poke workers to handle stale device lists
+ u.notifyWorkers(event.UserID)
+ }
+ return nil
+}
+
+func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) (bool, error) {
+ mu := u.mutex(event.UserID)
+ mu.Lock()
+ defer mu.Unlock()
+ // check if we have the prev IDs
+ exists, err := u.db.PrevIDsExists(ctx, event.UserID, event.PrevID)
+ if err != nil {
+ return false, fmt.Errorf("failed to check prev IDs exist for %s (%s): %w", event.UserID, event.DeviceID, err)
+ }
+ util.GetLogger(ctx).WithFields(logrus.Fields{
+ "prev_ids_exist": exists,
+ "user_id": event.UserID,
+ "device_id": event.DeviceID,
+ "stream_id": event.StreamID,
+ "prev_ids": event.PrevID,
+ }).Info("DeviceListUpdater.Update")
+
+ // if we haven't missed anything update the database and notify users
+ if exists {
+ keys := []api.DeviceMessage{
+ {
+ DeviceKeys: api.DeviceKeys{
+ DeviceID: event.DeviceID,
+ DisplayName: event.DeviceDisplayName,
+ KeyJSON: event.Keys,
+ UserID: event.UserID,
+ },
+ StreamID: event.StreamID,
+ },
+ }
+ err = u.db.StoreRemoteDeviceKeys(ctx, keys)
+ if err != nil {
+ return false, fmt.Errorf("failed to store remote device keys for %s (%s): %w", event.UserID, event.DeviceID, err)
+ }
+ // ALWAYS emit key changes when we've been poked over federation even if there's no change
+ // just in case this poke is important for something.
+ err = u.producer.ProduceKeyChanges(keys)
+ if err != nil {
+ return false, fmt.Errorf("failed to produce device key changes for %s (%s): %w", event.UserID, event.DeviceID, err)
+ }
+ return false, nil
+ }
+
+ err = u.db.MarkDeviceListStale(ctx, event.UserID, true)
+ if err != nil {
+ return false, fmt.Errorf("failed to mark device list for %s as stale: %w", event.UserID, err)
+ }
+
+ return true, nil
+}
+
+func (u *DeviceListUpdater) notifyWorkers(userID string) {
+ _, remoteServer, err := gomatrixserverlib.SplitID('@', userID)
+ if err != nil {
+ return
+ }
+ hash := fnv.New32a()
+ _, _ = hash.Write([]byte(remoteServer))
+ index := int(hash.Sum32()) % len(u.workerChans)
+ u.workerChans[index] <- remoteServer
+}
+
+func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) {
+ // It's possible to get many of the same server name in the channel, so in order
+ // to prevent processing the same server over and over we keep track of when we
+ // last made a request to the server. If we get the server name during the cooloff
+ // period, we'll ignore the poke.
+ lastProcessed := make(map[gomatrixserverlib.ServerName]time.Time)
+ cooloffPeriod := time.Minute
+ shouldProcess := func(srv gomatrixserverlib.ServerName) bool {
+ // we should process requests when now is after the last process time + cooloff
+ return time.Now().After(lastProcessed[srv].Add(cooloffPeriod))
+ }
+
+ // on failure, spin up a short-lived goroutine to inject the server name again.
+ inject := func(srv gomatrixserverlib.ServerName, duration time.Duration) {
+ time.Sleep(duration)
+ ch <- srv
+ }
+
+ for serverName := range ch {
+ if !shouldProcess(serverName) {
+ // do not inject into the channel as we know there will be a sleeping goroutine
+ // which will do it after the cooloff period expires
+ continue
+ }
+ lastProcessed[serverName] = time.Now()
+ shouldRetry := u.processServer(serverName)
+ if shouldRetry {
+ go inject(serverName, cooloffPeriod) // TODO: Backoff?
+ }
+ }
+}
+
+func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerName) bool {
+ requestTimeout := time.Minute // max amount of time we want to spend on each request
+ ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
+ defer cancel()
+ logger := util.GetLogger(ctx).WithField("server_name", serverName)
+ // fetch stale device lists
+ userIDs, err := u.db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{serverName})
+ if err != nil {
+ logger.WithError(err).Error("failed to load stale device lists")
+ return true
+ }
+ hasFailures := false
+ for _, userID := range userIDs {
+ if ctx.Err() != nil {
+ // we've timed out, give up and go to the back of the queue to let another server be processed.
+ hasFailures = true
+ break
+ }
+ res, err := u.fedClient.GetUserDevices(ctx, serverName, userID)
+ if err != nil {
+ logger.WithError(err).WithField("user_id", userID).Error("failed to query device keys for user")
+ hasFailures = true
+ continue
+ }
+ err = u.updateDeviceList(ctx, &res)
+ if err != nil {
+ logger.WithError(err).WithField("user_id", userID).Error("fetched device list but failed to store it")
+ hasFailures = true
+ }
+ }
+ return hasFailures
+}
+
+func (u *DeviceListUpdater) updateDeviceList(ctx context.Context, res *gomatrixserverlib.RespUserDevices) error {
+ keys := make([]api.DeviceMessage, len(res.Devices))
+ for i, device := range res.Devices {
+ keyJSON, err := json.Marshal(device.Keys)
+ if err != nil {
+ util.GetLogger(ctx).WithField("keys", device.Keys).Error("failed to marshal keys, skipping device")
+ continue
+ }
+ keys[i] = api.DeviceMessage{
+ StreamID: res.StreamID,
+ DeviceKeys: api.DeviceKeys{
+ DeviceID: device.DeviceID,
+ DisplayName: device.DisplayName,
+ UserID: res.UserID,
+ KeyJSON: keyJSON,
+ },
+ }
+ }
+ err := u.db.StoreRemoteDeviceKeys(ctx, keys)
+ if err != nil {
+ return err
+ }
+ return u.db.MarkDeviceListStale(ctx, res.UserID, false)
+}