diff options
Diffstat (limited to 'clientapi/auth')
-rw-r--r-- | clientapi/auth/login.go | 4 | ||||
-rw-r--r-- | clientapi/auth/login_test.go | 31 | ||||
-rw-r--r-- | clientapi/auth/password.go | 33 | ||||
-rw-r--r-- | clientapi/auth/user_interactive.go | 4 | ||||
-rw-r--r-- | clientapi/auth/user_interactive_test.go | 20 |
5 files changed, 55 insertions, 37 deletions
diff --git a/clientapi/auth/login.go b/clientapi/auth/login.go index 1c14c6fb..020731c9 100644 --- a/clientapi/auth/login.go +++ b/clientapi/auth/login.go @@ -33,7 +33,7 @@ import ( // called after authorization has completed, with the result of the authorization. // If the final return value is non-nil, an error occurred and the cleanup function // is nil. -func LoginFromJSONReader(ctx context.Context, r io.Reader, accountDB AccountDatabase, userAPI UserInternalAPIForLogin, cfg *config.ClientAPI) (*Login, LoginCleanupFunc, *util.JSONResponse) { +func LoginFromJSONReader(ctx context.Context, r io.Reader, useraccountAPI uapi.UserAccountAPI, userAPI UserInternalAPIForLogin, cfg *config.ClientAPI) (*Login, LoginCleanupFunc, *util.JSONResponse) { reqBytes, err := ioutil.ReadAll(r) if err != nil { err := &util.JSONResponse{ @@ -58,7 +58,7 @@ func LoginFromJSONReader(ctx context.Context, r io.Reader, accountDB AccountData switch header.Type { case authtypes.LoginTypePassword: typ = &LoginTypePassword{ - GetAccountByPassword: accountDB.GetAccountByPassword, + GetAccountByPassword: useraccountAPI.QueryAccountByPassword, Config: cfg, } case authtypes.LoginTypeToken: diff --git a/clientapi/auth/login_test.go b/clientapi/auth/login_test.go index e295f8f0..d401469c 100644 --- a/clientapi/auth/login_test.go +++ b/clientapi/auth/login_test.go @@ -16,7 +16,6 @@ package auth import ( "context" - "database/sql" "net/http" "reflect" "strings" @@ -64,14 +63,13 @@ func TestLoginFromJSONReader(t *testing.T) { } for _, tst := range tsts { t.Run(tst.Name, func(t *testing.T) { - var accountDB fakeAccountDB var userAPI fakeUserInternalAPI cfg := &config.ClientAPI{ Matrix: &config.Global{ ServerName: serverName, }, } - login, cleanup, err := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &accountDB, &userAPI, cfg) + login, cleanup, err := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, &userAPI, cfg) if err != nil { t.Fatalf("LoginFromJSONReader failed: %+v", err) } @@ -143,14 +141,13 @@ func TestBadLoginFromJSONReader(t *testing.T) { } for _, tst := range tsts { t.Run(tst.Name, func(t *testing.T) { - var accountDB fakeAccountDB var userAPI fakeUserInternalAPI cfg := &config.ClientAPI{ Matrix: &config.Global{ ServerName: serverName, }, } - _, cleanup, errRes := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &accountDB, &userAPI, cfg) + _, cleanup, errRes := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, &userAPI, cfg) if errRes == nil { cleanup(ctx, nil) t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", errRes, tst.WantErrCode) @@ -161,24 +158,22 @@ func TestBadLoginFromJSONReader(t *testing.T) { } } -type fakeAccountDB struct { - AccountDatabase -} - -func (*fakeAccountDB) GetAccountByPassword(ctx context.Context, localpart, password string) (*uapi.Account, error) { - if password == "invalidpassword" { - return nil, sql.ErrNoRows - } - - return &uapi.Account{}, nil -} - type fakeUserInternalAPI struct { UserInternalAPIForLogin - + uapi.UserAccountAPI DeletedTokens []string } +func (ua *fakeUserInternalAPI) QueryAccountByPassword(ctx context.Context, req *uapi.QueryAccountByPasswordRequest, res *uapi.QueryAccountByPasswordResponse) error { + if req.PlaintextPassword == "invalidpassword" { + res.Account = nil + return nil + } + res.Exists = true + res.Account = &uapi.Account{} + return nil +} + func (ua *fakeUserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *uapi.PerformLoginTokenDeletionRequest, res *uapi.PerformLoginTokenDeletionResponse) error { ua.DeletedTokens = append(ua.DeletedTokens, req.Token) return nil diff --git a/clientapi/auth/password.go b/clientapi/auth/password.go index 046b36f0..bcb4ca97 100644 --- a/clientapi/auth/password.go +++ b/clientapi/auth/password.go @@ -16,7 +16,6 @@ package auth import ( "context" - "database/sql" "net/http" "strings" @@ -29,7 +28,7 @@ import ( "github.com/matrix-org/util" ) -type GetAccountByPassword func(ctx context.Context, localpart, password string) (*api.Account, error) +type GetAccountByPassword func(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error type PasswordRequest struct { Login @@ -77,19 +76,33 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, } } // Squash username to all lowercase letters - _, err = t.GetAccountByPassword(ctx, strings.ToLower(localpart), r.Password) + res := &api.QueryAccountByPasswordResponse{} + err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{Localpart: strings.ToLower(localpart), PlaintextPassword: r.Password}, res) if err != nil { - if err == sql.ErrNoRows { - _, err = t.GetAccountByPassword(ctx, localpart, r.Password) - if err == nil { - return &r.Login, nil + return nil, &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.Unknown("unable to fetch account by password"), + } + } + + if !res.Exists { + err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{ + Localpart: localpart, + PlaintextPassword: r.Password, + }, res) + if err != nil { + return nil, &util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.Unknown("unable to fetch account by password"), } } // Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows // but that would leak the existence of the user. - return nil, &util.JSONResponse{ - Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("The username or password was incorrect or the account does not exist."), + if !res.Exists { + return nil, &util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("The username or password was incorrect or the account does not exist."), + } } } return &r.Login, nil diff --git a/clientapi/auth/user_interactive.go b/clientapi/auth/user_interactive.go index 4db75809..22c430f9 100644 --- a/clientapi/auth/user_interactive.go +++ b/clientapi/auth/user_interactive.go @@ -110,9 +110,9 @@ type UserInteractive struct { Sessions map[string][]string } -func NewUserInteractive(accountDB AccountDatabase, cfg *config.ClientAPI) *UserInteractive { +func NewUserInteractive(userAccountAPI api.UserAccountAPI, cfg *config.ClientAPI) *UserInteractive { typePassword := &LoginTypePassword{ - GetAccountByPassword: accountDB.GetAccountByPassword, + GetAccountByPassword: userAccountAPI.QueryAccountByPassword, Config: cfg, } return &UserInteractive{ diff --git a/clientapi/auth/user_interactive_test.go b/clientapi/auth/user_interactive_test.go index 76d161a7..a4b4587a 100644 --- a/clientapi/auth/user_interactive_test.go +++ b/clientapi/auth/user_interactive_test.go @@ -25,15 +25,25 @@ var ( ) type fakeAccountDatabase struct { - AccountDatabase + api.UserAccountAPI } -func (*fakeAccountDatabase) GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) { - acc, ok := lookup[localpart+" "+plaintextPassword] +func (d *fakeAccountDatabase) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error { + return nil +} + +func (d *fakeAccountDatabase) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error { + return nil +} + +func (d *fakeAccountDatabase) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error { + acc, ok := lookup[req.Localpart+" "+req.PlaintextPassword] if !ok { - return nil, fmt.Errorf("unknown user/password") + return fmt.Errorf("unknown user/password") } - return acc, nil + res.Account = acc + res.Exists = true + return nil } func setup() *UserInteractive { |