importbuiltinsimportimportlibimportimportlib.machineryimportinspectimportioimportlinecacheimportos.pathimporttypesfromcontextlibimportcontextmanagerfrompathlibimportPathfromtypingimportAny,BinaryIO,Callable,cast,Dict,Iterable,List,Optional,UnionfromweakrefimportWeakValueDictionaryimporttorchfromtorch.serializationimport_get_restore_location,_maybe_decode_asciifrom._directory_readerimportDirectoryReaderfrom._importlibimport(_calc___package__,_normalize_line_endings,_normalize_path,_resolve_name,_sanity_check,)from._manglingimportdemangle,PackageManglerfrom._package_unpicklerimportPackageUnpicklerfrom.file_structure_representationimport_create_directory_from_file_list,Directoryfrom.glob_groupimportGlobPatternfrom.importerimportImporter__all__=["PackageImporter"]# This is a list of imports that are implicitly allowed even if they haven't# been marked as extern. This is to work around the fact that Torch implicitly# depends on numpy and package can't track it.# https://github.com/pytorch/MultiPy/issues/46IMPLICIT_IMPORT_ALLOWLIST:Iterable[str]=["numpy","numpy.core","numpy.core._multiarray_umath",# FX GraphModule might depend on builtins module and users usually# don't extern builtins. Here we import it here by default."builtins",]
[docs]classPackageImporter(Importer):"""Importers allow you to load code written to packages by :class:`PackageExporter`. Code is loaded in a hermetic way, using files from the package rather than the normal python import system. This allows for the packaging of PyTorch model code and data so that it can be run on a server or used in the future for transfer learning. The importer for packages ensures that code in the module can only be loaded from within the package, except for modules explicitly listed as external during export. The file ``extern_modules`` in the zip archive lists all the modules that a package externally depends on. This prevents "implicit" dependencies where the package runs locally because it is importing a locally-installed package, but then fails when the package is copied to another machine. """"""The dictionary of already loaded modules from this package, equivalent to ``sys.modules`` but local to this importer. """modules:Dict[str,types.ModuleType]
[docs]def__init__(self,file_or_buffer:Union[str,torch._C.PyTorchFileReader,Path,BinaryIO],module_allowed:Callable[[str],bool]=lambdamodule_name:True,):"""Open ``file_or_buffer`` for importing. This checks that the imported package only requires modules allowed by ``module_allowed`` Args: file_or_buffer: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`), a string, or an ``os.PathLike`` object containing a filename. module_allowed (Callable[[str], bool], optional): A method to determine if a externally provided module should be allowed. Can be used to ensure packages loaded do not depend on modules that the server does not support. Defaults to allowing anything. Raises: ImportError: If the package will use a disallowed module. """torch._C._log_api_usage_once("torch.package.PackageImporter")self.zip_reader:Anyifisinstance(file_or_buffer,torch._C.PyTorchFileReader):self.filename="<pytorch_file_reader>"self.zip_reader=file_or_bufferelifisinstance(file_or_buffer,(Path,str)):self.filename=str(file_or_buffer)ifnotos.path.isdir(self.filename):self.zip_reader=torch._C.PyTorchFileReader(self.filename)else:self.zip_reader=DirectoryReader(self.filename)else:self.filename="<binary>"self.zip_reader=torch._C.PyTorchFileReader(file_or_buffer)torch._C._log_api_usage_metadata("torch.package.PackageImporter.metadata",{"serialization_id":self.zip_reader.serialization_id()},)self.root=_PackageNode(None)self.modules={}self.extern_modules=self._read_extern()forextern_moduleinself.extern_modules:ifnotmodule_allowed(extern_module):raiseImportError(f"package '{file_or_buffer}' needs the external module '{extern_module}' "f"but that module has been disallowed")self._add_extern(extern_module)forfnameinself.zip_reader.get_all_records():self._add_file(fname)self.patched_builtins=builtins.__dict__.copy()self.patched_builtins["__import__"]=self.__import__# Allow packaged modules to reference their PackageImporterself.modules["torch_package_importer"]=self# type: ignore[assignment]self._mangler=PackageMangler()# used for reduce deserializaitonself.storage_context:Any=Noneself.last_map_location=None# used for torch.serialization._loadself.Unpickler=lambda*args,**kwargs:PackageUnpickler(self,*args,**kwargs)
[docs]defimport_module(self,name:str,package=None):"""Load a module from the package if it hasn't already been loaded, and then return the module. Modules are loaded locally to the importer and will appear in ``self.modules`` rather than ``sys.modules``. Args: name (str): Fully qualified name of the module to load. package ([type], optional): Unused, but present to match the signature of importlib.import_module. Defaults to ``None``. Returns: types.ModuleType: The (possibly already) loaded module. """# We should always be able to support importing modules from this package.# This is to support something like:# obj = importer.load_pickle(...)# importer.import_module(obj.__module__) <- this string will be mangled## Note that _mangler.demangle will not demangle any module names# produced by a different PackageImporter instance.name=self._mangler.demangle(name)returnself._gcd_import(name)
[docs]defload_binary(self,package:str,resource:str)->bytes:"""Load raw bytes. Args: package (str): The name of module package (e.g. ``"my_package.my_subpackage"``). resource (str): The unique name for the resource. Returns: bytes: The loaded data. """path=self._zipfile_path(package,resource)returnself.zip_reader.get_record(path)
[docs]defload_text(self,package:str,resource:str,encoding:str="utf-8",errors:str="strict",)->str:"""Load a string. Args: package (str): The name of module package (e.g. ``"my_package.my_subpackage"``). resource (str): The unique name for the resource. encoding (str, optional): Passed to ``decode``. Defaults to ``'utf-8'``. errors (str, optional): Passed to ``decode``. Defaults to ``'strict'``. Returns: str: The loaded text. """data=self.load_binary(package,resource)returndata.decode(encoding,errors)
[docs]defload_pickle(self,package:str,resource:str,map_location=None)->Any:"""Unpickles the resource from the package, loading any modules that are needed to construct the objects using :meth:`import_module`. Args: package (str): The name of module package (e.g. ``"my_package.my_subpackage"``). resource (str): The unique name for the resource. map_location: Passed to `torch.load` to determine how tensors are mapped to devices. Defaults to ``None``. Returns: Any: The unpickled object. """pickle_file=self._zipfile_path(package,resource)restore_location=_get_restore_location(map_location)loaded_storages={}loaded_reduces={}storage_context=torch._C.DeserializationStorageContext()defload_tensor(dtype,size,key,location,restore_location):name=f"{key}.storage"ifstorage_context.has_storage(name):storage=storage_context.get_storage(name,dtype)._typed_storage()else:tensor=self.zip_reader.get_storage_from_record(".data/"+name,size,dtype)ifisinstance(self.zip_reader,torch._C.PyTorchFileReader):storage_context.add_storage(name,tensor)storage=tensor._typed_storage()loaded_storages[key]=restore_location(storage,location)defpersistent_load(saved_id):assertisinstance(saved_id,tuple)typename=_maybe_decode_ascii(saved_id[0])data=saved_id[1:]iftypename=="storage":storage_type,key,location,size=datadtype=storage_type.dtypeifkeynotinloaded_storages:load_tensor(dtype,size,key,_maybe_decode_ascii(location),restore_location,)storage=loaded_storages[key]# TODO: Once we decide to break serialization FC, we can# stop wrapping with TypedStoragereturntorch.storage.TypedStorage(wrap_storage=storage._untyped_storage,dtype=dtype,_internal=True)eliftypename=="reduce_package":# to fix BC breaking change, objects on this load path# will be loaded multiple times erroneouslyiflen(data)==2:func,args=datareturnfunc(self,*args)reduce_id,func,args=dataifreduce_idnotinloaded_reduces:loaded_reduces[reduce_id]=func(self,*args)returnloaded_reduces[reduce_id]else:f"Unknown typename for persistent_load, expected 'storage' or 'reduce_package' but got '{typename}'"# Load the data (which may in turn use `persistent_load` to load tensors)data_file=io.BytesIO(self.zip_reader.get_record(pickle_file))unpickler=self.Unpickler(data_file)unpickler.persistent_load=persistent_load# type: ignore[assignment]@contextmanagerdefset_deserialization_context():# to let reduce_package access deserializaiton contextself.storage_context=storage_contextself.last_map_location=map_locationtry:yieldfinally:self.storage_context=Noneself.last_map_location=Nonewithset_deserialization_context():result=unpickler.load()# TODO from zdevito:# This stateful weird function will need to be removed in our efforts# to unify the format. It has a race condition if multiple python# threads try to read independent filestorch._utils._validate_loaded_sparse_tensors()returnresult
[docs]defid(self):""" Returns internal identifier that torch.package uses to distinguish :class:`PackageImporter` instances. Looks like:: <torch_package_0> """returnself._mangler.parent_name()
[docs]deffile_structure(self,*,include:"GlobPattern"="**",exclude:"GlobPattern"=())->Directory:"""Returns a file structure representation of package's zipfile. Args: include (Union[List[str], str]): An optional string e.g. ``"my_package.my_subpackage"``, or optional list of strings for the names of the files to be included in the zipfile representation. This can also be a glob-style pattern, as described in :meth:`PackageExporter.mock` exclude (Union[List[str], str]): An optional pattern that excludes files whose name match the pattern. Returns: :class:`Directory` """return_create_directory_from_file_list(self.filename,self.zip_reader.get_all_records(),include,exclude)
[docs]defpython_version(self):"""Returns the version of python that was used to create this package. Note: this function is experimental and not Forward Compatible. The plan is to move this into a lock file later on. Returns: :class:`Optional[str]` a python version e.g. 3.8.9 or None if no version was stored with this package """python_version_path=".data/python_version"return(self.zip_reader.get_record(python_version_path).decode("utf-8").strip()ifself.zip_reader.has_record(python_version_path)elseNone)
def_read_extern(self):return(self.zip_reader.get_record(".data/extern_modules").decode("utf-8").splitlines(keepends=False))def_make_module(self,name:str,filename:Optional[str],is_package:bool,parent:str):mangled_filename=self._mangler.mangle(filename)iffilenameelseNonespec=importlib.machinery.ModuleSpec(name,self,# type: ignore[arg-type]origin="<package_importer>",is_package=is_package,)module=importlib.util.module_from_spec(spec)self.modules[name]=modulemodule.__name__=self._mangler.mangle(name)ns=module.__dict__ns["__spec__"]=specns["__loader__"]=selfns["__file__"]=mangled_filenamens["__cached__"]=Nonens["__builtins__"]=self.patched_builtinsns["__torch_package__"]=True# Add this module to our private global registry. It should be unique due to mangling.assertmodule.__name__notin_package_imported_modules_package_imported_modules[module.__name__]=module# pre-emptively install on the parent to prevent IMPORT_FROM from trying to# access sys.modulesself._install_on_parent(parent,name,module)iffilenameisnotNone:assertmangled_filenameisnotNone# pre-emptively install the source in `linecache` so that stack traces,# `inspect`, etc. work.assertfilenamenotinlinecache.cache# type: ignore[attr-defined]linecache.lazycache(mangled_filename,ns)code=self._compile_source(filename,mangled_filename)exec(code,ns)returnmoduledef_load_module(self,name:str,parent:str):cur:_PathNode=self.rootforatominname.split("."):ifnotisinstance(cur,_PackageNode)oratomnotincur.children:ifnameinIMPLICIT_IMPORT_ALLOWLIST:module=self.modules[name]=importlib.import_module(name)returnmoduleraiseModuleNotFoundError(f'No module named "{name}" in self-contained archive "{self.filename}"'f" and the module is also not in the list of allowed external modules: {self.extern_modules}",name=name,)cur=cur.children[atom]ifisinstance(cur,_ExternNode):module=self.modules[name]=importlib.import_module(name)returnmodulereturnself._make_module(name,cur.source_file,isinstance(cur,_PackageNode),parent)# type: ignore[attr-defined]def_compile_source(self,fullpath:str,mangled_filename:str):source=self.zip_reader.get_record(fullpath)source=_normalize_line_endings(source)returncompile(source,mangled_filename,"exec",dont_inherit=True)# note: named `get_source` so that linecache can find the source# when this is the __loader__ of a module.defget_source(self,module_name)->str:# linecache calls `get_source` with the `module.__name__` as the argument, so we must demangle it here.module=self.import_module(demangle(module_name))returnself.zip_reader.get_record(demangle(module.__file__)).decode("utf-8")# note: named `get_resource_reader` so that importlib.resources can find it.# This is otherwise considered an internal method.defget_resource_reader(self,fullname):try:package=self._get_package(fullname)exceptImportError:returnNoneifpackage.__loader__isnotself:returnNonereturn_PackageResourceReader(self,fullname)def_install_on_parent(self,parent:str,name:str,module:types.ModuleType):ifnotparent:return# Set the module as an attribute on its parent.parent_module=self.modules[parent]ifparent_module.__loader__isself:setattr(parent_module,name.rpartition(".")[2],module)# note: copied from cpython's import code, with call to create module replaced with _make_moduledef_do_find_and_load(self,name):path=Noneparent=name.rpartition(".")[0]module_name_no_parent=name.rpartition(".")[-1]ifparent:ifparentnotinself.modules:self._gcd_import(parent)# Crazy side-effects!ifnameinself.modules:returnself.modules[name]parent_module=self.modules[parent]try:path=parent_module.__path__# type: ignore[attr-defined]exceptAttributeError:# when we attempt to import a package only containing pybinded files,# the parent directory isn't always a package as defined by python,# so we search if the package is actually there or not before calling the error.ifisinstance(parent_module.__loader__,importlib.machinery.ExtensionFileLoader,):ifnamenotinself.extern_modules:msg=(_ERR_MSG+"; {!r} is a c extension module which was not externed. C extension modules \ need to be externed by the PackageExporter in order to be used as we do not support interning them.}.").format(name,name)raiseModuleNotFoundError(msg,name=name)fromNoneifnotisinstance(parent_module.__dict__.get(module_name_no_parent),types.ModuleType,):msg=(_ERR_MSG+"; {!r} is a c extension package which does not contain {!r}.").format(name,parent,name)raiseModuleNotFoundError(msg,name=name)fromNoneelse:msg=(_ERR_MSG+"; {!r} is not a package").format(name,parent)raiseModuleNotFoundError(msg,name=name)fromNonemodule=self._load_module(name,parent)self._install_on_parent(parent,name,module)returnmodule# note: copied from cpython's import codedef_find_and_load(self,name):module=self.modules.get(name,_NEEDS_LOADING)ifmoduleis_NEEDS_LOADING:returnself._do_find_and_load(name)ifmoduleisNone:message=f"import of {name} halted; None in sys.modules"raiseModuleNotFoundError(message,name=name)# To handle https://github.com/pytorch/pytorch/issues/57490, where std's# creation of fake submodules via the hacking of sys.modules is not import# friendlyifname=="os":self.modules["os.path"]=cast(Any,module).pathelifname=="typing":self.modules["typing.io"]=cast(Any,module).ioself.modules["typing.re"]=cast(Any,module).rereturnmoduledef_gcd_import(self,name,package=None,level=0):"""Import and return the module based on its name, the package the call is being made from, and the level adjustment. This function represents the greatest common denominator of functionality between import_module and __import__. This includes setting __package__ if the loader did not. """_sanity_check(name,package,level)iflevel>0:name=_resolve_name(name,package,level)returnself._find_and_load(name)# note: copied from cpython's import codedef_handle_fromlist(self,module,fromlist,*,recursive=False):"""Figure out what __import__ should return. The import_ parameter is a callable which takes the name of module to import. It is required to decouple the function from assuming importlib's import implementation is desired. """module_name=demangle(module.__name__)# The hell that is fromlist ...# If a package was imported, try to import stuff from fromlist.ifhasattr(module,"__path__"):forxinfromlist:ifnotisinstance(x,str):ifrecursive:where=module_name+".__all__"else:where="``from list''"raiseTypeError(f"Item in {where} must be str, "f"not {type(x).__name__}")elifx=="*":ifnotrecursiveandhasattr(module,"__all__"):self._handle_fromlist(module,module.__all__,recursive=True)elifnothasattr(module,x):from_name=f"{module_name}.{x}"try:self._gcd_import(from_name)exceptModuleNotFoundErrorasexc:# Backwards-compatibility dictates we ignore failed# imports triggered by fromlist for modules that don't# exist.if(exc.name==from_nameandself.modules.get(from_name,_NEEDS_LOADING)isnotNone):continueraisereturnmoduledef__import__(self,name,globals=None,locals=None,fromlist=(),level=0):iflevel==0:module=self._gcd_import(name)else:globals_=globalsifglobalsisnotNoneelse{}package=_calc___package__(globals_)module=self._gcd_import(name,package,level)ifnotfromlist:# Return up to the first dot in 'name'. This is complicated by the fact# that 'name' may be relative.iflevel==0:returnself._gcd_import(name.partition(".")[0])elifnotname:returnmoduleelse:# Figure out where to slice the module's name up to the first dot# in 'name'.cut_off=len(name)-len(name.partition(".")[0])# Slice end needs to be positive to alleviate need to special-case# when ``'.' not in name``.module_name=demangle(module.__name__)returnself.modules[module_name[:len(module_name)-cut_off]]else:returnself._handle_fromlist(module,fromlist)def_get_package(self,package):"""Take a package name or module object and return the module. If a name, the module is imported. If the passed or imported module object is not a package, raise an exception. """ifhasattr(package,"__spec__"):ifpackage.__spec__.submodule_search_locationsisNone:raiseTypeError(f"{package.__spec__.name!r} is not a package")else:returnpackageelse:module=self.import_module(package)ifmodule.__spec__.submodule_search_locationsisNone:raiseTypeError(f"{package!r} is not a package")else:returnmoduledef_zipfile_path(self,package,resource=None):package=self._get_package(package)assertpackage.__loader__isselfname=demangle(package.__name__)ifresourceisnotNone:resource=_normalize_path(resource)returnf"{name.replace('.','/')}/{resource}"else:returnf"{name.replace('.','/')}"def_get_or_create_package(self,atoms:List[str])->"Union[_PackageNode, _ExternNode]":cur=self.rootfori,atominenumerate(atoms):node=cur.children.get(atom,None)ifnodeisNone:node=cur.children[atom]=_PackageNode(None)ifisinstance(node,_ExternNode):returnnodeifisinstance(node,_ModuleNode):name=".".join(atoms[:i])raiseImportError(f"inconsistent module structure. module {name} is not a package, but has submodules")assertisinstance(node,_PackageNode)cur=nodereturncurdef_add_file(self,filename:str):"""Assembles a Python module out of the given file. Will ignore files in the .data directory. Args: filename (str): the name of the file inside of the package archive to be added """*prefix,last=filename.split("/")iflen(prefix)>1andprefix[0]==".data":returnpackage=self._get_or_create_package(prefix)ifisinstance(package,_ExternNode):raiseImportError(f"inconsistent module structure. package contains a module file {filename}"f" that is a subpackage of a module marked external.")iflast=="__init__.py":package.source_file=filenameeliflast.endswith(".py"):package_name=last[:-len(".py")]package.children[package_name]=_ModuleNode(filename)def_add_extern(self,extern_name:str):*prefix,last=extern_name.split(".")package=self._get_or_create_package(prefix)ifisinstance(package,_ExternNode):return# the shorter extern covers this extern casepackage.children[last]=_ExternNode()
_NEEDS_LOADING=object()_ERR_MSG_PREFIX="No module named "_ERR_MSG=_ERR_MSG_PREFIX+"{!r}"class_PathNode:passclass_PackageNode(_PathNode):def__init__(self,source_file:Optional[str]):self.source_file=source_fileself.children:Dict[str,_PathNode]={}class_ModuleNode(_PathNode):__slots__=["source_file"]def__init__(self,source_file:str):self.source_file=source_fileclass_ExternNode(_PathNode):pass# A private global registry of all modules that have been package-imported._package_imported_modules:WeakValueDictionary=WeakValueDictionary()# `inspect` by default only looks in `sys.modules` to find source files for classes.# Patch it to check our private registry of package-imported modules as well._orig_getfile=inspect.getfiledef_patched_getfile(object):ifinspect.isclass(object):ifobject.__module__in_package_imported_modules:return_package_imported_modules[object.__module__].__file__return_orig_getfile(object)inspect.getfile=_patched_getfileclass_PackageResourceReader:"""Private class used to support PackageImporter.get_resource_reader(). Confirms to the importlib.abc.ResourceReader interface. Allowed to access the innards of PackageImporter. """def__init__(self,importer,fullname):self.importer=importerself.fullname=fullnamedefopen_resource(self,resource):fromioimportBytesIOreturnBytesIO(self.importer.load_binary(self.fullname,resource))defresource_path(self,resource):# The contract for resource_path is that it either returns a concrete# file system path or raises FileNotFoundError.ifisinstance(self.importer.zip_reader,DirectoryReader)andself.importer.zip_reader.has_record(os.path.join(self.fullname,resource)):returnos.path.join(self.importer.zip_reader.directory,self.fullname,resource)raiseFileNotFoundErrordefis_resource(self,name):path=self.importer._zipfile_path(self.fullname,name)returnself.importer.zip_reader.has_record(path)defcontents(self):frompathlibimportPathfilename=self.fullname.replace(".","/")fullname_path=Path(self.importer._zipfile_path(self.fullname))files=self.importer.zip_reader.get_all_records()subdirs_seen=set()forfilenameinfiles:try:relative=Path(filename).relative_to(fullname_path)exceptValueError:continue# If the path of the file (which is relative to the top of the zip# namespace), relative to the package given when the resource# reader was created, has a parent, then it's a name in a# subdirectory and thus we skip it.parent_name=relative.parent.nameiflen(parent_name)==0:yieldrelative.nameelifparent_namenotinsubdirs_seen:subdirs_seen.add(parent_name)yieldparent_name
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.