diff options
Diffstat (limited to 'keyserver')
-rw-r--r-- | keyserver/internal/device_list_update.go | 298 | ||||
-rw-r--r-- | keyserver/internal/internal.go | 58 | ||||
-rw-r--r-- | keyserver/keyserver.go | 18 | ||||
-rw-r--r-- | keyserver/storage/interface.go | 8 | ||||
-rw-r--r-- | keyserver/storage/shared/storage.go | 12 |
5 files changed, 331 insertions, 63 deletions
diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go new file mode 100644 index 00000000..19d8463d --- /dev/null +++ b/keyserver/internal/device_list_update.go @@ -0,0 +1,298 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "encoding/json" + "fmt" + "hash/fnv" + "sync" + "time" + + "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/keyserver/producers" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" +) + +// DeviceListUpdater handles device list updates from remote servers. +// +// In the case where we have the prev_id for an update, the updater just stores the update (after acquiring a per-user lock). +// In the case where we do not have the prev_id for an update, the updater marks the user_id as stale and notifies +// a worker to get the latest device list for this user. Note: stream IDs are scoped per user so missing a prev_id +// for a (user, device) does not mean that DEVICE is outdated as the previous ID could be for a different device: +// we have to invalidate all devices for that user. Once the list has been fetched, the per-user lock is acquired and the +// updater stores the latest list along with the latest stream ID. +// +// On startup, the updater spins up N workers which are responsible for querying device keys from remote servers. +// Workers are scoped by homeserver domain, with one worker responsible for many domains, determined by hashing +// mod N the server name. Work is sent via a channel which just serves to "poke" the worker as the data is retrieved +// from the database (which allows us to batch requests to the same server). This has a number of desirable properties: +// - We guarantee only 1 in-flight /keys/query request per server at any time as there is exactly 1 worker responsible +// for that domain. +// - We don't have unbounded growth in proportion to the number of servers (this is more important in a P2P world where +// we have many many servers) +// - We can adjust concurrency (at the cost of memory usage) by tuning N, to accommodate mobile devices vs servers. +// The downsides are that: +// - Query requests can get queued behind other servers if they hash to the same worker, even if there are other free +// workers elsewhere. Whilst suboptimal, provided we cap how long a single request can last (e.g using context timeouts) +// we guarantee we will get around to it. Also, more users on a given server does not increase the number of requests +// (as /keys/query allows multiple users to be specified) so being stuck behind matrix.org won't materially be any worse +// than being stuck behind foo.bar +// In the event that the query fails, the worker spins up a short-lived goroutine whose sole purpose is to inject the server +// name back into the channel after a certain amount of time. If in the interim the device lists have been updated, then +// the database query will return no stale lists. Reinjection into the channel continues until success or the server terminates, +// when it will be reloaded on startup. +type DeviceListUpdater struct { + // A map from user_id to a mutex. Used when we are missing prev IDs so we don't make more than 1 + // request to the remote server and race. + // TODO: Put in an LRU cache to bound growth + userIDToMutex map[string]*sync.Mutex + mu *sync.Mutex // protects UserIDToMutex + + db DeviceListUpdaterDatabase + producer *producers.KeyChange + fedClient *gomatrixserverlib.FederationClient + workerChans []chan gomatrixserverlib.ServerName +} + +// DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater. +// Useful for testing. +type DeviceListUpdaterDatabase interface { + // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. + // If no domains are given, all user IDs with stale device lists are returned. + StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) + + // MarkDeviceListStale sets the stale bit for this user to isStale. + MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error + + // StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key + // for this (user, device). Does not modify the stream ID for keys. + StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error + + // PrevIDsExists returns true if all prev IDs exist for this user. + PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) +} + +// NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale. +func NewDeviceListUpdater( + db DeviceListUpdaterDatabase, producer *producers.KeyChange, fedClient *gomatrixserverlib.FederationClient, + numWorkers int, +) *DeviceListUpdater { + return &DeviceListUpdater{ + userIDToMutex: make(map[string]*sync.Mutex), + mu: &sync.Mutex{}, + db: db, + producer: producer, + fedClient: fedClient, + workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers), + } +} + +// Start the device list updater, which will try to refresh any stale device lists. +func (u *DeviceListUpdater) Start() error { + for i := 0; i < len(u.workerChans); i++ { + // Allocate a small buffer per channel. + // If the buffer limit is reached, backpressure will cause the processing of EDUs + // to stop (in this transaction) until key requests can be made. + ch := make(chan gomatrixserverlib.ServerName, 10) + u.workerChans[i] = ch + go u.worker(ch) + } + + staleLists, err := u.db.StaleDeviceLists(context.Background(), []gomatrixserverlib.ServerName{}) + if err != nil { + return err + } + for _, userID := range staleLists { + u.notifyWorkers(userID) + } + return nil +} + +func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex { + u.mu.Lock() + defer u.mu.Unlock() + if u.userIDToMutex[userID] == nil { + u.userIDToMutex[userID] = &sync.Mutex{} + } + return u.userIDToMutex[userID] +} + +func (u *DeviceListUpdater) Update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) error { + isDeviceListStale, err := u.update(ctx, event) + if err != nil { + return err + } + if isDeviceListStale { + // poke workers to handle stale device lists + u.notifyWorkers(event.UserID) + } + return nil +} + +func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) (bool, error) { + mu := u.mutex(event.UserID) + mu.Lock() + defer mu.Unlock() + // check if we have the prev IDs + exists, err := u.db.PrevIDsExists(ctx, event.UserID, event.PrevID) + if err != nil { + return false, fmt.Errorf("failed to check prev IDs exist for %s (%s): %w", event.UserID, event.DeviceID, err) + } + util.GetLogger(ctx).WithFields(logrus.Fields{ + "prev_ids_exist": exists, + "user_id": event.UserID, + "device_id": event.DeviceID, + "stream_id": event.StreamID, + "prev_ids": event.PrevID, + }).Info("DeviceListUpdater.Update") + + // if we haven't missed anything update the database and notify users + if exists { + keys := []api.DeviceMessage{ + { + DeviceKeys: api.DeviceKeys{ + DeviceID: event.DeviceID, + DisplayName: event.DeviceDisplayName, + KeyJSON: event.Keys, + UserID: event.UserID, + }, + StreamID: event.StreamID, + }, + } + err = u.db.StoreRemoteDeviceKeys(ctx, keys) + if err != nil { + return false, fmt.Errorf("failed to store remote device keys for %s (%s): %w", event.UserID, event.DeviceID, err) + } + // ALWAYS emit key changes when we've been poked over federation even if there's no change + // just in case this poke is important for something. + err = u.producer.ProduceKeyChanges(keys) + if err != nil { + return false, fmt.Errorf("failed to produce device key changes for %s (%s): %w", event.UserID, event.DeviceID, err) + } + return false, nil + } + + err = u.db.MarkDeviceListStale(ctx, event.UserID, true) + if err != nil { + return false, fmt.Errorf("failed to mark device list for %s as stale: %w", event.UserID, err) + } + + return true, nil +} + +func (u *DeviceListUpdater) notifyWorkers(userID string) { + _, remoteServer, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return + } + hash := fnv.New32a() + _, _ = hash.Write([]byte(remoteServer)) + index := int(hash.Sum32()) % len(u.workerChans) + u.workerChans[index] <- remoteServer +} + +func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) { + // It's possible to get many of the same server name in the channel, so in order + // to prevent processing the same server over and over we keep track of when we + // last made a request to the server. If we get the server name during the cooloff + // period, we'll ignore the poke. + lastProcessed := make(map[gomatrixserverlib.ServerName]time.Time) + cooloffPeriod := time.Minute + shouldProcess := func(srv gomatrixserverlib.ServerName) bool { + // we should process requests when now is after the last process time + cooloff + return time.Now().After(lastProcessed[srv].Add(cooloffPeriod)) + } + + // on failure, spin up a short-lived goroutine to inject the server name again. + inject := func(srv gomatrixserverlib.ServerName, duration time.Duration) { + time.Sleep(duration) + ch <- srv + } + + for serverName := range ch { + if !shouldProcess(serverName) { + // do not inject into the channel as we know there will be a sleeping goroutine + // which will do it after the cooloff period expires + continue + } + lastProcessed[serverName] = time.Now() + shouldRetry := u.processServer(serverName) + if shouldRetry { + go inject(serverName, cooloffPeriod) // TODO: Backoff? + } + } +} + +func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerName) bool { + requestTimeout := time.Minute // max amount of time we want to spend on each request + ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) + defer cancel() + logger := util.GetLogger(ctx).WithField("server_name", serverName) + // fetch stale device lists + userIDs, err := u.db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{serverName}) + if err != nil { + logger.WithError(err).Error("failed to load stale device lists") + return true + } + hasFailures := false + for _, userID := range userIDs { + if ctx.Err() != nil { + // we've timed out, give up and go to the back of the queue to let another server be processed. + hasFailures = true + break + } + res, err := u.fedClient.GetUserDevices(ctx, serverName, userID) + if err != nil { + logger.WithError(err).WithField("user_id", userID).Error("failed to query device keys for user") + hasFailures = true + continue + } + err = u.updateDeviceList(ctx, &res) + if err != nil { + logger.WithError(err).WithField("user_id", userID).Error("fetched device list but failed to store it") + hasFailures = true + } + } + return hasFailures +} + +func (u *DeviceListUpdater) updateDeviceList(ctx context.Context, res *gomatrixserverlib.RespUserDevices) error { + keys := make([]api.DeviceMessage, len(res.Devices)) + for i, device := range res.Devices { + keyJSON, err := json.Marshal(device.Keys) + if err != nil { + util.GetLogger(ctx).WithField("keys", device.Keys).Error("failed to marshal keys, skipping device") + continue + } + keys[i] = api.DeviceMessage{ + StreamID: res.StreamID, + DeviceKeys: api.DeviceKeys{ + DeviceID: device.DeviceID, + DisplayName: device.DisplayName, + UserID: res.UserID, + KeyJSON: keyJSON, + }, + } + } + err := u.db.StoreRemoteDeviceKeys(ctx, keys) + if err != nil { + return err + } + return u.db.MarkDeviceListStale(ctx, res.UserID, false) +} diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index d6e24566..ff298c07 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -38,74 +38,22 @@ type KeyInternalAPI struct { FedClient *gomatrixserverlib.FederationClient UserAPI userapi.UserInternalAPI Producer *producers.KeyChange - // A map from user_id to a mutex. Used when we are missing prev IDs so we don't make more than 1 - // request to the remote server and race. - // TODO: Put in an LRU cache to bound growth - UserIDToMutex map[string]*sync.Mutex - Mutex *sync.Mutex // protects UserIDToMutex + Updater *DeviceListUpdater } func (a *KeyInternalAPI) SetUserAPI(i userapi.UserInternalAPI) { a.UserAPI = i } -func (a *KeyInternalAPI) mutex(userID string) *sync.Mutex { - a.Mutex.Lock() - defer a.Mutex.Unlock() - if a.UserIDToMutex[userID] == nil { - a.UserIDToMutex[userID] = &sync.Mutex{} - } - return a.UserIDToMutex[userID] -} - func (a *KeyInternalAPI) InputDeviceListUpdate( ctx context.Context, req *api.InputDeviceListUpdateRequest, res *api.InputDeviceListUpdateResponse, ) { - mu := a.mutex(req.Event.UserID) - mu.Lock() - defer mu.Unlock() - // check if we have the prev IDs - exists, err := a.DB.PrevIDsExists(ctx, req.Event.UserID, req.Event.PrevID) + err := a.Updater.Update(ctx, req.Event) if err != nil { res.Error = &api.KeyError{ - Err: fmt.Sprintf("failed to check if prev ids exist: %s", err), + Err: fmt.Sprintf("failed to update device list: %s", err), } - return } - - // if we haven't missed anything update the database and notify users - if exists { - keys := []api.DeviceMessage{ - { - DeviceKeys: api.DeviceKeys{ - DeviceID: req.Event.DeviceID, - DisplayName: req.Event.DeviceDisplayName, - KeyJSON: req.Event.Keys, - UserID: req.Event.UserID, - }, - StreamID: req.Event.StreamID, - }, - } - err = a.DB.StoreRemoteDeviceKeys(ctx, keys) - if err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("failed to store remote device keys: %s", err), - } - return - } - // ALWAYS emit key changes when we've been poked over federation just in case - // this poke is important for something. - err = a.Producer.ProduceKeyChanges(keys) - if err != nil { - res.Error = &api.KeyError{ - Err: fmt.Sprintf("failed to emit remote device key changes: %s", err), - } - } - return - } - - // if we're missing an ID go and fetch it from the remote HS - } func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) { diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index 79d9cec9..4a6fbe3c 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -15,8 +15,6 @@ package keyserver import ( - "sync" - "github.com/Shopify/sarama" "github.com/gorilla/mux" "github.com/matrix-org/dendrite/internal/config" @@ -52,12 +50,16 @@ func NewInternalAPI( Producer: producer, DB: db, } + updater := internal.NewDeviceListUpdater(db, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable + err = updater.Start() + if err != nil { + logrus.WithError(err).Panicf("failed to start device list updater") + } return &internal.KeyInternalAPI{ - DB: db, - ThisServer: cfg.Matrix.ServerName, - FedClient: fedClient, - Producer: keyChangeProducer, - Mutex: &sync.Mutex{}, - UserIDToMutex: make(map[string]*sync.Mutex), + DB: db, + ThisServer: cfg.Matrix.ServerName, + FedClient: fedClient, + Producer: keyChangeProducer, + Updater: updater, } } diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index f67bbf71..2a60aacc 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -19,6 +19,7 @@ import ( "encoding/json" "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/gomatrixserverlib" ) type Database interface { @@ -64,4 +65,11 @@ type Database interface { // A to offset of sarama.OffsetNewest means no upper limit. // Returns the offset of the latest key change. KeyChanges(ctx context.Context, partition int32, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) + + // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. + // If no domains are given, all user IDs with stale device lists are returned. + StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) + + // MarkDeviceListStale sets the stale bit for this user to isStale. + MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error } diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index 78729774..68964be6 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/gomatrixserverlib" ) type Database struct { @@ -124,3 +125,14 @@ func (d *Database) StoreKeyChange(ctx context.Context, partition int32, offset i func (d *Database) KeyChanges(ctx context.Context, partition int32, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) { return d.KeyChangesTable.SelectKeyChanges(ctx, partition, fromOffset, toOffset) } + +// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. +// If no domains are given, all user IDs with stale device lists are returned. +func (d *Database) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { + return nil, nil // TODO +} + +// MarkDeviceListStale sets the stale bit for this user to isStale. +func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error { + return nil // TODO +} |