aboutsummaryrefslogtreecommitdiff
path: root/appservice/query/query.go
diff options
context:
space:
mode:
Diffstat (limited to 'appservice/query/query.go')
-rw-r--r--appservice/query/query.go98
1 files changed, 64 insertions, 34 deletions
diff --git a/appservice/query/query.go b/appservice/query/query.go
index 5c736f37..7f33e17f 100644
--- a/appservice/query/query.go
+++ b/appservice/query/query.go
@@ -19,10 +19,10 @@ package query
import (
"context"
"encoding/json"
+ "fmt"
"io"
"net/http"
"net/url"
- "strings"
"sync"
log "github.com/sirupsen/logrus"
@@ -32,9 +32,6 @@ import (
"github.com/matrix-org/dendrite/setup/config"
)
-const roomAliasExistsPath = "/rooms/"
-const userIDExistsPath = "/users/"
-
// AppServiceQueryAPI is an implementation of api.AppServiceQueryAPI
type AppServiceQueryAPI struct {
Cfg *config.AppServiceAPI
@@ -55,14 +52,23 @@ func (a *AppServiceQueryAPI) RoomAliasExists(
// Determine which application service should handle this request
for _, appservice := range a.Cfg.Derived.ApplicationServices {
if appservice.URL != "" && appservice.IsInterestedInRoomAlias(request.Alias) {
+ path := api.ASRoomAliasExistsPath
+ if a.Cfg.LegacyPaths {
+ path = api.ASRoomAliasExistsLegacyPath
+ }
// The full path to the rooms API, includes hs token
- URL, err := url.Parse(appservice.RequestUrl() + roomAliasExistsPath)
+ URL, err := url.Parse(appservice.RequestUrl() + path)
if err != nil {
return err
}
URL.Path += request.Alias
- apiURL := URL.String() + "?access_token=" + appservice.HSToken
+ if a.Cfg.LegacyAuth {
+ q := URL.Query()
+ q.Set("access_token", appservice.HSToken)
+ URL.RawQuery = q.Encode()
+ }
+ apiURL := URL.String()
// Send a request to each application service. If one responds that it has
// created the room, immediately return.
@@ -70,6 +76,7 @@ func (a *AppServiceQueryAPI) RoomAliasExists(
if err != nil {
return err
}
+ req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", appservice.HSToken))
req = req.WithContext(ctx)
resp, err := appservice.HTTPClient.Do(req)
@@ -123,12 +130,21 @@ func (a *AppServiceQueryAPI) UserIDExists(
for _, appservice := range a.Cfg.Derived.ApplicationServices {
if appservice.URL != "" && appservice.IsInterestedInUserID(request.UserID) {
// The full path to the rooms API, includes hs token
- URL, err := url.Parse(appservice.RequestUrl() + userIDExistsPath)
+ path := api.ASUserExistsPath
+ if a.Cfg.LegacyPaths {
+ path = api.ASUserExistsLegacyPath
+ }
+ URL, err := url.Parse(appservice.RequestUrl() + path)
if err != nil {
return err
}
URL.Path += request.UserID
- apiURL := URL.String() + "?access_token=" + appservice.HSToken
+ if a.Cfg.LegacyAuth {
+ q := URL.Query()
+ q.Set("access_token", appservice.HSToken)
+ URL.RawQuery = q.Encode()
+ }
+ apiURL := URL.String()
// Send a request to each application service. If one responds that it has
// created the user, immediately return.
@@ -136,6 +152,7 @@ func (a *AppServiceQueryAPI) UserIDExists(
if err != nil {
return err
}
+ req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", appservice.HSToken))
resp, err := appservice.HTTPClient.Do(req.WithContext(ctx))
if resp != nil {
defer func() {
@@ -176,25 +193,22 @@ type thirdpartyResponses interface {
api.ASProtocolResponse | []api.ASUserResponse | []api.ASLocationResponse
}
-func requestDo[T thirdpartyResponses](client *http.Client, url string, response *T) (err error) {
- origURL := url
- // try v1 and unstable appservice endpoints
- for _, version := range []string{"v1", "unstable"} {
- var resp *http.Response
- var body []byte
- asURL := strings.Replace(origURL, "unstable", version, 1)
- resp, err = client.Get(asURL)
- if err != nil {
- continue
- }
- defer resp.Body.Close() // nolint: errcheck
- body, err = io.ReadAll(resp.Body)
- if err != nil {
- continue
- }
- return json.Unmarshal(body, &response)
+func requestDo[T thirdpartyResponses](as *config.ApplicationService, url string, response *T) error {
+ req, err := http.NewRequest(http.MethodGet, url, nil)
+ if err != nil {
+ return err
}
- return err
+ req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", as.HSToken))
+ resp, err := as.HTTPClient.Do(req)
+ if err != nil {
+ return err
+ }
+ defer resp.Body.Close() // nolint: errcheck
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return err
+ }
+ return json.Unmarshal(body, &response)
}
func (a *AppServiceQueryAPI) Locations(
@@ -207,16 +221,22 @@ func (a *AppServiceQueryAPI) Locations(
return err
}
+ path := api.ASLocationPath
+ if a.Cfg.LegacyPaths {
+ path = api.ASLocationLegacyPath
+ }
for _, as := range a.Cfg.Derived.ApplicationServices {
var asLocations []api.ASLocationResponse
- params.Set("access_token", as.HSToken)
+ if a.Cfg.LegacyAuth {
+ params.Set("access_token", as.HSToken)
+ }
- url := as.RequestUrl() + api.ASLocationPath
+ url := as.RequestUrl() + path
if req.Protocol != "" {
url += "/" + req.Protocol
}
- if err := requestDo[[]api.ASLocationResponse](as.HTTPClient, url+"?"+params.Encode(), &asLocations); err != nil {
+ if err := requestDo[[]api.ASLocationResponse](&as, url+"?"+params.Encode(), &asLocations); err != nil {
log.WithError(err).WithField("application_service", as.ID).Error("unable to get 'locations' from application service")
continue
}
@@ -242,16 +262,22 @@ func (a *AppServiceQueryAPI) User(
return err
}
+ path := api.ASUserPath
+ if a.Cfg.LegacyPaths {
+ path = api.ASUserLegacyPath
+ }
for _, as := range a.Cfg.Derived.ApplicationServices {
var asUsers []api.ASUserResponse
- params.Set("access_token", as.HSToken)
+ if a.Cfg.LegacyAuth {
+ params.Set("access_token", as.HSToken)
+ }
- url := as.RequestUrl() + api.ASUserPath
+ url := as.RequestUrl() + path
if req.Protocol != "" {
url += "/" + req.Protocol
}
- if err := requestDo[[]api.ASUserResponse](as.HTTPClient, url+"?"+params.Encode(), &asUsers); err != nil {
+ if err := requestDo[[]api.ASUserResponse](&as, url+"?"+params.Encode(), &asUsers); err != nil {
log.WithError(err).WithField("application_service", as.ID).Error("unable to get 'user' from application service")
continue
}
@@ -272,6 +298,10 @@ func (a *AppServiceQueryAPI) Protocols(
req *api.ProtocolRequest,
resp *api.ProtocolResponse,
) error {
+ protocolPath := api.ASProtocolPath
+ if a.Cfg.LegacyPaths {
+ protocolPath = api.ASProtocolLegacyPath
+ }
// get a single protocol response
if req.Protocol != "" {
@@ -289,7 +319,7 @@ func (a *AppServiceQueryAPI) Protocols(
response := api.ASProtocolResponse{}
for _, as := range a.Cfg.Derived.ApplicationServices {
var proto api.ASProtocolResponse
- if err := requestDo[api.ASProtocolResponse](as.HTTPClient, as.RequestUrl()+api.ASProtocolPath+req.Protocol, &proto); err != nil {
+ if err := requestDo[api.ASProtocolResponse](&as, as.RequestUrl()+protocolPath+req.Protocol, &proto); err != nil {
log.WithError(err).WithField("application_service", as.ID).Error("unable to get 'protocol' from application service")
continue
}
@@ -319,7 +349,7 @@ func (a *AppServiceQueryAPI) Protocols(
for _, as := range a.Cfg.Derived.ApplicationServices {
for _, p := range as.Protocols {
var proto api.ASProtocolResponse
- if err := requestDo[api.ASProtocolResponse](as.HTTPClient, as.RequestUrl()+api.ASProtocolPath+p, &proto); err != nil {
+ if err := requestDo[api.ASProtocolResponse](&as, as.RequestUrl()+protocolPath+p, &proto); err != nil {
log.WithError(err).WithField("application_service", as.ID).Error("unable to get 'protocol' from application service")
continue
}