aboutsummaryrefslogtreecommitdiff
path: root/userapi
diff options
context:
space:
mode:
authorNeil Alexander <neilalexander@users.noreply.github.com>2020-09-04 15:16:13 +0100
committerGitHub <noreply@github.com>2020-09-04 15:16:13 +0100
commit5076925c184998414c3691e97fc21b554abf4a55 (patch)
treeb4da4d8a3015166e16eb9a6009f733653f0485bc /userapi
parentca8dcf46b746686e213b184c3ae42ba0be17b46b (diff)
Password changes (#1397)
* User API support for password changes * Password changes in client API * Update sytest-whitelist * Remove debug logging * Default logout_devices to true * Fix deleting devices by local part
Diffstat (limited to 'userapi')
-rw-r--r--userapi/api/api.go17
-rw-r--r--userapi/internal/api.go11
-rw-r--r--userapi/inthttp/client.go13
-rw-r--r--userapi/inthttp/server.go13
-rw-r--r--userapi/storage/accounts/interface.go1
-rw-r--r--userapi/storage/accounts/postgres/accounts_table.go16
-rw-r--r--userapi/storage/accounts/postgres/storage.go11
-rw-r--r--userapi/storage/accounts/sqlite3/accounts_table.go16
-rw-r--r--userapi/storage/accounts/sqlite3/storage.go12
-rw-r--r--userapi/storage/devices/interface.go2
-rw-r--r--userapi/storage/devices/postgres/devices_table.go12
-rw-r--r--userapi/storage/devices/postgres/storage.go8
-rw-r--r--userapi/storage/devices/sqlite3/devices_table.go12
-rw-r--r--userapi/storage/devices/sqlite3/storage.go8
14 files changed, 126 insertions, 26 deletions
diff --git a/userapi/api/api.go b/userapi/api/api.go
index e6d05c33..3baaa100 100644
--- a/userapi/api/api.go
+++ b/userapi/api/api.go
@@ -26,6 +26,7 @@ import (
type UserInternalAPI interface {
InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error
PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error
+ PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
PerformDeviceDeletion(ctx context.Context, req *PerformDeviceDeletionRequest, res *PerformDeviceDeletionResponse) error
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
@@ -63,6 +64,10 @@ type PerformDeviceDeletionRequest struct {
UserID string
// The devices to delete. An empty slice means delete all devices.
DeviceIDs []string
+ // The requesting device ID to exclude from deletion. This is needed
+ // so that a password change doesn't cause that client to be logged
+ // out. Only specify when DeviceIDs is empty.
+ ExceptDeviceID string
}
type PerformDeviceDeletionResponse struct {
@@ -165,6 +170,18 @@ type PerformAccountCreationResponse struct {
Account *Account
}
+// PerformAccountCreationRequest is the request for PerformAccountCreation
+type PerformPasswordUpdateRequest struct {
+ Localpart string // Required: The localpart for this account.
+ Password string // Required: The new password to set.
+}
+
+// PerformAccountCreationResponse is the response for PerformAccountCreation
+type PerformPasswordUpdateResponse struct {
+ PasswordUpdated bool
+ Account *Account
+}
+
// PerformDeviceCreationRequest is the request for PerformDeviceCreation
type PerformDeviceCreationRequest struct {
Localpart string
diff --git a/userapi/internal/api.go b/userapi/internal/api.go
index b97f148e..461c548c 100644
--- a/userapi/internal/api.go
+++ b/userapi/internal/api.go
@@ -98,6 +98,15 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
res.Account = acc
return nil
}
+
+func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
+ if err := a.AccountDB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
+ return err
+ }
+ res.PasswordUpdated = true
+ return nil
+}
+
func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.PerformDeviceCreationRequest, res *api.PerformDeviceCreationResponse) error {
util.GetLogger(ctx).WithFields(logrus.Fields{
"localpart": req.Localpart,
@@ -126,7 +135,7 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
deletedDeviceIDs := req.DeviceIDs
if len(req.DeviceIDs) == 0 {
var devices []api.Device
- devices, err = a.DeviceDB.RemoveAllDevices(ctx, local)
+ devices, err = a.DeviceDB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
for _, d := range devices {
deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
}
diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go
index 5f4df0eb..6dcaf756 100644
--- a/userapi/inthttp/client.go
+++ b/userapi/inthttp/client.go
@@ -30,6 +30,7 @@ const (
PerformDeviceCreationPath = "/userapi/performDeviceCreation"
PerformAccountCreationPath = "/userapi/performAccountCreation"
+ PerformPasswordUpdatePath = "/userapi/performPasswordUpdate"
PerformDeviceDeletionPath = "/userapi/performDeviceDeletion"
PerformDeviceUpdatePath = "/userapi/performDeviceUpdate"
@@ -81,6 +82,18 @@ func (h *httpUserInternalAPI) PerformAccountCreation(
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
+func (h *httpUserInternalAPI) PerformPasswordUpdate(
+ ctx context.Context,
+ request *api.PerformPasswordUpdateRequest,
+ response *api.PerformPasswordUpdateResponse,
+) error {
+ span, ctx := opentracing.StartSpanFromContext(ctx, "PerformPasswordUpdate")
+ defer span.Finish()
+
+ apiURL := h.apiURL + PerformPasswordUpdatePath
+ return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
+}
+
func (h *httpUserInternalAPI) PerformDeviceCreation(
ctx context.Context,
request *api.PerformDeviceCreationRequest,
diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go
index 47d68ff2..d2674678 100644
--- a/userapi/inthttp/server.go
+++ b/userapi/inthttp/server.go
@@ -39,6 +39,19 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
+ internalAPIMux.Handle(PerformAccountCreationPath,
+ httputil.MakeInternalAPI("performPasswordUpdate", func(req *http.Request) util.JSONResponse {
+ request := api.PerformPasswordUpdateRequest{}
+ response := api.PerformPasswordUpdateResponse{}
+ if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
+ return util.MessageResponse(http.StatusBadRequest, err.Error())
+ }
+ if err := s.PerformPasswordUpdate(req.Context(), &request, &response); err != nil {
+ return util.ErrorResponse(err)
+ }
+ return util.JSONResponse{Code: http.StatusOK, JSON: &response}
+ }),
+ )
internalAPIMux.Handle(PerformDeviceCreationPath,
httputil.MakeInternalAPI("performDeviceCreation", func(req *http.Request) util.JSONResponse {
request := api.PerformDeviceCreationRequest{}
diff --git a/userapi/storage/accounts/interface.go b/userapi/storage/accounts/interface.go
index 86b91e60..49446f11 100644
--- a/userapi/storage/accounts/interface.go
+++ b/userapi/storage/accounts/interface.go
@@ -28,6 +28,7 @@ type Database interface {
internal.PartitionStorer
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
+ SetPassword(ctx context.Context, localpart string, plaintextPassword string) error
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
SetDisplayName(ctx context.Context, localpart string, displayName string) error
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
diff --git a/userapi/storage/accounts/postgres/accounts_table.go b/userapi/storage/accounts/postgres/accounts_table.go
index 931ffb73..8c8d32cf 100644
--- a/userapi/storage/accounts/postgres/accounts_table.go
+++ b/userapi/storage/accounts/postgres/accounts_table.go
@@ -47,6 +47,9 @@ CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
+const updatePasswordSQL = "" +
+ "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
+
const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
@@ -56,10 +59,9 @@ const selectPasswordHashSQL = "" +
const selectNewNumericLocalpartSQL = "" +
"SELECT nextval('numeric_username_seq')"
-// TODO: Update password
-
type accountsStatements struct {
insertAccountStmt *sql.Stmt
+ updatePasswordStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt
@@ -74,6 +76,9 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil {
return
}
+ if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil {
+ return
+ }
if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil {
return
}
@@ -114,6 +119,13 @@ func (s *accountsStatements) insertAccount(
}, nil
}
+func (s *accountsStatements) updatePassword(
+ ctx context.Context, localpart, passwordHash string,
+) (err error) {
+ _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
+ return
+}
+
func (s *accountsStatements) selectPasswordHash(
ctx context.Context, localpart string,
) (hash string, err error) {
diff --git a/userapi/storage/accounts/postgres/storage.go b/userapi/storage/accounts/postgres/storage.go
index b36264dd..8b9ebef8 100644
--- a/userapi/storage/accounts/postgres/storage.go
+++ b/userapi/storage/accounts/postgres/storage.go
@@ -112,6 +112,17 @@ func (d *Database) SetDisplayName(
return d.profiles.setDisplayName(ctx, localpart, displayName)
}
+// SetPassword sets the account password to the given hash.
+func (d *Database) SetPassword(
+ ctx context.Context, localpart, plaintextPassword string,
+) error {
+ hash, err := hashPassword(plaintextPassword)
+ if err != nil {
+ return err
+ }
+ return d.accounts.updatePassword(ctx, localpart, hash)
+}
+
// CreateGuestAccount makes a new guest account and creates an empty profile
// for this account.
func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) {
diff --git a/userapi/storage/accounts/sqlite3/accounts_table.go b/userapi/storage/accounts/sqlite3/accounts_table.go
index 798a6de9..fbbdc337 100644
--- a/userapi/storage/accounts/sqlite3/accounts_table.go
+++ b/userapi/storage/accounts/sqlite3/accounts_table.go
@@ -45,6 +45,9 @@ CREATE TABLE IF NOT EXISTS account_accounts (
const insertAccountSQL = "" +
"INSERT INTO account_accounts(localpart, created_ts, password_hash, appservice_id) VALUES ($1, $2, $3, $4)"
+const updatePasswordSQL = "" +
+ "UPDATE account_accounts SET password_hash = $1 WHERE localpart = $2"
+
const selectAccountByLocalpartSQL = "" +
"SELECT localpart, appservice_id FROM account_accounts WHERE localpart = $1"
@@ -54,11 +57,10 @@ const selectPasswordHashSQL = "" +
const selectNewNumericLocalpartSQL = "" +
"SELECT COUNT(localpart) FROM account_accounts"
-// TODO: Update password
-
type accountsStatements struct {
db *sql.DB
insertAccountStmt *sql.Stmt
+ updatePasswordStmt *sql.Stmt
selectAccountByLocalpartStmt *sql.Stmt
selectPasswordHashStmt *sql.Stmt
selectNewNumericLocalpartStmt *sql.Stmt
@@ -75,6 +77,9 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
if s.insertAccountStmt, err = db.Prepare(insertAccountSQL); err != nil {
return
}
+ if s.updatePasswordStmt, err = db.Prepare(updatePasswordSQL); err != nil {
+ return
+ }
if s.selectAccountByLocalpartStmt, err = db.Prepare(selectAccountByLocalpartSQL); err != nil {
return
}
@@ -115,6 +120,13 @@ func (s *accountsStatements) insertAccount(
}, nil
}
+func (s *accountsStatements) updatePassword(
+ ctx context.Context, localpart, passwordHash string,
+) (err error) {
+ _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
+ return
+}
+
func (s *accountsStatements) selectPasswordHash(
ctx context.Context, localpart string,
) (hash string, err error) {
diff --git a/userapi/storage/accounts/sqlite3/storage.go b/userapi/storage/accounts/sqlite3/storage.go
index 46106297..4b66304c 100644
--- a/userapi/storage/accounts/sqlite3/storage.go
+++ b/userapi/storage/accounts/sqlite3/storage.go
@@ -126,6 +126,18 @@ func (d *Database) SetDisplayName(
})
}
+// SetPassword sets the account password to the given hash.
+func (d *Database) SetPassword(
+ ctx context.Context, localpart, plaintextPassword string,
+) error {
+ hash, err := hashPassword(plaintextPassword)
+ if err != nil {
+ return err
+ }
+ err = d.accounts.updatePassword(ctx, localpart, hash)
+ return err
+}
+
// CreateGuestAccount makes a new guest account and creates an empty profile
// for this account.
func (d *Database) CreateGuestAccount(ctx context.Context) (acc *api.Account, err error) {
diff --git a/userapi/storage/devices/interface.go b/userapi/storage/devices/interface.go
index 9b4261c9..168c84c5 100644
--- a/userapi/storage/devices/interface.go
+++ b/userapi/storage/devices/interface.go
@@ -36,5 +36,5 @@ type Database interface {
RemoveDevice(ctx context.Context, deviceID, localpart string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
- RemoveAllDevices(ctx context.Context, localpart string) (devices []api.Device, err error)
+ RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
}
diff --git a/userapi/storage/devices/postgres/devices_table.go b/userapi/storage/devices/postgres/devices_table.go
index 282466f8..c06af754 100644
--- a/userapi/storage/devices/postgres/devices_table.go
+++ b/userapi/storage/devices/postgres/devices_table.go
@@ -70,7 +70,7 @@ const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" +
- "SELECT device_id, display_name FROM device_devices WHERE localpart = $1"
+ "SELECT device_id, display_name FROM device_devices WHERE localpart = $1 AND device_id != $2"
const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
@@ -79,7 +79,7 @@ const deleteDeviceSQL = "" +
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
const deleteDevicesByLocalpartSQL = "" +
- "DELETE FROM device_devices WHERE localpart = $1"
+ "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
@@ -179,10 +179,10 @@ func (s *devicesStatements) deleteDevices(
// deleteDevicesByLocalpart removes all devices for the
// given user localpart.
func (s *devicesStatements) deleteDevicesByLocalpart(
- ctx context.Context, txn *sql.Tx, localpart string,
+ ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
- _, err := stmt.ExecContext(ctx, localpart)
+ _, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
return err
}
@@ -251,10 +251,10 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s
}
func (s *devicesStatements) selectDevicesByLocalpart(
- ctx context.Context, txn *sql.Tx, localpart string,
+ ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) ([]api.Device, error) {
devices := []api.Device{}
- rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart)
+ rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
if err != nil {
return devices, err
diff --git a/userapi/storage/devices/postgres/storage.go b/userapi/storage/devices/postgres/storage.go
index 04dae986..c5bd5b6c 100644
--- a/userapi/storage/devices/postgres/storage.go
+++ b/userapi/storage/devices/postgres/storage.go
@@ -68,7 +68,7 @@ func (d *Database) GetDeviceByID(
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {
- return d.devices.selectDevicesByLocalpart(ctx, nil, localpart)
+ return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
@@ -175,14 +175,14 @@ func (d *Database) RemoveDevices(
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
- ctx context.Context, localpart string,
+ ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
- devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart)
+ devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil {
return err
}
- if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
+ if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err
}
return nil
diff --git a/userapi/storage/devices/sqlite3/devices_table.go b/userapi/storage/devices/sqlite3/devices_table.go
index ecf43524..c75e1982 100644
--- a/userapi/storage/devices/sqlite3/devices_table.go
+++ b/userapi/storage/devices/sqlite3/devices_table.go
@@ -59,7 +59,7 @@ const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" +
- "SELECT device_id, display_name FROM device_devices WHERE localpart = $1"
+ "SELECT device_id, display_name FROM device_devices WHERE localpart = $1 AND device_id != $2"
const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
@@ -68,7 +68,7 @@ const deleteDeviceSQL = "" +
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
const deleteDevicesByLocalpartSQL = "" +
- "DELETE FROM device_devices WHERE localpart = $1"
+ "DELETE FROM device_devices WHERE localpart = $1 AND device_id != $2"
const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
@@ -182,10 +182,10 @@ func (s *devicesStatements) deleteDevices(
}
func (s *devicesStatements) deleteDevicesByLocalpart(
- ctx context.Context, txn *sql.Tx, localpart string,
+ ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
- _, err := stmt.ExecContext(ctx, localpart)
+ _, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
return err
}
@@ -231,10 +231,10 @@ func (s *devicesStatements) selectDeviceByID(
}
func (s *devicesStatements) selectDevicesByLocalpart(
- ctx context.Context, txn *sql.Tx, localpart string,
+ ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) ([]api.Device, error) {
devices := []api.Device{}
- rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart)
+ rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
if err != nil {
return devices, err
diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go
index f775fb66..7c6645dd 100644
--- a/userapi/storage/devices/sqlite3/storage.go
+++ b/userapi/storage/devices/sqlite3/storage.go
@@ -72,7 +72,7 @@ func (d *Database) GetDeviceByID(
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {
- return d.devices.selectDevicesByLocalpart(ctx, nil, localpart)
+ return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
@@ -179,14 +179,14 @@ func (d *Database) RemoveDevices(
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
- ctx context.Context, localpart string,
+ ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
- devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart)
+ devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil {
return err
}
- if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
+ if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err
}
return nil