Shortcuts

torch.autograd.function.FunctionCtx.set_materialize_grads

FunctionCtx.set_materialize_grads(value)[source]

Sets whether to materialize output grad tensors. Default is True.

This should be called only from inside the forward() method

If True, undefined output grad tensors will be expanded to tensors full of zeros prior to calling the backward() method.

Example::
>>> class SimpleFunc(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         return x.clone(), x.clone()
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):
>>>         return g1 + g2  # No check for None necessary
>>>
>>> # We modify SimpleFunc to handle non-materialized grad outputs
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         ctx.set_materialize_grads(False)
>>>         ctx.save_for_backward(x)
>>>         return x.clone(), x.clone()
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):
>>>         x, = ctx.saved_tensors
>>>         grad_input = torch.zeros_like(x)
>>>         if g1 is not None:  # We must check for None now
>>>             grad_input += g1
>>>         if g2 is not None:
>>>             grad_input += g2
>>>         return grad_input
>>>
>>> a = torch.tensor(1., requires_grad=True)
>>> b, _ = Func.apply(a)  # induces g2 to be undefined

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources