aboutsummaryrefslogtreecommitdiff
path: root/userapi/internal
diff options
context:
space:
mode:
authorNeil Alexander <neilalexander@users.noreply.github.com>2022-10-26 12:59:19 +0100
committerGitHub <noreply@github.com>2022-10-26 12:59:19 +0100
commitf6dea712d2e9c71f6ebe61f90e45a142852432e8 (patch)
tree981b818ec9ece4e67f1b27ed52f82510aecc465d /userapi/internal
parent2a4c7f45b37a9bcd1a37d42b0668e0c3dfb29762 (diff)
Initial support for multiple server names (#2829)
This PR is the first step towards virtual hosting by laying the groundwork for multiple server names being configured.
Diffstat (limited to 'userapi/internal')
-rw-r--r--userapi/internal/api.go58
-rw-r--r--userapi/internal/api_logintoken.go8
2 files changed, 44 insertions, 22 deletions
diff --git a/userapi/internal/api.go b/userapi/internal/api.go
index 7b94b3da..9ca76965 100644
--- a/userapi/internal/api.go
+++ b/userapi/internal/api.go
@@ -46,9 +46,9 @@ import (
type UserInternalAPI struct {
DB storage.Database
SyncProducer *producers.SyncAPI
+ Config *config.UserAPI
DisableTLSValidation bool
- ServerName gomatrixserverlib.ServerName
// AppServices is the list of all registered AS
AppServices []config.ApplicationService
KeyAPI keyapi.UserKeyAPI
@@ -62,8 +62,8 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
if err != nil {
return err
}
- if domain != a.ServerName {
- return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName)
+ if !a.Config.Matrix.IsLocalServerName(domain) {
+ return fmt.Errorf("cannot update account data of remote users (server name %s)", domain)
}
if req.DataType == "" {
return fmt.Errorf("data type must not be empty")
@@ -104,7 +104,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun
logrus.WithError(err).Error("UserInternalAPI.setFullyRead: SplitID failure")
return nil
}
- if domain != a.ServerName {
+ if !a.Config.Matrix.IsLocalServerName(domain) {
return nil
}
@@ -171,6 +171,11 @@ func addUserToRoom(
}
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
+ serverName := req.ServerName
+ if serverName == "" {
+ serverName = a.Config.Matrix.ServerName
+ }
+ // XXXX: Use the server name here
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
if err != nil {
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
@@ -188,8 +193,8 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
res.Account = &api.Account{
AppServiceID: req.AppServiceID,
Localpart: req.Localpart,
- ServerName: a.ServerName,
- UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName),
+ ServerName: serverName,
+ UserID: fmt.Sprintf("@%s:%s", req.Localpart, serverName),
AccountType: req.AccountType,
}
return nil
@@ -235,6 +240,12 @@ func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.Pe
}
func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.PerformDeviceCreationRequest, res *api.PerformDeviceCreationResponse) error {
+ serverName := req.ServerName
+ if serverName == "" {
+ serverName = a.Config.Matrix.ServerName
+ }
+ _ = serverName
+ // XXXX: Use the server name here
util.GetLogger(ctx).WithFields(logrus.Fields{
"localpart": req.Localpart,
"device_id": req.DeviceID,
@@ -259,8 +270,8 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
if err != nil {
return err
}
- if domain != a.ServerName {
- return fmt.Errorf("cannot PerformDeviceDeletion of remote users: got %s want %s", domain, a.ServerName)
+ if !a.Config.Matrix.IsLocalServerName(domain) {
+ return fmt.Errorf("cannot PerformDeviceDeletion of remote users (server name %s)", domain)
}
deletedDeviceIDs := req.DeviceIDs
if len(req.DeviceIDs) == 0 {
@@ -392,8 +403,8 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
if err != nil {
return err
}
- if domain != a.ServerName {
- return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName)
+ if !a.Config.Matrix.IsLocalServerName(domain) {
+ return fmt.Errorf("cannot query profile of remote users (server name %s)", domain)
}
prof, err := a.DB.GetProfileByLocalpart(ctx, local)
if err != nil {
@@ -443,8 +454,8 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice
if err != nil {
return err
}
- if domain != a.ServerName {
- return fmt.Errorf("cannot query devices of remote users: got %s want %s", domain, a.ServerName)
+ if !a.Config.Matrix.IsLocalServerName(domain) {
+ return fmt.Errorf("cannot query devices of remote users (server name %s)", domain)
}
devs, err := a.DB.GetDevicesByLocalpart(ctx, local)
if err != nil {
@@ -460,8 +471,8 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
if err != nil {
return err
}
- if domain != a.ServerName {
- return fmt.Errorf("cannot query account data of remote users: got %s want %s", domain, a.ServerName)
+ if !a.Config.Matrix.IsLocalServerName(domain) {
+ return fmt.Errorf("cannot query account data of remote users (server name %s)", domain)
}
if req.DataType != "" {
var data json.RawMessage
@@ -509,10 +520,13 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
}
return err
}
- localPart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
+ localPart, domain, err := gomatrixserverlib.SplitID('@', device.UserID)
if err != nil {
return err
}
+ if !a.Config.Matrix.IsLocalServerName(domain) {
+ return nil
+ }
acc, err := a.DB.GetAccountByLocalpart(ctx, localPart)
if err != nil {
return err
@@ -547,7 +561,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
AccountType: api.AccountTypeAppService,
}
- localpart, err := userutil.ParseUsernameParam(appServiceUserID, &a.ServerName)
+ localpart, _, err := userutil.ParseUsernameParam(appServiceUserID, a.Config.Matrix)
if err != nil {
return nil, err
}
@@ -572,8 +586,16 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
// PerformAccountDeactivation deactivates the user's account, removing all ability for the user to login again.
func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error {
+ serverName := req.ServerName
+ if serverName == "" {
+ serverName = a.Config.Matrix.ServerName
+ }
+ if !a.Config.Matrix.IsLocalServerName(serverName) {
+ return fmt.Errorf("server name %q not locally configured", serverName)
+ }
+
evacuateReq := &rsapi.PerformAdminEvacuateUserRequest{
- UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName),
+ UserID: fmt.Sprintf("@%s:%s", req.Localpart, serverName),
}
evacuateRes := &rsapi.PerformAdminEvacuateUserResponse{}
if err := a.RSAPI.PerformAdminEvacuateUser(ctx, evacuateReq, evacuateRes); err != nil {
@@ -584,7 +606,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
}
deviceReq := &api.PerformDeviceDeletionRequest{
- UserID: fmt.Sprintf("@%s:%s", req.Localpart, a.ServerName),
+ UserID: fmt.Sprintf("@%s:%s", req.Localpart, serverName),
}
deviceRes := &api.PerformDeviceDeletionResponse{}
if err := a.PerformDeviceDeletion(ctx, deviceReq, deviceRes); err != nil {
diff --git a/userapi/internal/api_logintoken.go b/userapi/internal/api_logintoken.go
index f1bf391e..87f25e5e 100644
--- a/userapi/internal/api_logintoken.go
+++ b/userapi/internal/api_logintoken.go
@@ -31,8 +31,8 @@ func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *ap
if err != nil {
return err
}
- if domain != a.ServerName {
- return fmt.Errorf("cannot create a login token for a remote user: got %s want %s", domain, a.ServerName)
+ if !a.Config.Matrix.IsLocalServerName(domain) {
+ return fmt.Errorf("cannot create a login token for a remote user (server name %s)", domain)
}
tokenMeta, err := a.DB.CreateLoginToken(ctx, &req.Data)
if err != nil {
@@ -63,8 +63,8 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog
if err != nil {
return err
}
- if domain != a.ServerName {
- return fmt.Errorf("cannot return a login token for a remote user: got %s want %s", domain, a.ServerName)
+ if !a.Config.Matrix.IsLocalServerName(domain) {
+ return fmt.Errorf("cannot return a login token for a remote user (server name %s)", domain)
}
if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil {
res.Data = nil