diff options
Diffstat (limited to 'keyserver')
-rw-r--r-- | keyserver/internal/internal.go | 120 | ||||
-rw-r--r-- | keyserver/keyserver.go | 4 |
2 files changed, 114 insertions, 10 deletions
diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 041732dc..e406dab4 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -19,6 +19,8 @@ import ( "context" "encoding/json" "fmt" + "sync" + "time" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage" @@ -30,6 +32,7 @@ import ( type KeyInternalAPI struct { DB storage.Database ThisServer gomatrixserverlib.ServerName + FedClient *gomatrixserverlib.FederationClient } func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { @@ -66,15 +69,67 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC mergeInto(res.OneTimeKeys, keys) delete(domainToDeviceKeys, string(a.ThisServer)) } - // TODO: claim remote keys + // claim remote keys + a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) +} + +func (a *KeyInternalAPI) claimRemoteKeys( + ctx context.Context, timeout time.Duration, res *api.PerformClaimKeysResponse, domainToDeviceKeys map[string]map[string]map[string]string, +) { + resultCh := make(chan *gomatrixserverlib.RespClaimKeys, len(domainToDeviceKeys)) + // allows us to wait until all federation servers have been poked + var wg sync.WaitGroup + wg.Add(len(domainToDeviceKeys)) + // mutex for failures + var failMu sync.Mutex + + // fan out + for d, k := range domainToDeviceKeys { + go func(domain string, keysToClaim map[string]map[string]string) { + defer wg.Done() + fedCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, gomatrixserverlib.ServerName(domain), keysToClaim) + if err != nil { + failMu.Lock() + res.Failures[domain] = map[string]interface{}{ + "message": err.Error(), + } + failMu.Unlock() + return + } + resultCh <- &claimKeyRes + }(d, k) + } + // Close the result channel when the goroutines have quit so the for .. range exits + go func() { + wg.Wait() + close(resultCh) + }() + + for result := range resultCh { + for userID, nest := range result.OneTimeKeys { + res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage) + for deviceID, nest2 := range nest { + res.OneTimeKeys[userID][deviceID] = make(map[string]json.RawMessage) + for keyIDWithAlgo, otk := range nest2 { + keyJSON, err := json.Marshal(otk) + if err != nil { + continue + } + res.OneTimeKeys[userID][deviceID][keyIDWithAlgo] = keyJSON + } + } + } + } } func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) { res.DeviceKeys = make(map[string]map[string]json.RawMessage) res.Failures = make(map[string]interface{}) // make a map from domain to device keys - domainToUserToDevice := make(map[string][]api.DeviceKeys) + domainToDeviceKeys := make(map[string]map[string][]string) for userID, deviceIDs := range req.UserToDevices { _, serverName, err := gomatrixserverlib.SplitID('@', userID) if err != nil { @@ -100,16 +155,63 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON } } else { - for _, deviceID := range deviceIDs { - domainToUserToDevice[domain] = append(domainToUserToDevice[domain], api.DeviceKeys{ - UserID: userID, - DeviceID: deviceID, - }) - } + domainToDeviceKeys[domain] = make(map[string][]string) + domainToDeviceKeys[domain][userID] = append(domainToDeviceKeys[domain][userID], deviceIDs...) } } // TODO: set device display names when they are known - // TODO: perform key queries for remote devices + + // perform key queries for remote devices + a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) +} + +func (a *KeyInternalAPI) queryRemoteKeys( + ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string, +) { + resultCh := make(chan *gomatrixserverlib.RespQueryKeys, len(domainToDeviceKeys)) + // allows us to wait until all federation servers have been poked + var wg sync.WaitGroup + wg.Add(len(domainToDeviceKeys)) + // mutex for failures + var failMu sync.Mutex + + // fan out + for domain, deviceKeys := range domainToDeviceKeys { + go func(serverName string, devKeys map[string][]string) { + defer wg.Done() + fedCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, gomatrixserverlib.ServerName(serverName), devKeys) + if err != nil { + failMu.Lock() + res.Failures[serverName] = map[string]interface{}{ + "message": err.Error(), + } + failMu.Unlock() + return + } + resultCh <- &queryKeysResp + }(domain, deviceKeys) + } + + // Close the result channel when the goroutines have quit so the for .. range exits + go func() { + wg.Wait() + close(resultCh) + }() + + for result := range resultCh { + for userID, nest := range result.DeviceKeys { + res.DeviceKeys[userID] = make(map[string]json.RawMessage) + for deviceID, deviceKey := range nest { + keyJSON, err := json.Marshal(deviceKey) + if err != nil { + continue + } + res.DeviceKeys[userID][deviceID] = keyJSON + } + } + } } func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index 3c70fc21..714b59f0 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/keyserver/internal" "github.com/matrix-org/dendrite/keyserver/inthttp" "github.com/matrix-org/dendrite/keyserver/storage" + "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" ) @@ -32,7 +33,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) { // NewInternalAPI returns a concerete implementation of the internal API. Callers // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. -func NewInternalAPI(cfg *config.Dendrite) api.KeyInternalAPI { +func NewInternalAPI(cfg *config.Dendrite, fedClient *gomatrixserverlib.FederationClient) api.KeyInternalAPI { db, err := storage.NewDatabase( string(cfg.Database.E2EKey), cfg.DbProperties(), @@ -43,5 +44,6 @@ func NewInternalAPI(cfg *config.Dendrite) api.KeyInternalAPI { return &internal.KeyInternalAPI{ DB: db, ThisServer: cfg.Matrix.ServerName, + FedClient: fedClient, } } |