Source code for torch.serialization
# mypy: allow-untyped-defs
import difflib
import functools
import os
import io
import re
import shutil
import struct
import sys
import torch
import tarfile
import tempfile
import warnings
from contextlib import closing, contextmanager
from enum import Enum
from ._utils import _import_dotted_name
from torch._sources import get_source_lines_and_file
from torch.types import Storage
from torch.storage import _get_dtype_from_pickle_storage_type
from typing import Any, BinaryIO, Callable, cast, Dict, Optional, Type, Tuple, Union, IO, List
from typing_extensions import TypeAlias, TypeGuard # Python 3.10+
import copyreg
import pickle
import torch._weights_only_unpickler as _weights_only_unpickler
DEFAULT_PROTOCOL = 2
LONG_SIZE = struct.Struct('=l').size
INT_SIZE = struct.Struct('=i').size
SHORT_SIZE = struct.Struct('=h').size
MAGIC_NUMBER = 0x1950a86a20f9469cfc6c
PROTOCOL_VERSION = 1001
STORAGE_KEY_SEPARATOR = ','
FILE_LIKE: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes]]
MAP_LOCATION: TypeAlias = Optional[Union[Callable[[Storage, str], Storage], torch.device, str, Dict[str, str]]]
STORAGE: TypeAlias = Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]
IS_WINDOWS = sys.platform == "win32"
if not IS_WINDOWS:
from mmap import MAP_SHARED, MAP_PRIVATE
else:
MAP_SHARED, MAP_PRIVATE = None, None # type: ignore[assignment]
__all__ = [
'SourceChangeWarning',
'mkdtemp',
'register_package',
'check_module_version_greater_or_equal',
'validate_cuda_device',
'validate_hpu_device',
'location_tag',
'default_restore_location',
'normalize_storage_type',
'storage_to_tensor_type',
'save',
'load',
'StorageType',
'LoadEndianness',
'get_default_load_endianness',
'set_default_load_endianness',
'clear_safe_globals',
'get_safe_globals',
'add_safe_globals',
]
class SourceChangeWarning(Warning):
pass
@contextmanager
def mkdtemp():
path = tempfile.mkdtemp()
try:
yield path
finally:
shutil.rmtree(path)
_package_registry: List[Tuple[int, Callable[[STORAGE], Optional[str]], Callable[[STORAGE, str], Optional[STORAGE]]]] = []
class LoadEndianness(Enum):
NATIVE = 1
LITTLE = 2
BIG = 3
_default_load_endian: Optional[LoadEndianness] = None
[docs]def get_default_load_endianness() -> Optional[LoadEndianness]:
'''
Get fallback byte order for loading files
If byteorder mark is not present in saved checkpoint,
this byte order is used as fallback.
By default, it's "native" byte order.
Returns:
default_load_endian: Optional[LoadEndianness]
'''
return _default_load_endian
[docs]def set_default_load_endianness(endianness):
'''
Set fallback byte order for loading files
If byteorder mark is not present in saved checkpoint,
this byte order is used as fallback.
By default, it's "native" byte order.
Args:
endianness: the new fallback byte order
'''
global _default_load_endian
if not isinstance(endianness, LoadEndianness) and endianness is not None:
raise TypeError("Invalid argument type in function set_default_load_endianness")
_default_load_endian = endianness
_default_mmap_options: int = MAP_PRIVATE
[docs]def get_default_mmap_options() -> int:
'''
Get default mmap options for :func:`torch.load` with ``mmap=True``.
Defaults to ``mmap.MAP_PRIVATE``.
Returns:
default_mmap_options: int
'''
return _default_mmap_options
[docs]def set_default_mmap_options(flags: int):
'''
Set default mmap options for :func:`torch.load` with ``mmap=True`` to flags.
For now, only either ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED`` are supported.
Please open an issue if you need any other option to be added here.
.. note::
This feature is currently not supported for Windows.
Args:
flags: ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED``
'''
global _default_mmap_options
if IS_WINDOWS:
raise RuntimeError("Changing the default mmap options is currently not supported for Windows")
if (flags != MAP_PRIVATE and flags != MAP_SHARED):
raise ValueError("Invalid argument in function set_default_mmap_options, "
f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}")
_default_mmap_options = flags
[docs]def clear_safe_globals() -> None:
'''
Clears the list of globals that are safe for ``weights_only`` load.
'''
_weights_only_unpickler._clear_safe_globals()
[docs]def get_safe_globals() -> List[Any]:
'''
Returns the list of user-added globals that are safe for ``weights_only`` load.
'''
return _weights_only_unpickler._get_safe_globals()
[docs]def add_safe_globals(safe_globals: List[Any]) -> None:
'''
Marks the given globals as safe for ``weights_only`` load. For example, functions
added to this list can be called during unpickling, classes could be instantiated
and have state set.
Args:
safe_globals (List[Any]): list of globals to mark as safe
Example:
>>> # xdoctest: +SKIP("Can't torch.save(t, ...) as doctest thinks MyTensor is defined on torch.serialization")
>>> import tempfile
>>> class MyTensor(torch.Tensor):
... pass
>>> t = MyTensor(torch.randn(2, 3))
>>> with tempfile.NamedTemporaryFile() as f:
... torch.save(t, f.name)
# Running `torch.load(f.name, weights_only=True)` will fail with
# Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
# Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
... torch.serialization.add_safe_globals([MyTensor])
... torch.load(f.name, weights_only=True)
# MyTensor([[-0.5024, -1.8152, -0.5455],
# [-0.8234, 2.0500, -0.3657]])
'''
_weights_only_unpickler._add_safe_globals(safe_globals)
def _is_zipfile(f) -> bool:
# This is a stricter implementation than zipfile.is_zipfile().
# zipfile.is_zipfile() is True if the magic number appears anywhere in the
# binary. Since we expect the files here to be generated by torch.save or
# torch.jit.save, it's safe to only check the start bytes and avoid
# collisions and assume the zip has only 1 file.
# See bugs.python.org/issue28494.
start = f.tell()
# Read the first few bytes and match against the ZIP file signature
local_header_magic_number = b'PK\x03\x04'
read_bytes = f.read(len(local_header_magic_number))
f.seek(start)
return read_bytes == local_header_magic_number
[docs]def register_package(
priority: int,
tagger: Callable[[STORAGE], Optional[str]],
deserializer: Callable[[STORAGE, str], Optional[STORAGE]]
):
'''
Registers callables for tagging and deserializing storage objects with an associated priority.
Tagging associates a device with a storage object at save time while deserializing moves a
storage object to an appropriate device at load time. :attr:`tagger` and :attr:`deserializer`
are run in the order given by their :attr:`priority` until a tagger/deserializer returns a
value that is not `None`.
To override the deserialization behavior for a device in the global registry, one can register a
tagger with a higher priority than the existing tagger.
This function can also be used to register a tagger and deserializer for new devices.
Args:
priority: Indicates the priority associated with the tagger and deserializer, where a lower
value indicates higher priority.
tagger: Callable that takes in a storage object and returns its tagged device as a string
or None.
deserializer: Callable that takes in storage object and a device string and returns a storage
object on the appropriate device or None.
Returns:
`None`
Example:
>>> def ipu_tag(obj):
>>> if obj.device.type == 'ipu':
>>> return 'ipu'
>>> def ipu_deserialize(obj, location):
>>> if location.startswith('ipu'):
>>> ipu = getattr(torch, "ipu", None)
>>> assert ipu is not None, "IPU device module is not loaded"
>>> assert torch.ipu.is_available(), "ipu is not available"
>>> return obj.ipu(location)
>>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize)
'''
queue_elem = (priority, tagger, deserializer)
_package_registry.append(queue_elem)
_package_registry.sort()
def check_module_version_greater_or_equal(module, req_version_tuple, error_if_malformed=True):
'''
Check if a module's version satisfies requirements
Usually, a module's version string will be like 'x.y.z', which would be represented
as a tuple (x, y, z), but sometimes it could be an unexpected format. If the version
string does not match the given tuple's format up to the length of the tuple, then
error and exit or emit a warning.
Args:
module: the module to check the version of
req_version_tuple: tuple (usually of ints) representing the required version
error_if_malformed: whether we should exit if module version string is malformed
Returns:
requirement_is_met: bool
'''
try:
version_strs = module.__version__.split('.')
# Cast module version fields to match the types of the required version
module_version = tuple(
type(req_field)(version_strs[idx]) for idx, req_field in enumerate(req_version_tuple)
)
requirement_is_met = module_version >= req_version_tuple
except Exception as e:
message = (
f"'{module.__name__}' module version string is malformed '{module.__version__}' and cannot be compared"
f" with tuple {str(req_version_tuple)}"
)
if error_if_malformed:
raise RuntimeError(message) from e
else:
warnings.warn(message + ', but continuing assuming that requirement is met')
requirement_is_met = True
return requirement_is_met
def _cpu_tag(obj):
if obj.device.type == 'cpu':
return 'cpu'
def _mps_tag(obj):
if obj.device.type == 'mps':
return 'mps'
def _meta_tag(obj):
if obj.device.type == 'meta':
return 'meta'
def _backend_tag(backend_name, obj):
if backend_name == 'privateuse1':
backend_name = torch._C._get_privateuse1_backend_name()
if obj.device.type == backend_name:
if obj.device.index is None:
return backend_name
else:
return backend_name + ':' + str(obj.device.index)
def _cpu_deserialize(obj, location):
if location == 'cpu':
return obj
def _mps_deserialize(obj, location):
if location.startswith('mps'):
return obj.mps()
def _meta_deserialize(obj, location):
if location == 'meta':
return torch.UntypedStorage(obj.nbytes(), device='meta')
def _validate_device(location, backend_name):
'''
Check whether the device index of specified backend is valid
In case of privateuse1 backend, your must first register a device_module for
privateuse1 using torch._register_device_module. Implement the following
methods in device_module like cuda: device_module._utils._get_device_index(location, True),
device_module.device_count().
Args:
location: string of device
backend_name: the backend name or the name of privateuse1, which can be renamed
Returns:
device_index: int
'''
if not hasattr(torch, backend_name):
raise RuntimeError(f'The {backend_name.upper()} device module is not registered. '
'If you are running on a CPU-only machine, '
'please use torch.load with map_location=torch.device(\'cpu\') '
'to map your storages to the CPU.')
device_module = getattr(torch, backend_name)
if hasattr(device_module, '_utils') and hasattr(device_module._utils, '_get_device_index'):
device_index = device_module._utils._get_device_index(location, True)
device = torch.device(backend_name, device_index)
else:
device = torch.device(location)
device_index = device.index if device.index else 0
if hasattr(device_module, 'is_available') and not device_module.is_available():
raise RuntimeError(f'Attempting to deserialize object on a {backend_name.upper()} '
f'device but torch.{backend_name}.is_available() is False. '
'If you are running on a CPU-only machine, '
'please use torch.load with map_location=torch.device(\'cpu\') '
'to map your storages to the CPU.')
if hasattr(device_module, 'device_count'):
device_count = device_module.device_count()
if device_index >= device_count:
raise RuntimeError(f'Attempting to deserialize object on {backend_name.upper()} device '
f'{device_index} but torch.{backend_name}.device_count() is {device_count}. '
'Please use torch.load with map_location to map your storages '
'to an existing device.')
return device
def validate_cuda_device(location):
return _validate_device(location, 'cuda').index
def validate_hpu_device(location):
return _validate_device(location, 'hpu').index
def _deserialize(backend_name, obj, location):
if backend_name == 'privateuse1':
backend_name = torch._C._get_privateuse1_backend_name()
if location.startswith(backend_name):
device = _validate_device(location, backend_name)
return obj.to(device=device)
register_package(10, _cpu_tag, _cpu_deserialize)
register_package(20, functools.partial(_backend_tag, 'cuda'), functools.partial(_deserialize, 'cuda'))
register_package(21, _mps_tag, _mps_deserialize)
register_package(22, _meta_tag, _meta_deserialize)
register_package(23, functools.partial(_backend_tag, 'privateuse1'), functools.partial(_deserialize, 'privateuse1'))
register_package(24, functools.partial(_backend_tag, 'hpu'), functools.partial(_deserialize, 'hpu'))
register_package(25, functools.partial(_backend_tag, 'xpu'), functools.partial(_deserialize, 'xpu'))
def location_tag(storage: Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]):
for _, tagger, _ in _package_registry:
location = tagger(storage)
if location:
return location
raise RuntimeError("don't know how to determine data location of "
+ torch.typename(storage))
def default_restore_location(storage, location):
for _, _, fn in _package_registry:
result = fn(storage, location)
if result is not None:
return result
raise RuntimeError("don't know how to restore data location of "
+ torch.typename(storage) + " (tagged with "
+ location + ")")
def normalize_storage_type(storage_type):
return getattr(torch, storage_type.__name__)
def storage_to_tensor_type(storage):
storage_type = type(storage)
module = _import_dotted_name(storage_type.__module__)
return getattr(module, storage_type.__name__.replace('Storage', 'Tensor'))
def _is_path(name_or_buffer) -> TypeGuard[Union[str, os.PathLike]]:
return isinstance(name_or_buffer, (str, os.PathLike))
class _opener:
def __init__(self, file_like):
self.file_like = file_like
def __enter__(self):
return self.file_like
def __exit__(self, *args):
pass
class _open_file(_opener):
def __init__(self, name, mode):
super().__init__(open(name, mode))
def __exit__(self, *args):
self.file_like.close()
class _open_buffer_reader(_opener):
def __init__(self, buffer):
super().__init__(buffer)
_check_seekable(buffer)
class _open_buffer_writer(_opener):
def __exit__(self, *args):
self.file_like.flush()
def _open_file_like(name_or_buffer, mode):
if _is_path(name_or_buffer):
return _open_file(name_or_buffer, mode)
else:
if 'w' in mode:
return _open_buffer_writer(name_or_buffer)
elif 'r' in mode:
return _open_buffer_reader(name_or_buffer)
else:
raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}")
class _open_zipfile_reader(_opener):
def __init__(self, name_or_buffer) -> None:
super().__init__(torch._C.PyTorchFileReader(name_or_buffer))
class _open_zipfile_writer_file(_opener):
def __init__(self, name) -> None:
self.file_stream = None
self.name = str(name)
try:
self.name.encode('ascii')
except UnicodeEncodeError:
# PyTorchFileWriter only supports ascii filename.
# For filenames with non-ascii characters, we rely on Python
# for writing out the file.
self.file_stream = io.FileIO(self.name, mode='w')
super().__init__(torch._C.PyTorchFileWriter(self.file_stream))
else:
super().__init__(torch._C.PyTorchFileWriter(self.name))
def __exit__(self, *args) -> None:
self.file_like.write_end_of_file()
if self.file_stream is not None:
self.file_stream.close()
class _open_zipfile_writer_buffer(_opener):
def __init__(self, buffer) -> None:
if not callable(getattr(buffer, "write", None)):
msg = f"Buffer of {str(type(buffer)).strip('<>')} has no callable attribute 'write'"
if not hasattr(buffer, "write"):
raise AttributeError(msg)
raise TypeError(msg)
self.buffer = buffer
super().__init__(torch._C.PyTorchFileWriter(buffer))
def __exit__(self, *args) -> None:
self.file_like.write_end_of_file()
self.buffer.flush()
def _open_zipfile_writer(name_or_buffer):
container: Type[_opener]
if _is_path(name_or_buffer):
container = _open_zipfile_writer_file
else:
container = _open_zipfile_writer_buffer
return container(name_or_buffer)
def _is_compressed_file(f) -> bool:
compress_modules = ['gzip']
try:
return f.__module__ in compress_modules
except AttributeError:
return False
def _should_read_directly(f):
"""
Checks if f is a file that should be read directly. It should be read
directly if it is backed by a real file (has a fileno) and is not a
a compressed file (e.g. gzip)
"""
if _is_compressed_file(f):
return False
try:
return f.fileno() >= 0
except io.UnsupportedOperation:
return False
except AttributeError:
return False
def _check_seekable(f) -> bool:
def raise_err_msg(patterns, e):
for p in patterns:
if p in str(e):
msg = (str(e) + ". You can only torch.load from a file that is seekable."
+ " Please pre-load the data into a buffer like io.BytesIO and"
+ " try to load from it instead.")
raise type(e)(msg)
raise e
try:
f.seek(f.tell())
return True
except (io.UnsupportedOperation, AttributeError) as e:
raise_err_msg(["seek", "tell"], e)
return False
def _check_dill_version(pickle_module) -> None:
'''Checks if using dill as the pickle module, and if so, checks if it is the correct version.
If dill version is lower than 0.3.1, a ValueError is raised.
Args:
pickle_module: module used for pickling metadata and objects
'''
if pickle_module is not None and pickle_module.__name__ == 'dill':
required_dill_version = (0, 3, 1)
if not check_module_version_greater_or_equal(pickle_module, required_dill_version, False):
raise ValueError((
"'torch' supports dill >= {}, but you have dill {}."
" Please upgrade dill or switch to 'pickle'"
).format(
'.'.join([str(num) for num in required_dill_version]),
pickle_module.__version__
))
def _check_save_filelike(f):
if not _is_path(f) and not hasattr(f, 'write'):
raise AttributeError(
"expected 'f' to be string, path, or a file-like object with "
"a 'write' attribute")
[docs]def save(
obj: object,
f: FILE_LIKE,
pickle_module: Any = pickle,
pickle_protocol: int = DEFAULT_PROTOCOL,
_use_new_zipfile_serialization: bool = True,
_disable_byteorder_record: bool = False
) -> None:
# Reference: https://github.com/pytorch/pytorch/issues/54354
# The first line of this docstring overrides the one Sphinx generates for the
# documentation. We need it so that Sphinx doesn't leak `pickle`s path from
# the build environment (e.g. `<module 'pickle' from '/leaked/path').
"""save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True)
Saves an object to a disk file.
See also: :ref:`saving-loading-tensors`
Args:
obj: saved object
f: a file-like object (has to implement write and flush) or a string or
os.PathLike object containing a file name
pickle_module: module used for pickling metadata and objects
pickle_protocol: can be specified to override the default protocol
.. note::
A common PyTorch convention is to save tensors using .pt file extension.
.. note::
PyTorch preserves storage sharing across serialization. See
:ref:`preserve-storage-sharing` for more details.
.. note::
The 1.6 release of PyTorch switched ``torch.save`` to use a new
zipfile-based file format. ``torch.load`` still retains the ability to
load files in the old format. If for any reason you want ``torch.save``
to use the old format, pass the kwarg ``_use_new_zipfile_serialization=False``.
Example:
>>> # xdoctest: +SKIP("makes cwd dirty")
>>> # Save to file
>>> x = torch.tensor([0, 1, 2, 3, 4])
>>> torch.save(x, 'tensor.pt')
>>> # Save to io.BytesIO buffer
>>> buffer = io.BytesIO()
>>> torch.save(x, buffer)
"""
torch._C._log_api_usage_once("torch.save")
_check_dill_version(pickle_module)
_check_save_filelike(f)
if _use_new_zipfile_serialization:
with _open_zipfile_writer(f) as opened_zipfile:
_save(obj, opened_zipfile, pickle_module, pickle_protocol, _disable_byteorder_record)
return
else:
with _open_file_like(f, 'wb') as opened_file:
_legacy_save(obj, opened_file, pickle_module, pickle_protocol)
def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
import torch.nn as nn
serialized_container_types = {}
serialized_storages = {}
# Since loading storages that view the same data with different dtypes is
# not supported, we need to keep track of the dtype associated with each
# storage data_ptr and throw an error if the dtype is ever different.
# TODO: This feature could be added in the future
storage_dtypes: Dict[int, torch.dtype] = {}
def persistent_id(obj: Any) -> Optional[Tuple]:
# FIXME: the docs say that persistent_id should only return a string
# but torch store returns tuples. This works only in the binary protocol
# see
# https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
# https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
if isinstance(obj, type) and issubclass(obj, nn.Module):
if obj in serialized_container_types:
return None
serialized_container_types[obj] = True
source_file = source = None
try:
source_lines, _, source_file = get_source_lines_and_file(obj)
source = ''.join(source_lines)
except Exception: # saving the source is optional, so we can ignore any errors
warnings.warn("Couldn't retrieve source code for container of "
"type " + obj.__name__ + ". It won't be checked "
"for correctness upon loading.")
return ('module', obj, source_file, source)
if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
storage: torch.UntypedStorage
if isinstance(obj, torch.storage.TypedStorage):
# TODO: Once we decide to break serialization FC, this case
# can be deleted
storage = obj._untyped_storage
storage_dtype = obj.dtype
storage_type_str = obj._pickle_storage_type()
storage_type = getattr(torch, storage_type_str)
dtype = obj.dtype
storage_numel = obj._size()
elif isinstance(obj, torch.UntypedStorage):
storage = obj
storage_dtype = torch.uint8
storage_type = normalize_storage_type(type(obj))
dtype = torch.uint8
storage_numel = storage.nbytes()
else:
raise TypeError(f'type not recognized: {type(obj)}')
# If storage is allocated, ensure that any other saved storages
# pointing to the same data all have the same dtype. If storage is
# not allocated, don't perform this check
if storage.data_ptr() != 0:
if storage.data_ptr() in storage_dtypes:
if storage_dtype != storage_dtypes[storage.data_ptr()]:
raise RuntimeError(
'Cannot save multiple tensors or storages that '
'view the same data as different types')
else:
storage_dtypes[storage.data_ptr()] = storage_dtype
view_metadata: Optional[Tuple[str, int, int]]
# Offset is always 0, but we keep it for backwards compatibility
# with the old serialization format (which supported storage views)
offset = 0
storage_key = str(storage._cdata)
location = location_tag(storage)
# TODO: There's an issue here with FC. It might be impossible to
# solve, but it's worth noting. Imagine we save a list `[storage,
# tensor]`, where `tensor.storage()` is the same as `storage`, and
# `tensor.element_size() > 1`. Let's say that `tensor.dtype ==
# torch.float`. The storage will be serialized with element size
# of 1, since we're choosing to serialize the first occurance of
# a duplicate storage. Since this legacy serialization format saves
# the numel of the storage, rather than nbytes directly, we'll be
# effectively saving nbytes in this case. We'll be able to load it
# and the tensor back up with no problems in _this_ and future
# versions of pytorch, but in older versions, here's the problem:
# the storage will be loaded up as a UntypedStorage, and then the
# FloatTensor will loaded and the UntypedStorage will be assigned to
# it. Since the storage dtype does not match the tensor dtype, this
# will cause an error. If we reverse the list, like `[tensor,
# storage]`, then we will save the `tensor.storage()` as a faked
# `FloatStorage`, and the saved size will be the correct
# dtype-specific numel count that old versions expect. `tensor`
# will be able to load up properly in old versions, pointing to
# a FloatStorage. However, `storage` is still being translated to
# a UntypedStorage, and it will try to resolve to the same
# FloatStorage that `tensor` contains. This will also cause an
# error. It doesn't seem like there's any way around this.
# Probably, we just cannot maintain FC for the legacy format if the
# saved list contains both a tensor and a storage that point to the
# same data. We should still be able to maintain FC for lists of
# just tensors, as long as all views share the same dtype as the
# tensor they are viewing.
if storage_key not in serialized_storages:
serialized_storages[storage_key] = (storage, dtype)
is_view = storage._cdata != storage._cdata
if is_view:
view_metadata = (str(storage._cdata), offset, storage.nbytes())
else:
view_metadata = None
res = ('storage',
storage_type,
storage_key,
location,
storage_numel,
view_metadata)
return res
return None
sys_info = dict(
protocol_version=PROTOCOL_VERSION,
little_endian=sys.byteorder == 'little',
type_sizes=dict(
short=SHORT_SIZE,
int=INT_SIZE,
long=LONG_SIZE,
),
)
pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol)
pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol)
pickle_module.dump(sys_info, f, protocol=pickle_protocol)
pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
pickler.persistent_id = persistent_id
pickler.dump(obj)
serialized_storage_keys = sorted(serialized_storages.keys())
pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol)
f.flush()
for key in serialized_storage_keys:
storage, dtype = serialized_storages[key]
storage._write_file(f, _should_read_directly(f), True, torch._utils._element_size(dtype))
def _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_record):
serialized_storages = {}
id_map: Dict[int, str] = {}
# Since loading storages that view the same data with different dtypes is
# not supported, we need to keep track of the dtype associated with each
# storage data_ptr and throw an error if the dtype is ever different.
# TODO: This feature could be added in the future
storage_dtypes: Dict[int, torch.dtype] = {}
def persistent_id(obj):
# FIXME: the docs say that persistent_id should only return a string
# but torch store returns tuples. This works only in the binary protocol
# see
# https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
# https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
if isinstance(obj, torch.storage.TypedStorage):
# TODO: Once we decide to break serialization FC, this case
# can be deleted
storage = obj._untyped_storage
storage_dtype = obj.dtype
storage_type_str = obj._pickle_storage_type()
storage_type = getattr(torch, storage_type_str)
storage_numel = obj._size()
else:
storage = obj
storage_dtype = torch.uint8
storage_type = normalize_storage_type(type(obj))
storage_numel = storage.nbytes()
# If storage is allocated, ensure that any other saved storages
# pointing to the same data all have the same dtype. If storage is
# not allocated, don't perform this check
if storage.data_ptr() != 0:
if storage.data_ptr() in storage_dtypes:
if storage_dtype != storage_dtypes[storage.data_ptr()]:
raise RuntimeError(
'Cannot save multiple tensors or storages that '
'view the same data as different types')
else:
storage_dtypes[storage.data_ptr()] = storage_dtype
storage_key = id_map.setdefault(storage._cdata, str(len(id_map)))
location = location_tag(storage)
serialized_storages[storage_key] = storage
return ('storage',
storage_type,
storage_key,
location,
storage_numel)
return None
# Write the pickle data for `obj`
data_buf = io.BytesIO()
pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol)
pickler.persistent_id = persistent_id
pickler.dump(obj)
data_value = data_buf.getvalue()
zip_file.write_record('data.pkl', data_value, len(data_value))
# Write byte order marker
if not _disable_byteorder_record:
if sys.byteorder not in ['little', 'big']:
raise ValueError('Unknown endianness type: ' + sys.byteorder)
zip_file.write_record('byteorder', sys.byteorder, len(sys.byteorder))
# Write each tensor to a file named tensor/the_tensor_key in the zip archive
for key in sorted(serialized_storages.keys()):
name = f'data/{key}'
storage = serialized_storages[key]
# given that we copy things around anyway, we might use storage.cpu()
# this means to that to get tensors serialized, you need to implement
# .cpu() on the underlying Storage
if storage.device.type != 'cpu':
storage = storage.cpu()
# Now that it is on the CPU we can directly copy it into the zip file
num_bytes = storage.nbytes()
zip_file.write_record(name, storage, num_bytes)
[docs]def load(
f: FILE_LIKE,
map_location: MAP_LOCATION = None,
pickle_module: Any = None,
*,
weights_only: Optional[bool] = None,
mmap: Optional[bool] = None,
**pickle_load_args: Any
) -> Any:
# Reference: https://github.com/pytorch/pytorch/issues/54354
# The first line of this docstring overrides the one Sphinx generates for the
# documentation. We need it so that Sphinx doesn't leak `pickle`s path from
# the build environment (e.g. `<module 'pickle' from '/leaked/path').
"""load(f, map_location=None, pickle_module=pickle, *, weights_only=False, mmap=None, **pickle_load_args)
Loads an object saved with :func:`torch.save` from a file.
:func:`torch.load` uses Python's unpickling facilities but treats storages,
which underlie tensors, specially. They are first deserialized on the
CPU and are then moved to the device they were saved from. If this fails
(e.g. because the run time system doesn't have certain devices), an exception
is raised. However, storages can be dynamically remapped to an alternative
set of devices using the :attr:`map_location` argument.
If :attr:`map_location` is a callable, it will be called once for each serialized
storage with two arguments: storage and location. The storage argument
will be the initial deserialization of the storage, residing on the CPU.
Each serialized storage has a location tag associated with it which
identifies the device it was saved from, and this tag is the second
argument passed to :attr:`map_location`. The builtin location tags are ``'cpu'``
for CPU tensors and ``'cuda:device_id'`` (e.g. ``'cuda:2'``) for CUDA tensors.
:attr:`map_location` should return either ``None`` or a storage. If
:attr:`map_location` returns a storage, it will be used as the final deserialized
object, already moved to the right device. Otherwise, :func:`torch.load` will
fall back to the default behavior, as if :attr:`map_location` wasn't specified.
If :attr:`map_location` is a :class:`torch.device` object or a string containing
a device tag, it indicates the location where all tensors should be loaded.
Otherwise, if :attr:`map_location` is a dict, it will be used to remap location tags
appearing in the file (keys), to ones that specify where to put the
storages (values).
User extensions can register their own location tags and tagging and
deserialization methods using :func:`torch.serialization.register_package`.
Args:
f: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`),
or a string or os.PathLike object containing a file name
map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage
locations
pickle_module: module used for unpickling metadata and objects (has to
match the :attr:`pickle_module` used to serialize file)
weights_only: Indicates whether unpickler should be restricted to
loading only tensors, primitive types, dictionaries
and any types added via :func:`torch.serialization.add_safe_globals`.
mmap: Indicates whether the file should be mmaped rather than loading all the storages into memory.
Typically, tensor storages in the file will first be moved from disk to CPU memory, after which they
are moved to the location that they were tagged with when saving, or specified by ``map_location``. This
second step is a no-op if the final location is CPU. When the ``mmap`` flag is set, instead of copying the
tensor storages from disk to CPU memory in the first step, ``f`` is mmaped.
pickle_load_args: (Python 3 only) optional keyword arguments passed over to
:func:`pickle_module.load` and :func:`pickle_module.Unpickler`, e.g.,
:attr:`errors=...`.
.. warning::
:func:`torch.load()` unless `weights_only` parameter is set to `True`,
uses ``pickle`` module implicitly, which is known to be insecure.
It is possible to construct malicious pickle data which will execute arbitrary code
during unpickling. Never load data that could have come from an untrusted
source in an unsafe mode, or that could have been tampered with. **Only load data you trust**.
.. note::
When you call :func:`torch.load()` on a file which contains GPU tensors, those tensors
will be loaded to GPU by default. You can call ``torch.load(.., map_location='cpu')``
and then :meth:`load_state_dict` to avoid GPU RAM surge when loading a model checkpoint.
.. note::
By default, we decode byte strings as ``utf-8``. This is to avoid a common error
case ``UnicodeDecodeError: 'ascii' codec can't decode byte 0x...``
when loading files saved by Python 2 in Python 3. If this default
is incorrect, you may use an extra :attr:`encoding` keyword argument to specify how
these objects should be loaded, e.g., :attr:`encoding='latin1'` decodes them
to strings using ``latin1`` encoding, and :attr:`encoding='bytes'` keeps them
as byte arrays which can be decoded later with ``byte_array.decode(...)``.
Example:
>>> # xdoctest: +SKIP("undefined filepaths")
>>> torch.load('tensors.pt', weights_only=True)
# Load all tensors onto the CPU
>>> torch.load('tensors.pt', map_location=torch.device('cpu'), weights_only=True)
# Load all tensors onto the CPU, using a function
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage, weights_only=True)
# Load all tensors onto GPU 1
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1), weights_only=True)
# Map tensors from GPU 1 to GPU 0
>>> torch.load('tensors.pt', map_location={'cuda:1': 'cuda:0'}, weights_only=True)
# Load tensor from io.BytesIO object
# Loading from a buffer setting weights_only=False, warning this can be unsafe
>>> with open('tensor.pt', 'rb') as f:
... buffer = io.BytesIO(f.read())
>>> torch.load(buffer, weights_only=False)
# Load a module with 'ascii' encoding for unpickling
# Loading from a module setting weights_only=False, warning this can be unsafe
>>> torch.load('module.pt', encoding='ascii', weights_only=False)
"""
torch._C._log_api_usage_once("torch.load")
UNSAFE_MESSAGE = (
"Re-running `torch.load` with `weights_only` set to `False` will likely succeed, "
"but it can result in arbitrary code execution. Do it only if you got the file from a "
"trusted source."
)
DOCS_MESSAGE = (
"\n\nCheck the documentation of torch.load to learn more about types accepted by default with "
"weights_only https://pytorch.org/docs/stable/generated/torch.load.html."
)
def _get_wo_message(message: str) -> str:
pattern = r"GLOBAL (\S+) was not an allowed global by default."
has_unsafe_global = re.search(pattern, message) is not None
if has_unsafe_global:
updated_message = (
"Weights only load failed. This file can still be loaded, to do so you have two options "
f"\n\t(1) {UNSAFE_MESSAGE}\n\t(2) Alternatively, to load with `weights_only=True` please check "
"the recommended steps in the following error message.\n\tWeightsUnpickler error: "
+ message
)
else:
updated_message = (
f"Weights only load failed. {UNSAFE_MESSAGE}\n Please file an issue with the following "
"so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler "
"error: " + message
)
return updated_message + DOCS_MESSAGE
if weights_only is None:
weights_only, warn_weights_only = False, True
else:
warn_weights_only = False
# Add ability to force safe only weight loads via environment variable
if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in ['1', 'y', 'yes', 'true']:
weights_only = True
if weights_only:
if pickle_module is not None:
raise RuntimeError("Can not safely load weights when explicit pickle_module is specified")
else:
if pickle_module is None:
if warn_weights_only:
warnings.warn(
"You are using `torch.load` with `weights_only=False` (the current default value), which uses "
"the default pickle module implicitly. It is possible to construct malicious pickle data "
"which will execute arbitrary code during unpickling (See "
"https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). "
"In a future release, the default value for `weights_only` will be flipped to `True`. This "
"limits the functions that could be executed during unpickling. Arbitrary objects will no "
"longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the "
"user via `torch.serialization.add_safe_globals`. We recommend you start setting "
"`weights_only=True` for any use case where you don't have full control of the loaded file. "
"Please open an issue on GitHub for any issues related to this experimental feature.",
FutureWarning,
stacklevel=2,
)
pickle_module = pickle
# make flipping default BC-compatible
if mmap is None:
mmap = False
_check_dill_version(pickle_module)
if 'encoding' not in pickle_load_args.keys():
pickle_load_args['encoding'] = 'utf-8'
with _open_file_like(f, 'rb') as opened_file:
if _is_zipfile(opened_file):
# The zipfile reader is going to advance the current file position.
# If we want to actually tail call to torch.jit.load, we need to
# reset back to the original position.
orig_position = opened_file.tell()
overall_storage = None
with _open_zipfile_reader(opened_file) as opened_zipfile:
if _is_torchscript_zip(opened_zipfile):
warnings.warn("'torch.load' received a zip file that looks like a TorchScript archive"
" dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to"
" silence this warning)", UserWarning)
opened_file.seek(orig_position)
return torch.jit.load(opened_file, map_location=map_location)
if mmap:
if not _is_path(f):
raise ValueError("f must be a file path in order to use the mmap argument")
size = os.path.getsize(f)
if not IS_WINDOWS:
shared = get_default_mmap_options() == MAP_SHARED
else:
shared = False
overall_storage = torch.UntypedStorage.from_file(os.fspath(f), shared, size)
if weights_only:
try:
return _load(opened_zipfile,
map_location,
_weights_only_unpickler,
overall_storage=overall_storage,
**pickle_load_args)
except RuntimeError as e:
raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
return _load(
opened_zipfile,
map_location,
pickle_module,
overall_storage=overall_storage,
**pickle_load_args,
)
if mmap:
f_name = "" if not isinstance(f, str) else f"{f}, "
raise RuntimeError("mmap can only be used with files saved with "
f"`torch.save({f_name}_use_new_zipfile_serialization=True), "
"please torch.save your checkpoint with this option in order to use mmap.")
if weights_only:
try:
return _legacy_load(opened_file, map_location, _weights_only_unpickler, **pickle_load_args)
except RuntimeError as e:
raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
return _legacy_load(
opened_file, map_location, pickle_module, **pickle_load_args
)
# Register pickling support for layout instances such as
# torch.sparse_coo, etc
def _get_layout(name):
"""Get layout extension object from its string representation.
"""
cache = _get_layout.cache # type: ignore[attr-defined]
if not cache:
for v in torch.__dict__.values():
if isinstance(v, torch.layout):
cache[str(v)] = v
return cache[name]
# There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087
_get_layout.cache = {} # type: ignore[attr-defined]
copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),)))
def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
deserialized_objects: Dict[int, Any] = {}
restore_location = _get_restore_location(map_location)
class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined]
def find_class(self, mod_name, name):
if type(name) is str and 'Storage' in name:
try:
return StorageType(name)
except KeyError:
pass
return super().find_class(mod_name, name)
def _check_container_source(container_type, source_file, original_source):
try:
current_source = ''.join(get_source_lines_and_file(container_type)[0])
except Exception: # saving the source is optional, so we can ignore any errors
warnings.warn("Couldn't retrieve source code for container of "
"type " + container_type.__name__ + ". It won't be checked "
"for correctness upon loading.")
return
if original_source != current_source:
if container_type.dump_patches:
file_name = container_type.__name__ + '.patch'
diff = difflib.unified_diff(current_source.split('\n'),
original_source.split('\n'),
source_file,
source_file, lineterm="")
lines = '\n'.join(diff)
try:
with open(file_name, 'a+') as f:
file_size = f.seek(0, 2)
f.seek(0)
if file_size == 0:
f.write(lines)
elif file_size != len(lines) or f.read() != lines:
raise OSError
msg = ("Saved a reverse patch to " + file_name + ". "
"Run `patch -p0 < " + file_name + "` to revert your "
"changes.")
except OSError:
msg = ("Tried to save a patch, but couldn't create a "
"writable file " + file_name + ". Make sure it "
"doesn't exist and your working directory is "
"writable.")
else:
msg = ("you can retrieve the original source code by "
"accessing the object's source attribute or set "
"`torch.nn.Module.dump_patches = True` and use the "
"patch tool to revert the changes.")
msg = f"source code of class '{torch.typename(container_type)}' has changed. {msg}"
warnings.warn(msg, SourceChangeWarning)
def legacy_load(f):
deserialized_objects: Dict[int, Any] = {}
def persistent_load(saved_id):
if isinstance(saved_id, tuple):
# Ignore containers that don't have any sources saved
if all(saved_id[1:]):
_check_container_source(*saved_id)
return saved_id[0]
return deserialized_objects[int(saved_id)]
with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \
mkdtemp() as tmpdir:
tar.extract('storages', path=tmpdir)
with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f:
num_storages = pickle_module.load(f, **pickle_load_args)
for i in range(num_storages):
args = pickle_module.load(f, **pickle_load_args)
key, location, storage_type = args
dtype = storage_type._dtype
obj = cast(Storage, torch.UntypedStorage)._new_with_file(f, torch._utils._element_size(dtype))
obj = restore_location(obj, location)
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with TypedStorage
deserialized_objects[key] = torch.storage.TypedStorage(
wrap_storage=obj,
dtype=dtype,
_internal=True)
storage_views = pickle_module.load(f, **pickle_load_args)
for target_cdata, root_cdata, offset, numel in storage_views:
root = deserialized_objects[root_cdata]
element_size = torch._utils._element_size(root.dtype)
offset_bytes = offset * element_size
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with TypedStorage
deserialized_objects[target_cdata] = torch.storage.TypedStorage(
wrap_storage=root._untyped_storage[offset_bytes:offset_bytes + numel * element_size],
dtype=root.dtype,
_internal=True)
tar.extract('tensors', path=tmpdir)
with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f:
num_tensors = pickle_module.load(f, **pickle_load_args)
for _ in range(num_tensors):
args = pickle_module.load(f, **pickle_load_args)
key, storage_id, original_tensor_type = args
storage = deserialized_objects[storage_id]
ndim, = struct.unpack('<i', f.read(4))
# skip next 4 bytes; legacy encoding treated ndim as 8 bytes
f.read(4)
numel = struct.unpack(f'<{ndim}q', f.read(8 * ndim))
stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim))
storage_offset, = struct.unpack('<q', f.read(8))
tensor = torch.empty((0,), dtype=storage.dtype).set_(
storage._untyped_storage, storage_offset, numel, stride)
deserialized_objects[key] = tensor
pickle_file = tar.extractfile('pickle')
unpickler = UnpicklerWrapper(pickle_file, **pickle_load_args)
unpickler.persistent_load = persistent_load
result = unpickler.load()
return result
deserialized_objects = {}
def persistent_load(saved_id):
assert isinstance(saved_id, tuple)
typename = _maybe_decode_ascii(saved_id[0])
data = saved_id[1:]
if typename == 'module':
# Ignore containers that don't have any sources saved
if all(data[1:]):
_check_container_source(*data)
return data[0]
elif typename == 'storage':
storage_type, root_key, location, numel, view_metadata = data
location = _maybe_decode_ascii(location)
dtype = storage_type.dtype
nbytes = numel * torch._utils._element_size(dtype)
if root_key not in deserialized_objects:
if torch._guards.active_fake_mode() is not None:
obj = cast(Storage, torch.UntypedStorage(nbytes, device='meta'))
else:
obj = cast(Storage, torch.UntypedStorage(nbytes))
obj._torch_load_uninitialized = True
obj = restore_location(obj, location)
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with TypedStorage
typed_storage = torch.storage.TypedStorage(
wrap_storage=obj,
dtype=dtype,
_internal=True)
deserialized_objects[root_key] = typed_storage
else:
typed_storage = deserialized_objects[root_key]
if typed_storage._data_ptr() == 0:
typed_storage = torch.storage.TypedStorage(
device=typed_storage._untyped_storage.device,
dtype=dtype,
_internal=True)
if view_metadata is not None:
view_key, offset, view_size = view_metadata
offset_bytes = offset * torch._utils._element_size(dtype)
view_size_bytes = view_size * torch._utils._element_size(dtype)
if view_key not in deserialized_objects:
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with TypedStorage
deserialized_objects[view_key] = torch.storage.TypedStorage(
wrap_storage=typed_storage._untyped_storage[offset_bytes:offset_bytes + view_size_bytes],
dtype=dtype,
_internal=True)
res = deserialized_objects[view_key]
else:
res = typed_storage
return res
else:
raise RuntimeError(f"Unknown saved id type: {saved_id[0]}")
_check_seekable(f)
f_should_read_directly = _should_read_directly(f)
if f_should_read_directly and f.tell() == 0:
# legacy_load requires that f has fileno()
# only if offset is zero we can attempt the legacy tar file loader
try:
return legacy_load(f)
except tarfile.TarError:
if _is_zipfile(f):
# .zip is used for torch.jit.save and will throw an un-pickling error here
raise RuntimeError(
f"{f.name} is a zip archive (did you mean to use torch.jit.load()?)") from None
# if not a tarfile, reset file offset and proceed
f.seek(0)
if not hasattr(f, 'readinto') and (3, 8, 0) <= sys.version_info < (3, 8, 2):
raise RuntimeError(
"torch.load does not work with file-like objects that do not implement readinto on Python 3.8.0 and 3.8.1. "
f'Received object of type "{type(f)}". Please update to Python 3.8.2 or newer to restore this '
"functionality.")
magic_number = pickle_module.load(f, **pickle_load_args)
if magic_number != MAGIC_NUMBER:
raise RuntimeError("Invalid magic number; corrupt file?")
protocol_version = pickle_module.load(f, **pickle_load_args)
if protocol_version != PROTOCOL_VERSION:
raise RuntimeError(f"Invalid protocol version: {protocol_version}")
_sys_info = pickle_module.load(f, **pickle_load_args)
unpickler = UnpicklerWrapper(f, **pickle_load_args)
unpickler.persistent_load = persistent_load
result = unpickler.load()
deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)
if torch._guards.active_fake_mode() is None:
offset = f.tell() if f_should_read_directly else None
for key in deserialized_storage_keys:
assert key in deserialized_objects
typed_storage = deserialized_objects[key]
typed_storage._untyped_storage._set_from_file(
f, offset, f_should_read_directly,
torch._utils._element_size(typed_storage.dtype))
if offset is not None:
offset = f.tell()
torch._utils._validate_loaded_sparse_tensors()
return result
def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str:
# When using encoding='bytes' in Py3, some **internal** keys stored as
# strings in Py2 are loaded as bytes. This function decodes them with
# ascii encoding, one that Py3 uses by default.
#
# NOTE: This should only be used on internal keys (e.g., `typename` and
# `location` in `persistent_load` below!
if isinstance(bytes_str, bytes):
return bytes_str.decode('ascii')
return bytes_str
def _get_restore_location(map_location):
if map_location is None:
restore_location = default_restore_location
elif isinstance(map_location, dict):
def restore_location(storage, location):
location = map_location.get(location, location)
return default_restore_location(storage, location)
elif isinstance(map_location, (str, bytes)):
def restore_location(storage, location):
return default_restore_location(storage, map_location)
elif isinstance(map_location, torch.device):
def restore_location(storage, location):
return default_restore_location(storage, str(map_location))
else:
def restore_location(storage, location):
result = map_location(storage, location)
if result is None:
result = default_restore_location(storage, location)
return result
return restore_location
class StorageType:
def __init__(self, name):
self._dtype = _get_dtype_from_pickle_storage_type(name)
@property
def dtype(self):
return self._dtype
def __str__(self):
return f'StorageType(dtype={self.dtype})'
def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', overall_storage=None, **pickle_load_args):
restore_location = _get_restore_location(map_location)
loaded_storages = {}
# check if byteswapping is needed
byteordername = 'byteorder'
byteorderdata = None
if zip_file.has_record(byteordername):
byteorderdata = zip_file.get_record(byteordername)
if byteorderdata not in [b'little', b'big']:
raise ValueError('Unknown endianness type: ' + byteorderdata.decode())
elif get_default_load_endianness() == LoadEndianness.LITTLE or \
get_default_load_endianness() is None:
byteorderdata = b'little'
elif get_default_load_endianness() == LoadEndianness.BIG:
byteorderdata = b'big'
elif get_default_load_endianness() == LoadEndianness.NATIVE:
pass
else:
raise ValueError('Invalid load endianness type')
if not zip_file.has_record(byteordername) and \
get_default_load_endianness() is None and \
sys.byteorder == 'big':
# Default behaviour was changed
# See https://github.com/pytorch/pytorch/issues/101688
warnings.warn("The default load endianness for checkpoints without a byteorder mark "
"on big endian machines was changed from 'native' to 'little' endian, "
"to avoid this behavior please use "
"torch.serialization.set_default_load_endianness to set "
"the desired default load endianness",
UserWarning)
def load_tensor(dtype, numel, key, location):
name = f'data/{key}'
if torch._guards.detect_fake_mode(None) is not None:
nbytes = numel * torch._utils._element_size(dtype)
storage = torch.UntypedStorage(nbytes, device='meta')
elif overall_storage is not None:
storage_offset = zip_file.get_record_offset(name)
storage = overall_storage[storage_offset:storage_offset + numel]
else:
storage = zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)._typed_storage()._untyped_storage
# swap here if byteswapping is needed
if byteorderdata is not None:
if byteorderdata.decode() != sys.byteorder:
storage.byteswap(dtype)
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with TypedStorage
typed_storage = torch.storage.TypedStorage(
wrap_storage=restore_location(storage, location),
dtype=dtype,
_internal=True)
if typed_storage._data_ptr() != 0:
loaded_storages[key] = typed_storage
return typed_storage
def persistent_load(saved_id):
assert isinstance(saved_id, tuple)
typename = _maybe_decode_ascii(saved_id[0])
data = saved_id[1:]
assert typename == 'storage', \
f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
storage_type, key, location, numel = data
if storage_type is torch.UntypedStorage:
dtype = torch.uint8
else:
dtype = storage_type.dtype
if key in loaded_storages:
typed_storage = loaded_storages[key]
else:
nbytes = numel * torch._utils._element_size(dtype)
typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))
return typed_storage
load_module_mapping: Dict[str, str] = {
# See https://github.com/pytorch/pytorch/pull/51633
'torch.tensor': 'torch._tensor'
}
# Need to subclass Unpickler instead of directly monkey-patching the find_class method
# because it's marked readonly in pickle.
# The type: ignore is because mypy can't statically determine the type of this class.
class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined]
# from https://stackoverflow.com/questions/13398462/unpickling-python-objects-with-a-changed-module-path/13405732
# Lets us override the imports that pickle uses when unpickling an object.
# This is useful for maintaining BC if we change a module path that tensor instantiation relies on.
def find_class(self, mod_name, name):
if type(name) is str and 'Storage' in name:
try:
return StorageType(name)
except KeyError:
pass
mod_name = load_module_mapping.get(mod_name, mod_name)
return super().find_class(mod_name, name)
# Load the data (which may in turn use `persistent_load` to load tensors)
data_file = io.BytesIO(zip_file.get_record(pickle_file))
unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
unpickler.persistent_load = persistent_load
# Needed for tensors where storage device and rebuild tensor device are
# not connected (wrapper subclasses and tensors rebuilt using numpy)
torch._utils._thread_local_state.map_location = map_location
result = unpickler.load()
del torch._utils._thread_local_state.map_location
torch._utils._validate_loaded_sparse_tensors()
torch._C._log_api_usage_metadata(
"torch.load.metadata", {"serialization_id": zip_file.serialization_id()}
)
return result
def _is_torchscript_zip(zip_file):
return 'constants.pkl' in zip_file.get_all_records()