torchtnt.utils.device.record_data_in_stream¶
-
torchtnt.utils.device.
record_data_in_stream
(data: T, stream: Stream) None ¶ Records the tensor element on certain streams, to avoid memory from being reused for another tensor. As mentioned in https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html, PyTorch uses the “caching allocator” for memory allocation for tensors. When a tensor is freed, its memory is likely to be reused by newly constructed tensors. By default, this allocator traces whether a tensor is still in use by only the CUDA stream where it was created. When a tensor is used by additional CUDA streams, we need to call record_stream to tell the allocator about these streams. Otherwise, the allocator might free the underlying memory of the tensor once it is no longer used by the creator stream. This is a notable programming trick when we write programs using multiple CUDA streams.
Parameters: - data – The data on which to call record_stream
- stream – The CUDA stream with which to call record_stream