Shortcuts

set_grad_enabled

class torch.set_grad_enabled(mode)[source]

Context-manager that sets gradient calculation on or off.

set_grad_enabled will enable or disable grads based on its argument mode. It can be used as a context-manager or as a function.

This context manager is thread local; it will not affect computation in other threads.

Parameters

mode (bool) – Flag whether to enable grad (True), or disable (False). This can be used to conditionally enable gradients.

Note

set_grad_enabled is one of several mechanisms that can enable or disable gradients locally see Locally disabling gradient computation for more information on how they compare.

Note

This API does not apply to forward-mode AD.

Example::
>>> x = torch.tensor([1.], requires_grad=True)
>>> is_train = False
>>> with torch.set_grad_enabled(is_train):
...     y = x * 2
>>> y.requires_grad
False
>>> _ = torch.set_grad_enabled(True)
>>> y = x * 2
>>> y.requires_grad
True
>>> _ = torch.set_grad_enabled(False)
>>> y = x * 2
>>> y.requires_grad
False

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources