.. _cuda-semantics: CUDA semantics ============== :mod:`torch.cuda` is used to set up and run CUDA operations. It keeps track of the currently selected GPU, and all CUDA tensors you allocate will by default be created on that device. The selected device can be changed with a :any:`torch.cuda.device` context manager. However, once a tensor is allocated, you can do operations on it irrespective of the selected device, and the results will be always placed in on the same device as the tensor. Cross-GPU operations are not allowed by default, with the exception of :meth:`~torch.Tensor.copy_` and other methods with copy-like functionality such as :meth:`~torch.Tensor.to` and :meth:`~torch.Tensor.cuda`. Unless you enable peer-to-peer memory access, any attempts to launch ops on tensors spread across different devices will raise an error. Below you can find a small example showcasing this:: cuda = torch.device('cuda') # Default CUDA device cuda0 = torch.device('cuda:0') cuda2 = torch.device('cuda:2') # GPU 2 (these are 0-indexed) x = torch.tensor([1., 2.], device=cuda0) # x.device is device(type='cuda', index=0) y = torch.tensor([1., 2.]).cuda() # y.device is device(type='cuda', index=0) with torch.cuda.device(1): # allocates a tensor on GPU 1 a = torch.tensor([1., 2.], device=cuda) # transfers a tensor from CPU to GPU 1 b = torch.tensor([1., 2.]).cuda() # a.device and b.device are device(type='cuda', index=1) # You can also use ``Tensor.to`` to transfer a tensor: b2 = torch.tensor([1., 2.]).to(device=cuda) # b.device and b2.device are device(type='cuda', index=1) c = a + b # c.device is device(type='cuda', index=1) z = x + y # z.device is device(type='cuda', index=0) # even within a context, you can specify the device # (or give a GPU index to the .cuda call) d = torch.randn(2, device=cuda2) e = torch.randn(2).to(cuda2) f = torch.randn(2).cuda(cuda2) # d.device, e.device, and f.device are all device(type='cuda', index=2) .. _tf32_on_ampere: TensorFloat-32(TF32) on Ampere devices -------------------------------------- Starting in PyTorch 1.7, there is a new flag called `allow_tf32`. This flag defaults to True in PyTorch 1.7 to PyTorch 1.11, and False in PyTorch 1.12 and later. This flag controls whether PyTorch is allowed to use the TensorFloat32 (TF32) tensor cores, available on new NVIDIA GPUs since Ampere, internally to compute matmul (matrix multiplies and batched matrix multiplies) and convolutions. TF32 tensor cores are designed to achieve better performance on matmul and convolutions on `torch.float32` tensors by rounding input data to have 10 bits of mantissa, and accumulating results with FP32 precision, maintaining FP32 dynamic range. matmuls and convolutions are controlled separately, and their corresponding flags can be accessed at: .. code:: python # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. torch.backends.cuda.matmul.allow_tf32 = True # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. torch.backends.cudnn.allow_tf32 = True Note that besides matmuls and convolutions themselves, functions and nn modules that internally uses matmuls or convolutions are also affected. These include `nn.Linear`, `nn.Conv*`, cdist, tensordot, affine grid and grid sample, adaptive log softmax, GRU and LSTM. To get an idea of the precision and speed, see the example code below: .. code:: python a_full = torch.randn(10240, 10240, dtype=torch.double, device='cuda') b_full = torch.randn(10240, 10240, dtype=torch.double, device='cuda') ab_full = a_full @ b_full mean = ab_full.abs().mean() # 80.7277 a = a_full.float() b = b_full.float() # Do matmul at TF32 mode. torch.backends.cuda.matmul.allow_tf32 = True ab_tf32 = a @ b # takes 0.016s on GA100 error = (ab_tf32 - ab_full).abs().max() # 0.1747 relative_error = error / mean # 0.0022 # Do matmul with TF32 disabled. torch.backends.cuda.matmul.allow_tf32 = False ab_fp32 = a @ b # takes 0.11s on GA100 error = (ab_fp32 - ab_full).abs().max() # 0.0031 relative_error = error / mean # 0.000039 From the above example, we can see that with TF32 enabled, the speed is ~7x faster, relative error compared to double precision is approximately 2 orders of magnitude larger. If full FP32 precision is needed, users can disable TF32 by: .. code:: python torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False To toggle the TF32 flags off in C++, you can do .. code:: C++ at::globalContext().setAllowTF32CuBLAS(false); at::globalContext().setAllowTF32CuDNN(false); For more information about TF32, see: - `TensorFloat-32`_ - `CUDA 11`_ - `Ampere architecture`_ .. _TensorFloat-32: https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/ .. _CUDA 11: https://devblogs.nvidia.com/cuda-11-features-revealed/ .. _Ampere architecture: https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/ .. _fp16reducedprecision: Reduced Precision Reduction in FP16 GEMMs ----------------------------------------- fp16 GEMMs are potentially done with some intermediate reduced precision reductions (e.g., in fp16 rather than fp32). These selective reductions in precision can allow for higher performance on certain workloads (particularly those with a large `k` dimension) and GPU architectures at the cost of numerical precision and potential for overflow. Some example benchmark data on V100: .. code:: [--------------------------- bench_gemm_transformer --------------------------] [ m , k , n ] | allow_fp16_reduc=True | allow_fp16_reduc=False 1 threads: -------------------------------------------------------------------- [4096, 4048, 4096] | 1634.6 | 1639.8 [4096, 4056, 4096] | 1670.8 | 1661.9 [4096, 4080, 4096] | 1664.2 | 1658.3 [4096, 4096, 4096] | 1639.4 | 1651.0 [4096, 4104, 4096] | 1677.4 | 1674.9 [4096, 4128, 4096] | 1655.7 | 1646.0 [4096, 4144, 4096] | 1796.8 | 2519.6 [4096, 5096, 4096] | 2094.6 | 3190.0 [4096, 5104, 4096] | 2144.0 | 2663.5 [4096, 5112, 4096] | 2149.1 | 2766.9 [4096, 5120, 4096] | 2142.8 | 2631.0 [4096, 9728, 4096] | 3875.1 | 5779.8 [4096, 16384, 4096] | 6182.9 | 9656.5 (times in microseconds). If full precision reductions are needed, users can disable reduced precision reductions in fp16 GEMMs with: .. code:: python torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False To toggle the reduced precision reduction flags in C++, you can do .. code:: C++ at::globalContext().setAllowFP16ReductionCuBLAS(false); Asynchronous execution ---------------------- By default, GPU operations are asynchronous. When you call a function that uses the GPU, the operations are *enqueued* to the particular device, but not necessarily executed until later. This allows us to execute more computations in parallel, including operations on CPU or other GPUs. In general, the effect of asynchronous computation is invisible to the caller, because (1) each device executes operations in the order they are queued, and (2) PyTorch automatically performs necessary synchronization when copying data between CPU and GPU or between two GPUs. Hence, computation will proceed as if every operation was executed synchronously. You can force synchronous computation by setting environment variable ``CUDA_LAUNCH_BLOCKING=1``. This can be handy when an error occurs on the GPU. (With asynchronous execution, such an error isn't reported until after the operation is actually executed, so the stack trace does not show where it was requested.) A consequence of the asynchronous computation is that time measurements without synchronizations are not accurate. To get precise measurements, one should either call :func:`torch.cuda.synchronize()` before measuring, or use :class:`torch.cuda.Event` to record times as following:: start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() # Run some things here end_event.record() torch.cuda.synchronize() # Wait for the events to be recorded! elapsed_time_ms = start_event.elapsed_time(end_event) As an exception, several functions such as :meth:`~torch.Tensor.to` and :meth:`~torch.Tensor.copy_` admit an explicit :attr:`non_blocking` argument, which lets the caller bypass synchronization when it is unnecessary. Another exception is CUDA streams, explained below. CUDA streams ^^^^^^^^^^^^ A `CUDA stream`_ is a linear sequence of execution that belongs to a specific device. You normally do not need to create one explicitly: by default, each device uses its own "default" stream. Operations inside each stream are serialized in the order they are created, but operations from different streams can execute concurrently in any relative order, unless explicit synchronization functions (such as :meth:`~torch.cuda.synchronize` or :meth:`~torch.cuda.Stream.wait_stream`) are used. For example, the following code is incorrect:: cuda = torch.device('cuda') s = torch.cuda.Stream() # Create a new stream. A = torch.empty((100, 100), device=cuda).normal_(0.0, 1.0) with torch.cuda.stream(s): # sum() may start execution before normal_() finishes! B = torch.sum(A) When the "current stream" is the default stream, PyTorch automatically performs necessary synchronization when data is moved around, as explained above. However, when using non-default streams, it is the user's responsibility to ensure proper synchronization. .. _bwd-cuda-stream-semantics: Stream semantics of backward passes ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Each backward CUDA op runs on the same stream that was used for its corresponding forward op. If your forward pass runs independent ops in parallel on different streams, this helps the backward pass exploit that same parallelism. The stream semantics of a backward call with respect to surrounding ops are the same as for any other call. The backward pass inserts internal syncs to ensure this even when backward ops run on multiple streams as described in the previous paragraph. More concretely, when calling :func:`autograd.backward`, :func:`autograd.grad`, or :meth:`tensor.backward`, and optionally supplying CUDA tensor(s) as the initial gradient(s) (e.g., :func:`autograd.backward(..., grad_tensors=initial_grads)`, :func:`autograd.grad(..., grad_outputs=initial_grads)`, or :meth:`tensor.backward(..., gradient=initial_grad)`), the acts of 1. optionally populating initial gradient(s), 2. invoking the backward pass, and 3. using the gradients have the same stream-semantics relationship as any group of ops:: s = torch.cuda.Stream() # Safe, grads are used in the same stream context as backward() with torch.cuda.stream(s): loss.backward() use grads # Unsafe with torch.cuda.stream(s): loss.backward() use grads # Safe, with synchronization with torch.cuda.stream(s): loss.backward() torch.cuda.current_stream().wait_stream(s) use grads # Safe, populating initial grad and invoking backward are in the same stream context with torch.cuda.stream(s): loss.backward(gradient=torch.ones_like(loss)) # Unsafe, populating initial_grad and invoking backward are in different stream contexts, # without synchronization initial_grad = torch.ones_like(loss) with torch.cuda.stream(s): loss.backward(gradient=initial_grad) # Safe, with synchronization initial_grad = torch.ones_like(loss) s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): initial_grad.record_stream(s) loss.backward(gradient=initial_grad) BC note: Using grads on the default stream ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ In prior versions of PyTorch (1.9 and earlier), the autograd engine always synced the default stream with all backward ops, so the following pattern:: with torch.cuda.stream(s): loss.backward() use grads was safe as long as ``use grads`` happened on the default stream. In present PyTorch, that pattern is no longer safe. If ``backward()`` and ``use grads`` are in different stream contexts, you must sync the streams:: with torch.cuda.stream(s): loss.backward() torch.cuda.current_stream().wait_stream(s) use grads even if ``use grads`` is on the default stream. .. _CUDA stream: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#streams .. _cuda-memory-management: Memory management ----------------- PyTorch uses a caching memory allocator to speed up memory allocations. This allows fast memory deallocation without device synchronizations. However, the unused memory managed by the allocator will still show as if used in ``nvidia-smi``. You can use :meth:`~torch.cuda.memory_allocated` and :meth:`~torch.cuda.max_memory_allocated` to monitor memory occupied by tensors, and use :meth:`~torch.cuda.memory_reserved` and :meth:`~torch.cuda.max_memory_reserved` to monitor the total amount of memory managed by the caching allocator. Calling :meth:`~torch.cuda.empty_cache` releases all **unused** cached memory from PyTorch so that those can be used by other GPU applications. However, the occupied GPU memory by tensors will not be freed so it can not increase the amount of GPU memory available for PyTorch. For more advanced users, we offer more comprehensive memory benchmarking via :meth:`~torch.cuda.memory_stats`. We also offer the capability to capture a complete snapshot of the memory allocator state via :meth:`~torch.cuda.memory_snapshot`, which can help you understand the underlying allocation patterns produced by your code. Use of a caching allocator can interfere with memory checking tools such as ``cuda-memcheck``. To debug memory errors using ``cuda-memcheck``, set ``PYTORCH_NO_CUDA_MEMORY_CACHING=1`` in your environment to disable caching. The behavior of caching allocator can be controlled via environment variable ``PYTORCH_CUDA_ALLOC_CONF``. The format is ``PYTORCH_CUDA_ALLOC_CONF=