diff options
Diffstat (limited to 'userapi/internal')
-rw-r--r-- | userapi/internal/api.go | 42 |
1 files changed, 38 insertions, 4 deletions
diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 738023dd..b9d18822 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -104,7 +104,8 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe } res.DeviceCreated = true res.Device = dev - return nil + // create empty device keys and upload them to trigger device list changes + return a.deviceListUpdate(dev.UserID, []string{dev.ID}) } func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.PerformDeviceDeletionRequest, res *api.PerformDeviceDeletionResponse) error { @@ -121,10 +122,14 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe return err } // create empty device keys and upload them to delete what was once there and trigger device list changes - deviceKeys := make([]keyapi.DeviceKeys, len(req.DeviceIDs)) - for i, did := range req.DeviceIDs { + return a.deviceListUpdate(req.UserID, req.DeviceIDs) +} + +func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) error { + deviceKeys := make([]keyapi.DeviceKeys, len(deviceIDs)) + for i, did := range deviceIDs { deviceKeys[i] = keyapi.DeviceKeys{ - UserID: req.UserID, + UserID: userID, DeviceID: did, KeyJSON: nil, } @@ -143,6 +148,35 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe return nil } +func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error { + localpart, _, err := gomatrixserverlib.SplitID('@', req.RequestingUserID) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") + return err + } + dev, err := a.DeviceDB.GetDeviceByID(ctx, localpart, req.DeviceID) + if err == sql.ErrNoRows { + res.DeviceExists = false + return nil + } else if err != nil { + util.GetLogger(ctx).WithError(err).Error("deviceDB.GetDeviceByID failed") + return err + } + res.DeviceExists = true + + if dev.UserID != req.RequestingUserID { + res.Forbidden = true + return nil + } + + err = a.DeviceDB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed") + return err + } + return nil +} + func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfileRequest, res *api.QueryProfileResponse) error { local, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { |