aboutsummaryrefslogtreecommitdiff
path: root/keyserver
diff options
context:
space:
mode:
authorNeil Alexander <neilalexander@users.noreply.github.com>2022-10-27 15:34:26 +0100
committerNeil Alexander <neilalexander@users.noreply.github.com>2022-10-27 15:34:26 +0100
commita2706e6498287a5b052ef47413175bf7551b36b1 (patch)
tree69a56445dd4167e0110ac4ac38db47a4a98d6057 /keyserver
parenta785532463852796ab1676e1e60ae8f2132eb49d (diff)
Refactor `claimRemoteKeys`
Diffstat (limited to 'keyserver')
-rw-r--r--keyserver/internal/internal.go59
1 files changed, 25 insertions, 34 deletions
diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go
index ff0968b2..92ee80d8 100644
--- a/keyserver/internal/internal.go
+++ b/keyserver/internal/internal.go
@@ -128,58 +128,49 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
func (a *KeyInternalAPI) claimRemoteKeys(
ctx context.Context, timeout time.Duration, res *api.PerformClaimKeysResponse, domainToDeviceKeys map[string]map[string]map[string]string,
) {
- resultCh := make(chan *gomatrixserverlib.RespClaimKeys, len(domainToDeviceKeys))
- // allows us to wait until all federation servers have been poked
- var wg sync.WaitGroup
+ var wg sync.WaitGroup // Wait for fan-out goroutines to finish
+ var mu sync.Mutex // Protects the response struct
+ var claimed int // Number of keys claimed in total
+ var failures int // Number of servers we failed to ask
+
+ util.GetLogger(ctx).Infof("Claiming remote keys from %d server(s)", len(domainToDeviceKeys))
wg.Add(len(domainToDeviceKeys))
- // mutex for failures
- var failMu sync.Mutex
- util.GetLogger(ctx).WithField("num_servers", len(domainToDeviceKeys)).Info("Claiming remote keys from servers")
- // fan out
for d, k := range domainToDeviceKeys {
go func(domain string, keysToClaim map[string]map[string]string) {
- defer wg.Done()
fedCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
+ defer wg.Done()
+
claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, gomatrixserverlib.ServerName(domain), keysToClaim)
+
+ mu.Lock()
+ defer mu.Unlock()
+
if err != nil {
util.GetLogger(ctx).WithError(err).WithField("server", domain).Error("ClaimKeys failed")
- failMu.Lock()
res.Failures[domain] = map[string]interface{}{
"message": err.Error(),
}
- failMu.Unlock()
+ failures++
return
}
- resultCh <- &claimKeyRes
- }(d, k)
- }
- // Close the result channel when the goroutines have quit so the for .. range exits
- go func() {
- wg.Wait()
- close(resultCh)
- }()
-
- keysClaimed := 0
- for result := range resultCh {
- for userID, nest := range result.OneTimeKeys {
- res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage)
- for deviceID, nest2 := range nest {
- res.OneTimeKeys[userID][deviceID] = make(map[string]json.RawMessage)
- for keyIDWithAlgo, otk := range nest2 {
- keyJSON, err := json.Marshal(otk)
- if err != nil {
- continue
- }
- res.OneTimeKeys[userID][deviceID][keyIDWithAlgo] = keyJSON
- keysClaimed++
+ for userID, deviceIDToKeys := range claimKeyRes.OneTimeKeys {
+ res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage)
+ for deviceID, keys := range deviceIDToKeys {
+ res.OneTimeKeys[userID][deviceID] = keys
+ claimed += len(keys)
}
}
- }
+ }(d, k)
}
- util.GetLogger(ctx).WithField("num_keys", keysClaimed).Info("Claimed remote keys")
+
+ wg.Wait()
+ util.GetLogger(ctx).WithFields(logrus.Fields{
+ "num_keys": claimed,
+ "num_failures": failures,
+ }).Infof("Claimed remote keys from %d server(s)", len(domainToDeviceKeys))
}
func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error {