aboutsummaryrefslogtreecommitdiff
path: root/keyserver/storage/storage_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'keyserver/storage/storage_test.go')
-rw-r--r--keyserver/storage/storage_test.go197
1 files changed, 0 insertions, 197 deletions
diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go
deleted file mode 100644
index e7a2af7c..00000000
--- a/keyserver/storage/storage_test.go
+++ /dev/null
@@ -1,197 +0,0 @@
-package storage_test
-
-import (
- "context"
- "reflect"
- "sync"
- "testing"
-
- "github.com/matrix-org/dendrite/keyserver/api"
- "github.com/matrix-org/dendrite/keyserver/storage"
- "github.com/matrix-org/dendrite/keyserver/types"
- "github.com/matrix-org/dendrite/test"
- "github.com/matrix-org/dendrite/test/testrig"
-)
-
-var ctx = context.Background()
-
-func MustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
- base, close := testrig.CreateBaseDendrite(t, dbType)
- db, err := storage.NewDatabase(base, &base.Cfg.KeyServer.Database)
- if err != nil {
- t.Fatalf("failed to create new database: %v", err)
- }
- return db, close
-}
-
-func MustNotError(t *testing.T, err error) {
- t.Helper()
- if err == nil {
- return
- }
- t.Fatalf("operation failed: %s", err)
-}
-
-func TestKeyChanges(t *testing.T) {
- test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, clean := MustCreateDatabase(t, dbType)
- defer clean()
- _, err := db.StoreKeyChange(ctx, "@alice:localhost")
- MustNotError(t, err)
- deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
- MustNotError(t, err)
- deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost")
- MustNotError(t, err)
- userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest)
- if err != nil {
- t.Fatalf("Failed to KeyChanges: %s", err)
- }
- if latest != deviceChangeIDC {
- t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC)
- }
- if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) {
- t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
- }
- })
-}
-
-func TestKeyChangesNoDupes(t *testing.T) {
- test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, clean := MustCreateDatabase(t, dbType)
- defer clean()
- deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
- MustNotError(t, err)
- deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost")
- MustNotError(t, err)
- if deviceChangeIDA == deviceChangeIDB {
- t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA)
- }
- deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost")
- MustNotError(t, err)
- userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest)
- if err != nil {
- t.Fatalf("Failed to KeyChanges: %s", err)
- }
- if latest != deviceChangeID {
- t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID)
- }
- if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) {
- t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
- }
- })
-}
-
-func TestKeyChangesUpperLimit(t *testing.T) {
- test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, clean := MustCreateDatabase(t, dbType)
- defer clean()
- deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
- MustNotError(t, err)
- deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
- MustNotError(t, err)
- _, err = db.StoreKeyChange(ctx, "@charlie:localhost")
- MustNotError(t, err)
- userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB)
- if err != nil {
- t.Fatalf("Failed to KeyChanges: %s", err)
- }
- if latest != deviceChangeIDB {
- t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB)
- }
- if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) {
- t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
- }
- })
-}
-
-var dbLock sync.Mutex
-var deviceArray = []string{"AAA", "another_device"}
-
-// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user,
-// and that they are returned correctly when querying for device keys.
-func TestDeviceKeysStreamIDGeneration(t *testing.T) {
- var err error
- test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- db, clean := MustCreateDatabase(t, dbType)
- defer clean()
- alice := "@alice:TestDeviceKeysStreamIDGeneration"
- bob := "@bob:TestDeviceKeysStreamIDGeneration"
- msgs := []api.DeviceMessage{
- {
- Type: api.TypeDeviceKeyUpdate,
- DeviceKeys: &api.DeviceKeys{
- DeviceID: "AAA",
- UserID: alice,
- KeyJSON: []byte(`{"key":"v1"}`),
- },
- // StreamID: 1
- },
- {
- Type: api.TypeDeviceKeyUpdate,
- DeviceKeys: &api.DeviceKeys{
- DeviceID: "AAA",
- UserID: bob,
- KeyJSON: []byte(`{"key":"v1"}`),
- },
- // StreamID: 1 as this is a different user
- },
- {
- Type: api.TypeDeviceKeyUpdate,
- DeviceKeys: &api.DeviceKeys{
- DeviceID: "another_device",
- UserID: alice,
- KeyJSON: []byte(`{"key":"v1"}`),
- },
- // StreamID: 2 as this is a 2nd device key
- },
- }
- MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
- if msgs[0].StreamID != 1 {
- t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID)
- }
- if msgs[1].StreamID != 1 {
- t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID)
- }
- if msgs[2].StreamID != 2 {
- t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID)
- }
-
- // updating a device sets the next stream ID for that user
- msgs = []api.DeviceMessage{
- {
- Type: api.TypeDeviceKeyUpdate,
- DeviceKeys: &api.DeviceKeys{
- DeviceID: "AAA",
- UserID: alice,
- KeyJSON: []byte(`{"key":"v2"}`),
- },
- // StreamID: 3
- },
- }
- MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
- if msgs[0].StreamID != 3 {
- t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
- }
-
- dbLock.Lock()
- defer dbLock.Unlock()
- // Querying for device keys returns the latest stream IDs
- msgs, err = db.DeviceKeysForUser(ctx, alice, deviceArray, false)
-
- if err != nil {
- t.Fatalf("DeviceKeysForUser returned error: %s", err)
- }
- wantStreamIDs := map[string]int64{
- "AAA": 3,
- "another_device": 2,
- }
- if len(msgs) != len(wantStreamIDs) {
- t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs))
- }
- for _, m := range msgs {
- if m.StreamID != wantStreamIDs[m.DeviceID] {
- t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID])
- }
- }
- })
-}