diff options
Diffstat (limited to 'keyserver/internal/internal.go')
-rw-r--r-- | keyserver/internal/internal.go | 46 |
1 files changed, 46 insertions, 0 deletions
diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 5be87aa4..041732dc 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -37,9 +37,39 @@ func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perform a.uploadDeviceKeys(ctx, req, res) a.uploadOneTimeKeys(ctx, req, res) } + func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) { + res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage) + res.Failures = make(map[string]interface{}) + // wrap request map in a top-level by-domain map + domainToDeviceKeys := make(map[string]map[string]map[string]string) + for userID, val := range req.OneTimeKeys { + _, serverName, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + continue // ignore invalid users + } + nested, ok := domainToDeviceKeys[string(serverName)] + if !ok { + nested = make(map[string]map[string]string) + } + nested[userID] = val + domainToDeviceKeys[string(serverName)] = nested + } + // claim local keys + if local, ok := domainToDeviceKeys[string(a.ThisServer)]; ok { + keys, err := a.DB.ClaimKeys(ctx, local) + if err != nil { + res.Error = &api.KeyError{ + Err: fmt.Sprintf("failed to ClaimKeys locally: %s", err), + } + } + mergeInto(res.OneTimeKeys, keys) + delete(domainToDeviceKeys, string(a.ThisServer)) + } + // TODO: claim remote keys } + func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) { res.DeviceKeys = make(map[string]map[string]json.RawMessage) res.Failures = make(map[string]interface{}) @@ -166,3 +196,19 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceKeys) { // TODO } + +func mergeInto(dst map[string]map[string]map[string]json.RawMessage, src []api.OneTimeKeys) { + for _, key := range src { + _, ok := dst[key.UserID] + if !ok { + dst[key.UserID] = make(map[string]map[string]json.RawMessage) + } + _, ok = dst[key.UserID][key.DeviceID] + if !ok { + dst[key.UserID][key.DeviceID] = make(map[string]json.RawMessage) + } + for keyID, keyJSON := range key.KeyJSON { + dst[key.UserID][key.DeviceID][keyID] = keyJSON + } + } +} |