importsysimporttorchimportfunctoolsimportinspectfromtypingimportAny,Callable,TypeVar,cast__all__=['no_grad','enable_grad','set_grad_enabled','inference_mode']# Used for annotating the decorator usage of 'no_grad' and 'enable_grad'.# See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decoratorsFuncType=Callable[...,Any]F=TypeVar('F',bound=FuncType)class_DecoratorContextManager:"""Allow a context manager to be used as a decorator"""def__call__(self,func:F)->F:ifinspect.isgeneratorfunction(func):returnself._wrap_generator(func)@functools.wraps(func)defdecorate_context(*args,**kwargs):withself.__class__():returnfunc(*args,**kwargs)returncast(F,decorate_context)def_wrap_generator(self,func):"""Wrap each generator invocation with the context manager"""@functools.wraps(func)defgenerator_context(*args,**kwargs):gen=func(*args,**kwargs)# Generators are suspended and unsuspended at `yield`, hence we# make sure the grad mode is properly set every time the execution# flow returns into the wrapped generator and restored when it# returns through our `yield` to our caller (see PR #49017).cls=type(self)try:# Issuing `None` to a generator fires it upwithcls():response=gen.send(None)whileTrue:try:# Forward the response to our caller and get its next requestrequest=yieldresponseexceptGeneratorExit:# Inform the still active generator about its imminent closurewithcls():gen.close()raiseexceptBaseException:# Propagate the exception thrown at us by the callerwithcls():response=gen.throw(*sys.exc_info())else:# Pass the last request to the generator and get its responsewithcls():response=gen.send(request)# We let the exceptions raised above by the generator's `.throw` or# `.send` methods bubble up to our caller, except for StopIterationexceptStopIterationase:# The generator informed us that it is done: take whatever its# returned value (if any) was and indicate that we're done too# by returning it (see docs for python's return-statement).returne.valuereturngenerator_contextdef__enter__(self)->None:raiseNotImplementedErrordef__exit__(self,exc_type:Any,exc_value:Any,traceback:Any)->None:raiseNotImplementedError
[docs]classno_grad(_DecoratorContextManager):r"""Context-manager that disabled gradient calculation. Disabling gradient calculation is useful for inference, when you are sure that you will not call :meth:`Tensor.backward()`. It will reduce memory consumption for computations that would otherwise have `requires_grad=True`. In this mode, the result of every computation will have `requires_grad=False`, even when the inputs have `requires_grad=True`. This context manager is thread local; it will not affect computation in other threads. Also functions as a decorator. (Make sure to instantiate with parenthesis.) .. note:: No-grad is one of several mechanisms that can enable or disable gradients locally see :ref:`locally-disable-grad-doc` for more information on how they compare. Example:: >>> x = torch.tensor([1], requires_grad=True) >>> with torch.no_grad(): ... y = x * 2 >>> y.requires_grad False >>> @torch.no_grad() ... def doubler(x): ... return x * 2 >>> z = doubler(x) >>> z.requires_grad False """def__init__(self):ifnottorch._jit_internal.is_scripting():super().__init__()self.prev=Falsedef__enter__(self):self.prev=torch.is_grad_enabled()torch.set_grad_enabled(False)def__exit__(self,exc_type:Any,exc_value:Any,traceback:Any)->None:torch.set_grad_enabled(self.prev)
classenable_grad(_DecoratorContextManager):r"""Context-manager that enables gradient calculation. Enables gradient calculation, if it has been disabled via :class:`~no_grad` or :class:`~set_grad_enabled`. This context manager is thread local; it will not affect computation in other threads. Also functions as a decorator. (Make sure to instantiate with parenthesis.) .. note:: enable_grad is one of several mechanisms that can enable or disable gradients locally see :ref:`locally-disable-grad-doc` for more information on how they compare. Example:: >>> x = torch.tensor([1.], requires_grad=True) >>> with torch.no_grad(): ... with torch.enable_grad(): ... y = x * 2 >>> y.requires_grad True >>> y.backward() >>> x.grad >>> @torch.enable_grad() ... def doubler(x): ... return x * 2 >>> with torch.no_grad(): ... z = doubler(x) >>> z.requires_grad True """def__enter__(self)->None:self.prev=torch.is_grad_enabled()torch._C._set_grad_enabled(True)def__exit__(self,exc_type:Any,exc_value:Any,traceback:Any)->None:torch._C._set_grad_enabled(self.prev)
[docs]classset_grad_enabled(object):r"""Context-manager that sets gradient calculation to on or off. ``set_grad_enabled`` will enable or disable grads based on its argument :attr:`mode`. It can be used as a context-manager or as a function. This context manager is thread local; it will not affect computation in other threads. Args: mode (bool): Flag whether to enable grad (``True``), or disable (``False``). This can be used to conditionally enable gradients. .. note:: set_grad_enabled is one of several mechanisms that can enable or disable gradients locally see :ref:`locally-disable-grad-doc` for more information on how they compare. Example:: >>> x = torch.tensor([1], requires_grad=True) >>> is_train = False >>> with torch.set_grad_enabled(is_train): ... y = x * 2 >>> y.requires_grad False >>> torch.set_grad_enabled(True) >>> y = x * 2 >>> y.requires_grad True >>> torch.set_grad_enabled(False) >>> y = x * 2 >>> y.requires_grad False """def__init__(self,mode:bool)->None:self.prev=torch.is_grad_enabled()torch._C._set_grad_enabled(mode)def__enter__(self)->None:passdef__exit__(self,exc_type:Any,exc_value:Any,traceback:Any)->None:torch._C._set_grad_enabled(self.prev)
classinference_mode(_DecoratorContextManager):r"""Context-manager that enables or disables inference mode InferenceMode is a new context manager analogous to :class:`~no_grad` to be used when you are certain your operations will have no interactions with autograd (e.g., model training). Code run under this mode gets better performance by disabling view tracking and version counter bumps. This context manager is thread local; it will not affect computation in other threads. Also functions as a decorator. (Make sure to instantiate with parenthesis.) .. note:: Inference mode is one of several mechanisms that can enable or disable gradients locally see :ref:`locally-disable-grad-doc` for more information on how they compare. Args: mode (bool): Flag whether to enable or disable inference mode Example:: >>> import torch >>> x = torch.ones(1, 2, 3, requires_grad=True) >>> with torch.inference_mode(): ... y = x * x >>> y.requires_grad False >>> y._version Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: Inference tensors do not track version counter. >>> @torch.inference_mode() ... def func(x): ... return x * x >>> out = func(x) >>> out.requires_grad False """def__init__(self,mode=True):ifnottorch._jit_internal.is_scripting():super().__init__()# Holds a python binding to a RAII guard that can enable or disable# inference modeself._inference_mode_raii_guard=Noneself.mode=modedef__enter__(self):self._inference_mode_raii_guard=torch._C._InferenceMode(self.mode)def__exit__(self,exc_type:Any,exc_value:Any,traceback:Any)->None:delself._inference_mode_raii_guard
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.