diff options
Diffstat (limited to 'userapi/storage/devices/postgres/devices_table.go')
-rw-r--r-- | userapi/storage/devices/postgres/devices_table.go | 36 |
1 files changed, 35 insertions, 1 deletions
diff --git a/userapi/storage/devices/postgres/devices_table.go b/userapi/storage/devices/postgres/devices_table.go index 1d036d1b..03bf7c72 100644 --- a/userapi/storage/devices/postgres/devices_table.go +++ b/userapi/storage/devices/postgres/devices_table.go @@ -84,11 +84,15 @@ const deleteDevicesByLocalpartSQL = "" + const deleteDevicesSQL = "" + "DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)" +const selectDevicesByIDSQL = "" + + "SELECT device_id, localpart, display_name FROM device_devices WHERE device_id = ANY($1)" + type devicesStatements struct { insertDeviceStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt selectDeviceByIDStmt *sql.Stmt selectDevicesByLocalpartStmt *sql.Stmt + selectDevicesByIDStmt *sql.Stmt updateDeviceNameStmt *sql.Stmt deleteDeviceStmt *sql.Stmt deleteDevicesByLocalpartStmt *sql.Stmt @@ -125,6 +129,9 @@ func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerN if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil { return } + if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil { + return + } s.serverName = server return } @@ -207,15 +214,42 @@ func (s *devicesStatements) selectDeviceByID( ctx context.Context, localpart, deviceID string, ) (*api.Device, error) { var dev api.Device + var displayName sql.NullString stmt := s.selectDeviceByIDStmt - err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName) + err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName) if err == nil { dev.ID = deviceID dev.UserID = userutil.MakeUserID(localpart, s.serverName) + if displayName.Valid { + dev.DisplayName = displayName.String + } } return &dev, err } +func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { + rows, err := s.selectDevicesByIDStmt.QueryContext(ctx, pq.StringArray(deviceIDs)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed") + var devices []api.Device + for rows.Next() { + var dev api.Device + var localpart string + var displayName sql.NullString + if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil { + return nil, err + } + if displayName.Valid { + dev.DisplayName = displayName.String + } + dev.UserID = userutil.MakeUserID(localpart, s.serverName) + devices = append(devices, dev) + } + return devices, rows.Err() +} + func (s *devicesStatements) selectDevicesByLocalpart( ctx context.Context, localpart string, ) ([]api.Device, error) { |