aboutsummaryrefslogtreecommitdiff
path: root/userapi
diff options
context:
space:
mode:
authorTill Faelligen <tfaelligen@gmail.com>2022-04-28 15:06:34 +0200
committerTill Faelligen <tfaelligen@gmail.com>2022-04-28 15:06:34 +0200
commit8683ff78b1bee6b7c35e7befb9903c794a17510c (patch)
treea8ebbf43f9b73cc26c2224be25b6b5a53119cd4f /userapi
parent65034d1f227de45e88d39ec5a3e83d854e840875 (diff)
Make tests more reliable
Diffstat (limited to 'userapi')
-rw-r--r--userapi/storage/postgres/devices_table.go13
-rw-r--r--userapi/storage/sqlite3/devices_table.go13
-rw-r--r--userapi/storage/storage_test.go8
3 files changed, 24 insertions, 10 deletions
diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go
index fe8c54e0..6c777982 100644
--- a/userapi/storage/postgres/devices_table.go
+++ b/userapi/storage/postgres/devices_table.go
@@ -75,7 +75,7 @@ const selectDeviceByTokenSQL = "" +
"SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
const selectDeviceByIDSQL = "" +
- "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
+ "SELECT display_name, last_seen_ts, ip FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
@@ -215,15 +215,22 @@ func (s *devicesStatements) SelectDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
var dev api.Device
- var displayName sql.NullString
+ var displayName, ip sql.NullString
+ var lastseenTS sql.NullInt64
stmt := s.selectDeviceByIDStmt
- err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName)
+ err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip)
if err == nil {
dev.ID = deviceID
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
if displayName.Valid {
dev.DisplayName = displayName.String
}
+ if lastseenTS.Valid {
+ dev.LastSeenTS = lastseenTS.Int64
+ }
+ if ip.Valid {
+ dev.LastSeenIP = ip.String
+ }
}
return &dev, err
}
diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go
index 7860bd6a..b86ed1cc 100644
--- a/userapi/storage/sqlite3/devices_table.go
+++ b/userapi/storage/sqlite3/devices_table.go
@@ -60,7 +60,7 @@ const selectDeviceByTokenSQL = "" +
"SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
const selectDeviceByIDSQL = "" +
- "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
+ "SELECT display_name, last_seen_ts, ip FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
@@ -212,15 +212,22 @@ func (s *devicesStatements) SelectDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
var dev api.Device
- var displayName sql.NullString
+ var displayName, ip sql.NullString
stmt := s.selectDeviceByIDStmt
- err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName)
+ var lastseenTS sql.NullInt64
+ err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip)
if err == nil {
dev.ID = deviceID
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
if displayName.Valid {
dev.DisplayName = displayName.String
}
+ if lastseenTS.Valid {
+ dev.LastSeenTS = lastseenTS.Int64
+ }
+ if ip.Valid {
+ dev.LastSeenIP = ip.String
+ }
}
return &dev, err
}
diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go
index e6c7d35f..2eb57d0b 100644
--- a/userapi/storage/storage_test.go
+++ b/userapi/storage/storage_test.go
@@ -180,12 +180,12 @@ func Test_Devices(t *testing.T) {
deviceWithID.DisplayName = newName
deviceWithID.LastSeenIP = "127.0.0.1"
deviceWithID.LastSeenTS = int64(gomatrixserverlib.AsTimestamp(time.Now().Truncate(time.Second)))
- devices, err = db.GetDevicesByLocalpart(ctx, localpart)
+ gotDevice, err = db.GetDeviceByID(ctx, localpart, deviceWithID.ID)
assert.NoError(t, err, "unable to get device by id")
assert.Equal(t, 2, len(devices))
- assert.Equal(t, deviceWithID.DisplayName, devices[0].DisplayName)
- assert.Equal(t, deviceWithID.LastSeenIP, devices[0].LastSeenIP)
- truncatedTime := gomatrixserverlib.Timestamp(devices[0].LastSeenTS).Time().Truncate(time.Second)
+ assert.Equal(t, deviceWithID.DisplayName, gotDevice.DisplayName)
+ assert.Equal(t, deviceWithID.LastSeenIP, gotDevice.LastSeenIP)
+ truncatedTime := gomatrixserverlib.Timestamp(gotDevice.LastSeenTS).Time().Truncate(time.Second)
assert.Equal(t, gomatrixserverlib.Timestamp(deviceWithID.LastSeenTS), gomatrixserverlib.AsTimestamp(truncatedTime))
// create one more device and remove the devices step by step