aboutsummaryrefslogtreecommitdiff
path: root/keyserver
diff options
context:
space:
mode:
Diffstat (limited to 'keyserver')
-rw-r--r--keyserver/consumers/devicelistupdate.go26
-rw-r--r--keyserver/consumers/signingkeyupdate.go28
-rw-r--r--keyserver/internal/device_list_update.go6
-rw-r--r--keyserver/internal/device_list_update_test.go14
-rw-r--r--keyserver/internal/internal.go4
-rw-r--r--keyserver/keyserver.go2
6 files changed, 45 insertions, 35 deletions
diff --git a/keyserver/consumers/devicelistupdate.go b/keyserver/consumers/devicelistupdate.go
index 575e4128..cd911f8c 100644
--- a/keyserver/consumers/devicelistupdate.go
+++ b/keyserver/consumers/devicelistupdate.go
@@ -30,12 +30,12 @@ import (
// DeviceListUpdateConsumer consumes device list updates that came in over federation.
type DeviceListUpdateConsumer struct {
- ctx context.Context
- jetstream nats.JetStreamContext
- durable string
- topic string
- updater *internal.DeviceListUpdater
- serverName gomatrixserverlib.ServerName
+ ctx context.Context
+ jetstream nats.JetStreamContext
+ durable string
+ topic string
+ updater *internal.DeviceListUpdater
+ isLocalServerName func(gomatrixserverlib.ServerName) bool
}
// NewDeviceListUpdateConsumer creates a new DeviceListConsumer. Call Start() to begin consuming from key servers.
@@ -46,12 +46,12 @@ func NewDeviceListUpdateConsumer(
updater *internal.DeviceListUpdater,
) *DeviceListUpdateConsumer {
return &DeviceListUpdateConsumer{
- ctx: process.Context(),
- jetstream: js,
- durable: cfg.Matrix.JetStream.Prefixed("KeyServerInputDeviceListConsumer"),
- topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate),
- updater: updater,
- serverName: cfg.Matrix.ServerName,
+ ctx: process.Context(),
+ jetstream: js,
+ durable: cfg.Matrix.JetStream.Prefixed("KeyServerInputDeviceListConsumer"),
+ topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate),
+ updater: updater,
+ isLocalServerName: cfg.Matrix.IsLocalServerName,
}
}
@@ -75,7 +75,7 @@ func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M
origin := gomatrixserverlib.ServerName(msg.Header.Get("origin"))
if _, serverName, err := gomatrixserverlib.SplitID('@', m.UserID); err != nil {
return true
- } else if serverName == t.serverName {
+ } else if t.isLocalServerName(serverName) {
return true
} else if serverName != origin {
return true
diff --git a/keyserver/consumers/signingkeyupdate.go b/keyserver/consumers/signingkeyupdate.go
index 366e259b..bcceaad1 100644
--- a/keyserver/consumers/signingkeyupdate.go
+++ b/keyserver/consumers/signingkeyupdate.go
@@ -31,12 +31,13 @@ import (
// SigningKeyUpdateConsumer consumes signing key updates that came in over federation.
type SigningKeyUpdateConsumer struct {
- ctx context.Context
- jetstream nats.JetStreamContext
- durable string
- topic string
- keyAPI *internal.KeyInternalAPI
- cfg *config.KeyServer
+ ctx context.Context
+ jetstream nats.JetStreamContext
+ durable string
+ topic string
+ keyAPI *internal.KeyInternalAPI
+ cfg *config.KeyServer
+ isLocalServerName func(gomatrixserverlib.ServerName) bool
}
// NewSigningKeyUpdateConsumer creates a new SigningKeyUpdateConsumer. Call Start() to begin consuming from key servers.
@@ -47,12 +48,13 @@ func NewSigningKeyUpdateConsumer(
keyAPI *internal.KeyInternalAPI,
) *SigningKeyUpdateConsumer {
return &SigningKeyUpdateConsumer{
- ctx: process.Context(),
- jetstream: js,
- durable: cfg.Matrix.JetStream.Prefixed("KeyServerSigningKeyConsumer"),
- topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate),
- keyAPI: keyAPI,
- cfg: cfg,
+ ctx: process.Context(),
+ jetstream: js,
+ durable: cfg.Matrix.JetStream.Prefixed("KeyServerSigningKeyConsumer"),
+ topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate),
+ keyAPI: keyAPI,
+ cfg: cfg,
+ isLocalServerName: cfg.Matrix.IsLocalServerName,
}
}
@@ -77,7 +79,7 @@ func (t *SigningKeyUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M
if _, serverName, err := gomatrixserverlib.SplitID('@', updatePayload.UserID); err != nil {
logrus.WithError(err).Error("failed to split user id")
return true
- } else if serverName == t.cfg.Matrix.ServerName {
+ } else if t.isLocalServerName(serverName) {
logrus.Warn("dropping device key update from ourself")
return true
} else if serverName != origin {
diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go
index 3f7c0d8b..8ff9dfc3 100644
--- a/keyserver/internal/device_list_update.go
+++ b/keyserver/internal/device_list_update.go
@@ -96,6 +96,7 @@ type DeviceListUpdater struct {
producer KeyChangeProducer
fedClient fedsenderapi.KeyserverFederationAPI
workerChans []chan gomatrixserverlib.ServerName
+ thisServer gomatrixserverlib.ServerName
// When device lists are stale for a user, they get inserted into this map with a channel which `Update` will
// block on or timeout via a select.
@@ -139,6 +140,7 @@ func NewDeviceListUpdater(
process *process.ProcessContext, db DeviceListUpdaterDatabase,
api DeviceListUpdaterAPI, producer KeyChangeProducer,
fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int,
+ thisServer gomatrixserverlib.ServerName,
) *DeviceListUpdater {
return &DeviceListUpdater{
process: process,
@@ -148,6 +150,7 @@ func NewDeviceListUpdater(
api: api,
producer: producer,
fedClient: fedClient,
+ thisServer: thisServer,
workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers),
userIDToChan: make(map[string]chan bool),
userIDToChanMu: &sync.Mutex{},
@@ -435,8 +438,7 @@ func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName go
"server_name": serverName,
"user_id": userID,
})
-
- res, err := u.fedClient.GetUserDevices(ctx, serverName, userID)
+ res, err := u.fedClient.GetUserDevices(ctx, u.thisServer, serverName, userID)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
return time.Minute * 10, err
diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go
index 28a13a0a..a374c951 100644
--- a/keyserver/internal/device_list_update_test.go
+++ b/keyserver/internal/device_list_update_test.go
@@ -129,7 +129,13 @@ func (t *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrixserverlib.FederationClient {
_, pkey, _ := ed25519.GenerateKey(nil)
fedClient := gomatrixserverlib.NewFederationClient(
- gomatrixserverlib.ServerName("example.test"), gomatrixserverlib.KeyID("ed25519:test"), pkey,
+ []*gomatrixserverlib.SigningIdentity{
+ {
+ ServerName: gomatrixserverlib.ServerName("example.test"),
+ KeyID: gomatrixserverlib.KeyID("ed25519:test"),
+ PrivateKey: pkey,
+ },
+ },
)
fedClient.Client = *gomatrixserverlib.NewClient(
gomatrixserverlib.WithTransport(&roundTripper{tripper}),
@@ -147,7 +153,7 @@ func TestUpdateHavePrevID(t *testing.T) {
}
ap := &mockDeviceListUpdaterAPI{}
producer := &mockKeyChangeProducer{}
- updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1)
+ updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, "localhost")
event := gomatrixserverlib.DeviceListUpdateEvent{
DeviceDisplayName: "Foo Bar",
Deleted: false,
@@ -219,7 +225,7 @@ func TestUpdateNoPrevID(t *testing.T) {
`)),
}, nil
})
- updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2)
+ updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, "example.test")
if err := updater.Start(); err != nil {
t.Fatalf("failed to start updater: %s", err)
}
@@ -288,7 +294,7 @@ func TestDebounce(t *testing.T) {
close(incomingFedReq)
return <-fedCh, nil
})
- updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1)
+ updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, "localhost")
if err := updater.Start(); err != nil {
t.Fatalf("failed to start updater: %s", err)
}
diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go
index 37c55c8f..9a08a0bb 100644
--- a/keyserver/internal/internal.go
+++ b/keyserver/internal/internal.go
@@ -146,7 +146,7 @@ func (a *KeyInternalAPI) claimRemoteKeys(
defer cancel()
defer wg.Done()
- claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, gomatrixserverlib.ServerName(domain), keysToClaim)
+ claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(domain), keysToClaim)
mu.Lock()
defer mu.Unlock()
@@ -559,7 +559,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
if len(devKeys) == 0 {
return
}
- queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, gomatrixserverlib.ServerName(serverName), devKeys)
+ queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(serverName), devKeys)
if err == nil {
resultCh <- &queryKeysResp
return
diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go
index 0a4b8fde..a86c2da4 100644
--- a/keyserver/keyserver.go
+++ b/keyserver/keyserver.go
@@ -58,7 +58,7 @@ func NewInternalAPI(
FedClient: fedClient,
Producer: keyChangeProducer,
}
- updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable
+ updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, cfg.Matrix.ServerName) // 8 workers TODO: configurable
ap.Updater = updater
go func() {
if err := updater.Start(); err != nil {