Shortcuts

Source code for torch.autograd.gradcheck

import collections
import functools
import warnings
from itertools import product
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union

import torch
import torch.testing
from torch._vmap_internals import _vmap, vmap
from torch.overrides import is_tensor_like
from torch.types import _TensorOrTensors

# Note: `get_*_jacobian` functions are added here even though we didn't intend to make them public
# since they have been exposed from before we added `__all__`  and we already maintain BC for them
# We should eventually deprecate them and remove them from `__all__`
__all__ = [
    "gradcheck",
    "gradgradcheck",
    "GradcheckError",
    "get_numerical_jacobian",
    "get_analytical_jacobian",
    "get_numerical_jacobian_wrt_specific_input",
]


[docs]class GradcheckError(RuntimeError): r"""Error raised by :func:`gradcheck` and :func:`gradgradcheck`.""" pass
def _is_sparse_compressed_tensor(obj: torch.Tensor): return obj.layout in { torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc, } def _is_sparse_any_tensor(obj: torch.Tensor): return _is_sparse_compressed_tensor(obj) or obj.layout is torch.sparse_coo def _is_float_or_complex_tensor(obj): return is_tensor_like(obj) and (obj.is_floating_point() or obj.is_complex()) def _allocate_jacobians_with_inputs( input_tensors: Tuple, numel_output ) -> Tuple[torch.Tensor, ...]: # Makes zero-filled tensors from inputs. If `numel_output` is not None, for # each tensor in `input_tensors`, returns a new zero-filled tensor with height # of `t.numel` and width of `numel_output`. Otherwise, for each tensor, returns # a 1-d tensor with size `(t.numel,)`. Each new tensor will be strided and have # the same dtype and device as those of the corresponding input. out: List[torch.Tensor] = [] for t in input_tensors: if _is_float_or_complex_tensor(t) and t.requires_grad: out.append(t.new_zeros((t.numel(), numel_output), layout=torch.strided)) return tuple(out) def _allocate_jacobians_with_outputs( output_tensors: Tuple, numel_input, dtype=None, device=None ) -> Tuple[torch.Tensor, ...]: # Makes zero-filled tensors from outputs. If `dim` is not None, for each tensor # in `output_tensors`, returns a new zero-filled tensor with height of `dim` and # width of `t.numel`. Otherwise, for each tensor, returns a 1-d tensor with size # (t.numel,). out: List[torch.Tensor] = [] options = {"dtype": dtype, "device": device, "layout": torch.strided} for t in output_tensors: if _is_float_or_complex_tensor(t): out.append(t.new_zeros((numel_input, t.numel()), **options)) return tuple(out) def _iter_tensors( x: Union[torch.Tensor, Iterable[torch.Tensor]], only_requiring_grad: bool = False ) -> Iterable[torch.Tensor]: if is_tensor_like(x): # mypy doesn't narrow type of `x` to torch.Tensor if x.requires_grad or not only_requiring_grad: # type: ignore[union-attr] yield x # type: ignore[misc] elif isinstance(x, collections.abc.Iterable) and not isinstance(x, str): for elem in x: yield from _iter_tensors(elem, only_requiring_grad) def _densify(x): # return a copy of sparse x with all unspecified elements # "replaced" with zero-valued elements if isinstance(x, (list, tuple)): return type(x)(map(_densify, x)) elif not is_tensor_like(x) or x.layout in {torch.strided, torch._mkldnn}: # type: ignore[attr-defined] # no attr _mkldnn return x elif x.layout is torch.sparse_coo: device = x.device indices_dtype = x._indices().dtype tmp = torch.ones(x.shape[: x.sparse_dim()], dtype=torch.int8, device=device) indices = tmp.nonzero().t().to(dtype=indices_dtype) values = torch.zeros( (tmp.numel(), *x.shape[x.sparse_dim() :]), dtype=x.dtype, device=device ) x_coalesced = x.detach().coalesce() if x_coalesced.numel() > 0: stride = tmp.stride() flat_indices = ( x_coalesced.indices() .mul( torch.tensor(stride, dtype=indices_dtype, device=device).unsqueeze( 1 ) ) .sum(0) ) values[flat_indices] = x_coalesced.values() return ( torch.sparse_coo_tensor(indices, values, x.shape) ._coalesced_(True) .requires_grad_(x.requires_grad) ) elif _is_sparse_compressed_tensor(x): blocksize = ( x.values().shape[1:3] if x.layout in {torch.sparse_bsr, torch.sparse_bsc} else None ) compressed_indices = ( x.crow_indices() if x.layout in {torch.sparse_csr, torch.sparse_bsr} else x.ccol_indices() ) # We'll use intermediate sparse COO for simplicity r = _densify(x.detach().to_sparse(layout=torch.sparse_coo)).to_sparse( layout=x.layout, blocksize=blocksize ) # Check that all elements are specified also after `to_sparse` op: dense_numel = r.values().numel() // max(1, r.values().shape[0]) batch_numel = compressed_indices.numel() // compressed_indices.shape[-1] sparse_numel = r.numel() // max(1, dense_numel * batch_numel) if sparse_numel != r._nnz(): raise AssertionError( f"{x.layout} densify failed: expected nnz={sparse_numel} but got {r._nnz()}" ) return r.requires_grad_(x.requires_grad) elif _is_sparse_any_tensor(x): raise NotImplementedError(x.layout) return x def _iter_tensor(x_tensor): # (Only used for slow gradcheck) Returns a generator that yields the following # elements at each iteration: # 1) a tensor: the same tensor is returned across all iterations. The tensor # is not the same as the original x_tensor as given as input - it is # prepared so that it can be modified in-place. Depending on whether the # input tensor is strided, sparse, or dense, the returned tensor may or may # not share storage with x_tensor. # 2) a tuple of indices that can be used with advanced indexing (yielded in # dictionary order) # 3) flattened index that will be used to index into the Jacobian tensor # # For a tensor t with size (2, 2), _iter_tensor yields: # `x, (0, 0), 0`, `x, (0, 1), 1`, `x, (1, 0), 2`, `x, (1, 1), 3` # # where x is the t.data of the original tensor. Perturbing the entry of x # at index (1, 1) yields the 3rd column of the overall Jacobian matrix. if _is_sparse_any_tensor(x_tensor): def get_stride(size): dim = len(size) tmp = 1 stride = [0] * dim for i in reversed(range(dim)): stride[i] = tmp tmp *= size[i] return stride x_nnz = x_tensor._nnz() x_size = list(x_tensor.size()) if x_tensor.layout is torch.sparse_coo: x_indices = x_tensor._indices().t() x_values = x_tensor._values() elif x_tensor.layout is torch.sparse_csr: x_indices = torch._convert_indices_from_csr_to_coo( x_tensor.crow_indices(), x_tensor.col_indices() ).t() x_values = x_tensor.values() elif x_tensor.layout is torch.sparse_csc: x_indices = torch._convert_indices_from_csr_to_coo( x_tensor.ccol_indices(), x_tensor.row_indices(), transpose=True ).t() x_values = x_tensor.values() elif x_tensor.layout is torch.sparse_bsr: x_block_values = x_tensor.values() x_blocksize = x_block_values.size()[1:3] x_indices = ( torch._convert_indices_from_csr_to_coo( x_tensor.crow_indices(), x_tensor.col_indices() ) .repeat_interleave(x_blocksize[0] * x_blocksize[1], 1) .mul_(torch.tensor(x_blocksize, device=x_tensor.device).reshape(2, 1)) .add_( torch.stack( torch.where(torch.ones(x_blocksize, device=x_tensor.device)) ).repeat(1, x_nnz) ) .t() ) x_values = x_block_values.flatten(0, 2) x_nnz = x_values.size(0) elif x_tensor.layout is torch.sparse_bsc: x_block_values = x_tensor.values() x_blocksize = x_block_values.size()[1:3] x_indices = ( torch._convert_indices_from_csr_to_coo( x_tensor.ccol_indices(), x_tensor.row_indices(), transpose=True ) .repeat_interleave(x_blocksize[0] * x_blocksize[1], 1) .mul_(torch.tensor(x_blocksize, device=x_tensor.device).reshape(2, 1)) .add_( torch.stack( torch.where(torch.ones(x_blocksize, device=x_tensor.device)) ).repeat(1, x_nnz) ) .t() ) x_values = x_block_values.flatten(0, 2) x_nnz = x_values.size(0) else: raise NotImplementedError(f"_iter_tensor for {x_tensor.layout} input") x_stride = get_stride(x_size) # Use .data here to get around the version check x_values = x_values.data for i in range(x_nnz): x_value = x_values[i] for x_idx in product(*[range(m) for m in x_values.size()[1:]]): indices = x_indices[i].tolist() + list(x_idx) d_idx = sum(indices[k] * x_stride[k] for k in range(len(x_size))) yield x_value, x_idx, d_idx elif x_tensor.layout == torch._mkldnn: # type: ignore[attr-defined] for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])): # this is really inefficient, but without indexing implemented, there's # not really a better way than converting back and forth x_tensor_dense = x_tensor.to_dense() yield x_tensor_dense, x_idx, d_idx else: # Use .data here to get around the version check x_tensor = x_tensor.data for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])): yield x_tensor, x_idx, d_idx def _get_numerical_jacobian( fn, inputs, outputs=None, target=None, eps=1e-3, is_forward_ad=False ) -> List[Tuple[torch.Tensor, ...]]: """Compute the numerical Jacobian of `fn(inputs)` with respect to `target`. If not specified, targets are the input. Returns M * N Jacobians where N is the number of tensors in target that require grad and M is the number of non-integral outputs. Args: fn: the function to compute the jacobian for inputs: inputs to `fn` outputs: provide precomputed outputs to avoid one extra invocation of fn target: the Tensors wrt whom Jacobians are calculated (default=`inputs`) eps: the magnitude of the perturbation during finite differencing (default=`1e-3`) is_forward_ad: if this numerical jacobian is computed to be checked wrt forward AD gradients (this is used for error checking only) Returns: A list of M N-tuples of tensors Note that `target` may not even be part of `input` to `fn`, so please be **very careful** in this to not clone `target`. """ jacobians: List[Tuple[torch.Tensor, ...]] = [] if outputs is None: outputs = _as_tuple(fn(*_as_tuple(inputs))) if not is_forward_ad and any(o.is_complex() for o in outputs): raise ValueError( "Expected output to be non-complex. get_numerical_jacobian no " "longer supports functions that return complex outputs." ) if target is None: target = inputs inp_indices = [ i for i, a in enumerate(target) if is_tensor_like(a) and a.requires_grad ] for i, (inp, inp_idx) in enumerate(zip(_iter_tensors(target, True), inp_indices)): jacobians += [ get_numerical_jacobian_wrt_specific_input( fn, inp_idx, inputs, outputs, eps, input=inp, is_forward_ad=is_forward_ad, ) ] return jacobians def get_numerical_jacobian(fn, inputs, target=None, eps=1e-3, grad_out=1.0): """Compute the numerical Jacobian for a given fn and its inputs. This is a Deprecated API. Args: fn: the function to compute the Jacobian for (must take inputs as a tuple) input: input to `fn` target: the Tensors wrt whom Jacobians are calculated (default=`input`) eps: the magnitude of the perturbation during finite differencing (default=`1e-3`) Returns: A list of Jacobians of `fn` (restricted to its first output) with respect to each input or target, if provided. Note that `target` may not even be part of `input` to `fn`, so please be **very careful** in this to not clone `target`. """ warnings.warn( "get_numerical_jacobian was part of PyTorch's private API and not " "meant to be exposed. We are deprecating it and it will be removed " "in a future version of PyTorch. If you have a specific use for " "this or feature request for this to be a stable API, please file " "us an issue at https://github.com/pytorch/pytorch/issues/new" ) if ( grad_out != 1.0 ): # grad_out param is only kept for backward compatibility reasons raise ValueError( "Expected grad_out to be 1.0. get_numerical_jacobian no longer " "supports values of grad_out != 1.0." ) def fn_pack_inps(*inps): return fn(inps) jacobians = _get_numerical_jacobian(fn_pack_inps, inputs, None, target, eps) return tuple(jacobian_for_each_output[0] for jacobian_for_each_output in jacobians) def _compute_numerical_gradient(fn, entry, v, norm_v, nbhd_checks_fn): # Computes numerical directional derivative as finite difference # of function `fn` at input `entry`, perturbed by vector `v`. if _is_sparse_compressed_tensor(entry): # sparse compressed tensors don't implement sub/add/copy_ # yet. However, in non-masked semantics context entry and v # have the same sparse indices ... assert entry.layout == v.layout, (entry.layout, v.layout) assert entry._nnz() == v._nnz(), (entry._nnz(), v._nnz(), entry.shape) # ... the finite differencing can be performed on values only: entry = entry.values() v = v.values() # we'll detach to avoid backward computations that sparse # tensors have limited support for. entry = entry.detach() orig = entry.clone() entry.copy_(orig - v) outa = fn() entry.copy_(orig + v) outb = fn() entry.copy_(orig) def compute(a, b): nbhd_checks_fn(a, b) ret = (b - a) / (2 * norm_v) # use central difference approx return ret.detach().reshape(-1) return tuple(compute(a, b) for (a, b) in zip(outa, outb)) def _compute_numerical_jvps_wrt_specific_input( jvp_fn, delta, input_is_complex, is_forward_ad=False ) -> List[torch.Tensor]: # Computing the jacobian only works for real delta # For details on the algorithm used here, refer: # Section 3.5.3 https://arxiv.org/pdf/1701.00392.pdf # s = fn(z) where z = x for real valued input # and z = x + yj for complex valued input jvps: List[torch.Tensor] = [] ds_dx_tup = jvp_fn(delta[0] if isinstance(delta, tuple) else delta) if input_is_complex: # C -> R ds_dy_tup = ( jvp_fn(delta[1] * 1j) if isinstance(delta, tuple) else jvp_fn(delta * 1j) ) for ds_dx, ds_dy in zip(ds_dx_tup, ds_dy_tup): assert not ds_dx.is_complex() # conjugate wirtinger derivative conj_w_d = ds_dx + ds_dy * 1j jvps.append(conj_w_d) else: for ds_dx in ds_dx_tup: # R -> R or (R -> C for the forward AD case) assert is_forward_ad or not ds_dx.is_complex() jvps.append(ds_dx) return jvps def _combine_jacobian_cols( jacobians_cols: Dict[int, List[torch.Tensor]], outputs, input, numel ) -> Tuple[torch.Tensor, ...]: # jacobian_cols maps column_idx -> output_idx -> single column of jacobian Tensor # we return a list that maps output_idx -> full jacobian Tensor jacobians = _allocate_jacobians_with_outputs( outputs, numel, dtype=input.dtype if input.dtype.is_complex else None ) for i, jacobian in enumerate(jacobians): for k, v in jacobians_cols.items(): jacobian[k] = v[i] return jacobians def _prepare_input( input: torch.Tensor, maybe_perturbed_input: Optional[torch.Tensor], fast_mode=False ) -> torch.Tensor: # Prepares the inputs to be passed into the function while including the new # modified input. if input.layout == torch._mkldnn: # type: ignore[attr-defined] # no attr _mkldnn # Convert back to mkldnn if maybe_perturbed_input is not None: return maybe_perturbed_input.to_mkldnn() else: return input elif _is_sparse_any_tensor(input): if fast_mode and maybe_perturbed_input is not None: # entry is already a "cloned" version of the original tensor # thus changes to entry are not reflected in the input return maybe_perturbed_input else: return input else: # We cannot use entry (input.data) if we want gradgrad to work because # fn (in the gradgrad case) needs to compute grad wrt input return input def _check_outputs_same_dtype_and_shape(output1, output2, eps, idx=None) -> None: # Check that the returned outputs don't have different dtype or shape when you # perturb the input on_index = "on index {idx} " if idx is not None else "" assert output1.shape == output2.shape, ( f"Expected `func` to return outputs with the same shape" f" when inputs are perturbed {on_index}by {eps}, but got:" f" shapes {output1.shape} and {output2.shape}." ) assert output1.dtype == output2.dtype, ( f"Expected `func` to return outputs with the same dtype" f" when inputs are perturbed {on_index}by {eps}, but got:" f" dtypes {output1.dtype} and {output2.dtype}." ) def get_numerical_jacobian_wrt_specific_input( fn, input_idx, inputs, outputs, eps, input=None, is_forward_ad=False ) -> Tuple[torch.Tensor, ...]: # Computes the numerical jacobians wrt to a single input. Returns N jacobian # tensors, where N is the number of outputs. We use a dictionary for # jacobian_cols because indices aren't necessarily consecutive for sparse inputs # When we perturb only a single element of the input tensor at a time, the jvp # is equivalent to a single col of the Jacobian matrix of fn. jacobian_cols: Dict[int, List[torch.Tensor]] = {} input = inputs[input_idx] if input is None else input assert input.requires_grad for x, idx, d_idx in _iter_tensor(input): wrapped_fn = _with_prepare_inputs(fn, inputs, input_idx, x) input_to_perturb = x[idx] nbhd_checks_fn = functools.partial( _check_outputs_same_dtype_and_shape, idx=idx, eps=eps ) jvp_fn = _get_numerical_jvp_fn( wrapped_fn, input_to_perturb, eps, nbhd_checks_fn ) jacobian_cols[d_idx] = _compute_numerical_jvps_wrt_specific_input( jvp_fn, eps, x.is_complex(), is_forward_ad ) return _combine_jacobian_cols(jacobian_cols, outputs, input, input.numel()) def _get_analytical_jacobian_forward_ad( fn, inputs, outputs, *, check_grad_dtypes=False, all_u=None ) -> Tuple[Tuple[torch.Tensor, ...], ...]: """Compute the analytical Jacobian using forward mode AD of `fn(inputs)` using forward mode AD with respect to `target`. Return N * M Jacobians where N is the number of tensors in target that require grad and M is the number of non-integral outputs. Contrary to other functions here, this function requires "inputs" to actually be used by the function. The computed value is expected to be wrong if the function captures the inputs by side effect instead of using the passed ones (many torch.nn tests do this). Args: fn: the function to compute the jacobian for inputs: inputs to `fn` outputs: provide precomputed outputs to avoid one extra invocation of fn check_grad_dtypes: if True, will check that the gradient dtype are valid all_u (optional): if provided, the Jacobian will be right multiplied with this vector Returns: A tuple of M N-tuples of tensors """ # To avoid early import issues fwAD = torch.autograd.forward_ad tensor_inputs = tuple(i for i in inputs if is_tensor_like(i) and i.requires_grad) if any(i.is_complex() for i in tensor_inputs): raise ValueError( "Expected inputs to be non-complex for _get_analytical_jacobian_forward_ad." ) if all_u: jacobians = tuple( _allocate_jacobians_with_outputs(outputs, 1) for i in tensor_inputs ) else: jacobians = tuple( _allocate_jacobians_with_outputs(outputs, i.numel()) for i in tensor_inputs ) with fwAD.dual_level(): fw_grads = [] dual_inputs = [] for i, inp in enumerate(inputs): if is_tensor_like(inp) and inp.requires_grad: if inp.layout == torch._mkldnn: # type: ignore[attr-defined] raise ValueError( "MKLDNN inputs are not support for forward AD gradcheck." ) inp = fwAD.make_dual(inp.detach(), torch.zeros_like(inp)) # If inp is a differentiable view, the dual might not be the tangent given to # make_dual, so read it explicitly from the dual tensor fw_grads.append(fwAD.unpack_dual(inp)[1]) dual_inputs.append(inp) if all_u: # Do the full reduction in one pass # To be consistent with numerical evaluation, we actually compute one reduction per input for i, (fw_grad, u) in enumerate(zip(fw_grads, all_u)): fw_grad.copy_(u.view_as(fw_grad)) raw_outputs = _as_tuple(fn(*dual_inputs)) dual_outputs = filter(_is_float_or_complex_tensor, raw_outputs) for index_o, d_o in enumerate(dual_outputs): val, res = fwAD.unpack_dual(d_o) if ( check_grad_dtypes and res is not None and val.is_complex() != res.is_complex() ): raise GradcheckError("Forward AD gradient has dtype mismatch.") # Remove extra dimension of size 1 corresponding to the reduced input jacobians[i][index_o].squeeze_(0) if res is None: jacobians[i][index_o].zero_() else: jacobians[i][index_o].copy_(res.reshape(-1)) fw_grad.zero_() else: # Reconstruct the full Jacobian column by column for i, fw_grad in enumerate(fw_grads): for lin_idx, grad_idx in enumerate( product(*[range(m) for m in fw_grad.size()]) ): fw_grad[grad_idx] = 1.0 raw_outputs = _as_tuple(fn(*dual_inputs)) dual_outputs = filter(_is_float_or_complex_tensor, raw_outputs) for index_o, d_o in enumerate(dual_outputs): val, res = fwAD.unpack_dual(d_o) if ( check_grad_dtypes and res is not None and val.is_complex() != res.is_complex() ): raise GradcheckError( "Forward AD gradient has dtype mismatch." ) if res is None: jacobians[i][index_o][lin_idx].zero_() else: jacobians[i][index_o][lin_idx].copy_(res.reshape(-1)) fw_grad[grad_idx] = 0.0 return jacobians def _get_input_to_perturb(input): # Prepare the input so that it can be modified in-place and do certain # operations that require the tensor to have strides. If fast_mode=False, # _iter_tensor would handle the below cases: if input.layout == torch._mkldnn: # type: ignore[attr-defined] # no attr _mkldnn # Convert to dense so we can perform operations that require strided tensors input_to_perturb = input.to_dense() elif _is_sparse_any_tensor(input): # Clone because input may require grad, and copy_ calls resize_, # which is not allowed for .data input_to_perturb = input.clone() else: input_to_perturb = input.data return input_to_perturb def _with_prepare_inputs(fn, inputs, input_idx, input_to_perturb, fast_mode=False): # Wraps `fn` so that its inputs are already supplied def wrapped_fn(): inp = tuple( _prepare_input(a, input_to_perturb if i == input_idx else None, fast_mode) if is_tensor_like(a) else a for i, a in enumerate(_as_tuple(inputs)) ) return tuple(a.clone() for a in _as_tuple(fn(*inp))) return wrapped_fn def _get_numerical_jvp_fn(wrapped_fn, input_to_perturb, eps, nbhd_checks_fn): # Wraps jvp_fn so that certain arguments are already supplied def jvp_fn(delta): return _compute_numerical_gradient( wrapped_fn, input_to_perturb, delta, eps, nbhd_checks_fn ) return jvp_fn def _reshape_tensor_or_tuple(u, shape): # We don't need to reshape when input corresponding to u is sparse if isinstance(u, tuple): if not _is_sparse_any_tensor(u[0]): return (u[0].reshape(shape), u[1].reshape(shape)) else: if not _is_sparse_any_tensor(u): return u.reshape(shape) return u def _mul_tensor_or_tuple(u, k): if isinstance(u, tuple): return (k * u[0], k * u[1]) else: return k * u def _get_numerical_jvp_wrt_specific_input( fn, input_idx, inputs, u, eps, is_forward_ad=False ) -> List[torch.Tensor]: input = inputs[input_idx] input_to_perturb = _get_input_to_perturb(input) wrapped_fn = _with_prepare_inputs(fn, inputs, input_idx, input_to_perturb, True) nbhd_checks_fn = functools.partial(_check_outputs_same_dtype_and_shape, eps=eps) jvp_fn = _get_numerical_jvp_fn(wrapped_fn, input_to_perturb, eps, nbhd_checks_fn) u = _reshape_tensor_or_tuple(u, input_to_perturb.shape) u = _mul_tensor_or_tuple(u, eps) return _compute_numerical_jvps_wrt_specific_input( jvp_fn, u, input.is_complex(), is_forward_ad ) def _get_numerical_vJu( fn, inputs, inp_indices, func_out, all_u, all_v, eps, is_forward_ad ): # Note that all_v can also be None, in that case, this function only computes Ju. reduced_jacobians: List[List[torch.Tensor]] = [] for i, (inp_idx, u) in enumerate(zip(inp_indices, all_u)): all_Ju = _get_numerical_jvp_wrt_specific_input( fn, inp_idx, inputs, u, eps, is_forward_ad ) # Filter out the Ju for non floating point outputs filtered_Ju = [] func_out = _as_tuple(func_out) assert len(all_Ju) == len(func_out) for Ju, output in zip(all_Ju, func_out): if _is_float_or_complex_tensor(output): filtered_Ju.append(Ju) else: # TODO: handle the other Ju pass if all_v is not None: jacobian_scalars: List[torch.Tensor] = [] for v, Ju in zip(all_v, filtered_Ju): jacobian_scalars.append(_dot_with_type_promotion(v, Ju)) reduced_jacobians.append(jacobian_scalars) else: reduced_jacobians.append(filtered_Ju) return reduced_jacobians def _check_jacobians_equal(j1, j2, atol): # Check whether the max difference between two Jacobian tensors are within some # tolerance `atol`. for j1_x, j2_x in zip(j1, j2): if j1_x.numel() != 0 and (j1_x - j2_x).abs().max() > atol: return False return True def _stack_and_check_tensors( list_of_list_of_tensors, inputs, numel_outputs ) -> Tuple[Tuple[torch.Tensor, ...], bool, bool]: # For the ith tensor in the inner list checks whether it has the same size and # dtype as the ith differentiable input. out_jacobians = _allocate_jacobians_with_inputs(inputs, numel_outputs) diff_input_list = list(_iter_tensors(inputs, True)) correct_grad_sizes = True correct_grad_types = True for i, tensor_list in enumerate(list_of_list_of_tensors): inp = diff_input_list[i] out_jacobian = out_jacobians[i] for j, tensor in enumerate(tensor_list): if tensor is not None and tensor.size() != inp.size(): correct_grad_sizes = False elif tensor is not None and tensor.dtype != inp.dtype: correct_grad_types = False if tensor is None: out_jacobian[:, j].zero_() else: dense = ( tensor.to_dense() if not tensor.layout == torch.strided else tensor ) assert out_jacobian[:, j].numel() == dense.numel() out_jacobian[:, j] = dense.reshape(-1) return out_jacobians, correct_grad_sizes, correct_grad_types FAILED_NONDET_MSG = """\n NOTE: If your op relies on non-deterministic operations i.e., it is listed here: https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html this failure might be expected. If you are adding a new operator, please file an issue and then use one of the workarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck. If the test - manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck with `nondet_tol=<tol>` as a keyword argument. - is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test to have `gradcheck_nondet_tol=<tol>`. - is a Module test (e.g., in common_nn.py), then modify the corresponding module_test entry to have `gradcheck_nondet_tol=<tol>` """ def _check_analytical_jacobian_attributes( inputs, output, nondet_tol, check_grad_dtypes, fast_mode=False, v=None ) -> Tuple[torch.Tensor, ...]: # This is used by both fast and slow mode: # - For slow mode, vjps[i][j] is the jth row of the Jacobian wrt the ith # input. # - For fast mode, vjps[i][0] is a linear combination of the rows # of the Jacobian wrt the ith input diff_input_list = list(_iter_tensors(inputs, True)) def vjp_fn(grad_output): return torch.autograd.grad( output, diff_input_list, grad_output, retain_graph=True, allow_unused=True ) # Compute everything twice to check for nondeterminism (which we call reentrancy) if fast_mode: vjps1 = _get_analytical_vjps_wrt_specific_output(vjp_fn, output.clone(), v) vjps2 = _get_analytical_vjps_wrt_specific_output(vjp_fn, output.clone(), v) else: vjps1 = _compute_analytical_jacobian_rows(vjp_fn, output.clone()) vjps2 = _compute_analytical_jacobian_rows(vjp_fn, output.clone()) output_numel = output.numel() if not fast_mode else 1 jacobians1, types_ok, sizes_ok = _stack_and_check_tensors( vjps1, inputs, output_numel ) jacobians2, _, _ = _stack_and_check_tensors(vjps2, inputs, output_numel) reentrant = _check_jacobians_equal(jacobians1, jacobians2, nondet_tol) if not types_ok and check_grad_dtypes: raise GradcheckError("Gradient has dtype mismatch") if not sizes_ok: raise GradcheckError("Analytical gradient has incorrect size") if not reentrant: raise GradcheckError( "Backward is not reentrant, i.e., running backward with " "same input and grad_output multiple times gives different values, " "although analytical gradient matches numerical gradient." f"The tolerance for nondeterminism was {nondet_tol}." + FAILED_NONDET_MSG ) return jacobians1 def _get_analytical_vJu_backward_mode( inputs, outputs, nondet_tol, check_grad_dtypes, all_v, all_u ): reduced_jacobians: List[List[torch.Tensor]] = [] for output, v in zip(outputs, all_v): all_vJ = _check_analytical_jacobian_attributes( inputs, output, nondet_tol, check_grad_dtypes, fast_mode=True, v=v ) jacobian_scalars: List[torch.Tensor] = [] for vJ, u in zip(all_vJ, all_u): # Why do we need squeeze here? vJ is a 2-d tensor so that we can reuse # the error checking logic from slow mode vJ = vJ.T.squeeze(0) if vJ.is_complex(): # C -> R tv = torch.view_as_real(vJ.resolve_conj()) tr = tv.select(-1, 0) ti = tv.select(-1, 1) jacobian_scalars.append(tr.dot(u[0]) + 1j * ti.dot(u[1])) else: # R -> R jacobian_scalars.append(vJ.dot(u)) reduced_jacobians.append(jacobian_scalars) return reduced_jacobians def get_analytical_jacobian(inputs, output, nondet_tol=0.0, grad_out=1.0): # Replicates the behavior of the old get_analytical_jacobian before the refactor # This shares much of its code with _check_analytical_jacobian_attributes warnings.warn( "get_analytical_jacobian was part of PyTorch's private API and not " "meant to be exposed. We are deprecating it and it will be removed " "in a future version of PyTorch. If you have a specific use for " "this or feature request for this to be a stable API, please file " "us an issue at https://github.com/pytorch/pytorch/issues/new" ) if ( grad_out != 1.0 ): # grad_out param is only kept for backward compatibility reasons raise ValueError( "Expected grad_out to be 1.0. get_analytical_jacobian no longer " "supports values of grad_out != 1.0." ) if output.is_complex(): raise ValueError( "Expected output to be non-complex. get_analytical_jacobian no " "longer supports functions that return complex outputs." ) diff_input_list = list(_iter_tensors(inputs, True)) def vjp_fn(grad_output): return torch.autograd.grad( output, diff_input_list, grad_output, retain_graph=True, allow_unused=True ) # Compute everything twice to check for nondeterminism (which we call reentrancy) vjps1 = _compute_analytical_jacobian_rows(vjp_fn, output.clone()) vjps2 = _compute_analytical_jacobian_rows(vjp_fn, output.clone()) output_numel = output.numel() jacobians1, types_ok, sizes_ok = _stack_and_check_tensors( vjps1, inputs, output_numel ) jacobians2, _, _ = _stack_and_check_tensors(vjps2, inputs, output_numel) reentrant = _check_jacobians_equal(jacobians1, jacobians2, nondet_tol) return jacobians1, reentrant, sizes_ok, types_ok def _get_analytical_jacobian(inputs, outputs, input_idx, output_idx): # Computes the analytical Jacobian in slow mode for a single input-output pair. # Forgoes performing checks on dtype, shape, and reentrancy. jacobians = _check_analytical_jacobian_attributes( inputs, outputs[output_idx], nondet_tol=float("inf"), check_grad_dtypes=False ) return jacobians[input_idx] def _compute_analytical_jacobian_rows( vjp_fn, sample_output ) -> List[List[Optional[torch.Tensor]]]: # Computes Jacobian row-by-row by projecting `vjp_fn` = v^T J on standard basis # vectors: vjp_fn(e) = e^T J is a corresponding row of the Jacobian. # NB: this function does not assume vjp_fn(v) to return tensors with the same # number of elements for different v. This is checked when we later combine the # rows into a single tensor. grad_out_base = torch.zeros_like( sample_output, memory_format=torch.legacy_contiguous_format ) flat_grad_out = grad_out_base.view(-1) # jacobians_rows[i][j] is the Jacobian jth row for the ith input jacobians_rows: List[List[Optional[torch.Tensor]]] = [] for j in range(flat_grad_out.numel()): flat_grad_out.zero_() flat_grad_out[j] = 1.0 # projection for jth row of Jacobian grad_inputs = vjp_fn(grad_out_base) for i, d_x in enumerate(grad_inputs): if j == 0: jacobians_rows.append([]) jacobians_rows[i] += [ d_x.clone() if isinstance(d_x, torch.Tensor) else None ] return jacobians_rows def _get_analytical_vjps_wrt_specific_output( vjp_fn, sample_output, v ) -> List[List[Optional[torch.Tensor]]]: vjps: List[List[Optional[torch.Tensor]]] = [] grad_inputs = vjp_fn(v.reshape(sample_output.shape)) for vjp in grad_inputs: vjps.append([vjp.clone() if isinstance(vjp, torch.Tensor) else None]) return vjps def _check_inputs(tupled_inputs) -> bool: # Make sure that gradients are saved for at least one input any_input_requiring_grad = False for idx, inp in enumerate(tupled_inputs): if is_tensor_like(inp) and inp.requires_grad: if not (inp.dtype == torch.float64 or inp.dtype == torch.complex128): warnings.warn( f"Input #{idx} requires gradient and " "is not a double precision floating point or complex. " "This check will likely fail if all the inputs are " "not of double precision floating point or complex. " ) if inp.is_sparse: content = inp._values() elif _is_sparse_compressed_tensor(inp): content = inp.values() else: content = inp # TODO: To cover more problematic cases, replace stride = 0 check with # "any overlap in memory" once we have a proper function to check it. if content.layout is not torch._mkldnn: # type: ignore[attr-defined] if not all( st > 0 or sz <= 1 for st, sz in zip(content.stride(), content.size()) ): raise RuntimeError( f"The {idx}th input has a dimension with stride 0. gradcheck only " "supports inputs that are non-overlapping to be able to " "compute the numerical gradients correctly. You should call " ".contiguous on the input before passing it to gradcheck." ) any_input_requiring_grad = True if not any_input_requiring_grad: raise ValueError( "gradcheck expects at least one input tensor to require gradient, " "but none of the them have requires_grad=True." ) return True def _check_outputs(outputs) -> None: if any(_is_sparse_any_tensor(t) for t in outputs if isinstance(t, torch.Tensor)): # it is easier to call to_dense() on the sparse output than # to modify analytical jacobian raise ValueError( "Sparse output is not supported at gradcheck yet. " "Please call to_dense(masked_grad=...) on the output of fn for gradcheck." ) if any(t.layout == torch._mkldnn for t in outputs if isinstance(t, torch.Tensor)): # type: ignore[attr-defined] raise ValueError( "MKLDNN output is not supported at gradcheck yet. " "Please call to_dense(masked_grad=...) on the output of fn for gradcheck." ) def _check_no_differentiable_outputs( func, inputs, func_out, eps, *, is_forward_ad ) -> bool: # When there are no differentiable outputs, numerical gradient for a function is # expected to be zero. jacobians_all_inputs_outputs = _get_numerical_jacobian( func, inputs, func_out, eps=eps, is_forward_ad=is_forward_ad ) for jacobians_all_outputs_and_fixed_input in jacobians_all_inputs_outputs: for jacobian in jacobians_all_outputs_and_fixed_input: if torch.ne(jacobian, 0).sum() > 0: raise GradcheckError( "Numerical gradient for function expected to be zero" ) return True def _check_no_differentiable_outputs_fast( func, func_out, all_inputs, inputs_indices, all_u, eps, nondet_tol ): for inp_idx, u in zip(inputs_indices, all_u): jvps = _get_numerical_jvp_wrt_specific_input(func, inp_idx, all_inputs, u, eps) for jvp in jvps: if jvp.numel() == 0: continue if (jvp - torch.zeros_like(jvp)).abs().max() > nondet_tol: raise GradcheckError( "Numerical gradient for function expected to be zero" ) return True FAILED_BATCHED_GRAD_MSG = """ gradcheck or gradgradcheck failed while testing batched gradient computation. This could have been invoked in a number of ways (via a test that calls gradcheck/gradgradcheck directly or via an autogenerated test). If you are adding a new operator, please file an issue and then use one of the workarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck. If the test - manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck with `check_batched_grad=False` as a keyword argument. - is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test to have `check_batched_grad=False` and/or `check_batched_gradgrad=False`. If you're modifying an existing operator that supports batched grad computation, or wish to make a new operator work with batched grad computation, please read the following. To compute batched grads (e.g., jacobians, hessians), we vmap over the backward computation. The most common failure case is if there is a 'vmap-incompatible operation' in the backward pass. Please see NOTE: [How to write vmap-compatible backward formulas] in the codebase for an explanation of how to fix this. """.strip() FAILED_BATCHED_GRAD_MSG_FWD_AD = """ gradcheck failed while testing batched gradient computation with forward-mode AD. This test is enabled automatically when both `check_batched_grad=True` and `check_forward_ad=True`, but can be disabled in the following ways dependong on how the test was invoked (via a test that calls gradcheck directly or via an autogenerated test). If you are adding a new operator, please file an issue and then use one of the workarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck. If the test - manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck with `check_batched_forward_grad=False` as a keyword argument. - is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test to have `check_batched_forward_grad=False` """ def _get_failed_batched_grad_test_msg( output_idx, input_idx, res, exp, is_forward_ad=False ): return f""" For output {output_idx} and input {input_idx}: {FAILED_BATCHED_GRAD_MSG_FWD_AD if is_forward_ad else FAILED_BATCHED_GRAD_MSG} Got: {res} Expected: {exp} """.strip() def _test_batched_grad_forward_ad(func, inputs) -> bool: fwAD = torch.autograd.forward_ad # To avoid early import issues (do we need this?) assert isinstance(inputs, tuple) for input_idx, current_input in enumerate(inputs): if not (is_tensor_like(current_input) and current_input.requires_grad): continue def jvp(tangent: torch.Tensor): with fwAD.dual_level(): dual = fwAD.make_dual(current_input.detach(), tangent) inputs_with_dual = tuple( dual if idx == input_idx else (inp.detach() if is_tensor_like(inp) else inp) for idx, inp in enumerate(inputs) ) dual_outputs = _as_tuple(func(*inputs_with_dual)) ret = [] for dual_output in dual_outputs: if dual_output is None: continue primal_out, tangent_out = fwAD.unpack_dual(dual_output) if tangent_out is not None: ret.append(tangent_out) else: ret.append( torch.zeros( [], dtype=primal_out.dtype, device=primal_out.device ).expand(primal_out.shape) ) return tuple(ret) if not _is_float_or_complex_tensor(current_input): continue tangents = [torch.randn_like(current_input) for _ in range(2)] expected = [jvp(t) for t in tangents] expected = [torch.stack(shards) for shards in zip(*expected)] try: result = _vmap(jvp)(torch.stack(tangents)) except RuntimeError as ex: # Rethrow to provide a better error message raise GradcheckError( f"While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG_FWD_AD}" ) from ex for input_idx, (res, exp) in enumerate(zip(result, expected)): if torch.allclose(res, exp): continue raise GradcheckError( _get_failed_batched_grad_test_msg( input_idx, input_idx, res, exp, is_forward_ad=True ) ) return True def _test_batched_grad(input, output, output_idx) -> bool: # NB: _test_batched_grad compares two autograd.grad invocations with a single # vmap(autograd.grad) invocation. It's not exactly a "gradcheck" in the # sense that we're not comparing an analytical jacobian with a numeric one, # but it is morally similar (we could have computed a full analytic jac # via vmap, but that is potentially slow) diff_input_list = list(_iter_tensors(input, True)) grad = functools.partial( torch.autograd.grad, output, diff_input_list, retain_graph=True, allow_unused=True, ) def vjp(v): results = grad(v) results = tuple( grad if grad is not None else torch.zeros([], dtype=inp.dtype, device=inp.device).expand(inp.shape) for grad, inp in zip(results, diff_input_list) ) return results grad_outputs = [torch.randn_like(output) for _ in range(2)] expected = [vjp(gO) for gO in grad_outputs] expected = [torch.stack(shards) for shards in zip(*expected)] # Squash warnings since these are expected to happen in most cases # NB: this doesn't work for CUDA tests: https://github.com/pytorch/pytorch/issues/50209 with warnings.catch_warnings(): warnings.filterwarnings("ignore", message="There is a performance drop") warnings.filterwarnings("ignore", message="Please use torch.vmap") try: result = vmap(vjp)(torch.stack(grad_outputs)) except RuntimeError as ex: # It's OK that we're not raising the error at the correct callsite. # That's because the callsite is always going to inside the Python # autograd.grad instead of the C++ traceback of what line in the # backward formula raise GradcheckError( f"While computing batched gradients, got: {ex}\n\n{FAILED_BATCHED_GRAD_MSG}" ) from ex for input_idx, (res, exp) in enumerate(zip(result, expected)): if torch.allclose(res, exp): continue raise GradcheckError( _get_failed_batched_grad_test_msg(output_idx, input_idx, res, exp) ) return True def _test_backward_mul_by_grad_output(outputs, inputs, masked) -> bool: # Tests that backward is multiplied by grad_output diff_input_list: List[torch.Tensor] = list(_iter_tensors(inputs, True)) if not diff_input_list: raise GradcheckError("no Tensors requiring grad found in input") grads_input = torch.autograd.grad( outputs, diff_input_list, [ torch.zeros_like(o, memory_format=torch.legacy_contiguous_format) for o in outputs ], allow_unused=True, ) for gi, di in zip(grads_input, diff_input_list): if gi is None: continue if isinstance(gi, torch.Tensor) and gi.layout != torch.strided: if gi.layout != di.layout: raise GradcheckError( "grad is incorrect layout (" + str(gi.layout) + " is not " + str(di.layout) + ")" ) if _is_sparse_any_tensor(gi): sparse_kind = str(gi.layout).replace("torch.", "").replace("_coo", "") if gi.sparse_dim() != di.sparse_dim(): raise GradcheckError( f"grad is {sparse_kind} tensor, but has incorrect sparse_dim" f" {gi.sparse_dim()}, expected {di.sparse_dim()}" ) if gi.dense_dim() != di.dense_dim(): raise GradcheckError( f"grad is {sparse_kind} tensor, but has incorrect dense_dim" f" {gi.dense_dim()}, expected {di.dense_dim()}" ) gi = gi.to_dense() di = di.to_dense() if masked: if not torch.allclose(gi, torch.zeros_like(gi)): raise GradcheckError("backward not multiplied by grad_output") elif not gi.eq(0).all(): raise GradcheckError("backward not multiplied by grad_output") if gi.dtype != di.dtype: raise GradcheckError("grad is incorrect type") if gi.device != di.device: raise GradcheckError("grad is incorrect device") if gi.size() != di.size(): raise GradcheckError("grad is incorrect size") return True def _test_undefined_forward_mode(func, outputs, inputs): fwAD = torch.autograd.forward_ad inp_tensors_idx, inp_tensors = _get_inp_tensors(inputs) all_v, all_u, all_u_dense = _make_vectors(inp_tensors, outputs, use_forward_ad=True) tensor_inputs = tuple(i for i in inputs if is_tensor_like(i) and i.requires_grad) with fwAD.dual_level(): fw_grads = [] dual_inputs = [] tensor_indices = set() for i, inp in enumerate(inputs): if is_tensor_like(inp) and inp.requires_grad: if inp.layout == torch._mkldnn: # type: ignore[attr-defined] raise ValueError( "MKLDNN inputs are not support for forward AD gradcheck." ) inp = fwAD.make_dual(inp.detach(), torch.zeros_like(inp)) # If inp is a differentiable view, the dual might not be the tangent given to # make_dual, so read it explicitly from the dual tensor fw_grads.append(fwAD.unpack_dual(inp)[1]) tensor_indices.add(i) dual_inputs.append(inp) for i, (fw_grad, u) in enumerate(zip(fw_grads, all_u)): fw_grad.copy_(u.view_as(fw_grad)) for idx, inp in enumerate(inputs): if idx not in tensor_indices: continue dual_inp_obj = dual_inputs[idx] # case 1 (Materialized Zero Tensor Tangent) dual_inputs[idx] = fwAD.make_dual(inp.detach(), torch.zeros_like(inp)) raw_outputs = _as_tuple(func(*dual_inputs)) dual_outputs1 = filter(_is_float_or_complex_tensor, raw_outputs) # case 2 (Efficient Zero Tensor Tangent since we don't make a dual object and pass a regular tensor) dual_inputs[idx] = inp.detach() raw_outputs = _as_tuple(func(*dual_inputs)) dual_outputs2 = filter(_is_float_or_complex_tensor, raw_outputs) # reset dual_inputs[idx] = dual_inp_obj for index_o, (d_o1, d_o2) in enumerate(zip(dual_outputs1, dual_outputs2)): val1, res1 = fwAD.unpack_dual(d_o1) val2, res2 = fwAD.unpack_dual(d_o2) if not (res1 is None or res2 is None): if not torch.allclose(res1, res2): raise GradcheckError( "Mismatch in tangent values for output with index: ", index_o, " when input: ", inp, " has an undefined tangent value. ", " Got: ", res1, " but expected: ", res2, ) return True def _test_undefined_backward_mode(func, outputs, inputs) -> bool: diff_input_list: List[torch.Tensor] = list(_iter_tensors(inputs, True)) if not diff_input_list: raise GradcheckError("no Tensors requiring grad found in input") def warn_bc_breaking(): warnings.warn( "Backwards compatibility: New undefined gradient support checking " "feature is enabled by default, but it may break existing callers " "of this function. If this is true for you, you can call this " 'function with "check_undefined_grad=False" to disable the feature' ) def check_undefined_grad_support(output_to_check): grads_output = [ torch.zeros_like(o, memory_format=torch.legacy_contiguous_format) for o in output_to_check ] try: grads_input = torch.autograd.grad( output_to_check, diff_input_list, grads_output, allow_unused=True ) except RuntimeError as e: warn_bc_breaking() raise GradcheckError( "Expected backward function to handle undefined output grads. " 'Please look at "Notes about undefined output gradients" in ' '"tools/autograd/derivatives.yaml"' ) from e for gi, i in zip(grads_input, diff_input_list): if (gi is not None) and (not gi.eq(0).all()): warn_bc_breaking() raise GradcheckError( "Expected all input grads to be undefined or zero when all output grads are undefined " 'or zero. Please look at "Notes about undefined output gradients" in ' '"tools/autograd/derivatives.yaml"' ) return True # All backward functions must work properly if all output grads are undefined outputs_to_check = [ [ torch._C._functions.UndefinedGrad()(o) for o in _differentiable_outputs(func(*inputs)) # This check filters out Tensor-likes that aren't instances of Tensor. if isinstance(o, torch.Tensor) ] ] # If there are multiple output grads, we should be able to undef one at a time without error if len(outputs_to_check[0]) > 1: for undef_grad_idx in range(len(outputs)): output_to_check = _differentiable_outputs(func(*inputs)) outputs_to_check.append( [ torch._C._functions.UndefinedGrad()(o) if idx == undef_grad_idx else o for idx, o in enumerate(output_to_check) ] ) return all(check_undefined_grad_support(output) for output in outputs_to_check) def _as_tuple(x): if isinstance(x, tuple): return x elif isinstance(x, list): return tuple(x) else: return (x,) def _differentiable_outputs(x): return tuple(o for o in _as_tuple(x) if o.requires_grad) def _get_notallclose_msg( analytical, numerical, output_idx, input_idx, complex_indices, test_imag=False, is_forward_ad=False, ) -> str: out_is_complex = ( (not is_forward_ad) and complex_indices and output_idx in complex_indices ) inp_is_complex = is_forward_ad and complex_indices and input_idx in complex_indices part = "imaginary" if test_imag else "real" element = "inputs" if is_forward_ad else "outputs" prefix = ( "" if not (out_is_complex or inp_is_complex) else f"While considering the {part} part of complex {element} only, " ) mode = "computed with forward mode " if is_forward_ad else "" return ( prefix + "Jacobian %smismatch for output %d with respect to input %d,\n" "numerical:%s\nanalytical:%s\n" % (mode, output_idx, input_idx, numerical, analytical) ) def _transpose(matrix_of_tensors): # returns list of tuples return list(zip(*matrix_of_tensors)) def _real_and_imag_output(fn): # returns new functions real(fn), and imag(fn) where real(fn) and imag(fn) behave the same as # the original fn, except torch.real or torch.imag are applied to the complex outputs def apply_to_c_outs(fn, fn_to_apply): def wrapped_fn(*inputs): outs = _as_tuple(fn(*inputs)) return tuple(fn_to_apply(o) if o.is_complex() else o for o in outs) return wrapped_fn return apply_to_c_outs(fn, torch.real), apply_to_c_outs(fn, torch.imag) def _real_and_imag_input(fn, complex_inp_indices, tupled_inputs): # returns new functions that take real inputs instead of complex inputs as # (x, y) -> fn(x + y * 1j). And it computes: inp -> fn(inp + y * 1j) and inp -> fn(x + inp * 1j). # In each case, the other part is considered constant. # We do not use 0 for the constant here to make sure we always call the user function with a valid input. def apply_to_c_inps(fn, fn_to_apply): def wrapped_fn(*inputs): new_inputs = list(inputs) for should_be_complex in complex_inp_indices: new_inputs[should_be_complex] = fn_to_apply( new_inputs[should_be_complex], tupled_inputs[should_be_complex] ) return _as_tuple(fn(*new_inputs)) return wrapped_fn real_fn = apply_to_c_inps(fn, lambda inp, orig: inp + orig.imag * 1j) imag_fn = apply_to_c_inps(fn, lambda inp, orig: orig.real + inp * 1j) return real_fn, imag_fn def _gradcheck_real_imag( gradcheck_fn, func, func_out, tupled_inputs, outputs, eps, rtol, atol, check_grad_dtypes, check_forward_ad, check_backward_ad, nondet_tol, check_undefined_grad, ): complex_out_indices = [i for i, o in enumerate(outputs) if o.is_complex()] has_any_complex_output = any(o.is_complex() for o in _as_tuple(func_out)) if check_backward_ad: if has_any_complex_output: real_fn, imag_fn = _real_and_imag_output(func) imag_func_out = imag_fn(*tupled_inputs) imag_outputs = _differentiable_outputs(imag_func_out) gradcheck_fn( imag_fn, imag_func_out, tupled_inputs, imag_outputs, eps, rtol, atol, check_grad_dtypes, nondet_tol, complex_indices=complex_out_indices, test_imag=True, ) real_func_out = real_fn(*tupled_inputs) real_outputs = _differentiable_outputs(real_func_out) gradcheck_fn( real_fn, real_func_out, tupled_inputs, real_outputs, eps, rtol, atol, check_grad_dtypes, nondet_tol, complex_indices=complex_out_indices, ) else: gradcheck_fn( func, func_out, tupled_inputs, outputs, eps, rtol, atol, check_grad_dtypes, nondet_tol, ) if check_forward_ad: complex_inp_indices = [ i for i, inp in enumerate(tupled_inputs) if is_tensor_like(inp) and inp.is_complex() ] if complex_inp_indices: real_fn, imag_fn = _real_and_imag_input( func, complex_inp_indices, tupled_inputs ) imag_inputs = [ inp.imag if is_tensor_like(inp) and inp.is_complex() else inp for inp in tupled_inputs ] imag_func_out = imag_fn(*imag_inputs) diff_imag_func_out = _differentiable_outputs(imag_func_out) gradcheck_fn( imag_fn, imag_func_out, imag_inputs, diff_imag_func_out, eps, rtol, atol, check_grad_dtypes, nondet_tol, complex_indices=complex_inp_indices, test_imag=True, use_forward_ad=True, ) real_inputs = [ inp.real if is_tensor_like(inp) and inp.is_complex() else inp for inp in tupled_inputs ] real_func_out = real_fn(*real_inputs) diff_real_func_out = _differentiable_outputs(real_func_out) gradcheck_fn( real_fn, real_func_out, real_inputs, diff_real_func_out, eps, rtol, atol, check_grad_dtypes, nondet_tol, complex_indices=complex_inp_indices, use_forward_ad=True, ) if check_undefined_grad: _test_undefined_forward_mode(imag_fn, imag_func_out, imag_inputs) _test_undefined_forward_mode(real_fn, real_func_out, real_inputs) else: gradcheck_fn( func, func_out, tupled_inputs, outputs, eps, rtol, atol, check_grad_dtypes, nondet_tol, use_forward_ad=True, ) if check_undefined_grad: _test_undefined_forward_mode(func, outputs, tupled_inputs) def _slow_gradcheck( func, func_out, tupled_inputs, outputs, eps, rtol, atol, check_grad_dtypes, nondet_tol, *, use_forward_ad=False, complex_indices=None, test_imag=False, masked=False, ): func_out = _as_tuple(func_out) if not outputs: return _check_no_differentiable_outputs( func, tupled_inputs, func_out, eps=eps, is_forward_ad=use_forward_ad ) tupled_inputs_numerical = tupled_inputs if masked else _densify(tupled_inputs) numerical = _transpose( _get_numerical_jacobian( func, tupled_inputs_numerical, func_out, eps=eps, is_forward_ad=use_forward_ad, ) ) # Note: [numerical vs analytical output length] # The numerical path returns jacobian quantity for all outputs, even if requires_grad of that # output is False. This behavior is necessary for _check_no_differentiable_outputs to work. numerical = [nj for o, nj in zip(func_out, numerical) if o.requires_grad] if use_forward_ad: analytical_forward = _get_analytical_jacobian_forward_ad( func, tupled_inputs, func_out, check_grad_dtypes=check_grad_dtypes ) for i, n_per_out in enumerate(numerical): for j, n in enumerate(n_per_out): a = analytical_forward[j][i] if not _allclose_with_type_promotion(a, n.to(a.device), rtol, atol): raise GradcheckError( _get_notallclose_msg( a, n, i, j, complex_indices, test_imag, is_forward_ad=True ) ) else: for i, o in enumerate(outputs): analytical = _check_analytical_jacobian_attributes( tupled_inputs, o, nondet_tol, check_grad_dtypes ) for j, (a, n) in enumerate(zip(analytical, numerical[i])): if not _allclose_with_type_promotion(a, n.to(a.device), rtol, atol): raise GradcheckError( _get_notallclose_msg(a, n, i, j, complex_indices, test_imag) ) return True def _dot_with_type_promotion(u, v): assert u.dim() == 1 and v.dim() == 1 return (u * v).sum() def _allclose_with_type_promotion(a, b, rtol, atol): promoted_type = torch.promote_types(a.dtype, b.dtype) a = a.to(dtype=promoted_type) b = b.to(dtype=promoted_type) return torch.allclose(a, b, rtol, atol) def _to_real_dtype(dtype): if dtype == torch.complex128: return torch.float64 elif dtype == torch.complex64: return torch.float32 else: return dtype def _vec_from_tensor(x, generator, downcast_complex=False): # Create a random vector with the same number of elements as x and the same # dtype/device. If x is complex and downcast_complex is False, we create a # complex tensor with only real component. if x.layout == torch.sparse_coo: # For sparse, create a random sparse vec with random values in the same # indices. Make sure size is set so that it isn't inferred to be smaller. x_values = x._values() dtype = _to_real_dtype(x.dtype) if downcast_complex else x.dtype values = ( torch.rand(x_values.numel(), generator=generator) .to(dtype=dtype, device=x.device) .view(x_values.shape) ) values /= values.norm() vec = torch.sparse_coo_tensor(x._indices(), values, x.size(), device=x.device) elif _is_sparse_compressed_tensor(x): if x.layout in {torch.sparse_csr, torch.sparse_bsr}: compressed_indices, plain_indices = x.crow_indices(), x.col_indices() else: compressed_indices, plain_indices = x.ccol_indices(), x.row_indices() x_values = x.values() dtype = _to_real_dtype(x.dtype) if downcast_complex else x.dtype values = ( torch.rand(x_values.numel(), generator=generator) .to(dtype=dtype, device=x.device) .view(x_values.shape) ) values /= values.norm() vec = torch.sparse_compressed_tensor( compressed_indices, plain_indices, values, x.size(), layout=x.layout, device=x.device, ) else: dtype = _to_real_dtype(x.dtype) if downcast_complex else x.dtype vec = torch.rand(x.numel(), generator=generator).to( dtype=dtype, device=x.device ) vec /= vec.norm() return vec def _get_inp_tensors(tupled_inputs): inp_idx_tup = [ (i, t) for i, t in enumerate(tupled_inputs) if is_tensor_like(t) and t.requires_grad ] return [tup[0] for tup in inp_idx_tup], [tup[1] for tup in inp_idx_tup] def _adjusted_atol(atol, u, v): # In slow gradcheck, we compare A and B element-wise, i.e., for some a, b we # allow: |a - b| < atol + rtol * b. But since we now compare q1 = v^T A u and # q2 = v^T B u, we must allow |q1 - q2| < v^T E u + rtol * v^T B u, where E is # the correctly sized matrix in which each entry is atol. # # We see that atol needs to be scaled by v^T M u (where M is an all-ones M x N # matrix): v^T M u = \sum_{i} \sum_{j} u_i * v_j = (\sum_{i} u_i)(\sum_{i} v_i) # TODO: properly handle case when u is tuple instead of only taking first element u = u[0] if isinstance(u, tuple) else u sum_u = u.sum() sum_v = 1.0 if v is None else v.sum() return atol * float(sum_u) * float(sum_v) FAST_FAIL_SLOW_OK_MSG = """ Fast gradcheck failed but element-wise differences are small. This means that the test might've passed in slow_mode! If you are adding a new operator, please file an issue and then use one of the workarounds. The workaround depends on how your test invokes gradcheck/gradgradcheck: If the test - manually invokes gradcheck/gradgradcheck, then call gradcheck/gradgradcheck with `fast_mode=False` as a keyword argument. - is OpInfo-based (e.g., in test_ops_gradients.py), then modify the OpInfo for the test to have `gradcheck_fast_mode=False` - is a Module test (e.g., in common_nn.py), then modify the corresponding module_test entry to have `gradcheck_fast_mode=False` """.strip() def _run_slow_mode_and_get_error( func, tupled_inputs, outputs, input_idx, output_idx, rtol, atol, eps, is_forward_ad ): # Compute jacobians in slow mode for better error message slow_numerical = _get_numerical_jacobian( func, tupled_inputs, outputs, eps=eps, is_forward_ad=is_forward_ad )[input_idx][output_idx] if is_forward_ad: def new_fn(inp): new_inputs = list(tupled_inputs) new_inputs[input_idx] = inp return _as_tuple(func(*new_inputs))[output_idx] slow_analytical = _get_analytical_jacobian_forward_ad( new_fn, (tupled_inputs[input_idx],), (outputs[output_idx],) )[0][0] else: slow_analytical = _get_analytical_jacobian( tupled_inputs, outputs, input_idx, output_idx ) # Assume jacobians are non-empty and have the same shape slow_max_diff = (slow_numerical - slow_analytical).abs().max() slow_allclose = torch.allclose(slow_analytical, slow_numerical, rtol, atol) msg = ( "\nThe above quantities relating the numerical and analytical jacobians are computed \n" "in fast mode. See: https://github.com/pytorch/pytorch/issues/53876 for more background \n" "about fast mode. Below, we recompute numerical and analytical jacobians in slow mode:\n\n" f"Numerical:\n {slow_numerical}\n" f"Analytical:\n{slow_analytical}\n\n" f"The max per-element difference (slow mode) is: {slow_max_diff}.\n" ) if slow_allclose: # Slow gradcheck would've passed! msg += FAST_FAIL_SLOW_OK_MSG return msg def _to_flat_dense_if_sparse(tensor): if _is_sparse_any_tensor(tensor): return tensor.to_dense().reshape(-1) else: return tensor def _make_vectors(inp_tensors, outputs, *, use_forward_ad): # Use our own generator to avoid messing with the user's RNG state g_cpu = torch.Generator() def _vec_from_tensor_cpu(*args): # Default allocate all tensors on CPU, so they are on the same device as the generator # even if the user specified a default device with torch.device("cpu"): return _vec_from_tensor(*args) all_u = [] all_u_dense = [] for inp in inp_tensors: ur = _vec_from_tensor_cpu(inp, g_cpu, True) ur_dense = _to_flat_dense_if_sparse(ur) if inp.is_complex(): ui = _vec_from_tensor_cpu(inp, g_cpu, True) all_u.append((ur, ui)) ui_dense = _to_flat_dense_if_sparse(ui) all_u_dense.append((ur_dense, ui_dense)) else: all_u.append(ur) all_u_dense.append(ur_dense) all_v = ( None if use_forward_ad else [_vec_from_tensor_cpu(out, g_cpu) for out in outputs] ) return all_v, all_u, all_u_dense def _check_analytical_numerical_equal( all_analytical, all_numerical, complex_indices, tupled_inputs, outputs, func, all_v, all_u, rtol, atol, eps, test_imag, *, is_forward_ad=False, ): for i, all_numerical_for_input_i in enumerate(all_numerical): for j, n in enumerate(all_numerical_for_input_i): # Forward AD generates the transpose of what this function expects if is_forward_ad: a = all_analytical[i][j] else: a = all_analytical[j][i] n = n.to(device=a.device) updated_atol = _adjusted_atol(atol, all_u[i], all_v[j] if all_v else None) if not _allclose_with_type_promotion(a, n.to(a.device), rtol, updated_atol): jacobians_str = _run_slow_mode_and_get_error( func, tupled_inputs, outputs, i, j, rtol, atol, eps, is_forward_ad ) raise GradcheckError( _get_notallclose_msg( a, n, j, i, complex_indices, test_imag, is_forward_ad ) + jacobians_str ) def _fast_gradcheck( func, func_out, inputs, outputs, eps, rtol, atol, check_grad_dtypes, nondet_tol, *, use_forward_ad=False, complex_indices=None, test_imag=False, masked=False, ): # See https://github.com/pytorch/pytorch/issues/53876 for details inp_tensors_idx, inp_tensors = _get_inp_tensors(inputs) # Backward mode computes v^T * J (VJP) # Since we computed J * u (JVP) through finite difference method, we perform an equality check # between VJP * u, v * JVP # ---- # Forward mode computes J * u (JVP) # Since we already compute JVP through finite difference method, # we don't need v for correctness check here as asserted below all_v, all_u, all_u_dense = _make_vectors( inp_tensors, outputs, use_forward_ad=use_forward_ad ) inputs_numerical, all_u_numerical, all_v_numerical = ( (inputs, all_u, all_v) if masked else _densify((inputs, all_u, all_v)) ) numerical_vJu = _get_numerical_vJu( func, inputs_numerical, inp_tensors_idx, func_out, all_u_numerical, all_v_numerical, eps, is_forward_ad=use_forward_ad, ) # TODO: replicate https://github.com/pytorch/pytorch/pull/77743 for fast gradcheck as well if use_forward_ad: assert all_v is None analytical_vJu = _get_analytical_jacobian_forward_ad( func, inputs, _as_tuple(func_out), all_u=all_u, check_grad_dtypes=check_grad_dtypes, ) else: if not outputs: _check_no_differentiable_outputs_fast( func, func_out, inputs, inp_tensors_idx, all_u, eps, nondet_tol ) analytical_vJu = _get_analytical_vJu_backward_mode( inputs, outputs, nondet_tol, check_grad_dtypes, all_v, all_u_dense ) _check_analytical_numerical_equal( analytical_vJu, numerical_vJu, complex_indices, inputs, outputs, func, all_v, all_u, rtol, atol, eps, test_imag, is_forward_ad=use_forward_ad, ) return True # Note [VarArg of Tensors] # ~~~~~~~~~~~~~~~~~~~~~~~~ # 'func' accepts a vararg of tensors, which isn't expressable in the type system at the moment. # If https://mypy.readthedocs.io/en/latest/additional_features.html?highlight=callable#extended-callable-types is accepted, # the '...' first argument of Callable can be replaced with VarArg(Tensor). # For now, we permit any input.
[docs]def gradcheck( func: Callable[..., Union[_TensorOrTensors]], # See Note [VarArg of Tensors] inputs: _TensorOrTensors, *, eps: float = 1e-6, atol: float = 1e-5, rtol: float = 1e-3, raise_exception: bool = True, nondet_tol: float = 0.0, check_undefined_grad: bool = True, check_grad_dtypes: bool = False, check_batched_grad: bool = False, check_batched_forward_grad: bool = False, check_forward_ad: bool = False, check_backward_ad: bool = True, fast_mode: bool = False, masked: Optional[bool] = None, ) -> bool: # noqa: D400,D205 r"""Check gradients computed via small finite differences against analytical gradients wrt tensors in :attr:`inputs` that are of floating point or complex type and with ``requires_grad=True``. The check between numerical and analytical gradients uses :func:`~torch.allclose`. For most of the complex functions we consider for optimization purposes, no notion of Jacobian exists. Instead, gradcheck verifies if the numerical and analytical values of the Wirtinger and Conjugate Wirtinger derivatives are consistent. Because the gradient computation is done under the assumption that the overall function has a real-valued output, we treat functions with complex output in a special way. For these functions, gradcheck is applied to two real-valued functions corresponding to taking the real components of the complex outputs for the first, and taking the imaginary components of the complex outputs for the second. For more details, check out :ref:`complex_autograd-doc`. .. note:: The default values are designed for :attr:`input` of double precision. This check will likely fail if :attr:`input` is of less precision, e.g., ``FloatTensor``. .. note:: Gradcheck may fail when evaluated on non-differentiable points because the numerically computed gradients via finite differencing may differ those computed analytically (not necessarily because either is incorrect). For more context, see :ref:`non-differentiable-func-grad`. .. warning:: If any checked tensor in :attr:`input` has overlapping memory, i.e., different indices pointing to the same memory address (e.g., from :func:`torch.expand`), this check will likely fail because the numerical gradients computed by point perturbation at such indices will change values at all other indices that share the same memory address. Args: func (function): a Python function that takes Tensor inputs and returns a Tensor or a tuple of Tensors inputs (tuple of Tensor or Tensor): inputs to the function eps (float, optional): perturbation for finite differences atol (float, optional): absolute tolerance rtol (float, optional): relative tolerance raise_exception (bool, optional): indicating whether to raise an exception if the check fails. The exception gives more information about the exact nature of the failure. This is helpful when debugging gradchecks. nondet_tol (float, optional): tolerance for non-determinism. When running identical inputs through the differentiation, the results must either match exactly (default, 0.0) or be within this tolerance. check_undefined_grad (bool, optional): if ``True``, check if undefined output grads are supported and treated as zeros, for ``Tensor`` outputs. check_batched_grad (bool, optional): if ``True``, check if we can compute batched gradients using prototype vmap support. Defaults to False. check_batched_forward_grad (bool, optional): if ``True``, checks if we can compute batched forward gradients using forward ad and prototype vmap support. Defaults to ``False``. check_forward_ad (bool, optional): if ``True``, check that the gradients computed with forward mode AD match the numerical ones. Defaults to ``False``. check_backward_ad (bool, optional): if ``False``, do not perform any checks that rely on backward mode AD to be implemented. Defaults to ``True``. fast_mode (bool, optional): Fast mode for gradcheck and gradgradcheck is currently only implemented for R to R functions. If none of the inputs and outputs are complex a faster implementation of gradcheck that no longer computes the entire jacobian is run; otherwise, we fall back to the slow implementation. masked (bool, optional): if ``True``, the gradients of unspecified elements of sparse tensors are ignored. Defaults to ``False``. Returns: ``True`` if all differences satisfy allclose condition """ assert ( check_forward_ad or check_backward_ad ), "Expected at least one of check_forward_ad or check_backward_ad to be True" assert not ( check_batched_grad and not check_backward_ad ), "Setting check_batched_grad=True requires check_backward_ad to be True" assert not ( check_batched_forward_grad and not check_forward_ad ), "Setting check_batched_forward_grad=True requires check_forward_ad to be True" args = locals().copy() args.pop("raise_exception") if not raise_exception: try: return _gradcheck_helper(**args) except GradcheckError as e: return False else: return _gradcheck_helper(**args)
def _gradcheck_helper( func, inputs, eps, atol, rtol, nondet_tol, check_undefined_grad, check_grad_dtypes, check_batched_grad, check_batched_forward_grad, check_forward_ad, check_backward_ad, fast_mode, masked, ): tupled_inputs = _as_tuple(inputs) _check_inputs(tupled_inputs) func_out = func(*tupled_inputs) outputs = _differentiable_outputs(func_out) _check_outputs(outputs) gradcheck_fn = functools.partial( _fast_gradcheck if fast_mode else _slow_gradcheck, masked=masked ) _gradcheck_real_imag( gradcheck_fn, func, func_out, tupled_inputs, outputs, eps, rtol, atol, check_grad_dtypes, check_forward_ad=check_forward_ad, check_backward_ad=check_backward_ad, nondet_tol=nondet_tol, check_undefined_grad=check_undefined_grad, ) if check_batched_forward_grad: _test_batched_grad_forward_ad(func, tupled_inputs) # Short circuit because remaining tests rely on backward AD to be implemented if not check_backward_ad: return True for i, o in enumerate(outputs): if check_batched_grad: _test_batched_grad(tupled_inputs, o, i) _test_backward_mul_by_grad_output(outputs, tupled_inputs, masked) if check_undefined_grad and check_backward_ad: _test_undefined_backward_mode(func, outputs, tupled_inputs) return True
[docs]def gradgradcheck( func: Callable[..., _TensorOrTensors], # See Note [VarArg of Tensors] inputs: _TensorOrTensors, grad_outputs: Optional[_TensorOrTensors] = None, *, eps: float = 1e-6, atol: float = 1e-5, rtol: float = 1e-3, gen_non_contig_grad_outputs: bool = False, raise_exception: bool = True, nondet_tol: float = 0.0, check_undefined_grad: bool = True, check_grad_dtypes: bool = False, check_batched_grad: bool = False, check_fwd_over_rev: bool = False, check_rev_over_rev: bool = True, fast_mode: bool = False, masked: bool = False, ) -> bool: # noqa: D400,D205 r"""Check gradients of gradients computed via small finite differences against analytical gradients wrt tensors in :attr:`inputs` and :attr:`grad_outputs` that are of floating point or complex type and with ``requires_grad=True``. This function checks that backpropagating through the gradients computed to the given :attr:`grad_outputs` are correct. The check between numerical and analytical gradients uses :func:`~torch.allclose`. .. note:: The default values are designed for :attr:`input` and :attr:`grad_outputs` of double precision. This check will likely fail if they are of less precision, e.g., ``FloatTensor``. .. warning:: If any checked tensor in :attr:`input` and :attr:`grad_outputs` has overlapping memory, i.e., different indices pointing to the same memory address (e.g., from :func:`torch.expand`), this check will likely fail because the numerical gradients computed by point perturbation at such indices will change values at all other indices that share the same memory address. Args: func (function): a Python function that takes Tensor inputs and returns a Tensor or a tuple of Tensors inputs (tuple of Tensor or Tensor): inputs to the function grad_outputs (tuple of Tensor or Tensor, optional): The gradients with respect to the function's outputs. eps (float, optional): perturbation for finite differences atol (float, optional): absolute tolerance rtol (float, optional): relative tolerance gen_non_contig_grad_outputs (bool, optional): if :attr:`grad_outputs` is ``None`` and :attr:`gen_non_contig_grad_outputs` is ``True``, the randomly generated gradient outputs are made to be noncontiguous raise_exception (bool, optional): indicating whether to raise an exception if the check fails. The exception gives more information about the exact nature of the failure. This is helpful when debugging gradchecks. nondet_tol (float, optional): tolerance for non-determinism. When running identical inputs through the differentiation, the results must either match exactly (default, 0.0) or be within this tolerance. Note that a small amount of nondeterminism in the gradient will lead to larger inaccuracies in the second derivative. check_undefined_grad (bool, optional): if True, check if undefined output grads are supported and treated as zeros check_batched_grad (bool, optional): if True, check if we can compute batched gradients using prototype vmap support. Defaults to False. fast_mode (bool, optional): if True, run a faster implementation of gradgradcheck that no longer computes the entire jacobian. masked (bool, optional): if True, the gradients of unspecified elements of sparse tensors are ignored (default, False). Returns: True if all differences satisfy allclose condition """ assert ( check_fwd_over_rev or check_rev_over_rev ), "Expected at least one of check_fwd_over_rev or check_rev_over_rev to be True" assert not ( check_undefined_grad and not check_rev_over_rev ), "Setting check_undefined_grad=True requires check_rev_over_rev to be True" assert not ( check_batched_grad and not check_rev_over_rev ), "Setting check_batched_grad=True requires check_rev_over_rev to be True" # TODO: do we want to test this too? # assert not (check_batched_forward_grad and not check_fwd_over_rev), ( # "Setting check_batched_forward_grad=True requires check_fwd_over_rev to be True") tupled_inputs = _as_tuple(inputs) if grad_outputs is None: # If grad_outputs is not specified, create random Tensors of the same shape, type, and device as the outputs outputs = _differentiable_outputs(func(*tupled_inputs)) tupled_grad_outputs = tuple( torch.testing.make_tensor( x.shape, dtype=x.dtype if x.is_floating_point() or x.is_complex() else torch.double, device=x.device, low=-1, high=1, requires_grad=True, noncontiguous=gen_non_contig_grad_outputs, ) for x in outputs ) else: tupled_grad_outputs = _as_tuple(grad_outputs) num_outputs = len(tupled_grad_outputs) # NB: We need to save the requires_grad information about the inputs here because gradcheck detaches inputs # before running forward mode AD diff_input_args_indices = { i for i, x in enumerate(tupled_inputs) if is_tensor_like(x) and x.requires_grad } diff_grad_output_indices = { i for i, x in enumerate(tupled_grad_outputs) if x.requires_grad } def new_func(*args): # Restore the requires_grad information input_args = tuple( x.requires_grad_() if i in diff_input_args_indices else x for i, x in enumerate(args[:-num_outputs]) ) outputs = _differentiable_outputs(func(*input_args)) grad_outputs = tuple( x.requires_grad_() if i in diff_grad_output_indices else x for i, x in enumerate(args[-num_outputs:]) ) diff_input_args = tuple( x for i, x in enumerate(input_args) if i in diff_input_args_indices ) grad_inputs = torch.autograd.grad( outputs, diff_input_args, grad_outputs, create_graph=True, allow_unused=True ) grad_inputs = tuple(g for g in grad_inputs if g is not None) return grad_inputs return gradcheck( new_func, tupled_inputs + tupled_grad_outputs, eps=eps, atol=atol, rtol=rtol, raise_exception=raise_exception, nondet_tol=nondet_tol, check_undefined_grad=check_undefined_grad, check_grad_dtypes=check_grad_dtypes, check_batched_grad=check_batched_grad, fast_mode=fast_mode, check_forward_ad=check_fwd_over_rev, check_backward_ad=check_rev_over_rev, masked=masked, )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources