Source code for torch.nn.modules.module

from collections import OrderedDict, Iterable
import functools

import torch
from ..backends.thnn import backend as thnn_backend
from ..parameter import Parameter
from torch.autograd import Variable
import torch.utils.hooks as hooks


def _addindent(s_, numSpaces):
    s = s_.split('\n')
    # don't do anything for single-line stuff
    if len(s) == 1:
        return s_
    first = s.pop(0)
    s = [(numSpaces * ' ') + line for line in s]
    s = '\n'.join(s)
    s = first + '\n' + s
    return s


[docs]class Module(object): r"""Base class for all neural network modules. Your models should also subclass this class. Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes:: import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x)) Submodules assigned in this way will be registered, and will have their parameters converted too when you call .cuda(), etc. """ dump_patches = False def __init__(self): self._backend = thnn_backend self._parameters = OrderedDict() self._buffers = OrderedDict() self._backward_hooks = OrderedDict() self._forward_hooks = OrderedDict() self._forward_pre_hooks = OrderedDict() self._modules = OrderedDict() self.training = True
[docs] def forward(self, *input): """Defines the computation performed at every call. Should be overriden by all subclasses. .. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:`Module` instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them. """ raise NotImplementedError
[docs] def register_buffer(self, name, tensor): """Adds a persistent buffer to the module. This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's ``running_mean`` is not a parameter, but is part of the persistent state. Buffers can be accessed as attributes using given names. Args: name (string): name of the buffer. The buffer can be accessed from this module using the given name tensor (Tensor): buffer to be registered. Example: >>> self.register_buffer('running_mean', torch.zeros(num_features)) """ if hasattr(self, name) and name not in self._buffers: raise KeyError("attribute '{}' already exists".format(name)) self._buffers[name] = tensor
[docs] def register_parameter(self, name, param): """Adds a parameter to the module. The parameter can be accessed as an attribute using given name. Args: name (string): name of the parameter. The parameter can be accessed from this module using the given name parameter (Parameter): parameter to be added to the module. """ if '_parameters' not in self.__dict__: raise AttributeError( "cannot assign parameter before Module.__init__() call") if hasattr(self, name) and name not in self._parameters: raise KeyError("attribute '{}' already exists".format(name)) if param is None: self._parameters[name] = None elif not isinstance(param, Parameter): raise TypeError("cannot assign '{}' object to parameter '{}' " "(torch.nn.Parameter or None required)" .format(torch.typename(param), name)) elif param.grad_fn: raise ValueError( "Cannot assign non-leaf Variable to parameter '{0}'. Model " "parameters must be created explicitly. To express '{0}' " "as a function of another variable, compute the value in " "the forward() method.".format(name)) else: self._parameters[name] = param
[docs] def add_module(self, name, module): """Adds a child module to the current module. The module can be accessed as an attribute using the given name. Args: name (string): name of the child module. The child module can be accessed from this module using the given name parameter (Module): child module to be added to the module. """ if not isinstance(module, Module) and module is not None: raise TypeError("{} is not a Module subclass".format( torch.typename(module))) if hasattr(self, name) and name not in self._modules: raise KeyError("attribute '{}' already exists".format(name)) self._modules[name] = module
def _apply(self, fn): for module in self.children(): module._apply(fn) for param in self._parameters.values(): if param is not None: # Variables stored in modules are graph leaves, and we don't # want to create copy nodes, so we have to unpack the data. param.data = fn(param.data) if param._grad is not None: param._grad.data = fn(param._grad.data) for key, buf in self._buffers.items(): if buf is not None: self._buffers[key] = fn(buf) return self
[docs] def apply(self, fn): """Applies ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self. Typical use includes initializing the parameters of a model (see also :ref:`torch-nn-init`). Args: fn (:class:`Module` -> None): function to be applied to each submodule Returns: Module: self Example: >>> def init_weights(m): >>> print(m) >>> if type(m) == nn.Linear: >>> m.weight.data.fill_(1.0) >>> print(m.weight) >>> >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) >>> net.apply(init_weights) Linear (2 -> 2) Parameter containing: 1 1 1 1 [torch.FloatTensor of size 2x2] Linear (2 -> 2) Parameter containing: 1 1 1 1 [torch.FloatTensor of size 2x2] Sequential ( (0): Linear (2 -> 2) (1): Linear (2 -> 2) ) """ for module in self.children(): module.apply(fn) fn(self) return self
[docs] def cuda(self, device=None): """Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. Arguments: device (int, optional): if specified, all parameters will be copied to that device Returns: Module: self """ return self._apply(lambda t: t.cuda(device))
[docs] def cpu(self): """Moves all model parameters and buffers to the CPU. Returns: Module: self """ return self._apply(lambda t: t.cpu())
[docs] def type(self, dst_type): """Casts all parameters and buffers to dst_type. Arguments: dst_type (type or string): the desired type Returns: Module: self """ return self._apply(lambda t: t.type(dst_type))
[docs] def float(self): """Casts all parameters and buffers to float datatype. Returns: Module: self """ return self._apply(lambda t: t.float())
[docs] def double(self): """Casts all parameters and buffers to double datatype. Returns: Module: self """ return self._apply(lambda t: t.double())
[docs] def half(self): """Casts all parameters and buffers to half datatype. Returns: Module: self """ return self._apply(lambda t: t.half())
[docs] def register_backward_hook(self, hook): """Registers a backward hook on the module. The hook will be called every time the gradients with respect to module inputs are computed. The hook should have the following signature:: hook(module, grad_input, grad_output) -> Tensor or None The :attr:`grad_input` and :attr:`grad_output` may be tuples if the module has multiple inputs or outputs. The hook should not modify its arguments, but it can optionally return a new gradient with respect to input that will be used in place of :attr:`grad_input` in subsequent computations. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` """ handle = hooks.RemovableHandle(self._backward_hooks) self._backward_hooks[handle.id] = hook return handle
[docs] def register_forward_pre_hook(self, hook): """Registers a forward pre-hook on the module. The hook will be called every time before :func:`forward` is invoked. It should have the following signature:: hook(module, input) -> None The hook should not modify the input. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` """ handle = hooks.RemovableHandle(self._forward_pre_hooks) self._forward_pre_hooks[handle.id] = hook return handle
[docs] def register_forward_hook(self, hook): r"""Registers a forward hook on the module. The hook will be called every time after :func:`forward` has computed an output. It should have the following signature:: hook(module, input, output) -> None The hook should not modify the input or output. Returns: :class:`torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling ``handle.remove()`` """ handle = hooks.RemovableHandle(self._forward_hooks) self._forward_hooks[handle.id] = hook return handle
def _tracing_name(self, tracing_state): if not tracing_state._traced_module_stack: return None module = tracing_state._traced_module_stack[-1] for name, child in module.named_children(): if child is self: return name return None def _slow_forward(self, *input, **kwargs): input_vars = tuple(torch.autograd.function._iter_variables(input)) tracing_state = torch.jit.get_tracing_state(input_vars) if not tracing_state: return self.forward(*input, **kwargs) if not hasattr(tracing_state, '_traced_module_stack'): tracing_state._traced_module_stack = [] name = self._tracing_name(tracing_state) if name: tracing_state.push_scope('%s[%s]' % (self.__class__.__name__, name)) else: tracing_state.push_scope(self.__class__.__name__) tracing_state._traced_module_stack.append(self) try: result = self.forward(*input, **kwargs) finally: tracing_state.pop_scope() tracing_state._traced_module_stack.pop() return result def __call__(self, *input, **kwargs): for hook in self._forward_pre_hooks.values(): hook(self, input) if torch.jit._tracing: result = self._slow_forward(*input, **kwargs) else: result = self.forward(*input, **kwargs) for hook in self._forward_hooks.values(): hook_result = hook(self, input, result) if hook_result is not None: raise RuntimeError( "forward hooks should never return any values, but '{}'" "didn't return None".format(hook)) if len(self._backward_hooks) > 0: var = result while not isinstance(var, Variable): if isinstance(var, dict): var = next((v for v in var.values() if isinstance(v, Variable))) else: var = var[0] grad_fn = var.grad_fn if grad_fn is not None: for hook in self._backward_hooks.values(): wrapper = functools.partial(hook, self) functools.update_wrapper(wrapper, hook) grad_fn.register_hook(wrapper) return result def __setstate__(self, state): self.__dict__.update(state) if '_forward_pre_hooks' not in self.__dict__: self._forward_pre_hooks = OrderedDict() def __getattr__(self, name): if '_parameters' in self.__dict__: _parameters = self.__dict__['_parameters'] if name in _parameters: return _parameters[name] if '_buffers' in self.__dict__: _buffers = self.__dict__['_buffers'] if name in _buffers: return _buffers[name] if '_modules' in self.__dict__: modules = self.__dict__['_modules'] if name in modules: return modules[name] raise AttributeError("'{}' object has no attribute '{}'".format( type(self).__name__, name)) def __setattr__(self, name, value): def remove_from(*dicts): for d in dicts: if name in d: del d[name] params = self.__dict__.get('_parameters') if isinstance(value, Parameter): if params is None: raise AttributeError( "cannot assign parameters before Module.__init__() call") remove_from(self.__dict__, self._buffers, self._modules) self.register_parameter(name, value) elif params is not None and name in params: if value is not None: raise TypeError("cannot assign '{}' as parameter '{}' " "(torch.nn.Parameter or None expected)" .format(torch.typename(value), name)) self.register_parameter(name, value) else: modules = self.__dict__.get('_modules') if isinstance(value, Module): if modules is None: raise AttributeError( "cannot assign module before Module.__init__() call") remove_from(self.__dict__, self._parameters, self._buffers) modules[name] = value elif modules is not None and name in modules: if value is not None: raise TypeError("cannot assign '{}' as child module '{}' " "(torch.nn.Module or None expected)" .format(torch.typename(value), name)) modules[name] = value else: buffers = self.__dict__.get('_buffers') if buffers is not None and name in buffers: if value is not None and not torch.is_tensor(value): raise TypeError("cannot assign '{}' as buffer '{}' " "(torch.Tensor or None expected)" .format(torch.typename(value), name)) buffers[name] = value else: object.__setattr__(self, name, value) def __delattr__(self, name): if name in self._parameters: del self._parameters[name] elif name in self._buffers: del self._buffers[name] elif name in self._modules: del self._modules[name] else: object.__delattr__(self, name)
[docs] def state_dict(self, destination=None, prefix='', keep_vars=False): """Returns a dictionary containing a whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. When keep_vars is ``True``, it returns a Variable for each parameter (rather than a Tensor). Args: destination (dict, optional): if not None, the return dictionary is stored into destination. Default: None prefix (string, optional): Adds a prefix to the key (name) of every parameter and buffer in the result dictionary. Default: '' keep_vars (bool, optional): if ``True``, returns a Variable for each parameter. If ``False``, returns a Tensor for each parameter. Default: ``False`` Returns: dict: a dictionary containing a whole state of the module Example: >>> module.state_dict().keys() ['bias', 'weight'] """ if destination is None: destination = OrderedDict() for name, param in self._parameters.items(): if param is not None: destination[prefix + name] = param if keep_vars else param.data for name, buf in self._buffers.items(): if buf is not None: destination[prefix + name] = buf for name, module in self._modules.items(): if module is not None: module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars) return destination
[docs] def load_state_dict(self, state_dict, strict=True): """Copies parameters and buffers from :attr:`state_dict` into this module and its descendants. If :attr:`strict` is ``True`` then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :func:`state_dict()` function. Arguments: state_dict (dict): A dict containing parameters and persistent buffers. strict (bool): Strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's `:func:`state_dict()` function. """ own_state = self.state_dict() for name, param in state_dict.items(): if name in own_state: if isinstance(param, Parameter): # backwards compatibility for serialized parameters param = param.data try: own_state[name].copy_(param) except Exception: raise RuntimeError('While copying the parameter named {}, ' 'whose dimensions in the model are {} and ' 'whose dimensions in the checkpoint are {}.' .format(name, own_state[name].size(), param.size())) elif strict: raise KeyError('unexpected key "{}" in state_dict' .format(name)) if strict: missing = set(own_state.keys()) - set(state_dict.keys()) if len(missing) > 0: raise KeyError('missing keys in state_dict: "{}"'.format(missing))
[docs] def parameters(self): """Returns an iterator over module parameters. This is typically passed to an optimizer. Yields: Parameter: module parameter Example: >>> for param in model.parameters(): >>> print(type(param.data), param.size()) <class 'torch.FloatTensor'> (20L,) <class 'torch.FloatTensor'> (20L, 1L, 5L, 5L) """ for name, param in self.named_parameters(): yield param
[docs] def named_parameters(self, memo=None, prefix=''): """Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself Yields: (string, Parameter): Tuple containing the name and parameter Example: >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size()) """ if memo is None: memo = set() for name, p in self._parameters.items(): if p is not None and p not in memo: memo.add(p) yield prefix + ('.' if prefix else '') + name, p for mname, module in self.named_children(): submodule_prefix = prefix + ('.' if prefix else '') + mname for name, p in module.named_parameters(memo, submodule_prefix): yield name, p
def _all_buffers(self, memo=None): if memo is None: memo = set() for name, b in self._buffers.items(): if b is not None and b not in memo: memo.add(b) yield b for module in self.children(): for b in module._all_buffers(memo): yield b
[docs] def children(self): """Returns an iterator over immediate children modules. Yields: Module: a child module """ for name, module in self.named_children(): yield module
[docs] def named_children(self): """Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself. Yields: (string, Module): Tuple containing a name and child module Example: >>> for name, module in model.named_children(): >>> if name in ['conv4', 'conv5']: >>> print(module) """ memo = set() for name, module in self._modules.items(): if module is not None and module not in memo: memo.add(module) yield name, module
[docs] def modules(self): """Returns an iterator over all modules in the network. Yields: Module: a module in the network Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.modules()): >>> print(idx, '->', m) 0 -> Sequential ( (0): Linear (2 -> 2) (1): Linear (2 -> 2) ) 1 -> Linear (2 -> 2) """ for name, module in self.named_modules(): yield module
[docs] def named_modules(self, memo=None, prefix=''): """Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself. Yields: (string, Module): Tuple of name and module Note: Duplicate modules are returned only once. In the following example, ``l`` will be returned only once. >>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.named_modules()): >>> print(idx, '->', m) 0 -> ('', Sequential ( (0): Linear (2 -> 2) (1): Linear (2 -> 2) )) 1 -> ('0', Linear (2 -> 2)) """ if memo is None: memo = set() if self not in memo: memo.add(self) yield prefix, self for name, module in self._modules.items(): if module is None: continue submodule_prefix = prefix + ('.' if prefix else '') + name for m in module.named_modules(memo, submodule_prefix): yield m
[docs] def train(self, mode=True): """Sets the module in training mode. This has any effect only on modules such as Dropout or BatchNorm. Returns: Module: self """ self.training = mode for module in self.children(): module.train(mode) return self
[docs] def eval(self): """Sets the module in evaluation mode. This has any effect only on modules such as Dropout or BatchNorm. """ return self.train(False)
[docs] def zero_grad(self): """Sets gradients of all model parameters to zero.""" for p in self.parameters(): if p.grad is not None: if p.grad.volatile: p.grad.data.zero_() else: data = p.grad.data p.grad = Variable(data.new().resize_as_(data).zero_())
def share_memory(self): return self._apply(lambda t: t.share_memory_()) def __repr__(self): tmpstr = self.__class__.__name__ + '(\n' for key, module in self._modules.items(): modstr = module.__repr__() modstr = _addindent(modstr, 2) tmpstr = tmpstr + ' (' + key + '): ' + modstr + '\n' tmpstr = tmpstr + ')' return tmpstr def __dir__(self): module_attrs = dir(self.__class__) attrs = list(self.__dict__.keys()) parameters = list(self._parameters.keys()) modules = list(self._modules.keys()) buffers = list(self._buffers.keys()) keys = module_attrs + attrs + parameters + modules + buffers return sorted(keys)