# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from torch import Tensor
from typing import List, Tuple
from .named_members_polyfill import _named_parameters, _named_buffers
import copy
# Utilities to make nn.Module "functional"
# In particular the goal is to be able to provide a function that takes as input
# the parameters and evaluate the nn.Module using fixed inputs.
def _del_nested_attr(obj: nn.Module, names: List[str]) -> None:
"""
Deletes the attribute specified by the given list of names.
For example, to delete the attribute obj.conv.weight,
use _del_nested_attr(obj, ['conv', 'weight'])
"""
if len(names) == 1:
delattr(obj, names[0])
else:
_del_nested_attr(getattr(obj, names[0]), names[1:])
def _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:
"""
Set the attribute specified by the given list of names to value.
For example, to set the attribute obj.conv.weight,
use _del_nested_attr(obj, ['conv', 'weight'], value)
"""
if len(names) == 1:
setattr(obj, names[0], value)
else:
_set_nested_attr(getattr(obj, names[0]), names[1:], value)
def _get_nested_attr(obj: nn.Module, names: List[str]) -> None:
if len(names) == 1:
return getattr(obj, names[0])
else:
_get_nested_attr(getattr(obj, names[0]), names[1:])
def raise_parameter_tying_error():
raise RuntimeError(
"make_functional(module): we don't yet support models that "
"do parameter tying (also sometimes known as weight sharing). "
"Please try to rewrite your model by replacing all instances of the "
"tied parameter with another and/or comment your support in "
"https://github.com/pytorch/functorch/issues/446")
def create_names_map(named_params, tied_named_params):
"""
named_params is a dictionary of tensors: {'A': A, 'B': B}
tied_named_params is another dictionary of tensors {'A': A, 'B': B, 'B_tied': B}
with potentially tied (or 'duplicated') tensors
This function creates a mapping from the names in named_params to the
names in tied_named_params: {'A': ['A'], 'B': ['B', 'B_tied']}.
"""
named_params = {k: v for k, v in named_params}
tied_named_params = {k: v for k, v in tied_named_params}
tensors_dict_keys = set(named_params.keys())
tied_tensors_dict_keys = set(tied_named_params.keys())
assert tensors_dict_keys.issubset(tied_tensors_dict_keys)
tensor_to_mapping = {}
for key, tensor in named_params.items():
tensor_to_mapping[tensor] = (key, [])
for key, tensor in tied_named_params.items():
assert tensor in tensor_to_mapping
tensor_to_mapping[tensor][1].append(key.split('.'))
result = {key: value for key, value in tensor_to_mapping.values()}
return result
def _extract_members(mod: nn.Module, _named_members, named_members, subclass):
all_named_members = tuple(_named_members(mod, remove_duplicate=False))
named_members = tuple(named_members())
names_map = create_names_map(named_members, all_named_members)
# Remove all the members in the model
memo = {}
for name, p in all_named_members:
if p not in memo:
memo[p] = subclass(torch.empty_like(p, device='meta'))
replacement = memo[p]
_set_nested_attr(mod, name.split("."), replacement)
if len(named_members) == 0:
names, params = (), ()
else:
names, params = zip(*named_members)
return params, names, names_map
def extract_weights(mod: nn.Module):
"""
This function removes all the Parameters from the model and
return them as a tuple as well as their original attribute names.
The weights must be re-loaded with `load_weights` before the model
can be used again.
Note that this function modifies the model in place and after this
call, mod.parameters() will be empty.
"""
return _extract_members(mod, _named_parameters, mod.named_parameters, nn.Parameter)
def extract_buffers(mod: nn.Module):
return _extract_members(mod, _named_buffers, mod.named_buffers, lambda x: x)
def load_weights(mod: nn.Module, names: List[str], params: Tuple[Tensor, ...], as_params=False) -> None:
"""
Reload a set of weights so that `mod` can be used again to perform a forward pass.
Note that the `params` are regular Tensors (that can have history) and so are left
as Tensors. This means that mod.parameters() will still be empty after this call.
"""
for name, p in zip(names, params):
if as_params:
p = nn.Parameter(p)
_del_nested_attr(mod, name.split("."))
_set_nested_attr(mod, name.split("."), p)
def _swap_state(mod: nn.Module, names_map: List[str], elems):
result = []
for (_, attr_names), elem in zip(names_map.items(), elems):
for i, attr_name in enumerate(attr_names):
if i == 0:
result.append(_get_nested_attr(mod, attr_name))
_del_nested_attr(mod, attr_name)
_set_nested_attr(mod, attr_name, elem)
return result
def load_buffers(mod: nn.Module, names: List[str], buffers: Tuple[Tensor, ...], as_params=False) -> None:
for name, p in zip(names, buffers):
_set_nested_attr(mod, name.split("."), p)
def load_state(
model: nn.Module,
weights: List[Tensor], weight_names: List[str],
buffers=(), buffer_names=()):
"""load_state(model, weights, weight_names, buffers=(), buffer_names=()) -> model
load_state takes `weights` and `buffers` and assigns them to the model.
This is the inverse operation of `make_functional_deprecated_v1`.
"""
assert len(weight_names) == len(weights)
load_weights(model, weight_names, weights)
if len(buffers) > 0:
assert len(buffer_names) == len(buffers)
load_buffers(model, buffer_names, buffers)
return model
def make_functional_deprecated_v1(model: nn.Module):
"""make_functional_deprecated_v1(model) -> weights, func, weight_names
Given an nn.Module, make_functional_deprecated_v1 extracts the state (weights)
and returns a functional version of the model, `func`. This makes
it so that it is possible use transforms over the parameters of
`model`.
`func` can be invoked as follows:
```
x = torch.randn(4, 3)
model = nn.Linear(3, 3)
weights, func, _ = make_functional_deprecated_v1(model)
func(weights, (x,))
```
And here is an example of applying the grad transform:
```
x = torch.randn(4, 3)
model = nn.Linear(3, 3)
weights, _, func = make_functional_deprecated_v1(model)
grad_weights = grad(func)(weights, (x,))
```
To put the state back into a model, use `load_state`.
"""
buffers = list(model.buffers())
if len(buffers) > 0:
raise RuntimeError('make_functional_deprecated_v1(model): `model` has buffers. Please use '
'make_functional_with_buffers_deprecated_v1(model) instead.')
weights, descriptors, _ = extract_weights(model)
def fun(weights, data):
mutable_model = copy.deepcopy(model)
load_weights(mutable_model, descriptors, weights)
return mutable_model(*data)
return weights, fun, descriptors
def make_functional_with_buffers_deprecated_v1(model: nn.Module):
"""make_functional_with_buffers_deprecated_v1(model) -> weights, buffers, func, weight_names, buffer_names
Given an nn.Module, make_functional_with_buffers_deprecated_v1 extracts the state (weights and buffers)
and returns a functional version of the model, `func`.
`func` can be invoked as follows:
```
x = torch.randn(4, 3)
model = nn.Linear(3, 3)
weights, buffers, func, _, _ = make_functional_with_buffers_deprecated_v1(model)
func(weights, buffers, (x,))
```
And here is an example of applying the grad transform:
```
x = torch.randn(4, 3)
model = nn.Linear(3, 3)
weights, buffers, func, _, _ = make_functional_with_buffers_deprecated_v1(model)
func(weights, buffers, (x,))
grad_weights = grad(func)(weights, buffers, (x,))
```
To put the state back into a model, use `load_state`.
"""
weights, weight_descriptors, _ = extract_weights(model)
buffers, buf_descriptors, _ = extract_buffers(model)
def fun(weights, buffers, data):
mutable_model = copy.deepcopy(model)
load_weights(mutable_model, weight_descriptors, weights)
load_buffers(mutable_model, buf_descriptors, buffers)
return mutable_model(*data)
return weights, buffers, fun, weight_descriptors, buf_descriptors
class FunctionalModuleWithBuffers(nn.Module):
"""
This is the callable object returned by :func:`make_functional_with_buffers`.
"""
def __init__(self, stateless_model, param_names, buffer_names,
param_names_map, buffer_names_map):
super(FunctionalModuleWithBuffers, self).__init__()
self.stateless_model = stateless_model
self.param_names = param_names
self.buffer_names = buffer_names
self.all_names_map = dict(param_names_map)
self.all_names_map.update(buffer_names_map)
@staticmethod
def _create_from(model):
# TODO: We don't need to copy the model to create a stateless copy
model_copy = copy.deepcopy(model)
params, param_names, param_names_map = extract_weights(model_copy)
buffers, buffer_names, buffer_names_map = extract_buffers(model_copy)
return (
FunctionalModuleWithBuffers(model_copy, param_names, buffer_names,
param_names_map, buffer_names_map),
params,
buffers,
)
def forward(self, params, buffers, *args, **kwargs):
# Temporarily load the state back onto self.stateless_model
old_state = _swap_state(
self.stateless_model,
self.all_names_map,
list(params) + list(buffers))
try:
return self.stateless_model(*args, **kwargs)
finally:
# Remove the loaded state on self.stateless_model
_swap_state(self.stateless_model, self.all_names_map, old_state)
class FunctionalModule(nn.Module):
"""
This is the callable object returned by :func:`make_functional`.
"""
def __init__(self, stateless_model, param_names, names_map):
super(FunctionalModule, self).__init__()
self.stateless_model = stateless_model
self.param_names = param_names
self.names_map = names_map
@staticmethod
def _create_from(model):
# TODO: We don't need to copy the model to create a stateless copy
model_copy = copy.deepcopy(model)
params, param_names, names_map = extract_weights(model_copy)
return FunctionalModule(model_copy, param_names, names_map), params
def forward(self, params, *args, **kwargs):
# Temporarily load the state back onto self.stateless_model
old_state = _swap_state(self.stateless_model, self.names_map, params)
try:
return self.stateless_model(*args, **kwargs)
finally:
# Remove the loaded state on self.stateless_model
_swap_state(self.stateless_model, self.names_map, old_state)
[docs]def make_functional(model: nn.Module):
"""make_functional(model) -> func, params
Given a ``torch.nn.Module``, :func:`make_functional` extracts the state
(params) and returns a functional version of the model, ``func``. This
makes it so that it is possible use transforms over the parameters of
``model``.
``func`` can be invoked as follows:
.. code-block:: python
import torch
import torch.nn as nn
from functorch import make_functional
x = torch.randn(4, 3)
model = nn.Linear(3, 3)
func, params = make_functional(model)
func(params, x)
And here is an example of applying the grad transform over the parameters
of a model.
.. code-block:: python
import torch
import torch.nn as nn
from functorch import make_functional, grad
x = torch.randn(4, 3)
t = torch.randn(4, 3)
model = nn.Linear(3, 3)
func, params = make_functional(model)
def compute_loss(params, x, t):
y = func(params, x)
return nn.functional.mse_loss(y, t)
grad_weights = grad(compute_loss)(params, x, t)
If the model has any buffers, please use :func:`make_functional_with_buffers` instead.
"""
buffers = list(model.buffers())
if len(buffers) > 0:
raise RuntimeError('make_functional(model): `model` has buffers. Please use '
'make_functional_with_buffers(model) instead.')
return FunctionalModule._create_from(model)
[docs]def make_functional_with_buffers(model: nn.Module):
"""make_functional_with_buffers(model) -> func, params, buffers
Given a ``torch.nn.Module``, make_functional_with_buffers extracts the
state (params and buffers) and returns a functional version of the model
``func`` that can be invoked like a function.
``func`` can be invoked as follows:
.. code-block:: python
import torch
import torch.nn as nn
from functorch import make_functional_with_buffers
x = torch.randn(4, 3)
model = nn.Linear(3, 3)
func, params, buffers = make_functional_with_buffers(model)
func(params, buffers, x)
And here is an example of applying the grad transform over the parameters
of a model:
.. code-block:: python
import torch
import torch.nn as nn
from functorch import make_functional_with_buffers, grad
x = torch.randn(4, 3)
t = torch.randn(4, 3)
model = nn.Linear(3, 3)
func, params, buffers = make_functional_with_buffers(model)
def compute_loss(params, buffers, x, t):
y = func(params, buffers, x)
return nn.functional.mse_loss(y, t)
grad_weights = grad(compute_loss)(params, buffers, x, t)
"""
return FunctionalModuleWithBuffers._create_from(model)
def transpose_stack(tuple_of_tuple_of_tensors):
tuple_of_tuple_of_tensors = tuple(zip(*tuple_of_tuple_of_tensors))
results = tuple(torch.stack(shards).detach() for shards in tuple_of_tuple_of_tensors)
return results
[docs]def combine_state_for_ensemble(models):
"""combine_state_for_ensemble(models) -> func, params, buffers
Prepares a list of torch.nn.Modules for ensembling with :func:`vmap`.
Given a list of ``M`` ``nn.Modules`` of the same class, stacks all of their
parameters and buffers together to make ``params`` and ``buffers``.
Each parameter and buffer in the result will have an additional dimension
of size ``M``.
:func:`combine_state_for_ensemble` also returns ``func``, a functional
version of one of the models in :attr:`models`. One cannot directly run
``func(params, buffers, *args, **kwargs)`` directly, you probably want to
use ``vmap(func, ...)(params, buffers, *args, **kwargs)``
Here's an example of how to ensemble over a very simple model:
.. code-block:: python
num_models = 5
batch_size = 64
in_features, out_features = 3, 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
data = torch.randn(batch_size, 3)
fmodel, params, buffers = combine_state_for_ensemble(models)
output = vmap(fmodel, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)
.. warning::
All of the modules being stacked together must be the same (except for
the values of their parameters/buffers). For example, they should be in the
same mode (training vs eval).
This API is subject to change -- we're investigating better ways to
create ensembles and would love your feedback how to improve this.
"""
if len(models) == 0:
raise RuntimeError('combine_state_for_ensemble: Expected at least one model, got 0.')
if not (all(m.training for m in models) or all(not m.training for m in models)):
raise RuntimeError('combine_state_for_ensemble: Expected all models to '
'have the same training/eval mode.')
model0_typ = type(models[0])
if not all(type(m) == model0_typ for m in models):
raise RuntimeError('combine_state_for_ensemble: Expected all models to '
'be of the same class.')
funcs, params, buffers = zip(*[make_functional_with_buffers(model)
for model in models])
params = transpose_stack(params)
buffers = transpose_stack(buffers)
return funcs[0], params, buffers
def functional_init(model_class, ensemble_shape=(), device='cpu'):
def wrapped(*args, **kwargs):
if len(ensemble_shape) >= 2:
raise ValueError('NYI: ensemble_shape with more than 1 element')
if len(ensemble_shape) == 0:
model = model_class(*args, **kwargs).to(device)
return make_functional_deprecated_v1(model)
num_models = ensemble_shape[0]
if num_models <= 0:
raise ValueError(f"num_models {num_models} should be > 0")
# NB: Not very efficient, more of a POC
models = tuple(model_class(*args, **kwargs).to(device)
for _ in range(num_models))
_, fn, names = make_functional_deprecated_v1(model_class(*args, **kwargs))
weights = tuple(make_functional_deprecated_v1(model)[0] for model in models)
weights = tuple(zip(*weights))
weights = tuple(torch.stack(shards).detach() for shards in weights)
return weights, fn, names
return wrapped
def functional_init_with_buffers(model_class, ensemble_shape=(), device='cpu'):
def wrapped(*args, **kwargs):
if len(ensemble_shape) >= 2:
raise ValueError('NYI: ensemble_shape with more than 1 element')
if len(ensemble_shape) == 0:
model = model_class(*args, **kwargs).to(device)
return make_functional_deprecated_v1(model)
num_models = ensemble_shape[0]
if num_models <= 0:
raise ValueError(f"num_models {num_models} should be > 0")
# NB: Not very efficient, more of a POC
models = tuple(model_class(*args, **kwargs).to(device)
for _ in range(num_models))
_, _, fn, weight_names, buffer_names = \
make_functional_with_buffers_deprecated_v1(model_class(*args, **kwargs))
weights, buffers = zip(*tuple(make_functional_with_buffers_deprecated_v1(model)[:2]
for model in models))
weights = tuple(zip(*weights))
weights = tuple(torch.stack(shards).detach() for shards in weights)
buffers = tuple(zip(*buffers))
buffers = tuple(torch.stack(shards).detach() for shards in buffers)
return weights, buffers, fn, weight_names, buffer_names
return wrapped