aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTill <2353100+S7evinK@users.noreply.github.com>2024-01-08 19:14:29 +0100
committerGitHub <noreply@github.com>2024-01-08 19:14:29 +0100
commit13c5173273852d5f16f3c3d46f20fb1fd33f99d0 (patch)
tree763ad292bad7580e23025fea9bcd37ac1a2ed70c
parentedd02ec468d405b6376fdbc4527a20357c0f6cef (diff)
Fix notary keys requests for all keys (#3296)
This should be more spec compliant: > If no key IDs are given to be queried, the notary server should query for all keys.
-rw-r--r--federationapi/federationapi_test.go129
-rw-r--r--federationapi/internal/query.go9
-rw-r--r--federationapi/routing/keys.go9
3 files changed, 144 insertions, 3 deletions
diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go
index 1ea8c40e..79f4b3f2 100644
--- a/federationapi/federationapi_test.go
+++ b/federationapi/federationapi_test.go
@@ -5,11 +5,14 @@ import (
"crypto/ed25519"
"encoding/json"
"fmt"
+ "net/http"
+ "net/http/httptest"
"strings"
"sync"
"testing"
"time"
+ "github.com/matrix-org/dendrite/federationapi/routing"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/internal/sqlutil"
@@ -17,7 +20,10 @@ import (
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/matrix-org/gomatrixserverlib/spec"
+ "github.com/matrix-org/util"
"github.com/nats-io/nats.go"
+ "github.com/stretchr/testify/assert"
+ "github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/federationapi"
"github.com/matrix-org/dendrite/federationapi/api"
@@ -362,3 +368,126 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) {
}
}
}
+
+func TestNotaryServer(t *testing.T) {
+ testCases := []struct {
+ name string
+ httpBody string
+ pubKeyRequest *gomatrixserverlib.PublicKeyNotaryLookupRequest
+ validateFunc func(t *testing.T, response util.JSONResponse)
+ }{
+ {
+ name: "empty httpBody",
+ validateFunc: func(t *testing.T, resp util.JSONResponse) {
+ assert.Equal(t, http.StatusBadRequest, resp.Code)
+ nk, ok := resp.JSON.(spec.MatrixError)
+ assert.True(t, ok)
+ assert.Equal(t, spec.ErrorBadJSON, nk.ErrCode)
+ },
+ },
+ {
+ name: "valid but empty httpBody",
+ httpBody: "{}",
+ validateFunc: func(t *testing.T, resp util.JSONResponse) {
+ want := util.JSONResponse{
+ Code: http.StatusOK,
+ JSON: routing.NotaryKeysResponse{ServerKeys: []json.RawMessage{}},
+ }
+ assert.Equal(t, want, resp)
+ },
+ },
+ {
+ name: "request all keys using an empty criteria",
+ httpBody: `{"server_keys":{"servera":{}}}`,
+ validateFunc: func(t *testing.T, resp util.JSONResponse) {
+ assert.Equal(t, http.StatusOK, resp.Code)
+ nk, ok := resp.JSON.(routing.NotaryKeysResponse)
+ assert.True(t, ok)
+ assert.Equal(t, "servera", gjson.GetBytes(nk.ServerKeys[0], "server_name").Str)
+ assert.True(t, gjson.GetBytes(nk.ServerKeys[0], "verify_keys.ed25519:someID").Exists())
+ },
+ },
+ {
+ name: "request all keys using null as the criteria",
+ httpBody: `{"server_keys":{"servera":null}}`,
+ validateFunc: func(t *testing.T, resp util.JSONResponse) {
+ assert.Equal(t, http.StatusOK, resp.Code)
+ nk, ok := resp.JSON.(routing.NotaryKeysResponse)
+ assert.True(t, ok)
+ assert.Equal(t, "servera", gjson.GetBytes(nk.ServerKeys[0], "server_name").Str)
+ assert.True(t, gjson.GetBytes(nk.ServerKeys[0], "verify_keys.ed25519:someID").Exists())
+ },
+ },
+ {
+ name: "request specific key",
+ httpBody: `{"server_keys":{"servera":{"ed25519:someID":{}}}}`,
+ validateFunc: func(t *testing.T, resp util.JSONResponse) {
+ assert.Equal(t, http.StatusOK, resp.Code)
+ nk, ok := resp.JSON.(routing.NotaryKeysResponse)
+ assert.True(t, ok)
+ assert.Equal(t, "servera", gjson.GetBytes(nk.ServerKeys[0], "server_name").Str)
+ assert.True(t, gjson.GetBytes(nk.ServerKeys[0], "verify_keys.ed25519:someID").Exists())
+ },
+ },
+ {
+ name: "request multiple servers",
+ httpBody: `{"server_keys":{"servera":{"ed25519:someID":{}},"serverb":{"ed25519:someID":{}}}}`,
+ validateFunc: func(t *testing.T, resp util.JSONResponse) {
+ assert.Equal(t, http.StatusOK, resp.Code)
+ nk, ok := resp.JSON.(routing.NotaryKeysResponse)
+ assert.True(t, ok)
+ wantServers := map[string]struct{}{
+ "servera": {},
+ "serverb": {},
+ }
+ for _, js := range nk.ServerKeys {
+ serverName := gjson.GetBytes(js, "server_name").Str
+ _, ok = wantServers[serverName]
+ assert.True(t, ok, "unexpected servername: %s", serverName)
+ delete(wantServers, serverName)
+ assert.True(t, gjson.GetBytes(js, "verify_keys.ed25519:someID").Exists())
+ }
+ if len(wantServers) > 0 {
+ t.Fatalf("expected response to also contain: %#v", wantServers)
+ }
+ },
+ },
+ }
+
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ cfg, processCtx, close := testrig.CreateConfig(t, dbType)
+ defer close()
+ cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
+ caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics)
+ natsInstance := jetstream.NATSInstance{}
+ fc := &fedClient{
+ keys: map[spec.ServerName]struct {
+ key ed25519.PrivateKey
+ keyID gomatrixserverlib.KeyID
+ }{
+ "servera": {
+ key: test.PrivateKeyA,
+ keyID: "ed25519:someID",
+ },
+ "serverb": {
+ key: test.PrivateKeyB,
+ keyID: "ed25519:someID",
+ },
+ },
+ }
+
+ fedAPI := federationapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, fc, nil, caches, nil, true)
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(tc.httpBody))
+ req.Host = string(cfg.Global.ServerName)
+
+ resp := routing.NotaryKeys(req, &cfg.FederationAPI, fedAPI, tc.pubKeyRequest)
+ // assert that we received the expected response
+ tc.validateFunc(t, resp)
+ })
+ }
+
+ })
+}
diff --git a/federationapi/internal/query.go b/federationapi/internal/query.go
index e53f19ff..21e77c48 100644
--- a/federationapi/internal/query.go
+++ b/federationapi/internal/query.go
@@ -43,6 +43,15 @@ func (a *FederationInternalAPI) fetchServerKeysFromCache(
ctx context.Context, req *api.QueryServerKeysRequest,
) ([]gomatrixserverlib.ServerKeys, error) {
var results []gomatrixserverlib.ServerKeys
+
+ // We got a request for _all_ server keys, return them.
+ if len(req.KeyIDToCriteria) == 0 {
+ serverKeysResponses, _ := a.db.GetNotaryKeys(ctx, req.ServerName, []gomatrixserverlib.KeyID{})
+ if len(serverKeysResponses) == 0 {
+ return nil, fmt.Errorf("failed to find server key response for server %s", req.ServerName)
+ }
+ return serverKeysResponses, nil
+ }
for keyID, criteria := range req.KeyIDToCriteria {
serverKeysResponses, _ := a.db.GetNotaryKeys(ctx, req.ServerName, []gomatrixserverlib.KeyID{keyID})
if len(serverKeysResponses) == 0 {
diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go
index 3d8ff2de..38a88e4b 100644
--- a/federationapi/routing/keys.go
+++ b/federationapi/routing/keys.go
@@ -197,6 +197,10 @@ func localKeys(cfg *config.FederationAPI, serverName spec.ServerName) (*gomatrix
return &keys, err
}
+type NotaryKeysResponse struct {
+ ServerKeys []json.RawMessage `json:"server_keys"`
+}
+
func NotaryKeys(
httpReq *http.Request, cfg *config.FederationAPI,
fsAPI federationAPI.FederationInternalAPI,
@@ -217,10 +221,9 @@ func NotaryKeys(
}
}
- var response struct {
- ServerKeys []json.RawMessage `json:"server_keys"`
+ response := NotaryKeysResponse{
+ ServerKeys: []json.RawMessage{},
}
- response.ServerKeys = []json.RawMessage{}
for serverName, kidToCriteria := range req.ServerKeys {
var keyList []gomatrixserverlib.ServerKeys