aboutsummaryrefslogtreecommitdiff
path: root/federationapi/federationapi_keys_test.go
blob: 9dda389ed9a82e590642620320a72219606ab053 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
package federationapi

import (
	"bytes"
	"context"
	"crypto/ed25519"
	"encoding/json"
	"fmt"
	"io"
	"net/http"
	"os"
	"testing"
	"time"

	"github.com/matrix-org/dendrite/internal/sqlutil"
	"github.com/matrix-org/dendrite/setup/jetstream"
	"github.com/matrix-org/dendrite/setup/process"
	"github.com/matrix-org/gomatrixserverlib"
	"github.com/matrix-org/gomatrixserverlib/fclient"
	"github.com/matrix-org/gomatrixserverlib/spec"

	"github.com/matrix-org/dendrite/federationapi/api"
	"github.com/matrix-org/dendrite/federationapi/routing"
	"github.com/matrix-org/dendrite/internal/caching"
	"github.com/matrix-org/dendrite/setup/config"
)

type server struct {
	name      spec.ServerName           // server name
	validity  time.Duration             // key validity duration from now
	config    *config.FederationAPI     // skeleton config, from TestMain
	fedclient fclient.FederationClient  // uses MockRoundTripper
	cache     *caching.Caches           // server-specific cache
	api       api.FederationInternalAPI // server-specific server key API
}

func (s *server) renew() {
	// This updates the validity period to be an hour in the
	// future, which is particularly useful in server A and
	// server C's cases which have validity either as now or
	// in the past.
	s.validity = time.Hour
	s.config.Matrix.KeyValidityPeriod = s.validity
}

var (
	serverKeyID = gomatrixserverlib.KeyID("ed25519:auto")
	serverA     = &server{name: "a.com", validity: time.Duration(0)} // expires now
	serverB     = &server{name: "b.com", validity: time.Hour}        // expires in an hour
	serverC     = &server{name: "c.com", validity: -time.Hour}       // expired an hour ago
)

var servers = map[string]*server{
	"a.com": serverA,
	"b.com": serverB,
	"c.com": serverC,
}

func TestMain(m *testing.M) {
	// Set up the server key API for each "server" that we
	// will use in our tests.
	os.Exit(func() int {
		for _, s := range servers {
			// Generate a new key.
			_, testPriv, err := ed25519.GenerateKey(nil)
			if err != nil {
				panic("can't generate identity key: " + err.Error())
			}

			// Create a new cache but don't enable prometheus!
			s.cache = caching.NewRistrettoCache(8*1024*1024, time.Hour, false)
			natsInstance := jetstream.NATSInstance{}
			// Create a temporary directory for JetStream.
			d, err := os.MkdirTemp("./", "jetstream*")
			if err != nil {
				panic(err)
			}
			defer os.RemoveAll(d)

			// Draw up just enough Dendrite config for the server key
			// API to work.
			cfg := &config.Dendrite{}
			cfg.Defaults(config.DefaultOpts{
				Generate:       true,
				SingleDatabase: false,
			})
			cfg.Global.ServerName = spec.ServerName(s.name)
			cfg.Global.PrivateKey = testPriv
			cfg.Global.JetStream.InMemory = true
			cfg.Global.JetStream.TopicPrefix = string(s.name[:1])
			cfg.Global.JetStream.StoragePath = config.Path(d)
			cfg.Global.KeyID = serverKeyID
			cfg.Global.KeyValidityPeriod = s.validity
			cfg.FederationAPI.KeyPerspectives = nil
			f, err := os.CreateTemp(d, "federation_keys_test*.db")
			if err != nil {
				return -1
			}
			defer f.Close()
			cfg.FederationAPI.Database.ConnectionString = config.DataSource("file:" + f.Name())
			s.config = &cfg.FederationAPI

			// Create a transport which redirects federation requests to
			// the mock round tripper. Since we're not *really* listening for
			// federation requests then this will return the key instead.
			transport := &http.Transport{}
			transport.RegisterProtocol("matrix", &MockRoundTripper{})

			// Create the federation client.
			s.fedclient = fclient.NewFederationClient(
				s.config.Matrix.SigningIdentities(),
				fclient.WithTransport(transport),
			)

			// Finally, build the server key APIs.
			processCtx := process.NewProcessContext()
			cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions)
			s.api = NewInternalAPI(processCtx, cfg, cm, &natsInstance, s.fedclient, nil, s.cache, nil, true)
		}

		// Now that we have built our server key APIs, start the
		// rest of the tests.
		return m.Run()
	}())
}

type MockRoundTripper struct{}

func (m *MockRoundTripper) RoundTrip(req *http.Request) (res *http.Response, err error) {
	// Check if the request is looking for keys from a server that
	// we know about in the test. The only reason this should go wrong
	// is if the test is broken.
	s, ok := servers[req.Host]
	if !ok {
		return nil, fmt.Errorf("server not known: %s", req.Host)
	}

	// We're intercepting /matrix/key/v2/server requests here, so check
	// that the URL supplied in the request is for that.
	if req.URL.Path != "/_matrix/key/v2/server" {
		return nil, fmt.Errorf("unexpected request path: %s", req.URL.Path)
	}

	// Get the keys and JSON-ify them.
	keys := routing.LocalKeys(s.config, spec.ServerName(req.Host))
	body, err := json.MarshalIndent(keys.JSON, "", "  ")
	if err != nil {
		return nil, err
	}

	// And respond.
	res = &http.Response{
		StatusCode: 200,
		Body:       io.NopCloser(bytes.NewReader(body)),
	}
	return
}

func TestServersRequestOwnKeys(t *testing.T) {
	// Each server will request its own keys. There's no reason
	// for this to fail as each server should know its own keys.

	for name, s := range servers {
		req := gomatrixserverlib.PublicKeyLookupRequest{
			ServerName: s.name,
			KeyID:      serverKeyID,
		}
		res, err := s.api.FetchKeys(
			context.Background(),
			map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp{
				req: spec.AsTimestamp(time.Now()),
			},
		)
		if err != nil {
			t.Fatalf("server could not fetch own key: %s", err)
		}
		if _, ok := res[req]; !ok {
			t.Fatalf("server didn't return its own key in the results")
		}
		t.Logf("%s's key expires at %s\n", name, res[req].ValidUntilTS.Time())
	}
}

func TestRenewalBehaviour(t *testing.T) {
	// Server A will request Server C's key but their validity period
	// is an hour in the past. We'll retrieve the key as, even though it's
	// past its validity, it will be able to verify past events.

	req := gomatrixserverlib.PublicKeyLookupRequest{
		ServerName: serverC.name,
		KeyID:      serverKeyID,
	}

	res, err := serverA.api.FetchKeys(
		context.Background(),
		map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp{
			req: spec.AsTimestamp(time.Now()),
		},
	)
	if err != nil {
		t.Fatalf("server A failed to retrieve server C key: %s", err)
	}
	if len(res) != 1 {
		t.Fatalf("server C should have returned one key but instead returned %d keys", len(res))
	}
	if _, ok := res[req]; !ok {
		t.Fatalf("server C isn't included in the key fetch response")
	}

	originalValidity := res[req].ValidUntilTS

	// We're now going to kick server C into renewing its key. Since we're
	// happy at this point that the key that we already have is from the past
	// then repeating a key fetch should cause us to try and renew the key.
	// If so, then the new key will end up in our cache.
	serverC.renew()

	res, err = serverA.api.FetchKeys(
		context.Background(),
		map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp{
			req: spec.AsTimestamp(time.Now()),
		},
	)
	if err != nil {
		t.Fatalf("server A failed to retrieve server C key: %s", err)
	}
	if len(res) != 1 {
		t.Fatalf("server C should have returned one key but instead returned %d keys", len(res))
	}
	if _, ok := res[req]; !ok {
		t.Fatalf("server C isn't included in the key fetch response")
	}

	currentValidity := res[req].ValidUntilTS

	if originalValidity == currentValidity {
		t.Fatalf("server C key should have renewed but didn't")
	}
}