torch.optim¶
torch.optim
is a package implementing various optimization algorithms.
Most commonly used methods are already supported, and the interface is general enough, so that more sophisticated ones can also be easily integrated in the future.
How to use an optimizer¶
To use torch.optim
you have to construct an optimizer object that will hold
the current state and will update the parameters based on the computed gradients.
Constructing it¶
To construct an Optimizer
you have to give it an iterable containing the
parameters (all should be Parameter
s) or named parameters
(tuples of (str, Parameter
)) to optimize. Then,
you can specify optimizerspecific options such as the learning rate, weight decay, etc.
Example:
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr=0.0001)
Named parameters example:
optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([('layer0', var1), ('layer1', var2)], lr=0.0001)
Perparameter options¶
Optimizer
s also support specifying perparameter options. To do this, instead
of passing an iterable of Variable
s, pass in an iterable of
dict
s. Each of them will define a separate parameter group, and should contain
a params
key, containing a list of parameters belonging to it. Other keys
should match the keyword arguments accepted by the optimizers, and will be used
as optimization options for this group.
For example, this is very useful when one wants to specify perlayer learning rates:
optim.SGD([
{'params': model.base.parameters(), 'lr': 1e2},
{'params': model.classifier.parameters()}
], lr=1e3, momentum=0.9)
optim.SGD([
{'params': model.base.named_parameters(), 'lr': 1e2},
{'params': model.classifier.named_parameters()}
], lr=1e3, momentum=0.9)
This means that model.base
’s parameters will use a learning rate of 1e2
, whereas
model.classifier
’s parameters will stick to the default learning rate of 1e3
.
Finally a momentum of 0.9
will be used for all parameters.
Note
You can still pass options as keyword arguments. They will be used as defaults, in the groups that didn’t override them. This is useful when you only want to vary a single option, while keeping all others consistent between parameter groups.
Also consider the following example related to the distinct penalization of parameters.
Remember that parameters()
returns an iterable that
contains all learnable parameters, including biases and other
parameters that may prefer distinct penalization. To address this, one can specify
individual penalization weights for each parameter group:
bias_params = [p for name, p in self.named_parameters() if 'bias' in name]
others = [p for name, p in self.named_parameters() if 'bias' not in name]
optim.SGD([
{'params': others},
{'params': bias_params, 'weight_decay': 0}
], weight_decay=1e2, lr=1e2)
In this manner, bias terms are isolated from nonbias terms, and a weight_decay
of 0
is set specifically for the bias terms, as to avoid any penalization for
this group.
Taking an optimization step¶
All optimizers implement a step()
method, that updates the
parameters. It can be used in two ways:
optimizer.step()
¶
This is a simplified version supported by most optimizers. The function can be
called once the gradients are computed using e.g.
backward()
.
Example:
for input, target in dataset:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
optimizer.step(closure)
¶
Some optimization algorithms such as Conjugate Gradient and LBFGS need to reevaluate the function multiple times, so you have to pass in a closure that allows them to recompute your model. The closure should clear the gradients, compute the loss, and return it.
Example:
for input, target in dataset:
def closure():
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
return loss
optimizer.step(closure)
Base class¶
 class torch.optim.Optimizer(params, defaults)[source]¶
Base class for all optimizers.
Warning
Parameters need to be specified as collections that have a deterministic ordering that is consistent between runs. Examples of objects that don’t satisfy those properties are sets and iterators over values of dictionaries.
 Parameters
params (iterable) – an iterable of
torch.Tensor
s ordict
s. Specifies what Tensors should be optimized.defaults (Dict[str, Any]) – (dict): a dict containing default values of optimization options (used when a parameter group doesn’t specify them).
Add a param group to the 

Load the optimizer state. 

Register a load_state_dict prehook which will be called before 

Register a load_state_dict posthook which will be called after 

Return the state of the optimizer as a 

Register a state dict prehook which will be called before 

Register a state dict posthook which will be called after 

Perform a single optimization step to update parameter. 

Register an optimizer step pre hook which will be called before optimizer step. 

Register an optimizer step post hook which will be called after optimizer step. 

Reset the gradients of all optimized 
Algorithms¶
Implements Adadelta algorithm. 

Implements Adafactor algorithm. 

Implements Adagrad algorithm. 

Implements Adam algorithm. 

Implements AdamW algorithm. 

SparseAdam implements a masked version of the Adam algorithm suitable for sparse gradients. 

Implements Adamax algorithm (a variant of Adam based on infinity norm). 

Implements Averaged Stochastic Gradient Descent. 

Implements LBFGS algorithm. 

Implements NAdam algorithm. 

Implements RAdam algorithm. 

Implements RMSprop algorithm. 

Implements the resilient backpropagation algorithm. 

Implements stochastic gradient descent (optionally with momentum). 
Many of our algorithms have various implementations optimized for performance, readability and/or generality, so we attempt to default to the generally fastest implementation for the current device if no particular implementation has been specified by the user.
We have 3 major categories of implementations: forloop, foreach (multitensor), and fused. The most straightforward implementations are forloops over the parameters with big chunks of computation. Forlooping is usually slower than our foreach implementations, which combine parameters into a multitensor and run the big chunks of computation all at once, thereby saving many sequential kernel calls. A few of our optimizers have even faster fused implementations, which fuse the big chunks of computation into one kernel. We can think of foreach implementations as fusing horizontally and fused implementations as fusing vertically on top of that.
In general, the performance ordering of the 3 implementations is fused > foreach > forloop. So when applicable, we default to foreach over forloop. Applicable means the foreach implementation is available, the user has not specified any implementationspecific kwargs (e.g., fused, foreach, differentiable), and all tensors are native. Note that while fused should be even faster than foreach, the implementations are newer and we would like to give them more bakein time before flipping the switch everywhere. We summarize the stability status for each implementation on the second table below, you are welcome to try them out though!
Below is a table showing the available and default implementations of each algorithm:
Algorithm 
Default 
Has foreach? 
Has fused? 

foreach 
yes 
no 

forloop 
no 
no 

foreach 
yes 
yes (cpu only) 

foreach 
yes 
yes 

foreach 
yes 
yes 

forloop 
no 
no 

foreach 
yes 
no 

foreach 
yes 
no 

forloop 
no 
no 

foreach 
yes 
no 

foreach 
yes 
no 

foreach 
yes 
no 

foreach 
yes 
no 

foreach 
yes 
yes 
Below table is showing the stability status for fused implementations:
Algorithm 
CPU 
CUDA 
MPS 

unsupported 
unsupported 
unsupported 

unsupported 
unsupported 
unsupported 

beta 
unsupported 
unsupported 

beta 
stable 
beta 

beta 
stable 
beta 

unsupported 
unsupported 
unsupported 

unsupported 
unsupported 
unsupported 

unsupported 
unsupported 
unsupported 

unsupported 
unsupported 
unsupported 

unsupported 
unsupported 
unsupported 

unsupported 
unsupported 
unsupported 

unsupported 
unsupported 
unsupported 

unsupported 
unsupported 
unsupported 

beta 
beta 
beta 
How to adjust learning rate¶
torch.optim.lr_scheduler.LRScheduler
provides several methods to adjust the learning
rate based on the number of epochs. torch.optim.lr_scheduler.ReduceLROnPlateau
allows dynamic learning rate reducing based on some validation measurements.
Learning rate scheduling should be applied after optimizer’s update; e.g., you should write your code this way:
Example:
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = ExponentialLR(optimizer, gamma=0.9)
for epoch in range(20):
for input, target in dataset:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
scheduler.step()
Most learning rate schedulers can be called backtoback (also referred to as chaining schedulers). The result is that each scheduler is applied one after the other on the learning rate obtained by the one preceding it.
Example:
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler1 = ExponentialLR(optimizer, gamma=0.9)
scheduler2 = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
for epoch in range(20):
for input, target in dataset:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
scheduler1.step()
scheduler2.step()
In many places in the documentation, we will use the following template to refer to schedulers algorithms.
>>> scheduler = ...
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
Warning
Prior to PyTorch 1.1.0, the learning rate scheduler was expected to be called before
the optimizer’s update; 1.1.0 changed this behavior in a BCbreaking way. If you use
the learning rate scheduler (calling scheduler.step()
) before the optimizer’s update
(calling optimizer.step()
), this will skip the first value of the learning rate schedule.
If you are unable to reproduce results after upgrading to PyTorch 1.1.0, please check
if you are calling scheduler.step()
at the wrong time.
Adjusts the learning rate during optimization. 

Sets the initial learning rate. 

Multiply the learning rate of each parameter group by the factor given in the specified function. 

Decays the learning rate of each parameter group by gamma every step_size epochs. 

Decays the learning rate of each parameter group by gamma once the number of epoch reaches one of the milestones. 

Multiply the learning rate of each parameter group by a small constant factor. 

Decays the learning rate of each parameter group by linearly changing small multiplicative factor. 

Decays the learning rate of each parameter group by gamma every epoch. 

Decays the learning rate of each parameter group using a polynomial function in the given total_iters. 

Set the learning rate of each parameter group using a cosine annealing schedule. 

Chains a list of learning rate schedulers. 

Contains a list of schedulers expected to be called sequentially during the optimization process. 

Reduce learning rate when a metric has stopped improving. 

Sets the learning rate of each parameter group according to cyclical learning rate policy (CLR). 

Sets the learning rate of each parameter group according to the 1cycle learning rate policy. 

Set the learning rate of each parameter group using a cosine annealing schedule. 
How to utilize named parameters to load optimizer state dict¶
The function load_state_dict()
stores the optional param_names
content from the
loaded state dict if present. However, the process of loading the optimizer state is not affected,
as the order of the parameters matters to maintain compatibility (in case of different ordering).
To utilize the loaded parameters names from the loaded state dict, a custom register_load_state_dict_pre_hook
needs to be implemented according to the desired behavior.
This can be useful, for instance, when the model architecture changes, but the weights and optimizer states need to remain unchanged. The following example demonstrates how to implement this customization.
Example:
class OneLayerModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(3, 4)
def forward(self, x):
return self.fc(x)
model = OneLayerModel()
optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
# training..
torch.save(optimizer.state_dict(), PATH)
Let’s say that model
implements an expert (MoE), and we want to duplicate it and resume training
for two experts, both initialized the same way as the fc
layer. For the following model2
we create two layers identical to fc
and resume training by loading the model weights and optimizer states from model
into both fc1
and fc2
of model2
(and adjust them accordingly):
class TwoLayerModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(3, 4)
self.fc2 = nn.Linear(3, 4)
def forward(self, x):
return (self.fc1(x) + self.fc2(x)) / 2
model2 = TwoLayerModel()
# adapt and load model weights..
optimizer2 = optim.SGD(model2.named_parameters(), lr=0.01, momentum=0.9)
To load the state dict for optimizer2
with the state dict of the previous optimizer such that both
fc1
and fc2
will be initialized with a copy of fc
optimizer states
(to resume training for each layer from fc
), we can use the following hook:
def adapt_state_dict_ids(optimizer, state_dict):
adapted_state_dict = deepcopy(optimizer.state_dict())
# Copy setup parameters (lr, weight_decay, etc.), in case they differ in the loaded state dict.
for k, v in state_dict['param_groups'][0].items():
if k not in ['params', 'param_names']:
adapted_state_dict['param_groups'][0][k] = v
lookup_dict = {
'fc1.weight': 'fc.weight',
'fc1.bias': 'fc.bias',
'fc2.weight': 'fc.weight',
'fc2.bias': 'fc.bias'
}
clone_deepcopy = lambda d: {k: (v.clone() if isinstance(v, torch.Tensor) else deepcopy(v)) for k, v in d.items()}
for param_id, param_name in zip(
optimizer.state_dict()['param_groups'][0]['params'],
optimizer.state_dict()['param_groups'][0]['param_names']):
name_in_loaded = lookup_dict[param_name]
index_in_loaded_list = state_dict['param_groups'][0]['param_names'].index(name_in_loaded)
id_in_loaded = state_dict['param_groups'][0]['params'][index_in_loaded_list]
# Copy the state of the corresponding parameter
if id_in_loaded in state_dict['state']:
adapted_state_dict['state'][param_id] = clone_deepcopy(state_dict['state'][id_in_loaded])
return adapted_state_dict
optimizer2.register_load_state_dict_pre_hook(adapt_state_dict_ids)
optimizer2.load_state_dict(torch.load(PATH)) # The previous optimizer saved state_dict
This ensures that the adapted state_dict with the correct states for the layers of model2
will be used
during model loading.
Note that this code is designed specifically for this example (e.g., assuming a single parameter group),
and other cases might require different adaptations.
The following example shows how to handle missing parameters in a loaded
state dict
when the model structure changes.
The Model_bypass
adds a new bypass
layer, which is not present in the original Model1
.
To resume training, a custom adapt_state_dict_missing_param
hook is used to adapt the optimizer’s state_dict
,
ensuring existing parameters are mapped correctly, while missing ones (like the bypass layer) remain unchanged
(as initialized in this example).
This approach enables smooth loading and resuming of the optimizer state despite model changes.
The new bypass layer will be trained from scratch:
class Model1(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(5, 5)
def forward(self, x):
return self.fc(x) + x
model = Model1()
optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
# training..
torch.save(optimizer.state_dict(), PATH)
class Model_bypass(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(5, 5)
self.bypass = nn.Linear(5, 5, bias=False)
torch.nn.init.eye_(self.bypass.weight)
def forward(self, x):
return self.fc(x) + self.bypass(x)
model2 = Model_bypass()
optimizer2 = optim.SGD(model2.named_parameters(), lr=0.01, momentum=0.9)
def adapt_state_dict_missing_param(optimizer, state_dict):
adapted_state_dict = deepcopy(optimizer.state_dict())
# Copy setup parameters (lr, weight_decay, etc.), in case they differ in the loaded state dict.
for k, v in state_dict['param_groups'][0].items():
if k not in ['params', 'param_names']:
adapted_state_dict['param_groups'][0][k] = v
lookup_dict = {
'fc.weight': 'fc.weight',
'fc.bias': 'fc.bias',
'bypass.weight': None,
}
clone_deepcopy = lambda d: {k: (v.clone() if isinstance(v, torch.Tensor) else deepcopy(v)) for k, v in d.items()}
for param_id, param_name in zip(
optimizer.state_dict()['param_groups'][0]['params'],
optimizer.state_dict()['param_groups'][0]['param_names']):
name_in_loaded = lookup_dict[param_name]
if name_in_loaded in state_dict['param_groups'][0]['param_names']:
index_in_loaded_list = state_dict['param_groups'][0]['param_names'].index(name_in_loaded)
id_in_loaded = state_dict['param_groups'][0]['params'][index_in_loaded_list]
# Copy the state of the corresponding parameter
if id_in_loaded in state_dict['state']:
adapted_state_dict['state'][param_id] = clone_deepcopy(state_dict['state'][id_in_loaded])
return adapted_state_dict
optimizer2.register_load_state_dict_pre_hook(adapt_state_dict_ids)
optimizer2.load_state_dict(torch.load(PATH)) # The previous optimizer saved state_dict
As a third example, instead of loading a state according to the order of parameters (the default approach), this hook can be used to load according to the parameters’ names:
def names_matching(optimizer, state_dict):
assert len(state_dict['param_groups']) == len(optimizer.state_dict()['param_groups'])
adapted_state_dict = deepcopy(optimizer.state_dict())
for g_ind in range(len(state_dict['param_groups'])):
assert len(state_dict['param_groups'][g_ind]['params']) == len(
optimizer.state_dict()['param_groups'][g_ind]['params'])
for k, v in state_dict['param_groups'][g_ind].items():
if k not in ['params', 'param_names']:
adapted_state_dict['param_groups'][g_ind][k] = v
for param_id, param_name in zip(
optimizer.state_dict()['param_groups'][g_ind]['params'],
optimizer.state_dict()['param_groups'][g_ind]['param_names']):
index_in_loaded_list = state_dict['param_groups'][g_ind]['param_names'].index(param_name)
id_in_loaded = state_dict['param_groups'][g_ind]['params'][index_in_loaded_list]
# Copy the state of the corresponding parameter
if id_in_loaded in state_dict['state']:
adapted_state_dict['state'][param_id] = deepcopy(state_dict['state'][id_in_loaded])
return adapted_state_dict
Weight Averaging (SWA and EMA)¶
torch.optim.swa_utils.AveragedModel
implements Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA),
torch.optim.swa_utils.SWALR
implements the SWA learning rate scheduler and
torch.optim.swa_utils.update_bn()
is a utility function used to update SWA/EMA batch
normalization statistics at the end of training.
SWA has been proposed in Averaging Weights Leads to Wider Optima and Better Generalization.
EMA is a widely known technique to reduce the training time by reducing the number of weight updates needed. It is a variation of Polyak averaging, but using exponential weights instead of equal weights across iterations.
Constructing averaged models¶
The AveragedModel class serves to compute the weights of the SWA or EMA model.
You can create an SWA averaged model by running:
>>> averaged_model = AveragedModel(model)
EMA models are constructed by specifying the multi_avg_fn
argument as follows:
>>> decay = 0.999
>>> averaged_model = AveragedModel(model, multi_avg_fn=get_ema_multi_avg_fn(decay))
Decay is a parameter between 0 and 1 that controls how fast the averaged parameters are decayed. If not provided to torch.optim.swa_utils.get_ema_multi_avg_fn()
, the default is 0.999. Decay value should be close to 1.0, as smaller values can cause optimization convergence issues.
torch.optim.swa_utils.get_ema_multi_avg_fn()
returns a function that applies the following EMA equation to the weights:
where alpha is the EMA decay.
Here the model model
can be an arbitrary torch.nn.Module
object. averaged_model
will keep track of the running averages of the parameters of the model
. To update these
averages, you should use the update_parameters()
function after the optimizer.step():
>>> averaged_model.update_parameters(model)
For SWA and EMA, this call is usually done right after the optimizer step()
. In the case of SWA, this is usually skipped for some numbers of steps at the beginning of the training.
Custom averaging strategies¶
By default, torch.optim.swa_utils.AveragedModel
computes a running equal average of
the parameters that you provide, but you can also use custom averaging functions with the
avg_fn
or multi_avg_fn
parameters:
avg_fn
allows defining a function operating on each parameter tuple (averaged parameter, model parameter) and should return the new averaged parameter.multi_avg_fn
allows defining more efficient operations acting on a tuple of parameter lists, (averaged parameter list, model parameter list), at the same time, for example using thetorch._foreach*
functions. This function must update the averaged parameters inplace.
In the following example ema_model
computes an exponential moving average using the avg_fn
parameter:
>>> ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged:\
>>> 0.9 * averaged_model_parameter + 0.1 * model_parameter
>>> ema_model = torch.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg)
In the following example ema_model
computes an exponential moving average using the more efficient multi_avg_fn
parameter:
>>> ema_model = AveragedModel(model, multi_avg_fn=get_ema_multi_avg_fn(0.9))
SWA learning rate schedules¶
Typically, in SWA the learning rate is set to a high constant value. SWALR
is a
learning rate scheduler that anneals the learning rate to a fixed value, and then keeps it
constant. For example, the following code creates a scheduler that linearly anneals the
learning rate from its initial value to 0.05 in 5 epochs within each parameter group:
>>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, \
>>> anneal_strategy="linear", anneal_epochs=5, swa_lr=0.05)
You can also use cosine annealing to a fixed value instead of linear annealing by setting
anneal_strategy="cos"
.
Taking care of batch normalization¶
update_bn()
is a utility function that allows to compute the batchnorm statistics for the SWA model
on a given dataloader loader
at the end of training:
>>> torch.optim.swa_utils.update_bn(loader, swa_model)
update_bn()
applies the swa_model
to every element in the dataloader and computes the activation
statistics for each batch normalization layer in the model.
Warning
update_bn()
assumes that each batch in the dataloader loader
is either a tensors or a list of
tensors where the first element is the tensor that the network swa_model
should be applied to.
If your dataloader has a different structure, you can update the batch normalization statistics of the
swa_model
by doing a forward pass with the swa_model
on each element of the dataset.
Putting it all together: SWA¶
In the example below, swa_model
is the SWA model that accumulates the averages of the weights.
We train the model for a total of 300 epochs and we switch to the SWA learning rate schedule
and start to collect SWA averages of the parameters at epoch 160:
>>> loader, optimizer, model, loss_fn = ...
>>> swa_model = torch.optim.swa_utils.AveragedModel(model)
>>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)
>>> swa_start = 160
>>> swa_scheduler = SWALR(optimizer, swa_lr=0.05)
>>>
>>> for epoch in range(300):
>>> for input, target in loader:
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
>>> if epoch > swa_start:
>>> swa_model.update_parameters(model)
>>> swa_scheduler.step()
>>> else:
>>> scheduler.step()
>>>
>>> # Update bn statistics for the swa_model at the end
>>> torch.optim.swa_utils.update_bn(loader, swa_model)
>>> # Use swa_model to make predictions on test data
>>> preds = swa_model(test_input)
Putting it all together: EMA¶
In the example below, ema_model
is the EMA model that accumulates the exponentiallydecayed averages of the weights with a decay rate of 0.999.
We train the model for a total of 300 epochs and start to collect EMA averages immediately.
>>> loader, optimizer, model, loss_fn = ...
>>> ema_model = torch.optim.swa_utils.AveragedModel(model, \
>>> multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.999))
>>>
>>> for epoch in range(300):
>>> for input, target in loader:
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
>>> ema_model.update_parameters(model)
>>>
>>> # Update bn statistics for the ema_model at the end
>>> torch.optim.swa_utils.update_bn(loader, ema_model)
>>> # Use ema_model to make predictions on test data
>>> preds = ema_model(test_input)
Implements averaged model for Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA). 

Anneals the learning rate in each parameter group to a fixed value. 
 torch.optim.swa_utils.get_ema_multi_avg_fn(decay=0.999)[source]¶
Get the function applying exponential moving average (EMA) across multiple params.
 torch.optim.swa_utils.update_bn(loader, model, device=None)[source]¶
Update BatchNorm running_mean, running_var buffers in the model.
It performs one pass over data in loader to estimate the activation statistics for BatchNorm layers in the model.
 Parameters
loader (torch.utils.data.DataLoader) – dataset loader to compute the activation statistics on. Each data batch should be either a tensor, or a list/tuple whose first element is a tensor containing data.
model (torch.nn.Module) – model for which we seek to update BatchNorm statistics.
device (torch.device, optional) – If set, data will be transferred to
device
before being passed intomodel
.
Example
>>> loader, model = ... >>> torch.optim.swa_utils.update_bn(loader, model)
Note
The update_bn utility assumes that each data batch in
loader
is either a tensor or a list or tuple of tensors; in the latter case it is assumed thatmodel.forward()
should be called on the first element of the list or tuple corresponding to the data batch.