aboutsummaryrefslogtreecommitdiff
path: root/roomserver/storage/tables/user_room_keys_table_test.go
blob: 2809771b4c47184ecfe33cdd51d3bd6c51c32704 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
package tables_test

import (
	"context"
	"crypto/ed25519"
	"database/sql"
	"testing"

	"github.com/matrix-org/dendrite/internal/sqlutil"
	"github.com/matrix-org/dendrite/roomserver/storage/postgres"
	"github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
	"github.com/matrix-org/dendrite/roomserver/storage/tables"
	"github.com/matrix-org/dendrite/roomserver/types"
	"github.com/matrix-org/dendrite/setup/config"
	"github.com/matrix-org/dendrite/test"
	"github.com/matrix-org/gomatrixserverlib/spec"
	"github.com/stretchr/testify/assert"
	ed255192 "golang.org/x/crypto/ed25519"
)

func mustCreateUserRoomKeysTable(t *testing.T, dbType test.DBType) (tab tables.UserRoomKeys, db *sql.DB, close func()) {
	t.Helper()
	connStr, close := test.PrepareDBConnectionString(t, dbType)
	db, err := sqlutil.Open(&config.DatabaseOptions{
		ConnectionString: config.DataSource(connStr),
	}, sqlutil.NewExclusiveWriter())
	assert.NoError(t, err)
	switch dbType {
	case test.DBTypePostgres:
		err = postgres.CreateUserRoomKeysTable(db)
		assert.NoError(t, err)
		tab, err = postgres.PrepareUserRoomKeysTable(db)
	case test.DBTypeSQLite:
		err = sqlite3.CreateUserRoomKeysTable(db)
		assert.NoError(t, err)
		tab, err = sqlite3.PrepareUserRoomKeysTable(db)
	}
	assert.NoError(t, err)

	return tab, db, close
}

func TestUserRoomKeysTable(t *testing.T) {
	test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
		tab, db, close := mustCreateUserRoomKeysTable(t, dbType)
		defer close()
		userNID := types.EventStateKeyNID(1)
		roomNID := types.RoomNID(1)
		_, key, err := ed25519.GenerateKey(nil)
		assert.NoError(t, err)

		err = sqlutil.WithTransaction(db, func(txn *sql.Tx) error {
			var gotKey, key2, key3 ed25519.PrivateKey
			var pubKey ed25519.PublicKey
			gotKey, err = tab.InsertUserRoomPrivatePublicKey(context.Background(), txn, userNID, roomNID, key)
			assert.NoError(t, err)
			assert.Equal(t, gotKey, key)

			// again, this shouldn't result in an error, but return the existing key
			_, key2, err = ed25519.GenerateKey(nil)
			assert.NoError(t, err)
			gotKey, err = tab.InsertUserRoomPrivatePublicKey(context.Background(), txn, userNID, roomNID, key2)
			assert.NoError(t, err)
			assert.Equal(t, gotKey, key)

			// add another user
			_, key3, err = ed25519.GenerateKey(nil)
			assert.NoError(t, err)
			userNID2 := types.EventStateKeyNID(2)
			_, err = tab.InsertUserRoomPrivatePublicKey(context.Background(), txn, userNID2, roomNID, key3)
			assert.NoError(t, err)

			gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, roomNID)
			assert.NoError(t, err)
			assert.Equal(t, key, gotKey)
			pubKey, err = tab.SelectUserRoomPublicKey(context.Background(), txn, userNID, roomNID)
			assert.NoError(t, err)
			assert.Equal(t, key.Public(), pubKey)

			// try to update an existing key, this should only be done for users NOT on this homeserver
			var gotPubKey ed25519.PublicKey
			gotPubKey, err = tab.InsertUserRoomPublicKey(context.Background(), txn, userNID, roomNID, key2.Public().(ed25519.PublicKey))
			assert.NoError(t, err)
			assert.Equal(t, key2.Public(), gotPubKey)

			// Key doesn't exist
			gotKey, err = tab.SelectUserRoomPrivateKey(context.Background(), txn, userNID, 2)
			assert.NoError(t, err)
			assert.Nil(t, gotKey)
			pubKey, err = tab.SelectUserRoomPublicKey(context.Background(), txn, userNID, 2)
			assert.NoError(t, err)
			assert.Nil(t, pubKey)

			// query user NIDs for senderKeys
			var gotKeys map[string]types.UserRoomKeyPair
			query := map[types.RoomNID][]ed25519.PublicKey{
				roomNID:          {key2.Public().(ed25519.PublicKey), key3.Public().(ed25519.PublicKey)},
				types.RoomNID(2): {key.Public().(ed25519.PublicKey), key3.Public().(ed25519.PublicKey)}, // doesn't exist
			}
			gotKeys, err = tab.BulkSelectUserNIDs(context.Background(), txn, query)
			assert.NoError(t, err)
			assert.NotNil(t, gotKeys)

			wantKeys := map[string]types.UserRoomKeyPair{
				string(spec.Base64Bytes(key2.Public().(ed25519.PublicKey)).Encode()): {RoomNID: roomNID, EventStateKeyNID: userNID},
				string(spec.Base64Bytes(key3.Public().(ed25519.PublicKey)).Encode()): {RoomNID: roomNID, EventStateKeyNID: userNID2},
			}
			assert.Equal(t, wantKeys, gotKeys)

			// insert key that came in over federation
			var gotPublicKey, key4 ed255192.PublicKey
			key4, _, err = ed25519.GenerateKey(nil)
			assert.NoError(t, err)
			gotPublicKey, err = tab.InsertUserRoomPublicKey(context.Background(), txn, userNID, 2, key4)
			assert.NoError(t, err)
			assert.Equal(t, key4, gotPublicKey)

			return nil
		})
		assert.NoError(t, err)

	})
}