autosim.device#

Torch device helpers for AutoSim simulators.

exception TorchDeviceError(device)[source]#

Bases: NotImplementedError

Exception raised when the device is not implemented in torch.

Parameters:

device (str)

get_torch_device(device)[source]#

Get the device returning the torch default device if None.

Parameters:

device (str | device | None) – The device to get. If None, the default torch device is returned.

Returns:

The device.

Raises:

TorchDeviceError – If the device is not a valid torch device.

Return type:

device

move_tensors_to_device(*args, device)[source]#

Move the given tensor to the device.

Parameters:
  • *args (Tensor) – The tensors to move.

  • device (device) – The device to move the tensors to.

Returns:

The tensors on the device.

Return type:

tuple[Tensor, …]

check_torch_device_is_available(device)[source]#

Check if the given device type is available.

Parameters:

device (str | device) – The device to check.

Returns:

True if the device is available, False otherwise.

Raises:

TorchDeviceError – If the device is not a valid torch device.

Return type:

bool

check_model_device(model, expected_device)[source]#

Check if the model is on the expected device.

Parameters:
  • model (Module) – The model to check.

  • expected_device (str) – The expected device.

Returns:

True if the model is on the expected device (ignoring device index), False otherwise.

Return type:

bool

class TorchDeviceMixin(device=None, cpu_only=False)[source]#

Bases: object

Mixin class to add device management to a PyTorch model.

device#

The device to use. If None, the default torch device is used.

Raises:

TorchDeviceError – If the device is not a valid torch device.

Parameters: