importsysimporttorchimportwarningsfromcontextlibimportcontextmanagerfromtorch.backendsimportContextProp,PropModule,__allow_nonbracketed_mutationtry:fromtorch._Cimport_cudnnexceptImportError:_cudnn=None# type: ignore[assignment]# Write:## torch.backends.cudnn.enabled = False## to globally disable CuDNN/MIOpen__cudnn_version=Noneif_cudnnisnotNone:def_init():global__cudnn_versionif__cudnn_versionisNone:__cudnn_version=_cudnn.getVersionInt()runtime_version=_cudnn.getRuntimeVersion()compile_version=_cudnn.getCompileVersion()runtime_major,runtime_minor,_=runtime_versioncompile_major,compile_minor,_=compile_version# Different major versions are always incompatible# Starting with cuDNN 7, minor versions are backwards-compatible# Not sure about MIOpen (ROCm), so always do a strict checkifruntime_major!=compile_major:cudnn_compatible=Falseelifruntime_major<7ornot_cudnn.is_cuda:cudnn_compatible=runtime_minor==compile_minorelse:cudnn_compatible=runtime_minor>=compile_minorifnotcudnn_compatible:raiseRuntimeError('cuDNN version incompatibility: PyTorch was compiled against {} ''but linked against {}'.format(compile_version,runtime_version))returnTrueelse:def_init():returnFalse
[docs]defversion():"""Returns the version of cuDNN"""ifnot_init():returnNonereturn__cudnn_version
[docs]defis_available():r"""Returns a bool indicating if CUDNN is currently available."""returntorch._C.has_cudnn
defis_acceptable(tensor):ifnottorch._C._get_cudnn_enabled():returnFalseiftensor.device.type!='cuda'ortensor.dtypenotinCUDNN_TENSOR_DTYPES:returnFalseifnotis_available():warnings.warn("PyTorch was compiled without cuDNN/MIOpen support. To use cuDNN/MIOpen, rebuild ""PyTorch making sure the library is visible to the build system.")returnFalseifnot_init():warnings.warn('cuDNN/MIOpen library not found. Check your {libpath}'.format(libpath={'darwin':'DYLD_LIBRARY_PATH','win32':'PATH'}.get(sys.platform,'LD_LIBRARY_PATH')))returnFalsereturnTruedefset_flags(_enabled=None,_benchmark=None,_deterministic=None,_allow_tf32=None):orig_flags=(torch._C._get_cudnn_enabled(),torch._C._get_cudnn_benchmark(),torch._C._get_cudnn_deterministic(),torch._C._get_cudnn_allow_tf32())if_enabledisnotNone:torch._C._set_cudnn_enabled(_enabled)if_benchmarkisnotNone:torch._C._set_cudnn_benchmark(_benchmark)if_deterministicisnotNone:torch._C._set_cudnn_deterministic(_deterministic)if_allow_tf32isnotNone:torch._C._set_cudnn_allow_tf32(_allow_tf32)returnorig_flags@contextmanagerdefflags(enabled=False,benchmark=False,deterministic=False,allow_tf32=True):with__allow_nonbracketed_mutation():orig_flags=set_flags(enabled,benchmark,deterministic,allow_tf32)try:yieldfinally:# recover the previous valueswith__allow_nonbracketed_mutation():set_flags(*orig_flags)# The magic here is to allow us to intercept code like this:## torch.backends.<cudnn|mkldnn>.enabled = TrueclassCudnnModule(PropModule):def__init__(self,m,name):super(CudnnModule,self).__init__(m,name)enabled=ContextProp(torch._C._get_cudnn_enabled,torch._C._set_cudnn_enabled)deterministic=ContextProp(torch._C._get_cudnn_deterministic,torch._C._set_cudnn_deterministic)benchmark=ContextProp(torch._C._get_cudnn_benchmark,torch._C._set_cudnn_benchmark)allow_tf32=ContextProp(torch._C._get_cudnn_allow_tf32,torch._C._set_cudnn_allow_tf32)# This is the sys.modules replacement trick, see# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273sys.modules[__name__]=CudnnModule(sys.modules[__name__],__name__)# Add type annotation for the replaced moduleenabled:booldeterministic:boolbenchmark:bool
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.