aboutsummaryrefslogtreecommitdiff
path: root/userapi
diff options
context:
space:
mode:
authorTill <2353100+S7evinK@users.noreply.github.com>2023-10-23 11:09:05 +0200
committerGitHub <noreply@github.com>2023-10-23 11:09:05 +0200
commit8c23c1150c37a88e078037b8c4b47f4efecab727 (patch)
treeab2f4b3a34fa217e81cadee2f2ffba8d07ccf6d6 /userapi
parentfe2955a4dbf40562374d79bc991e0d7028f0f239 (diff)
Tweaks around the device list updater (#3227)
I hope the comments explain the changes. `notifyWorkers` notifies a worker which then calls `processServer`, which in turn gets all users and calls `processServerUser`. There is no need to call `processServer` for the same domain on startup.
Diffstat (limited to 'userapi')
-rw-r--r--userapi/internal/device_list_update.go45
-rw-r--r--userapi/internal/device_list_update_test.go46
2 files changed, 82 insertions, 9 deletions
diff --git a/userapi/internal/device_list_update.go b/userapi/internal/device_list_update.go
index 3fccf56b..2f33589f 100644
--- a/userapi/internal/device_list_update.go
+++ b/userapi/internal/device_list_update.go
@@ -180,11 +180,13 @@ func (u *DeviceListUpdater) Start() error {
if err != nil {
return err
}
+
+ newStaleLists := dedupeStaleLists(staleLists)
offset, step := time.Second*10, time.Second
- if max := len(staleLists); max > 120 {
+ if max := len(newStaleLists); max > 120 {
step = (time.Second * 120) / time.Duration(max)
}
- for _, userID := range staleLists {
+ for _, userID := range newStaleLists {
userID := userID // otherwise we are only sending the last entry
time.AfterFunc(offset, func() {
u.notifyWorkers(userID)
@@ -416,6 +418,12 @@ func (u *DeviceListUpdater) worker(ch chan spec.ServerName) {
func (u *DeviceListUpdater) processServer(serverName spec.ServerName) (time.Duration, bool) {
ctx := u.process.Context()
+ // If the process.Context is canceled, there is no need to go further.
+ // This avoids spamming the logs when shutting down
+ if errors.Is(ctx.Err(), context.Canceled) {
+ return defaultWaitTime, false
+ }
+
logger := util.GetLogger(ctx).WithField("server_name", serverName)
deviceListUpdateCount.WithLabelValues(string(serverName)).Inc()
@@ -428,13 +436,6 @@ func (u *DeviceListUpdater) processServer(serverName spec.ServerName) (time.Dura
return waitTime, true
}
- defer func() {
- for _, userID := range userIDs {
- // always clear the channel to unblock Update calls regardless of success/failure
- u.clearChannel(userID)
- }
- }()
-
for _, userID := range userIDs {
userWait, err := u.processServerUser(ctx, serverName, userID)
if err != nil {
@@ -461,6 +462,11 @@ func (u *DeviceListUpdater) processServer(serverName spec.ServerName) (time.Dura
func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName spec.ServerName, userID string) (time.Duration, error) {
ctx, cancel := context.WithTimeout(ctx, requestTimeout)
defer cancel()
+
+ // If we are processing more than one user per server, this unblocks further calls to Update
+ // immediately instead of just after **all** users have been processed.
+ defer u.clearChannel(userID)
+
logger := util.GetLogger(ctx).WithFields(logrus.Fields{
"server_name": serverName,
"user_id": userID,
@@ -579,3 +585,24 @@ func (u *DeviceListUpdater) updateDeviceList(res *fclient.RespUserDevices) error
}
return nil
}
+
+// dedupeStaleLists de-duplicates the stateList entries using the domain.
+// This is used on startup, processServer is getting all users anyway, so
+// there is no need to send every user to the workers.
+func dedupeStaleLists(staleLists []string) []string {
+ seenDomains := make(map[spec.ServerName]struct{})
+ newStaleLists := make([]string, 0, len(staleLists))
+ for _, userID := range staleLists {
+ _, domain, err := gomatrixserverlib.SplitID('@', userID)
+ if err != nil {
+ // non-fatal and should not block starting up
+ continue
+ }
+ if _, ok := seenDomains[domain]; ok {
+ continue
+ }
+ newStaleLists = append(newStaleLists, userID)
+ seenDomains[domain] = struct{}{}
+ }
+ return newStaleLists
+}
diff --git a/userapi/internal/device_list_update_test.go b/userapi/internal/device_list_update_test.go
index 10b9c652..38fd8b58 100644
--- a/userapi/internal/device_list_update_test.go
+++ b/userapi/internal/device_list_update_test.go
@@ -428,3 +428,49 @@ func TestDeviceListUpdater_CleanUp(t *testing.T) {
}
})
}
+
+func Test_dedupeStateList(t *testing.T) {
+ alice := "@alice:localhost"
+ bob := "@bob:localhost"
+ charlie := "@charlie:notlocalhost"
+ invalidUserID := "iaminvalid:localhost"
+
+ tests := []struct {
+ name string
+ staleLists []string
+ want []string
+ }{
+ {
+ name: "empty stateLists",
+ staleLists: []string{},
+ want: []string{},
+ },
+ {
+ name: "single entry",
+ staleLists: []string{alice},
+ want: []string{alice},
+ },
+ {
+ name: "multiple entries without dupe servers",
+ staleLists: []string{alice, charlie},
+ want: []string{alice, charlie},
+ },
+ {
+ name: "multiple entries with dupe servers",
+ staleLists: []string{alice, bob, charlie},
+ want: []string{alice, charlie},
+ },
+ {
+ name: "list with invalid userID",
+ staleLists: []string{alice, bob, invalidUserID},
+ want: []string{alice},
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := dedupeStaleLists(tt.staleLists); !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("dedupeStaleLists() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}