functorch.combine_state_for_ensemble(models)func, params, buffers[source]

Prepares a list of torch.nn.Modules for ensembling with vmap().

Given a list of M nn.Modules of the same class, stacks all of their parameters and buffers together to make params and buffers. Each parameter and buffer in the result will have an additional dimension of size M.

combine_state_for_ensemble() also returns func, a functional version of one of the models in models. One cannot directly run func(params, buffers, *args, **kwargs) directly, you probably want to use vmap(func, ...)(params, buffers, *args, **kwargs)

Here’s an example of how to ensemble over a very simple model:

num_models = 5
batch_size = 64
in_features, out_features = 3, 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
data = torch.randn(batch_size, 3)

fmodel, params, buffers = combine_state_for_ensemble(models)
output = vmap(fmodel, (0, 0, None))(params, buffers, data)

assert output.shape == (num_models, batch_size, out_features)


All of the modules being stacked together must be the same (except for the values of their parameters/buffers). For example, they should be in the same mode (training vs eval).

This API is subject to change – we’re investigating better ways to create ensembles and would love your feedback how to improve this.


We’ve integrated functorch into PyTorch. As the final step of the integration, functorch.combine_state_for_ensemble is deprecated as of PyTorch 2.0 and will be deleted in a future version of PyTorch >= 2.3. Please use torch.func.stack_module_state instead; see the PyTorch 2.0 release notes and/or the torch.func migration guide for more details


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources