aboutsummaryrefslogtreecommitdiff
path: root/userapi/internal/device_list_update.go
diff options
context:
space:
mode:
Diffstat (limited to 'userapi/internal/device_list_update.go')
-rw-r--r--userapi/internal/device_list_update.go45
1 files changed, 36 insertions, 9 deletions
diff --git a/userapi/internal/device_list_update.go b/userapi/internal/device_list_update.go
index 3fccf56b..2f33589f 100644
--- a/userapi/internal/device_list_update.go
+++ b/userapi/internal/device_list_update.go
@@ -180,11 +180,13 @@ func (u *DeviceListUpdater) Start() error {
if err != nil {
return err
}
+
+ newStaleLists := dedupeStaleLists(staleLists)
offset, step := time.Second*10, time.Second
- if max := len(staleLists); max > 120 {
+ if max := len(newStaleLists); max > 120 {
step = (time.Second * 120) / time.Duration(max)
}
- for _, userID := range staleLists {
+ for _, userID := range newStaleLists {
userID := userID // otherwise we are only sending the last entry
time.AfterFunc(offset, func() {
u.notifyWorkers(userID)
@@ -416,6 +418,12 @@ func (u *DeviceListUpdater) worker(ch chan spec.ServerName) {
func (u *DeviceListUpdater) processServer(serverName spec.ServerName) (time.Duration, bool) {
ctx := u.process.Context()
+ // If the process.Context is canceled, there is no need to go further.
+ // This avoids spamming the logs when shutting down
+ if errors.Is(ctx.Err(), context.Canceled) {
+ return defaultWaitTime, false
+ }
+
logger := util.GetLogger(ctx).WithField("server_name", serverName)
deviceListUpdateCount.WithLabelValues(string(serverName)).Inc()
@@ -428,13 +436,6 @@ func (u *DeviceListUpdater) processServer(serverName spec.ServerName) (time.Dura
return waitTime, true
}
- defer func() {
- for _, userID := range userIDs {
- // always clear the channel to unblock Update calls regardless of success/failure
- u.clearChannel(userID)
- }
- }()
-
for _, userID := range userIDs {
userWait, err := u.processServerUser(ctx, serverName, userID)
if err != nil {
@@ -461,6 +462,11 @@ func (u *DeviceListUpdater) processServer(serverName spec.ServerName) (time.Dura
func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName spec.ServerName, userID string) (time.Duration, error) {
ctx, cancel := context.WithTimeout(ctx, requestTimeout)
defer cancel()
+
+ // If we are processing more than one user per server, this unblocks further calls to Update
+ // immediately instead of just after **all** users have been processed.
+ defer u.clearChannel(userID)
+
logger := util.GetLogger(ctx).WithFields(logrus.Fields{
"server_name": serverName,
"user_id": userID,
@@ -579,3 +585,24 @@ func (u *DeviceListUpdater) updateDeviceList(res *fclient.RespUserDevices) error
}
return nil
}
+
+// dedupeStaleLists de-duplicates the stateList entries using the domain.
+// This is used on startup, processServer is getting all users anyway, so
+// there is no need to send every user to the workers.
+func dedupeStaleLists(staleLists []string) []string {
+ seenDomains := make(map[spec.ServerName]struct{})
+ newStaleLists := make([]string, 0, len(staleLists))
+ for _, userID := range staleLists {
+ _, domain, err := gomatrixserverlib.SplitID('@', userID)
+ if err != nil {
+ // non-fatal and should not block starting up
+ continue
+ }
+ if _, ok := seenDomains[domain]; ok {
+ continue
+ }
+ newStaleLists = append(newStaleLists, userID)
+ seenDomains[domain] = struct{}{}
+ }
+ return newStaleLists
+}