aboutsummaryrefslogtreecommitdiff
path: root/userapi/internal
diff options
context:
space:
mode:
Diffstat (limited to 'userapi/internal')
-rw-r--r--userapi/internal/api.go42
1 files changed, 38 insertions, 4 deletions
diff --git a/userapi/internal/api.go b/userapi/internal/api.go
index 738023dd..b9d18822 100644
--- a/userapi/internal/api.go
+++ b/userapi/internal/api.go
@@ -104,7 +104,8 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
}
res.DeviceCreated = true
res.Device = dev
- return nil
+ // create empty device keys and upload them to trigger device list changes
+ return a.deviceListUpdate(dev.UserID, []string{dev.ID})
}
func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.PerformDeviceDeletionRequest, res *api.PerformDeviceDeletionResponse) error {
@@ -121,10 +122,14 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
return err
}
// create empty device keys and upload them to delete what was once there and trigger device list changes
- deviceKeys := make([]keyapi.DeviceKeys, len(req.DeviceIDs))
- for i, did := range req.DeviceIDs {
+ return a.deviceListUpdate(req.UserID, req.DeviceIDs)
+}
+
+func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) error {
+ deviceKeys := make([]keyapi.DeviceKeys, len(deviceIDs))
+ for i, did := range deviceIDs {
deviceKeys[i] = keyapi.DeviceKeys{
- UserID: req.UserID,
+ UserID: userID,
DeviceID: did,
KeyJSON: nil,
}
@@ -143,6 +148,35 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
return nil
}
+func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error {
+ localpart, _, err := gomatrixserverlib.SplitID('@', req.RequestingUserID)
+ if err != nil {
+ util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
+ return err
+ }
+ dev, err := a.DeviceDB.GetDeviceByID(ctx, localpart, req.DeviceID)
+ if err == sql.ErrNoRows {
+ res.DeviceExists = false
+ return nil
+ } else if err != nil {
+ util.GetLogger(ctx).WithError(err).Error("deviceDB.GetDeviceByID failed")
+ return err
+ }
+ res.DeviceExists = true
+
+ if dev.UserID != req.RequestingUserID {
+ res.Forbidden = true
+ return nil
+ }
+
+ err = a.DeviceDB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName)
+ if err != nil {
+ util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
+ return err
+ }
+ return nil
+}
+
func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfileRequest, res *api.QueryProfileResponse) error {
local, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil {