"""Timer class based on the timeit.Timer class, but torch aware."""importenumimporttimeitimporttextwrapfromtypingimportoverload,Any,Callable,Dict,List,NoReturn,Optional,Tuple,Type,Unionimporttorchfromtorch.utils.benchmark.utilsimportcommon,cpp_jitfromtorch.utils.benchmark.utils._stubsimportTimerClass,TimeitModuleTypefromtorch.utils.benchmark.utils.valgrind_wrapperimporttimer_interfaceasvalgrind_timer_interface__all__=["Timer","timer","Language"]iftorch.backends.cuda.is_built()andtorch.cuda.is_available():# type: ignore[no-untyped-call]deftimer()->float:torch.cuda.synchronize()returntimeit.default_timer()eliftorch._C._get_privateuse1_backend_name()!="privateuseone":privateuse1_device_handler=getattr(torch,torch._C._get_privateuse1_backend_name(),None) \
iftorch._C._get_privateuse1_backend_name()!="cpu"elseNonedeftimer()->float:ifprivateuse1_device_handler:privateuse1_device_handler.synchronize()returntimeit.default_timer()else:timer=timeit.default_timerclassLanguage(enum.Enum):PYTHON=0CPP=1classCPPTimer:def__init__(self,stmt:str,setup:str,global_setup:str,timer:Callable[[],float],globals:Dict[str,Any],)->None:iftimerisnottimeit.default_timer:raiseNotImplementedError("PyTorch was built with CUDA and a GPU is present; however ""Timer does not yet support GPU measurements. If your ""code is CPU only, pass `timer=timeit.default_timer` to the ""Timer's constructor to indicate this. (Note that this will ""produce incorrect results if the GPU is in fact used, as ""Timer will not synchronize CUDA.)")ifglobals:raiseValueError("C++ timing does not support globals.")self._stmt:str=textwrap.dedent(stmt)self._setup:str=textwrap.dedent(setup)self._global_setup:str=textwrap.dedent(global_setup)self._timeit_module:Optional[TimeitModuleType]=Nonedeftimeit(self,number:int)->float:ifself._timeit_moduleisNone:self._timeit_module=cpp_jit.compile_timeit_template(stmt=self._stmt,setup=self._setup,global_setup=self._global_setup,)returnself._timeit_module.timeit(number)
[docs]classTimer:"""Helper class for measuring execution time of PyTorch statements. For a full tutorial on how to use this class, see: https://pytorch.org/tutorials/recipes/recipes/benchmark.html The PyTorch Timer is based on `timeit.Timer` (and in fact uses `timeit.Timer` internally), but with several key differences: 1) Runtime aware: Timer will perform warmups (important as some elements of PyTorch are lazily initialized), set threadpool size so that comparisons are apples-to-apples, and synchronize asynchronous CUDA functions when necessary. 2) Focus on replicates: When measuring code, and particularly complex kernels / models, run-to-run variation is a significant confounding factor. It is expected that all measurements should include replicates to quantify noise and allow median computation, which is more robust than mean. To that effect, this class deviates from the `timeit` API by conceptually merging `timeit.Timer.repeat` and `timeit.Timer.autorange`. (Exact algorithms are discussed in method docstrings.) The `timeit` method is replicated for cases where an adaptive strategy is not desired. 3) Optional metadata: When defining a Timer, one can optionally specify `label`, `sub_label`, `description`, and `env`. (Defined later) These fields are included in the representation of result object and by the `Compare` class to group and display results for comparison. 4) Instruction counts In addition to wall times, Timer can run a statement under Callgrind and report instructions executed. Directly analogous to `timeit.Timer` constructor arguments: `stmt`, `setup`, `timer`, `globals` PyTorch Timer specific constructor arguments: `label`, `sub_label`, `description`, `env`, `num_threads` Args: stmt: Code snippet to be run in a loop and timed. setup: Optional setup code. Used to define variables used in `stmt` global_setup: (C++ only) Code which is placed at the top level of the file for things like `#include` statements. timer: Callable which returns the current time. If PyTorch was built without CUDA or there is no GPU present, this defaults to `timeit.default_timer`; otherwise it will synchronize CUDA before measuring the time. globals: A dict which defines the global variables when `stmt` is being executed. This is the other method for providing variables which `stmt` needs. label: String which summarizes `stmt`. For instance, if `stmt` is "torch.nn.functional.relu(torch.add(x, 1, out=out))" one might set label to "ReLU(x + 1)" to improve readability. sub_label: Provide supplemental information to disambiguate measurements with identical stmt or label. For instance, in our example above sub_label might be "float" or "int", so that it is easy to differentiate: "ReLU(x + 1): (float)" "ReLU(x + 1): (int)" when printing Measurements or summarizing using `Compare`. description: String to distinguish measurements with identical label and sub_label. The principal use of `description` is to signal to `Compare` the columns of data. For instance one might set it based on the input size to create a table of the form: :: | n=1 | n=4 | ... ------------- ... ReLU(x + 1): (float) | ... | ... | ... ReLU(x + 1): (int) | ... | ... | ... using `Compare`. It is also included when printing a Measurement. env: This tag indicates that otherwise identical tasks were run in different environments, and are therefore not equivalent, for instance when A/B testing a change to a kernel. `Compare` will treat Measurements with different `env` specification as distinct when merging replicate runs. num_threads: The size of the PyTorch threadpool when executing `stmt`. Single threaded performance is important as both a key inference workload and a good indicator of intrinsic algorithmic efficiency, so the default is set to one. This is in contrast to the default PyTorch threadpool size which tries to utilize all cores. """_timer_cls:Type[TimerClass]=timeit.Timerdef__init__(self,stmt:str="pass",setup:str="pass",global_setup:str="",timer:Callable[[],float]=timer,globals:Optional[Dict[str,Any]]=None,label:Optional[str]=None,sub_label:Optional[str]=None,description:Optional[str]=None,env:Optional[str]=None,num_threads:int=1,language:Union[Language,str]=Language.PYTHON,):ifnotisinstance(stmt,str):raiseValueError("Currently only a `str` stmt is supported.")# We copy `globals` to prevent mutations from leaking.# (For instance, `eval` adds the `__builtins__` key)self._globals=dict(globalsor{})timer_kwargs={}iflanguagein(Language.PYTHON,"py","python"):# Include `torch` if not specified as a convenience feature.self._globals.setdefault("torch",torch)self._language:Language=Language.PYTHONifglobal_setup:raiseValueError(f"global_setup is C++ only, got `{global_setup}`. Most ""likely this code can simply be moved to `setup`.")eliflanguagein(Language.CPP,"cpp","c++"):assertself._timer_clsistimeit.Timer,"_timer_cls has already been swapped."self._timer_cls=CPPTimersetup=(""ifsetup=="pass"elsesetup)self._language=Language.CPPtimer_kwargs["global_setup"]=global_setupelse:raiseValueError(f"Invalid language `{language}`.")# Convenience adjustment so that multi-line code snippets defined in# functions do not IndentationError (Python) or look odd (C++). The# leading newline removal is for the initial newline that appears when# defining block strings. For instance:# textwrap.dedent("""# print("This is a stmt")# """)# produces '\nprint("This is a stmt")\n'.## Stripping this down to 'print("This is a stmt")' doesn't change# what gets executed, but it makes __repr__'s nicer.stmt=textwrap.dedent(stmt)stmt=(stmt[1:]ifstmtandstmt[0]=="\n"elsestmt).rstrip()setup=textwrap.dedent(setup)setup=(setup[1:]ifsetupandsetup[0]=="\n"elsesetup).rstrip()self._timer=self._timer_cls(stmt=stmt,setup=setup,timer=timer,globals=valgrind_timer_interface.CopyIfCallgrind.unwrap_all(self._globals),**timer_kwargs,)self._task_spec=common.TaskSpec(stmt=stmt,setup=setup,global_setup=global_setup,label=label,sub_label=sub_label,description=description,env=env,num_threads=num_threads,)def_timeit(self,number:int)->float:# Even calling a timer in C++ takes ~50 ns, so no real operation should# take less than 1 ns. (And this prevents divide by zero errors.)returnmax(self._timer.timeit(number),1e-9)
[docs]deftimeit(self,number:int=1000000)->common.Measurement:"""Mirrors the semantics of timeit.Timer.timeit(). Execute the main statement (`stmt`) `number` times. https://docs.python.org/3/library/timeit.html#timeit.Timer.timeit """withcommon.set_torch_threads(self._task_spec.num_threads):# Warmupself._timeit(number=max(int(number//100),2))returncommon.Measurement(number_per_run=number,raw_times=[self._timeit(number=number)],task_spec=self._task_spec)
defrepeat(self,repeat:int=-1,number:int=-1)->None:raiseNotImplementedError("See `Timer.blocked_autorange.`")defautorange(self,callback:Optional[Callable[[int,float],NoReturn]]=None)->None:raiseNotImplementedError("See `Timer.blocked_autorange.`")def_threaded_measurement_loop(self,number:int,time_hook:Callable[[],float],stop_hook:Callable[[List[float]],bool],min_run_time:float,max_run_time:Optional[float]=None,callback:Optional[Callable[[int,float],NoReturn]]=None)->List[float]:total_time=0.0can_stop=Falsetimes:List[float]=[]withcommon.set_torch_threads(self._task_spec.num_threads):while(total_time<min_run_time)or(notcan_stop):time_spent=time_hook()times.append(time_spent)total_time+=time_spentifcallback:callback(number,time_spent)can_stop=stop_hook(times)ifmax_run_timeandtotal_time>max_run_time:breakreturntimesdef_estimate_block_size(self,min_run_time:float)->int:withcommon.set_torch_threads(self._task_spec.num_threads):# Estimate the block size needed for measurement to be negligible# compared to the inner loop. This also serves as a warmup.overhead=torch.tensor([self._timeit(0)for_inrange(5)]).median().item()number=1whileTrue:time_taken=self._timeit(number)relative_overhead=overhead/time_takenifrelative_overhead<=1e-4andtime_taken>=min_run_time/1000:breakiftime_taken>min_run_time:break# Avoid overflow in C++ pybind11 interfaceifnumber*10>2147483647:breaknumber*=10returnnumber
[docs]defblocked_autorange(self,callback:Optional[Callable[[int,float],NoReturn]]=None,min_run_time:float=0.2,)->common.Measurement:"""Measure many replicates while keeping timer overhead to a minimum. At a high level, blocked_autorange executes the following pseudo-code:: `setup` total_time = 0 while total_time < min_run_time start = timer() for _ in range(block_size): `stmt` total_time += (timer() - start) Note the variable `block_size` in the inner loop. The choice of block size is important to measurement quality, and must balance two competing objectives: 1) A small block size results in more replicates and generally better statistics. 2) A large block size better amortizes the cost of `timer` invocation, and results in a less biased measurement. This is important because CUDA synchronization time is non-trivial (order single to low double digit microseconds) and would otherwise bias the measurement. blocked_autorange sets block_size by running a warmup period, increasing block size until timer overhead is less than 0.1% of the overall computation. This value is then used for the main measurement loop. Returns: A `Measurement` object that contains measured runtimes and repetition counts, and can be used to compute statistics. (mean, median, etc.) """number=self._estimate_block_size(min_run_time)deftime_hook()->float:returnself._timeit(number)defstop_hook(times:List[float])->bool:returnTruetimes=self._threaded_measurement_loop(number,time_hook,stop_hook,min_run_time=min_run_time,callback=callback)returncommon.Measurement(number_per_run=number,raw_times=times,task_spec=self._task_spec)
[docs]defadaptive_autorange(self,threshold:float=0.1,*,min_run_time:float=0.01,max_run_time:float=10.0,callback:Optional[Callable[[int,float],NoReturn]]=None,)->common.Measurement:"""Similar to `blocked_autorange` but also checks for variablility in measurements and repeats until iqr/median is smaller than `threshold` or `max_run_time` is reached. At a high level, adaptive_autorange executes the following pseudo-code:: `setup` times = [] while times.sum < max_run_time start = timer() for _ in range(block_size): `stmt` times.append(timer() - start) enough_data = len(times)>3 and times.sum > min_run_time small_iqr=times.iqr/times.mean<threshold if enough_data and small_iqr: break Args: threshold: value of iqr/median threshold for stopping min_run_time: total runtime needed before checking `threshold` max_run_time: total runtime for all measurements regardless of `threshold` Returns: A `Measurement` object that contains measured runtimes and repetition counts, and can be used to compute statistics. (mean, median, etc.) """number=self._estimate_block_size(min_run_time=0.05)deftime_hook()->float:returnself._timeit(number)defstop_hook(times:List[float])->bool:iflen(times)>3:returncommon.Measurement(number_per_run=number,raw_times=times,task_spec=self._task_spec).meets_confidence(threshold=threshold)returnFalsetimes=self._threaded_measurement_loop(number,time_hook,stop_hook,min_run_time,max_run_time,callback=callback)returncommon.Measurement(number_per_run=number,raw_times=times,task_spec=self._task_spec)
[docs]defcollect_callgrind(self,number:int=100,*,repeats:Optional[int]=None,collect_baseline:bool=True,retain_out_file:bool=False,)->Any:"""Collect instruction counts using Callgrind. Unlike wall times, instruction counts are deterministic (modulo non-determinism in the program itself and small amounts of jitter from the Python interpreter.) This makes them ideal for detailed performance analysis. This method runs `stmt` in a separate process so that Valgrind can instrument the program. Performance is severely degraded due to the instrumentation, however this is ameliorated by the fact that a small number of iterations is generally sufficient to obtain good measurements. In order to to use this method `valgrind`, `callgrind_control`, and `callgrind_annotate` must be installed. Because there is a process boundary between the caller (this process) and the `stmt` execution, `globals` cannot contain arbitrary in-memory data structures. (Unlike timing methods) Instead, globals are restricted to builtins, `nn.Modules`'s, and TorchScripted functions/modules to reduce the surprise factor from serialization and subsequent deserialization. The `GlobalsBridge` class provides more detail on this subject. Take particular care with nn.Modules: they rely on pickle and you may need to add an import to `setup` for them to transfer properly. By default, a profile for an empty statement will be collected and cached to indicate how many instructions are from the Python loop which drives `stmt`. Returns: A `CallgrindStats` object which provides instruction counts and some basic facilities for analyzing and manipulating results. """ifnotisinstance(self._task_spec.stmt,str):raiseValueError("`collect_callgrind` currently only supports string `stmt`")ifrepeatsisnotNoneandrepeats<1:raiseValueError("If specified, `repeats` must be >= 1")# Check that the statement is valid. It doesn't guarantee success, but it's much# simpler and quicker to raise an exception for a faulty `stmt` or `setup` in# the parent process rather than the valgrind subprocess.self._timeit(1)is_python=(self._language==Language.PYTHON)assertis_pythonornotself._globalsresult=valgrind_timer_interface.wrapper_singleton().collect_callgrind(task_spec=self._task_spec,globals=self._globals,number=number,repeats=repeatsor1,collect_baseline=collect_baselineandis_python,is_python=is_python,retain_out_file=retain_out_file,)return(result[0]ifrepeatsisNoneelseresult)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.