From f6dea712d2e9c71f6ebe61f90e45a142852432e8 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 26 Oct 2022 12:59:19 +0100 Subject: Initial support for multiple server names (#2829) This PR is the first step towards virtual hosting by laying the groundwork for multiple server names being configured. --- userapi/api/api.go | 31 ++++++++++++++++---- userapi/internal/api.go | 58 ++++++++++++++++++++++++++------------ userapi/internal/api_logintoken.go | 8 +++--- userapi/userapi.go | 2 +- userapi/userapi_test.go | 4 +-- 5 files changed, 72 insertions(+), 31 deletions(-) (limited to 'userapi') diff --git a/userapi/api/api.go b/userapi/api/api.go index eef29144..8d7f783d 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -318,8 +318,9 @@ type QuerySearchProfilesResponse struct { // PerformAccountCreationRequest is the request for PerformAccountCreation type PerformAccountCreationRequest struct { - AccountType AccountType // Required: whether this is a guest or user account - Localpart string // Required: The localpart for this account. Ignored if account type is guest. + AccountType AccountType // Required: whether this is a guest or user account + Localpart string // Required: The localpart for this account. Ignored if account type is guest. + ServerName gomatrixserverlib.ServerName // optional: if not specified, default server name used instead AppServiceID string // optional: the application service ID (not user ID) creating this account, if any. Password string // optional: if missing then this account will be a passwordless account @@ -360,7 +361,8 @@ type PerformLastSeenUpdateResponse struct { // PerformDeviceCreationRequest is the request for PerformDeviceCreation type PerformDeviceCreationRequest struct { Localpart string - AccessToken string // optional: if blank one will be made on your behalf + ServerName gomatrixserverlib.ServerName // optional: if blank, default server name used + AccessToken string // optional: if blank one will be made on your behalf // optional: if nil an ID is generated for you. If set, replaces any existing device session, // which will generate a new access token and invalidate the old one. DeviceID *string @@ -384,7 +386,8 @@ type PerformDeviceCreationResponse struct { // PerformAccountDeactivationRequest is the request for PerformAccountDeactivation type PerformAccountDeactivationRequest struct { - Localpart string + Localpart string + ServerName gomatrixserverlib.ServerName // optional: if blank, default server name used } // PerformAccountDeactivationResponse is the response for PerformAccountDeactivation @@ -434,6 +437,18 @@ type Device struct { AccountType AccountType } +func (d *Device) UserDomain() gomatrixserverlib.ServerName { + _, domain, err := gomatrixserverlib.SplitID('@', d.UserID) + if err != nil { + // This really is catastrophic because it means that someone + // managed to forge a malformed user ID for a device during + // login. + // TODO: Is there a better way to deal with this than panic? + panic(err) + } + return domain +} + // Account represents a Matrix account on this home server. type Account struct { UserID string @@ -577,7 +592,9 @@ type Notification struct { } type PerformSetAvatarURLRequest struct { - Localpart, AvatarURL string + Localpart string + ServerName gomatrixserverlib.ServerName + AvatarURL string } type PerformSetAvatarURLResponse struct { Profile *authtypes.Profile `json:"profile"` @@ -606,7 +623,9 @@ type QueryAccountByPasswordResponse struct { } type PerformUpdateDisplayNameRequest struct { - Localpart, DisplayName string + Localpart string + ServerName gomatrixserverlib.ServerName + DisplayName string } type PerformUpdateDisplayNameResponse struct { diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 7b94b3da..9ca76965 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -46,9 +46,9 @@ import ( type UserInternalAPI struct { DB storage.Database SyncProducer *producers.SyncAPI + Config *config.UserAPI DisableTLSValidation bool - ServerName gomatrixserverlib.ServerName // AppServices is the list of all registered AS AppServices []config.ApplicationService KeyAPI keyapi.UserKeyAPI @@ -62,8 +62,8 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc if err != nil { return err } - if domain != a.ServerName { - return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName) + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("cannot update account data of remote users (server name %s)", domain) } if req.DataType == "" { return fmt.Errorf("data type must not be empty") @@ -104,7 +104,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun logrus.WithError(err).Error("UserInternalAPI.setFullyRead: SplitID failure") return nil } - if domain != a.ServerName { + if !a.Config.Matrix.IsLocalServerName(domain) { return nil } @@ -171,6 +171,11 @@ func addUserToRoom( } func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error { + serverName := req.ServerName + if serverName == "" { + serverName = a.Config.Matrix.ServerName + } + // XXXX: Use the server name here acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType) if err != nil { if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists @@ -188,8 +193,8 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P res.Account = &api.Account{ AppServiceID: req.AppServiceID, Localpart: req.Localpart, - ServerName: a.ServerName, - UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName), + ServerName: serverName, + UserID: fmt.Sprintf("@%s:%s", req.Localpart, serverName), AccountType: req.AccountType, } return nil @@ -235,6 +240,12 @@ func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.Pe } func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.PerformDeviceCreationRequest, res *api.PerformDeviceCreationResponse) error { + serverName := req.ServerName + if serverName == "" { + serverName = a.Config.Matrix.ServerName + } + _ = serverName + // XXXX: Use the server name here util.GetLogger(ctx).WithFields(logrus.Fields{ "localpart": req.Localpart, "device_id": req.DeviceID, @@ -259,8 +270,8 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe if err != nil { return err } - if domain != a.ServerName { - return fmt.Errorf("cannot PerformDeviceDeletion of remote users: got %s want %s", domain, a.ServerName) + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("cannot PerformDeviceDeletion of remote users (server name %s)", domain) } deletedDeviceIDs := req.DeviceIDs if len(req.DeviceIDs) == 0 { @@ -392,8 +403,8 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil if err != nil { return err } - if domain != a.ServerName { - return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName) + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("cannot query profile of remote users (server name %s)", domain) } prof, err := a.DB.GetProfileByLocalpart(ctx, local) if err != nil { @@ -443,8 +454,8 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice if err != nil { return err } - if domain != a.ServerName { - return fmt.Errorf("cannot query devices of remote users: got %s want %s", domain, a.ServerName) + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("cannot query devices of remote users (server name %s)", domain) } devs, err := a.DB.GetDevicesByLocalpart(ctx, local) if err != nil { @@ -460,8 +471,8 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc if err != nil { return err } - if domain != a.ServerName { - return fmt.Errorf("cannot query account data of remote users: got %s want %s", domain, a.ServerName) + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("cannot query account data of remote users (server name %s)", domain) } if req.DataType != "" { var data json.RawMessage @@ -509,10 +520,13 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc } return err } - localPart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + localPart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { return err } + if !a.Config.Matrix.IsLocalServerName(domain) { + return nil + } acc, err := a.DB.GetAccountByLocalpart(ctx, localPart) if err != nil { return err @@ -547,7 +561,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe AccountType: api.AccountTypeAppService, } - localpart, err := userutil.ParseUsernameParam(appServiceUserID, &a.ServerName) + localpart, _, err := userutil.ParseUsernameParam(appServiceUserID, a.Config.Matrix) if err != nil { return nil, err } @@ -572,8 +586,16 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe // PerformAccountDeactivation deactivates the user's account, removing all ability for the user to login again. func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error { + serverName := req.ServerName + if serverName == "" { + serverName = a.Config.Matrix.ServerName + } + if !a.Config.Matrix.IsLocalServerName(serverName) { + return fmt.Errorf("server name %q not locally configured", serverName) + } + evacuateReq := &rsapi.PerformAdminEvacuateUserRequest{ - UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName), + UserID: fmt.Sprintf("@%s:%s", req.Localpart, serverName), } evacuateRes := &rsapi.PerformAdminEvacuateUserResponse{} if err := a.RSAPI.PerformAdminEvacuateUser(ctx, evacuateReq, evacuateRes); err != nil { @@ -584,7 +606,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a } deviceReq := &api.PerformDeviceDeletionRequest{ - UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName), + UserID: fmt.Sprintf("@%s:%s", req.Localpart, serverName), } deviceRes := &api.PerformDeviceDeletionResponse{} if err := a.PerformDeviceDeletion(ctx, deviceReq, deviceRes); err != nil { diff --git a/userapi/internal/api_logintoken.go b/userapi/internal/api_logintoken.go index f1bf391e..87f25e5e 100644 --- a/userapi/internal/api_logintoken.go +++ b/userapi/internal/api_logintoken.go @@ -31,8 +31,8 @@ func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *ap if err != nil { return err } - if domain != a.ServerName { - return fmt.Errorf("cannot create a login token for a remote user: got %s want %s", domain, a.ServerName) + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("cannot create a login token for a remote user (server name %s)", domain) } tokenMeta, err := a.DB.CreateLoginToken(ctx, &req.Data) if err != nil { @@ -63,8 +63,8 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog if err != nil { return err } - if domain != a.ServerName { - return fmt.Errorf("cannot return a login token for a remote user: got %s want %s", domain, a.ServerName) + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("cannot return a login token for a remote user (server name %s)", domain) } if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil { res.Data = nil diff --git a/userapi/userapi.go b/userapi/userapi.go index c077248e..e46a8e76 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -76,7 +76,7 @@ func NewInternalAPI( userAPI := &internal.UserInternalAPI{ DB: db, SyncProducer: syncProducer, - ServerName: cfg.Matrix.ServerName, + Config: cfg, AppServices: appServices, KeyAPI: keyAPI, RSAPI: rsAPI, diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index aaa93f45..2a43c0bd 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -66,8 +66,8 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap } return &internal.UserInternalAPI{ - DB: accountDB, - ServerName: cfg.Matrix.ServerName, + DB: accountDB, + Config: cfg, }, accountDB, func() { close() baseclose() -- cgit v1.2.3