diff options
Diffstat (limited to 'keyserver/internal/internal.go')
-rw-r--r-- | keyserver/internal/internal.go | 120 |
1 files changed, 111 insertions, 9 deletions
diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 041732dc..e406dab4 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -19,6 +19,8 @@ import ( "context" "encoding/json" "fmt" + "sync" + "time" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage" @@ -30,6 +32,7 @@ import ( type KeyInternalAPI struct { DB storage.Database ThisServer gomatrixserverlib.ServerName + FedClient *gomatrixserverlib.FederationClient } func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { @@ -66,15 +69,67 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC mergeInto(res.OneTimeKeys, keys) delete(domainToDeviceKeys, string(a.ThisServer)) } - // TODO: claim remote keys + // claim remote keys + a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) +} + +func (a *KeyInternalAPI) claimRemoteKeys( + ctx context.Context, timeout time.Duration, res *api.PerformClaimKeysResponse, domainToDeviceKeys map[string]map[string]map[string]string, +) { + resultCh := make(chan *gomatrixserverlib.RespClaimKeys, len(domainToDeviceKeys)) + // allows us to wait until all federation servers have been poked + var wg sync.WaitGroup + wg.Add(len(domainToDeviceKeys)) + // mutex for failures + var failMu sync.Mutex + + // fan out + for d, k := range domainToDeviceKeys { + go func(domain string, keysToClaim map[string]map[string]string) { + defer wg.Done() + fedCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, gomatrixserverlib.ServerName(domain), keysToClaim) + if err != nil { + failMu.Lock() + res.Failures[domain] = map[string]interface{}{ + "message": err.Error(), + } + failMu.Unlock() + return + } + resultCh <- &claimKeyRes + }(d, k) + } + // Close the result channel when the goroutines have quit so the for .. range exits + go func() { + wg.Wait() + close(resultCh) + }() + + for result := range resultCh { + for userID, nest := range result.OneTimeKeys { + res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage) + for deviceID, nest2 := range nest { + res.OneTimeKeys[userID][deviceID] = make(map[string]json.RawMessage) + for keyIDWithAlgo, otk := range nest2 { + keyJSON, err := json.Marshal(otk) + if err != nil { + continue + } + res.OneTimeKeys[userID][deviceID][keyIDWithAlgo] = keyJSON + } + } + } + } } func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) { res.DeviceKeys = make(map[string]map[string]json.RawMessage) res.Failures = make(map[string]interface{}) // make a map from domain to device keys - domainToUserToDevice := make(map[string][]api.DeviceKeys) + domainToDeviceKeys := make(map[string]map[string][]string) for userID, deviceIDs := range req.UserToDevices { _, serverName, err := gomatrixserverlib.SplitID('@', userID) if err != nil { @@ -100,16 +155,63 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON } } else { - for _, deviceID := range deviceIDs { - domainToUserToDevice[domain] = append(domainToUserToDevice[domain], api.DeviceKeys{ - UserID: userID, - DeviceID: deviceID, - }) - } + domainToDeviceKeys[domain] = make(map[string][]string) + domainToDeviceKeys[domain][userID] = append(domainToDeviceKeys[domain][userID], deviceIDs...) } } // TODO: set device display names when they are known - // TODO: perform key queries for remote devices + + // perform key queries for remote devices + a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) +} + +func (a *KeyInternalAPI) queryRemoteKeys( + ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string, +) { + resultCh := make(chan *gomatrixserverlib.RespQueryKeys, len(domainToDeviceKeys)) + // allows us to wait until all federation servers have been poked + var wg sync.WaitGroup + wg.Add(len(domainToDeviceKeys)) + // mutex for failures + var failMu sync.Mutex + + // fan out + for domain, deviceKeys := range domainToDeviceKeys { + go func(serverName string, devKeys map[string][]string) { + defer wg.Done() + fedCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, gomatrixserverlib.ServerName(serverName), devKeys) + if err != nil { + failMu.Lock() + res.Failures[serverName] = map[string]interface{}{ + "message": err.Error(), + } + failMu.Unlock() + return + } + resultCh <- &queryKeysResp + }(domain, deviceKeys) + } + + // Close the result channel when the goroutines have quit so the for .. range exits + go func() { + wg.Wait() + close(resultCh) + }() + + for result := range resultCh { + for userID, nest := range result.DeviceKeys { + res.DeviceKeys[userID] = make(map[string]json.RawMessage) + for deviceID, deviceKey := range nest { + keyJSON, err := json.Marshal(deviceKey) + if err != nil { + continue + } + res.DeviceKeys[userID][deviceID] = keyJSON + } + } + } } func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { |