Shortcuts

torch.testing

Warning

This module is in a PROTOTYPE state. New functions are still being added, and the available functions may change in future PyTorch releases. We are actively looking for feedback for UI/UX improvements or missing functionalities.

torch.testing.assert_close(actual, expected, *, allow_subclasses=True, rtol=None, atol=None, equal_nan=False, check_device=True, check_dtype=True, check_stride=False, check_is_coalesced=True, msg=None)[source]

Asserts that actual and expected are close.

If actual and expected are strided, non-quantized, real-valued, and finite, they are considered close if

actualexpectedatol+rtolexpected\lvert \text{actual} - \text{expected} \rvert \le \texttt{atol} + \texttt{rtol} \cdot \lvert \text{expected} \rvert

and they have the same device (if check_device is True), same dtype (if check_dtype is True), and the same stride (if check_stride is True). Non-finite values (-inf and inf) are only considered close if and only if they are equal. NaN’s are only considered equal to each other if equal_nan is True.

If actual and expected are sparse (either having COO or CSR layout), their strided members are checked individually. Indices, namely indices for COO or crow_indices and col_indices for CSR layout, are always checked for equality whereas the values are checked for closeness according to the definition above. Sparse COO tensors are only considered close if both are either coalesced or uncoalesced (if check_is_coalesced is True).

If actual and expected are quantized, they are considered close if they have the same qscheme() and the result of dequantize() is close according to the definition above.

actual and expected can be Tensor’s or any tensor-or-scalar-likes from which torch.Tensor’s can be constructed with torch.as_tensor(). Except for Python scalars the input types have to be directly related. In addition, actual and expected can be Sequence’s or Mapping’s in which case they are considered close if their structure matches and all their elements are considered close according to the above definition.

Note

Python scalars are an exception to the type relation requirement, because their type(), i.e. int, float, and complex, is equivalent to the dtype of a tensor-like. Thus, Python scalars of different types can be checked, but require check_dtype to be set to False.

Parameters
  • actual (Any) – Actual input.

  • expected (Any) – Expected input.

  • allow_subclasses (bool) – If True (default) and except for Python scalars, inputs of directly related types are allowed. Otherwise type equality is required.

  • rtol (Optional[float]) – Relative tolerance. If specified atol must also be specified. If omitted, default values based on the dtype are selected with the below table.

  • atol (Optional[float]) – Absolute tolerance. If specified rtol must also be specified. If omitted, default values based on the dtype are selected with the below table.

  • equal_nan (Union[bool, str]) – If True, two NaN values will be considered equal.

  • check_device (bool) – If True (default), asserts that corresponding tensors are on the same device. If this check is disabled, tensors on different device’s are moved to the CPU before being compared.

  • check_dtype (bool) – If True (default), asserts that corresponding tensors have the same dtype. If this check is disabled, tensors with different dtype’s are promoted to a common dtype (according to torch.promote_types()) before being compared.

  • check_stride (bool) – If True and corresponding tensors are strided, asserts that they have the same stride.

  • check_is_coalesced (bool) – If True (default) and corresponding tensors are sparse COO, checks that both actual and expected are either coalesced or uncoalesced. If this check is disabled, tensors are coalesce()’ed before being compared.

  • msg (Optional[Union[str, Callable[[Tensor, Tensor, Diagnostics], str]]]) – Optional error message to use if the values of corresponding tensors mismatch. Can be passed as callable in which case it will be called with the mismatching tensors and a namespace of diagnostics about the mismatches. See below for details.

Raises
  • ValueError – If no torch.Tensor can be constructed from an input.

  • ValueError – If only rtol or atol is specified.

  • AssertionError – If corresponding inputs are not Python scalars and are not directly related.

  • AssertionError – If allow_subclasses is False, but corresponding inputs are not Python scalars and have different types.

  • AssertionError – If the inputs are Sequence’s, but their length does not match.

  • AssertionError – If the inputs are Mapping’s, but their set of keys do not match.

  • AssertionError – If corresponding tensors do not have the same shape.

  • AssertionError – If corresponding tensors do not have the same layout.

  • AssertionError – If corresponding tensors are quantized, but have different qscheme()’s.

  • AssertionError – If check_device is True, but corresponding tensors are not on the same device.

  • AssertionError – If check_dtype is True, but corresponding tensors do not have the same dtype.

  • AssertionError – If check_stride is True, but corresponding strided tensors do not have the same stride.

  • AssertionError – If check_is_coalesced is True, but corresponding sparse COO tensors are not both either coalesced or uncoalesced.

  • AssertionError – If the values of corresponding tensors are not close according to the definition above.

The following table displays the default rtol and atol for different dtype’s. In case of mismatching dtype’s, the maximum of both tolerances is used.

dtype

rtol

atol

float16

1e-3

1e-5

bfloat16

1.6e-2

1e-5

float32

1.3e-6

1e-5

float64

1e-7

1e-7

complex32

1e-3

1e-5

complex64

1.3e-6

1e-5

complex128

1e-7

1e-7

other

0.0

0.0

The namespace of diagnostics that will be passed to msg if its a callable has the following attributes:

  • number_of_elements (int): Number of elements in each tensor being compared.

  • total_mismatches (int): Total number of mismatches.

  • max_abs_diff (Union[int, float]): Greatest absolute difference of the inputs.

  • max_abs_diff_idx (Union[int, Tuple[int, …]]): Index of greatest absolute difference.

  • atol (float): Allowed absolute tolerance.

  • max_rel_diff (Union[int, float]): Greatest relative difference of the inputs.

  • max_rel_diff_idx (Union[int, Tuple[int, …]]): Index of greatest relative difference.

  • rtol (float): Allowed relative tolerance.

For max_abs_diff and max_rel_diff the type depends on the dtype of the inputs.

Note

assert_close() is highly configurable with strict default settings. Users are encouraged to partial() it to fit their use case. For example, if an equality check is needed, one might define an assert_equal that uses zero tolrances for every dtype by default:

>>> import functools
>>> assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)
>>> assert_equal(1e-9, 1e-10)
Traceback (most recent call last):
...
AssertionError: Scalars are not equal!

Absolute difference: 8.999999703829253e-10
Relative difference: 8.999999583666371

Examples

>>> # tensor to tensor comparison
>>> expected = torch.tensor([1e0, 1e-1, 1e-2])
>>> actual = torch.acos(torch.cos(expected))
>>> torch.testing.assert_close(actual, expected)
>>> # scalar to scalar comparison
>>> import math
>>> expected = math.sqrt(2.0)
>>> actual = 2.0 / math.sqrt(2.0)
>>> torch.testing.assert_close(actual, expected)
>>> # numpy array to numpy array comparison
>>> import numpy as np
>>> expected = np.array([1e0, 1e-1, 1e-2])
>>> actual = np.arccos(np.cos(expected))
>>> torch.testing.assert_close(actual, expected)
>>> # sequence to sequence comparison
>>> import numpy as np
>>> # The types of the sequences do not have to match. They only have to have the same
>>> # length and their elements have to match.
>>> expected = [torch.tensor([1.0]), 2.0, np.array(3.0)]
>>> actual = tuple(expected)
>>> torch.testing.assert_close(actual, expected)
>>> # mapping to mapping comparison
>>> from collections import OrderedDict
>>> import numpy as np
>>> foo = torch.tensor(1.0)
>>> bar = 2.0
>>> baz = np.array(3.0)
>>> # The types and a possible ordering of mappings do not have to match. They only
>>> # have to have the same set of keys and their elements have to match.
>>> expected = OrderedDict([("foo", foo), ("bar", bar), ("baz", baz)])
>>> actual = {"baz": baz, "bar": bar, "foo": foo}
>>> torch.testing.assert_close(actual, expected)
>>> expected = torch.tensor([1.0, 2.0, 3.0])
>>> actual = expected.clone()
>>> # By default, directly related instances can be compared
>>> torch.testing.assert_close(torch.nn.Parameter(actual), expected)
>>> # This check can be made more strict with allow_subclasses=False
>>> torch.testing.assert_close(
...     torch.nn.Parameter(actual), expected, allow_subclasses=False
... )
Traceback (most recent call last):
...
AssertionError: Except for Python scalars, type equality is required if
allow_subclasses=False, but got <class 'torch.nn.parameter.Parameter'> and
<class 'torch.Tensor'> instead.
>>> # If the inputs are not directly related, they are never considered close
>>> torch.testing.assert_close(actual.numpy(), expected)
Traceback (most recent call last):
...
AssertionError: Except for Python scalars, input types need to be directly
related, but got <class 'numpy.ndarray'> and <class 'torch.Tensor'> instead.
>>> # Exceptions to these rules are Python scalars. They can be checked regardless of
>>> # their type if check_dtype=False.
>>> torch.testing.assert_close(1.0, 1, check_dtype=False)
>>> # NaN != NaN by default.
>>> expected = torch.tensor(float("Nan"))
>>> actual = expected.clone()
>>> torch.testing.assert_close(actual, expected)
Traceback (most recent call last):
...
AssertionError: Scalars are not close!

Absolute difference: nan (up to 1e-05 allowed)
Relative difference: nan (up to 1.3e-06 allowed)
>>> torch.testing.assert_close(actual, expected, equal_nan=True)
>>> expected = torch.tensor([1.0, 2.0, 3.0])
>>> actual = torch.tensor([1.0, 4.0, 5.0])
>>> # The default mismatch message can be overwritten.
>>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!")
Traceback (most recent call last):
...
AssertionError: Argh, the tensors are not close!
>>> # The error message can also created at runtime by passing a callable.
>>> def custom_msg(actual, expected, diagnostics):
...     ratio = diagnostics.total_mismatches / diagnostics.number_of_elements
...     return (
...         f"Argh, we found {diagnostics.total_mismatches} mismatches! "
...         f"That is {ratio:.1%}!"
...     )
>>> torch.testing.assert_close(actual, expected, msg=custom_msg)
Traceback (most recent call last):
...
AssertionError: Argh, we found 2 mismatches! That is 66.7%!
torch.testing.make_tensor(shape, device, dtype, *, low=None, high=None, requires_grad=False, noncontiguous=False, exclude_zero=False)[source]

Creates a tensor with the given shape, device, and dtype, and filled with values uniformly drawn from [low, high).

If low or high are specified and are outside the range of the dtype’s representable finite values then they are clamped to the lowest or highest representable finite value, respectively. If None, then the following table describes the default values for low and high, which depend on dtype.

dtype

low

high

boolean type

0

2

unsigned integral type

0

10

signed integral types

-9

10

floating types

-9

9

complex types

-9

9

Parameters
  • shape (Tuple[int, ..]) – A sequence of integers defining the shape of the output tensor.

  • device (Union[str, torch.device]) – The device of the returned tensor.

  • dtype (torch.dtype) – The data type of the returned tensor.

  • low (Optional[Number]) – Sets the lower limit (inclusive) of the given range. If a number is provided it is clamped to the least representable finite value of the given dtype. When None (default), this value is determined based on the dtype (see the table above). Default: None.

  • high (Optional[Number]) – Sets the upper limit (exclusive) of the given range. If a number is provided it is clamped to the greatest representable finite value of the given dtype. When None (default) this value is determined based on the dtype (see the table above). Default: None.

  • requires_grad (Optional[bool]) – If autograd should record operations on the returned tensor. Default: False.

  • noncontiguous (Optional[bool]) – If True, the returned tensor will be noncontiguous. This argument is ignored if the constructed tensor has fewer than two elements.

  • exclude_zero (Optional[bool]) – If True then zeros are replaced with the dtype’s small positive value depending on the dtype. For bool and integer types zero is replaced with one. For floating point types it is replaced with the dtype’s smallest positive normal number (the “tiny” value of the dtype’s finfo() object), and for complex types it is replaced with a complex number whose real and imaginary parts are both the smallest positive normal number representable by the complex type. Default False.

Raises

Examples

>>> from torch.testing import make_tensor
>>> # Creates a float tensor with values in [-1, 1)
>>> make_tensor((3,), device='cpu', dtype=torch.float32, low=-1, high=1)
tensor([ 0.1205, 0.2282, -0.6380])
>>> # Creates a bool tensor on CUDA
>>> make_tensor((2, 2), device='cuda', dtype=torch.bool)
tensor([[False, False],
        [False, True]], device='cuda:0')

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources