aboutsummaryrefslogtreecommitdiff
path: root/federationapi/storage/tables/server_key_table_test.go
blob: 322169bd00f486daa65590f474b19a55083e35e5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
package tables_test

import (
	"context"
	"testing"
	"time"

	"github.com/element-hq/dendrite/federationapi/storage/postgres"
	"github.com/element-hq/dendrite/federationapi/storage/sqlite3"
	"github.com/element-hq/dendrite/federationapi/storage/tables"
	"github.com/element-hq/dendrite/internal/sqlutil"
	"github.com/element-hq/dendrite/setup/config"
	"github.com/element-hq/dendrite/test"
	"github.com/matrix-org/gomatrixserverlib"
	"github.com/matrix-org/gomatrixserverlib/spec"
	"github.com/stretchr/testify/assert"
)

func mustCreateServerKeyDB(t *testing.T, dbType test.DBType) (tables.FederationServerSigningKeys, func()) {
	connStr, close := test.PrepareDBConnectionString(t, dbType)
	db, err := sqlutil.Open(&config.DatabaseOptions{
		ConnectionString: config.DataSource(connStr),
	}, sqlutil.NewExclusiveWriter())
	if err != nil {
		t.Fatalf("failed to open database: %s", err)
	}
	var tab tables.FederationServerSigningKeys
	switch dbType {
	case test.DBTypePostgres:
		tab, err = postgres.NewPostgresServerSigningKeysTable(db)
	case test.DBTypeSQLite:
		tab, err = sqlite3.NewSQLiteServerSigningKeysTable(db)
	}
	if err != nil {
		t.Fatalf("failed to create table: %s", err)
	}
	return tab, close
}

func TestServerKeysTable(t *testing.T) {
	test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
		ctx, cancel := context.WithCancel(context.Background())
		tab, close := mustCreateServerKeyDB(t, dbType)
		t.Cleanup(func() {
			close()
			cancel()
		})

		req := gomatrixserverlib.PublicKeyLookupRequest{
			ServerName: "localhost",
			KeyID:      "ed25519:test",
		}
		expectedTimestamp := spec.AsTimestamp(time.Now().Add(time.Hour))
		res := gomatrixserverlib.PublicKeyLookupResult{
			VerifyKey:    gomatrixserverlib.VerifyKey{Key: make(spec.Base64Bytes, 0)},
			ExpiredTS:    0,
			ValidUntilTS: expectedTimestamp,
		}

		// Insert the key
		err := tab.UpsertServerKeys(ctx, nil, req, res)
		assert.NoError(t, err)

		selectKeys := map[gomatrixserverlib.PublicKeyLookupRequest]spec.Timestamp{
			req: spec.AsTimestamp(time.Now()),
		}
		gotKeys, err := tab.BulkSelectServerKeys(ctx, nil, selectKeys)
		assert.NoError(t, err)

		// Now we should have a key for the req above
		assert.NotNil(t, gotKeys[req])
		assert.Equal(t, res, gotKeys[req])

		// "Expire" the key by setting ExpireTS to a non-zero value and ValidUntilTS to 0
		expectedTimestamp = spec.AsTimestamp(time.Now())
		res.ExpiredTS = expectedTimestamp
		res.ValidUntilTS = 0

		// Update the key
		err = tab.UpsertServerKeys(ctx, nil, req, res)
		assert.NoError(t, err)

		gotKeys, err = tab.BulkSelectServerKeys(ctx, nil, selectKeys)
		assert.NoError(t, err)

		// The key should be expired
		assert.NotNil(t, gotKeys[req])
		assert.Equal(t, res, gotKeys[req])

		// Upsert a different key to validate querying multiple keys
		req2 := gomatrixserverlib.PublicKeyLookupRequest{
			ServerName: "notlocalhost",
			KeyID:      "ed25519:test2",
		}
		expectedTimestamp2 := spec.AsTimestamp(time.Now().Add(time.Hour))
		res2 := gomatrixserverlib.PublicKeyLookupResult{
			VerifyKey:    gomatrixserverlib.VerifyKey{Key: make(spec.Base64Bytes, 0)},
			ExpiredTS:    0,
			ValidUntilTS: expectedTimestamp2,
		}

		err = tab.UpsertServerKeys(ctx, nil, req2, res2)
		assert.NoError(t, err)

		// Select multiple keys
		selectKeys[req2] = spec.AsTimestamp(time.Now())

		gotKeys, err = tab.BulkSelectServerKeys(ctx, nil, selectKeys)
		assert.NoError(t, err)

		// We now should receive two keys, one of which is expired
		assert.Equal(t, 2, len(gotKeys))
		assert.Equal(t, res2, gotKeys[req2])
		assert.Equal(t, res, gotKeys[req])
	})
}