Gradcheck mechanics

This note presents an overview of how the gradcheck() and gradgradcheck() functions work.

It will cover both forward and backward mode AD for both real and complex-valued functions as well as higher-order derivatives. This note also covers both the default behavior of gradcheck as well as the case where fast_mode=True argument is passed (referred to as fast gradcheck below).

Notations and background information

Throughout this note, we will use the following convention:

  1. xx, yy, aa, bb, vv, uu, urur and uiui are real-valued vectors and zz is a complex-valued vector that can be rewritten in terms of two real-valued vectors as z=a+ibz = a + i b.

  2. NN and MM are two integers that we will use for the dimension of the input and output space respectively.

  3. f:RNRMf: \mathcal{R}^N \to \mathcal{R}^M is our basic real-to-real function such that y=f(x)y = f(x).

  4. g:CNRMg: \mathcal{C}^N \to \mathcal{R}^M is our basic complex-to-real function such that y=g(z)y = g(z).

For the simple real-to-real case, we write as JfJ_f the Jacobian matrix associated with ff of size M×NM \times N. This matrix contains all the partial derivatives such that the entry at position (i,j)(i, j) contains yixj\frac{\partial y_i}{\partial x_j}. Backward mode AD is then computing, for a given vector vv of size MM, the quantity vTJfv^T J_f. Forward mode AD on the other hand is computing, for a given vector uu of size NN, the quantity JfuJ_f u.

For functions that contain complex values, the story is a lot more complex. We only provide the gist here and the full description can be found at Autograd for Complex Numbers.

The constraints to satisfy complex differentiability (Cauchy-Riemann equations) are too restrictive for all real-valued loss functions, so we instead opted to use Wirtinger calculus. In a basic setting of Wirtinger calculus, the chain rule requires access to both the Wirtinger derivative (called WW below) and the Conjugate Wirtinger derivative (called CWCW below). Both WW and CWCW need to be propagated because in general, despite their name, one is not the complex conjugate of the other.

To avoid having to propagate both values, for backward mode AD, we always work under the assumption that the function whose derivative is being calculated is either a real-valued function or is part of a bigger real-valued function. This assumption means that all the intermediary gradients we compute during the backward pass are also associated with real-valued functions. In practice, this assumption is not restrictive when doing optimization as such problem require real-valued objectives (as there is no natural ordering of the complex numbers).

Under this assumption, using WW and CWCW definitions, we can show that W=CWW = CW^* (we use * to denote complex conjugation here) and so only one of the two values actually need to be “backwarded through the graph” as the other one can easily be recovered. To simplify internal computations, PyTorch uses 2CW2 * CW as the value it backwards and returns when the user asks for gradients. Similarly to the real case, when the output is actually in RM\mathcal{R}^M, backward mode AD does not compute 2CW2 * CW but only vT(2CW)v^T (2 * CW) for a given vector vRMv \in \mathcal{R}^M.

For forward mode AD, we use a similar logic, in this case, assuming that the function is part of a larger function whose input is in R\mathcal{R}. Under this assumption, we can make a similar claim that every intermediary result corresponds to a function whose input is in R\mathcal{R} and in this case, using WW and CWCW definitions, we can show that W=CWW = CW for the intermediary functions. To make sure the forward and backward mode compute the same quantities in the elementary case of a one dimensional function, the forward mode also computes 2CW2 * CW. Similarly to the real case, when the input is actually in RN\mathcal{R}^N, forward mode AD does not compute 2CW2 * CW but only (2CW)u(2 * CW) u for a given vector uRNu \in \mathcal{R}^N.

Default backward mode gradcheck behavior

Real-to-real functions

To test a function f:RNRM,xyf: \mathcal{R}^N \to \mathcal{R}^M, x \to y, we reconstruct the full Jacobian matrix JfJ_f of size M×NM \times N in two ways: analytically and numerically. The analytical version uses our backward mode AD while the numerical version uses finite difference. The two reconstructed Jacobian matrices are then compared elementwise for equality.

Default real input numerical evaluation

If we consider the elementary case of a one-dimensional function (N=M=1N = M = 1), then we can use the basic finite difference formula from the wikipedia article. We use the “central difference” for better numerical properties:

yxf(x+eps)f(xeps)2eps\frac{\partial y}{\partial x} \approx \frac{f(x + eps) - f(x - eps)}{2 * eps}

This formula easily generalizes for multiple outputs (M>1M \gt 1) by having yx\frac{\partial y}{\partial x} be a column vector of size M×1M \times 1 like f(x+eps)f(x + eps). In that case, the above formula can be re-used as-is and approximates the full Jacobian matrix with only two evaluations of the user function (namely f(x+eps)f(x + eps) and f(xeps)f(x - eps)).

It is more computationally expensive to handle the case with multiple inputs (N>1N \gt 1). In this scenario, we loop over all the inputs one after the other and apply the epseps perturbation for each element of xx one after the other. This allows us to reconstruct the JfJ_f matrix column by column.

Default real input analytical evaluation

For the analytical evaluation, we use the fact, as described above, that backward mode AD computes vTJfv^T J_f. For functions with a single output, we simply use v=1v = 1 to recover the full Jacobian matrix with a single backward pass.

For functions with more than one output, we resort to a for-loop which iterates over the outputs where each vv is a one-hot vector corresponding to each output one after the other. This allows to reconstruct the JfJ_f matrix row by row.

Complex-to-real functions

To test a function g:CNRM,zyg: \mathcal{C}^N \to \mathcal{R}^M, z \to y with z=a+ibz = a + i b, we reconstruct the (complex-valued) matrix that contains 2CW2 * CW.

Default complex input numerical evaluation

Consider the elementary case where N=M=1N = M = 1 first. We know from (chapter 3 of) this research paper that:

CW:=yz=12(ya+iyb)CW := \frac{\partial y}{\partial z^*} = \frac{1}{2} * (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b})

Note that ya\frac{\partial y}{\partial a} and yb\frac{\partial y}{\partial b}, in the above equation, are RR\mathcal{R} \to \mathcal{R} derivatives. To evaluate these numerically, we use the method described above for the real-to-real case. This allows us to compute the CWCW matrix and then multiply it by 22.

Note that the code, as of time of writing, computes this value in a slightly convoluted way:

# Code from
# Notation changes in this code block:
# s here is y above
# x, y here are a, b above

ds_dx = compute_gradient(eps)
ds_dy = compute_gradient(eps * 1j)
# conjugate wirtinger derivative
conj_w_d = 0.5 * (ds_dx + ds_dy * 1j)
# wirtinger derivative
w_d = 0.5 * (ds_dx - ds_dy * 1j)
d[d_idx] = grad_out.conjugate() * conj_w_d + grad_out * w_d.conj()

# Since grad_out is always 1, and W and CW are complex conjugate of each other, the last line ends up computing exactly `conj_w_d + w_d.conj() = conj_w_d + conj_w_d = 2 * conj_w_d`.

Default complex input analytical evaluation

Since backward mode AD computes exactly twice the CWCW derivative already, we simply use the same trick as for the real-to-real case here and reconstruct the matrix row by row when there are multiple real outputs.

Functions with complex outputs

In this case, the user-provided function does not follow the assumption from the autograd that the function we compute backward AD for is real-valued. This means that using autograd directly on this function is not well defined. To solve this, we will replace the test of the function h:PNCMh: \mathcal{P}^N \to \mathcal{C}^M (where P\mathcal{P} can be either R\mathcal{R} or C\mathcal{C}), with two functions: hrhr and hihi such that:

hr(q):=real(f(q))hi(q):=imag(f(q))\begin{aligned} hr(q) &:= real(f(q)) \\ hi(q) &:= imag(f(q)) \end{aligned}

where qPq \in \mathcal{P}. We then do a basic gradcheck for both hrhr and hihi using either the real-to-real or complex-to-real case described above, depending on P\mathcal{P}.

Note that, the code, as of time of writing, does not create these functions explicitly but perform the chain rule with the realreal or imagimag functions manually by passing the grad_out\text{grad\_out} arguments to the different functions. When grad_out=1\text{grad\_out} = 1, then we are considering hrhr. When grad_out=1j\text{grad\_out} = 1j, then we are considering hihi.

Fast backward mode gradcheck

While the above formulation of gradcheck is great, both, to ensure correctness and debuggability, it is very slow because it reconstructs the full Jacobian matrices. This section presents a way to perform gradcheck in a faster way without affecting its correctness. The debuggability can be recovered by adding special logic when we detect an error. In that case, we can run the default version that reconstructs the full matrix to give full details to the user.

The high level strategy here is to find a scalar quantity that can be computed efficiently by both the numerical and analytical methods and that represents the full matrix computed by the slow gradcheck well enough to ensure that it will catch any discrepancy in the Jacobians.

Fast gradcheck for real-to-real functions

The scalar quantity that we want to compute here is vTJfuv^T J_f u for a given random vector vRMv \in \mathcal{R}^M and a random unit norm vector uRNu \in \mathcal{R}^N.

For the numerical evaluation, we can efficiently compute

Jfuf(x+ueps)f(xueps)2eps.J_f u \approx \frac{f(x + u * eps) - f(x - u * eps)}{2 * eps}.

We then perform the dot product between this vector and vv to get the scalar value of interest.

For the analytical version, we can use backward mode AD to compute vTJfv^T J_f directly. We then perform the dot product with uu to get the expected value.

Fast gradcheck for complex-to-real functions

Similar to the real-to-real case, we want to perform a reduction of the full matrix. But the 2CW2 * CW matrix is complex-valued and so in this case, we will compare to complex scalars.

Due to some constraints on what we can compute efficiently in the numerical case and to keep the number of numerical evaluations to a minimum, we compute the following (albeit surprising) scalar value:

s:=2vT(real(CW)ur+iimag(CW)ui)s := 2 * v^T (real(CW) ur + i * imag(CW) ui)

where vRMv \in \mathcal{R}^M, urRNur \in \mathcal{R}^N and uiRNui \in \mathcal{R}^N.

Fast complex input numerical evaluation

We first consider how to compute ss with a numerical method. To do so, keeping in mind that we’re considering g:CNRM,zyg: \mathcal{C}^N \to \mathcal{R}^M, z \to y with z=a+ibz = a + i b, and that CW=12(ya+iyb)CW = \frac{1}{2} * (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b}), we rewrite it as follows:

s=2vT(real(CW)ur+iimag(CW)ui)=2vT(12yaur+i12ybui)=vT(yaur+iybui)=vT((yaur)+i(ybui))\begin{aligned} s &= 2 * v^T (real(CW) ur + i * imag(CW) ui) \\ &= 2 * v^T (\frac{1}{2} * \frac{\partial y}{\partial a} ur + i * \frac{1}{2} * \frac{\partial y}{\partial b} ui) \\ &= v^T (\frac{\partial y}{\partial a} ur + i * \frac{\partial y}{\partial b} ui) \\ &= v^T ((\frac{\partial y}{\partial a} ur) + i * (\frac{\partial y}{\partial b} ui)) \end{aligned}

In this formula, we can see that yaur\frac{\partial y}{\partial a} ur and ybui\frac{\partial y}{\partial b} ui can be evaluated the same way as the fast version for the real-to-real case. Once these real-valued quantities have been computed, we can reconstruct the complex vector on the right side and do a dot product with the real-valued vv vector.

Fast complex input analytical evaluation

For the analytical case, things are simpler and we rewrite the formula as:

s=2vT(real(CW)ur+iimag(CW)ui)=vTreal(2CW)ur+ivTimag(2CW)ui)=real(vT(2CW))ur+iimag(vT(2CW))ui\begin{aligned} s &= 2 * v^T (real(CW) ur + i * imag(CW) ui) \\ &= v^T real(2 * CW) ur + i * v^T imag(2 * CW) ui) \\ &= real(v^T (2 * CW)) ur + i * imag(v^T (2 * CW)) ui \end{aligned}

We can thus use the fact that the backward mode AD provides us with an efficient way to compute vT(2CW)v^T (2 * CW) and then perform a dot product of the real part with urur and the imaginary part with uiui before reconstructing the final complex scalar ss.

Why not use a complex uu

At this point, you might be wondering why we did not select a complex uu and just performed the reduction 2vTCWu2 * v^T CW u'. To dive into this, in this paragraph, we will use the complex version of uu noted u=ur+iuiu' = ur' + i ui'. Using such complex uu', the problem is that when doing the numerical evaluation, we would need to compute:

2CWu=(ya+iyb)(ur+iui)=yaur+iyaui+iyburybui\begin{aligned} 2*CW u' &= (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b})(ur' + i ui') \\ &= \frac{\partial y}{\partial a} ur' + i \frac{\partial y}{\partial a} ui' + i \frac{\partial y}{\partial b} ur' - \frac{\partial y}{\partial b} ui' \end{aligned}

Which would require four evaluations of real-to-real finite difference (twice as much compared to the approached proposed above). Since this approach does not have more degrees of freedom (same number of real valued variables) and we try to get the fastest possible evaluation here, we use the other formulation above.

Fast gradcheck for functions with complex outputs

Just like in the slow case, we consider two real-valued functions and use the appropriate rule from above for each function.

Gradgradcheck implementation

PyTorch also provide a utility to verify second order gradients. The goal here is to make sure that the backward implementation is also properly differentiable and computes the right thing.

This feature is implemented by considering the function F:x,vvTJfF: x, v \to v^T J_f and use the gradcheck defined above on this function. Note that vv in this case is just a random vector with the same type as f(x)f(x).

The fast version of gradgradcheck is implemented by using the fast version of gradcheck on that same function FF.


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources