get_dtype¶
- torchtune.utils.get_dtype(dtype: Optional[str] = None, device: Optional[device] = None) dtype [source]¶
Get the torch.dtype corresponding to the given precision string. If no string is passed, we will default to torch.float32.
Note
If bf16 precision is requested with a CUDA device, we verify whether the device indeed supports bf16 kernels. If not, a
RuntimeError
is raised.- Parameters:
dtype (Optional[str]) – The precision dtype. Default:
None
, in which we default to torch.float32device (Optional[torch.device]) – Device in use for training. Only CUDA and CPU devices are supported. If a CUDA device is passed in, additional checking is done to ensure that the device supports the requested precision. Default:
None
, in which case a CUDA device is assumed.
- Raises:
ValueError – if precision isn’t supported by the library
RuntimeError – if bf16 precision is requested but not available on this hardware.
- Returns:
The corresponding torch.dtype.
- Return type: