Source code for torch.cuda.streams

import ctypes
import torch
from . import cudart, check_error, cudaStatus


[docs]class Stream(torch._C._CudaStreamBase): """Wrapper around a CUDA stream. Arguments: device(int, optional): a device on which to allocate the Stream. priority(int, optional): priority of the stream. Lower numbers represent higher priorities. """ def __new__(cls, device=-1, priority=0, **kwargs): with torch.cuda.device(device): return super(Stream, cls).__new__(cls, priority=priority, **kwargs)
[docs] def wait_event(self, event): """Makes all future work submitted to the stream wait for an event. Arguments: event (Event): an event to wait for. """ check_error(cudart().cudaStreamWaitEvent(self, event, ctypes.c_int(0)))
[docs] def wait_stream(self, stream): """Synchronizes with another stream. All future work submitted to this stream will wait until all kernels submitted to a given stream at the time of call complete. Arguments: stream (Stream): a stream to synchronize. """ self.wait_event(stream.record_event())
[docs] def record_event(self, event=None): """Records an event. Arguments: event (Event, optional): event to record. If not given, a new one will be allocated. Returns: Recorded event. """ if event is None: event = Event() check_error(cudart().cudaEventRecord(event, self)) return event
[docs] def query(self): """Checks if all the work submitted has been completed. Returns: A boolean indicating if all kernels in this stream are completed. """ res = cudart().cudaStreamQuery(self) if res == cudaStatus.ERROR_NOT_READY: return False check_error(res) return True
[docs] def synchronize(self): """Wait for all the kernels in this stream to complete.""" check_error(cudart().cudaStreamSynchronize(self))
@staticmethod def priority_range(): least_priority = ctypes.c_int() greatest_priority = ctypes.c_int() check_error(cudart().cudaDeviceGetStreamPriorityRange( ctypes.byref(least_priority), ctypes.byref(greatest_priority))) return (least_priority.value, greatest_priority.value) @property def priority(self): priority = ctypes.c_int() check_error(cudart().cudaStreamGetPriority(self, ctypes.byref(priority))) return priority.value @property def _as_parameter_(self): return ctypes.c_void_p(self.cuda_stream) def __eq__(self, o): if isinstance(o, Stream): return o.device == self.device and o.cuda_stream == self.cuda_stream return False def __hash__(self): return hash((self.cuda_stream, self.device)) def __repr__(self): return ('<torch.cuda.Stream device={0} cuda_stream={1:#x}>' .format(self.device, self.cuda_stream))
class EventHandle(ctypes.Structure): IPC_HANDLE_SIZE = 64 _fields_ = [('reserved', ctypes.c_char * IPC_HANDLE_SIZE)]
[docs]class Event(object): """Wrapper around CUDA event. Arguments: enable_timing (bool): indicates if the event should measure time (default: ``False``) blocking (bool): if ``True``, :meth:`wait` will be blocking (default: ``False``) interprocess (bool): if ``True``, the event can be shared between processes (default: ``False``) """ DEFAULT = 0x0 BLOCKING_SYNC = 0x1 DISABLE_TIMING = 0x2 INTERPROCESS = 0x4 def __init__(self, enable_timing=False, blocking=False, interprocess=False, _handle=None): flags = Event.DEFAULT if not enable_timing: flags |= Event.DISABLE_TIMING if blocking: flags |= Event.BLOCKING_SYNC if interprocess: flags |= Event.INTERPROCESS ptr = ctypes.c_void_p() self._cudart = cudart() if _handle: check_error(self._cudart.cudaIpcOpenEventHandle(ctypes.byref(ptr), _handle)) else: check_error(self._cudart.cudaEventCreateWithFlags(ctypes.byref(ptr), ctypes.c_uint(flags))) self._as_parameter_ = ptr def __del__(self): if hasattr(self, '_as_parameter_'): check_error(self._cudart.cudaEventDestroy(self._as_parameter_)) del self._as_parameter_
[docs] def record(self, stream=None): """Records the event in a given stream.""" if stream is None: stream = torch.cuda.current_stream() stream.record_event(self)
[docs] def wait(self, stream=None): """Makes a given stream wait for the event.""" if stream is None: stream = torch.cuda.current_stream() stream.wait_event(self)
[docs] def query(self): """Checks if the event has been recorded. Returns: A boolean indicating if the event has been recorded. """ res = cudart().cudaEventQuery(self) if res == cudaStatus.ERROR_NOT_READY: return False check_error(res) return True
[docs] def elapsed_time(self, end_event): """Returns the time elapsed before the event was recorded.""" time_ms = ctypes.c_float() check_error(cudart().cudaEventElapsedTime( ctypes.byref(time_ms), self, end_event)) return time_ms.value
[docs] def synchronize(self): """Synchronizes with the event.""" check_error(cudart().cudaEventSynchronize(self))
[docs] def ipc_handle(self): """Returns an IPC handle of this event.""" handle = EventHandle() check_error(cudart().cudaIpcGetEventHandle(ctypes.byref(handle), self)) return handle
def __repr__(self): return '<torch.cuda.Event {0:#x}>'.format(self._as_parameter_.value)