Shortcuts

enable_grad

class torch.enable_grad(orig_func=None)[source][source]

Context-manager that enables gradient calculation.

Enables gradient calculation, if it has been disabled via no_grad or set_grad_enabled.

This context manager is thread local; it will not affect computation in other threads.

Also functions as a decorator.

Note

enable_grad is one of several mechanisms that can enable or disable gradients locally see Locally disabling gradient computation for more information on how they compare.

Note

This API does not apply to forward-mode AD.

Example::
>>> x = torch.tensor([1.], requires_grad=True)
>>> with torch.no_grad():
...     with torch.enable_grad():
...         y = x * 2
>>> y.requires_grad
True
>>> y.backward()
>>> x.grad
tensor([2.])
>>> @torch.enable_grad()
... def doubler(x):
...     return x * 2
>>> with torch.no_grad():
...     z = doubler(x)
>>> z.requires_grad
True
>>> @torch.enable_grad()
... def tripler(x):
...     return x * 3
>>> with torch.no_grad():
...     z = tripler(x)
>>> z.requires_grad
True

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources