aboutsummaryrefslogtreecommitdiff
path: root/clientapi/auth
diff options
context:
space:
mode:
Diffstat (limited to 'clientapi/auth')
-rw-r--r--clientapi/auth/login.go4
-rw-r--r--clientapi/auth/login_test.go31
-rw-r--r--clientapi/auth/password.go33
-rw-r--r--clientapi/auth/user_interactive.go4
-rw-r--r--clientapi/auth/user_interactive_test.go20
5 files changed, 55 insertions, 37 deletions
diff --git a/clientapi/auth/login.go b/clientapi/auth/login.go
index 1c14c6fb..020731c9 100644
--- a/clientapi/auth/login.go
+++ b/clientapi/auth/login.go
@@ -33,7 +33,7 @@ import (
// called after authorization has completed, with the result of the authorization.
// If the final return value is non-nil, an error occurred and the cleanup function
// is nil.
-func LoginFromJSONReader(ctx context.Context, r io.Reader, accountDB AccountDatabase, userAPI UserInternalAPIForLogin, cfg *config.ClientAPI) (*Login, LoginCleanupFunc, *util.JSONResponse) {
+func LoginFromJSONReader(ctx context.Context, r io.Reader, useraccountAPI uapi.UserAccountAPI, userAPI UserInternalAPIForLogin, cfg *config.ClientAPI) (*Login, LoginCleanupFunc, *util.JSONResponse) {
reqBytes, err := ioutil.ReadAll(r)
if err != nil {
err := &util.JSONResponse{
@@ -58,7 +58,7 @@ func LoginFromJSONReader(ctx context.Context, r io.Reader, accountDB AccountData
switch header.Type {
case authtypes.LoginTypePassword:
typ = &LoginTypePassword{
- GetAccountByPassword: accountDB.GetAccountByPassword,
+ GetAccountByPassword: useraccountAPI.QueryAccountByPassword,
Config: cfg,
}
case authtypes.LoginTypeToken:
diff --git a/clientapi/auth/login_test.go b/clientapi/auth/login_test.go
index e295f8f0..d401469c 100644
--- a/clientapi/auth/login_test.go
+++ b/clientapi/auth/login_test.go
@@ -16,7 +16,6 @@ package auth
import (
"context"
- "database/sql"
"net/http"
"reflect"
"strings"
@@ -64,14 +63,13 @@ func TestLoginFromJSONReader(t *testing.T) {
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
- var accountDB fakeAccountDB
var userAPI fakeUserInternalAPI
cfg := &config.ClientAPI{
Matrix: &config.Global{
ServerName: serverName,
},
}
- login, cleanup, err := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &accountDB, &userAPI, cfg)
+ login, cleanup, err := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, &userAPI, cfg)
if err != nil {
t.Fatalf("LoginFromJSONReader failed: %+v", err)
}
@@ -143,14 +141,13 @@ func TestBadLoginFromJSONReader(t *testing.T) {
}
for _, tst := range tsts {
t.Run(tst.Name, func(t *testing.T) {
- var accountDB fakeAccountDB
var userAPI fakeUserInternalAPI
cfg := &config.ClientAPI{
Matrix: &config.Global{
ServerName: serverName,
},
}
- _, cleanup, errRes := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &accountDB, &userAPI, cfg)
+ _, cleanup, errRes := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, &userAPI, cfg)
if errRes == nil {
cleanup(ctx, nil)
t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", errRes, tst.WantErrCode)
@@ -161,24 +158,22 @@ func TestBadLoginFromJSONReader(t *testing.T) {
}
}
-type fakeAccountDB struct {
- AccountDatabase
-}
-
-func (*fakeAccountDB) GetAccountByPassword(ctx context.Context, localpart, password string) (*uapi.Account, error) {
- if password == "invalidpassword" {
- return nil, sql.ErrNoRows
- }
-
- return &uapi.Account{}, nil
-}
-
type fakeUserInternalAPI struct {
UserInternalAPIForLogin
-
+ uapi.UserAccountAPI
DeletedTokens []string
}
+func (ua *fakeUserInternalAPI) QueryAccountByPassword(ctx context.Context, req *uapi.QueryAccountByPasswordRequest, res *uapi.QueryAccountByPasswordResponse) error {
+ if req.PlaintextPassword == "invalidpassword" {
+ res.Account = nil
+ return nil
+ }
+ res.Exists = true
+ res.Account = &uapi.Account{}
+ return nil
+}
+
func (ua *fakeUserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *uapi.PerformLoginTokenDeletionRequest, res *uapi.PerformLoginTokenDeletionResponse) error {
ua.DeletedTokens = append(ua.DeletedTokens, req.Token)
return nil
diff --git a/clientapi/auth/password.go b/clientapi/auth/password.go
index 046b36f0..bcb4ca97 100644
--- a/clientapi/auth/password.go
+++ b/clientapi/auth/password.go
@@ -16,7 +16,6 @@ package auth
import (
"context"
- "database/sql"
"net/http"
"strings"
@@ -29,7 +28,7 @@ import (
"github.com/matrix-org/util"
)
-type GetAccountByPassword func(ctx context.Context, localpart, password string) (*api.Account, error)
+type GetAccountByPassword func(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error
type PasswordRequest struct {
Login
@@ -77,19 +76,33 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login,
}
}
// Squash username to all lowercase letters
- _, err = t.GetAccountByPassword(ctx, strings.ToLower(localpart), r.Password)
+ res := &api.QueryAccountByPasswordResponse{}
+ err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{Localpart: strings.ToLower(localpart), PlaintextPassword: r.Password}, res)
if err != nil {
- if err == sql.ErrNoRows {
- _, err = t.GetAccountByPassword(ctx, localpart, r.Password)
- if err == nil {
- return &r.Login, nil
+ return nil, &util.JSONResponse{
+ Code: http.StatusInternalServerError,
+ JSON: jsonerror.Unknown("unable to fetch account by password"),
+ }
+ }
+
+ if !res.Exists {
+ err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{
+ Localpart: localpart,
+ PlaintextPassword: r.Password,
+ }, res)
+ if err != nil {
+ return nil, &util.JSONResponse{
+ Code: http.StatusInternalServerError,
+ JSON: jsonerror.Unknown("unable to fetch account by password"),
}
}
// Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows
// but that would leak the existence of the user.
- return nil, &util.JSONResponse{
- Code: http.StatusForbidden,
- JSON: jsonerror.Forbidden("The username or password was incorrect or the account does not exist."),
+ if !res.Exists {
+ return nil, &util.JSONResponse{
+ Code: http.StatusForbidden,
+ JSON: jsonerror.Forbidden("The username or password was incorrect or the account does not exist."),
+ }
}
}
return &r.Login, nil
diff --git a/clientapi/auth/user_interactive.go b/clientapi/auth/user_interactive.go
index 4db75809..22c430f9 100644
--- a/clientapi/auth/user_interactive.go
+++ b/clientapi/auth/user_interactive.go
@@ -110,9 +110,9 @@ type UserInteractive struct {
Sessions map[string][]string
}
-func NewUserInteractive(accountDB AccountDatabase, cfg *config.ClientAPI) *UserInteractive {
+func NewUserInteractive(userAccountAPI api.UserAccountAPI, cfg *config.ClientAPI) *UserInteractive {
typePassword := &LoginTypePassword{
- GetAccountByPassword: accountDB.GetAccountByPassword,
+ GetAccountByPassword: userAccountAPI.QueryAccountByPassword,
Config: cfg,
}
return &UserInteractive{
diff --git a/clientapi/auth/user_interactive_test.go b/clientapi/auth/user_interactive_test.go
index 76d161a7..a4b4587a 100644
--- a/clientapi/auth/user_interactive_test.go
+++ b/clientapi/auth/user_interactive_test.go
@@ -25,15 +25,25 @@ var (
)
type fakeAccountDatabase struct {
- AccountDatabase
+ api.UserAccountAPI
}
-func (*fakeAccountDatabase) GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) {
- acc, ok := lookup[localpart+" "+plaintextPassword]
+func (d *fakeAccountDatabase) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
+ return nil
+}
+
+func (d *fakeAccountDatabase) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error {
+ return nil
+}
+
+func (d *fakeAccountDatabase) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error {
+ acc, ok := lookup[req.Localpart+" "+req.PlaintextPassword]
if !ok {
- return nil, fmt.Errorf("unknown user/password")
+ return fmt.Errorf("unknown user/password")
}
- return acc, nil
+ res.Account = acc
+ res.Exists = true
+ return nil
}
func setup() *UserInteractive {