aboutsummaryrefslogtreecommitdiff
path: root/keyserver/internal/internal.go
diff options
context:
space:
mode:
Diffstat (limited to 'keyserver/internal/internal.go')
-rw-r--r--keyserver/internal/internal.go46
1 files changed, 46 insertions, 0 deletions
diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go
index 5be87aa4..041732dc 100644
--- a/keyserver/internal/internal.go
+++ b/keyserver/internal/internal.go
@@ -37,9 +37,39 @@ func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.Perform
a.uploadDeviceKeys(ctx, req, res)
a.uploadOneTimeKeys(ctx, req, res)
}
+
func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) {
+ res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage)
+ res.Failures = make(map[string]interface{})
+ // wrap request map in a top-level by-domain map
+ domainToDeviceKeys := make(map[string]map[string]map[string]string)
+ for userID, val := range req.OneTimeKeys {
+ _, serverName, err := gomatrixserverlib.SplitID('@', userID)
+ if err != nil {
+ continue // ignore invalid users
+ }
+ nested, ok := domainToDeviceKeys[string(serverName)]
+ if !ok {
+ nested = make(map[string]map[string]string)
+ }
+ nested[userID] = val
+ domainToDeviceKeys[string(serverName)] = nested
+ }
+ // claim local keys
+ if local, ok := domainToDeviceKeys[string(a.ThisServer)]; ok {
+ keys, err := a.DB.ClaimKeys(ctx, local)
+ if err != nil {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("failed to ClaimKeys locally: %s", err),
+ }
+ }
+ mergeInto(res.OneTimeKeys, keys)
+ delete(domainToDeviceKeys, string(a.ThisServer))
+ }
+ // TODO: claim remote keys
}
+
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) {
res.DeviceKeys = make(map[string]map[string]json.RawMessage)
res.Failures = make(map[string]interface{})
@@ -166,3 +196,19 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform
func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceKeys) {
// TODO
}
+
+func mergeInto(dst map[string]map[string]map[string]json.RawMessage, src []api.OneTimeKeys) {
+ for _, key := range src {
+ _, ok := dst[key.UserID]
+ if !ok {
+ dst[key.UserID] = make(map[string]map[string]json.RawMessage)
+ }
+ _, ok = dst[key.UserID][key.DeviceID]
+ if !ok {
+ dst[key.UserID][key.DeviceID] = make(map[string]json.RawMessage)
+ }
+ for keyID, keyJSON := range key.KeyJSON {
+ dst[key.UserID][key.DeviceID][keyID] = keyJSON
+ }
+ }
+}