- static Function.jvp(ctx, *grad_inputs)¶
Define a formula for differentiating the operation with forward mode automatic differentiation.
This function is to be overridden by all subclasses. It must accept a context
ctxas the first argument, followed by as many inputs as the
forward()got (None will be passed in for non tensor inputs of the forward function), and it should return as many tensors as there were outputs to
forward(). Each argument is the gradient w.r.t the given input, and each returned value should be the gradient w.r.t. the corresponding output. If an output is not a Tensor or the function is not differentiable with respect to that output, you can just pass None as a gradient for that input.
You can use the
ctxobject to pass any value from the forward to this functions.
- Return type