enable_grad¶
- class torch.enable_grad(orig_func=None)[source][source]¶
Context-manager that enables gradient calculation.
Enables gradient calculation, if it has been disabled via
no_grad
orset_grad_enabled
.This context manager is thread local; it will not affect computation in other threads.
Also functions as a decorator.
Note
enable_grad is one of several mechanisms that can enable or disable gradients locally see Locally disabling gradient computation for more information on how they compare.
Note
This API does not apply to forward-mode AD.
- Example::
>>> x = torch.tensor([1.], requires_grad=True) >>> with torch.no_grad(): ... with torch.enable_grad(): ... y = x * 2 >>> y.requires_grad True >>> y.backward() >>> x.grad tensor([2.]) >>> @torch.enable_grad() ... def doubler(x): ... return x * 2 >>> with torch.no_grad(): ... z = doubler(x) >>> z.requires_grad True >>> @torch.enable_grad() ... def tripler(x): ... return x * 3 >>> with torch.no_grad(): ... z = tripler(x) >>> z.requires_grad True