Source code for torch.utils.benchmark.utils.common
"""Base shared classes and utilities."""importcollectionsimportcontextlibimportdataclassesimportosimportshutilimporttempfileimporttextwrapimporttimefromtypingimportcast,Any,DefaultDict,Dict,Iterable,Iterator,List,Optional,Tupleimportuuidimporttorch__all__=["TaskSpec","Measurement","select_unit","unit_to_english","trim_sigfig","ordered_unique","set_torch_threads"]_MAX_SIGNIFICANT_FIGURES=4_MIN_CONFIDENCE_INTERVAL=25e-9# 25 ns# Measurement will include a warning if the distribution is suspect. All# runs are expected to have some variation; these parameters set the# thresholds._IQR_WARN_THRESHOLD=0.1_IQR_GROSS_WARN_THRESHOLD=0.25@dataclasses.dataclass(init=True,repr=False,eq=True,frozen=True)classTaskSpec:"""Container for information used to define a Timer. (except globals)"""stmt:strsetup:strglobal_setup:str=""label:Optional[str]=Nonesub_label:Optional[str]=Nonedescription:Optional[str]=Noneenv:Optional[str]=Nonenum_threads:int=1@propertydeftitle(self)->str:"""Best effort attempt at a string label for the measurement."""ifself.labelisnotNone:returnself.label+(f": {self.sub_label}"ifself.sub_labelelse"")elif"\n"notinself.stmt:returnself.stmt+(f": {self.sub_label}"ifself.sub_labelelse"")return(f"stmt:{f' ({self.sub_label})'ifself.sub_labelelse''}\n"f"{textwrap.indent(self.stmt,' ')}")defsetup_str(self)->str:return(""if(self.setup=="pass"ornotself.setup)elsef"setup:\n{textwrap.indent(self.setup,' ')}"if"\n"inself.setupelsef"setup: {self.setup}")defsummarize(self)->str:"""Build TaskSpec portion of repr string for other containers."""sections=[self.title,self.descriptionor"",self.setup_str(),]return"\n".join([f"{i}\n"if"\n"inielseiforiinsectionsifi])_TASKSPEC_FIELDS=tuple(i.nameforiindataclasses.fields(TaskSpec))
[docs]@dataclasses.dataclass(init=True,repr=False)classMeasurement:"""The result of a Timer measurement. This class stores one or more measurements of a given statement. It is serializable and provides several convenience methods (including a detailed __repr__) for downstream consumers. """number_per_run:intraw_times:List[float]task_spec:TaskSpecmetadata:Optional[Dict[Any,Any]]=None# Reserved for user payloads.def__post_init__(self)->None:self._sorted_times:Tuple[float,...]=()self._warnings:Tuple[str,...]=()self._median:float=-1.0self._mean:float=-1.0self._p25:float=-1.0self._p75:float=-1.0def__getattr__(self,name:str)->Any:# Forward TaskSpec fields for convenience.ifnamein_TASKSPEC_FIELDS:returngetattr(self.task_spec,name)returnsuper().__getattribute__(name)# =========================================================================# == Convenience methods for statistics ===================================# =========================================================================## These methods use raw time divided by number_per_run; this is an# extrapolation and hides the fact that different number_per_run will# result in different amortization of overheads, however if Timer has# selected an appropriate number_per_run then this is a non-issue, and# forcing users to handle that division would result in a poor experience.@propertydeftimes(self)->List[float]:return[t/self.number_per_runfortinself.raw_times]@propertydefmedian(self)->float:self._lazy_init()returnself._median@propertydefmean(self)->float:self._lazy_init()returnself._mean@propertydefiqr(self)->float:self._lazy_init()returnself._p75-self._p25@propertydefsignificant_figures(self)->int:"""Approximate significant figure estimate. This property is intended to give a convenient way to estimate the precision of a measurement. It only uses the interquartile region to estimate statistics to try to mitigate skew from the tails, and uses a static z value of 1.645 since it is not expected to be used for small values of `n`, so z can approximate `t`. The significant figure estimation used in conjunction with the `trim_sigfig` method to provide a more human interpretable data summary. __repr__ does not use this method; it simply displays raw values. Significant figure estimation is intended for `Compare`. """self._lazy_init()n_total=len(self._sorted_times)lower_bound=int(n_total//4)upper_bound=int(torch.tensor(3*n_total/4).ceil())interquartile_points:Tuple[float,...]=self._sorted_times[lower_bound:upper_bound]std=torch.tensor(interquartile_points).std(unbiased=False).item()sqrt_n=torch.tensor(len(interquartile_points)).sqrt().item()# Rough estimates. These are by no means statistically rigorous.confidence_interval=max(1.645*std/sqrt_n,_MIN_CONFIDENCE_INTERVAL)relative_ci=torch.tensor(self._median/confidence_interval).log10().item()num_significant_figures=int(torch.tensor(relative_ci).floor())returnmin(max(num_significant_figures,1),_MAX_SIGNIFICANT_FIGURES)@propertydefhas_warnings(self)->bool:self._lazy_init()returnbool(self._warnings)def_lazy_init(self)->None:ifself.raw_timesandnotself._sorted_times:self._sorted_times=tuple(sorted(self.times))_sorted_times=torch.tensor(self._sorted_times,dtype=torch.float64)self._median=_sorted_times.quantile(.5).item()self._mean=_sorted_times.mean().item()self._p25=_sorted_times.quantile(.25).item()self._p75=_sorted_times.quantile(.75).item()defadd_warning(msg:str)->None:rel_iqr=self.iqr/self.median*100self._warnings+=(f" WARNING: Interquartile range is {rel_iqr:.1f}% "f"of the median measurement.\n{msg}",)ifnotself.meets_confidence(_IQR_GROSS_WARN_THRESHOLD):add_warning("This suggests significant environmental influence.")elifnotself.meets_confidence(_IQR_WARN_THRESHOLD):add_warning("This could indicate system fluctuation.")defmeets_confidence(self,threshold:float=_IQR_WARN_THRESHOLD)->bool:returnself.iqr/self.median<threshold@propertydeftitle(self)->str:returnself.task_spec.title@propertydefenv(self)->str:return("Unspecified env"ifself.taskspec.envisNoneelsecast(str,self.taskspec.env))@propertydefas_row_name(self)->str:returnself.sub_labelorself.stmtor"[Unknown]"def__repr__(self)->str:""" Example repr: <utils.common.Measurement object at 0x7f395b6ac110> Broadcasting add (4x8) Median: 5.73 us IQR: 2.25 us (4.01 to 6.26) 372 measurements, 100 runs per measurement, 1 thread WARNING: Interquartile range is 39.4% of the median measurement. This suggests significant environmental influence. """self._lazy_init()skip_line,newline="MEASUREMENT_REPR_SKIP_LINE","\n"n=len(self._sorted_times)time_unit,time_scale=select_unit(self._median)iqr_filter=''ifn>=4elseskip_linerepr_str=f"""{super().__repr__()}{self.task_spec.summarize()}{'Median: 'ifn>1else''}{self._median/time_scale:.2f}{time_unit}{iqr_filter}IQR: {self.iqr/time_scale:.2f}{time_unit} ({self._p25/time_scale:.2f} to {self._p75/time_scale:.2f}){n} measurement{'s'ifn>1else''}, {self.number_per_run} runs {'per measurement,'ifn>1else','}{self.num_threads} thread{'s'ifself.num_threads>1else''}{newline.join(self._warnings)}""".strip()# noqa: B950return"\n".join(lforlinrepr_str.splitlines(keepends=False)ifskip_linenotinl)
[docs]@staticmethoddefmerge(measurements:Iterable["Measurement"])->List["Measurement"]:"""Convenience method for merging replicates. Merge will extrapolate times to `number_per_run=1` and will not transfer any metadata. (Since it might differ between replicates) """grouped_measurements:DefaultDict[TaskSpec,List["Measurement"]]=collections.defaultdict(list)forminmeasurements:grouped_measurements[m.task_spec].append(m)defmerge_group(task_spec:TaskSpec,group:List["Measurement"])->"Measurement":times:List[float]=[]formingroup:# Different measurements could have different `number_per_run`,# so we call `.times` which normalizes the results.times.extend(m.times)returnMeasurement(number_per_run=1,raw_times=times,task_spec=task_spec,metadata=None,)return[merge_group(t,g)fort,gingrouped_measurements.items()]
defselect_unit(t:float)->Tuple[str,float]:"""Determine how to scale times for O(1) magnitude. This utility is used to format numbers for human consumption. """time_unit={-3:"ns",-2:"us",-1:"ms"}.get(int(torch.tensor(t).log10().item()//3),"s")time_scale={"ns":1e-9,"us":1e-6,"ms":1e-3,"s":1}[time_unit]returntime_unit,time_scaledefunit_to_english(u:str)->str:return{"ns":"nanosecond","us":"microsecond","ms":"millisecond","s":"second",}[u]deftrim_sigfig(x:float,n:int)->float:"""Trim `x` to `n` significant figures. (e.g. 3.14159, 2 -> 3.10000)"""assertn==int(n)magnitude=int(torch.tensor(x).abs().log10().ceil().item())scale=10**(magnitude-n)returnfloat(torch.tensor(x/scale).round()*scale)defordered_unique(elements:Iterable[Any])->List[Any]:returnlist(collections.OrderedDict({i:Noneforiinelements}).keys())@contextlib.contextmanagerdefset_torch_threads(n:int)->Iterator[None]:prior_num_threads=torch.get_num_threads()try:torch.set_num_threads(n)yieldfinally:torch.set_num_threads(prior_num_threads)def_make_temp_dir(prefix:Optional[str]=None,gc_dev_shm:bool=False)->str:"""Create a temporary directory. The caller is responsible for cleanup. This function is conceptually similar to `tempfile.mkdtemp`, but with the key additional feature that it will use shared memory if the `BENCHMARK_USE_DEV_SHM` environment variable is set. This is an implementation detail, but an important one for cases where many Callgrind measurements are collected at once. (Such as when collecting microbenchmarks.) This is an internal utility, and is exported solely so that microbenchmarks can reuse the util. """use_dev_shm:bool=(os.getenv("BENCHMARK_USE_DEV_SHM")or"").lower()in("1","true")ifuse_dev_shm:root="/dev/shm/pytorch_benchmark_utils"assertos.name=="posix",f"tmpfs (/dev/shm) is POSIX only, current platform is {os.name}"assertos.path.exists("/dev/shm"),"This system does not appear to support tmpfs (/dev/shm)."os.makedirs(root,exist_ok=True)# Because we're working in shared memory, it is more important than# usual to clean up ALL intermediate files. However we don't want every# worker to walk over all outstanding directories, so instead we only# check when we are sure that it won't lead to contention.ifgc_dev_shm:foriinos.listdir(root):owner_file=os.path.join(root,i,"owner.pid")ifnotos.path.exists(owner_file):continuewithopen(owner_file,"rt")asf:owner_pid=int(f.read())ifowner_pid==os.getpid():continuetry:# https://stackoverflow.com/questions/568271/how-to-check-if-there-exists-a-process-with-a-given-pid-in-pythonos.kill(owner_pid,0)exceptOSError:print(f"Detected that {os.path.join(root,i)} was orphaned in shared memory. Cleaning up.")shutil.rmtree(os.path.join(root,i))else:root=tempfile.gettempdir()# We include the time so names sort by creation time, and add a UUID# to ensure we don't collide.name=f"{prefixortempfile.gettempprefix()}__{int(time.time())}__{uuid.uuid4()}"path=os.path.join(root,name)os.makedirs(path,exist_ok=False)ifuse_dev_shm:withopen(os.path.join(path,"owner.pid"),"wt")asf:f.write(str(os.getpid()))returnpath
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.