Shortcuts

get_dtype

torchtune.utils.get_dtype(dtype: Optional[str] = None, device: Optional[device] = None) dtype[source]

Get the torch.dtype corresponding to the given precision string. If no string is passed, we will default to torch.float32.

Note

If bf16 precision is requested with a CUDA device, we verify whether the device indeed supports bf16 kernels. If not, a RuntimeError is raised.

Parameters:
  • dtype (Optional[str]) – The precision dtype. Default: None, in which we default to torch.float32

  • device (Optional[torch.device]) – Device in use for training. Only CUDA and CPU devices are supported. If a CUDA device is passed in, additional checking is done to ensure that the device supports the requested precision. Default: None, in which case a CUDA device is assumed.

Raises:
  • ValueError – if precision isn’t supported by the library

  • RuntimeError – if bf16 precision is requested but not available on this hardware.

Returns:

The corresponding torch.dtype.

Return type:

torch.dtype

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources