Shortcuts

torch.use_deterministic_algorithms

torch.use_deterministic_algorithms(mode)[source]

Sets whether PyTorch operations must use “deterministic” algorithms. That is, algorithms which, given the same input, and when run on the same software and hardware, always produce the same output. When enabled, operations will use deterministic algorithms when available, and if only nondeterministic algorithms are available they will throw a RuntimeError when called.

The following normally-nondeterministic operations will act deterministically when mode=True:

The following normally-nondeterministic operations will throw a RuntimeError when mode=True:

A handful of CUDA operations are nondeterministic if the CUDA version is 10.2 or greater, unless the environment variable CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8 is set. See the CUDA documentation for more details: https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility If one of these environment variable configurations is not set, a RuntimeError will be raised from these operations when called with CUDA tensors:

Note that deterministic operations tend to have worse performance than nondeterministic operations.

Note

This flag does not detect or prevent nondeterministic behavior caused by calling an inplace operation on a tensor with an internal memory overlap or by giving such a tensor as the out argument for an operation. In these cases, multiple writes of different data may target a single memory location, and the order of writes is not guaranteed.

Parameters

mode (bool) – If True, makes potentially nondeterministic operations switch to a deterministic algorithm or throw a runtime error. If False, allows nondeterministic operations.

Example:

>>> torch.use_deterministic_algorithms(True)

# Forward mode nondeterministic error
>>> torch.randn(10).index_copy(0, torch.tensor([0]), torch.randn(1))
...
RuntimeError: index_copy does not have a deterministic implementation...

# Backward mode nondeterministic error
>>> torch.randn(10, requires_grad=True, device='cuda').index_select(0, torch.tensor([0], device='cuda')).backward()
...
RuntimeError: index_add_cuda_ does not have a deterministic implementation...

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources