From f023cdf8c42cc1a4bb850b478dbbf7d901b5e1bd Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Wed, 27 Apr 2022 15:05:49 +0200 Subject: Add UserAPI storage tests (#2384) * Add tests for parts of the userapi storage * Add tests for keybackup * Add LoginToken tests * Add OpenID tests * Add profile tests * Add pusher tests * Add ThreePID tests * Add notification tests * Add more device tests, fix numeric localpart query * Fix failing CI * Fix numeric local part query --- userapi/storage/postgres/accounts_table.go | 6 ++---- userapi/storage/postgres/devices_table.go | 22 +++++++++++++--------- 2 files changed, 15 insertions(+), 13 deletions(-) (limited to 'userapi/storage/postgres') diff --git a/userapi/storage/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go index 92311d56..f86812f1 100644 --- a/userapi/storage/postgres/accounts_table.go +++ b/userapi/storage/postgres/accounts_table.go @@ -47,8 +47,6 @@ CREATE TABLE IF NOT EXISTS account_accounts ( -- TODO: -- upgraded_ts, devices, any email reset stuff? ); --- Create sequence for autogenerated numeric usernames -CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1; ` const insertAccountSQL = "" + @@ -67,7 +65,7 @@ const selectPasswordHashSQL = "" + "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE" const selectNewNumericLocalpartSQL = "" + - "SELECT nextval('numeric_username_seq')" + "SELECT COALESCE(MAX(localpart::integer), 0) FROM account_accounts WHERE localpart ~ '^[0-9]*$'" type accountsStatements struct { insertAccountStmt *sql.Stmt @@ -178,5 +176,5 @@ func (s *accountsStatements) SelectNewNumericLocalpart( stmt = sqlutil.TxStmt(txn, stmt) } err = stmt.QueryRowContext(ctx).Scan(&id) - return + return id + 1, err } diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go index 7bc5dc69..fe8c54e0 100644 --- a/userapi/storage/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -78,7 +78,7 @@ const selectDeviceByIDSQL = "" + "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2" + "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" const updateDeviceNameSQL = "" + "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" @@ -93,7 +93,7 @@ const deleteDevicesSQL = "" + "DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)" const selectDevicesByIDSQL = "" + - "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id = ANY($1)" + "SELECT device_id, localpart, display_name, last_seen_ts FROM device_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC" const updateDeviceLastSeen = "" + "UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4" @@ -235,16 +235,20 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s } defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed") var devices []api.Device + var dev api.Device + var localpart string + var lastseents sql.NullInt64 + var displayName sql.NullString for rows.Next() { - var dev api.Device - var localpart string - var displayName sql.NullString - if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil { + if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil { return nil, err } if displayName.Valid { dev.DisplayName = displayName.String } + if lastseents.Valid { + dev.LastSeenTS = lastseents.Int64 + } dev.UserID = userutil.MakeUserID(localpart, s.serverName) devices = append(devices, dev) } @@ -262,10 +266,10 @@ func (s *devicesStatements) SelectDevicesByLocalpart( } defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByLocalpart: rows.close() failed") + var dev api.Device + var lastseents sql.NullInt64 + var id, displayname, ip, useragent sql.NullString for rows.Next() { - var dev api.Device - var lastseents sql.NullInt64 - var id, displayname, ip, useragent sql.NullString err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent) if err != nil { return devices, err -- cgit v1.2.3