set_default_dtype¶
- torchtune.utils.set_default_dtype(dtype: dtype) Generator[None, None, None] [source]¶
Context manager to set torch’s default dtype.
- Parameters:
dtype (
torch.dtype
) – The desired default dtype inside the context manager.- Returns:
context manager for setting default dtype.
- Return type:
ContextManager
Example
>>> with set_default_dtype(torch.bfloat16): >>> x = torch.tensor([1, 2, 3]) >>> x.dtype torch.bfloat16