aboutsummaryrefslogtreecommitdiff
path: root/userapi/storage/devices/postgres/devices_table.go
diff options
context:
space:
mode:
Diffstat (limited to 'userapi/storage/devices/postgres/devices_table.go')
-rw-r--r--userapi/storage/devices/postgres/devices_table.go36
1 files changed, 35 insertions, 1 deletions
diff --git a/userapi/storage/devices/postgres/devices_table.go b/userapi/storage/devices/postgres/devices_table.go
index 1d036d1b..03bf7c72 100644
--- a/userapi/storage/devices/postgres/devices_table.go
+++ b/userapi/storage/devices/postgres/devices_table.go
@@ -84,11 +84,15 @@ const deleteDevicesByLocalpartSQL = "" +
const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
+const selectDevicesByIDSQL = "" +
+ "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id = ANY($1)"
+
type devicesStatements struct {
insertDeviceStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt
selectDeviceByIDStmt *sql.Stmt
selectDevicesByLocalpartStmt *sql.Stmt
+ selectDevicesByIDStmt *sql.Stmt
updateDeviceNameStmt *sql.Stmt
deleteDeviceStmt *sql.Stmt
deleteDevicesByLocalpartStmt *sql.Stmt
@@ -125,6 +129,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN
if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil {
return
}
+ if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
+ return
+ }
s.serverName = server
return
}
@@ -207,15 +214,42 @@ func (s *devicesStatements) selectDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
var dev api.Device
+ var displayName sql.NullString
stmt := s.selectDeviceByIDStmt
- err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName)
+ err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName)
if err == nil {
dev.ID = deviceID
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
+ if displayName.Valid {
+ dev.DisplayName = displayName.String
+ }
}
return &dev, err
}
+func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
+ rows, err := s.selectDevicesByIDStmt.QueryContext(ctx, pq.StringArray(deviceIDs))
+ if err != nil {
+ return nil, err
+ }
+ defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed")
+ var devices []api.Device
+ for rows.Next() {
+ var dev api.Device
+ var localpart string
+ var displayName sql.NullString
+ if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil {
+ return nil, err
+ }
+ if displayName.Valid {
+ dev.DisplayName = displayName.String
+ }
+ dev.UserID = userutil.MakeUserID(localpart, s.serverName)
+ devices = append(devices, dev)
+ }
+ return devices, rows.Err()
+}
+
func (s *devicesStatements) selectDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {