aboutsummaryrefslogtreecommitdiff
path: root/userapi
diff options
context:
space:
mode:
authorTill <2353100+S7evinK@users.noreply.github.com>2023-02-20 15:26:09 +0100
committerGitHub <noreply@github.com>2023-02-20 15:26:09 +0100
commit7f114cc5387f04d748270d48f92708f137df38a7 (patch)
tree6ade197dd74999c3661486b632789f938f2a4b7b /userapi
parent4594233f89f8531fca8f696ab0ece36909130c2a (diff)
Fix issue where device keys are removed if a device ID is reused (#2982)
Fixes https://github.com/matrix-org/dendrite/issues/2980
Diffstat (limited to 'userapi')
-rw-r--r--userapi/internal/user_api.go13
-rw-r--r--userapi/userapi_test.go58
2 files changed, 55 insertions, 16 deletions
diff --git a/userapi/internal/user_api.go b/userapi/internal/user_api.go
index 1cbd9719..8977697b 100644
--- a/userapi/internal/user_api.go
+++ b/userapi/internal/user_api.go
@@ -254,6 +254,17 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
if !a.Config.Matrix.IsLocalServerName(serverName) {
return fmt.Errorf("server name %s is not local", serverName)
}
+ // If a device ID was specified, check if it already exists and
+ // avoid sending an empty device list update which would remove
+ // existing device keys.
+ isExisting := false
+ if req.DeviceID != nil && *req.DeviceID != "" {
+ existingDev, err := a.DB.GetDeviceByID(ctx, req.Localpart, req.ServerName, *req.DeviceID)
+ if err != nil && !errors.Is(err, sql.ErrNoRows) {
+ return err
+ }
+ isExisting = existingDev.ID == *req.DeviceID
+ }
util.GetLogger(ctx).WithFields(logrus.Fields{
"localpart": req.Localpart,
"device_id": req.DeviceID,
@@ -265,7 +276,7 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
}
res.DeviceCreated = true
res.Device = dev
- if req.NoDeviceListUpdate {
+ if req.NoDeviceListUpdate || isExisting {
return nil
}
// create empty device keys and upload them to trigger device list changes
diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go
index 08b1336b..01e491cb 100644
--- a/userapi/userapi_test.go
+++ b/userapi/userapi_test.go
@@ -18,6 +18,7 @@ import (
"context"
"fmt"
"reflect"
+ "sync"
"testing"
"time"
@@ -44,13 +45,25 @@ type apiTestOpts struct {
serverName string
}
-type dummyProducer struct{}
+type dummyProducer struct {
+ callCount sync.Map
+ t *testing.T
+}
-func (d *dummyProducer) PublishMsg(*nats.Msg, ...nats.PubOpt) (*nats.PubAck, error) {
+func (d *dummyProducer) PublishMsg(msg *nats.Msg, opts ...nats.PubOpt) (*nats.PubAck, error) {
+ count, loaded := d.callCount.LoadOrStore(msg.Subject, 1)
+ if loaded {
+ c, ok := count.(int)
+ if !ok {
+ d.t.Fatalf("unexpected type: %T with value %q", c, c)
+ }
+ d.callCount.Store(msg.Subject, c+1)
+ d.t.Logf("Incrementing call counter for %s", msg.Subject)
+ }
return &nats.PubAck{}, nil
}
-func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (api.UserInternalAPI, storage.UserDatabase, func()) {
+func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType, publisher producers.JetStreamPublisher) (api.UserInternalAPI, storage.UserDatabase, func()) {
if opts.loginTokenLifetime == 0 {
opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond
}
@@ -82,8 +95,12 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap
},
}
- syncProducer := producers.NewSyncAPI(accountDB, &dummyProducer{}, "", "")
- keyChangeProducer := &producers.KeyChange{DB: keyDB, JetStream: &dummyProducer{}}
+ if publisher == nil {
+ publisher = &dummyProducer{t: t}
+ }
+
+ syncProducer := producers.NewSyncAPI(accountDB, publisher, "client_data", "notification_data")
+ keyChangeProducer := &producers.KeyChange{DB: keyDB, JetStream: publisher, Topic: "keychange"}
return &internal.UserInternalAPI{
DB: accountDB,
KeyDatabase: keyDB,
@@ -150,7 +167,7 @@ func TestQueryProfile(t *testing.T) {
}
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
+ userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType, nil)
defer close()
_, err := accountDB.CreateAccount(context.TODO(), "alice", serverName, "foobar", "", api.AccountTypeUser)
if err != nil {
@@ -173,7 +190,7 @@ func TestQueryProfile(t *testing.T) {
func TestPasswordlessLoginFails(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
+ userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType, nil)
defer close()
_, err := accountDB.CreateAccount(ctx, "auser", serverName, "", "", api.AccountTypeAppService)
if err != nil {
@@ -199,7 +216,7 @@ func TestLoginToken(t *testing.T) {
t.Run("tokenLoginFlow", func(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
+ userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType, nil)
defer close()
_, err := accountDB.CreateAccount(ctx, "auser", serverName, "apassword", "", api.AccountTypeUser)
if err != nil {
@@ -249,7 +266,7 @@ func TestLoginToken(t *testing.T) {
t.Run("expiredTokenIsNotReturned", func(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- userAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{loginTokenLifetime: -1 * time.Second}, dbType)
+ userAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{loginTokenLifetime: -1 * time.Second}, dbType, nil)
defer close()
creq := api.PerformLoginTokenCreationRequest{
@@ -274,7 +291,7 @@ func TestLoginToken(t *testing.T) {
t.Run("deleteWorks", func(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- userAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
+ userAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType, nil)
defer close()
creq := api.PerformLoginTokenCreationRequest{
@@ -305,7 +322,7 @@ func TestLoginToken(t *testing.T) {
t.Run("deleteUnknownIsNoOp", func(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- userAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
+ userAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType, nil)
defer close()
dreq := api.PerformLoginTokenDeletionRequest{Token: "non-existent token"}
var dresp api.PerformLoginTokenDeletionResponse
@@ -323,7 +340,7 @@ func TestQueryAccountByLocalpart(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- intAPI, db, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
+ intAPI, db, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType, nil)
defer close()
createdAcc, err := db.CreateAccount(ctx, localpart, userServername, "", "", alice.AccountType)
@@ -402,7 +419,7 @@ func TestAccountData(t *testing.T) {
}
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType)
+ intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType, nil)
defer close()
for _, tc := range testCases {
@@ -518,7 +535,7 @@ func TestDevices(t *testing.T) {
}
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType)
+ intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType, nil)
defer close()
for _, tc := range creationTests {
@@ -623,7 +640,8 @@ func TestDevices(t *testing.T) {
func TestDeviceIDReuse(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
- intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType)
+ publisher := &dummyProducer{t: t}
+ intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType, publisher)
defer close()
res := api.PerformDeviceCreationResponse{}
@@ -637,6 +655,9 @@ func TestDeviceIDReuse(t *testing.T) {
// Do the same request again, we expect a different sessionID
res2 := api.PerformDeviceCreationResponse{}
+ // Set NoDeviceListUpdate to false, to verify we don't send device list updates when
+ // reusing the same device ID
+ req.NoDeviceListUpdate = false
err = intAPI.PerformDeviceCreation(ctx, &req, &res2)
if err != nil {
t.Fatalf("expected no error, but got: %v", err)
@@ -645,5 +666,12 @@ func TestDeviceIDReuse(t *testing.T) {
if res2.Device.SessionID == res.Device.SessionID {
t.Fatalf("expected a different session ID, but they are the same")
}
+
+ publisher.callCount.Range(func(key, value any) bool {
+ if value != nil {
+ t.Fatalf("expected publisher to not get called, but got value %d for subject %s", value, key)
+ }
+ return true
+ })
})
}