importdifflibimportosimportioimportshutilimportstructimportsysimporttorchimporttarfileimporttempfileimportwarningsfromcontextlibimportclosing,contextmanagerfrom._utilsimport_import_dotted_namefrom._siximportstring_classesas_string_classesfromtorch._sourcesimportget_source_lines_and_filefromtorch.typesimportStoragefromtypingimportAny,BinaryIO,cast,Dict,Optional,Type,Tuple,Union,IOimportcopyregimportpickleimportpathlibDEFAULT_PROTOCOL=2LONG_SIZE=struct.Struct('=l').sizeINT_SIZE=struct.Struct('=i').sizeSHORT_SIZE=struct.Struct('=h').sizeMAGIC_NUMBER=0x1950a86a20f9469cfc6cPROTOCOL_VERSION=1001STORAGE_KEY_SEPARATOR=','classSourceChangeWarning(Warning):pass@contextmanagerdefmkdtemp():path=tempfile.mkdtemp()yieldpathshutil.rmtree(path)_package_registry=[]def_is_zipfile(f)->bool:# This is a stricter implementation than zipfile.is_zipfile().# zipfile.is_zipfile() is True if the magic number appears anywhere in the# binary. Since we expect the files here to be generated by torch.save or# torch.jit.save, it's safe to only check the start bytes and avoid# collisions and assume the zip has only 1 file.# See bugs.python.org/issue28494.# Read the first 4 bytes of the fileread_bytes=[]start=f.tell()byte=f.read(1)whilebyte!="":read_bytes.append(byte)iflen(read_bytes)==4:breakbyte=f.read(1)f.seek(start)local_header_magic_number=[b'P',b'K',b'\x03',b'\x04']returnread_bytes==local_header_magic_numberdefregister_package(priority,tagger,deserializer):queue_elem=(priority,tagger,deserializer)_package_registry.append(queue_elem)_package_registry.sort()defcheck_module_version_greater_or_equal(module,req_version_tuple,error_if_malformed=True):''' Check if a module's version satisfies requirements Usually, a module's version string will be like 'x.y.z', which would be represented as a tuple (x, y, z), but sometimes it could be an unexpected format. If the version string does not match the given tuple's format up to the length of the tuple, then error and exit or emit a warning. Args: module: the module to check the version of req_version_tuple: tuple (usually of ints) representing the required version error_if_malformed: whether we should exit if module version string is malformed Returns: requirement_is_met: bool '''try:version_strs=module.__version__.split('.')# Cast module version fields to match the types of the required versionmodule_version=tuple(type(req_field)(version_strs[idx])foridx,req_fieldinenumerate(req_version_tuple))requirement_is_met=module_version>=req_version_tupleexceptExceptionase:message=("'%s' module version string is malformed '%s' and cannot be compared"" with tuple %s")%(module.__name__,module.__version__,str(req_version_tuple))iferror_if_malformed:raiseRuntimeError(message)fromeelse:warnings.warn(message+', but continuing assuming that requirement is met')requirement_is_met=Truereturnrequirement_is_metdef_cpu_tag(obj):iftype(obj).__module__=='torch':return'cpu'def_cuda_tag(obj):iftype(obj).__module__=='torch.cuda':return'cuda:'+str(obj.get_device())def_cpu_deserialize(obj,location):iflocation=='cpu':returnobjdefvalidate_cuda_device(location):device=torch.cuda._utils._get_device_index(location,True)ifnottorch.cuda.is_available():raiseRuntimeError('Attempting to deserialize object on a CUDA ''device but torch.cuda.is_available() is False. ''If you are running on a CPU-only machine, ''please use torch.load with map_location=torch.device(\'cpu\') ''to map your storages to the CPU.')device_count=torch.cuda.device_count()ifdevice>=device_count:raiseRuntimeError('Attempting to deserialize object on CUDA device 'f'{device} but torch.cuda.device_count() is {device_count}. Please use ''torch.load with map_location to map your storages ''to an existing device.')returndevicedef_cuda_deserialize(obj,location):iflocation.startswith('cuda'):device=validate_cuda_device(location)ifgetattr(obj,"_torch_load_uninitialized",False):storage_type=getattr(torch.cuda,type(obj).__name__)withtorch.cuda.device(device):returnstorage_type(obj.size())else:returnobj.cuda(device)register_package(10,_cpu_tag,_cpu_deserialize)register_package(20,_cuda_tag,_cuda_deserialize)deflocation_tag(storage:Storage):for_,tagger,_in_package_registry:location=tagger(storage)iflocation:returnlocationraiseRuntimeError("don't know how to determine data location of "+torch.typename(storage))defdefault_restore_location(storage,location):for_,_,fnin_package_registry:result=fn(storage,location)ifresultisnotNone:returnresultraiseRuntimeError("don't know how to restore data location of "+torch.typename(storage)+" (tagged with "+location+")")defnormalize_storage_type(storage_type):returngetattr(torch,storage_type.__name__)defstorage_to_tensor_type(storage):storage_type=type(storage)module=_import_dotted_name(storage_type.__module__)returngetattr(module,storage_type.__name__.replace('Storage','Tensor'))def_is_path(name_or_buffer):returnisinstance(name_or_buffer,str)or \
isinstance(name_or_buffer,pathlib.Path)class_opener(object):def__init__(self,file_like):self.file_like=file_likedef__enter__(self):returnself.file_likedef__exit__(self,*args):passclass_open_file(_opener):def__init__(self,name,mode):super(_open_file,self).__init__(open(name,mode))def__exit__(self,*args):self.file_like.close()class_open_buffer_reader(_opener):def__init__(self,buffer):super(_open_buffer_reader,self).__init__(buffer)_check_seekable(buffer)class_open_buffer_writer(_opener):def__exit__(self,*args):self.file_like.flush()def_open_file_like(name_or_buffer,mode):if_is_path(name_or_buffer):return_open_file(name_or_buffer,mode)else:if'w'inmode:return_open_buffer_writer(name_or_buffer)elif'r'inmode:return_open_buffer_reader(name_or_buffer)else:raiseRuntimeError(f"Expected 'r' or 'w' in mode but got {mode}")class_open_zipfile_reader(_opener):def__init__(self,name_or_buffer)->None:super(_open_zipfile_reader,self).__init__(torch._C.PyTorchFileReader(name_or_buffer))class_open_zipfile_writer_file(_opener):def__init__(self,name)->None:super(_open_zipfile_writer_file,self).__init__(torch._C.PyTorchFileWriter(str(name)))def__exit__(self,*args)->None:self.file_like.write_end_of_file()class_open_zipfile_writer_buffer(_opener):def__init__(self,buffer)->None:self.buffer=buffersuper(_open_zipfile_writer_buffer,self).__init__(torch._C.PyTorchFileWriter(buffer))def__exit__(self,*args)->None:self.file_like.write_end_of_file()self.buffer.flush()def_open_zipfile_writer(name_or_buffer):container:Type[_opener]if_is_path(name_or_buffer):container=_open_zipfile_writer_fileelse:container=_open_zipfile_writer_bufferreturncontainer(name_or_buffer)def_is_compressed_file(f)->bool:compress_modules=['gzip']try:returnf.__module__incompress_modulesexceptAttributeError:returnFalsedef_should_read_directly(f):""" Checks if f is a file that should be read directly. It should be read directly if it is backed by a real file (has a fileno) and is not a a compressed file (e.g. gzip) """if_is_compressed_file(f):returnFalsetry:returnf.fileno()>=0exceptio.UnsupportedOperation:returnFalseexceptAttributeError:returnFalsedef_check_seekable(f)->bool:defraise_err_msg(patterns,e):forpinpatterns:ifpinstr(e):msg=(str(e)+". You can only torch.load from a file that is seekable."+" Please pre-load the data into a buffer like io.BytesIO and"+" try to load from it instead.")raisetype(e)(msg)raiseetry:f.seek(f.tell())returnTrueexcept(io.UnsupportedOperation,AttributeError)ase:raise_err_msg(["seek","tell"],e)returnFalsedef_check_dill_version(pickle_module)->None:'''Checks if using dill as the pickle module, and if so, checks if it is the correct version. If dill version is lower than 0.3.1, a ValueError is raised. Args: pickle_module: module used for pickling metadata and objects '''ifpickle_module.__name__=='dill':required_dill_version=(0,3,1)ifnotcheck_module_version_greater_or_equal(pickle_module,required_dill_version,False):raiseValueError(("'torch' supports dill >= %s, but you have dill %s."" Please upgrade dill or switch to 'pickle'")%('.'.join([str(num)fornuminrequired_dill_version]),pickle_module.__version__))
[docs]defsave(obj,f:Union[str,os.PathLike,BinaryIO,IO[bytes]],pickle_module=pickle,pickle_protocol=DEFAULT_PROTOCOL,_use_new_zipfile_serialization=True)->None:# Reference: https://github.com/pytorch/pytorch/issues/54354# The first line of this docstring overrides the one Sphinx generates for the# documentation. We need it so that Sphinx doesn't leak `pickle`s path from# the build environment (e.g. `<module 'pickle' from '/leaked/path')."""save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True) Saves an object to a disk file. See also: :ref:`saving-loading-tensors` Args: obj: saved object f: a file-like object (has to implement write and flush) or a string or os.PathLike object containing a file name pickle_module: module used for pickling metadata and objects pickle_protocol: can be specified to override the default protocol .. note:: A common PyTorch convention is to save tensors using .pt file extension. .. note:: PyTorch preserves storage sharing across serialization. See :ref:`preserve-storage-sharing` for more details. .. note:: The 1.6 release of PyTorch switched ``torch.save`` to use a new zipfile-based file format. ``torch.load`` still retains the ability to load files in the old format. If for any reason you want ``torch.save`` to use the old format, pass the kwarg ``_use_new_zipfile_serialization=False``. Example: >>> # Save to file >>> x = torch.tensor([0, 1, 2, 3, 4]) >>> torch.save(x, 'tensor.pt') >>> # Save to io.BytesIO buffer >>> buffer = io.BytesIO() >>> torch.save(x, buffer) """_check_dill_version(pickle_module)with_open_file_like(f,'wb')asopened_file:if_use_new_zipfile_serialization:with_open_zipfile_writer(opened_file)asopened_zipfile:_save(obj,opened_zipfile,pickle_module,pickle_protocol)return_legacy_save(obj,opened_file,pickle_module,pickle_protocol)
def_legacy_save(obj,f,pickle_module,pickle_protocol)->None:importtorch.nnasnnserialized_container_types={}serialized_storages={}defpersistent_id(obj:Any)->Optional[Tuple]:# FIXME: the docs say that persistent_id should only return a string# but torch store returns tuples. This works only in the binary protocol# see# https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects# https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537ifisinstance(obj,type)andissubclass(obj,nn.Module):ifobjinserialized_container_types:returnNoneserialized_container_types[obj]=Truesource_file=source=Nonetry:source_lines,_,source_file=get_source_lines_and_file(obj)source=''.join(source_lines)exceptException:# saving the source is optional, so we can ignore any errorswarnings.warn("Couldn't retrieve source code for container of ""type "+obj.__name__+". It won't be checked ""for correctness upon loading.")return('module',obj,source_file,source)eliftorch.is_storage(obj):view_metadata:Optional[Tuple[str,int,int]]obj=cast(Storage,obj)storage_type=normalize_storage_type(type(obj))# Offset is always 0, but we keep it for backwards compatibility# with the old serialization format (which supported storage views)offset=0obj_key=str(obj._cdata)location=location_tag(obj)serialized_storages[obj_key]=objis_view=obj._cdata!=obj._cdataifis_view:view_metadata=(str(obj._cdata),offset,obj.size())else:view_metadata=Nonereturn('storage',storage_type,obj_key,location,obj.size(),view_metadata)returnNonesys_info=dict(protocol_version=PROTOCOL_VERSION,little_endian=sys.byteorder=='little',type_sizes=dict(short=SHORT_SIZE,int=INT_SIZE,long=LONG_SIZE,),)pickle_module.dump(MAGIC_NUMBER,f,protocol=pickle_protocol)pickle_module.dump(PROTOCOL_VERSION,f,protocol=pickle_protocol)pickle_module.dump(sys_info,f,protocol=pickle_protocol)pickler=pickle_module.Pickler(f,protocol=pickle_protocol)pickler.persistent_id=persistent_idpickler.dump(obj)serialized_storage_keys=sorted(serialized_storages.keys())pickle_module.dump(serialized_storage_keys,f,protocol=pickle_protocol)f.flush()forkeyinserialized_storage_keys:serialized_storages[key]._write_file(f,_should_read_directly(f),True)def_save(obj,zip_file,pickle_module,pickle_protocol):serialized_storages={}id_map:Dict[int,str]={}defpersistent_id(obj):# FIXME: the docs say that persistent_id should only return a string# but torch store returns tuples. This works only in the binary protocol# see# https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects# https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537iftorch.is_storage(obj):storage_type=normalize_storage_type(type(obj))obj_key=id_map.setdefault(obj._cdata,str(len(id_map)))location=location_tag(obj)serialized_storages[obj_key]=objreturn('storage',storage_type,obj_key,location,obj.size())returnNone# Write the pickle data for `obj`data_buf=io.BytesIO()pickler=pickle_module.Pickler(data_buf,protocol=pickle_protocol)pickler.persistent_id=persistent_idpickler.dump(obj)data_value=data_buf.getvalue()zip_file.write_record('data.pkl',data_value,len(data_value))# Write each tensor to a file named tensor/the_tensor_key in the zip archiveforkeyinsorted(serialized_storages.keys()):name=f'data/{key}'storage=serialized_storages[key]# given that we copy things around anyway, we might use storage.cpu()# this means to that to get tensors serialized, you need to implement# .cpu() on the underlying Storageifstorage.device.type!='cpu':storage=storage.cpu()# Now that it is on the CPU we can directly copy it into the zip filenum_bytes=storage.size()*storage.element_size()zip_file.write_record(name,storage.data_ptr(),num_bytes)defload(f,map_location=None,pickle_module=pickle,**pickle_load_args):# Reference: https://github.com/pytorch/pytorch/issues/54354# The first line of this docstring overrides the one Sphinx generates for the# documentation. We need it so that Sphinx doesn't leak `pickle`s path from# the build environment (e.g. `<module 'pickle' from '/leaked/path')."""load(f, map_location=None, pickle_module=pickle, **pickle_load_args) Loads an object saved with :func:`torch.save` from a file. :func:`torch.load` uses Python's unpickling facilities but treats storages, which underlie tensors, specially. They are first deserialized on the CPU and are then moved to the device they were saved from. If this fails (e.g. because the run time system doesn't have certain devices), an exception is raised. However, storages can be dynamically remapped to an alternative set of devices using the :attr:`map_location` argument. If :attr:`map_location` is a callable, it will be called once for each serialized storage with two arguments: storage and location. The storage argument will be the initial deserialization of the storage, residing on the CPU. Each serialized storage has a location tag associated with it which identifies the device it was saved from, and this tag is the second argument passed to :attr:`map_location`. The builtin location tags are ``'cpu'`` for CPU tensors and ``'cuda:device_id'`` (e.g. ``'cuda:2'``) for CUDA tensors. :attr:`map_location` should return either ``None`` or a storage. If :attr:`map_location` returns a storage, it will be used as the final deserialized object, already moved to the right device. Otherwise, :func:`torch.load` will fall back to the default behavior, as if :attr:`map_location` wasn't specified. If :attr:`map_location` is a :class:`torch.device` object or a string containing a device tag, it indicates the location where all tensors should be loaded. Otherwise, if :attr:`map_location` is a dict, it will be used to remap location tags appearing in the file (keys), to ones that specify where to put the storages (values). User extensions can register their own location tags and tagging and deserialization methods using :func:`torch.serialization.register_package`. Args: f: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`), or a string or os.PathLike object containing a file name map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage locations pickle_module: module used for unpickling metadata and objects (has to match the :attr:`pickle_module` used to serialize file) pickle_load_args: (Python 3 only) optional keyword arguments passed over to :func:`pickle_module.load` and :func:`pickle_module.Unpickler`, e.g., :attr:`errors=...`. .. warning:: :func:`torch.load()` uses ``pickle`` module implicitly, which is known to be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling. Never load data that could have come from an untrusted source, or that could have been tampered with. **Only load data you trust**. .. note:: When you call :func:`torch.load()` on a file which contains GPU tensors, those tensors will be loaded to GPU by default. You can call ``torch.load(.., map_location='cpu')`` and then :meth:`load_state_dict` to avoid GPU RAM surge when loading a model checkpoint. .. note:: By default, we decode byte strings as ``utf-8``. This is to avoid a common error case ``UnicodeDecodeError: 'ascii' codec can't decode byte 0x...`` when loading files saved by Python 2 in Python 3. If this default is incorrect, you may use an extra :attr:`encoding` keyword argument to specify how these objects should be loaded, e.g., :attr:`encoding='latin1'` decodes them to strings using ``latin1`` encoding, and :attr:`encoding='bytes'` keeps them as byte arrays which can be decoded later with ``byte_array.decode(...)``. Example: >>> torch.load('tensors.pt') # Load all tensors onto the CPU >>> torch.load('tensors.pt', map_location=torch.device('cpu')) # Load all tensors onto the CPU, using a function >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage) # Load all tensors onto GPU 1 >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1)) # Map tensors from GPU 1 to GPU 0 >>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'}) # Load tensor from io.BytesIO object >>> with open('tensor.pt', 'rb') as f: ... buffer = io.BytesIO(f.read()) >>> torch.load(buffer) # Load a module with 'ascii' encoding for unpickling >>> torch.load('module.pt', encoding='ascii') """_check_dill_version(pickle_module)if'encoding'notinpickle_load_args.keys():pickle_load_args['encoding']='utf-8'with_open_file_like(f,'rb')asopened_file:if_is_zipfile(opened_file):# The zipfile reader is going to advance the current file position.# If we want to actually tail call to torch.jit.load, we need to# reset back to the original position.orig_position=opened_file.tell()with_open_zipfile_reader(opened_file)asopened_zipfile:if_is_torchscript_zip(opened_zipfile):warnings.warn("'torch.load' received a zip file that looks like a TorchScript archive"" dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to"" silence this warning)",UserWarning)opened_file.seek(orig_position)returntorch.jit.load(opened_file)return_load(opened_zipfile,map_location,pickle_module,**pickle_load_args)return_legacy_load(opened_file,map_location,pickle_module,**pickle_load_args)# Register pickling support for layout instances such as# torch.sparse_coo, etcdef_get_layout(name):"""Get layout extension object from its string representation. """cache=_get_layout.cache# type: ignore[attr-defined]ifnotcache:forvintorch.__dict__.values():ifisinstance(v,torch.layout):cache[str(v)]=vreturncache[name]# There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087_get_layout.cache={}# type: ignore[attr-defined]copyreg.pickle(torch.layout,lambdaobj:(_get_layout,(str(obj),)))def_legacy_load(f,map_location,pickle_module,**pickle_load_args):deserialized_objects:Dict[int,Any]={}restore_location=_get_restore_location(map_location)def_check_container_source(container_type,source_file,original_source):try:current_source=''.join(get_source_lines_and_file(container_type)[0])exceptException:# saving the source is optional, so we can ignore any errorswarnings.warn("Couldn't retrieve source code for container of ""type "+container_type.__name__+". It won't be checked ""for correctness upon loading.")returniforiginal_source!=current_source:ifcontainer_type.dump_patches:file_name=container_type.__name__+'.patch'diff=difflib.unified_diff(current_source.split('\n'),original_source.split('\n'),source_file,source_file,lineterm="")lines='\n'.join(diff)try:withopen(file_name,'a+')asf:file_size=f.seek(0,2)f.seek(0)iffile_size==0:f.write(lines)eliffile_size!=len(lines)orf.read()!=lines:raiseIOErrormsg=("Saved a reverse patch to "+file_name+". ""Run `patch -p0 < "+file_name+"` to revert your ""changes.")exceptIOError:msg=("Tried to save a patch, but couldn't create a ""writable file "+file_name+". Make sure it ""doesn't exist and your working directory is ""writable.")else:msg=("you can retrieve the original source code by ""accessing the object's source attribute or set ""`torch.nn.Module.dump_patches = True` and use the ""patch tool to revert the changes.")msg=f"source code of class '{torch.typename(container_type)}' has changed. {msg}"warnings.warn(msg,SourceChangeWarning)deflegacy_load(f):deserialized_objects:Dict[int,Any]={}defpersistent_load(saved_id):ifisinstance(saved_id,tuple):# Ignore containers that don't have any sources savedifall(saved_id[1:]):_check_container_source(*saved_id)returnsaved_id[0]returndeserialized_objects[int(saved_id)]withclosing(tarfile.open(fileobj=f,mode='r:',format=tarfile.PAX_FORMAT))astar, \
mkdtemp()astmpdir:tar.extract('storages',path=tmpdir)withopen(os.path.join(tmpdir,'storages'),'rb',0)asf:num_storages=pickle_module.load(f,**pickle_load_args)foriinrange(num_storages):args=pickle_module.load(f,**pickle_load_args)key,location,storage_type=argsobj=storage_type._new_with_file(f)obj=restore_location(obj,location)deserialized_objects[key]=objstorage_views=pickle_module.load(f,**pickle_load_args)fortarget_cdata,root_cdata,offset,sizeinstorage_views:root=deserialized_objects[root_cdata]deserialized_objects[target_cdata]=root[offset:offset+size]tar.extract('tensors',path=tmpdir)withopen(os.path.join(tmpdir,'tensors'),'rb',0)asf:num_tensors=pickle_module.load(f,**pickle_load_args)for_inrange(num_tensors):args=pickle_module.load(f,**pickle_load_args)key,storage_id,original_tensor_type=argsstorage=deserialized_objects[storage_id]tensor_type=storage_to_tensor_type(storage)ndim,=struct.unpack('<i',f.read(4))# skip next 4 bytes; legacy encoding treated ndim as 8 bytesf.read(4)size=struct.unpack(f'<{ndim}q',f.read(8*ndim))stride=struct.unpack(f'<{ndim}q',f.read(8*ndim))storage_offset,=struct.unpack('<q',f.read(8))tensor=tensor_type().set_(storage,storage_offset,size,stride)deserialized_objects[key]=tensorpickle_file=tar.extractfile('pickle')unpickler=pickle_module.Unpickler(pickle_file,**pickle_load_args)unpickler.persistent_load=persistent_loadresult=unpickler.load()returnresultdeserialized_objects={}defpersistent_load(saved_id):assertisinstance(saved_id,tuple)typename=_maybe_decode_ascii(saved_id[0])data=saved_id[1:]iftypename=='module':# Ignore containers that don't have any sources savedifall(data[1:]):_check_container_source(*data)returndata[0]eliftypename=='storage':data_type,root_key,location,size,view_metadata=datalocation=_maybe_decode_ascii(location)ifroot_keynotindeserialized_objects:obj=data_type(size)obj._torch_load_uninitialized=Truedeserialized_objects[root_key]=restore_location(obj,location)storage=deserialized_objects[root_key]ifview_metadataisnotNone:view_key,offset,view_size=view_metadataifview_keynotindeserialized_objects:deserialized_objects[view_key]=storage[offset:offset+view_size]returndeserialized_objects[view_key]else:returnstorageelse:raiseRuntimeError("Unknown saved id type: %s"%saved_id[0])_check_seekable(f)f_should_read_directly=_should_read_directly(f)iff_should_read_directlyandf.tell()==0:# legacy_load requires that f has fileno()# only if offset is zero we can attempt the legacy tar file loadertry:returnlegacy_load(f)excepttarfile.TarError:if_is_zipfile(f):# .zip is used for torch.jit.save and will throw an un-pickling error hereraiseRuntimeError(f"{f.name} is a zip archive (did you mean to use torch.jit.load()?)")fromNone# if not a tarfile, reset file offset and proceedf.seek(0)ifnothasattr(f,'readinto')and(3,8,0)<=sys.version_info<(3,8,2):raiseRuntimeError("torch.load does not work with file-like objects that do not implement readinto on Python 3.8.0 and 3.8.1. "f"Received object of type \"{type(f)}\". Please update to Python 3.8.2 or newer to restore this ""functionality.")magic_number=pickle_module.load(f,**pickle_load_args)ifmagic_number!=MAGIC_NUMBER:raiseRuntimeError("Invalid magic number; corrupt file?")protocol_version=pickle_module.load(f,**pickle_load_args)ifprotocol_version!=PROTOCOL_VERSION:raiseRuntimeError("Invalid protocol version: %s"%protocol_version)_sys_info=pickle_module.load(f,**pickle_load_args)unpickler=pickle_module.Unpickler(f,**pickle_load_args)unpickler.persistent_load=persistent_loadresult=unpickler.load()deserialized_storage_keys=pickle_module.load(f,**pickle_load_args)offset=f.tell()iff_should_read_directlyelseNoneforkeyindeserialized_storage_keys:assertkeyindeserialized_objectsdeserialized_objects[key]._set_from_file(f,offset,f_should_read_directly)ifoffsetisnotNone:offset=f.tell()torch._utils._validate_loaded_sparse_tensors()returnresultdef_maybe_decode_ascii(bytes_str:Union[bytes,str])->str:# When using encoding='bytes' in Py3, some **internal** keys stored as# strings in Py2 are loaded as bytes. This function decodes them with# ascii encoding, one that Py3 uses by default.## NOTE: This should only be used on internal keys (e.g., `typename` and# `location` in `persistent_load` below!ifisinstance(bytes_str,bytes):returnbytes_str.decode('ascii')returnbytes_strdef_get_restore_location(map_location):ifmap_locationisNone:restore_location=default_restore_locationelifisinstance(map_location,dict):defrestore_location(storage,location):location=map_location.get(location,location)returndefault_restore_location(storage,location)elifisinstance(map_location,_string_classes):defrestore_location(storage,location):returndefault_restore_location(storage,map_location)elifisinstance(map_location,torch.device):defrestore_location(storage,location):returndefault_restore_location(storage,str(map_location))else:defrestore_location(storage,location):result=map_location(storage,location)ifresultisNone:result=default_restore_location(storage,location)returnresultreturnrestore_locationdef_load(zip_file,map_location,pickle_module,pickle_file='data.pkl',**pickle_load_args):restore_location=_get_restore_location(map_location)loaded_storages={}defload_tensor(data_type,size,key,location):name=f'data/{key}'dtype=data_type(0).dtypestorage=zip_file.get_storage_from_record(name,size,dtype).storage()loaded_storages[key]=restore_location(storage,location)defpersistent_load(saved_id):assertisinstance(saved_id,tuple)typename=_maybe_decode_ascii(saved_id[0])data=saved_id[1:]asserttypename=='storage', \
f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"data_type,key,location,size=dataifkeynotinloaded_storages:load_tensor(data_type,size,key,_maybe_decode_ascii(location))storage=loaded_storages[key]returnstorageload_module_mapping:Dict[str,str]={# See https://github.com/pytorch/pytorch/pull/51633'torch.tensor':'torch._tensor'}# Need to subclass Unpickler instead of directly monkey-patching the find_class method# because it's marked readonly in pickle.# The type: ignore is because mypy can't statically determine the type of this class.classUnpicklerWrapper(pickle_module.Unpickler):# type: ignore[name-defined]# from https://stackoverflow.com/questions/13398462/unpickling-python-objects-with-a-changed-module-path/13405732# Lets us override the imports that pickle uses when unpickling an object.# This is useful for maintaining BC if we change a module path that tensor instantiation relies on.deffind_class(self,mod_name,name):mod_name=load_module_mapping.get(mod_name,mod_name)returnsuper().find_class(mod_name,name)# Load the data (which may in turn use `persistent_load` to load tensors)data_file=io.BytesIO(zip_file.get_record(pickle_file))unpickler=UnpicklerWrapper(data_file,**pickle_load_args)unpickler.persistent_load=persistent_loadresult=unpickler.load()torch._utils._validate_loaded_sparse_tensors()returnresultdef_is_torchscript_zip(zip_file):return'constants.pkl'inzip_file.get_all_records()
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.