aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNeil Alexander <neilalexander@users.noreply.github.com>2020-09-04 15:16:13 +0100
committerGitHub <noreply@github.com>2020-09-04 15:16:13 +0100
commit5076925c184998414c3691e97fc21b554abf4a55 (patch)
treeb4da4d8a3015166e16eb9a6009f733653f0485bc
parentca8dcf46b746686e213b184c3ae42ba0be17b46b (diff)
Password changes (#1397)
* User API support for password changes * Password changes in client API * Update sytest-whitelist * Remove debug logging * Default logout_devices to true * Fix deleting devices by local part
-rw-r--r--clientapi/auth/authtypes/logintypes.go1
-rw-r--r--clientapi/routing/password.go127
-rw-r--r--clientapi/routing/routing.go9
-rw-r--r--sytest-whitelist5
-rw-r--r--userapi/api/api.go17
-rw-r--r--userapi/internal/api.go11
-rw-r--r--userapi/inthttp/client.go13
-rw-r--r--userapi/inthttp/server.go13
-rw-r--r--userapi/storage/accounts/interface.go1
-rw-r--r--userapi/storage/accounts/postgres/accounts_table.go16
-rw-r--r--userapi/storage/accounts/postgres/storage.go11
-rw-r--r--userapi/storage/accounts/sqlite3/accounts_table.go16
-rw-r--r--userapi/storage/accounts/sqlite3/storage.go12
-rw-r--r--userapi/storage/devices/interface.go2
-rw-r--r--userapi/storage/devices/postgres/devices_table.go12
-rw-r--r--userapi/storage/devices/postgres/storage.go8
-rw-r--r--userapi/storage/devices/sqlite3/devices_table.go12
-rw-r--r--userapi/storage/devices/sqlite3/storage.go8
18 files changed, 268 insertions, 26 deletions
diff --git a/clientapi/auth/authtypes/logintypes.go b/clientapi/auth/authtypes/logintypes.go
index 087e4504..da032425 100644
--- a/clientapi/auth/authtypes/logintypes.go
+++ b/clientapi/auth/authtypes/logintypes.go
@@ -5,6 +5,7 @@ type LoginType string
// The relevant login types implemented in Dendrite
const (
+ LoginTypePassword = "m.login.password"
LoginTypeDummy = "m.login.dummy"
LoginTypeSharedSecret = "org.matrix.login.shared_secret"
LoginTypeRecaptcha = "m.login.recaptcha"
diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go
new file mode 100644
index 00000000..8b81b9f0
--- /dev/null
+++ b/clientapi/routing/password.go
@@ -0,0 +1,127 @@
+package routing
+
+import (
+ "net/http"
+
+ "github.com/matrix-org/dendrite/clientapi/auth"
+ "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
+ "github.com/matrix-org/dendrite/clientapi/httputil"
+ "github.com/matrix-org/dendrite/clientapi/jsonerror"
+ "github.com/matrix-org/dendrite/internal/config"
+ "github.com/matrix-org/dendrite/userapi/api"
+ userapi "github.com/matrix-org/dendrite/userapi/api"
+ "github.com/matrix-org/dendrite/userapi/storage/accounts"
+ "github.com/matrix-org/gomatrixserverlib"
+ "github.com/matrix-org/util"
+)
+
+type newPasswordRequest struct {
+ NewPassword string `json:"new_password"`
+ LogoutDevices bool `json:"logout_devices"`
+ Auth newPasswordAuth `json:"auth"`
+}
+
+type newPasswordAuth struct {
+ Type string `json:"type"`
+ Session string `json:"session"`
+ auth.PasswordRequest
+}
+
+func Password(
+ req *http.Request,
+ userAPI userapi.UserInternalAPI,
+ accountDB accounts.Database,
+ device *api.Device,
+ cfg *config.ClientAPI,
+) util.JSONResponse {
+ // Check that the existing password is right.
+ var r newPasswordRequest
+ r.LogoutDevices = true
+
+ // Unmarshal the request.
+ resErr := httputil.UnmarshalJSONRequest(req, &r)
+ if resErr != nil {
+ return *resErr
+ }
+
+ // Retrieve or generate the sessionID
+ sessionID := r.Auth.Session
+ if sessionID == "" {
+ // Generate a new, random session ID
+ sessionID = util.RandomString(sessionIDLength)
+ }
+
+ // Require password auth to change the password.
+ if r.Auth.Type != authtypes.LoginTypePassword {
+ return util.JSONResponse{
+ Code: http.StatusUnauthorized,
+ JSON: newUserInteractiveResponse(
+ sessionID,
+ []authtypes.Flow{
+ {
+ Stages: []authtypes.LoginType{authtypes.LoginTypePassword},
+ },
+ },
+ nil,
+ ),
+ }
+ }
+
+ // Check if the existing password is correct.
+ typePassword := auth.LoginTypePassword{
+ GetAccountByPassword: accountDB.GetAccountByPassword,
+ Config: cfg,
+ }
+ if _, authErr := typePassword.Login(req.Context(), &r.Auth.PasswordRequest); authErr != nil {
+ return *authErr
+ }
+ AddCompletedSessionStage(sessionID, authtypes.LoginTypePassword)
+
+ // Check the new password strength.
+ if resErr = validatePassword(r.NewPassword); resErr != nil {
+ return *resErr
+ }
+
+ // Get the local part.
+ localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
+ if err != nil {
+ util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
+ return jsonerror.InternalServerError()
+ }
+
+ // Ask the user API to perform the password change.
+ passwordReq := &userapi.PerformPasswordUpdateRequest{
+ Localpart: localpart,
+ Password: r.NewPassword,
+ }
+ passwordRes := &userapi.PerformPasswordUpdateResponse{}
+ if err := userAPI.PerformPasswordUpdate(req.Context(), passwordReq, passwordRes); err != nil {
+ util.GetLogger(req.Context()).WithError(err).Error("PerformPasswordUpdate failed")
+ return jsonerror.InternalServerError()
+ }
+ if !passwordRes.PasswordUpdated {
+ util.GetLogger(req.Context()).Error("Expected password to have been updated but wasn't")
+ return jsonerror.InternalServerError()
+ }
+
+ // If the request asks us to log out all other devices then
+ // ask the user API to do that.
+ if r.LogoutDevices {
+ logoutReq := &userapi.PerformDeviceDeletionRequest{
+ UserID: device.UserID,
+ DeviceIDs: nil,
+ ExceptDeviceID: device.ID,
+ }
+ logoutRes := &userapi.PerformDeviceDeletionResponse{}
+ if err := userAPI.PerformDeviceDeletion(req.Context(), logoutReq, logoutRes); err != nil {
+ util.GetLogger(req.Context()).WithError(err).Error("PerformDeviceDeletion failed")
+ return jsonerror.InternalServerError()
+ }
+ }
+
+ // Return a success code.
+ return util.JSONResponse{
+ Code: http.StatusOK,
+ JSON: struct{}{},
+ }
+}
diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go
index 708f6fee..b29fccf2 100644
--- a/clientapi/routing/routing.go
+++ b/clientapi/routing/routing.go
@@ -417,6 +417,15 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
+ r0mux.Handle("/account/password",
+ httputil.MakeAuthAPI("password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
+ if r := rateLimits.rateLimit(req); r != nil {
+ return *r
+ }
+ return Password(req, userAPI, accountDB, device, cfg)
+ }),
+ ).Methods(http.MethodPost, http.MethodOptions)
+
// Stub endpoints required by Riot
r0mux.Handle("/login",
diff --git a/sytest-whitelist b/sytest-whitelist
index 7ce59fef..93d2de59 100644
--- a/sytest-whitelist
+++ b/sytest-whitelist
@@ -460,3 +460,8 @@ If user leaves room, remote user changes device and rejoins we see update in /sy
Can search public room list
Can get remote public room list
Asking for a remote rooms list, but supplying the local server's name, returns the local rooms list
+After changing password, can't log in with old password
+After changing password, can log in with new password
+After changing password, existing session still works
+After changing password, different sessions can optionally be kept
+After changing password, a different session no longer works by default
diff --git a/userapi/api/api.go b/userapi/api/api.go
index e6d05c33..3baaa100 100644
--- a/userapi/api/api.go
+++ b/userapi/api/api.go
@@ -26,6 +26,7 @@ import (
type UserInternalAPI interface {
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
+ PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
@@ -63,6 +64,10 @@ type PerformDeviceDeletionRequest struct {
UserID string
// The devices to delete. An empty slice means delete all devices.
DeviceIDs []string
+ // The requesting device ID to exclude from deletion. This is needed
+ // so that a password change doesn't cause that client to be logged
+ // out. Only specify when DeviceIDs is empty.
+ ExceptDeviceID string
}
type PerformDeviceDeletionResponse struct {
@@ -165,6 +170,18 @@ type PerformAccountCreationResponse struct {
Account *Account
}
+// PerformAccountCreationRequest is the request for PerformAccountCreation
+type PerformPasswordUpdateRequest struct {
+ Localpart string // Required: The localpart for this account.
+ Password string // Required: The new password to set.
+}
+
+// PerformAccountCreationResponse is the response for PerformAccountCreation
+type PerformPasswordUpdateResponse struct {
+ PasswordUpdated bool
+ Account *Account
+}
+
// PerformDeviceCreationRequest is the request for PerformDeviceCreation
type PerformDeviceCreationRequest struct {
Localpart string
diff --git a/userapi/internal/api.go b/userapi/internal/api.go
index b97f148e..461c548c 100644
--- a/userapi/internal/api.go
+++ b/userapi/internal/api.go
@@ -98,6 +98,15 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
res.Account = acc
return nil
}
+
+func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
+ if err := a.AccountDB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
+ return err
+ }
+ res.PasswordUpdated = true
+ return nil
+}
+
func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.PerformDeviceCreationRequest, res *api.PerformDeviceCreationResponse) error {
util.GetLogger(ctx).WithFields(logrus.Fields{
"localpart": req.Localpart,
@@ -126,7 +135,7 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
deletedDeviceIDs := req.DeviceIDs
if len(req.DeviceIDs) == 0 {
var devices []api.Device
- devices, err = a.DeviceDB.RemoveAllDevices(ctx, local)
+ devices, err = a.DeviceDB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
for _, d := range devices {
deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
}
diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go
index 5f4df0eb..6dcaf756 100644
--- a/userapi/inthttp/client.go
+++ b/userapi/inthttp/client.go
@@ -30,6 +30,7 @@ const (
PerformDeviceCreationPath = "/userapi/performDeviceCreation"
PerformAccountCreationPath = "/userapi/performAccountCreation"
+ PerformPasswordUpdatePath = "/userapi/performPasswordUpdate"
PerformDeviceDeletionPath = "/userapi/performDeviceDeletion"
PerformDeviceUpdatePath = "/userapi/performDeviceUpdate"
@@ -81,6 +82,18 @@ func (h *httpUserInternalAPI) PerformAccountCreation(
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
+func (h *httpUserInternalAPI) PerformPasswordUpdate(
+ ctx context.Context,
+ request *api.PerformPasswordUpdateRequest,
+ response *api.PerformPasswordUpdateResponse,
+) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPasswordUpdate")
+ defer span.Finish()
+
+ apiURL := h.apiURL + PerformPasswordUpdatePath
+ return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+}
+
func (h *httpUserInternalAPI) PerformDeviceCreation(
ctx context.Context,
request *api.PerformDeviceCreationRequest,
diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go
index 47d68ff2..d2674678 100644
--- a/userapi/inthttp/server.go
+++ b/userapi/inthttp/server.go
@@ -39,6 +39,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
+ internalAPIMux.Handle(PerformAccountCreationPath,
+ httputil.MakeInternalAPI("performPasswordUpdate", func(req *http.Request) util.JSONResponse {
+ request := api.PerformPasswordUpdateRequest{}
+ response := api.PerformPasswordUpdateResponse{}
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.MessageResponse(http.StatusBadRequest, err.Error())
+ }
+ if err := s.PerformPasswordUpdate(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
internalAPIMux.Handle(PerformDeviceCreationPath,
httputil.MakeInternalAPI("performDeviceCreation", func(req *http.Request) util.JSONResponse {
request := api.PerformDeviceCreationRequest{}
diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/accounts/interface.go
index 86b91e60..49446f11 100644
--- a/userapi/storage/accounts/interface.go
+++ b/userapi/storage/accounts/interface.go
@@ -28,6 +28,7 @@ type Database interface {
internal.PartitionStorer
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
+ SetPassword(ctx context.Context, localpart string, plaintextPassword string) error
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
SetDisplayName(ctx context.Context, localpart string, displayName string) error
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
diff --git a/userapi/storage/accounts/postgres/accounts_table.go b/userapi/storage/accounts/postgres/accounts_table.go
index 931ffb73..8c8d32cf 100644
--- a/userapi/storage/accounts/postgres/accounts_table.go
+++ b/userapi/storage/accounts/postgres/accounts_table.go
@@ -47,6 +47,9 @@ CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
+const updatePasswordSQL = "" +
+ "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
+
const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
@@ -56,10 +59,9 @@ const selectPasswordHashSQL = "" +
const selectNewNumericLocalpartSQL = "" +
"SELECT nextval('numeric_username_seq')"
-// TODO: Update password
-
type accountsStatements struct {
insertAccountStmt *sql.Stmt
+ updatePasswordStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt
@@ -74,6 +76,9 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil {
return
}
+ if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil {
+ return
+ }
if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil {
return
}
@@ -114,6 +119,13 @@ func (s *accountsStatements) insertAccount(
}, nil
}
+func (s *accountsStatements) updatePassword(
+ ctx context.Context, localpart, passwordHash string,
+) (err error) {
+ _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
+ return
+}
+
func (s *accountsStatements) selectPasswordHash(
ctx context.Context, localpart string,
) (hash string, err error) {
diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go
index b36264dd..8b9ebef8 100644
--- a/userapi/storage/accounts/postgres/storage.go
+++ b/userapi/storage/accounts/postgres/storage.go
@@ -112,6 +112,17 @@ func (d *Database) SetDisplayName(
return d.profiles.setDisplayName(ctx, localpart, displayName)
}
+// SetPassword sets the account password to the given hash.
+func (d *Database) SetPassword(
+ ctx context.Context, localpart, plaintextPassword string,
+) error {
+ hash, err := hashPassword(plaintextPassword)
+ if err != nil {
+ return err
+ }
+ return d.accounts.updatePassword(ctx, localpart, hash)
+}
+
// CreateGuestAccount makes a new guest account and creates an empty profile
// for this account.
func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) {
diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go
index 798a6de9..fbbdc337 100644
--- a/userapi/storage/accounts/sqlite3/accounts_table.go
+++ b/userapi/storage/accounts/sqlite3/accounts_table.go
@@ -45,6 +45,9 @@ CREATE TABLE IF NOT EXISTS account_accounts (
const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
+const updatePasswordSQL = "" +
+ "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
+
const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
@@ -54,11 +57,10 @@ const selectPasswordHashSQL = "" +
const selectNewNumericLocalpartSQL = "" +
"SELECT COUNT(localpart) FROM account_accounts"
-// TODO: Update password
-
type accountsStatements struct {
db *sql.DB
insertAccountStmt *sql.Stmt
+ updatePasswordStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt
@@ -75,6 +77,9 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil {
return
}
+ if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil {
+ return
+ }
if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil {
return
}
@@ -115,6 +120,13 @@ func (s *accountsStatements) insertAccount(
}, nil
}
+func (s *accountsStatements) updatePassword(
+ ctx context.Context, localpart, passwordHash string,
+) (err error) {
+ _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
+ return
+}
+
func (s *accountsStatements) selectPasswordHash(
ctx context.Context, localpart string,
) (hash string, err error) {
diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go
index 46106297..4b66304c 100644
--- a/userapi/storage/accounts/sqlite3/storage.go
+++ b/userapi/storage/accounts/sqlite3/storage.go
@@ -126,6 +126,18 @@ func (d *Database) SetDisplayName(
})
}
+// SetPassword sets the account password to the given hash.
+func (d *Database) SetPassword(
+ ctx context.Context, localpart, plaintextPassword string,
+) error {
+ hash, err := hashPassword(plaintextPassword)
+ if err != nil {
+ return err
+ }
+ err = d.accounts.updatePassword(ctx, localpart, hash)
+ return err
+}
+
// CreateGuestAccount makes a new guest account and creates an empty profile
// for this account.
func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) {
diff --git a/userapi/storage/devices/interface.go b/userapi/storage/devices/interface.go
index 9b4261c9..168c84c5 100644
--- a/userapi/storage/devices/interface.go
+++ b/userapi/storage/devices/interface.go
@@ -36,5 +36,5 @@ type Database interface {
RemoveDevice(ctx context.Context, deviceID, localpart string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
- RemoveAllDevices(ctx context.Context, localpart string) (devices []api.Device, err error)
+ RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
}
diff --git a/userapi/storage/devices/postgres/devices_table.go b/userapi/storage/devices/postgres/devices_table.go
index 282466f8..c06af754 100644
--- a/userapi/storage/devices/postgres/devices_table.go
+++ b/userapi/storage/devices/postgres/devices_table.go
@@ -70,7 +70,7 @@ const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" +
- "SELECT device_id, display_name FROM device_devices WHERE localpart = $1"
+ "SELECT device_id, display_name FROM device_devices WHERE localpart = $1 AND device_id != $2"
const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
@@ -79,7 +79,7 @@ const deleteDeviceSQL = "" +
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
const deleteDevicesByLocalpartSQL = "" +
- "DELETE FROM device_devices WHERE localpart = $1"
+ "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
@@ -179,10 +179,10 @@ func (s *devicesStatements) deleteDevices(
// deleteDevicesByLocalpart removes all devices for the
// given user localpart.
func (s *devicesStatements) deleteDevicesByLocalpart(
- ctx context.Context, txn *sql.Tx, localpart string,
+ ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
- _, err := stmt.ExecContext(ctx, localpart)
+ _, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
return err
}
@@ -251,10 +251,10 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s
}
func (s *devicesStatements) selectDevicesByLocalpart(
- ctx context.Context, txn *sql.Tx, localpart string,
+ ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) ([]api.Device, error) {
devices := []api.Device{}
- rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart)
+ rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
if err != nil {
return devices, err
diff --git a/userapi/storage/devices/postgres/storage.go b/userapi/storage/devices/postgres/storage.go
index 04dae986..c5bd5b6c 100644
--- a/userapi/storage/devices/postgres/storage.go
+++ b/userapi/storage/devices/postgres/storage.go
@@ -68,7 +68,7 @@ func (d *Database) GetDeviceByID(
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {
- return d.devices.selectDevicesByLocalpart(ctx, nil, localpart)
+ return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
@@ -175,14 +175,14 @@ func (d *Database) RemoveDevices(
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
- ctx context.Context, localpart string,
+ ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
- devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart)
+ devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil {
return err
}
- if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
+ if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err
}
return nil
diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go
index ecf43524..c75e1982 100644
--- a/userapi/storage/devices/sqlite3/devices_table.go
+++ b/userapi/storage/devices/sqlite3/devices_table.go
@@ -59,7 +59,7 @@ const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" +
- "SELECT device_id, display_name FROM device_devices WHERE localpart = $1"
+ "SELECT device_id, display_name FROM device_devices WHERE localpart = $1 AND device_id != $2"
const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
@@ -68,7 +68,7 @@ const deleteDeviceSQL = "" +
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
const deleteDevicesByLocalpartSQL = "" +
- "DELETE FROM device_devices WHERE localpart = $1"
+ "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
@@ -182,10 +182,10 @@ func (s *devicesStatements) deleteDevices(
}
func (s *devicesStatements) deleteDevicesByLocalpart(
- ctx context.Context, txn *sql.Tx, localpart string,
+ ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
- _, err := stmt.ExecContext(ctx, localpart)
+ _, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
return err
}
@@ -231,10 +231,10 @@ func (s *devicesStatements) selectDeviceByID(
}
func (s *devicesStatements) selectDevicesByLocalpart(
- ctx context.Context, txn *sql.Tx, localpart string,
+ ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) ([]api.Device, error) {
devices := []api.Device{}
- rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart)
+ rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
if err != nil {
return devices, err
diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go
index f775fb66..7c6645dd 100644
--- a/userapi/storage/devices/sqlite3/storage.go
+++ b/userapi/storage/devices/sqlite3/storage.go
@@ -72,7 +72,7 @@ func (d *Database) GetDeviceByID(
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {
- return d.devices.selectDevicesByLocalpart(ctx, nil, localpart)
+ return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
@@ -179,14 +179,14 @@ func (d *Database) RemoveDevices(
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
- ctx context.Context, localpart string,
+ ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
- devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart)
+ devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil {
return err
}
- if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
+ if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err
}
return nil