functorch.combine_state_for_ensemble¶
-
functorch.
combine_state_for_ensemble
(models) → func, params, buffers[source]¶ Prepares a list of torch.nn.Modules for ensembling with
vmap()
.Given a list of
M
nn.Modules
of the same class, stacks all of their parameters and buffers together to makeparams
andbuffers
. Each parameter and buffer in the result will have an additional dimension of sizeM
.combine_state_for_ensemble()
also returnsfunc
, a functional version of one of the models inmodels
. One cannot directly runfunc(params, buffers, *args, **kwargs)
directly, you probably want to usevmap(func, ...)(params, buffers, *args, **kwargs)
Here’s an example of how to ensemble over a very simple model:
num_models = 5 batch_size = 64 in_features, out_features = 3, 3 models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] data = torch.randn(batch_size, 3) fmodel, params, buffers = combine_state_for_ensemble(models) output = vmap(fmodel, (0, 0, None))(params, buffers, data) assert output.shape == (num_models, batch_size, out_features)