functorch.combine_state_for_ensemble¶
- functorch.combine_state_for_ensemble(models) func, params, buffers [source]¶
Prepares a list of torch.nn.Modules for ensembling with
vmap()
.Given a list of
M
nn.Modules
of the same class, stacks all of their parameters and buffers together to makeparams
andbuffers
. Each parameter and buffer in the result will have an additional dimension of sizeM
.combine_state_for_ensemble()
also returnsfunc
, a functional version of one of the models inmodels
. One cannot directly runfunc(params, buffers, *args, **kwargs)
directly, you probably want to usevmap(func, ...)(params, buffers, *args, **kwargs)
Here’s an example of how to ensemble over a very simple model:
num_models = 5 batch_size = 64 in_features, out_features = 3, 3 models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] data = torch.randn(batch_size, 3) fmodel, params, buffers = combine_state_for_ensemble(models) output = vmap(fmodel, (0, 0, None))(params, buffers, data) assert output.shape == (num_models, batch_size, out_features)
Warning
All of the modules being stacked together must be the same (except for the values of their parameters/buffers). For example, they should be in the same mode (training vs eval).
This API is subject to change – we’re investigating better ways to create ensembles and would love your feedback how to improve this.
Warning
We’ve integrated functorch into PyTorch. As the final step of the integration, functorch.combine_state_for_ensemble is deprecated as of PyTorch 2.0 and will be deleted in a future version of PyTorch >= 2.3. Please use torch.func.stack_module_state instead; see the PyTorch 2.0 release notes and/or the torch.func migration guide for more details https://pytorch.org/docs/main/func.migrating.html