aboutsummaryrefslogtreecommitdiff
path: root/userapi
diff options
context:
space:
mode:
Diffstat (limited to 'userapi')
-rw-r--r--userapi/api/api.go10
-rw-r--r--userapi/api/api_trace.go6
-rw-r--r--userapi/internal/api.go5
-rw-r--r--userapi/inthttp/client.go12
-rw-r--r--userapi/inthttp/server.go5
-rw-r--r--userapi/userapi_test.go61
6 files changed, 99 insertions, 0 deletions
diff --git a/userapi/api/api.go b/userapi/api/api.go
index d3f5aefc..4ea2e91c 100644
--- a/userapi/api/api.go
+++ b/userapi/api/api.go
@@ -50,6 +50,7 @@ type KeyserverUserAPI interface {
type RoomserverUserAPI interface {
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
+ QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error)
}
// api functions required by the media api
@@ -671,3 +672,12 @@ type PerformSaveThreePIDAssociationRequest struct {
ServerName gomatrixserverlib.ServerName
Medium string
}
+
+type QueryAccountByLocalpartRequest struct {
+ Localpart string
+ ServerName gomatrixserverlib.ServerName
+}
+
+type QueryAccountByLocalpartResponse struct {
+ Account *Account
+}
diff --git a/userapi/api/api_trace.go b/userapi/api/api_trace.go
index ce661770..d10b5767 100644
--- a/userapi/api/api_trace.go
+++ b/userapi/api/api_trace.go
@@ -204,6 +204,12 @@ func (t *UserInternalAPITrace) PerformSaveThreePIDAssociation(ctx context.Contex
return err
}
+func (t *UserInternalAPITrace) QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) error {
+ err := t.Impl.QueryAccountByLocalpart(ctx, req, res)
+ util.GetLogger(ctx).Infof("QueryAccountByLocalpart req=%+v res=%+v", js(req), js(res))
+ return err
+}
+
func js(thing interface{}) string {
b, err := json.Marshal(thing)
if err != nil {
diff --git a/userapi/internal/api.go b/userapi/internal/api.go
index 3f256457..0bb480da 100644
--- a/userapi/internal/api.go
+++ b/userapi/internal/api.go
@@ -548,6 +548,11 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
return nil
}
+func (a *UserInternalAPI) QueryAccountByLocalpart(ctx context.Context, req *api.QueryAccountByLocalpartRequest, res *api.QueryAccountByLocalpartResponse) (err error) {
+ res.Account, err = a.DB.GetAccountByLocalpart(ctx, req.Localpart, req.ServerName)
+ return
+}
+
// Return the appservice 'device' or nil if the token is not an appservice. Returns an error if there was a problem
// creating a 'device'.
func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appServiceUserID string) (*api.Device, error) {
diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go
index 87ae058c..51b0fe3e 100644
--- a/userapi/inthttp/client.go
+++ b/userapi/inthttp/client.go
@@ -60,6 +60,7 @@ const (
QueryAccountByPasswordPath = "/userapi/queryAccountByPassword"
QueryLocalpartForThreePIDPath = "/userapi/queryLocalpartForThreePID"
QueryThreePIDsForLocalpartPath = "/userapi/queryThreePIDsForLocalpart"
+ QueryAccountByLocalpartPath = "/userapi/queryAccountType"
)
// NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API.
@@ -440,3 +441,14 @@ func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation(
h.httpClient, ctx, request, response,
)
}
+
+func (h *httpUserInternalAPI) QueryAccountByLocalpart(
+ ctx context.Context,
+ req *api.QueryAccountByLocalpartRequest,
+ res *api.QueryAccountByLocalpartResponse,
+) error {
+ return httputil.CallInternalRPCAPI(
+ "QueryAccountByLocalpart", h.apiURL+QueryAccountByLocalpartPath,
+ h.httpClient, ctx, req, res,
+ )
+}
diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go
index f0579079..b40b507c 100644
--- a/userapi/inthttp/server.go
+++ b/userapi/inthttp/server.go
@@ -189,4 +189,9 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI, enableMetrics
PerformSaveThreePIDAssociationPath,
httputil.MakeInternalRPCAPI("UserAPIPerformSaveThreePIDAssociation", enableMetrics, s.PerformSaveThreePIDAssociation),
)
+
+ internalAPIMux.Handle(
+ QueryAccountByLocalpartPath,
+ httputil.MakeInternalRPCAPI("AccountByLocalpart", enableMetrics, s.QueryAccountByLocalpart),
+ )
}
diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go
index 8a19af19..dada56de 100644
--- a/userapi/userapi_test.go
+++ b/userapi/userapi_test.go
@@ -307,3 +307,64 @@ func TestLoginToken(t *testing.T) {
})
})
}
+
+func TestQueryAccountByLocalpart(t *testing.T) {
+ alice := test.NewUser(t)
+
+ localpart, userServername, _ := gomatrixserverlib.SplitID('@', alice.ID)
+
+ ctx := context.Background()
+ test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
+ intAPI, db, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
+ defer close()
+
+ createdAcc, err := db.CreateAccount(ctx, localpart, userServername, "", "", alice.AccountType)
+ if err != nil {
+ t.Error(err)
+ }
+
+ testCases := func(t *testing.T, internalAPI api.UserInternalAPI) {
+ // Query existing account
+ queryAccResp := &api.QueryAccountByLocalpartResponse{}
+ if err = internalAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{
+ Localpart: localpart,
+ ServerName: userServername,
+ }, queryAccResp); err != nil {
+ t.Error(err)
+ }
+ if !reflect.DeepEqual(createdAcc, queryAccResp.Account) {
+ t.Fatalf("created and queried accounts don't match:\n%+v vs.\n%+v", createdAcc, queryAccResp.Account)
+ }
+
+ // Query non-existent account, this should result in an error
+ err = internalAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{
+ Localpart: "doesnotexist",
+ ServerName: userServername,
+ }, queryAccResp)
+
+ if err == nil {
+ t.Fatalf("expected an error, but got none: %+v", queryAccResp)
+ }
+ }
+
+ t.Run("Monolith", func(t *testing.T) {
+ testCases(t, intAPI)
+ // also test tracing
+ testCases(t, &api.UserInternalAPITrace{Impl: intAPI})
+ })
+
+ t.Run("HTTP API", func(t *testing.T) {
+ router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter()
+ userapi.AddInternalRoutes(router, intAPI, false)
+ apiURL, cancel := test.ListenAndServe(t, router, false)
+ defer cancel()
+
+ userHTTPApi, err := inthttp.NewUserAPIClient(apiURL, &http.Client{Timeout: time.Second * 5})
+ if err != nil {
+ t.Fatalf("failed to create HTTP client: %s", err)
+ }
+ testCases(t, userHTTPApi)
+
+ })
+ })
+}