diff options
Diffstat (limited to 'appservice/query/query.go')
-rw-r--r-- | appservice/query/query.go | 98 |
1 files changed, 64 insertions, 34 deletions
diff --git a/appservice/query/query.go b/appservice/query/query.go index 5c736f37..7f33e17f 100644 --- a/appservice/query/query.go +++ b/appservice/query/query.go @@ -19,10 +19,10 @@ package query import ( "context" "encoding/json" + "fmt" "io" "net/http" "net/url" - "strings" "sync" log "github.com/sirupsen/logrus" @@ -32,9 +32,6 @@ import ( "github.com/matrix-org/dendrite/setup/config" ) -const roomAliasExistsPath = "/rooms/" -const userIDExistsPath = "/users/" - // AppServiceQueryAPI is an implementation of api.AppServiceQueryAPI type AppServiceQueryAPI struct { Cfg *config.AppServiceAPI @@ -55,14 +52,23 @@ func (a *AppServiceQueryAPI) RoomAliasExists( // Determine which application service should handle this request for _, appservice := range a.Cfg.Derived.ApplicationServices { if appservice.URL != "" && appservice.IsInterestedInRoomAlias(request.Alias) { + path := api.ASRoomAliasExistsPath + if a.Cfg.LegacyPaths { + path = api.ASRoomAliasExistsLegacyPath + } // The full path to the rooms API, includes hs token - URL, err := url.Parse(appservice.RequestUrl() + roomAliasExistsPath) + URL, err := url.Parse(appservice.RequestUrl() + path) if err != nil { return err } URL.Path += request.Alias - apiURL := URL.String() + "?access_token=" + appservice.HSToken + if a.Cfg.LegacyAuth { + q := URL.Query() + q.Set("access_token", appservice.HSToken) + URL.RawQuery = q.Encode() + } + apiURL := URL.String() // Send a request to each application service. If one responds that it has // created the room, immediately return. @@ -70,6 +76,7 @@ func (a *AppServiceQueryAPI) RoomAliasExists( if err != nil { return err } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", appservice.HSToken)) req = req.WithContext(ctx) resp, err := appservice.HTTPClient.Do(req) @@ -123,12 +130,21 @@ func (a *AppServiceQueryAPI) UserIDExists( for _, appservice := range a.Cfg.Derived.ApplicationServices { if appservice.URL != "" && appservice.IsInterestedInUserID(request.UserID) { // The full path to the rooms API, includes hs token - URL, err := url.Parse(appservice.RequestUrl() + userIDExistsPath) + path := api.ASUserExistsPath + if a.Cfg.LegacyPaths { + path = api.ASUserExistsLegacyPath + } + URL, err := url.Parse(appservice.RequestUrl() + path) if err != nil { return err } URL.Path += request.UserID - apiURL := URL.String() + "?access_token=" + appservice.HSToken + if a.Cfg.LegacyAuth { + q := URL.Query() + q.Set("access_token", appservice.HSToken) + URL.RawQuery = q.Encode() + } + apiURL := URL.String() // Send a request to each application service. If one responds that it has // created the user, immediately return. @@ -136,6 +152,7 @@ func (a *AppServiceQueryAPI) UserIDExists( if err != nil { return err } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", appservice.HSToken)) resp, err := appservice.HTTPClient.Do(req.WithContext(ctx)) if resp != nil { defer func() { @@ -176,25 +193,22 @@ type thirdpartyResponses interface { api.ASProtocolResponse | []api.ASUserResponse | []api.ASLocationResponse } -func requestDo[T thirdpartyResponses](client *http.Client, url string, response *T) (err error) { - origURL := url - // try v1 and unstable appservice endpoints - for _, version := range []string{"v1", "unstable"} { - var resp *http.Response - var body []byte - asURL := strings.Replace(origURL, "unstable", version, 1) - resp, err = client.Get(asURL) - if err != nil { - continue - } - defer resp.Body.Close() // nolint: errcheck - body, err = io.ReadAll(resp.Body) - if err != nil { - continue - } - return json.Unmarshal(body, &response) +func requestDo[T thirdpartyResponses](as *config.ApplicationService, url string, response *T) error { + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return err } - return err + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", as.HSToken)) + resp, err := as.HTTPClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() // nolint: errcheck + body, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + return json.Unmarshal(body, &response) } func (a *AppServiceQueryAPI) Locations( @@ -207,16 +221,22 @@ func (a *AppServiceQueryAPI) Locations( return err } + path := api.ASLocationPath + if a.Cfg.LegacyPaths { + path = api.ASLocationLegacyPath + } for _, as := range a.Cfg.Derived.ApplicationServices { var asLocations []api.ASLocationResponse - params.Set("access_token", as.HSToken) + if a.Cfg.LegacyAuth { + params.Set("access_token", as.HSToken) + } - url := as.RequestUrl() + api.ASLocationPath + url := as.RequestUrl() + path if req.Protocol != "" { url += "/" + req.Protocol } - if err := requestDo[[]api.ASLocationResponse](as.HTTPClient, url+"?"+params.Encode(), &asLocations); err != nil { + if err := requestDo[[]api.ASLocationResponse](&as, url+"?"+params.Encode(), &asLocations); err != nil { log.WithError(err).WithField("application_service", as.ID).Error("unable to get 'locations' from application service") continue } @@ -242,16 +262,22 @@ func (a *AppServiceQueryAPI) User( return err } + path := api.ASUserPath + if a.Cfg.LegacyPaths { + path = api.ASUserLegacyPath + } for _, as := range a.Cfg.Derived.ApplicationServices { var asUsers []api.ASUserResponse - params.Set("access_token", as.HSToken) + if a.Cfg.LegacyAuth { + params.Set("access_token", as.HSToken) + } - url := as.RequestUrl() + api.ASUserPath + url := as.RequestUrl() + path if req.Protocol != "" { url += "/" + req.Protocol } - if err := requestDo[[]api.ASUserResponse](as.HTTPClient, url+"?"+params.Encode(), &asUsers); err != nil { + if err := requestDo[[]api.ASUserResponse](&as, url+"?"+params.Encode(), &asUsers); err != nil { log.WithError(err).WithField("application_service", as.ID).Error("unable to get 'user' from application service") continue } @@ -272,6 +298,10 @@ func (a *AppServiceQueryAPI) Protocols( req *api.ProtocolRequest, resp *api.ProtocolResponse, ) error { + protocolPath := api.ASProtocolPath + if a.Cfg.LegacyPaths { + protocolPath = api.ASProtocolLegacyPath + } // get a single protocol response if req.Protocol != "" { @@ -289,7 +319,7 @@ func (a *AppServiceQueryAPI) Protocols( response := api.ASProtocolResponse{} for _, as := range a.Cfg.Derived.ApplicationServices { var proto api.ASProtocolResponse - if err := requestDo[api.ASProtocolResponse](as.HTTPClient, as.RequestUrl()+api.ASProtocolPath+req.Protocol, &proto); err != nil { + if err := requestDo[api.ASProtocolResponse](&as, as.RequestUrl()+protocolPath+req.Protocol, &proto); err != nil { log.WithError(err).WithField("application_service", as.ID).Error("unable to get 'protocol' from application service") continue } @@ -319,7 +349,7 @@ func (a *AppServiceQueryAPI) Protocols( for _, as := range a.Cfg.Derived.ApplicationServices { for _, p := range as.Protocols { var proto api.ASProtocolResponse - if err := requestDo[api.ASProtocolResponse](as.HTTPClient, as.RequestUrl()+api.ASProtocolPath+p, &proto); err != nil { + if err := requestDo[api.ASProtocolResponse](&as, as.RequestUrl()+protocolPath+p, &proto); err != nil { log.WithError(err).WithField("application_service", as.ID).Error("unable to get 'protocol' from application service") continue } |