aboutsummaryrefslogtreecommitdiff
path: root/userapi/internal/key_api.go
diff options
context:
space:
mode:
Diffstat (limited to 'userapi/internal/key_api.go')
-rw-r--r--userapi/internal/key_api.go798
1 files changed, 798 insertions, 0 deletions
diff --git a/userapi/internal/key_api.go b/userapi/internal/key_api.go
new file mode 100644
index 00000000..be816fe5
--- /dev/null
+++ b/userapi/internal/key_api.go
@@ -0,0 +1,798 @@
+// Copyright 2020 The Matrix.org Foundation C.I.C.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package internal
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "sync"
+ "time"
+
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
+ "github.com/sirupsen/logrus"
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
+
+ "github.com/matrix-org/dendrite/userapi/api"
+)
+
+func (a *UserInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) error {
+ userIDs, latest, err := a.KeyDatabase.KeyChanges(ctx, req.Offset, req.ToOffset)
+ if err != nil {
+ res.Error = &api.KeyError{
+ Err: err.Error(),
+ }
+ return nil
+ }
+ res.Offset = latest
+ res.UserIDs = userIDs
+ return nil
+}
+
+func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) error {
+ res.KeyErrors = make(map[string]map[string]*api.KeyError)
+ if len(req.DeviceKeys) > 0 {
+ a.uploadLocalDeviceKeys(ctx, req, res)
+ }
+ if len(req.OneTimeKeys) > 0 {
+ a.uploadOneTimeKeys(ctx, req, res)
+ }
+ otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
+ if err != nil {
+ return err
+ }
+ res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks}
+ return nil
+}
+
+func (a *UserInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) error {
+ res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage)
+ res.Failures = make(map[string]interface{})
+ // wrap request map in a top-level by-domain map
+ domainToDeviceKeys := make(map[string]map[string]map[string]string)
+ for userID, val := range req.OneTimeKeys {
+ _, serverName, err := gomatrixserverlib.SplitID('@', userID)
+ if err != nil {
+ continue // ignore invalid users
+ }
+ nested, ok := domainToDeviceKeys[string(serverName)]
+ if !ok {
+ nested = make(map[string]map[string]string)
+ }
+ nested[userID] = val
+ domainToDeviceKeys[string(serverName)] = nested
+ }
+ for domain, local := range domainToDeviceKeys {
+ if !a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
+ continue
+ }
+ // claim local keys
+ keys, err := a.KeyDatabase.ClaimKeys(ctx, local)
+ if err != nil {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("failed to ClaimKeys locally: %s", err),
+ }
+ }
+ util.GetLogger(ctx).WithField("keys_claimed", len(keys)).WithField("num_users", len(local)).Info("Claimed local keys")
+ for _, key := range keys {
+ _, ok := res.OneTimeKeys[key.UserID]
+ if !ok {
+ res.OneTimeKeys[key.UserID] = make(map[string]map[string]json.RawMessage)
+ }
+ _, ok = res.OneTimeKeys[key.UserID][key.DeviceID]
+ if !ok {
+ res.OneTimeKeys[key.UserID][key.DeviceID] = make(map[string]json.RawMessage)
+ }
+ for keyID, keyJSON := range key.KeyJSON {
+ res.OneTimeKeys[key.UserID][key.DeviceID][keyID] = keyJSON
+ }
+ }
+ delete(domainToDeviceKeys, domain)
+ }
+ if len(domainToDeviceKeys) > 0 {
+ a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
+ }
+ return nil
+}
+
+func (a *UserInternalAPI) claimRemoteKeys(
+ ctx context.Context, timeout time.Duration, res *api.PerformClaimKeysResponse, domainToDeviceKeys map[string]map[string]map[string]string,
+) {
+ var wg sync.WaitGroup // Wait for fan-out goroutines to finish
+ var mu sync.Mutex // Protects the response struct
+ var claimed int // Number of keys claimed in total
+ var failures int // Number of servers we failed to ask
+
+ util.GetLogger(ctx).Infof("Claiming remote keys from %d server(s)", len(domainToDeviceKeys))
+ wg.Add(len(domainToDeviceKeys))
+
+ for d, k := range domainToDeviceKeys {
+ go func(domain string, keysToClaim map[string]map[string]string) {
+ fedCtx, cancel := context.WithTimeout(ctx, timeout)
+ defer cancel()
+ defer wg.Done()
+
+ claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, a.Config.Matrix.ServerName, gomatrixserverlib.ServerName(domain), keysToClaim)
+
+ mu.Lock()
+ defer mu.Unlock()
+
+ if err != nil {
+ util.GetLogger(ctx).WithError(err).WithField("server", domain).Error("ClaimKeys failed")
+ res.Failures[domain] = map[string]interface{}{
+ "message": err.Error(),
+ }
+ failures++
+ return
+ }
+
+ for userID, deviceIDToKeys := range claimKeyRes.OneTimeKeys {
+ res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage)
+ for deviceID, keys := range deviceIDToKeys {
+ res.OneTimeKeys[userID][deviceID] = keys
+ claimed += len(keys)
+ }
+ }
+ }(d, k)
+ }
+
+ wg.Wait()
+ util.GetLogger(ctx).WithFields(logrus.Fields{
+ "num_keys": claimed,
+ "num_failures": failures,
+ }).Infof("Claimed remote keys from %d server(s)", len(domainToDeviceKeys))
+}
+
+func (a *UserInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error {
+ if err := a.KeyDatabase.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("Failed to delete device keys: %s", err),
+ }
+ }
+ return nil
+}
+
+func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) error {
+ count, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
+ if err != nil {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("Failed to query OTK counts: %s", err),
+ }
+ return nil
+ }
+ res.Count = *count
+ return nil
+}
+
+func (a *UserInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error {
+ msgs, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, false)
+ if err != nil {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("failed to query DB for device keys: %s", err),
+ }
+ return nil
+ }
+ maxStreamID := int64(0)
+ // remove deleted devices
+ var result []api.DeviceMessage
+ for _, m := range msgs {
+ if m.StreamID > maxStreamID {
+ maxStreamID = m.StreamID
+ }
+ if m.KeyJSON == nil || len(m.KeyJSON) == 0 {
+ continue
+ }
+ result = append(result, m)
+ }
+ res.Devices = result
+ res.StreamID = maxStreamID
+ return nil
+}
+
+// PerformMarkAsStaleIfNeeded marks the users device list as stale, if the given deviceID is not present
+// in our database.
+func (a *UserInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *api.PerformMarkAsStaleRequest, res *struct{}) error {
+ knownDevices, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, []string{}, true)
+ if err != nil {
+ return err
+ }
+ if len(knownDevices) == 0 {
+ return nil // fmt.Errorf("unknown user %s", req.UserID)
+ }
+
+ for i := range knownDevices {
+ if knownDevices[i].DeviceID == req.DeviceID {
+ return nil // we already know about this device
+ }
+ }
+
+ return a.Updater.ManualUpdate(ctx, req.Domain, req.UserID)
+}
+
+// nolint:gocyclo
+func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error {
+ var respMu sync.Mutex
+ res.DeviceKeys = make(map[string]map[string]json.RawMessage)
+ res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
+ res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
+ res.UserSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
+ res.Failures = make(map[string]interface{})
+
+ // make a map from domain to device keys
+ domainToDeviceKeys := make(map[string]map[string][]string)
+ domainToCrossSigningKeys := make(map[string]map[string]struct{})
+ for userID, deviceIDs := range req.UserToDevices {
+ _, serverName, err := gomatrixserverlib.SplitID('@', userID)
+ if err != nil {
+ continue // ignore invalid users
+ }
+ domain := string(serverName)
+ // query local devices
+ if a.Config.Matrix.IsLocalServerName(serverName) {
+ deviceKeys, err := a.KeyDatabase.DeviceKeysForUser(ctx, userID, deviceIDs, false)
+ if err != nil {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("failed to query local device keys: %s", err),
+ }
+ return nil
+ }
+
+ // pull out display names after we have the keys so we handle wildcards correctly
+ var dids []string
+ for _, dk := range deviceKeys {
+ dids = append(dids, dk.DeviceID)
+ }
+ var queryRes api.QueryDeviceInfosResponse
+ err = a.QueryDeviceInfos(ctx, &api.QueryDeviceInfosRequest{
+ DeviceIDs: dids,
+ }, &queryRes)
+ if err != nil {
+ util.GetLogger(ctx).Warnf("Failed to QueryDeviceInfos for device IDs, display names will be missing")
+ }
+
+ if res.DeviceKeys[userID] == nil {
+ res.DeviceKeys[userID] = make(map[string]json.RawMessage)
+ }
+ for _, dk := range deviceKeys {
+ if len(dk.KeyJSON) == 0 {
+ continue // don't include blank keys
+ }
+ // inject display name if known (either locally or remotely)
+ displayName := dk.DisplayName
+ if queryRes.DeviceInfo[dk.DeviceID].DisplayName != "" {
+ displayName = queryRes.DeviceInfo[dk.DeviceID].DisplayName
+ }
+ dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct {
+ DisplayName string `json:"device_display_name,omitempty"`
+ }{displayName})
+ res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON
+ }
+ } else {
+ domainToDeviceKeys[domain] = make(map[string][]string)
+ domainToDeviceKeys[domain][userID] = append(domainToDeviceKeys[domain][userID], deviceIDs...)
+ }
+ // work out if our cross-signing request for this user was
+ // satisfied, if not add them to the list of things to fetch
+ if _, ok := res.MasterKeys[userID]; !ok {
+ if _, ok := domainToCrossSigningKeys[domain]; !ok {
+ domainToCrossSigningKeys[domain] = make(map[string]struct{})
+ }
+ domainToCrossSigningKeys[domain][userID] = struct{}{}
+ }
+ if _, ok := res.SelfSigningKeys[userID]; !ok {
+ if _, ok := domainToCrossSigningKeys[domain]; !ok {
+ domainToCrossSigningKeys[domain] = make(map[string]struct{})
+ }
+ domainToCrossSigningKeys[domain][userID] = struct{}{}
+ }
+ }
+
+ // attempt to satisfy key queries from the local database first as we should get device updates pushed to us
+ domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, &respMu, domainToDeviceKeys)
+ if len(domainToDeviceKeys) > 0 || len(domainToCrossSigningKeys) > 0 {
+ // perform key queries for remote devices
+ a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys)
+ }
+
+ // Now that we've done the potentially expensive work of asking the federation,
+ // try filling the cross-signing keys from the database that we know about.
+ a.crossSigningKeysFromDatabase(ctx, req, res)
+
+ // Finally, append signatures that we know about
+ // TODO: This is horrible because we need to round-trip the signature from
+ // JSON, add the signatures and marshal it again, for some reason?
+
+ for targetUserID, masterKey := range res.MasterKeys {
+ if masterKey.Signatures == nil {
+ masterKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
+ }
+ for targetKeyID := range masterKey.Keys {
+ sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID)
+ if err != nil {
+ // Stop executing the function if the context was canceled/the deadline was exceeded,
+ // as we can't continue without a valid context.
+ if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
+ return nil
+ }
+ logrus.WithError(err).Errorf("a.KeyDatabase.CrossSigningSigsForTarget failed")
+ continue
+ }
+ if len(sigMap) == 0 {
+ continue
+ }
+ for sourceUserID, forSourceUser := range sigMap {
+ for sourceKeyID, sourceSig := range forSourceUser {
+ if _, ok := masterKey.Signatures[sourceUserID]; !ok {
+ masterKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
+ }
+ masterKey.Signatures[sourceUserID][sourceKeyID] = sourceSig
+ }
+ }
+ }
+ }
+
+ for targetUserID, forUserID := range res.DeviceKeys {
+ for targetKeyID, key := range forUserID {
+ sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, gomatrixserverlib.KeyID(targetKeyID))
+ if err != nil {
+ // Stop executing the function if the context was canceled/the deadline was exceeded,
+ // as we can't continue without a valid context.
+ if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
+ return nil
+ }
+ logrus.WithError(err).Errorf("a.KeyDatabase.CrossSigningSigsForTarget failed")
+ continue
+ }
+ if len(sigMap) == 0 {
+ continue
+ }
+ var deviceKey gomatrixserverlib.DeviceKeys
+ if err = json.Unmarshal(key, &deviceKey); err != nil {
+ continue
+ }
+ if deviceKey.Signatures == nil {
+ deviceKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
+ }
+ for sourceUserID, forSourceUser := range sigMap {
+ for sourceKeyID, sourceSig := range forSourceUser {
+ if _, ok := deviceKey.Signatures[sourceUserID]; !ok {
+ deviceKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
+ }
+ deviceKey.Signatures[sourceUserID][sourceKeyID] = sourceSig
+ }
+ }
+ if js, err := json.Marshal(deviceKey); err == nil {
+ res.DeviceKeys[targetUserID][targetKeyID] = js
+ }
+ }
+ }
+ return nil
+}
+
+func (a *UserInternalAPI) remoteKeysFromDatabase(
+ ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, domainToDeviceKeys map[string]map[string][]string,
+) map[string]map[string][]string {
+ fetchRemote := make(map[string]map[string][]string)
+ for domain, userToDeviceMap := range domainToDeviceKeys {
+ for userID, deviceIDs := range userToDeviceMap {
+ // we can't safely return keys from the db when all devices are requested as we don't
+ // know if one has just been added.
+ if len(deviceIDs) > 0 {
+ err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, deviceIDs)
+ if err == nil {
+ continue
+ }
+ util.GetLogger(ctx).WithError(err).Error("populateResponseWithDeviceKeysFromDatabase")
+ }
+ // fetch device lists from remote
+ if _, ok := fetchRemote[domain]; !ok {
+ fetchRemote[domain] = make(map[string][]string)
+ }
+ fetchRemote[domain][userID] = append(fetchRemote[domain][userID], deviceIDs...)
+
+ }
+ }
+ return fetchRemote
+}
+
+func (a *UserInternalAPI) queryRemoteKeys(
+ ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse,
+ domainToDeviceKeys map[string]map[string][]string, domainToCrossSigningKeys map[string]map[string]struct{},
+) {
+ resultCh := make(chan *gomatrixserverlib.RespQueryKeys, len(domainToDeviceKeys))
+ // allows us to wait until all federation servers have been poked
+ var wg sync.WaitGroup
+ // mutex for writing directly to res (e.g failures)
+ var respMu sync.Mutex
+
+ domains := map[string]struct{}{}
+ for domain := range domainToDeviceKeys {
+ if a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
+ continue
+ }
+ domains[domain] = struct{}{}
+ }
+ for domain := range domainToCrossSigningKeys {
+ if a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
+ continue
+ }
+ domains[domain] = struct{}{}
+ }
+ wg.Add(len(domains))
+
+ // fan out
+ for domain := range domains {
+ go a.queryRemoteKeysOnServer(
+ ctx, domain, domainToDeviceKeys[domain], domainToCrossSigningKeys[domain],
+ &wg, &respMu, timeout, resultCh, res,
+ )
+ }
+
+ // Close the result channel when the goroutines have quit so the for .. range exits
+ go func() {
+ wg.Wait()
+ close(resultCh)
+ }()
+
+ processResult := func(result *gomatrixserverlib.RespQueryKeys) {
+ respMu.Lock()
+ defer respMu.Unlock()
+ for userID, nest := range result.DeviceKeys {
+ res.DeviceKeys[userID] = make(map[string]json.RawMessage)
+ for deviceID, deviceKey := range nest {
+ keyJSON, err := json.Marshal(deviceKey)
+ if err != nil {
+ continue
+ }
+ res.DeviceKeys[userID][deviceID] = keyJSON
+ }
+ }
+
+ for userID, body := range result.MasterKeys {
+ res.MasterKeys[userID] = body
+ }
+
+ for userID, body := range result.SelfSigningKeys {
+ res.SelfSigningKeys[userID] = body
+ }
+
+ // TODO: do we want to persist these somewhere now
+ // that we have fetched them?
+ }
+
+ for result := range resultCh {
+ processResult(result)
+ }
+}
+
+func (a *UserInternalAPI) queryRemoteKeysOnServer(
+ ctx context.Context, serverName string, devKeys map[string][]string, crossSigningKeys map[string]struct{},
+ wg *sync.WaitGroup, respMu *sync.Mutex, timeout time.Duration, resultCh chan<- *gomatrixserverlib.RespQueryKeys,
+ res *api.QueryKeysResponse,
+) {
+ defer wg.Done()
+ fedCtx := ctx
+ if timeout > 0 {
+ var cancel context.CancelFunc
+ fedCtx, cancel = context.WithTimeout(ctx, timeout)
+ defer cancel()
+ }
+ // for users who we do not have any knowledge about, try to start doing device list updates for them
+ // by hitting /users/devices - otherwise fallback to /keys/query which has nicer bulk properties but
+ // lack a stream ID.
+ userIDsForAllDevices := map[string]struct{}{}
+ for userID, deviceIDs := range devKeys {
+ if len(deviceIDs) == 0 {
+ userIDsForAllDevices[userID] = struct{}{}
+ }
+ }
+ // for cross-signing keys, it's probably easier just to hit /keys/query if we aren't already doing
+ // a device list update, so we'll populate those back into the /keys/query list if not
+ for userID := range crossSigningKeys {
+ if devKeys == nil {
+ devKeys = map[string][]string{}
+ }
+ if _, ok := userIDsForAllDevices[userID]; !ok {
+ devKeys[userID] = []string{}
+ }
+ }
+ for userID := range userIDsForAllDevices {
+ err := a.Updater.ManualUpdate(context.Background(), gomatrixserverlib.ServerName(serverName), userID)
+ if err != nil {
+ logrus.WithFields(logrus.Fields{
+ logrus.ErrorKey: err,
+ "user_id": userID,
+ "server": serverName,
+ }).Error("Failed to manually update device lists for user")
+ // try to do it via /keys/query
+ devKeys[userID] = []string{}
+ continue
+ }
+ // refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this
+ // user so the fact that we're populating all devices here isn't a problem so long as we have devices.
+ err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil)
+ if err != nil {
+ logrus.WithFields(logrus.Fields{
+ logrus.ErrorKey: err,
+ "user_id": userID,
+ "server": serverName,
+ }).Error("Failed to manually update device lists for user")
+ // try to do it via /keys/query
+ devKeys[userID] = []string{}
+ continue
+ }
+ }
+ if len(devKeys) == 0 {
+ return
+ }
+ queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, a.Config.Matrix.ServerName, gomatrixserverlib.ServerName(serverName), devKeys)
+ if err == nil {
+ resultCh <- &queryKeysResp
+ return
+ }
+ respMu.Lock()
+ res.Failures[serverName] = map[string]interface{}{
+ "message": err.Error(),
+ }
+ respMu.Unlock()
+
+ // last ditch, use the cache only. This is good for when clients hit /keys/query and the remote server
+ // is down, better to return something than nothing at all. Clients can know about the failure by
+ // inspecting the failures map though so they can know it's a cached response.
+ for userID, dkeys := range devKeys {
+ // drop the error as it's already a failure at this point
+ _ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, dkeys)
+ }
+
+ // Sytest expects no failures, if we still could retrieve keys, e.g. from local cache
+ respMu.Lock()
+ if len(res.DeviceKeys) > 0 {
+ delete(res.Failures, serverName)
+ }
+ respMu.Unlock()
+}
+
+func (a *UserInternalAPI) populateResponseWithDeviceKeysFromDatabase(
+ ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, userID string, deviceIDs []string,
+) error {
+ keys, err := a.KeyDatabase.DeviceKeysForUser(ctx, userID, deviceIDs, false)
+ // if we can't query the db or there are fewer keys than requested, fetch from remote.
+ if err != nil {
+ return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err)
+ }
+ if len(keys) < len(deviceIDs) {
+ return fmt.Errorf("DeviceKeysForUser %s returned fewer devices than requested, falling back to remote", userID)
+ }
+ if len(deviceIDs) == 0 && len(keys) == 0 {
+ return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID)
+ }
+ respMu.Lock()
+ if res.DeviceKeys[userID] == nil {
+ res.DeviceKeys[userID] = make(map[string]json.RawMessage)
+ }
+ respMu.Unlock()
+
+ for _, key := range keys {
+ if len(key.KeyJSON) == 0 {
+ continue // ignore deleted keys
+ }
+ // inject the display name
+ key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
+ DisplayName string `json:"device_display_name,omitempty"`
+ }{key.DisplayName})
+ respMu.Lock()
+ res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
+ respMu.Unlock()
+ }
+ return nil
+}
+
+func (a *UserInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
+ // get a list of devices from the user API that actually exist, as
+ // we won't store keys for devices that don't exist
+ uapidevices := &api.QueryDevicesResponse{}
+ if err := a.QueryDevices(ctx, &api.QueryDevicesRequest{UserID: req.UserID}, uapidevices); err != nil {
+ res.Error = &api.KeyError{
+ Err: err.Error(),
+ }
+ return
+ }
+ if !uapidevices.UserExists {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("user %q does not exist", req.UserID),
+ }
+ return
+ }
+ existingDeviceMap := make(map[string]struct{}, len(uapidevices.Devices))
+ for _, key := range uapidevices.Devices {
+ existingDeviceMap[key.ID] = struct{}{}
+ }
+
+ // Get all of the user existing device keys so we can check for changes.
+ existingKeys, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, true)
+ if err != nil {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()),
+ }
+ return
+ }
+
+ // Work out whether we have device keys in the keyserver for devices that
+ // no longer exist in the user API. This is mostly an exercise to ensure
+ // that we keep some integrity between the two.
+ var toClean []gomatrixserverlib.KeyID
+ for _, k := range existingKeys {
+ if _, ok := existingDeviceMap[k.DeviceID]; !ok {
+ toClean = append(toClean, gomatrixserverlib.KeyID(k.DeviceID))
+ }
+ }
+
+ if len(toClean) > 0 {
+ if err = a.KeyDatabase.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil {
+ logrus.WithField("user_id", req.UserID).WithError(err).Errorf("Failed to clean up %d stale keyserver device key entries", len(toClean))
+ } else {
+ logrus.WithField("user_id", req.UserID).Debugf("Cleaned up %d stale keyserver device key entries", len(toClean))
+ }
+ }
+
+ var keysToStore []api.DeviceMessage
+
+ if req.OnlyDisplayNameUpdates {
+ for _, existingKey := range existingKeys {
+ for _, newKey := range req.DeviceKeys {
+ switch {
+ case existingKey.UserID != newKey.UserID:
+ continue
+ case existingKey.DeviceID != newKey.DeviceID:
+ continue
+ case existingKey.DisplayName != newKey.DisplayName:
+ existingKey.DisplayName = newKey.DisplayName
+ }
+ }
+ keysToStore = append(keysToStore, existingKey)
+ }
+ } else {
+ // assert that the user ID / device ID are not lying for each key
+ for _, key := range req.DeviceKeys {
+ var serverName gomatrixserverlib.ServerName
+ _, serverName, err = gomatrixserverlib.SplitID('@', key.UserID)
+ if err != nil {
+ continue // ignore invalid users
+ }
+ if !a.Config.Matrix.IsLocalServerName(serverName) {
+ continue // ignore remote users
+ }
+ if len(key.KeyJSON) == 0 {
+ keysToStore = append(keysToStore, key.WithStreamID(0))
+ continue // deleted keys don't need sanity checking
+ }
+ // check that the device in question actually exists in the user
+ // API before we try and store a key for it
+ if _, ok := existingDeviceMap[key.DeviceID]; !ok {
+ continue
+ }
+ gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str
+ gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str
+ if gotUserID == key.UserID && gotDeviceID == key.DeviceID {
+ keysToStore = append(keysToStore, key.WithStreamID(0))
+ continue
+ }
+
+ res.KeyError(key.UserID, key.DeviceID, &api.KeyError{
+ Err: fmt.Sprintf(
+ "user_id or device_id mismatch: users: %s - %s, devices: %s - %s",
+ gotUserID, key.UserID, gotDeviceID, key.DeviceID,
+ ),
+ })
+ }
+ }
+
+ // store the device keys and emit changes
+ err = a.KeyDatabase.StoreLocalDeviceKeys(ctx, keysToStore)
+ if err != nil {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),
+ }
+ return
+ }
+ err = emitDeviceKeyChanges(a.KeyChangeProducer, existingKeys, keysToStore, req.OnlyDisplayNameUpdates)
+ if err != nil {
+ util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err)
+ }
+}
+
+func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
+ if req.UserID == "" {
+ res.Error = &api.KeyError{
+ Err: "user ID missing",
+ }
+ }
+ if req.DeviceID != "" && len(req.OneTimeKeys) == 0 {
+ counts, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
+ if err != nil {
+ res.Error = &api.KeyError{
+ Err: fmt.Sprintf("a.KeyDatabase.OneTimeKeysCount: %s", err),
+ }
+ }
+ if counts != nil {
+ res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts)
+ }
+ return
+ }
+ for _, key := range req.OneTimeKeys {
+ // grab existing keys based on (user/device/algorithm/key ID)
+ keyIDsWithAlgorithms := make([]string, len(key.KeyJSON))
+ i := 0
+ for keyIDWithAlgo := range key.KeyJSON {
+ keyIDsWithAlgorithms[i] = keyIDWithAlgo
+ i++
+ }
+ existingKeys, err := a.KeyDatabase.ExistingOneTimeKeys(ctx, req.UserID, req.DeviceID, keyIDsWithAlgorithms)
+ if err != nil {
+ res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
+ Err: "failed to query existing one-time keys: " + err.Error(),
+ })
+ continue
+ }
+ for keyIDWithAlgo := range existingKeys {
+ // if keys exist and the JSON doesn't match, error out as the key already exists
+ if !bytes.Equal(existingKeys[keyIDWithAlgo], key.KeyJSON[keyIDWithAlgo]) {
+ res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
+ Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time key already exists", req.UserID, req.DeviceID, keyIDWithAlgo),
+ })
+ continue
+ }
+ }
+ // store one-time keys
+ counts, err := a.KeyDatabase.StoreOneTimeKeys(ctx, key)
+ if err != nil {
+ res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
+ Err: fmt.Sprintf("%s device %s : failed to store one-time keys: %s", req.UserID, req.DeviceID, err.Error()),
+ })
+ continue
+ }
+ // collect counts
+ res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts)
+ }
+
+}
+
+func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error {
+ // if we only want to update the display names, we can skip the checks below
+ if onlyUpdateDisplayName {
+ return producer.ProduceKeyChanges(new)
+ }
+ // find keys in new that are not in existing
+ var keysAdded []api.DeviceMessage
+ for _, newKey := range new {
+ exists := false
+ for _, existingKey := range existing {
+ // Do not treat the absence of keys as equal, or else we will not emit key changes
+ // when users delete devices which never had a key to begin with as both KeyJSONs are nil.
+ if existingKey.DeviceKeysEqual(&newKey) {
+ exists = true
+ break
+ }
+ }
+ if !exists {
+ keysAdded = append(keysAdded, newKey)
+ }
+ }
+ return producer.ProduceKeyChanges(keysAdded)
+}