[docs]classautocast(torch.amp.autocast_mode.autocast):r""" See :class:`torch.autocast`. ``torch.cuda.amp.autocast(args...)`` is equivalent to ``torch.autocast("cuda", args...)`` """def__init__(self,enabled:bool=True,dtype:torch.dtype=torch.float16,cache_enabled:bool=True,):iftorch._jit_internal.is_scripting():self._enabled=enabledself.device="cuda"self.fast_dtype=dtypereturnsuper().__init__("cuda",enabled=enabled,dtype=dtype,cache_enabled=cache_enabled)def__enter__(self):iftorch._jit_internal.is_scripting():returnselfreturnsuper().__enter__()# TODO: discuss a unified TorchScript-friendly API for autocastdef__exit__(self,exc_type:Any,exc_val:Any,exc_tb:Any):# type: ignore[override]iftorch._jit_internal.is_scripting():returnreturnsuper().__exit__(exc_type,exc_val,exc_tb)def__call__(self,func):iftorch._jit_internal.is_scripting():returnfuncreturnsuper().__call__(func)
# Casts Tensors and containers of Tensors. Special-cases passthroughs for strings and np.ndarrays, which# may be falsely detected as "Iterables."def_cast(value,dtype):ifisinstance(value,torch.Tensor):is_eligible=(value.is_floating_point()andvalue.is_cudaand(value.dtypeisnottorch.float64))returnvalue.to(dtype)ifis_eligibleelsevalueelifisinstance(value,(str,bytes)):returnvalueelifHAS_NUMPYandisinstance(value,np.ndarray):returnvalueelifisinstance(value,collections.abc.Mapping):return{_cast(k,dtype):_cast(v,dtype)fork,vinvalue.items()}elifisinstance(value,collections.abc.Iterable):iterable=(_cast(v,dtype)forvinvalue)ifisinstance(value,(list,tuple)):returntype(value)(iterable)else:returniterableelse:returnvalue# custom_fwd is a decorator that may or may not be used with arguments, following# https://github.com/dabeaz/python-cookbook/tree/master/src/9/defining_a_decorator_that_takes_an_optional_argument.# this works:# @custom_fwd# def forward(...):# this also works:# @custom_fwd(cast_inputs=torch.float)# def forward(...):
[docs]defcustom_fwd(fwd=None,*,cast_inputs=None):""" Helper decorator for ``forward`` methods of custom autograd functions (subclasses of :class:`torch.autograd.Function`). See the :ref:`example page<amp-custom-examples>` for more detail. Args: cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``, when ``forward`` runs in an autocast-enabled region, casts incoming floating-point CUDA Tensors to the target dtype (non-floating-point Tensors are not affected), then executes ``forward`` with autocast disabled. If ``None``, ``forward``'s internal ops execute with the current autocast state. .. note:: If the decorated ``forward`` is called outside an autocast-enabled region, :func:`custom_fwd<custom_fwd>` is a no-op and ``cast_inputs`` has no effect. """iffwdisNone:returnfunctools.partial(custom_fwd,cast_inputs=cast_inputs)@functools.wraps(fwd)defdecorate_fwd(*args,**kwargs):args[0]._dtype=torch.get_autocast_gpu_dtype()ifcast_inputsisNone:args[0]._fwd_used_autocast=torch.is_autocast_enabled()returnfwd(*args,**kwargs)else:autocast_context=torch.is_autocast_enabled()args[0]._fwd_used_autocast=Falseifautocast_context:withautocast(enabled=False):returnfwd(*_cast(args,cast_inputs),**_cast(kwargs,cast_inputs))else:returnfwd(*args,**kwargs)returndecorate_fwd
# Autograd ensures incoming gradients are the same type as forward outputs. Allowing a separate# cast_inputs argument on custom_bwd is unnecessary and could cause errors if it doesn't match# cast_inputs supplied to custom_fwd.
[docs]defcustom_bwd(bwd):""" Helper decorator for backward methods of custom autograd functions (subclasses of :class:`torch.autograd.Function`). Ensures that ``backward`` executes with the same autocast state as ``forward``. See the :ref:`example page<amp-custom-examples>` for more detail. """@functools.wraps(bwd)defdecorate_bwd(*args,**kwargs):withautocast(enabled=args[0]._fwd_used_autocast,dtype=args[0]._dtype):returnbwd(*args,**kwargs)returndecorate_bwd
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.