Source code for torch.serialization

import difflib
import inspect
import os
import shutil
import struct
import sys
import torch
import tarfile
import tempfile
import warnings
from contextlib import closing, contextmanager
from ._utils import _import_dotted_name
if sys.version_info[0] == 2:
    import cPickle as pickle
    import pickle


LONG_SIZE = struct.Struct('=l').size
INT_SIZE = struct.Struct('=i').size
SHORT_SIZE = struct.Struct('=h').size

MAGIC_NUMBER = 0x1950a86a20f9469cfc6c

class SourceChangeWarning(Warning):

def mkdtemp():
    path = tempfile.mkdtemp()
    yield path

_package_registry = []

def register_package(priority, tagger, deserializer):
    queue_elem = (priority, tagger, deserializer)

def _cpu_tag(obj):
    if type(obj).__module__ == 'torch':
        return 'cpu'

def _cuda_tag(obj):
    if type(obj).__module__ == 'torch.cuda':
        return 'cuda:' + str(obj.get_device())

def _cpu_deserialize(obj, location):
    if location == 'cpu':
        return obj

def _cuda_deserialize(obj, location):
    if location.startswith('cuda'):
        device_id = max(int(location[5:]), 0)
        return obj.cuda(device_id)

register_package(10, _cpu_tag, _cpu_deserialize)
register_package(20, _cuda_tag, _cuda_deserialize)

def location_tag(storage):
    for _, tagger, _ in _package_registry:
        location = tagger(storage)
        if location:
            return location
    raise RuntimeError("don't know how to determine data location of " +

def default_restore_location(storage, location):
    for _, _, fn in _package_registry:
        result = fn(storage, location)
        if result is not None:
            return result
    raise RuntimeError("don't know how to restore data location of " +
                       torch.typename(storage) + " (tagged with " +
                       location + ")")

def normalize_storage_type(storage_type):
    return getattr(torch, storage_type.__name__)

def storage_to_tensor_type(storage):
    storage_type = type(storage)
    module = _import_dotted_name(storage_type.__module__)
    return getattr(module, storage_type.__name__.replace('Storage', 'Tensor'))

[docs]def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL): """Saves an object to a disk file. See also: :ref:`recommend-saving-models` Args: obj: saved object f: a file-like object (has to implement fileno that returns a file descriptor) or a string containing a file name pickle_module: module used for pickling metadata and objects pickle_protocol: can be specified to override the default protocol """ new_fd = False if isinstance(f, str) or (sys.version_info[0] == 2 and isinstance(f, unicode)): new_fd = True f = open(f, "wb") try: return _save(obj, f, pickle_module, pickle_protocol) finally: if new_fd: f.close()
def _save(obj, f, pickle_module, pickle_protocol): import torch.nn as nn serialized_container_types = {} serialized_storages = {} def persistent_id(obj): # FIXME: the docs say that persistent_id should only return a string # but torch store returns tuples. This works only in the binary protocol # see # # if isinstance(obj, type) and issubclass(obj, nn.Module): if obj in serialized_container_types: return None serialized_container_types[obj] = True source_file = source = None try: source_file = inspect.getsourcefile(obj) source = inspect.getsource(obj) except: # saving the source is optional, so we can ignore any errors warnings.warn("Couldn't retrieve source code for container of " "type " + obj.__name__ + ". It won't be checked " "for correctness upon loading.") return ('module', obj, source_file, source) elif torch.is_storage(obj): storage_type = normalize_storage_type(type(obj)) root, offset = obj._root_storage() root_key = str(root._cdata) location = location_tag(obj) serialized_storages[root_key] = root is_view = obj._cdata != root._cdata if is_view: view_metadata = (str(obj._cdata), offset, obj.size()) else: view_metadata = None return ('storage', storage_type, root_key, location, root.size(), view_metadata) return None sys_info = dict( protocol_version=PROTOCOL_VERSION, little_endian=sys.byteorder == 'little', type_sizes=dict( short=SHORT_SIZE, int=INT_SIZE, long=LONG_SIZE, ), ) pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol) pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol) pickle_module.dump(sys_info, f, protocol=pickle_protocol) pickler = pickle_module.Pickler(f, protocol=pickle_protocol) pickler.persistent_id = persistent_id pickler.dump(obj) serialized_storage_keys = sorted(serialized_storages.keys()) pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol) f.flush() for key in serialized_storage_keys: serialized_storages[key]._write_file(f)
[docs]def load(f, map_location=None, pickle_module=pickle): """Loads an object saved with :func:`` from a file. torch.load can dynamically remap storages to be loaded on a different device using the map_location argument. If it's a callable, it will be called with two arguments: storage and location tag. It's expected to either return a storage that's been moved to a different location, or None (and the location will be resolved using the default method). If this argument is a dict it's expected to be a mapping from location tags used in a file, to location tags of the current system. By default the location tags are 'cpu' for host tensors and 'cuda:device_id' (e.g. 'cuda:2') for cuda tensors. User extensions can register their own tagging and deserialization methods using register_package. Args: f: a file-like object (has to implement fileno that returns a file descriptor, and must implement seek), or a string containing a file name map_location: a function or a dict specifying how to remap storage locations pickle_module: module used for unpickling metadata and objects (has to match the pickle_module used to serialize file) Example: >>> torch.load('') # Load all tensors onto the CPU >>> torch.load('', map_location=lambda storage, loc: storage) # Map tensors from GPU 1 to GPU 0 >>> torch.load('', map_location={'cuda:1':'cuda:0'}) """ new_fd = False if isinstance(f, str) or (sys.version_info[0] == 2 and isinstance(f, unicode)): new_fd = True f = open(f, 'rb') try: return _load(f, map_location, pickle_module) finally: if new_fd: f.close()
def _load(f, map_location, pickle_module): deserialized_objects = {} if map_location is None: restore_location = default_restore_location elif isinstance(map_location, dict): def restore_location(storage, location): location = map_location.get(location, location) return default_restore_location(storage, location) else: def restore_location(storage, location): result = map_location(storage, location) if result is None: result = default_restore_location(storage, location) return result def _check_container_source(container_type, source_file, original_source): current_source = inspect.getsource(container_type) if original_source != current_source: if container_type.dump_patches: file_name = container_type.__name__ + '.patch' diff = difflib.unified_diff(current_source.split('\n'), original_source.split('\n'), source_file, source_file, lineterm="") lines = '\n'.join(diff) try: with open(file_name, 'a+') as f: file_size =, 2) if file_size == 0: f.write(lines) elif file_size != len(lines) or != lines: raise IOError msg = ("Saved a reverse patch to " + file_name + ". " "Run `patch -p0 < " + file_name + "` to revert your " "changes.") except IOError: msg = ("Tried to save a patch, but couldn't create a " "writable file " + file_name + ". Make sure it " "doesn't exist and your working directory is " "writable.") else: msg = ("you can retrieve the original source code by " "accessing the object's source attribute or set " "`torch.nn.Module.dump_patches = True` and use the " "patch tool to revert the changes.") msg = ("source code of class '{}' has changed. {}" .format(torch.typename(container_type), msg)) warnings.warn(msg, SourceChangeWarning) def legacy_load(f): deserialized_objects = {} def persistent_load(saved_id): if isinstance(saved_id, tuple): # Ignore containers that don't have any sources saved if all(saved_id[1:]): _check_container_source(*saved_id) return saved_id[0] return deserialized_objects[int(saved_id)] with closing(, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \ mkdtemp() as tmpdir: tar.extract('storages', path=tmpdir) with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f: num_storages = pickle_module.load(f) for i in range(num_storages): args = pickle_module.load(f) key, location, storage_type = args obj = storage_type._new_with_file(f) obj = restore_location(obj, location) deserialized_objects[key] = obj storage_views = pickle_module.load(f) for target_cdata, root_cdata, offset, size in storage_views: root = deserialized_objects[root_cdata] deserialized_objects[target_cdata] = root[offset:offset + size] tar.extract('tensors', path=tmpdir) with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f: num_tensors = pickle_module.load(f) for i in range(num_tensors): args = pickle_module.load(f) key, storage_id, original_tensor_type = args storage = deserialized_objects[storage_id] tensor_type = storage_to_tensor_type(storage) tensor = tensor_type._new_with_metadata_file(f, storage) deserialized_objects[key] = tensor pickle_file = tar.extractfile('pickle') unpickler = pickle_module.Unpickler(pickle_file) unpickler.persistent_load = persistent_load result = unpickler.load() return result deserialized_objects = {} def persistent_load(saved_id): assert isinstance(saved_id, tuple) typename = saved_id[0] data = saved_id[1:] if typename == 'module': # Ignore containers that don't have any sources saved if all(data[1:]): _check_container_source(*data) return data[0] elif typename == 'storage': data_type, root_key, location, size, view_metadata = data if root_key not in deserialized_objects: deserialized_objects[root_key] = restore_location( data_type(size), location) storage = deserialized_objects[root_key] if view_metadata is not None: view_key, offset, view_size = view_metadata if view_key not in deserialized_objects: deserialized_objects[view_key] = storage[offset:offset + view_size] return deserialized_objects[view_key] else: return storage else: raise RuntimeError("Unknown saved id type: %s" % saved_id[0]) # try the legacy loader first, which only works if f is a tarfile try: return legacy_load(f) except tarfile.TarError: pass magic_number = pickle_module.load(f) if magic_number != MAGIC_NUMBER: raise RuntimeError("Invalid magic number; corrupt file?") protocol_version = pickle_module.load(f) if protocol_version != PROTOCOL_VERSION: raise RuntimeError("Invalid protocol version: %s" % protocol_version) _sys_info = pickle_module.load(f) unpickler = pickle_module.Unpickler(f) unpickler.persistent_load = persistent_load result = unpickler.load() deserialized_storage_keys = pickle_module.load(f) offset = f.tell() for key in deserialized_storage_keys: assert key in deserialized_objects deserialized_objects[key]._set_from_file(f, offset) offset = None return result