aboutsummaryrefslogtreecommitdiff
path: root/keyserver/internal/internal.go
diff options
context:
space:
mode:
Diffstat (limited to 'keyserver/internal/internal.go')
-rw-r--r--keyserver/internal/internal.go120
1 files changed, 111 insertions, 9 deletions
diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go
index 041732dc..e406dab4 100644
--- a/keyserver/internal/internal.go
+++ b/keyserver/internal/internal.go
@@ -19,6 +19,8 @@ import (
"context"
"encoding/json"
"fmt"
+ "sync"
+ "time"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage"
@@ -30,6 +32,7 @@ import (
type KeyInternalAPI struct {
DB storage.Database
ThisServer gomatrixserverlib.ServerName
+ FedClient *gomatrixserverlib.FederationClient
}
func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
@@ -66,15 +69,67 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC
mergeInto(res.OneTimeKeys, keys)
delete(domainToDeviceKeys, string(a.ThisServer))
}
- // TODO: claim remote keys
+ // claim remote keys
+ a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
+}
+
+func (a *KeyInternalAPI) claimRemoteKeys(
+ ctx context.Context, timeout time.Duration, res *api.PerformClaimKeysResponse, domainToDeviceKeys map[string]map[string]map[string]string,
+) {
+ resultCh := make(chan *gomatrixserverlib.RespClaimKeys, len(domainToDeviceKeys))
+ // allows us to wait until all federation servers have been poked
+ var wg sync.WaitGroup
+ wg.Add(len(domainToDeviceKeys))
+ // mutex for failures
+ var failMu sync.Mutex
+
+ // fan out
+ for d, k := range domainToDeviceKeys {
+ go func(domain string, keysToClaim map[string]map[string]string) {
+ defer wg.Done()
+ fedCtx, cancel := context.WithTimeout(ctx, timeout)
+ defer cancel()
+ claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, gomatrixserverlib.ServerName(domain), keysToClaim)
+ if err != nil {
+ failMu.Lock()
+ res.Failures[domain] = map[string]interface{}{
+ "message": err.Error(),
+ }
+ failMu.Unlock()
+ return
+ }
+ resultCh <- &claimKeyRes
+ }(d, k)
+ }
+ // Close the result channel when the goroutines have quit so the for .. range exits
+ go func() {
+ wg.Wait()
+ close(resultCh)
+ }()
+
+ for result := range resultCh {
+ for userID, nest := range result.OneTimeKeys {
+ res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage)
+ for deviceID, nest2 := range nest {
+ res.OneTimeKeys[userID][deviceID] = make(map[string]json.RawMessage)
+ for keyIDWithAlgo, otk := range nest2 {
+ keyJSON, err := json.Marshal(otk)
+ if err != nil {
+ continue
+ }
+ res.OneTimeKeys[userID][deviceID][keyIDWithAlgo] = keyJSON
+ }
+ }
+ }
+ }
}
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) {
res.DeviceKeys = make(map[string]map[string]json.RawMessage)
res.Failures = make(map[string]interface{})
// make a map from domain to device keys
- domainToUserToDevice := make(map[string][]api.DeviceKeys)
+ domainToDeviceKeys := make(map[string]map[string][]string)
for userID, deviceIDs := range req.UserToDevices {
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
@@ -100,16 +155,63 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON
}
} else {
- for _, deviceID := range deviceIDs {
- domainToUserToDevice[domain] = append(domainToUserToDevice[domain], api.DeviceKeys{
- UserID: userID,
- DeviceID: deviceID,
- })
- }
+ domainToDeviceKeys[domain] = make(map[string][]string)
+ domainToDeviceKeys[domain][userID] = append(domainToDeviceKeys[domain][userID], deviceIDs...)
}
}
// TODO: set device display names when they are known
- // TODO: perform key queries for remote devices
+
+ // perform key queries for remote devices
+ a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
+}
+
+func (a *KeyInternalAPI) queryRemoteKeys(
+ ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string,
+) {
+ resultCh := make(chan *gomatrixserverlib.RespQueryKeys, len(domainToDeviceKeys))
+ // allows us to wait until all federation servers have been poked
+ var wg sync.WaitGroup
+ wg.Add(len(domainToDeviceKeys))
+ // mutex for failures
+ var failMu sync.Mutex
+
+ // fan out
+ for domain, deviceKeys := range domainToDeviceKeys {
+ go func(serverName string, devKeys map[string][]string) {
+ defer wg.Done()
+ fedCtx, cancel := context.WithTimeout(ctx, timeout)
+ defer cancel()
+ queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, gomatrixserverlib.ServerName(serverName), devKeys)
+ if err != nil {
+ failMu.Lock()
+ res.Failures[serverName] = map[string]interface{}{
+ "message": err.Error(),
+ }
+ failMu.Unlock()
+ return
+ }
+ resultCh <- &queryKeysResp
+ }(domain, deviceKeys)
+ }
+
+ // Close the result channel when the goroutines have quit so the for .. range exits
+ go func() {
+ wg.Wait()
+ close(resultCh)
+ }()
+
+ for result := range resultCh {
+ for userID, nest := range result.DeviceKeys {
+ res.DeviceKeys[userID] = make(map[string]json.RawMessage)
+ for deviceID, deviceKey := range nest {
+ keyJSON, err := json.Marshal(deviceKey)
+ if err != nil {
+ continue
+ }
+ res.DeviceKeys[userID][deviceID] = keyJSON
+ }
+ }
+ }
}
func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {