importcopyimportglobimportimportlibimportosimportreimportshleximportsetuptoolsimportsubprocessimportsysimportsysconfigimportwarningsimportcollectionsimporttorchimporttorch._appdirsfrom.file_batonimportFileBatonfrom._cpp_extension_versionerimportExtensionVersionerfrom.hipifyimporthipify_pythonfrom.hipify.hipify_pythonimportget_hip_file_path,GeneratedFileCleanerfromtypingimportList,Optional,Unionfromsetuptools.command.build_extimportbuild_extfrompkg_resourcesimportpackaging# type: ignore[attr-defined]IS_WINDOWS=sys.platform=='win32'IS_MACOS=sys.platform.startswith('darwin')IS_LINUX=sys.platform.startswith('linux')LIB_EXT='.pyd'ifIS_WINDOWSelse'.so'EXEC_EXT='.exe'ifIS_WINDOWSelse''CLIB_PREFIX=''ifIS_WINDOWSelse'lib'CLIB_EXT='.dll'ifIS_WINDOWSelse'.so'SHARED_FLAG='/DLL'ifIS_WINDOWSelse'-shared'_HERE=os.path.abspath(__file__)_TORCH_PATH=os.path.dirname(os.path.dirname(_HERE))TORCH_LIB_PATH=os.path.join(_TORCH_PATH,'lib')BUILD_SPLIT_CUDA=os.getenv('BUILD_SPLIT_CUDA')or(os.path.exists(os.path.join(TORCH_LIB_PATH,f'{CLIB_PREFIX}torch_cuda_cu{CLIB_EXT}'))andos.path.exists(os.path.join(TORCH_LIB_PATH,f'{CLIB_PREFIX}torch_cuda_cpp{CLIB_EXT}')))SUBPROCESS_DECODE_ARGS=('oem',)ifIS_WINDOWSelse()# Taken directly from python stdlib < 3.9# See https://github.com/pytorch/pytorch/issues/48617def_nt_quote_args(args:Optional[List[str]])->List[str]:"""Quote command-line arguments for DOS/Windows conventions. Just wraps every argument which contains blanks in double quotes, and returns a new argument list. """# Cover None-typeifnotargs:return[]return[f'"{arg}"'if' 'inargelseargforarginargs]def_find_cuda_home()->Optional[str]:r'''Finds the CUDA install path.'''# Guess #1cuda_home=os.environ.get('CUDA_HOME')oros.environ.get('CUDA_PATH')ifcuda_homeisNone:# Guess #2try:which='where'ifIS_WINDOWSelse'which'withopen(os.devnull,'w')asdevnull:nvcc=subprocess.check_output([which,'nvcc'],stderr=devnull).decode(*SUBPROCESS_DECODE_ARGS).rstrip('\r\n')cuda_home=os.path.dirname(os.path.dirname(nvcc))exceptException:# Guess #3ifIS_WINDOWS:cuda_homes=glob.glob('C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*')iflen(cuda_homes)==0:cuda_home=''else:cuda_home=cuda_homes[0]else:cuda_home='/usr/local/cuda'ifnotos.path.exists(cuda_home):cuda_home=Noneifcuda_homeandnottorch.cuda.is_available():print(f"No CUDA runtime is found, using CUDA_HOME='{cuda_home}'")returncuda_homedef_find_rocm_home()->Optional[str]:r'''Finds the ROCm install path.'''# Guess #1rocm_home=os.environ.get('ROCM_HOME')oros.environ.get('ROCM_PATH')ifrocm_homeisNone:# Guess #2try:pipe_hipcc=subprocess.Popen(["which hipcc | xargs readlink -f"],stdout=subprocess.PIPE,stderr=subprocess.PIPE,shell=True)hipcc,_=pipe_hipcc.communicate()# this will be either <ROCM_HOME>/hip/bin/hipcc or <ROCM_HOME>/bin/hipccrocm_home=os.path.dirname(os.path.dirname(hipcc.decode(*SUBPROCESS_DECODE_ARGS).rstrip('\r\n')))ifos.path.basename(rocm_home)=='hip':rocm_home=os.path.dirname(rocm_home)exceptException:# Guess #3rocm_home='/opt/rocm'ifnotos.path.exists(rocm_home):rocm_home=Noneifrocm_homeandtorch.version.hipisNone:print(f"No ROCm runtime is found, using ROCM_HOME='{rocm_home}'")returnrocm_homedef_join_rocm_home(*paths)->str:r''' Joins paths with ROCM_HOME, or raises an error if it ROCM_HOME is not set. This is basically a lazy way of raising an error for missing $ROCM_HOME only once we need to get any ROCm-specific path. '''ifROCM_HOMEisNone:raiseEnvironmentError('ROCM_HOME environment variable is not set. ''Please set it to your ROCm install root.')elifIS_WINDOWS:raiseEnvironmentError('Building PyTorch extensions using ''ROCm and Windows is not supported.')returnos.path.join(ROCM_HOME,*paths)MINIMUM_GCC_VERSION=(5,0,0)MINIMUM_MSVC_VERSION=(19,0,24215)ABI_INCOMPATIBILITY_WARNING=''' !! WARNING !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!Your compiler ({}) may be ABI-incompatible with PyTorch!Please use a compiler that is ABI-compatible with GCC 5.0 and above.See https://gcc.gnu.org/onlinedocs/libstdc++/manual/abi.html.See https://gist.github.com/goldsborough/d466f43e8ffc948ff92de7486c5216d6for instructions on how to install GCC 5 or higher.!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! !! WARNING !!'''WRONG_COMPILER_WARNING=''' !! WARNING !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!Your compiler ({user_compiler}) is not compatible with the compiler Pytorch wasbuilt with for this platform, which is {pytorch_compiler} on {platform}. Pleaseuse {pytorch_compiler} to to compile your extension. Alternatively, you maycompile PyTorch from source using {user_compiler}, and then you can also use{user_compiler} to compile your extension.See https://github.com/pytorch/pytorch/blob/master/CONTRIBUTING.md for helpwith compiling PyTorch from source.!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! !! WARNING !!'''CUDA_MISMATCH_MESSAGE='''The detected CUDA version ({0}) mismatches the version that was used to compilePyTorch ({1}). Please make sure to use the same CUDA versions.'''CUDA_MISMATCH_WARN="The detected CUDA version ({0}) has a minor version mismatch with the version that was used to compile PyTorch ({1}). Most likely this shouldn't be a problem."CUDA_NOT_FOUND_MESSAGE='''CUDA was not found on the system, please set the CUDA_HOME or the CUDA_PATHenvironment variable or add NVCC to your system PATH. The extension compilation will fail.'''ROCM_HOME=_find_rocm_home()MIOPEN_HOME=_join_rocm_home('miopen')ifROCM_HOMEelseNoneIS_HIP_EXTENSION=Trueif((ROCM_HOMEisnotNone)and(torch.version.hipisnotNone))elseFalseROCM_VERSION=Noneiftorch.version.hipisnotNone:ROCM_VERSION=tuple(int(v)forvintorch.version.hip.split('.')[:2])CUDA_HOME=_find_cuda_home()CUDNN_HOME=os.environ.get('CUDNN_HOME')oros.environ.get('CUDNN_PATH')# PyTorch releases have the version pattern major.minor.patch, whereas when# PyTorch is built from source, we append the git commit hash, which gives# it the below pattern.BUILT_FROM_SOURCE_VERSION_PATTERN=re.compile(r'\d+\.\d+\.\d+\w+\+\w+')COMMON_MSVC_FLAGS=['/MD','/wd4819','/wd4251','/wd4244','/wd4267','/wd4275','/wd4018','/wd4190','/EHsc']MSVC_IGNORE_CUDAFE_WARNINGS=['base_class_has_different_dll_interface','field_without_dll_interface','dll_interface_conflict_none_assumed','dll_interface_conflict_dllexport_assumed']COMMON_NVCC_FLAGS=['-D__CUDA_NO_HALF_OPERATORS__','-D__CUDA_NO_HALF_CONVERSIONS__','-D__CUDA_NO_BFLOAT16_CONVERSIONS__','-D__CUDA_NO_HALF2_OPERATORS__','--expt-relaxed-constexpr']COMMON_HIP_FLAGS=['-fPIC','-D__HIP_PLATFORM_HCC__=1','-DUSE_ROCM=1',]COMMON_HIPCC_FLAGS=['-DCUDA_HAS_FP16=1','-D__HIP_NO_HALF_OPERATORS__=1','-D__HIP_NO_HALF_CONVERSIONS__=1',]JIT_EXTENSION_VERSIONER=ExtensionVersioner()PLAT_TO_VCVARS={'win32':'x86','win-amd64':'x86_amd64',}def_is_binary_build()->bool:returnnotBUILT_FROM_SOURCE_VERSION_PATTERN.match(torch.version.__version__)def_accepted_compilers_for_platform()->List[str]:# gnu-c++ and gnu-cc are the conda gcc compilersreturn['clang++','clang']ifIS_MACOSelse['g++','gcc','gnu-c++','gnu-cc']defget_default_build_root()->str:r''' Returns the path to the root folder under which extensions will built. For each extension module built, there will be one folder underneath the folder returned by this function. For example, if ``p`` is the path returned by this function and ``ext`` the name of an extension, the build folder for the extension will be ``p/ext``. This directory is **user-specific** so that multiple users on the same machine won't meet permission issues. '''returnos.path.realpath(torch._appdirs.user_cache_dir(appname='torch_extensions'))defcheck_compiler_ok_for_platform(compiler:str)->bool:r''' Verifies that the compiler is the expected one for the current platform. Args: compiler (str): The compiler executable to check. Returns: True if the compiler is gcc/g++ on Linux or clang/clang++ on macOS, and always True for Windows. '''ifIS_WINDOWS:returnTruewhich=subprocess.check_output(['which',compiler],stderr=subprocess.STDOUT)# Use os.path.realpath to resolve any symlinks, in particular from 'c++' to e.g. 'g++'.compiler_path=os.path.realpath(which.decode(*SUBPROCESS_DECODE_ARGS).strip())# Check the compiler nameifany(nameincompiler_pathfornamein_accepted_compilers_for_platform()):returnTrue# If compiler wrapper is used try to infer the actual compiler by invoking it with -v flagversion_string=subprocess.check_output([compiler,'-v'],stderr=subprocess.STDOUT).decode(*SUBPROCESS_DECODE_ARGS)ifIS_LINUX:# Check for 'gcc' or 'g++' for sccache warpperpattern=re.compile("^COLLECT_GCC=(.*)$",re.MULTILINE)results=re.findall(pattern,version_string)iflen(results)!=1:returnFalsecompiler_path=os.path.realpath(results[0].strip())# On RHEL/CentOS c++ is a gcc compiler wrapperifos.path.basename(compiler_path)=='c++'and'gcc version'inversion_string:returnTruereturnany(nameincompiler_pathfornamein_accepted_compilers_for_platform())ifIS_MACOS:# Check for 'clang' or 'clang++'returnversion_string.startswith("Apple clang")returnFalse
[docs]defcheck_compiler_abi_compatibility(compiler)->bool:r''' Verifies that the given compiler is ABI-compatible with PyTorch. Args: compiler (str): The compiler executable name to check (e.g. ``g++``). Must be executable in a shell process. Returns: False if the compiler is (likely) ABI-incompatible with PyTorch, else True. '''ifnot_is_binary_build():returnTrueifos.environ.get('TORCH_DONT_CHECK_COMPILER_ABI')in['ON','1','YES','TRUE','Y']:returnTrue# First check if the compiler is one of the expected ones for the particular platform.ifnotcheck_compiler_ok_for_platform(compiler):warnings.warn(WRONG_COMPILER_WARNING.format(user_compiler=compiler,pytorch_compiler=_accepted_compilers_for_platform()[0],platform=sys.platform))returnFalseifIS_MACOS:# There is no particular minimum version we need for clang, so we're good here.returnTruetry:ifIS_LINUX:minimum_required_version=MINIMUM_GCC_VERSIONversionstr=subprocess.check_output([compiler,'-dumpfullversion','-dumpversion'])version=versionstr.decode(*SUBPROCESS_DECODE_ARGS).strip().split('.')else:minimum_required_version=MINIMUM_MSVC_VERSIONcompiler_info=subprocess.check_output(compiler,stderr=subprocess.STDOUT)match=re.search(r'(\d+)\.(\d+)\.(\d+)',compiler_info.decode(*SUBPROCESS_DECODE_ARGS).strip())version=['0','0','0']ifmatchisNoneelselist(match.groups())exceptException:_,error,_=sys.exc_info()warnings.warn(f'Error checking compiler version for {compiler}: {error}')returnFalseiftuple(map(int,version))>=minimum_required_version:returnTruecompiler=f'{compiler}{".".join(version)}'warnings.warn(ABI_INCOMPATIBILITY_WARNING.format(compiler))returnFalse
# See below for why we inherit BuildExtension from object.# https://stackoverflow.com/questions/1713038/super-fails-with-error-typeerror-argument-1-must-be-type-not-classobj-when
[docs]classBuildExtension(build_ext,object):r''' A custom :mod:`setuptools` build extension . This :class:`setuptools.build_ext` subclass takes care of passing the minimum required compiler flags (e.g. ``-std=c++14``) as well as mixed C++/CUDA compilation (and support for CUDA files in general). When using :class:`BuildExtension`, it is allowed to supply a dictionary for ``extra_compile_args`` (rather than the usual list) that maps from languages (``cxx`` or ``nvcc``) to a list of additional compiler flags to supply to the compiler. This makes it possible to supply different flags to the C++ and CUDA compiler during mixed compilation. ``use_ninja`` (bool): If ``use_ninja`` is ``True`` (default), then we attempt to build using the Ninja backend. Ninja greatly speeds up compilation compared to the standard ``setuptools.build_ext``. Fallbacks to the standard distutils backend if Ninja is not available. .. note:: By default, the Ninja backend uses #CPUS + 2 workers to build the extension. This may use up too many resources on some systems. One can control the number of workers by setting the `MAX_JOBS` environment variable to a non-negative number. '''@classmethoddefwith_options(cls,**options):r''' Returns a subclass with alternative constructor that extends any original keyword arguments to the original constructor with the given options. '''classcls_with_options(cls):# type: ignore[misc, valid-type]def__init__(self,*args,**kwargs):kwargs.update(options)super().__init__(*args,**kwargs)returncls_with_optionsdef__init__(self,*args,**kwargs)->None:super(BuildExtension,self).__init__(*args,**kwargs)self.no_python_abi_suffix=kwargs.get("no_python_abi_suffix",False)self.use_ninja=kwargs.get('use_ninja',True)ifself.use_ninja:# Test if we can use ninja. Fallback otherwise.msg=('Attempted to use ninja as the BuildExtension backend but ''{}. Falling back to using the slow distutils backend.')ifnotis_ninja_available():warnings.warn(msg.format('we could not find ninja.'))self.use_ninja=Falsedeffinalize_options(self)->None:super().finalize_options()ifself.use_ninja:self.force=Truedefbuild_extensions(self)->None:self._check_abi()cuda_ext=Falseextension_iter=iter(self.extensions)extension=next(extension_iter,None)whilenotcuda_extandextension:forsourceinextension.sources:_,ext=os.path.splitext(source)ifext=='.cu':cuda_ext=Truebreakextension=next(extension_iter,None)ifcuda_extandnotIS_HIP_EXTENSION:self._check_cuda_version()forextensioninself.extensions:# Ensure at least an empty list of flags for 'cxx' and 'nvcc' when# extra_compile_args is a dict. Otherwise, default torch flags do# not get passed. Necessary when only one of 'cxx' and 'nvcc' is# passed to extra_compile_args in CUDAExtension, i.e.# CUDAExtension(..., extra_compile_args={'cxx': [...]})# or# CUDAExtension(..., extra_compile_args={'nvcc': [...]})ifisinstance(extension.extra_compile_args,dict):forextin['cxx','nvcc']:ifextnotinextension.extra_compile_args:extension.extra_compile_args[ext]=[]self._add_compile_flag(extension,'-DTORCH_API_INCLUDE_EXTENSION_H')# See note [Pybind11 ABI constants]fornamein["COMPILER_TYPE","STDLIB","BUILD_ABI"]:val=getattr(torch._C,f"_PYBIND11_{name}")ifvalisnotNoneandnotIS_WINDOWS:self._add_compile_flag(extension,f'-DPYBIND11_{name}="{val}"')self._define_torch_extension_name(extension)self._add_gnu_cpp_abi_flag(extension)# Register .cu, .cuh and .hip as valid source extensions.self.compiler.src_extensions+=['.cu','.cuh','.hip']# Save the original _compile method for later.ifself.compiler.compiler_type=='msvc':self.compiler._cpp_extensions+=['.cu','.cuh']original_compile=self.compiler.compileoriginal_spawn=self.compiler.spawnelse:original_compile=self.compiler._compiledefappend_std14_if_no_std_present(cflags)->None:# NVCC does not allow multiple -std to be passed, so we avoid# overriding the option if the user explicitly passed it.cpp_format_prefix='/{}:'ifself.compiler.compiler_type=='msvc'else'-{}='cpp_flag_prefix=cpp_format_prefix.format('std')cpp_flag=cpp_flag_prefix+'c++14'ifnotany(flag.startswith(cpp_flag_prefix)forflagincflags):cflags.append(cpp_flag)defunix_cuda_flags(cflags):cflags=(COMMON_NVCC_FLAGS+['--compiler-options',"'-fPIC'"]+cflags+_get_cuda_arch_flags(cflags))# NVCC does not allow multiple -ccbin/--compiler-bindir to be passed, so we avoid# overriding the option if the user explicitly passed it._ccbin=os.getenv("CC")if(_ccbinisnotNoneandnotany([flag.startswith('-ccbin')orflag.startswith('--compiler-bindir')forflagincflags])):cflags.extend(['-ccbin',_ccbin])returncflagsdefconvert_to_absolute_paths_inplace(paths):# Helper function. See Note [Absolute include_dirs]ifpathsisnotNone:foriinrange(len(paths)):ifnotos.path.isabs(paths[i]):paths[i]=os.path.abspath(paths[i])defunix_wrap_single_compile(obj,src,ext,cc_args,extra_postargs,pp_opts)->None:# Copy before we make any modifications.cflags=copy.deepcopy(extra_postargs)try:original_compiler=self.compiler.compiler_soif_is_cuda_file(src):nvcc=[_join_rocm_home('bin','hipcc')ifIS_HIP_EXTENSIONelse_join_cuda_home('bin','nvcc')]self.compiler.set_executable('compiler_so',nvcc)ifisinstance(cflags,dict):cflags=cflags['nvcc']ifIS_HIP_EXTENSION:cflags=COMMON_HIPCC_FLAGS+cflags+_get_rocm_arch_flags(cflags)else:cflags=unix_cuda_flags(cflags)elifisinstance(cflags,dict):cflags=cflags['cxx']ifIS_HIP_EXTENSION:cflags=COMMON_HIP_FLAGS+cflagsappend_std14_if_no_std_present(cflags)original_compile(obj,src,ext,cc_args,cflags,pp_opts)finally:# Put the original compiler back in place.self.compiler.set_executable('compiler_so',original_compiler)defunix_wrap_ninja_compile(sources,output_dir=None,macros=None,include_dirs=None,debug=0,extra_preargs=None,extra_postargs=None,depends=None):r"""Compiles sources by outputting a ninja file and running it."""# NB: I copied some lines from self.compiler (which is an instance# of distutils.UnixCCompiler). See the following link.# https://github.com/python/cpython/blob/f03a8f8d5001963ad5b5b28dbd95497e9cc15596/Lib/distutils/ccompiler.py#L564-L567# This can be fragile, but a lot of other repos also do this# (see https://github.com/search?q=_setup_compile&type=Code)# so it is probably OK; we'll also get CI signal if/when# we update our python version (which is when distutils can be# upgraded)# Use absolute path for output_dir so that the object file paths# (`objects`) get generated with absolute paths.output_dir=os.path.abspath(output_dir)# See Note [Absolute include_dirs]convert_to_absolute_paths_inplace(self.compiler.include_dirs)_,objects,extra_postargs,pp_opts,_= \
self.compiler._setup_compile(output_dir,macros,include_dirs,sources,depends,extra_postargs)common_cflags=self.compiler._get_cc_args(pp_opts,debug,extra_preargs)extra_cc_cflags=self.compiler.compiler_so[1:]with_cuda=any(map(_is_cuda_file,sources))# extra_postargs can be either:# - a dict mapping cxx/nvcc to extra flags# - a list of extra flags.ifisinstance(extra_postargs,dict):post_cflags=extra_postargs['cxx']else:post_cflags=list(extra_postargs)ifIS_HIP_EXTENSION:post_cflags=COMMON_HIP_FLAGS+post_cflagsappend_std14_if_no_std_present(post_cflags)cuda_post_cflags=Nonecuda_cflags=Noneifwith_cuda:cuda_cflags=common_cflagsifisinstance(extra_postargs,dict):cuda_post_cflags=extra_postargs['nvcc']else:cuda_post_cflags=list(extra_postargs)ifIS_HIP_EXTENSION:cuda_post_cflags=cuda_post_cflags+_get_rocm_arch_flags(cuda_post_cflags)cuda_post_cflags=COMMON_HIP_FLAGS+COMMON_HIPCC_FLAGS+cuda_post_cflagselse:cuda_post_cflags=unix_cuda_flags(cuda_post_cflags)append_std14_if_no_std_present(cuda_post_cflags)cuda_cflags=[shlex.quote(f)forfincuda_cflags]cuda_post_cflags=[shlex.quote(f)forfincuda_post_cflags]_write_ninja_file_and_compile_objects(sources=sources,objects=objects,cflags=[shlex.quote(f)forfinextra_cc_cflags+common_cflags],post_cflags=[shlex.quote(f)forfinpost_cflags],cuda_cflags=cuda_cflags,cuda_post_cflags=cuda_post_cflags,build_directory=output_dir,verbose=True,with_cuda=with_cuda)# Return *all* object filenames, not just the ones we just built.returnobjectsdefwin_cuda_flags(cflags):return(COMMON_NVCC_FLAGS+cflags+_get_cuda_arch_flags(cflags))defwin_wrap_single_compile(sources,output_dir=None,macros=None,include_dirs=None,debug=0,extra_preargs=None,extra_postargs=None,depends=None):self.cflags=copy.deepcopy(extra_postargs)extra_postargs=Nonedefspawn(cmd):# Using regex to match src, obj and include filessrc_regex=re.compile('/T(p|c)(.*)')src_list=[m.group(2)formin(src_regex.match(elem)forelemincmd)ifm]obj_regex=re.compile('/Fo(.*)')obj_list=[m.group(1)formin(obj_regex.match(elem)forelemincmd)ifm]include_regex=re.compile(r'((\-|\/)I.*)')include_list=[m.group(1)formin(include_regex.match(elem)forelemincmd)ifm]iflen(src_list)>=1andlen(obj_list)>=1:src=src_list[0]obj=obj_list[0]if_is_cuda_file(src):nvcc=_join_cuda_home('bin','nvcc')ifisinstance(self.cflags,dict):cflags=self.cflags['nvcc']elifisinstance(self.cflags,list):cflags=self.cflagselse:cflags=[]cflags=win_cuda_flags(cflags)+['--use-local-env']forflaginCOMMON_MSVC_FLAGS:cflags=['-Xcompiler',flag]+cflagsforignore_warninginMSVC_IGNORE_CUDAFE_WARNINGS:cflags=['-Xcudafe','--diag_suppress='+ignore_warning]+cflagscmd=[nvcc,'-c',src,'-o',obj]+include_list+cflagselifisinstance(self.cflags,dict):cflags=COMMON_MSVC_FLAGS+self.cflags['cxx']cmd+=cflagselifisinstance(self.cflags,list):cflags=COMMON_MSVC_FLAGS+self.cflagscmd+=cflagsreturnoriginal_spawn(cmd)try:self.compiler.spawn=spawnreturnoriginal_compile(sources,output_dir,macros,include_dirs,debug,extra_preargs,extra_postargs,depends)finally:self.compiler.spawn=original_spawndefwin_wrap_ninja_compile(sources,output_dir=None,macros=None,include_dirs=None,debug=0,extra_preargs=None,extra_postargs=None,depends=None):ifnotself.compiler.initialized:self.compiler.initialize()output_dir=os.path.abspath(output_dir)# Note [Absolute include_dirs]# Convert relative path in self.compiler.include_dirs to absolute path if any,# For ninja build, the build location is not local, the build happens# in a in script created build folder, relative path lost their correctness.# To be consistent with jit extension, we allow user to enter relative include_dirs# in setuptools.setup, and we convert the relative path to absolute path hereconvert_to_absolute_paths_inplace(self.compiler.include_dirs)_,objects,extra_postargs,pp_opts,_= \
self.compiler._setup_compile(output_dir,macros,include_dirs,sources,depends,extra_postargs)common_cflags=extra_preargsor[]cflags=[]ifdebug:cflags.extend(self.compiler.compile_options_debug)else:cflags.extend(self.compiler.compile_options)common_cflags.extend(COMMON_MSVC_FLAGS)cflags=cflags+common_cflags+pp_optswith_cuda=any(map(_is_cuda_file,sources))# extra_postargs can be either:# - a dict mapping cxx/nvcc to extra flags# - a list of extra flags.ifisinstance(extra_postargs,dict):post_cflags=extra_postargs['cxx']else:post_cflags=list(extra_postargs)append_std14_if_no_std_present(post_cflags)cuda_post_cflags=Nonecuda_cflags=Noneifwith_cuda:cuda_cflags=['--use-local-env']forcommon_cflagincommon_cflags:cuda_cflags.append('-Xcompiler')cuda_cflags.append(common_cflag)forignore_warninginMSVC_IGNORE_CUDAFE_WARNINGS:cuda_cflags.append('-Xcudafe')cuda_cflags.append('--diag_suppress='+ignore_warning)cuda_cflags.extend(pp_opts)ifisinstance(extra_postargs,dict):cuda_post_cflags=extra_postargs['nvcc']else:cuda_post_cflags=list(extra_postargs)cuda_post_cflags=win_cuda_flags(cuda_post_cflags)cflags=_nt_quote_args(cflags)post_cflags=_nt_quote_args(post_cflags)ifwith_cuda:cuda_cflags=_nt_quote_args(cuda_cflags)cuda_post_cflags=_nt_quote_args(cuda_post_cflags)_write_ninja_file_and_compile_objects(sources=sources,objects=objects,cflags=cflags,post_cflags=post_cflags,cuda_cflags=cuda_cflags,cuda_post_cflags=cuda_post_cflags,build_directory=output_dir,verbose=True,with_cuda=with_cuda)# Return *all* object filenames, not just the ones we just built.returnobjects# Monkey-patch the _compile or compile method.# https://github.com/python/cpython/blob/dc0284ee8f7a270b6005467f26d8e5773d76e959/Lib/distutils/ccompiler.py#L511ifself.compiler.compiler_type=='msvc':ifself.use_ninja:self.compiler.compile=win_wrap_ninja_compileelse:self.compiler.compile=win_wrap_single_compileelse:ifself.use_ninja:self.compiler.compile=unix_wrap_ninja_compileelse:self.compiler._compile=unix_wrap_single_compilebuild_ext.build_extensions(self)defget_ext_filename(self,ext_name):# Get the original shared library name. For Python 3, this name will be# suffixed with "<SOABI>.so", where <SOABI> will be something like# cpython-37m-x86_64-linux-gnu.ext_filename=super(BuildExtension,self).get_ext_filename(ext_name)# If `no_python_abi_suffix` is `True`, we omit the Python 3 ABI# component. This makes building shared libraries with setuptools that# aren't Python modules nicer.ifself.no_python_abi_suffix:# The parts will be e.g. ["my_extension", "cpython-37m-x86_64-linux-gnu", "so"].ext_filename_parts=ext_filename.split('.')# Omit the second to last element.without_abi=ext_filename_parts[:-2]+ext_filename_parts[-1:]ext_filename='.'.join(without_abi)returnext_filenamedef_check_abi(self):# On some platforms, like Windows, compiler_cxx is not available.ifhasattr(self.compiler,'compiler_cxx'):compiler=self.compiler.compiler_cxx[0]elifIS_WINDOWS:compiler=os.environ.get('CXX','cl')else:compiler=os.environ.get('CXX','c++')check_compiler_abi_compatibility(compiler)# Warn user if VC env is activated but `DISTUILS_USE_SDK` is not set.ifIS_WINDOWSand'VSCMD_ARG_TGT_ARCH'inos.environand'DISTUTILS_USE_SDK'notinos.environ:msg=('It seems that the VC environment is activated but DISTUTILS_USE_SDK is not set.''This may lead to multiple activations of the VC env.''Please set `DISTUTILS_USE_SDK=1` and try again.')raiseUserWarning(msg)def_check_cuda_version(self):ifCUDA_HOME:nvcc=os.path.join(CUDA_HOME,'bin','nvcc')cuda_version_str=subprocess.check_output([nvcc,'--version']).strip().decode(*SUBPROCESS_DECODE_ARGS)cuda_version=re.search(r'release (\d+[.]\d+)',cuda_version_str)ifcuda_versionisnotNone:cuda_str_version=cuda_version.group(1)cuda_ver=packaging.version.parse(cuda_str_version)torch_cuda_version=packaging.version.parse(torch.version.cuda)ifcuda_ver!=torch_cuda_version:# major/minor attributes are only available in setuptools>=49.6.0ifgetattr(cuda_ver,"major",float("nan"))!=getattr(torch_cuda_version,"major",float("nan")):raiseRuntimeError(CUDA_MISMATCH_MESSAGE.format(cuda_str_version,torch.version.cuda))warnings.warn(CUDA_MISMATCH_WARN.format(cuda_str_version,torch.version.cuda))else:raiseRuntimeError(CUDA_NOT_FOUND_MESSAGE)def_add_compile_flag(self,extension,flag):extension.extra_compile_args=copy.deepcopy(extension.extra_compile_args)ifisinstance(extension.extra_compile_args,dict):forargsinextension.extra_compile_args.values():args.append(flag)else:extension.extra_compile_args.append(flag)def_define_torch_extension_name(self,extension):# pybind11 doesn't support dots in the names# so in order to support extensions in the packages# like torch._C, we take the last part of the string# as the library namenames=extension.name.split('.')name=names[-1]define=f'-DTORCH_EXTENSION_NAME={name}'self._add_compile_flag(extension,define)def_add_gnu_cpp_abi_flag(self,extension):# use the same CXX ABI as what PyTorch was compiled withself._add_compile_flag(extension,'-D_GLIBCXX_USE_CXX11_ABI='+str(int(torch._C._GLIBCXX_USE_CXX11_ABI)))
[docs]defCppExtension(name,sources,*args,**kwargs):r''' Creates a :class:`setuptools.Extension` for C++. Convenience method that creates a :class:`setuptools.Extension` with the bare minimum (but often sufficient) arguments to build a C++ extension. All arguments are forwarded to the :class:`setuptools.Extension` constructor. Example: >>> from setuptools import setup >>> from torch.utils.cpp_extension import BuildExtension, CppExtension >>> setup( name='extension', ext_modules=[ CppExtension( name='extension', sources=['extension.cpp'], extra_compile_args=['-g']), ], cmdclass={ 'build_ext': BuildExtension }) '''include_dirs=kwargs.get('include_dirs',[])include_dirs+=include_paths()kwargs['include_dirs']=include_dirslibrary_dirs=kwargs.get('library_dirs',[])library_dirs+=library_paths()kwargs['library_dirs']=library_dirslibraries=kwargs.get('libraries',[])libraries.append('c10')libraries.append('torch')libraries.append('torch_cpu')libraries.append('torch_python')kwargs['libraries']=librarieskwargs['language']='c++'returnsetuptools.Extension(name,sources,*args,**kwargs)
[docs]defCUDAExtension(name,sources,*args,**kwargs):r''' Creates a :class:`setuptools.Extension` for CUDA/C++. Convenience method that creates a :class:`setuptools.Extension` with the bare minimum (but often sufficient) arguments to build a CUDA/C++ extension. This includes the CUDA include path, library path and runtime library. All arguments are forwarded to the :class:`setuptools.Extension` constructor. Example: >>> from setuptools import setup >>> from torch.utils.cpp_extension import BuildExtension, CUDAExtension >>> setup( name='cuda_extension', ext_modules=[ CUDAExtension( name='cuda_extension', sources=['extension.cpp', 'extension_kernel.cu'], extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']}) ], cmdclass={ 'build_ext': BuildExtension }) Compute capabilities: By default the extension will be compiled to run on all archs of the cards visible during the building process of the extension, plus PTX. If down the road a new card is installed the extension may need to be recompiled. If a visible card has a compute capability (CC) that's newer than the newest version for which your nvcc can build fully-compiled binaries, Pytorch will make nvcc fall back to building kernels with the newest version of PTX your nvcc does support (see below for details on PTX). You can override the default behavior using `TORCH_CUDA_ARCH_LIST` to explicitly specify which CCs you want the extension to support: TORCH_CUDA_ARCH_LIST="6.1 8.6" python build_my_extension.py TORCH_CUDA_ARCH_LIST="5.2 6.0 6.1 7.0 7.5 8.0 8.6+PTX" python build_my_extension.py The +PTX option causes extension kernel binaries to include PTX instructions for the specified CC. PTX is an intermediate representation that allows kernels to runtime-compile for any CC >= the specified CC (for example, 8.6+PTX generates PTX that can runtime-compile for any GPU with CC >= 8.6). This improves your binary's forward compatibility. However, relying on older PTX to provide forward compat by runtime-compiling for newer CCs can modestly reduce performance on those newer CCs. If you know exact CC(s) of the GPUs you want to target, you're always better off specifying them individually. For example, if you want your extension to run on 8.0 and 8.6, "8.0+PTX" would work functionally because it includes PTX that can runtime-compile for 8.6, but "8.0 8.6" would be better. Note that while it's possible to include all supported archs, the more archs get included the slower the building process will be, as it will build a separate kernel image for each arch. Note that CUDA-11.5 nvcc will hit internal compiler error while parsing torch/extension.h on Windows. To workaround the issue, move python binding logic to pure C++ file. Example use: >>> #include <ATen/ATen.h> >>> at::Tensor SigmoidAlphaBlendForwardCuda(....) Instead of: >>> #include <torch/extension.h> >>> torch::Tensor SigmoidAlphaBlendForwardCuda(...) Currently open issue for nvcc bug: https://github.com/pytorch/pytorch/issues/69460 Complete workaround code example: https://github.com/facebookresearch/pytorch3d/commit/cb170ac024a949f1f9614ffe6af1c38d972f7d48 '''library_dirs=kwargs.get('library_dirs',[])library_dirs+=library_paths(cuda=True)kwargs['library_dirs']=library_dirslibraries=kwargs.get('libraries',[])libraries.append('c10')libraries.append('torch')libraries.append('torch_cpu')libraries.append('torch_python')ifIS_HIP_EXTENSION:assertROCM_VERSIONisnotNonelibraries.append('amdhip64'ifROCM_VERSION>=(3,5)else'hip_hcc')libraries.append('c10_hip')libraries.append('torch_hip')else:libraries.append('cudart')libraries.append('c10_cuda')ifBUILD_SPLIT_CUDA:libraries.append('torch_cuda_cu')libraries.append('torch_cuda_cpp')else:libraries.append('torch_cuda')kwargs['libraries']=librariesinclude_dirs=kwargs.get('include_dirs',[])ifIS_HIP_EXTENSION:build_dir=os.getcwd()hipify_result=hipify_python.hipify(project_directory=build_dir,output_directory=build_dir,includes=[os.path.join(os.path.relpath(include_dir,build_dir),'*')forinclude_dirininclude_dirs]ifinclude_dirselse['*'],extra_files=[os.path.abspath(s)forsinsources],show_detailed=True,is_pytorch_extension=True,)hipified_sources=set()forsourceinsources:s_abs=os.path.abspath(source)hipified_sources.add(hipify_result[s_abs]["hipified_path"]ifs_absinhipify_resultelses_abs)sources=list(hipified_sources)include_dirs+=include_paths(cuda=True)kwargs['include_dirs']=include_dirskwargs['language']='c++'returnsetuptools.Extension(name,sources,*args,**kwargs)
[docs]definclude_paths(cuda:bool=False)->List[str]:''' Get the include paths required to build a C++ or CUDA extension. Args: cuda: If `True`, includes CUDA-specific include paths. Returns: A list of include path strings. '''lib_include=os.path.join(_TORCH_PATH,'include')paths=[lib_include,# Remove this once torch/torch.h is officially no longer supported for C++ extensions.os.path.join(lib_include,'torch','csrc','api','include'),# Some internal (old) Torch headers don't properly prefix their includes,# so we need to pass -Itorch/lib/include/TH as well.os.path.join(lib_include,'TH'),os.path.join(lib_include,'THC')]ifcudaandIS_HIP_EXTENSION:paths.append(os.path.join(lib_include,'THH'))paths.append(_join_rocm_home('include'))ifMIOPEN_HOMEisnotNone:paths.append(os.path.join(MIOPEN_HOME,'include'))elifcuda:cuda_home_include=_join_cuda_home('include')# if we have the Debian/Ubuntu packages for cuda, we get /usr as cuda home.# but gcc doesn't like having /usr/include passed explicitlyifcuda_home_include!='/usr/include':paths.append(cuda_home_include)ifCUDNN_HOMEisnotNone:paths.append(os.path.join(CUDNN_HOME,'include'))returnpaths
deflibrary_paths(cuda:bool=False)->List[str]:r''' Get the library paths required to build a C++ or CUDA extension. Args: cuda: If `True`, includes CUDA-specific library paths. Returns: A list of library path strings. '''# We need to link against libtorch.sopaths=[TORCH_LIB_PATH]ifcudaandIS_HIP_EXTENSION:lib_dir='lib'paths.append(_join_rocm_home(lib_dir))elifcuda:ifIS_WINDOWS:lib_dir='lib/x64'else:lib_dir='lib64'if(notos.path.exists(_join_cuda_home(lib_dir))andos.path.exists(_join_cuda_home('lib'))):# 64-bit CUDA may be installed in 'lib' (see e.g. gh-16955)# Note that it's also possible both don't exist (see# _find_cuda_home) - in that case we stay with 'lib64'.lib_dir='lib'paths.append(_join_cuda_home(lib_dir))ifCUDNN_HOMEisnotNone:paths.append(os.path.join(CUDNN_HOME,lib_dir))returnpaths
[docs]defload(name,sources:Union[str,List[str]],extra_cflags=None,extra_cuda_cflags=None,extra_ldflags=None,extra_include_paths=None,build_directory=None,verbose=False,with_cuda:Optional[bool]=None,is_python_module=True,is_standalone=False,keep_intermediates=True):r''' Loads a PyTorch C++ extension just-in-time (JIT). To load an extension, a Ninja build file is emitted, which is used to compile the given sources into a dynamic library. This library is subsequently loaded into the current Python process as a module and returned from this function, ready for use. By default, the directory to which the build file is emitted and the resulting library compiled to is ``<tmp>/torch_extensions/<name>``, where ``<tmp>`` is the temporary folder on the current platform and ``<name>`` the name of the extension. This location can be overridden in two ways. First, if the ``TORCH_EXTENSIONS_DIR`` environment variable is set, it replaces ``<tmp>/torch_extensions`` and all extensions will be compiled into subfolders of this directory. Second, if the ``build_directory`` argument to this function is supplied, it overrides the entire path, i.e. the library will be compiled into that folder directly. To compile the sources, the default system compiler (``c++``) is used, which can be overridden by setting the ``CXX`` environment variable. To pass additional arguments to the compilation process, ``extra_cflags`` or ``extra_ldflags`` can be provided. For example, to compile your extension with optimizations, pass ``extra_cflags=['-O3']``. You can also use ``extra_cflags`` to pass further include directories. CUDA support with mixed compilation is provided. Simply pass CUDA source files (``.cu`` or ``.cuh``) along with other sources. Such files will be detected and compiled with nvcc rather than the C++ compiler. This includes passing the CUDA lib64 directory as a library directory, and linking ``cudart``. You can pass additional flags to nvcc via ``extra_cuda_cflags``, just like with ``extra_cflags`` for C++. Various heuristics for finding the CUDA install directory are used, which usually work fine. If not, setting the ``CUDA_HOME`` environment variable is the safest option. Args: name: The name of the extension to build. This MUST be the same as the name of the pybind11 module! sources: A list of relative or absolute paths to C++ source files. extra_cflags: optional list of compiler flags to forward to the build. extra_cuda_cflags: optional list of compiler flags to forward to nvcc when building CUDA sources. extra_ldflags: optional list of linker flags to forward to the build. extra_include_paths: optional list of include directories to forward to the build. build_directory: optional path to use as build workspace. verbose: If ``True``, turns on verbose logging of load steps. with_cuda: Determines whether CUDA headers and libraries are added to the build. If set to ``None`` (default), this value is automatically determined based on the existence of ``.cu`` or ``.cuh`` in ``sources``. Set it to `True`` to force CUDA headers and libraries to be included. is_python_module: If ``True`` (default), imports the produced shared library as a Python module. If ``False``, behavior depends on ``is_standalone``. is_standalone: If ``False`` (default) loads the constructed extension into the process as a plain dynamic library. If ``True``, build a standalone executable. Returns: If ``is_python_module`` is ``True``: Returns the loaded PyTorch extension as a Python module. If ``is_python_module`` is ``False`` and ``is_standalone`` is ``False``: Returns nothing. (The shared library is loaded into the process as a side effect.) If ``is_standalone`` is ``True``. Return the path to the executable. (On Windows, TORCH_LIB_PATH is added to the PATH environment variable as a side effect.) Example: >>> from torch.utils.cpp_extension import load >>> module = load( name='extension', sources=['extension.cpp', 'extension_kernel.cu'], extra_cflags=['-O2'], verbose=True) '''return_jit_compile(name,[sources]ifisinstance(sources,str)elsesources,extra_cflags,extra_cuda_cflags,extra_ldflags,extra_include_paths,build_directoryor_get_build_directory(name,verbose),verbose,with_cuda,is_python_module,is_standalone,keep_intermediates=keep_intermediates)
[docs]defload_inline(name,cpp_sources,cuda_sources=None,functions=None,extra_cflags=None,extra_cuda_cflags=None,extra_ldflags=None,extra_include_paths=None,build_directory=None,verbose=False,with_cuda=None,is_python_module=True,with_pytorch_error_handling=True,keep_intermediates=True):r''' Loads a PyTorch C++ extension just-in-time (JIT) from string sources. This function behaves exactly like :func:`load`, but takes its sources as strings rather than filenames. These strings are stored to files in the build directory, after which the behavior of :func:`load_inline` is identical to :func:`load`. See `the tests <https://github.com/pytorch/pytorch/blob/master/test/test_cpp_extensions_jit.py>`_ for good examples of using this function. Sources may omit two required parts of a typical non-inline C++ extension: the necessary header includes, as well as the (pybind11) binding code. More precisely, strings passed to ``cpp_sources`` are first concatenated into a single ``.cpp`` file. This file is then prepended with ``#include <torch/extension.h>``. Furthermore, if the ``functions`` argument is supplied, bindings will be automatically generated for each function specified. ``functions`` can either be a list of function names, or a dictionary mapping from function names to docstrings. If a list is given, the name of each function is used as its docstring. The sources in ``cuda_sources`` are concatenated into a separate ``.cu`` file and prepended with ``torch/types.h``, ``cuda.h`` and ``cuda_runtime.h`` includes. The ``.cpp`` and ``.cu`` files are compiled separately, but ultimately linked into a single library. Note that no bindings are generated for functions in ``cuda_sources`` per se. To bind to a CUDA kernel, you must create a C++ function that calls it, and either declare or define this C++ function in one of the ``cpp_sources`` (and include its name in ``functions``). See :func:`load` for a description of arguments omitted below. Args: cpp_sources: A string, or list of strings, containing C++ source code. cuda_sources: A string, or list of strings, containing CUDA source code. functions: A list of function names for which to generate function bindings. If a dictionary is given, it should map function names to docstrings (which are otherwise just the function names). with_cuda: Determines whether CUDA headers and libraries are added to the build. If set to ``None`` (default), this value is automatically determined based on whether ``cuda_sources`` is provided. Set it to ``True`` to force CUDA headers and libraries to be included. with_pytorch_error_handling: Determines whether pytorch error and warning macros are handled by pytorch instead of pybind. To do this, each function ``foo`` is called via an intermediary ``_safe_foo`` function. This redirection might cause issues in obscure cases of cpp. This flag should be set to ``False`` when this redirect causes issues. Example: >>> from torch.utils.cpp_extension import load_inline >>> source = \'\'\' at::Tensor sin_add(at::Tensor x, at::Tensor y) { return x.sin() + y.sin(); } \'\'\' >>> module = load_inline(name='inline_extension', cpp_sources=[source], functions=['sin_add']) .. note:: By default, the Ninja backend uses #CPUS + 2 workers to build the extension. This may use up too many resources on some systems. One can control the number of workers by setting the `MAX_JOBS` environment variable to a non-negative number. '''build_directory=build_directoryor_get_build_directory(name,verbose)ifisinstance(cpp_sources,str):cpp_sources=[cpp_sources]cuda_sources=cuda_sourcesor[]ifisinstance(cuda_sources,str):cuda_sources=[cuda_sources]cpp_sources.insert(0,'#include <torch/extension.h>')# If `functions` is supplied, we create the pybind11 bindings for the user.# Here, `functions` is (or becomes, after some processing) a map from# function names to function docstrings.iffunctionsisnotNone:module_def=[]module_def.append('PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {')ifisinstance(functions,str):functions=[functions]ifisinstance(functions,list):# Make the function docstring the same as the function name.functions=dict((f,f)forfinfunctions)elifnotisinstance(functions,dict):raiseValueError(f"Expected 'functions' to be a list or dict, but was {type(functions)}")forfunction_name,docstringinfunctions.items():ifwith_pytorch_error_handling:module_def.append('m.def("{0}", torch::wrap_pybind_function({0}), "{1}");'.format(function_name,docstring))else:module_def.append('m.def("{0}", {0}, "{1}");'.format(function_name,docstring))module_def.append('}')cpp_sources+=module_defcpp_source_path=os.path.join(build_directory,'main.cpp')withopen(cpp_source_path,'w')ascpp_source_file:cpp_source_file.write('\n'.join(cpp_sources))sources=[cpp_source_path]ifcuda_sources:cuda_sources.insert(0,'#include <torch/types.h>')cuda_sources.insert(1,'#include <cuda.h>')cuda_sources.insert(2,'#include <cuda_runtime.h>')cuda_source_path=os.path.join(build_directory,'cuda.cu')withopen(cuda_source_path,'w')ascuda_source_file:cuda_source_file.write('\n'.join(cuda_sources))sources.append(cuda_source_path)return_jit_compile(name,sources,extra_cflags,extra_cuda_cflags,extra_ldflags,extra_include_paths,build_directory,verbose,with_cuda,is_python_module,is_standalone=False,keep_intermediates=keep_intermediates)
def_jit_compile(name,sources,extra_cflags,extra_cuda_cflags,extra_ldflags,extra_include_paths,build_directory:str,verbose:bool,with_cuda:Optional[bool],is_python_module,is_standalone,keep_intermediates=True)->None:ifis_python_moduleandis_standalone:raiseValueError("`is_python_module` and `is_standalone` are mutually exclusive.")ifwith_cudaisNone:with_cuda=any(map(_is_cuda_file,sources))with_cudnn=any(['cudnn'infforfinextra_ldflagsor[]])old_version=JIT_EXTENSION_VERSIONER.get_version(name)version=JIT_EXTENSION_VERSIONER.bump_version_if_changed(name,sources,build_arguments=[extra_cflags,extra_cuda_cflags,extra_ldflags,extra_include_paths],build_directory=build_directory,with_cuda=with_cuda,is_python_module=is_python_module,is_standalone=is_standalone,)ifversion>0:ifversion!=old_versionandverbose:print(f'The input conditions for extension module {name} have changed. '+f'Bumping to version {version} and re-building as {name}_v{version}...')name=f'{name}_v{version}'ifversion!=old_version:baton=FileBaton(os.path.join(build_directory,'lock'))ifbaton.try_acquire():try:withGeneratedFileCleaner(keep_intermediates=keep_intermediates)asclean_ctx:ifIS_HIP_EXTENSIONand(with_cudaorwith_cudnn):hipify_python.hipify(project_directory=build_directory,output_directory=build_directory,includes=os.path.join(build_directory,'*'),extra_files=[os.path.abspath(s)forsinsources],show_detailed=verbose,is_pytorch_extension=True,clean_ctx=clean_ctx)_write_ninja_file_and_build_library(name=name,sources=sources,extra_cflags=extra_cflagsor[],extra_cuda_cflags=extra_cuda_cflagsor[],extra_ldflags=extra_ldflagsor[],extra_include_paths=extra_include_pathsor[],build_directory=build_directory,verbose=verbose,with_cuda=with_cuda,is_standalone=is_standalone)finally:baton.release()else:baton.wait()elifverbose:print('No modifications detected for re-loaded extension 'f'module {name}, skipping build step...')ifverbose:print(f'Loading extension module {name}...')ifis_standalone:return_get_exec_path(name,build_directory)return_import_module_from_library(name,build_directory,is_python_module)def_write_ninja_file_and_compile_objects(sources:List[str],objects,cflags,post_cflags,cuda_cflags,cuda_post_cflags,build_directory:str,verbose:bool,with_cuda:Optional[bool])->None:verify_ninja_availability()ifIS_WINDOWS:compiler=os.environ.get('CXX','cl')else:compiler=os.environ.get('CXX','c++')check_compiler_abi_compatibility(compiler)ifwith_cudaisNone:with_cuda=any(map(_is_cuda_file,sources))build_file_path=os.path.join(build_directory,'build.ninja')ifverbose:print(f'Emitting ninja build file {build_file_path}...')_write_ninja_file(path=build_file_path,cflags=cflags,post_cflags=post_cflags,cuda_cflags=cuda_cflags,cuda_post_cflags=cuda_post_cflags,sources=sources,objects=objects,ldflags=None,library_target=None,with_cuda=with_cuda)ifverbose:print('Compiling objects...')_run_ninja_build(build_directory,verbose,# It would be better if we could tell users the name of the extension# that failed to build but there isn't a good way to get it here.error_prefix='Error compiling objects for extension')def_write_ninja_file_and_build_library(name,sources:List[str],extra_cflags,extra_cuda_cflags,extra_ldflags,extra_include_paths,build_directory:str,verbose:bool,with_cuda:Optional[bool],is_standalone:bool=False)->None:verify_ninja_availability()ifIS_WINDOWS:compiler=os.environ.get('CXX','cl')else:compiler=os.environ.get('CXX','c++')check_compiler_abi_compatibility(compiler)ifwith_cudaisNone:with_cuda=any(map(_is_cuda_file,sources))extra_ldflags=_prepare_ldflags(extra_ldflagsor[],with_cuda,verbose,is_standalone)build_file_path=os.path.join(build_directory,'build.ninja')ifverbose:print(f'Emitting ninja build file {build_file_path}...')# NOTE: Emitting a new ninja build file does not cause re-compilation if# the sources did not change, so it's ok to re-emit (and it's fast)._write_ninja_file_to_build_library(path=build_file_path,name=name,sources=sources,extra_cflags=extra_cflagsor[],extra_cuda_cflags=extra_cuda_cflagsor[],extra_ldflags=extra_ldflagsor[],extra_include_paths=extra_include_pathsor[],with_cuda=with_cuda,is_standalone=is_standalone)ifverbose:print(f'Building extension module {name}...')_run_ninja_build(build_directory,verbose,error_prefix=f"Error building extension '{name}'")
[docs]defis_ninja_available():r''' Returns ``True`` if the `ninja <https://ninja-build.org/>`_ build system is available on the system, ``False`` otherwise. '''try:subprocess.check_output('ninja --version'.split())exceptException:returnFalseelse:returnTrue
[docs]defverify_ninja_availability():r''' Raises ``RuntimeError`` if `ninja <https://ninja-build.org/>`_ build system is not available on the system, does nothing otherwise. '''ifnotis_ninja_available():raiseRuntimeError("Ninja is required to load C++ extensions")
def_prepare_ldflags(extra_ldflags,with_cuda,verbose,is_standalone):ifIS_WINDOWS:python_path=os.path.dirname(sys.executable)python_lib_path=os.path.join(python_path,'libs')extra_ldflags.append('c10.lib')ifwith_cuda:extra_ldflags.append('c10_cuda.lib')extra_ldflags.append('torch_cpu.lib')ifBUILD_SPLIT_CUDAandwith_cuda:extra_ldflags.append('torch_cuda_cu.lib')# See [Note about _torch_cuda_cu_linker_symbol_op and torch_cuda_cu] in native_functions.yamlextra_ldflags.append('-INCLUDE:?_torch_cuda_cu_linker_symbol_op_cuda@native@at@@YA?AVTensor@2@AEBV32@@Z')extra_ldflags.append('torch_cuda_cpp.lib')# /INCLUDE is used to ensure torch_cuda_cpp is linked against in a project that relies on it.# Related issue: https://github.com/pytorch/pytorch/issues/31611extra_ldflags.append('-INCLUDE:?warp_size@cuda@at@@YAHXZ')elifwith_cuda:extra_ldflags.append('torch_cuda.lib')# /INCLUDE is used to ensure torch_cuda is linked against in a project that relies on it.# Related issue: https://github.com/pytorch/pytorch/issues/31611extra_ldflags.append('-INCLUDE:?warp_size@cuda@at@@YAHXZ')extra_ldflags.append('torch.lib')extra_ldflags.append(f'/LIBPATH:{TORCH_LIB_PATH}')ifnotis_standalone:extra_ldflags.append('torch_python.lib')extra_ldflags.append(f'/LIBPATH:{python_lib_path}')else:extra_ldflags.append(f'-L{TORCH_LIB_PATH}')extra_ldflags.append('-lc10')ifwith_cuda:extra_ldflags.append('-lc10_hip'ifIS_HIP_EXTENSIONelse'-lc10_cuda')extra_ldflags.append('-ltorch_cpu')ifBUILD_SPLIT_CUDAandwith_cuda:extra_ldflags.append('-ltorch_hip'ifIS_HIP_EXTENSIONelse'-ltorch_cuda_cu -ltorch_cuda_cpp')elifwith_cuda:extra_ldflags.append('-ltorch_hip'ifIS_HIP_EXTENSIONelse'-ltorch_cuda')extra_ldflags.append('-ltorch')ifnotis_standalone:extra_ldflags.append('-ltorch_python')ifis_standaloneand"TBB"intorch.__config__.parallel_info():extra_ldflags.append('-ltbb')ifis_standalone:extra_ldflags.append(f"-Wl,-rpath,{TORCH_LIB_PATH}")ifwith_cuda:ifverbose:print('Detected CUDA files, patching ldflags')ifIS_WINDOWS:extra_ldflags.append(f'/LIBPATH:{_join_cuda_home("lib/x64")}')extra_ldflags.append('cudart.lib')ifCUDNN_HOMEisnotNone:extra_ldflags.append(os.path.join(CUDNN_HOME,'lib/x64'))elifnotIS_HIP_EXTENSION:extra_ldflags.append(f'-L{_join_cuda_home("lib64")}')extra_ldflags.append('-lcudart')ifCUDNN_HOMEisnotNone:extra_ldflags.append(f'-L{os.path.join(CUDNN_HOME,"lib64")}')elifIS_HIP_EXTENSION:assertROCM_VERSIONisnotNoneextra_ldflags.append(f'-L{_join_rocm_home("lib")}')extra_ldflags.append('-lamdhip64'ifROCM_VERSION>=(3,5)else'-lhip_hcc')returnextra_ldflagsdef_get_cuda_arch_flags(cflags:Optional[List[str]]=None)->List[str]:r''' Determine CUDA arch flags to use. For an arch, say "6.1", the added compile flag will be ``-gencode=arch=compute_61,code=sm_61``. For an added "+PTX", an additional ``-gencode=arch=compute_xx,code=compute_xx`` is added. See select_compute_arch.cmake for corresponding named and supported arches when building with CMake. '''# If cflags is given, there may already be user-provided arch flags in it# (from `extra_compile_args`)ifcflagsisnotNone:forflagincflags:if'arch'inflag:return[]# Note: keep combined names ("arch1+arch2") above single names, otherwise# string replacement may not do the right thingnamed_arches=collections.OrderedDict([('Kepler+Tesla','3.7'),('Kepler','3.5+PTX'),('Maxwell+Tegra','5.3'),('Maxwell','5.0;5.2+PTX'),('Pascal','6.0;6.1+PTX'),('Volta','7.0+PTX'),('Turing','7.5+PTX'),('Ampere','8.0;8.6+PTX'),])supported_arches=['3.5','3.7','5.0','5.2','5.3','6.0','6.1','6.2','7.0','7.2','7.5','8.0','8.6']valid_arch_strings=supported_arches+[s+"+PTX"forsinsupported_arches]# The default is sm_30 for CUDA 9.x and 10.x# First check for an env var (same as used by the main setup.py)# Can be one or more architectures, e.g. "6.1" or "3.5;5.2;6.0;6.1;7.0+PTX"# See cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake_arch_list=os.environ.get('TORCH_CUDA_ARCH_LIST',None)# If not given, determine what's best for the GPU / CUDA version that can be foundifnot_arch_list:arch_list=[]# the assumption is that the extension should run on any of the currently visible cards,# which could be of different types - therefore all archs for visible cards should be includedforiinrange(torch.cuda.device_count()):capability=torch.cuda.get_device_capability(i)supported_sm=[int(arch.split('_')[1])forarchintorch.cuda.get_arch_list()if'sm_'inarch]max_supported_sm=max((sm//10,sm%10)forsminsupported_sm)# Capability of the device may be higher than what's supported by the user's# NVCC, causing compilation error. User's NVCC is expected to match the one# used to build pytorch, so we use the maximum supported capability of pytorch# to clamp the capability.capability=min(max_supported_sm,capability)arch=f'{capability[0]}.{capability[1]}'ifarchnotinarch_list:arch_list.append(arch)arch_list=sorted(arch_list)arch_list[-1]+='+PTX'else:# Deal with lists that are ' ' separated (only deal with ';' after)_arch_list=_arch_list.replace(' ',';')# Expand named archesfornamed_arch,archvalinnamed_arches.items():_arch_list=_arch_list.replace(named_arch,archval)arch_list=_arch_list.split(';')flags=[]forarchinarch_list:ifarchnotinvalid_arch_strings:raiseValueError(f"Unknown CUDA arch ({arch}) or GPU not supported")else:num=arch[0]+arch[2]flags.append(f'-gencode=arch=compute_{num},code=sm_{num}')ifarch.endswith('+PTX'):flags.append(f'-gencode=arch=compute_{num},code=compute_{num}')returnsorted(list(set(flags)))def_get_rocm_arch_flags(cflags:Optional[List[str]]=None)->List[str]:# If cflags is given, there may already be user-provided arch flags in it# (from `extra_compile_args`)ifcflagsisnotNone:forflagincflags:if'amdgpu-target'inflag:return['-fno-gpu-rdc']# Use same defaults as used for building PyTorch# Allow env var to override, just like during initial cmake build._archs=os.environ.get('PYTORCH_ROCM_ARCH',None)ifnot_archs:archs=torch.cuda.get_arch_list()else:archs=_archs.replace(' ',';').split(';')flags=['--amdgpu-target=%s'%archforarchinarchs]flags+=['-fno-gpu-rdc']returnflagsdef_get_build_directory(name:str,verbose:bool)->str:root_extensions_directory=os.environ.get('TORCH_EXTENSIONS_DIR')ifroot_extensions_directoryisNone:root_extensions_directory=get_default_build_root()cu_str=('cpu'iftorch.version.cudaisNoneelsef'cu{torch.version.cuda.replace(".","")}')# type: ignore[attr-defined]python_version=f'py{sys.version_info.major}{sys.version_info.minor}'build_folder=f'{python_version}_{cu_str}'root_extensions_directory=os.path.join(root_extensions_directory,build_folder)ifverbose:print(f'Using {root_extensions_directory} as PyTorch extensions root...')build_directory=os.path.join(root_extensions_directory,name)ifnotos.path.exists(build_directory):ifverbose:print(f'Creating extension directory {build_directory}...')# This is like mkdir -p, i.e. will also create parent directories.os.makedirs(build_directory,exist_ok=True)returnbuild_directorydef_get_num_workers(verbose:bool)->Optional[int]:max_jobs=os.environ.get('MAX_JOBS')ifmax_jobsisnotNoneandmax_jobs.isdigit():ifverbose:print(f'Using envvar MAX_JOBS ({max_jobs}) as the number of workers...')returnint(max_jobs)ifverbose:print('Allowing ninja to set a default number of workers... ''(overridable by setting the environment variable MAX_JOBS=N)')returnNonedef_run_ninja_build(build_directory:str,verbose:bool,error_prefix:str)->None:command=['ninja','-v']num_workers=_get_num_workers(verbose)ifnum_workersisnotNone:command.extend(['-j',str(num_workers)])env=os.environ.copy()# Try to activate the vc env for the usersifIS_WINDOWSand'VSCMD_ARG_TGT_ARCH'notinenv:fromsetuptoolsimportdistutilsplat_name=distutils.util.get_platform()plat_spec=PLAT_TO_VCVARS[plat_name]vc_env=distutils._msvccompiler._get_vc_env(plat_spec)vc_env={k.upper():vfork,vinvc_env.items()}fork,vinenv.items():uk=k.upper()ifuknotinvc_env:vc_env[uk]=venv=vc_envtry:sys.stdout.flush()sys.stderr.flush()# Warning: don't pass stdout=None to subprocess.run to get output.# subprocess.run assumes that sys.__stdout__ has not been modified and# attempts to write to it by default. However, when we call _run_ninja_build# from ahead-of-time cpp extensions, the following happens:# 1) If the stdout encoding is not utf-8, setuptools detachs __stdout__.# https://github.com/pypa/setuptools/blob/7e97def47723303fafabe48b22168bbc11bb4821/setuptools/dist.py#L1110# (it probably shouldn't do this)# 2) subprocess.run (on POSIX, with no stdout override) relies on# __stdout__ not being detached:# https://github.com/python/cpython/blob/c352e6c7446c894b13643f538db312092b351789/Lib/subprocess.py#L1214# To work around this, we pass in the fileno directly and hope that# it is valid.stdout_fileno=1subprocess.run(command,stdout=stdout_filenoifverboseelsesubprocess.PIPE,stderr=subprocess.STDOUT,cwd=build_directory,check=True,env=env)exceptsubprocess.CalledProcessErrorase:# Python 2 and 3 compatible way of getting the error object._,error,_=sys.exc_info()# error.output contains the stdout and stderr of the build attempt.message=error_prefix# `error` is a CalledProcessError (which has an `ouput`) attribute, but# mypy thinks it's Optional[BaseException] and doesn't narrowifhasattr(error,'output')anderror.output:# type: ignore[union-attr]message+=f": {error.output.decode(*SUBPROCESS_DECODE_ARGS)}"# type: ignore[union-attr]raiseRuntimeError(message)fromedef_get_exec_path(module_name,path):ifIS_WINDOWSandTORCH_LIB_PATHnotinos.getenv('PATH','').split(';'):torch_lib_in_path=any(os.path.exists(p)andos.path.samefile(p,TORCH_LIB_PATH)forpinos.getenv('PATH','').split(';'))ifnottorch_lib_in_path:os.environ['PATH']=f"{TORCH_LIB_PATH};{os.getenv('PATH','')}"returnos.path.join(path,f'{module_name}{EXEC_EXT}')def_import_module_from_library(module_name,path,is_python_module):filepath=os.path.join(path,f"{module_name}{LIB_EXT}")ifis_python_module:# https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-pathspec=importlib.util.spec_from_file_location(module_name,filepath)module=importlib.util.module_from_spec(spec)assertisinstance(spec.loader,importlib.abc.Loader)spec.loader.exec_module(module)returnmoduleelse:torch.ops.load_library(filepath)def_write_ninja_file_to_build_library(path,name,sources,extra_cflags,extra_cuda_cflags,extra_ldflags,extra_include_paths,with_cuda,is_standalone)->None:extra_cflags=[flag.strip()forflaginextra_cflags]extra_cuda_cflags=[flag.strip()forflaginextra_cuda_cflags]extra_ldflags=[flag.strip()forflaginextra_ldflags]extra_include_paths=[flag.strip()forflaginextra_include_paths]# Turn into absolute paths so we can emit them into the ninja build# file wherever it is.user_includes=[os.path.abspath(file)forfileinextra_include_paths]# include_paths() gives us the location of torch/extension.hsystem_includes=include_paths(with_cuda)# sysconfig.get_path('include') gives us the location of Python.h# Explicitly specify 'posix_prefix' scheme on non-Windows platforms to workaround error on some MacOS# installations where default `get_path` points to non-existing `/Library/Python/M.m/include` folderpython_include_path=sysconfig.get_path('include',scheme='nt'ifIS_WINDOWSelse'posix_prefix')ifpython_include_pathisnotNone:system_includes.append(python_include_path)# Windows does not understand `-isystem`.ifIS_WINDOWS:user_includes+=system_includessystem_includes.clear()common_cflags=[]ifnotis_standalone:common_cflags.append(f'-DTORCH_EXTENSION_NAME={name}')common_cflags.append('-DTORCH_API_INCLUDE_EXTENSION_H')# Note [Pybind11 ABI constants]## Pybind11 before 2.4 used to build an ABI strings using the following pattern:# f"__pybind11_internals_v{PYBIND11_INTERNALS_VERSION}{PYBIND11_INTERNALS_KIND}{PYBIND11_BUILD_TYPE}__"# Since 2.4 compier type, stdlib and build abi parameters are also encoded like this:# f"__pybind11_internals_v{PYBIND11_INTERNALS_VERSION}{PYBIND11_INTERNALS_KIND}{PYBIND11_COMPILER_TYPE}{PYBIND11_STDLIB}{PYBIND11_BUILD_ABI}{PYBIND11_BUILD_TYPE}__"## This was done in order to further narrow down the chances of compiler ABI incompatibility# that can cause a hard to debug segfaults.# For PyTorch extensions we want to relax those restrictions and pass compiler, stdlib and abi properties# captured during PyTorch native library compilation in torch/csrc/Module.cppforpnamein["COMPILER_TYPE","STDLIB","BUILD_ABI"]:pval=getattr(torch._C,f"_PYBIND11_{pname}")ifpvalisnotNoneandnotIS_WINDOWS:common_cflags.append(f'-DPYBIND11_{pname}=\\"{pval}\\"')common_cflags+=[f'-I{include}'forincludeinuser_includes]common_cflags+=[f'-isystem {include}'forincludeinsystem_includes]common_cflags+=['-D_GLIBCXX_USE_CXX11_ABI='+str(int(torch._C._GLIBCXX_USE_CXX11_ABI))]ifIS_WINDOWS:cflags=common_cflags+COMMON_MSVC_FLAGS+extra_cflagscflags=_nt_quote_args(cflags)else:cflags=common_cflags+['-fPIC','-std=c++14']+extra_cflagsifwith_cudaandIS_HIP_EXTENSION:cuda_flags=['-DWITH_HIP']+cflags+COMMON_HIP_FLAGS+COMMON_HIPCC_FLAGScuda_flags+=extra_cuda_cflagscuda_flags+=_get_rocm_arch_flags(cuda_flags)sources=[sifnot_is_cuda_file(s)elseos.path.abspath(os.path.join(path,get_hip_file_path(os.path.relpath(s,path),is_pytorch_extension=True)))forsinsources]elifwith_cuda:cuda_flags=common_cflags+COMMON_NVCC_FLAGS+_get_cuda_arch_flags()ifIS_WINDOWS:forflaginCOMMON_MSVC_FLAGS:cuda_flags=['-Xcompiler',flag]+cuda_flagsforignore_warninginMSVC_IGNORE_CUDAFE_WARNINGS:cuda_flags=['-Xcudafe','--diag_suppress='+ignore_warning]+cuda_flagscuda_flags=_nt_quote_args(cuda_flags)cuda_flags+=_nt_quote_args(extra_cuda_cflags)else:cuda_flags+=['--compiler-options',"'-fPIC'"]cuda_flags+=extra_cuda_cflagsifnotany(flag.startswith('-std=')forflagincuda_flags):cuda_flags.append('-std=c++14')ifos.getenv("CC")isnotNone:cuda_flags=['-ccbin',os.getenv("CC")]+cuda_flagselse:cuda_flags=Nonedefobject_file_path(source_file:str)->str:# '/path/to/file.cpp' -> 'file'file_name=os.path.splitext(os.path.basename(source_file))[0]if_is_cuda_file(source_file)andwith_cuda:# Use a different object filename in case a C++ and CUDA file have# the same filename but different extension (.cpp vs. .cu).target=f'{file_name}.cuda.o'else:target=f'{file_name}.o'returntargetobjects=[object_file_path(src)forsrcinsources]ldflags=([]ifis_standaloneelse[SHARED_FLAG])+extra_ldflags# The darwin linker needs explicit consent to ignore unresolved symbols.ifIS_MACOS:ldflags.append('-undefined dynamic_lookup')elifIS_WINDOWS:ldflags=_nt_quote_args(ldflags)ext=EXEC_EXTifis_standaloneelseLIB_EXTlibrary_target=f'{name}{ext}'_write_ninja_file(path=path,cflags=cflags,post_cflags=None,cuda_cflags=cuda_flags,cuda_post_cflags=None,sources=sources,objects=objects,ldflags=ldflags,library_target=library_target,with_cuda=with_cuda)def_write_ninja_file(path,cflags,post_cflags,cuda_cflags,cuda_post_cflags,sources,objects,ldflags,library_target,with_cuda)->None:r"""Write a ninja file that does the desired compiling and linking. `path`: Where to write this file `cflags`: list of flags to pass to $cxx. Can be None. `post_cflags`: list of flags to append to the $cxx invocation. Can be None. `cuda_cflags`: list of flags to pass to $nvcc. Can be None. `cuda_postflags`: list of flags to append to the $nvcc invocation. Can be None. `sources`: list of paths to source files `objects`: list of desired paths to objects, one per source. `ldflags`: list of flags to pass to linker. Can be None. `library_target`: Name of the output library. Can be None; in that case, we do no linking. `with_cuda`: If we should be compiling with CUDA. """defsanitize_flags(flags):ifflagsisNone:return[]else:return[flag.strip()forflaginflags]cflags=sanitize_flags(cflags)post_cflags=sanitize_flags(post_cflags)cuda_cflags=sanitize_flags(cuda_cflags)cuda_post_cflags=sanitize_flags(cuda_post_cflags)ldflags=sanitize_flags(ldflags)# Sanity checks...assertlen(sources)==len(objects)assertlen(sources)>0ifIS_WINDOWS:compiler=os.environ.get('CXX','cl')else:compiler=os.environ.get('CXX','c++')# Version 1.3 is required for the `deps` directive.config=['ninja_required_version = 1.3']config.append(f'cxx = {compiler}')ifwith_cuda:ifIS_HIP_EXTENSION:nvcc=_join_rocm_home('bin','hipcc')else:nvcc=_join_cuda_home('bin','nvcc')config.append(f'nvcc = {nvcc}')flags=[f'cflags = {" ".join(cflags)}']flags.append(f'post_cflags = {" ".join(post_cflags)}')ifwith_cuda:flags.append(f'cuda_cflags = {" ".join(cuda_cflags)}')flags.append(f'cuda_post_cflags = {" ".join(cuda_post_cflags)}')flags.append(f'ldflags = {" ".join(ldflags)}')# Turn into absolute paths so we can emit them into the ninja build# file wherever it is.sources=[os.path.abspath(file)forfileinsources]# See https://ninja-build.org/build.ninja.html for reference.compile_rule=['rule compile']ifIS_WINDOWS:compile_rule.append(' command = cl /showIncludes $cflags -c $in /Fo$out $post_cflags')compile_rule.append(' deps = msvc')else:compile_rule.append(' command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags')compile_rule.append(' depfile = $out.d')compile_rule.append(' deps = gcc')ifwith_cuda:cuda_compile_rule=['rule cuda_compile']nvcc_gendeps=''# --generate-dependencies-with-compile was added in CUDA 10.2.# Compilation will work on earlier CUDA versions but header file# dependencies are not correctly computed.required_cuda_version=packaging.version.parse('10.2')has_cuda_version=torch.version.cudaisnotNoneifhas_cuda_versionandpackaging.version.parse(torch.version.cuda)>=required_cuda_version:cuda_compile_rule.append(' depfile = $out.d')cuda_compile_rule.append(' deps = gcc')# Note: non-system deps with nvcc are only supported# on Linux so use --generate-dependencies-with-compile# to make this work on Windows too.ifIS_WINDOWS:nvcc_gendeps='--generate-dependencies-with-compile --dependency-output $out.d'cuda_compile_rule.append(f' command = $nvcc {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags')# Emit one build rule per source to enable incremental build.build=[]forsource_file,object_fileinzip(sources,objects):is_cuda_source=_is_cuda_file(source_file)andwith_cudarule='cuda_compile'ifis_cuda_sourceelse'compile'ifIS_WINDOWS:source_file=source_file.replace(':','$:')object_file=object_file.replace(':','$:')source_file=source_file.replace(" ","$ ")object_file=object_file.replace(" ","$ ")build.append(f'build {object_file}: {rule}{source_file}')iflibrary_targetisnotNone:link_rule=['rule link']ifIS_WINDOWS:cl_paths=subprocess.check_output(['where','cl']).decode(*SUBPROCESS_DECODE_ARGS).split('\r\n')iflen(cl_paths)>=1:cl_path=os.path.dirname(cl_paths[0]).replace(':','$:')else:raiseRuntimeError("MSVC is required to load C++ extensions")link_rule.append(f' command = "{cl_path}/link.exe" $in /nologo $ldflags /out:$out')else:link_rule.append(' command = $cxx $in $ldflags -o $out')link=[f'build {library_target}: link {" ".join(objects)}']default=[f'default {library_target}']else:link_rule,link,default=[],[],[]# 'Blocks' should be separated by newlines, for visual benefit.blocks=[config,flags,compile_rule]ifwith_cuda:blocks.append(cuda_compile_rule)blocks+=[link_rule,build,link,default]withopen(path,'w')asbuild_file:forblockinblocks:lines='\n'.join(block)build_file.write(f'{lines}\n\n')def_join_cuda_home(*paths)->str:r''' Joins paths with CUDA_HOME, or raises an error if it CUDA_HOME is not set. This is basically a lazy way of raising an error for missing $CUDA_HOME only once we need to get any CUDA-specific path. '''ifCUDA_HOMEisNone:raiseEnvironmentError('CUDA_HOME environment variable is not set. ''Please set it to your CUDA install root.')returnos.path.join(CUDA_HOME,*paths)def_is_cuda_file(path:str)->bool:valid_ext=['.cu','.cuh']ifIS_HIP_EXTENSION:valid_ext.append('.hip')returnos.path.splitext(path)[1]invalid_ext
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.