aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTill <2353100+S7evinK@users.noreply.github.com>2023-02-17 11:39:46 +0100
committerGitHub <noreply@github.com>2023-02-17 11:39:46 +0100
commitf0805071d54b3fdcc8e8d4c5f4238ae0af1bad77 (patch)
tree9166c87eb19ad84b66ff05a42312d07ca28f7977
parent11d9b9db0e96c51c1430d451d23cf5ae9f36e4ee (diff)
Fix SQLite `session_id` (#2977)
This fixes an issue with device_id/session_ids. If a `device_id` is reused, we would reuse the same `session_id`, since we delete one device and insert a new one directly, resulting in the query to get a new `session_id` to return the previous session_id. (`SELECT count(access_token)`)
-rw-r--r--userapi/storage/postgres/devices_table.go16
-rw-r--r--userapi/storage/shared/storage.go44
-rw-r--r--userapi/storage/sqlite3/devices_table.go29
-rw-r--r--userapi/storage/tables/interface.go1
4 files changed, 73 insertions, 17 deletions
diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go
index 2dd21618..7481ac5b 100644
--- a/userapi/storage/postgres/devices_table.go
+++ b/userapi/storage/postgres/devices_table.go
@@ -81,7 +81,7 @@ const selectDeviceByIDSQL = "" +
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3"
const selectDevicesByLocalpartSQL = "" +
- "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC"
+ "SELECT device_id, display_name, last_seen_ts, ip, user_agent, session_id FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC"
const updateDeviceNameSQL = "" +
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4"
@@ -96,7 +96,7 @@ const deleteDevicesSQL = "" +
"DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = ANY($3)"
const selectDevicesByIDSQL = "" +
- "SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC"
+ "SELECT device_id, localpart, server_name, display_name, last_seen_ts, session_id FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC"
const updateDeviceLastSeen = "" +
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6"
@@ -171,6 +171,14 @@ func (s *devicesStatements) InsertDevice(
}, nil
}
+func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ accessToken string, displayName *string, ipAddr, userAgent string,
+ sessionID int64,
+) (*api.Device, error) {
+ return s.InsertDevice(ctx, txn, id, localpart, serverName, accessToken, displayName, ipAddr, userAgent)
+}
+
// deleteDevice removes a single device by id and user localpart.
func (s *devicesStatements) DeleteDevice(
ctx context.Context, txn *sql.Tx, id string,
@@ -271,7 +279,7 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
var lastseents sql.NullInt64
var displayName sql.NullString
for rows.Next() {
- if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil {
+ if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents, &dev.SessionID); err != nil {
return nil, err
}
if displayName.Valid {
@@ -303,7 +311,7 @@ func (s *devicesStatements) SelectDevicesByLocalpart(
var lastseents sql.NullInt64
var id, displayname, ip, useragent sql.NullString
for rows.Next() {
- err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent)
+ err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent, &dev.SessionID)
if err != nil {
return devices, err
}
diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go
index f549dcef..bf94f14d 100644
--- a/userapi/storage/shared/storage.go
+++ b/userapi/storage/shared/storage.go
@@ -588,16 +588,42 @@ func (d *Database) CreateDevice(
deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string,
) (dev *api.Device, returnErr error) {
if deviceID != nil {
- returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
- var err error
- // Revoke existing tokens for this device
- if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart, serverName); err != nil {
+ _, ok := d.Writer.(*sqlutil.ExclusiveWriter)
+ if ok { // we're using most likely using SQLite, so do things a little different
+ returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ var err error
+
+ devices, err := d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, serverName, "")
+ if err != nil && !errors.Is(err, sql.ErrNoRows) {
+ return err
+ }
+ // No devices yet, only create a new one
+ if len(devices) == 0 {
+ dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent)
+ return err
+ }
+ sessionID := devices[0].SessionID + 1
+
+ // Revoke existing tokens for this device
+ if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart, serverName); err != nil {
+ return err
+ }
+ // Create a new device with the session ID incremented
+ dev, err = d.Devices.InsertDeviceWithSessionID(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent, sessionID)
return err
- }
+ })
+ } else {
+ returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
+ var err error
+ // Revoke existing tokens for this device
+ if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart, serverName); err != nil {
+ return err
+ }
- dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent)
- return err
- })
+ dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent)
+ return err
+ })
+ }
} else {
// We generate device IDs in a loop in case its already taken.
// We cap this at going round 5 times to ensure we don't spin forever
@@ -618,7 +644,7 @@ func (d *Database) CreateDevice(
}
}
}
- return
+ return dev, returnErr
}
// generateDeviceID creates a new device id. Returns an error if failed to generate
diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go
index c5db34bd..449e4549 100644
--- a/userapi/storage/sqlite3/devices_table.go
+++ b/userapi/storage/sqlite3/devices_table.go
@@ -65,7 +65,7 @@ const selectDeviceByIDSQL = "" +
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3"
const selectDevicesByLocalpartSQL = "" +
- "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC"
+ "SELECT device_id, display_name, last_seen_ts, ip, user_agent, session_id FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC"
const updateDeviceNameSQL = "" +
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4"
@@ -80,7 +80,7 @@ const deleteDevicesSQL = "" +
"DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id IN ($3)"
const selectDevicesByIDSQL = "" +
- "SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
+ "SELECT device_id, localpart, server_name, display_name, last_seen_ts, session_id FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
const updateDeviceLastSeen = "" +
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6"
@@ -162,6 +162,27 @@ func (s *devicesStatements) InsertDevice(
}, nil
}
+func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id,
+ localpart string, serverName gomatrixserverlib.ServerName,
+ accessToken string, displayName *string, ipAddr, userAgent string,
+ sessionID int64,
+) (*api.Device, error) {
+ createdTimeMS := time.Now().UnixNano() / 1000000
+ insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
+ if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
+ return nil, err
+ }
+ return &api.Device{
+ ID: id,
+ UserID: userutil.MakeUserID(localpart, serverName),
+ AccessToken: accessToken,
+ SessionID: sessionID,
+ LastSeenTS: createdTimeMS,
+ LastSeenIP: ipAddr,
+ UserAgent: userAgent,
+ }, nil
+}
+
func (s *devicesStatements) DeleteDevice(
ctx context.Context, txn *sql.Tx, id string,
localpart string, serverName gomatrixserverlib.ServerName,
@@ -271,7 +292,7 @@ func (s *devicesStatements) SelectDevicesByLocalpart(
var lastseents sql.NullInt64
var id, displayname, ip, useragent sql.NullString
for rows.Next() {
- err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent)
+ err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent, &dev.SessionID)
if err != nil {
return devices, err
}
@@ -317,7 +338,7 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
var displayName sql.NullString
var lastseents sql.NullInt64
for rows.Next() {
- if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil {
+ if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents, &dev.SessionID); err != nil {
return nil, err
}
if displayName.Valid {
diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go
index e14776cf..9221e571 100644
--- a/userapi/storage/tables/interface.go
+++ b/userapi/storage/tables/interface.go
@@ -44,6 +44,7 @@ type AccountsTable interface {
type DevicesTable interface {
InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error)
+ InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName, accessToken string, displayName *string, ipAddr, userAgent string, sessionID int64) (*api.Device, error)
DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName) error
DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error
DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) error