diff options
author | Neil Alexander <neilalexander@users.noreply.github.com> | 2022-04-25 14:22:46 +0100 |
---|---|---|
committer | Neil Alexander <neilalexander@users.noreply.github.com> | 2022-04-25 14:22:46 +0100 |
commit | aad81b7b4dcf971508cde266c5ae99e35261bf27 (patch) | |
tree | a8edc02d68b1622706b594df4dd93fdd690f0b47 | |
parent | 446819e4ac405393ae7834107adc5761afce8a34 (diff) |
Only call key update process functions if there are updates, don't send things to ourselves over federation
-rw-r--r-- | federationapi/queue/queue.go | 2 | ||||
-rw-r--r-- | federationapi/routing/send.go | 49 | ||||
-rw-r--r-- | keyserver/internal/internal.go | 9 |
3 files changed, 51 insertions, 9 deletions
diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go index 5b548127..c45bbd1d 100644 --- a/federationapi/queue/queue.go +++ b/federationapi/queue/queue.go @@ -210,6 +210,7 @@ func (oqs *OutgoingQueues) SendEvent( destmap[d] = struct{}{} } delete(destmap, oqs.origin) + delete(destmap, oqs.signing.ServerName) // Check if any of the destinations are prohibited by server ACLs. for destination := range destmap { @@ -275,6 +276,7 @@ func (oqs *OutgoingQueues) SendEDU( destmap[d] = struct{}{} } delete(destmap, oqs.origin) + delete(destmap, oqs.signing.ServerName) // There is absolutely no guarantee that the EDU will have a room_id // field, as it is not required by the spec. However, if it *does* diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index f2b902b6..2c01afb1 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -124,6 +124,7 @@ func Send( t := txnReq{ rsAPI: rsAPI, keys: keys, + ourServerName: cfg.Matrix.ServerName, federation: federation, servers: servers, keyAPI: keyAPI, @@ -183,6 +184,7 @@ type txnReq struct { gomatrixserverlib.Transaction rsAPI api.RoomserverInternalAPI keyAPI keyapi.KeyInternalAPI + ourServerName gomatrixserverlib.ServerName keys gomatrixserverlib.JSONVerifier federation txnFederationClient roomsMu *internal.MutexByRoom @@ -303,6 +305,7 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res return &gomatrixserverlib.RespSend{PDUs: results}, nil } +// nolint:gocyclo func (t *txnReq) processEDUs(ctx context.Context) { for _, e := range t.EDUs { eduCountTotal.Inc() @@ -318,13 +321,11 @@ func (t *txnReq) processEDUs(ctx context.Context) { util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal typing event") continue } - _, domain, err := gomatrixserverlib.SplitID('@', typingPayload.UserID) - if err != nil { - util.GetLogger(ctx).WithError(err).Debug("Failed to split domain from typing event sender") + if _, serverName, err := gomatrixserverlib.SplitID('@', typingPayload.UserID); err != nil { continue - } - if domain != t.Origin { - util.GetLogger(ctx).Debugf("Dropping typing event where sender domain (%q) doesn't match origin (%q)", domain, t.Origin) + } else if serverName == t.ourServerName { + continue + } else if serverName != t.Origin { continue } if err := t.producer.SendTyping(ctx, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil { @@ -337,6 +338,13 @@ func (t *txnReq) processEDUs(ctx context.Context) { util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal send-to-device events") continue } + if _, serverName, err := gomatrixserverlib.SplitID('@', directPayload.Sender); err != nil { + continue + } else if serverName == t.ourServerName { + continue + } else if serverName != t.Origin { + continue + } for userID, byUser := range directPayload.Messages { for deviceID, message := range byUser { // TODO: check that the user and the device actually exist here @@ -405,6 +413,13 @@ func (t *txnReq) processPresence(ctx context.Context, e gomatrixserverlib.EDU) e return err } for _, content := range payload.Push { + if _, serverName, err := gomatrixserverlib.SplitID('@', content.UserID); err != nil { + continue + } else if serverName == t.ourServerName { + continue + } else if serverName != t.Origin { + continue + } presence, ok := syncTypes.PresenceFromString(content.Presence) if !ok { continue @@ -424,7 +439,13 @@ func (t *txnReq) processSigningKeyUpdate(ctx context.Context, e gomatrixserverli }).Debug("Failed to unmarshal signing key update") return err } - + if _, serverName, err := gomatrixserverlib.SplitID('@', updatePayload.UserID); err != nil { + return nil + } else if serverName == t.ourServerName { + return nil + } else if serverName != t.Origin { + return nil + } keys := gomatrixserverlib.CrossSigningKeys{} if updatePayload.MasterKey != nil { keys.MasterKey = *updatePayload.MasterKey @@ -450,6 +471,13 @@ func (t *txnReq) processReceiptEvent(ctx context.Context, timestamp gomatrixserverlib.Timestamp, eventIDs []string, ) error { + if _, serverName, err := gomatrixserverlib.SplitID('@', userID); err != nil { + return nil + } else if serverName == t.ourServerName { + return nil + } else if serverName != t.Origin { + return nil + } // store every event for _, eventID := range eventIDs { if err := t.producer.SendReceipt(ctx, userID, roomID, eventID, receiptType, timestamp); err != nil { @@ -466,6 +494,13 @@ func (t *txnReq) processDeviceListUpdate(ctx context.Context, e gomatrixserverli util.GetLogger(ctx).WithError(err).Error("Failed to unmarshal device list update event") return } + if _, serverName, err := gomatrixserverlib.SplitID('@', payload.UserID); err != nil { + return + } else if serverName == t.ourServerName { + return + } else if serverName != t.Origin { + return + } var inputRes keyapi.InputDeviceListUpdateResponse t.keyAPI.InputDeviceListUpdate(context.Background(), &keyapi.InputDeviceListUpdateRequest{ Event: payload, diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index e70de767..e571c7e5 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -71,8 +71,12 @@ func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyC func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { res.KeyErrors = make(map[string]map[string]*api.KeyError) - a.uploadLocalDeviceKeys(ctx, req, res) - a.uploadOneTimeKeys(ctx, req, res) + if len(req.DeviceKeys) > 0 { + a.uploadLocalDeviceKeys(ctx, req, res) + } + if len(req.OneTimeKeys) > 0 { + a.uploadOneTimeKeys(ctx, req, res) + } } func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) { @@ -663,6 +667,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per // add the display name field from keysToStore into existingKeys keysToStore = appendDisplayNames(existingKeys, keysToStore) } + // store the device keys and emit changes err = a.DB.StoreLocalDeviceKeys(ctx, keysToStore) if err != nil { |