aboutsummaryrefslogtreecommitdiff
path: root/userapi/internal/api.go
diff options
context:
space:
mode:
authorNeil Alexander <neilalexander@users.noreply.github.com>2022-11-11 16:41:37 +0000
committerGitHub <noreply@github.com>2022-11-11 16:41:37 +0000
commit529df30b5649e67a2f98114e6640d259cba53566 (patch)
treebcb994ce79916f14c9a11cd11f32063411332585 /userapi/internal/api.go
parente177e0ae73d7cc34ffb9869681a6bf177f805205 (diff)
Virtual hosting schema and logic changes (#2876)
Note that virtual users cannot federate correctly yet.
Diffstat (limited to 'userapi/internal/api.go')
-rw-r--r--userapi/internal/api.go101
1 files changed, 57 insertions, 44 deletions
diff --git a/userapi/internal/api.go b/userapi/internal/api.go
index 9ca76965..3f256457 100644
--- a/userapi/internal/api.go
+++ b/userapi/internal/api.go
@@ -68,7 +68,7 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
if req.DataType == "" {
return fmt.Errorf("data type must not be empty")
}
- if err := a.DB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData); err != nil {
+ if err := a.DB.SaveAccountData(ctx, local, domain, req.RoomID, req.DataType, req.AccountData); err != nil {
util.GetLogger(ctx).WithError(err).Error("a.DB.SaveAccountData failed")
return fmt.Errorf("failed to save account data: %w", err)
}
@@ -108,7 +108,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun
return nil
}
- deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now())))
+ deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, domain, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now())))
if err != nil {
logrus.WithError(err).Errorf("UserInternalAPI.setFullyRead: DeleteNotificationsUpTo failed")
return err
@@ -124,7 +124,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun
return nil
}
- if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, a.DB); err != nil {
+ if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, domain, a.DB); err != nil {
logrus.WithError(err).Error("UserInternalAPI.setFullyRead: NotifyUserCounts failed")
return err
}
@@ -175,8 +175,10 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
if serverName == "" {
serverName = a.Config.Matrix.ServerName
}
- // XXXX: Use the server name here
- acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
+ if !a.Config.Matrix.IsLocalServerName(serverName) {
+ return fmt.Errorf("server name %s is not local", serverName)
+ }
+ acc, err := a.DB.CreateAccount(ctx, req.Localpart, serverName, req.Password, req.AppServiceID, req.AccountType)
if err != nil {
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
switch req.OnConflict {
@@ -215,8 +217,8 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
return nil
}
- if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
- return err
+ if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, serverName, req.Localpart); err != nil {
+ return fmt.Errorf("a.DB.SetDisplayName: %w", err)
}
postRegisterJoinRooms(a.Cfg, acc, a.RSAPI)
@@ -227,11 +229,14 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
}
func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
- if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
+ if !a.Config.Matrix.IsLocalServerName(req.ServerName) {
+ return fmt.Errorf("server name %s is not local", req.ServerName)
+ }
+ if err := a.DB.SetPassword(ctx, req.Localpart, req.ServerName, req.Password); err != nil {
return err
}
if req.LogoutDevices {
- if _, err := a.DB.RemoveAllDevices(context.Background(), req.Localpart, ""); err != nil {
+ if _, err := a.DB.RemoveAllDevices(context.Background(), req.Localpart, req.ServerName, ""); err != nil {
return err
}
}
@@ -244,14 +249,15 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
if serverName == "" {
serverName = a.Config.Matrix.ServerName
}
- _ = serverName
- // XXXX: Use the server name here
+ if !a.Config.Matrix.IsLocalServerName(serverName) {
+ return fmt.Errorf("server name %s is not local", serverName)
+ }
util.GetLogger(ctx).WithFields(logrus.Fields{
"localpart": req.Localpart,
"device_id": req.DeviceID,
"display_name": req.DeviceDisplayName,
}).Info("PerformDeviceCreation")
- dev, err := a.DB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
+ dev, err := a.DB.CreateDevice(ctx, req.Localpart, serverName, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
if err != nil {
return err
}
@@ -276,12 +282,12 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
deletedDeviceIDs := req.DeviceIDs
if len(req.DeviceIDs) == 0 {
var devices []api.Device
- devices, err = a.DB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
+ devices, err = a.DB.RemoveAllDevices(ctx, local, domain, req.ExceptDeviceID)
for _, d := range devices {
deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
}
} else {
- err = a.DB.RemoveDevices(ctx, local, req.DeviceIDs)
+ err = a.DB.RemoveDevices(ctx, local, domain, req.DeviceIDs)
}
if err != nil {
return err
@@ -335,23 +341,29 @@ func (a *UserInternalAPI) PerformLastSeenUpdate(
req *api.PerformLastSeenUpdateRequest,
res *api.PerformLastSeenUpdateResponse,
) error {
- localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID)
+ localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil {
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
}
- if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr, req.UserAgent); err != nil {
+ if !a.Config.Matrix.IsLocalServerName(domain) {
+ return fmt.Errorf("server name %s is not local", domain)
+ }
+ if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, domain, req.DeviceID, req.RemoteAddr, req.UserAgent); err != nil {
return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err)
}
return nil
}
func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error {
- localpart, _, err := gomatrixserverlib.SplitID('@', req.RequestingUserID)
+ localpart, domain, err := gomatrixserverlib.SplitID('@', req.RequestingUserID)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
return err
}
- dev, err := a.DB.GetDeviceByID(ctx, localpart, req.DeviceID)
+ if !a.Config.Matrix.IsLocalServerName(domain) {
+ return fmt.Errorf("server name %s is not local", domain)
+ }
+ dev, err := a.DB.GetDeviceByID(ctx, localpart, domain, req.DeviceID)
if err == sql.ErrNoRows {
res.DeviceExists = false
return nil
@@ -366,7 +378,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
return nil
}
- err = a.DB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName)
+ err = a.DB.UpdateDevice(ctx, localpart, domain, req.DeviceID, req.DisplayName)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
return err
@@ -406,7 +418,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
if !a.Config.Matrix.IsLocalServerName(domain) {
return fmt.Errorf("cannot query profile of remote users (server name %s)", domain)
}
- prof, err := a.DB.GetProfileByLocalpart(ctx, local)
+ prof, err := a.DB.GetProfileByLocalpart(ctx, local, domain)
if err != nil {
if err == sql.ErrNoRows {
return nil
@@ -457,7 +469,7 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice
if !a.Config.Matrix.IsLocalServerName(domain) {
return fmt.Errorf("cannot query devices of remote users (server name %s)", domain)
}
- devs, err := a.DB.GetDevicesByLocalpart(ctx, local)
+ devs, err := a.DB.GetDevicesByLocalpart(ctx, local, domain)
if err != nil {
return err
}
@@ -476,7 +488,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
}
if req.DataType != "" {
var data json.RawMessage
- data, err = a.DB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
+ data, err = a.DB.GetAccountDataByType(ctx, local, domain, req.RoomID, req.DataType)
if err != nil {
return err
}
@@ -494,7 +506,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
}
return nil
}
- global, rooms, err := a.DB.GetAccountData(ctx, local)
+ global, rooms, err := a.DB.GetAccountData(ctx, local, domain)
if err != nil {
return err
}
@@ -527,7 +539,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
if !a.Config.Matrix.IsLocalServerName(domain) {
return nil
}
- acc, err := a.DB.GetAccountByLocalpart(ctx, localPart)
+ acc, err := a.DB.GetAccountByLocalpart(ctx, localPart, domain)
if err != nil {
return err
}
@@ -561,14 +573,14 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
AccountType: api.AccountTypeAppService,
}
- localpart, _, err := userutil.ParseUsernameParam(appServiceUserID, a.Config.Matrix)
+ localpart, domain, err := userutil.ParseUsernameParam(appServiceUserID, a.Config.Matrix)
if err != nil {
return nil, err
}
if localpart != "" { // AS is masquerading as another user
// Verify that the user is registered
- account, err := a.DB.GetAccountByLocalpart(ctx, localpart)
+ account, err := a.DB.GetAccountByLocalpart(ctx, localpart, domain)
// Verify that the account exists and either appServiceID matches or
// it belongs to the appservice user namespaces
if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) {
@@ -620,7 +632,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
return err
}
- err := a.DB.DeactivateAccount(ctx, req.Localpart)
+ err := a.DB.DeactivateAccount(ctx, req.Localpart, serverName)
res.AccountDeactivated = err == nil
return err
}
@@ -783,7 +795,7 @@ func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.Query
if req.Only == "highlight" {
filter = tables.HighlightNotifications
}
- notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, fromID, req.Limit, filter)
+ notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, req.ServerName, fromID, req.Limit, filter)
if err != nil {
return err
}
@@ -811,23 +823,23 @@ func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.Perform
}
}
if req.Pusher.Kind == "" {
- return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart)
+ return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart, req.ServerName)
}
if req.Pusher.PushKeyTS == 0 {
req.Pusher.PushKeyTS = int64(time.Now().Unix())
}
- return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart)
+ return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart, req.ServerName)
}
func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error {
- pushers, err := a.DB.GetPushers(ctx, req.Localpart)
+ pushers, err := a.DB.GetPushers(ctx, req.Localpart, req.ServerName)
if err != nil {
return err
}
for i := range pushers {
logrus.Warnf("pusher session: %d, req session: %d", pushers[i].SessionID, req.SessionID)
if pushers[i].SessionID != req.SessionID {
- err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart)
+ err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart, req.ServerName)
if err != nil {
return err
}
@@ -838,7 +850,7 @@ func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.Pe
func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error {
var err error
- res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart)
+ res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart, req.ServerName)
return err
}
@@ -864,11 +876,11 @@ func (a *UserInternalAPI) PerformPushRulesPut(
}
func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error {
- localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID)
+ localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
if err != nil {
return fmt.Errorf("failed to split user ID %q for push rules", req.UserID)
}
- pushRules, err := a.DB.QueryPushRules(ctx, localpart)
+ pushRules, err := a.DB.QueryPushRules(ctx, localpart, domain)
if err != nil {
return fmt.Errorf("failed to query push rules: %w", err)
}
@@ -877,14 +889,14 @@ func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPush
}
func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error {
- profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL)
+ profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.ServerName, req.AvatarURL)
res.Profile = profile
res.Changed = changed
return err
}
-func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error {
- id, err := a.DB.GetNewNumericLocalpart(ctx)
+func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, req *api.QueryNumericLocalpartRequest, res *api.QueryNumericLocalpartResponse) error {
+ id, err := a.DB.GetNewNumericLocalpart(ctx, req.ServerName)
if err != nil {
return err
}
@@ -894,12 +906,12 @@ func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.Qu
func (a *UserInternalAPI) QueryAccountAvailability(ctx context.Context, req *api.QueryAccountAvailabilityRequest, res *api.QueryAccountAvailabilityResponse) error {
var err error
- res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart)
+ res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart, req.ServerName)
return err
}
func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error {
- acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.PlaintextPassword)
+ acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.ServerName, req.PlaintextPassword)
switch err {
case sql.ErrNoRows: // user does not exist
return nil
@@ -915,23 +927,24 @@ func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.Q
}
func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *api.PerformUpdateDisplayNameResponse) error {
- profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName)
+ profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.ServerName, req.DisplayName)
res.Profile = profile
res.Changed = changed
return err
}
func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error {
- localpart, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium)
+ localpart, domain, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium)
if err != nil {
return err
}
res.Localpart = localpart
+ res.ServerName = domain
return nil
}
func (a *UserInternalAPI) QueryThreePIDsForLocalpart(ctx context.Context, req *api.QueryThreePIDsForLocalpartRequest, res *api.QueryThreePIDsForLocalpartResponse) error {
- r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart)
+ r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart, req.ServerName)
if err != nil {
return err
}
@@ -944,7 +957,7 @@ func (a *UserInternalAPI) PerformForgetThreePID(ctx context.Context, req *api.Pe
}
func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, req *api.PerformSaveThreePIDAssociationRequest, res *struct{}) error {
- return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.Medium)
+ return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.ServerName, req.Medium)
}
const pushRulesAccountDataType = "m.push_rules"