- class torch.enable_grad(orig_func=None)¶
Context-manager that enables gradient calculation.
This context manager is thread local; it will not affect computation in other threads.
Also functions as a decorator.
enable_grad is one of several mechanisms that can enable or disable gradients locally see Locally disabling gradient computation for more information on how they compare.
This API does not apply to forward-mode AD.
>>> x = torch.tensor([1.], requires_grad=True) >>> with torch.no_grad(): ... with torch.enable_grad(): ... y = x * 2 >>> y.requires_grad True >>> y.backward() >>> x.grad tensor([2.]) >>> @torch.enable_grad() ... def doubler(x): ... return x * 2 >>> with torch.no_grad(): ... z = doubler(x) >>> z.requires_grad True >>> @torch.enable_grad ... def tripler(x): ... return x * 3 >>> with torch.no_grad(): ... z = tripler(x) >>> z.requires_grad True