diff options
Diffstat (limited to 'keyserver/internal/device_list_update.go')
-rw-r--r-- | keyserver/internal/device_list_update.go | 298 |
1 files changed, 298 insertions, 0 deletions
diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go new file mode 100644 index 00000000..19d8463d --- /dev/null +++ b/keyserver/internal/device_list_update.go @@ -0,0 +1,298 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "encoding/json" + "fmt" + "hash/fnv" + "sync" + "time" + + "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/keyserver/producers" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" +) + +// DeviceListUpdater handles device list updates from remote servers. +// +// In the case where we have the prev_id for an update, the updater just stores the update (after acquiring a per-user lock). +// In the case where we do not have the prev_id for an update, the updater marks the user_id as stale and notifies +// a worker to get the latest device list for this user. Note: stream IDs are scoped per user so missing a prev_id +// for a (user, device) does not mean that DEVICE is outdated as the previous ID could be for a different device: +// we have to invalidate all devices for that user. Once the list has been fetched, the per-user lock is acquired and the +// updater stores the latest list along with the latest stream ID. +// +// On startup, the updater spins up N workers which are responsible for querying device keys from remote servers. +// Workers are scoped by homeserver domain, with one worker responsible for many domains, determined by hashing +// mod N the server name. Work is sent via a channel which just serves to "poke" the worker as the data is retrieved +// from the database (which allows us to batch requests to the same server). This has a number of desirable properties: +// - We guarantee only 1 in-flight /keys/query request per server at any time as there is exactly 1 worker responsible +// for that domain. +// - We don't have unbounded growth in proportion to the number of servers (this is more important in a P2P world where +// we have many many servers) +// - We can adjust concurrency (at the cost of memory usage) by tuning N, to accommodate mobile devices vs servers. +// The downsides are that: +// - Query requests can get queued behind other servers if they hash to the same worker, even if there are other free +// workers elsewhere. Whilst suboptimal, provided we cap how long a single request can last (e.g using context timeouts) +// we guarantee we will get around to it. Also, more users on a given server does not increase the number of requests +// (as /keys/query allows multiple users to be specified) so being stuck behind matrix.org won't materially be any worse +// than being stuck behind foo.bar +// In the event that the query fails, the worker spins up a short-lived goroutine whose sole purpose is to inject the server +// name back into the channel after a certain amount of time. If in the interim the device lists have been updated, then +// the database query will return no stale lists. Reinjection into the channel continues until success or the server terminates, +// when it will be reloaded on startup. +type DeviceListUpdater struct { + // A map from user_id to a mutex. Used when we are missing prev IDs so we don't make more than 1 + // request to the remote server and race. + // TODO: Put in an LRU cache to bound growth + userIDToMutex map[string]*sync.Mutex + mu *sync.Mutex // protects UserIDToMutex + + db DeviceListUpdaterDatabase + producer *producers.KeyChange + fedClient *gomatrixserverlib.FederationClient + workerChans []chan gomatrixserverlib.ServerName +} + +// DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater. +// Useful for testing. +type DeviceListUpdaterDatabase interface { + // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. + // If no domains are given, all user IDs with stale device lists are returned. + StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) + + // MarkDeviceListStale sets the stale bit for this user to isStale. + MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error + + // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key + // for this (user, device). Does not modify the stream ID for keys. + StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error + + // PrevIDsExists returns true if all prev IDs exist for this user. + PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) +} + +// NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale. +func NewDeviceListUpdater( + db DeviceListUpdaterDatabase, producer *producers.KeyChange, fedClient *gomatrixserverlib.FederationClient, + numWorkers int, +) *DeviceListUpdater { + return &DeviceListUpdater{ + userIDToMutex: make(map[string]*sync.Mutex), + mu: &sync.Mutex{}, + db: db, + producer: producer, + fedClient: fedClient, + workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers), + } +} + +// Start the device list updater, which will try to refresh any stale device lists. +func (u *DeviceListUpdater) Start() error { + for i := 0; i < len(u.workerChans); i++ { + // Allocate a small buffer per channel. + // If the buffer limit is reached, backpressure will cause the processing of EDUs + // to stop (in this transaction) until key requests can be made. + ch := make(chan gomatrixserverlib.ServerName, 10) + u.workerChans[i] = ch + go u.worker(ch) + } + + staleLists, err := u.db.StaleDeviceLists(context.Background(), []gomatrixserverlib.ServerName{}) + if err != nil { + return err + } + for _, userID := range staleLists { + u.notifyWorkers(userID) + } + return nil +} + +func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex { + u.mu.Lock() + defer u.mu.Unlock() + if u.userIDToMutex[userID] == nil { + u.userIDToMutex[userID] = &sync.Mutex{} + } + return u.userIDToMutex[userID] +} + +func (u *DeviceListUpdater) Update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) error { + isDeviceListStale, err := u.update(ctx, event) + if err != nil { + return err + } + if isDeviceListStale { + // poke workers to handle stale device lists + u.notifyWorkers(event.UserID) + } + return nil +} + +func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) (bool, error) { + mu := u.mutex(event.UserID) + mu.Lock() + defer mu.Unlock() + // check if we have the prev IDs + exists, err := u.db.PrevIDsExists(ctx, event.UserID, event.PrevID) + if err != nil { + return false, fmt.Errorf("failed to check prev IDs exist for %s (%s): %w", event.UserID, event.DeviceID, err) + } + util.GetLogger(ctx).WithFields(logrus.Fields{ + "prev_ids_exist": exists, + "user_id": event.UserID, + "device_id": event.DeviceID, + "stream_id": event.StreamID, + "prev_ids": event.PrevID, + }).Info("DeviceListUpdater.Update") + + // if we haven't missed anything update the database and notify users + if exists { + keys := []api.DeviceMessage{ + { + DeviceKeys: api.DeviceKeys{ + DeviceID: event.DeviceID, + DisplayName: event.DeviceDisplayName, + KeyJSON: event.Keys, + UserID: event.UserID, + }, + StreamID: event.StreamID, + }, + } + err = u.db.StoreRemoteDeviceKeys(ctx, keys) + if err != nil { + return false, fmt.Errorf("failed to store remote device keys for %s (%s): %w", event.UserID, event.DeviceID, err) + } + // ALWAYS emit key changes when we've been poked over federation even if there's no change + // just in case this poke is important for something. + err = u.producer.ProduceKeyChanges(keys) + if err != nil { + return false, fmt.Errorf("failed to produce device key changes for %s (%s): %w", event.UserID, event.DeviceID, err) + } + return false, nil + } + + err = u.db.MarkDeviceListStale(ctx, event.UserID, true) + if err != nil { + return false, fmt.Errorf("failed to mark device list for %s as stale: %w", event.UserID, err) + } + + return true, nil +} + +func (u *DeviceListUpdater) notifyWorkers(userID string) { + _, remoteServer, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return + } + hash := fnv.New32a() + _, _ = hash.Write([]byte(remoteServer)) + index := int(hash.Sum32()) % len(u.workerChans) + u.workerChans[index] <- remoteServer +} + +func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) { + // It's possible to get many of the same server name in the channel, so in order + // to prevent processing the same server over and over we keep track of when we + // last made a request to the server. If we get the server name during the cooloff + // period, we'll ignore the poke. + lastProcessed := make(map[gomatrixserverlib.ServerName]time.Time) + cooloffPeriod := time.Minute + shouldProcess := func(srv gomatrixserverlib.ServerName) bool { + // we should process requests when now is after the last process time + cooloff + return time.Now().After(lastProcessed[srv].Add(cooloffPeriod)) + } + + // on failure, spin up a short-lived goroutine to inject the server name again. + inject := func(srv gomatrixserverlib.ServerName, duration time.Duration) { + time.Sleep(duration) + ch <- srv + } + + for serverName := range ch { + if !shouldProcess(serverName) { + // do not inject into the channel as we know there will be a sleeping goroutine + // which will do it after the cooloff period expires + continue + } + lastProcessed[serverName] = time.Now() + shouldRetry := u.processServer(serverName) + if shouldRetry { + go inject(serverName, cooloffPeriod) // TODO: Backoff? + } + } +} + +func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerName) bool { + requestTimeout := time.Minute // max amount of time we want to spend on each request + ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) + defer cancel() + logger := util.GetLogger(ctx).WithField("server_name", serverName) + // fetch stale device lists + userIDs, err := u.db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{serverName}) + if err != nil { + logger.WithError(err).Error("failed to load stale device lists") + return true + } + hasFailures := false + for _, userID := range userIDs { + if ctx.Err() != nil { + // we've timed out, give up and go to the back of the queue to let another server be processed. + hasFailures = true + break + } + res, err := u.fedClient.GetUserDevices(ctx, serverName, userID) + if err != nil { + logger.WithError(err).WithField("user_id", userID).Error("failed to query device keys for user") + hasFailures = true + continue + } + err = u.updateDeviceList(ctx, &res) + if err != nil { + logger.WithError(err).WithField("user_id", userID).Error("fetched device list but failed to store it") + hasFailures = true + } + } + return hasFailures +} + +func (u *DeviceListUpdater) updateDeviceList(ctx context.Context, res *gomatrixserverlib.RespUserDevices) error { + keys := make([]api.DeviceMessage, len(res.Devices)) + for i, device := range res.Devices { + keyJSON, err := json.Marshal(device.Keys) + if err != nil { + util.GetLogger(ctx).WithField("keys", device.Keys).Error("failed to marshal keys, skipping device") + continue + } + keys[i] = api.DeviceMessage{ + StreamID: res.StreamID, + DeviceKeys: api.DeviceKeys{ + DeviceID: device.DeviceID, + DisplayName: device.DisplayName, + UserID: res.UserID, + KeyJSON: keyJSON, + }, + } + } + err := u.db.StoreRemoteDeviceKeys(ctx, keys) + if err != nil { + return err + } + return u.db.MarkDeviceListStale(ctx, res.UserID, false) +} |