Shortcuts

Source code for torch.amp.autocast_mode

import torch
import functools
import warnings

from typing import Any, Optional
from torch.types import _dtype

__all__ = ['autocast_decorator', 'autocast']

def autocast_decorator(autocast_instance, func):
    @functools.wraps(func)
    def decorate_autocast(*args, **kwargs):
        with autocast_instance:
            return func(*args, **kwargs)
    decorate_autocast.__script_unsupported = '@autocast() decorator is not supported in script mode'  # type: ignore[attr-defined]
    return decorate_autocast

[docs]class autocast: r""" Instances of :class:`autocast` serve as context managers or decorators that allow regions of your script to run in mixed precision. In these regions, ops run in an op-specific dtype chosen by autocast to improve performance while maintaining accuracy. See the :ref:`Autocast Op Reference<autocast-op-reference>` for details. When entering an autocast-enabled region, Tensors may be any type. You should not call ``half()`` or ``bfloat16()`` on your model(s) or inputs when using autocasting. :class:`autocast` should wrap only the forward pass(es) of your network, including the loss computation(s). Backward passes under autocast are not recommended. Backward ops run in the same type that autocast used for corresponding forward ops. Example for CUDA Devices:: # Creates model and optimizer in default precision model = Net().cuda() optimizer = optim.SGD(model.parameters(), ...) for input, target in data: optimizer.zero_grad() # Enables autocasting for the forward pass (model + loss) with autocast(): output = model(input) loss = loss_fn(output, target) # Exits the context manager before backward() loss.backward() optimizer.step() See the :ref:`CUDA Automatic Mixed Precision examples<amp-examples>` for usage (along with gradient scaling) in more complex scenarios (e.g., gradient penalty, multiple models/losses, custom autograd functions). :class:`autocast` can also be used as a decorator, e.g., on the ``forward`` method of your model:: class AutocastModel(nn.Module): ... @autocast() def forward(self, input): ... Floating-point Tensors produced in an autocast-enabled region may be ``float16``. After returning to an autocast-disabled region, using them with floating-point Tensors of different dtypes may cause type mismatch errors. If so, cast the Tensor(s) produced in the autocast region back to ``float32`` (or other dtype if desired). If a Tensor from the autocast region is already ``float32``, the cast is a no-op, and incurs no additional overhead. CUDA Example:: # Creates some tensors in default dtype (here assumed to be float32) a_float32 = torch.rand((8, 8), device="cuda") b_float32 = torch.rand((8, 8), device="cuda") c_float32 = torch.rand((8, 8), device="cuda") d_float32 = torch.rand((8, 8), device="cuda") with autocast(): # torch.mm is on autocast's list of ops that should run in float16. # Inputs are float32, but the op runs in float16 and produces float16 output. # No manual casts are required. e_float16 = torch.mm(a_float32, b_float32) # Also handles mixed input types f_float16 = torch.mm(d_float32, e_float16) # After exiting autocast, calls f_float16.float() to use with d_float32 g_float32 = torch.mm(d_float32, f_float16.float()) CPU Training Example:: # Creates model and optimizer in default precision model = Net() optimizer = optim.SGD(model.parameters(), ...) for epoch in epochs: for input, target in data: optimizer.zero_grad() # Runs the forward pass with autocasting. with torch.autocast(device_type="cpu", dtype=torch.bfloat16): output = model(input) loss = loss_fn(output, target) loss.backward() optimizer.step() CPU Inference Example:: # Creates model in default precision model = Net().eval() with torch.autocast(device_type="cpu", dtype=torch.bfloat16): for input in data: # Runs the forward pass with autocasting. output = model(input) CPU Inference Example with Jit Trace:: class TestModel(nn.Module): def __init__(self, input_size, num_classes): super().__init__() self.fc1 = nn.Linear(input_size, num_classes) def forward(self, x): return self.fc1(x) input_size = 2 num_classes = 2 model = TestModel(input_size, num_classes).eval() # For now, we suggest to disable the Jit Autocast Pass, # As the issue: https://github.com/pytorch/pytorch/issues/75956 torch._C._jit_set_autocast_mode(False) with torch.cpu.amp.autocast(cache_enabled=False): model = torch.jit.trace(model, torch.randn(1, input_size)) model = torch.jit.freeze(model) # Models Run for _ in range(3): model(torch.randn(1, input_size)) Type mismatch errors *in* an autocast-enabled region are a bug; if this is what you observe, please file an issue. ``autocast(enabled=False)`` subregions can be nested in autocast-enabled regions. Locally disabling autocast can be useful, for example, if you want to force a subregion to run in a particular ``dtype``. Disabling autocast gives you explicit control over the execution type. In the subregion, inputs from the surrounding region should be cast to ``dtype`` before use:: # Creates some tensors in default dtype (here assumed to be float32) a_float32 = torch.rand((8, 8), device="cuda") b_float32 = torch.rand((8, 8), device="cuda") c_float32 = torch.rand((8, 8), device="cuda") d_float32 = torch.rand((8, 8), device="cuda") with autocast(): e_float16 = torch.mm(a_float32, b_float32) with autocast(enabled=False): # Calls e_float16.float() to ensure float32 execution # (necessary because e_float16 was created in an autocasted region) f_float32 = torch.mm(c_float32, e_float16.float()) # No manual casts are required when re-entering the autocast-enabled region. # torch.mm again runs in float16 and produces float16 output, regardless of input types. g_float16 = torch.mm(d_float32, f_float32) The autocast state is thread-local. If you want it enabled in a new thread, the context manager or decorator must be invoked in that thread. This affects :class:`torch.nn.DataParallel` and :class:`torch.nn.parallel.DistributedDataParallel` when used with more than one GPU per process (see :ref:`Working with Multiple GPUs<amp-multigpu>`). Args: device_type(str, required): Whether to use 'cuda' or 'cpu' device enabled(bool, optional): Whether autocasting should be enabled in the region. Default: ``True`` dtype(torch_dtype, optional): Whether to use torch.float16 or torch.bfloat16. cache_enabled(bool, optional): Whether the weight cache inside autocast should be enabled. Default: ``True`` """ def __init__(self, device_type : str, dtype : Optional[_dtype] = None, enabled : bool = True, cache_enabled : Optional[bool] = None): if torch._jit_internal.is_scripting(): self._enabled = enabled self.device = device_type self.fast_dtype = dtype # TODO: support get_autocast_gpu/cpu_dtype assert dtype is not None return self.device = device_type if self.device == 'cuda': self.fast_dtype = torch.get_autocast_gpu_dtype() elif self.device == 'cpu': self.fast_dtype = torch.get_autocast_cpu_dtype() elif self.device == 'xpu': self.fast_dtype = torch.xpu.get_autocast_xpu_dtype() # type: ignore[attr-defined] elif self.device == 'hpu': self.fast_dtype = torch.hpu.get_autocast_hpu_dtype() # type: ignore[attr-defined] else: raise RuntimeError('User specified autocast device_type must be \'cuda\' or \'cpu\'') self._cache_enabled = torch.is_autocast_cache_enabled() if enabled and torch.cuda.amp.common.amp_definitely_not_available() and self.device == 'cuda': warnings.warn('User provided device_type of \'cuda\', but CUDA is not available. Disabling') enabled = False if dtype is not None: self.fast_dtype = dtype if cache_enabled is not None: self._cache_enabled = cache_enabled if self.device == 'cpu': supported_dtype = [torch.bfloat16] if self.fast_dtype not in supported_dtype: error_message = 'In CPU autocast, but the target dtype is not supported. Disabling autocast.\n' error_message += 'CPU Autocast only supports dtype of torch.bfloat16 currently.' warnings.warn(error_message) enabled = False elif self.device == 'xpu': supported_dtype = [torch.bfloat16, torch.float16] if self.fast_dtype not in supported_dtype: error_message = 'In XPU autocast, but the target dtype is not supported. Disabling autocast.\n' error_message += 'XPU Autocast only supports dtype of torch.bfloat16 currently.' warnings.warn(error_message) enabled = False elif self.device == 'hpu': supported_dtype = [torch.bfloat16, torch.float16] if self.fast_dtype not in supported_dtype: error_message = 'In HPU autocast, but the target dtype is not supported. Disabling autocast.\n' error_message += 'HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently.' warnings.warn(error_message) enabled = False elif self.device == 'cuda': if self.fast_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported(): raise RuntimeError('Current CUDA Device does not support bfloat16. Please switch dtype to float16.') self._enabled = enabled def __enter__(self): if torch._jit_internal.is_scripting(): assert self.fast_dtype is not None return self self.prev_cache_enabled = torch.is_autocast_cache_enabled() if self.device == 'cpu': self.prev = torch.is_autocast_cpu_enabled() self.prev_fastdtype = torch.get_autocast_cpu_dtype() torch.set_autocast_cpu_enabled(self._enabled) torch.set_autocast_cpu_dtype(self.fast_dtype) # type: ignore[arg-type] torch.autocast_increment_nesting() elif self.device == 'xpu': self.prev = torch.xpu.is_autocast_xpu_enabled() # type: ignore[attr-defined] self.prev_fastdtype = torch.xpu.get_autocast_xpu_dtype() # type: ignore[attr-defined] torch.xpu.set_autocast_xpu_enabled(self._enabled) # type: ignore[attr-defined] torch.xpu.set_autocast_xpu_dtype(self.fast_dtype) # type: ignore[attr-defined] torch.autocast_increment_nesting() elif self.device == 'hpu': self.prev = torch.hpu.is_autocast_hpu_enabled() # type: ignore[attr-defined] self.prev_fastdtype = torch.hpu.get_autocast_hpu_dtype() # type: ignore[attr-defined] torch.hpu.set_autocast_hpu_enabled(self._enabled) # type: ignore[attr-defined] torch.hpu.set_autocast_hpu_dtype(self.fast_dtype) # type: ignore[attr-defined] torch.autocast_increment_nesting() else: self.prev = torch.is_autocast_enabled() self.prev_fastdtype = torch.get_autocast_gpu_dtype() torch.set_autocast_gpu_dtype(self.fast_dtype) # type: ignore[arg-type] torch.set_autocast_enabled(self._enabled) torch.autocast_increment_nesting() torch.set_autocast_cache_enabled(self._cache_enabled) def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override] if torch._jit_internal.is_scripting(): return # Drop the cache when we exit to a nesting level that's outside any instance of autocast. if self.device == 'cpu': if torch.autocast_decrement_nesting() == 0: torch.clear_autocast_cache() torch.set_autocast_cpu_enabled(self.prev) torch.set_autocast_cpu_dtype(self.prev_fastdtype) elif self.device == 'xpu': if torch.autocast_decrement_nesting() == 0: torch.clear_autocast_cache() torch.xpu.set_autocast_xpu_enabled(self.prev) # type: ignore[attr-defined] torch.xpu.set_autocast_xpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined] elif self.device == 'hpu': if torch.autocast_decrement_nesting() == 0: torch.clear_autocast_cache() torch.hpu.set_autocast_hpu_enabled(self.prev) # type: ignore[attr-defined] torch.hpu.set_autocast_hpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined] else: if torch.autocast_decrement_nesting() == 0: torch.clear_autocast_cache() torch.set_autocast_enabled(self.prev) torch.set_autocast_gpu_dtype(self.prev_fastdtype) torch.set_autocast_cache_enabled(self.prev_cache_enabled) return False def __call__(self, func): if torch._jit_internal.is_scripting(): return func return autocast_decorator(self, func)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources