aboutsummaryrefslogtreecommitdiff
path: root/userapi/userapi_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'userapi/userapi_test.go')
-rw-r--r--userapi/userapi_test.go327
1 files changed, 321 insertions, 6 deletions
diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go
index 68d08c2f..08b1336b 100644
--- a/userapi/userapi_test.go
+++ b/userapi/userapi_test.go
@@ -21,7 +21,10 @@ import (
"testing"
"time"
+ "github.com/matrix-org/dendrite/userapi/producers"
"github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
+ "github.com/nats-io/nats.go"
"golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/setup/config"
@@ -38,32 +41,55 @@ const (
type apiTestOpts struct {
loginTokenLifetime time.Duration
+ serverName string
}
-func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (api.UserInternalAPI, storage.Database, func()) {
+type dummyProducer struct{}
+
+func (d *dummyProducer) PublishMsg(*nats.Msg, ...nats.PubOpt) (*nats.PubAck, error) {
+ return &nats.PubAck{}, nil
+}
+
+func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (api.UserInternalAPI, storage.UserDatabase, func()) {
if opts.loginTokenLifetime == 0 {
opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond
}
base, baseclose := testrig.CreateBaseDendrite(t, dbType)
connStr, close := test.PrepareDBConnectionString(t, dbType)
- accountDB, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{
+ sName := serverName
+ if opts.serverName != "" {
+ sName = gomatrixserverlib.ServerName(opts.serverName)
+ }
+ accountDB, err := storage.NewUserDatabase(base, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
- }, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "")
+ }, sName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "")
if err != nil {
t.Fatalf("failed to create account DB: %s", err)
}
+ keyDB, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{
+ ConnectionString: config.DataSource(connStr),
+ })
+ if err != nil {
+ t.Fatalf("failed to create key DB: %s", err)
+ }
+
cfg := &config.UserAPI{
Matrix: &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
- ServerName: serverName,
+ ServerName: sName,
},
},
}
+ syncProducer := producers.NewSyncAPI(accountDB, &dummyProducer{}, "", "")
+ keyChangeProducer := &producers.KeyChange{DB: keyDB, JetStream: &dummyProducer{}}
return &internal.UserInternalAPI{
- DB: accountDB,
- Config: cfg,
+ DB: accountDB,
+ KeyDatabase: keyDB,
+ Config: cfg,
+ SyncProducer: syncProducer,
+ KeyChangeProducer: keyChangeProducer,
}, accountDB, func() {
close()
baseclose()
@@ -332,3 +358,292 @@ func TestQueryAccountByLocalpart(t *testing.T) {
testCases(t, intAPI)
})
}
+
+func TestAccountData(t *testing.T) {
+ ctx := context.Background()
+ alice := test.NewUser(t)
+
+ testCases := []struct {
+ name string
+ inputData *api.InputAccountDataRequest
+ wantErr bool
+ }{
+ {
+ name: "not a local user",
+ inputData: &api.InputAccountDataRequest{UserID: "@notlocal:example.com"},
+ wantErr: true,
+ },
+ {
+ name: "local user missing datatype",
+ inputData: &api.InputAccountDataRequest{UserID: alice.ID},
+ wantErr: true,
+ },
+ {
+ name: "missing json",
+ inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.push_rules", AccountData: nil},
+ wantErr: true,
+ },
+ {
+ name: "with json",
+ inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.push_rules", AccountData: []byte("{}")},
+ },
+ {
+ name: "room data",
+ inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.push_rules", AccountData: []byte("{}"), RoomID: "!dummy:test"},
+ },
+ {
+ name: "ignored users",
+ inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.ignored_user_list", AccountData: []byte("{}")},
+ },
+ {
+ name: "m.fully_read",
+ inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.fully_read", AccountData: []byte("{}")},
+ },
+ }
+
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType)
+ defer close()
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ res := api.InputAccountDataResponse{}
+ err := intAPI.InputAccountData(ctx, tc.inputData, &res)
+ if tc.wantErr && err == nil {
+ t.Fatalf("expected an error, but got none")
+ }
+ if !tc.wantErr && err != nil {
+ t.Fatalf("expected no error, but got: %s", err)
+ }
+
+ // query the data again and compare
+ queryRes := api.QueryAccountDataResponse{}
+ queryReq := api.QueryAccountDataRequest{
+ UserID: tc.inputData.UserID,
+ DataType: tc.inputData.DataType,
+ RoomID: tc.inputData.RoomID,
+ }
+ err = intAPI.QueryAccountData(ctx, &queryReq, &queryRes)
+ if err != nil && !tc.wantErr {
+ t.Fatal(err)
+ }
+ // verify global data
+ if tc.inputData.RoomID == "" {
+ if !reflect.DeepEqual(tc.inputData.AccountData, queryRes.GlobalAccountData[tc.inputData.DataType]) {
+ t.Fatalf("expected accountdata to be %s, got %s", string(tc.inputData.AccountData), string(queryRes.GlobalAccountData[tc.inputData.DataType]))
+ }
+ } else {
+ // verify room data
+ if !reflect.DeepEqual(tc.inputData.AccountData, queryRes.RoomAccountData[tc.inputData.RoomID][tc.inputData.DataType]) {
+ t.Fatalf("expected accountdata to be %s, got %s", string(tc.inputData.AccountData), string(queryRes.RoomAccountData[tc.inputData.RoomID][tc.inputData.DataType]))
+ }
+ }
+ })
+ }
+ })
+}
+
+func TestDevices(t *testing.T) {
+ ctx := context.Background()
+
+ dupeAccessToken := util.RandomString(8)
+
+ displayName := "testing"
+
+ creationTests := []struct {
+ name string
+ inputData *api.PerformDeviceCreationRequest
+ wantErr bool
+ wantNewDevID bool
+ }{
+ {
+ name: "not a local user",
+ inputData: &api.PerformDeviceCreationRequest{Localpart: "test1", ServerName: "notlocal"},
+ wantErr: true,
+ },
+ {
+ name: "implicit local user",
+ inputData: &api.PerformDeviceCreationRequest{Localpart: "test1", AccessToken: util.RandomString(8), NoDeviceListUpdate: true, DeviceDisplayName: &displayName},
+ },
+ {
+ name: "explicit local user",
+ inputData: &api.PerformDeviceCreationRequest{Localpart: "test2", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true},
+ },
+ {
+ name: "dupe token - ok",
+ inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true},
+ },
+ {
+ name: "dupe token - not ok",
+ inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true},
+ wantErr: true,
+ },
+ {
+ name: "test3 second device", // used to test deletion later
+ inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true},
+ },
+ {
+ name: "test3 third device", // used to test deletion later
+ wantNewDevID: true,
+ inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true},
+ },
+ }
+
+ deletionTests := []struct {
+ name string
+ inputData *api.PerformDeviceDeletionRequest
+ wantErr bool
+ wantDevices int
+ }{
+ {
+ name: "deletion - not a local user",
+ inputData: &api.PerformDeviceDeletionRequest{UserID: "@test:notlocalhost"},
+ wantErr: true,
+ },
+ {
+ name: "deleting not existing devices should not error",
+ inputData: &api.PerformDeviceDeletionRequest{UserID: "@test1:test", DeviceIDs: []string{"iDontExist"}},
+ wantDevices: 1,
+ },
+ {
+ name: "delete all devices",
+ inputData: &api.PerformDeviceDeletionRequest{UserID: "@test1:test"},
+ wantDevices: 0,
+ },
+ {
+ name: "delete all devices",
+ inputData: &api.PerformDeviceDeletionRequest{UserID: "@test3:test"},
+ wantDevices: 0,
+ },
+ }
+
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType)
+ defer close()
+
+ for _, tc := range creationTests {
+ t.Run(tc.name, func(t *testing.T) {
+ res := api.PerformDeviceCreationResponse{}
+ deviceID := util.RandomString(8)
+ tc.inputData.DeviceID = &deviceID
+ if tc.wantNewDevID {
+ tc.inputData.DeviceID = nil
+ }
+ err := intAPI.PerformDeviceCreation(ctx, tc.inputData, &res)
+ if tc.wantErr && err == nil {
+ t.Fatalf("expected an error, but got none")
+ }
+ if !tc.wantErr && err != nil {
+ t.Fatalf("expected no error, but got: %s", err)
+ }
+ if !res.DeviceCreated {
+ return
+ }
+
+ queryDevicesRes := api.QueryDevicesResponse{}
+ queryDevicesReq := api.QueryDevicesRequest{UserID: res.Device.UserID}
+ if err = intAPI.QueryDevices(ctx, &queryDevicesReq, &queryDevicesRes); err != nil {
+ t.Fatal(err)
+ }
+ // We only want to verify one device
+ if len(queryDevicesRes.Devices) > 1 {
+ return
+ }
+ res.Device.AccessToken = ""
+
+ // At this point, there should only be one device
+ if !reflect.DeepEqual(*res.Device, queryDevicesRes.Devices[0]) {
+ t.Fatalf("expected device to be\n%#v, got \n%#v", *res.Device, queryDevicesRes.Devices[0])
+ }
+
+ newDisplayName := "new name"
+ if tc.inputData.DeviceDisplayName == nil {
+ updateRes := api.PerformDeviceUpdateResponse{}
+ updateReq := api.PerformDeviceUpdateRequest{
+ RequestingUserID: fmt.Sprintf("@%s:%s", tc.inputData.Localpart, "test"),
+ DeviceID: deviceID,
+ DisplayName: &newDisplayName,
+ }
+
+ if err = intAPI.PerformDeviceUpdate(ctx, &updateReq, &updateRes); err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ queryDeviceInfosRes := api.QueryDeviceInfosResponse{}
+ queryDeviceInfosReq := api.QueryDeviceInfosRequest{DeviceIDs: []string{*tc.inputData.DeviceID}}
+ if err = intAPI.QueryDeviceInfos(ctx, &queryDeviceInfosReq, &queryDeviceInfosRes); err != nil {
+ t.Fatal(err)
+ }
+ gotDisplayName := queryDeviceInfosRes.DeviceInfo[*tc.inputData.DeviceID].DisplayName
+ if tc.inputData.DeviceDisplayName != nil {
+ wantDisplayName := *tc.inputData.DeviceDisplayName
+ if wantDisplayName != gotDisplayName {
+ t.Fatalf("expected displayName to be %s, got %s", wantDisplayName, gotDisplayName)
+ }
+ } else {
+ wantDisplayName := newDisplayName
+ if wantDisplayName != gotDisplayName {
+ t.Fatalf("expected displayName to be %s, got %s", wantDisplayName, gotDisplayName)
+ }
+ }
+ })
+ }
+
+ for _, tc := range deletionTests {
+ t.Run(tc.name, func(t *testing.T) {
+ delRes := api.PerformDeviceDeletionResponse{}
+ err := intAPI.PerformDeviceDeletion(ctx, tc.inputData, &delRes)
+ if tc.wantErr && err == nil {
+ t.Fatalf("expected an error, but got none")
+ }
+ if !tc.wantErr && err != nil {
+ t.Fatalf("expected no error, but got: %s", err)
+ }
+ if tc.wantErr {
+ return
+ }
+
+ queryDevicesRes := api.QueryDevicesResponse{}
+ queryDevicesReq := api.QueryDevicesRequest{UserID: tc.inputData.UserID}
+ if err = intAPI.QueryDevices(ctx, &queryDevicesReq, &queryDevicesRes); err != nil {
+ t.Fatal(err)
+ }
+
+ if len(queryDevicesRes.Devices) != tc.wantDevices {
+ t.Fatalf("expected %d devices, got %d", tc.wantDevices, len(queryDevicesRes.Devices))
+ }
+
+ })
+ }
+ })
+}
+
+// Tests that the session ID of a device is not reused when reusing the same device ID.
+func TestDeviceIDReuse(t *testing.T) {
+ ctx := context.Background()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType)
+ defer close()
+
+ res := api.PerformDeviceCreationResponse{}
+ // create a first device
+ deviceID := util.RandomString(8)
+ req := api.PerformDeviceCreationRequest{Localpart: "alice", ServerName: "test", DeviceID: &deviceID, NoDeviceListUpdate: true}
+ err := intAPI.PerformDeviceCreation(ctx, &req, &res)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Do the same request again, we expect a different sessionID
+ res2 := api.PerformDeviceCreationResponse{}
+ err = intAPI.PerformDeviceCreation(ctx, &req, &res2)
+ if err != nil {
+ t.Fatalf("expected no error, but got: %v", err)
+ }
+
+ if res2.Device.SessionID == res.Device.SessionID {
+ t.Fatalf("expected a different session ID, but they are the same")
+ }
+ })
+}