diff options
Diffstat (limited to 'userapi/internal/device_list_update.go')
-rw-r--r-- | userapi/internal/device_list_update.go | 31 |
1 files changed, 16 insertions, 15 deletions
diff --git a/userapi/internal/device_list_update.go b/userapi/internal/device_list_update.go index a274e1ae..d60e522e 100644 --- a/userapi/internal/device_list_update.go +++ b/userapi/internal/device_list_update.go @@ -26,6 +26,7 @@ import ( rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" @@ -98,8 +99,8 @@ type DeviceListUpdater struct { api DeviceListUpdaterAPI producer KeyChangeProducer fedClient fedsenderapi.KeyserverFederationAPI - workerChans []chan gomatrixserverlib.ServerName - thisServer gomatrixserverlib.ServerName + workerChans []chan spec.ServerName + thisServer spec.ServerName // When device lists are stale for a user, they get inserted into this map with a channel which `Update` will // block on or timeout via a select. @@ -113,7 +114,7 @@ type DeviceListUpdater struct { type DeviceListUpdaterDatabase interface { // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. // If no domains are given, all user IDs with stale device lists are returned. - StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) + StaleDeviceLists(ctx context.Context, domains []spec.ServerName) ([]string, error) // MarkDeviceListStale sets the stale bit for this user to isStale. MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error @@ -146,7 +147,7 @@ func NewDeviceListUpdater( process *process.ProcessContext, db DeviceListUpdaterDatabase, api DeviceListUpdaterAPI, producer KeyChangeProducer, fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int, - rsAPI rsapi.KeyserverRoomserverAPI, thisServer gomatrixserverlib.ServerName, + rsAPI rsapi.KeyserverRoomserverAPI, thisServer spec.ServerName, ) *DeviceListUpdater { return &DeviceListUpdater{ process: process, @@ -157,7 +158,7 @@ func NewDeviceListUpdater( producer: producer, fedClient: fedClient, thisServer: thisServer, - workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers), + workerChans: make([]chan spec.ServerName, numWorkers), userIDToChan: make(map[string]chan bool), userIDToChanMu: &sync.Mutex{}, rsAPI: rsAPI, @@ -170,12 +171,12 @@ func (u *DeviceListUpdater) Start() error { // Allocate a small buffer per channel. // If the buffer limit is reached, backpressure will cause the processing of EDUs // to stop (in this transaction) until key requests can be made. - ch := make(chan gomatrixserverlib.ServerName, 10) + ch := make(chan spec.ServerName, 10) u.workerChans[i] = ch go u.worker(ch) } - staleLists, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{}) + staleLists, err := u.db.StaleDeviceLists(u.process.Context(), []spec.ServerName{}) if err != nil { return err } @@ -195,7 +196,7 @@ func (u *DeviceListUpdater) Start() error { // CleanUp removes stale device entries for users we don't share a room with anymore func (u *DeviceListUpdater) CleanUp() error { - staleUsers, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{}) + staleUsers, err := u.db.StaleDeviceLists(u.process.Context(), []spec.ServerName{}) if err != nil { return err } @@ -223,7 +224,7 @@ func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex { // ManualUpdate invalidates the device list for the given user and fetches the latest and tracks it. // Blocks until the device list is synced or the timeout is reached. -func (u *DeviceListUpdater) ManualUpdate(ctx context.Context, serverName gomatrixserverlib.ServerName, userID string) error { +func (u *DeviceListUpdater) ManualUpdate(ctx context.Context, serverName spec.ServerName, userID string) error { mu := u.mutex(userID) mu.Lock() err := u.db.MarkDeviceListStale(ctx, userID, true) @@ -369,12 +370,12 @@ func (u *DeviceListUpdater) clearChannel(userID string) { } } -func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) { - retries := make(map[gomatrixserverlib.ServerName]time.Time) +func (u *DeviceListUpdater) worker(ch chan spec.ServerName) { + retries := make(map[spec.ServerName]time.Time) retriesMu := &sync.Mutex{} // restarter goroutine which will inject failed servers into ch when it is time go func() { - var serversToRetry []gomatrixserverlib.ServerName + var serversToRetry []spec.ServerName for { serversToRetry = serversToRetry[:0] // reuse memory time.Sleep(time.Second) @@ -413,7 +414,7 @@ func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) { } } -func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerName) (time.Duration, bool) { +func (u *DeviceListUpdater) processServer(serverName spec.ServerName) (time.Duration, bool) { ctx := u.process.Context() logger := util.GetLogger(ctx).WithField("server_name", serverName) deviceListUpdateCount.WithLabelValues(string(serverName)).Inc() @@ -421,7 +422,7 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam waitTime := defaultWaitTime // How long should we wait to try again? successCount := 0 // How many user requests failed? - userIDs, err := u.db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{serverName}) + userIDs, err := u.db.StaleDeviceLists(ctx, []spec.ServerName{serverName}) if err != nil { logger.WithError(err).Error("Failed to load stale device lists") return waitTime, true @@ -457,7 +458,7 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam return waitTime, !allUsersSucceeded } -func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName gomatrixserverlib.ServerName, userID string) (time.Duration, error) { +func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName spec.ServerName, userID string) (time.Duration, error) { ctx, cancel := context.WithTimeout(ctx, requestTimeout) defer cancel() logger := util.GetLogger(ctx).WithFields(logrus.Fields{ |