aboutsummaryrefslogtreecommitdiff
path: root/userapi/internal/key_api_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'userapi/internal/key_api_test.go')
-rw-r--r--userapi/internal/key_api_test.go161
1 files changed, 161 insertions, 0 deletions
diff --git a/userapi/internal/key_api_test.go b/userapi/internal/key_api_test.go
new file mode 100644
index 00000000..fc7e7e0d
--- /dev/null
+++ b/userapi/internal/key_api_test.go
@@ -0,0 +1,161 @@
+package internal_test
+
+import (
+ "context"
+ "reflect"
+ "testing"
+
+ "github.com/matrix-org/dendrite/setup/config"
+ "github.com/matrix-org/dendrite/test"
+ "github.com/matrix-org/dendrite/test/testrig"
+ "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/internal"
+ "github.com/matrix-org/dendrite/userapi/storage"
+)
+
+func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) {
+ t.Helper()
+ connStr, close := test.PrepareDBConnectionString(t, dbType)
+ base, _, _ := testrig.Base(nil)
+ db, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{
+ ConnectionString: config.DataSource(connStr),
+ })
+ if err != nil {
+ t.Fatalf("failed to create new user db: %v", err)
+ }
+ return db, func() {
+ base.Close()
+ close()
+ }
+}
+
+func Test_QueryDeviceMessages(t *testing.T) {
+ alice := test.NewUser(t)
+ type args struct {
+ req *api.QueryDeviceMessagesRequest
+ res *api.QueryDeviceMessagesResponse
+ }
+ tests := []struct {
+ name string
+ args args
+ wantErr bool
+ want *api.QueryDeviceMessagesResponse
+ }{
+ {
+ name: "no existing keys",
+ args: args{
+ req: &api.QueryDeviceMessagesRequest{
+ UserID: "@doesNotExist:localhost",
+ },
+ res: &api.QueryDeviceMessagesResponse{},
+ },
+ want: &api.QueryDeviceMessagesResponse{},
+ },
+ {
+ name: "existing user returns devices",
+ args: args{
+ req: &api.QueryDeviceMessagesRequest{
+ UserID: alice.ID,
+ },
+ res: &api.QueryDeviceMessagesResponse{},
+ },
+ want: &api.QueryDeviceMessagesResponse{
+ StreamID: 6,
+ Devices: []api.DeviceMessage{
+ {
+ Type: api.TypeDeviceKeyUpdate, StreamID: 5, DeviceKeys: &api.DeviceKeys{
+ DeviceID: "myDevice",
+ DisplayName: "first device",
+ UserID: alice.ID,
+ KeyJSON: []byte("ghi"),
+ },
+ },
+ {
+ Type: api.TypeDeviceKeyUpdate, StreamID: 6, DeviceKeys: &api.DeviceKeys{
+ DeviceID: "mySecondDevice",
+ DisplayName: "second device",
+ UserID: alice.ID,
+ KeyJSON: []byte("jkl"),
+ }, // streamID 6
+ },
+ },
+ },
+ },
+ }
+
+ deviceMessages := []api.DeviceMessage{
+ { // not the user we're looking for
+ Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
+ UserID: "@doesNotExist:localhost",
+ },
+ // streamID 1 for this user
+ },
+ { // empty keyJSON will be ignored
+ Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
+ DeviceID: "myDevice",
+ UserID: alice.ID,
+ }, // streamID 1
+ },
+ {
+ Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
+ DeviceID: "myDevice",
+ UserID: alice.ID,
+ KeyJSON: []byte("abc"),
+ }, // streamID 2
+ },
+ {
+ Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
+ DeviceID: "myDevice",
+ UserID: alice.ID,
+ KeyJSON: []byte("def"),
+ }, // streamID 3
+ },
+ {
+ Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
+ DeviceID: "myDevice",
+ UserID: alice.ID,
+ KeyJSON: []byte(""),
+ }, // streamID 4
+ },
+ {
+ Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
+ DeviceID: "myDevice",
+ DisplayName: "first device",
+ UserID: alice.ID,
+ KeyJSON: []byte("ghi"),
+ }, // streamID 5
+ },
+ {
+ Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
+ DeviceID: "mySecondDevice",
+ UserID: alice.ID,
+ KeyJSON: []byte("jkl"),
+ DisplayName: "second device",
+ }, // streamID 6
+ },
+ }
+ ctx := context.Background()
+
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ db, closeDB := mustCreateDatabase(t, dbType)
+ defer closeDB()
+ if err := db.StoreLocalDeviceKeys(ctx, deviceMessages); err != nil {
+ t.Fatalf("failed to store local devicesKeys")
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ a := &internal.UserInternalAPI{
+ KeyDatabase: db,
+ }
+ if err := a.QueryDeviceMessages(ctx, tt.args.req, tt.args.res); (err != nil) != tt.wantErr {
+ t.Errorf("QueryDeviceMessages() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ got := tt.args.res
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("QueryDeviceMessages(): got:\n%+v, want:\n%+v", got, tt.want)
+ }
+ })
+ }
+ })
+}