aboutsummaryrefslogtreecommitdiff
path: root/clientapi/routing
diff options
context:
space:
mode:
authorS7evinK <2353100+S7evinK@users.noreply.github.com>2022-02-25 14:33:02 +0100
committerGitHub <noreply@github.com>2022-02-25 14:33:02 +0100
commitcf27e26712c5aa655377f21a2efe34237c3d681f (patch)
treed24d40b9bf21223347c35c1619a24d49b978d933 /clientapi/routing
parent4c07374c42e5d671cfb137634475f43f84f9db0e (diff)
Remember parameters on registration (#2225)
* Remember parameters for sessions Cleanup sessions on successfully registering or after a while * Add flakey test * Update to use time.AfterFunc, add more tests * Try to drain the channel, if possible
Diffstat (limited to 'clientapi/routing')
-rw-r--r--clientapi/routing/auth_fallback.go2
-rw-r--r--clientapi/routing/key_crosssigning.go2
-rw-r--r--clientapi/routing/password.go2
-rw-r--r--clientapi/routing/register.go116
-rw-r--r--clientapi/routing/register_test.go45
5 files changed, 141 insertions, 26 deletions
diff --git a/clientapi/routing/auth_fallback.go b/clientapi/routing/auth_fallback.go
index 839ca9e5..abfe830f 100644
--- a/clientapi/routing/auth_fallback.go
+++ b/clientapi/routing/auth_fallback.go
@@ -162,7 +162,7 @@ func AuthFallback(
}
// Success. Add recaptcha as a completed login flow
- AddCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)
+ sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)
serveSuccess()
return nil
diff --git a/clientapi/routing/key_crosssigning.go b/clientapi/routing/key_crosssigning.go
index 7ecab9d4..4426b7fd 100644
--- a/clientapi/routing/key_crosssigning.go
+++ b/clientapi/routing/key_crosssigning.go
@@ -70,7 +70,7 @@ func UploadCrossSigningDeviceKeys(
if _, authErr := typePassword.Login(req.Context(), &uploadReq.Auth.PasswordRequest); authErr != nil {
return *authErr
}
- AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
+ sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
uploadReq.UserID = device.UserID
keyserverAPI.PerformUploadDeviceKeys(req.Context(), &uploadReq.PerformUploadDeviceKeysRequest, uploadRes)
diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go
index 49951019..acac60fa 100644
--- a/clientapi/routing/password.go
+++ b/clientapi/routing/password.go
@@ -74,7 +74,7 @@ func Password(
if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil {
return *authErr
}
- AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
+ sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
// Check the new password strength.
if resErr = validatePassword(r.NewPassword); resErr != nil {
diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go
index d00d9886..10cfa432 100644
--- a/clientapi/routing/register.go
+++ b/clientapi/routing/register.go
@@ -72,14 +72,19 @@ func init() {
// sessionsDict keeps track of completed auth stages for each session.
// It shouldn't be passed by value because it contains a mutex.
type sessionsDict struct {
- sync.Mutex
+ sync.RWMutex
sessions map[string][]authtypes.LoginType
+ params map[string]registerRequest
+ timer map[string]*time.Timer
}
-// GetCompletedStages returns the completed stages for a session.
-func (d *sessionsDict) GetCompletedStages(sessionID string) []authtypes.LoginType {
- d.Lock()
- defer d.Unlock()
+// defaultTimeout is the timeout used to clean up sessions
+const defaultTimeOut = time.Minute * 5
+
+// getCompletedStages returns the completed stages for a session.
+func (d *sessionsDict) getCompletedStages(sessionID string) []authtypes.LoginType {
+ d.RLock()
+ defer d.RUnlock()
if completedStages, ok := d.sessions[sessionID]; ok {
return completedStages
@@ -88,28 +93,79 @@ func (d *sessionsDict) GetCompletedStages(sessionID string) []authtypes.LoginTyp
return make([]authtypes.LoginType, 0)
}
+// addParams adds a registerRequest to a sessionID and starts a timer to delete that registerRequest
+func (d *sessionsDict) addParams(sessionID string, r registerRequest) {
+ d.startTimer(defaultTimeOut, sessionID)
+ d.Lock()
+ defer d.Unlock()
+ d.params[sessionID] = r
+}
+
+func (d *sessionsDict) getParams(sessionID string) (registerRequest, bool) {
+ d.RLock()
+ defer d.RUnlock()
+ r, ok := d.params[sessionID]
+ return r, ok
+}
+
+// deleteSession cleans up a given session, either because the registration completed
+// successfully, or because a given timeout (default: 5min) was reached.
+func (d *sessionsDict) deleteSession(sessionID string) {
+ d.Lock()
+ defer d.Unlock()
+ delete(d.params, sessionID)
+ delete(d.sessions, sessionID)
+ // stop the timer, e.g. because the registration was completed
+ if t, ok := d.timer[sessionID]; ok {
+ if !t.Stop() {
+ select {
+ case <-t.C:
+ default:
+ }
+ }
+ delete(d.timer, sessionID)
+ }
+}
+
func newSessionsDict() *sessionsDict {
return &sessionsDict{
sessions: make(map[string][]authtypes.LoginType),
+ params: make(map[string]registerRequest),
+ timer: make(map[string]*time.Timer),
}
}
-// AddCompletedSessionStage records that a session has completed an auth stage.
-func AddCompletedSessionStage(sessionID string, stage authtypes.LoginType) {
- sessions.Lock()
- defer sessions.Unlock()
+func (d *sessionsDict) startTimer(duration time.Duration, sessionID string) {
+ d.Lock()
+ defer d.Unlock()
+ t, ok := d.timer[sessionID]
+ if ok {
+ if !t.Stop() {
+ <-t.C
+ }
+ t.Reset(duration)
+ return
+ }
+ d.timer[sessionID] = time.AfterFunc(duration, func() {
+ d.deleteSession(sessionID)
+ })
+}
- for _, completedStage := range sessions.sessions[sessionID] {
+// addCompletedSessionStage records that a session has completed an auth stage
+// also starts a timer to delete the session once done.
+func (d *sessionsDict) addCompletedSessionStage(sessionID string, stage authtypes.LoginType) {
+ d.startTimer(defaultTimeOut, sessionID)
+ d.Lock()
+ defer d.Unlock()
+ for _, completedStage := range d.sessions[sessionID] {
if completedStage == stage {
return
}
}
- sessions.sessions[sessionID] = append(sessions.sessions[sessionID], stage)
+ d.sessions[sessionID] = append(sessions.sessions[sessionID], stage)
}
var (
- // TODO: Remove old sessions. Need to do so on a session-specific timeout.
- // sessions stores the completed flow stages for all sessions. Referenced using their sessionID.
sessions = newSessionsDict()
validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`)
)
@@ -167,7 +223,7 @@ func newUserInteractiveResponse(
params map[string]interface{},
) userInteractiveResponse {
return userInteractiveResponse{
- fs, sessions.GetCompletedStages(sessionID), params, sessionID,
+ fs, sessions.getCompletedStages(sessionID), params, sessionID,
}
}
@@ -645,12 +701,12 @@ func handleRegistrationFlow(
}
// Add Recaptcha to the list of completed registration stages
- AddCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)
+ sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha)
case authtypes.LoginTypeDummy:
// there is nothing to do
// Add Dummy to the list of completed registration stages
- AddCompletedSessionStage(sessionID, authtypes.LoginTypeDummy)
+ sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeDummy)
case "":
// An empty auth type means that we want to fetch the available
@@ -666,7 +722,7 @@ func handleRegistrationFlow(
// Check if the user's registration flow has been completed successfully
// A response with current registration flow and remaining available methods
// will be returned if a flow has not been successfully completed yet
- return checkAndCompleteFlow(sessions.GetCompletedStages(sessionID),
+ return checkAndCompleteFlow(sessions.getCompletedStages(sessionID),
req, r, sessionID, cfg, userAPI)
}
@@ -708,7 +764,7 @@ func handleApplicationServiceRegistration(
// Don't need to worry about appending to registration stages as
// application service registration is entirely separate.
return completeRegistration(
- req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(),
+ req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session,
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService,
)
}
@@ -727,11 +783,11 @@ func checkAndCompleteFlow(
if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) {
// This flow was completed, registration can continue
return completeRegistration(
- req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(),
+ req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID,
r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser,
)
}
-
+ sessions.addParams(sessionID, r)
// There are still more stages to complete.
// Return the flows and those that have been completed.
return util.JSONResponse{
@@ -750,11 +806,25 @@ func checkAndCompleteFlow(
func completeRegistration(
ctx context.Context,
userAPI userapi.UserInternalAPI,
- username, password, appserviceID, ipAddr, userAgent string,
+ username, password, appserviceID, ipAddr, userAgent, sessionID string,
inhibitLogin eventutil.WeakBoolean,
displayName, deviceID *string,
accType userapi.AccountType,
) util.JSONResponse {
+ var registrationOK bool
+ defer func() {
+ if registrationOK {
+ sessions.deleteSession(sessionID)
+ }
+ }()
+
+ if data, ok := sessions.getParams(sessionID); ok {
+ username = data.Username
+ password = data.Password
+ deviceID = data.DeviceID
+ displayName = data.InitialDisplayName
+ inhibitLogin = data.InhibitLogin
+ }
if username == "" {
return util.JSONResponse{
Code: http.StatusBadRequest,
@@ -795,6 +865,7 @@ func completeRegistration(
// Check whether inhibit_login option is set. If so, don't create an access
// token or a device for this user
if inhibitLogin {
+ registrationOK = true
return util.JSONResponse{
Code: http.StatusOK,
JSON: registerResponse{
@@ -828,6 +899,7 @@ func completeRegistration(
}
}
+ registrationOK = true
return util.JSONResponse{
Code: http.StatusOK,
JSON: registerResponse{
@@ -976,5 +1048,5 @@ func handleSharedSecretRegistration(userAPI userapi.UserInternalAPI, sr *SharedS
if ssrr.Admin {
accType = userapi.AccountTypeAdmin
}
- return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), false, &ssrr.User, &deviceID, accType)
+ return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType)
}
diff --git a/clientapi/routing/register_test.go b/clientapi/routing/register_test.go
index 1f615dc2..c6b7e61c 100644
--- a/clientapi/routing/register_test.go
+++ b/clientapi/routing/register_test.go
@@ -17,6 +17,7 @@ package routing
import (
"regexp"
"testing"
+ "time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/setup/config"
@@ -140,7 +141,7 @@ func TestFlowCheckingExtraneousIncorrectInput(t *testing.T) {
func TestEmptyCompletedFlows(t *testing.T) {
fakeEmptySessions := newSessionsDict()
fakeSessionID := "aRandomSessionIDWhichDoesNotExist"
- ret := fakeEmptySessions.GetCompletedStages(fakeSessionID)
+ ret := fakeEmptySessions.getCompletedStages(fakeSessionID)
// check for []
if ret == nil || len(ret) != 0 {
@@ -208,3 +209,45 @@ func TestValidationOfApplicationServices(t *testing.T) {
t.Errorf("user_id should not have been valid: @_something_else:localhost")
}
}
+
+func TestSessionCleanUp(t *testing.T) {
+ s := newSessionsDict()
+
+ t.Run("session is cleaned up after a while", func(t *testing.T) {
+ t.Parallel()
+ dummySession := "helloWorld"
+ // manually added, as s.addParams() would start the timer with the default timeout
+ s.params[dummySession] = registerRequest{Username: "Testing"}
+ s.startTimer(time.Millisecond, dummySession)
+ time.Sleep(time.Millisecond * 2)
+ if data, ok := s.getParams(dummySession); ok {
+ t.Errorf("expected session to be deleted: %+v", data)
+ }
+ })
+
+ t.Run("session is deleted, once the registration completed", func(t *testing.T) {
+ t.Parallel()
+ dummySession := "helloWorld2"
+ s.startTimer(time.Minute, dummySession)
+ s.deleteSession(dummySession)
+ if data, ok := s.getParams(dummySession); ok {
+ t.Errorf("expected session to be deleted: %+v", data)
+ }
+ })
+
+ t.Run("session timer is restarted after second call", func(t *testing.T) {
+ t.Parallel()
+ dummySession := "helloWorld3"
+ // the following will start a timer with the default timeout of 5min
+ s.addParams(dummySession, registerRequest{Username: "Testing"})
+ s.addCompletedSessionStage(dummySession, authtypes.LoginTypeRecaptcha)
+ s.addCompletedSessionStage(dummySession, authtypes.LoginTypeDummy)
+ s.getCompletedStages(dummySession)
+ // reset the timer with a lower timeout
+ s.startTimer(time.Millisecond, dummySession)
+ time.Sleep(time.Millisecond * 2)
+ if data, ok := s.getParams(dummySession); ok {
+ t.Errorf("expected session to be deleted: %+v", data)
+ }
+ })
+} \ No newline at end of file