Patching Batch Norm
===================

What's happening?
-----------------
Batch Norm requires in-place updates to running_mean and running_var of the same size as the input.
Functorch does not support inplace update to a regular tensor that takes in a batched tensor (i.e.
``regular.add_(batched)`` is not allowed). So when vmapping over a batch of inputs to a single module,
we end up with this error

How to fix
----------
One of the best supported ways is to switch BatchNorm for GroupNorm. Options 1 and 2 support this

All of these options assume that you don't need running stats. If you're using a module this means
that it's assumed you won't use batch norm in evaluation mode. If you have a use case that involves
running batch norm with vmap in evaluation mode, please file an issue

Option 1: Change the BatchNorm
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
If you want to change for GroupNorm, anywhere that you have BatchNorm, replace it with:

.. code-block:: python

    BatchNorm2d(C, G, track_running_stats=False)

Here ``C`` is the same ``C`` as in the original BatchNorm. ``G`` is the number of groups to
break ``C`` into. As such, ``C % G == 0`` and as a fallback, you can set ``C == G``, meaning
each channel will be treated separately.

If you must use BatchNorm and you've built the module yourself, you can change the module to
not use running stats. In other words, anywhere that there's a BatchNorm module, set the
``track_running_stats`` flag to be False

.. code-block:: python

    BatchNorm2d(64, track_running_stats=False)


Option 2: torchvision parameter
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Some torchvision models, like resnet and regnet, can take in a ``norm_layer`` parameter. These are
often defaulted to be BatchNorm2d if they've been defaulted.

Instead you can set it to be GroupNorm.

.. code-block:: python

    import torchvision
    from functools import partial
    torchvision.models.resnet18(norm_layer=lambda c: GroupNorm(num_groups=g, c))

Here, once again, ``c % g == 0`` so as a fallback, set ``g = c``.

If you are attached to BatchNorm, be sure to use a version that doesn't use running stats

.. code-block:: python

    import torchvision
    from functools import partial
    torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False))

Option 3: functorch's patching
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
functorch has added some functionality to allow for quick, in-place patching of the module to not
use running stats. Changing the norm layer is more fragile, so we have not offered that. If you
have a net where you want the BatchNorm to not use running stats, you can run
``replace_all_batch_norm_modules_`` to update the module in-place to not use running stats

.. code-block:: python

    from torch.func import replace_all_batch_norm_modules_
    replace_all_batch_norm_modules_(net)

Option 4: eval mode
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
When run under eval mode, the running_mean and running_var will not be updated. Therefore, vmap can support this mode

.. code-block:: python

    model.eval()
    vmap(model)(x)
    model.train()