torch.use_deterministic_algorithms¶
-
torch.
use_deterministic_algorithms
(mode)[source]¶ Sets whether PyTorch operations must use “deterministic” algorithms. That is, algorithms which, given the same input, and when run on the same software and hardware, always produce the same output. When enabled, operations will use deterministic algorithms when available, and if only nondeterministic algorithms are available they will throw a
RuntimeError
when called.The following normally-nondeterministic operations will act deterministically when
mode=True
:torch.nn.Conv1d
when called on CUDA tensortorch.nn.Conv2d
when called on CUDA tensortorch.nn.Conv3d
when called on CUDA tensortorch.nn.ConvTranspose1d
when called on CUDA tensortorch.nn.ConvTranspose2d
when called on CUDA tensortorch.nn.ConvTranspose3d
when called on CUDA tensortorch.bmm()
when called on sparse-dense CUDA tensorstorch.Tensor.__getitem__()
when attempting to differentiate a CPU tensor and the index is a list of tensorstorch.Tensor.index_put()
withaccumulate=False
torch.Tensor.index_put()
withaccumulate=True
when called on a CPU tensortorch.Tensor.put_()
withaccumulate=True
when called on a CPU tensortorch.Tensor.scatter_add_()
wheninput
dimension is one and called on a CUDA tensortorch.gather()
wheninput
dimension is one and called on a CUDA tensor that requires gradtorch.index_add()
when called on CUDA tensortorch.index_select()
when attempting to differentiate a CUDA tensortorch.repeat_interleave()
when attempting to differentiate a CUDA tensortorch.Tensor.index_copy()
when called on a CPU or CUDA tensor
The following normally-nondeterministic operations will throw a
RuntimeError
whenmode=True
:torch.nn.AvgPool3d
when attempting to differentiate a CUDA tensortorch.nn.AdaptiveAvgPool2d
when attempting to differentiate a CUDA tensortorch.nn.AdaptiveAvgPool3d
when attempting to differentiate a CUDA tensortorch.nn.MaxPool3d
when attempting to differentiate a CUDA tensortorch.nn.AdaptiveMaxPool2d
when attempting to differentiate a CUDA tensortorch.nn.FractionalMaxPool2d
when attempting to differentiate a CUDA tensortorch.nn.FractionalMaxPool3d
when attempting to differentiate a CUDA tensortorch.nn.functional.interpolate()
when attempting to differentiate a CUDA tensor and one of the following modes is used:linear
bilinear
bicubic
trilinear
torch.nn.ReflectionPad1d
when attempting to differentiate a CUDA tensortorch.nn.ReflectionPad2d
when attempting to differentiate a CUDA tensortorch.nn.ReflectionPad3d
when attempting to differentiate a CUDA tensortorch.nn.ReplicationPad1d
when attempting to differentiate a CUDA tensortorch.nn.ReplicationPad2d
when attempting to differentiate a CUDA tensortorch.nn.ReplicationPad3d
when attempting to differentiate a CUDA tensortorch.nn.NLLLoss
when called on a CUDA tensortorch.nn.CTCLoss
when attempting to differentiate a CUDA tensortorch.nn.EmbeddingBag
when attempting to differentiate a CUDA tensor whenmode='max'
torch.Tensor.scatter_add_()
wheninput
dimension is larger than one and called on a CUDA tensortorch.gather()
wheninput
dimension is larger than one and called on a CUDA tensor that requires gradtorch.Tensor.put_()
whenaccumulate=False
torch.Tensor.put_()
whenaccumulate=True
and called on a CUDA tensortorch.histc()
when called on a CUDA tensortorch.bincount()
when called on a CUDA tensortorch.kthvalue()
with called on a CUDA tensortorch.median()
with indices output when called on a CUDA tensortorch.nn.functional.grid_sample()
when attempting to differentiate a CUDA tensor
A handful of CUDA operations are nondeterministic if the CUDA version is 10.2 or greater, unless the environment variable
CUBLAS_WORKSPACE_CONFIG=:4096:8
orCUBLAS_WORKSPACE_CONFIG=:16:8
is set. See the CUDA documentation for more details: https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility If one of these environment variable configurations is not set, aRuntimeError
will be raised from these operations when called with CUDA tensors:Note that deterministic operations tend to have worse performance than nondeterministic operations.
Note
This flag does not detect or prevent nondeterministic behavior caused by calling an inplace operation on a tensor with an internal memory overlap or by giving such a tensor as the
out
argument for an operation. In these cases, multiple writes of different data may target a single memory location, and the order of writes is not guaranteed.- Parameters
mode (
bool
) – If True, makes potentially nondeterministic operations switch to a deterministic algorithm or throw a runtime error. If False, allows nondeterministic operations.
Example:
>>> torch.use_deterministic_algorithms(True) # Forward mode nondeterministic error >>> torch.randn(10).index_copy(0, torch.tensor([0]), torch.randn(1)) ... RuntimeError: index_copy does not have a deterministic implementation... # Backward mode nondeterministic error >>> torch.randn(10, requires_grad=True, device='cuda').index_select(0, torch.tensor([0], device='cuda')).backward() ... RuntimeError: index_add_cuda_ does not have a deterministic implementation...