diff options
Diffstat (limited to 'userapi')
-rw-r--r-- | userapi/api/api.go | 10 | ||||
-rw-r--r-- | userapi/api/api_trace.go | 6 | ||||
-rw-r--r-- | userapi/internal/api.go | 5 | ||||
-rw-r--r-- | userapi/inthttp/client.go | 12 | ||||
-rw-r--r-- | userapi/inthttp/server.go | 5 | ||||
-rw-r--r-- | userapi/userapi_test.go | 61 |
6 files changed, 99 insertions, 0 deletions
diff --git a/userapi/api/api.go b/userapi/api/api.go index d3f5aefc..4ea2e91c 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -50,6 +50,7 @@ type KeyserverUserAPI interface { type RoomserverUserAPI interface { QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error + QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error) } // api functions required by the media api @@ -671,3 +672,12 @@ type PerformSaveThreePIDAssociationRequest struct { ServerName gomatrixserverlib.ServerName Medium string } + +type QueryAccountByLocalpartRequest struct { + Localpart string + ServerName gomatrixserverlib.ServerName +} + +type QueryAccountByLocalpartResponse struct { + Account *Account +} diff --git a/userapi/api/api_trace.go b/userapi/api/api_trace.go index ce661770..d10b5767 100644 --- a/userapi/api/api_trace.go +++ b/userapi/api/api_trace.go @@ -204,6 +204,12 @@ func (t *UserInternalAPITrace) PerformSaveThreePIDAssociation(ctx context.Contex return err } +func (t *UserInternalAPITrace) QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) error { + err := t.Impl.QueryAccountByLocalpart(ctx, req, res) + util.GetLogger(ctx).Infof("QueryAccountByLocalpart req=%+v res=%+v", js(req), js(res)) + return err +} + func js(thing interface{}) string { b, err := json.Marshal(thing) if err != nil { diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 3f256457..0bb480da 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -548,6 +548,11 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc return nil } +func (a *UserInternalAPI) QueryAccountByLocalpart(ctx context.Context, req *api.QueryAccountByLocalpartRequest, res *api.QueryAccountByLocalpartResponse) (err error) { + res.Account, err = a.DB.GetAccountByLocalpart(ctx, req.Localpart, req.ServerName) + return +} + // Return the appservice 'device' or nil if the token is not an appservice. Returns an error if there was a problem // creating a 'device'. func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appServiceUserID string) (*api.Device, error) { diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 87ae058c..51b0fe3e 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -60,6 +60,7 @@ const ( QueryAccountByPasswordPath = "/userapi/queryAccountByPassword" QueryLocalpartForThreePIDPath = "/userapi/queryLocalpartForThreePID" QueryThreePIDsForLocalpartPath = "/userapi/queryThreePIDsForLocalpart" + QueryAccountByLocalpartPath = "/userapi/queryAccountType" ) // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. @@ -440,3 +441,14 @@ func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation( h.httpClient, ctx, request, response, ) } + +func (h *httpUserInternalAPI) QueryAccountByLocalpart( + ctx context.Context, + req *api.QueryAccountByLocalpartRequest, + res *api.QueryAccountByLocalpartResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryAccountByLocalpart", h.apiURL+QueryAccountByLocalpartPath, + h.httpClient, ctx, req, res, + ) +} diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index f0579079..b40b507c 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -189,4 +189,9 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI, enableMetrics PerformSaveThreePIDAssociationPath, httputil.MakeInternalRPCAPI("UserAPIPerformSaveThreePIDAssociation", enableMetrics, s.PerformSaveThreePIDAssociation), ) + + internalAPIMux.Handle( + QueryAccountByLocalpartPath, + httputil.MakeInternalRPCAPI("AccountByLocalpart", enableMetrics, s.QueryAccountByLocalpart), + ) } diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 8a19af19..dada56de 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -307,3 +307,64 @@ func TestLoginToken(t *testing.T) { }) }) } + +func TestQueryAccountByLocalpart(t *testing.T) { + alice := test.NewUser(t) + + localpart, userServername, _ := gomatrixserverlib.SplitID('@', alice.ID) + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + intAPI, db, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) + defer close() + + createdAcc, err := db.CreateAccount(ctx, localpart, userServername, "", "", alice.AccountType) + if err != nil { + t.Error(err) + } + + testCases := func(t *testing.T, internalAPI api.UserInternalAPI) { + // Query existing account + queryAccResp := &api.QueryAccountByLocalpartResponse{} + if err = internalAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{ + Localpart: localpart, + ServerName: userServername, + }, queryAccResp); err != nil { + t.Error(err) + } + if !reflect.DeepEqual(createdAcc, queryAccResp.Account) { + t.Fatalf("created and queried accounts don't match:\n%+v vs.\n%+v", createdAcc, queryAccResp.Account) + } + + // Query non-existent account, this should result in an error + err = internalAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{ + Localpart: "doesnotexist", + ServerName: userServername, + }, queryAccResp) + + if err == nil { + t.Fatalf("expected an error, but got none: %+v", queryAccResp) + } + } + + t.Run("Monolith", func(t *testing.T) { + testCases(t, intAPI) + // also test tracing + testCases(t, &api.UserInternalAPITrace{Impl: intAPI}) + }) + + t.Run("HTTP API", func(t *testing.T) { + router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() + userapi.AddInternalRoutes(router, intAPI, false) + apiURL, cancel := test.ListenAndServe(t, router, false) + defer cancel() + + userHTTPApi, err := inthttp.NewUserAPIClient(apiURL, &http.Client{Timeout: time.Second * 5}) + if err != nil { + t.Fatalf("failed to create HTTP client: %s", err) + } + testCases(t, userHTTPApi) + + }) + }) +} |