This note presents an overview of how the
gradgradcheck() functions work.
It will cover both forward and backward mode AD for both real and complex-valued functions as well as higher-order derivatives.
This note also covers both the default behavior of gradcheck as well as the case where
fast_mode=True argument is passed (referred to as fast gradcheck below).
Notations and background information¶
Throughout this note, we will use the following convention:
, , , , , , and are real-valued vectors and is a complex-valued vector that can be rewritten in terms of two real-valued vectors as .
and are two integers that we will use for the dimension of the input and output space respectively.
is our basic real-to-real function such that .
is our basic complex-to-real function such that .
For the simple real-to-real case, we write as the Jacobian matrix associated with of size . This matrix contains all the partial derivatives such that the entry at position contains . Backward mode AD is then computing, for a given vector of size , the quantity . Forward mode AD on the other hand is computing, for a given vector of size , the quantity .
For functions that contain complex values, the story is a lot more complex. We only provide the gist here and the full description can be found at Autograd for Complex Numbers.
The constraints to satisfy complex differentiability (Cauchy-Riemann equations) are too restrictive for all real-valued loss functions, so we instead opted to use Wirtinger calculus. In a basic setting of Wirtinger calculus, the chain rule requires access to both the Wirtinger derivative (called below) and the Conjugate Wirtinger derivative (called below). Both and need to be propagated because in general, despite their name, one is not the complex conjugate of the other.
To avoid having to propagate both values, for backward mode AD, we always work under the assumption that the function whose derivative is being calculated is either a real-valued function or is part of a bigger real-valued function. This assumption means that all the intermediary gradients we compute during the backward pass are also associated with real-valued functions. In practice, this assumption is not restrictive when doing optimization as such problem require real-valued objectives (as there is no natural ordering of the complex numbers).
Under this assumption, using and definitions, we can show that (we use to denote complex conjugation here) and so only one of the two values actually need to be “backwarded through the graph” as the other one can easily be recovered. To simplify internal computations, PyTorch uses as the value it backwards and returns when the user asks for gradients. Similarly to the real case, when the output is actually in , backward mode AD does not compute but only for a given vector .
For forward mode AD, we use a similar logic, in this case, assuming that the function is part of a larger function whose input is in . Under this assumption, we can make a similar claim that every intermediary result corresponds to a function whose input is in and in this case, using and definitions, we can show that for the intermediary functions. To make sure the forward and backward mode compute the same quantities in the elementary case of a one dimensional function, the forward mode also computes . Similarly to the real case, when the input is actually in , forward mode AD does not compute but only for a given vector .
Default backward mode gradcheck behavior¶
To test a function , we reconstruct the full Jacobian matrix of size in two ways: analytically and numerically. The analytical version uses our backward mode AD while the numerical version uses finite difference. The two reconstructed Jacobian matrices are then compared elementwise for equality.
Default real input numerical evaluation¶
If we consider the elementary case of a one-dimensional function (), then we can use the basic finite difference formula from the wikipedia article. We use the “central difference” for better numerical properties:
This formula easily generalizes for multiple outputs () by having be a column vector of size like . In that case, the above formula can be re-used as-is and approximates the full Jacobian matrix with only two evaluations of the user function (namely and ).
It is more computationally expensive to handle the case with multiple inputs (). In this scenario, we loop over all the inputs one after the other and apply the perturbation for each element of one after the other. This allows us to reconstruct the matrix column by column.
Default real input analytical evaluation¶
For the analytical evaluation, we use the fact, as described above, that backward mode AD computes . For functions with a single output, we simply use to recover the full Jacobian matrix with a single backward pass.
For functions with more than one output, we resort to a for-loop which iterates over the outputs where each is a one-hot vector corresponding to each output one after the other. This allows to reconstruct the matrix row by row.
To test a function with , we reconstruct the (complex-valued) matrix that contains .
Default complex input numerical evaluation¶
Consider the elementary case where first. We know from (chapter 3 of) this research paper that:
Note that and , in the above equation, are derivatives. To evaluate these numerically, we use the method described above for the real-to-real case. This allows us to compute the matrix and then multiply it by .
Note that the code, as of time of writing, computes this value in a slightly convoluted way:
# Code from https://github.com/pytorch/pytorch/blob/58eb23378f2a376565a66ac32c93a316c45b6131/torch/autograd/gradcheck.py#L99-L105 # Notation changes in this code block: # s here is y above # x, y here are a, b above ds_dx = compute_gradient(eps) ds_dy = compute_gradient(eps * 1j) # conjugate wirtinger derivative conj_w_d = 0.5 * (ds_dx + ds_dy * 1j) # wirtinger derivative w_d = 0.5 * (ds_dx - ds_dy * 1j) d[d_idx] = grad_out.conjugate() * conj_w_d + grad_out * w_d.conj() # Since grad_out is always 1, and W and CW are complex conjugate of each other, the last line ends up computing exactly `conj_w_d + w_d.conj() = conj_w_d + conj_w_d = 2 * conj_w_d`.
Default complex input analytical evaluation¶
Since backward mode AD computes exactly twice the derivative already, we simply use the same trick as for the real-to-real case here and reconstruct the matrix row by row when there are multiple real outputs.
Functions with complex outputs¶
In this case, the user-provided function does not follow the assumption from the autograd that the function we compute backward AD for is real-valued. This means that using autograd directly on this function is not well defined. To solve this, we will replace the test of the function (where can be either or ), with two functions: and such that:
where . We then do a basic gradcheck for both and using either the real-to-real or complex-to-real case described above, depending on .
Note that, the code, as of time of writing, does not create these functions explicitly but perform the chain rule with the or functions manually by passing the arguments to the different functions. When , then we are considering . When , then we are considering .
Fast backward mode gradcheck¶
While the above formulation of gradcheck is great, both, to ensure correctness and debuggability, it is very slow because it reconstructs the full Jacobian matrices. This section presents a way to perform gradcheck in a faster way without affecting its correctness. The debuggability can be recovered by adding special logic when we detect an error. In that case, we can run the default version that reconstructs the full matrix to give full details to the user.
The high level strategy here is to find a scalar quantity that can be computed efficiently by both the numerical and analytical methods and that represents the full matrix computed by the slow gradcheck well enough to ensure that it will catch any discrepancy in the Jacobians.
Fast gradcheck for real-to-real functions¶
The scalar quantity that we want to compute here is for a given random vector and a random unit norm vector .
For the numerical evaluation, we can efficiently compute
We then perform the dot product between this vector and to get the scalar value of interest.
For the analytical version, we can use backward mode AD to compute directly. We then perform the dot product with to get the expected value.
Fast gradcheck for complex-to-real functions¶
Similar to the real-to-real case, we want to perform a reduction of the full matrix. But the matrix is complex-valued and so in this case, we will compare to complex scalars.
Due to some constraints on what we can compute efficiently in the numerical case and to keep the number of numerical evaluations to a minimum, we compute the following (albeit surprising) scalar value:
where , and .
Fast complex input numerical evaluation¶
We first consider how to compute with a numerical method. To do so, keeping in mind that we’re considering with , and that , we rewrite it as follows:
In this formula, we can see that and can be evaluated the same way as the fast version for the real-to-real case. Once these real-valued quantities have been computed, we can reconstruct the complex vector on the right side and do a dot product with the real-valued vector.
Fast complex input analytical evaluation¶
For the analytical case, things are simpler and we rewrite the formula as:
We can thus use the fact that the backward mode AD provides us with an efficient way to compute and then perform a dot product of the real part with and the imaginary part with before reconstructing the final complex scalar .
Why not use a complex ¶
At this point, you might be wondering why we did not select a complex and just performed the reduction . To dive into this, in this paragraph, we will use the complex version of noted . Using such complex , the problem is that when doing the numerical evaluation, we would need to compute:
Which would require four evaluations of real-to-real finite difference (twice as much compared to the approached proposed above). Since this approach does not have more degrees of freedom (same number of real valued variables) and we try to get the fastest possible evaluation here, we use the other formulation above.
Fast gradcheck for functions with complex outputs¶
Just like in the slow case, we consider two real-valued functions and use the appropriate rule from above for each function.
PyTorch also provide a utility to verify second order gradients. The goal here is to make sure that the backward implementation is also properly differentiable and computes the right thing.
This feature is implemented by considering the function and use the gradcheck defined above on this function. Note that in this case is just a random vector with the same type as .
The fast version of gradgradcheck is implemented by using the fast version of gradcheck on that same function .