functorch.combine_state_for_ensemble

functorch.combine_state_for_ensemble(models)func, params, buffers[source]

Prepares a list of torch.nn.Modules for ensembling with vmap().

Given a list of M nn.Modules of the same class, stacks all of their parameters and buffers together to make params and buffers. Each parameter and buffer in the result will have an additional dimension of size M.

combine_state_for_ensemble() also returns func, a functional version of one of the models in models. One cannot directly run func(params, buffers, *args, **kwargs) directly, you probably want to use vmap(func, ...)(params, buffers, *args, **kwargs)

Here’s an example of how to ensemble over a very simple model:

num_models = 5
batch_size = 64
in_features, out_features = 3, 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
data = torch.randn(batch_size, 3)

fmodel, params, buffers = combine_state_for_ensemble(models)
output = vmap(fmodel, (0, 0, None))(params, buffers, data)

assert output.shape == (num_models, batch_size, out_features)