(*tensors)[source] Saves given tensors for a future call to
.This should be called at most once, and only from inside the
method. This should only be called with input or output tensorsIn
, saved tensors can be accessed through thesaved_tensors
attribute. Before returning them to the user, a check is made to ensure they weren’t used in any in-place operation that modified their content.Arguments can also be
. This is a no-op.See Extending torch.autograd for more details on how to use this method.
- Example::
>>> class Func(Function): >>> @staticmethod >>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int): >>> w = x * y * z >>> out = x * y + y * z + w >>> ctx.save_for_backward(x, y, out) >>> ctx.z = z # z is not a tensor >>> ctx.w = w # w is neither input nor output >>> return out >>> >>> @staticmethod >>> def backward(ctx, grad_out): >>> x, y, out = ctx.saved_tensors >>> z = ctx.z >>> gx = grad_out * (y + y * z) >>> gy = grad_out * (x + z + x * z) >>> gz = None >>> return gx, gy, gz >>> >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double) >>> b = torch.tensor(2., requires_grad=True, dtype=torch.double) >>> c = 4 >>> d = Func.apply(a, b, c)