Saves given tensors for a future call to
save_for_backwardshould be called at most once, only from inside the
forward()method, and only with tensors.
All tensors intended to be used in the backward pass should be saved with
save_for_backward(as opposed to directly on
ctx) to prevent incorrect gradients and memory leaks, and enable the application of saved tensor hooks. See
backward(), saved tensors can be accessed through the
saved_tensorsattribute. Before returning them to the user, a check is made to ensure they weren’t used in any in-place operation that modified their content.
Arguments can also be
None. This is a no-op.
See Extending torch.autograd for more details on how to use this method.
>>> class Func(Function): >>> @staticmethod >>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int): >>> w = x * y * z >>> out = x * y + y * z + w >>> ctx.save_for_backward(x, y, w, out) >>> ctx.z = z # z is not a tensor >>> return out >>> >>> @staticmethod >>> def backward(ctx, grad_out): >>> x, y, w, out = ctx.saved_tensors >>> z = ctx.z >>> gx = grad_out * (y + y * z) >>> gy = grad_out * (x + z + x * z) >>> gz = None >>> return gx, gy, gz >>> >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double) >>> b = torch.tensor(2., requires_grad=True, dtype=torch.double) >>> c = 4 >>> d = Func.apply(a, b, c)