.. currentmodule:: torch.func

.. _ux-limitations:

UX Limitations
==============

torch.func, like `JAX <https://github.com/google/jax>`_, has restrictions around
what can be transformed. In general, JAX’s limitations are that transforms
only work with pure functions: that is, functions where the output is completely
determined by the input and that do not involve side effects (like mutation).

We have a similar guarantee: our transforms work well with pure functions.
However, we do support certain in-place operations. On one hand, writing code
compatible with function transforms may involve changing how you write PyTorch
code, on the other hand, you may find that our transforms let you express things
that were previously difficult to express in PyTorch.

General limitations
-------------------

All torch.func transforms share a limitation in that a function should not
assign to global variables. Instead, all outputs to a function must be returned
from the function. This restriction comes from how torch.func is implemented:
each transform wraps Tensor inputs in special torch.func Tensor subclasses
that facilitate the transform.

So, instead of the following:

::

  import torch
  from torch.func import grad

  # Don't do this
  intermediate = None

  def f(x):
    global intermediate
    intermediate = x.sin()
    z = intermediate.sin()
    return z

  x = torch.randn([])
  grad_x = grad(f)(x)

Please rewrite ``f`` to return ``intermediate``:

::

  def f(x):
    intermediate = x.sin()
    z = intermediate.sin()
    return z, intermediate

  grad_x, intermediate = grad(f, has_aux=True)(x)

torch.autograd APIs
-------------------

If you are trying to use a ``torch.autograd`` API like ``torch.autograd.grad``
or ``torch.autograd.backward`` inside of a function being transformed by
:func:`vmap` or one of torch.func's AD transforms (:func:`vjp`, :func:`jvp`,
:func:`jacrev`, :func:`jacfwd`), the transform may not be able to transform over it.
If it is unable to do so, you'll receive an error message.

This is a fundamental design limitation in how PyTorch's AD support is implemented
and the reason why we designed the torch.func library. Please instead use the torch.func
equivalents of the ``torch.autograd`` APIs:
- ``torch.autograd.grad``, ``Tensor.backward`` -> ``torch.func.vjp`` or ``torch.func.grad``
- ``torch.autograd.functional.jvp`` -> ``torch.func.jvp``
- ``torch.autograd.functional.jacobian`` -> ``torch.func.jacrev`` or ``torch.func.jacfwd``
- ``torch.autograd.functional.hessian`` -> ``torch.func.hessian``

vmap limitations
----------------

.. note::
  :func:`vmap` is our most restrictive transform.
  The grad-related transforms (:func:`grad`, :func:`vjp`, :func:`jvp`) do not
  have these limitations. :func:`jacfwd` (and :func:`hessian`, which is
  implemented with :func:`jacfwd`) is a composition of :func:`vmap` and
  :func:`jvp` so it also has these limitations.

``vmap(func)`` is a transform that returns a function that maps ``func`` over
some new dimension of each input Tensor. The mental model for vmap is that it is
like running a for-loop: for pure functions (i.e. in the absence of side
effects), ``vmap(f)(x)`` is equivalent to:

::

  torch.stack([f(x_i) for x_i in x.unbind(0)])

Mutation: Arbitrary mutation of Python data structures
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In the presence of side effects, :func:`vmap` no longer acts like it is running
a for-loop. For example, the following function:

::

  def f(x, list):
    list.pop()
    print("hello!")
    return x.sum(0)

  x = torch.randn(3, 1)
  lst = [0, 1, 2, 3]

  result = vmap(f, in_dims=(0, None))(x, lst)

will print "hello!" once and pop only one element from ``lst``.


:func:`vmap` executes ``f`` a single time, so all side effects only happen once.

This is a consequence of how vmap is implemented. torch.func has a special,
internal BatchedTensor class. ``vmap(f)(*inputs)`` takes all Tensor inputs,
turns them into BatchedTensors, and calls ``f(*batched_tensor_inputs)``.
BatchedTensor overrides the PyTorch API to produce batched (i.e. vectorized)
behavior for each PyTorch operator.


Mutation: in-place PyTorch Operations
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

You might be here due to receiving an error about vmap-incompatible in-place
operations. :func:`vmap` will raise an error if it encounters an unsupported PyTorch
in-place operation and it will succeed otherwise. Unsupported operations
are those that would cause a Tensor with more elements to be written to a
Tensor with fewer elements. Here's an example of how this can occur:

::

  def f(x, y):
    x.add_(y)
    return x

  x = torch.randn(1)
  y = torch.randn(3, 1)  # When vmapped over, looks like it has shape [1]

  # Raises an error because `x` has fewer elements than `y`.
  vmap(f, in_dims=(None, 0))(x, y)

``x`` is a Tensor with one element, ``y`` is a Tensor with three elements.
``x + y`` has three elements (due to broadcasting), but attempting to write
three elements back into ``x``, which only has one element, raises an error
due to attempting to write three elements into a Tensor with a single element.

There is no problem if the Tensor being written to is batched under
:func:`~torch.vmap` (i.e. it is being vmapped over).

::

  def f(x, y):
    x.add_(y)
    return x

  x = torch.randn(3, 1)
  y = torch.randn(3, 1)
  expected = x + y

  # Does not raise an error because x is being vmapped over.
  vmap(f, in_dims=(0, 0))(x, y)
  assert torch.allclose(x, expected)

One common fix for this is to replace calls to factory functions with
their "new_*" equivalent. For example:

- Replace :func:`torch.zeros` with :meth:`Tensor.new_zeros`
- Replace :func:`torch.empty` with :meth:`Tensor.new_empty`

To see why this helps, consider the following.

::

  def diag_embed(vec):
    assert vec.dim() == 1
    result = torch.zeros(vec.shape[0], vec.shape[0])
    result.diagonal().copy_(vec)
    return result

  vecs = torch.tensor([[0., 1, 2], [3., 4, 5]])

  # RuntimeError: vmap: inplace arithmetic(self, *extra_args) is not possible ...
  vmap(diag_embed)(vecs)

Inside of :func:`~torch.vmap`, ``result`` is a Tensor of shape [3, 3].
However, although ``vec`` looks like it has shape [3], ``vec`` actually has
underlying shape [2, 3].
It is not possible to copy ``vec`` into ``result.diagonal()``, which has
shape [3], because it has too many elements.

::

  def diag_embed(vec):
    assert vec.dim() == 1
    result = vec.new_zeros(vec.shape[0], vec.shape[0])
    result.diagonal().copy_(vec)
    return result

  vecs = torch.tensor([[0., 1, 2], [3., 4, 5]])
  vmap(diag_embed)(vecs)

Replacing :func:`torch.zeros` with :meth:`Tensor.new_zeros` makes it so that
``result`` has an underlying Tensor of shape [2, 3, 3], so it is now possible
to copy ``vec``, which has underlying shape [2, 3], into ``result.diagonal()``.


Mutation: out= PyTorch Operations
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
:func:`vmap` doesn't support the ``out=`` keyword argument in PyTorch operations.
It will error out gracefully if it encounters that in your code.

This is not a fundamental limitation; we could theoretically support this in the
future but we have chosen not to for now.

Data-dependent Python control flow
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
We don't yet support ``vmap`` over data-dependent control flow. Data-dependent
control flow is when the condition of an if-statement, while-loop, or
for-loop is a Tensor that is being ``vmap``'ed over. For example, the
following will raise an error message:

::

  def relu(x):
    if x > 0:
      return x
    return 0

  x = torch.randn(3)
  vmap(relu)(x)

However, any control flow that is not dependent on the values in ``vmap``'ed
tensors will work:

::

  def custom_dot(x):
    if x.dim() == 1:
      return torch.dot(x, x)
    return (x * x).sum()

  x = torch.randn(3)
  vmap(custom_dot)(x)

JAX supports transforming over
`data-dependent control flow <https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators>`_
using special control flow operators (e.g. ``jax.lax.cond``, ``jax.lax.while_loop``).
We're investigating adding equivalents of those to PyTorch.

Data-dependent operations (.item())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
We do not (and will not) support vmap over a user-defined function that calls
``.item()`` on a Tensor. For example, the following will raise an error message:

::

  def f(x):
    return x.item()

  x = torch.randn(3)
  vmap(f)(x)

Please try to rewrite your code to not use ``.item()`` calls.

You may also encounter an error message about using ``.item()`` but you might
not have used it. In those cases, it is possible that PyTorch internally is
calling ``.item()`` -- please file an issue on GitHub and we'll fix
PyTorch internals.

Dynamic shape operations (nonzero and friends)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
``vmap(f)`` requires that ``f`` applied to every "example" in your input
returns a Tensor with the same shape. Operations such as ``torch.nonzero``,
``torch.is_nonzero`` are not supported and will error as a result.

To see why, consider the following example:

::

  xs = torch.tensor([[0, 1, 2], [0, 0, 3]])
  vmap(torch.nonzero)(xs)

``torch.nonzero(xs[0])`` returns a Tensor of shape 2;
but ``torch.nonzero(xs[1])`` returns a Tensor of shape 1.
We are unable to construct a single Tensor as an output;
the output would need to be a ragged Tensor (and PyTorch does not yet have
the concept of a ragged Tensor).


Randomness
----------
The user's intention when calling a random operation can be unclear. Specifically, some users may want
the random behavior to be the same across batches while others may want it to differ across batches.
To address this, ``vmap`` takes a randomness flag.

The flag can only be passed to vmap and can take on 3 values, "error," "different," or "same," defaulting
to error. Under "error" mode, any call to a random function will produce an error asking the user to use
one of the other two flags based on their use case.

Under "different" randomness, elements in a batch produce different random values. For instance,

::

  def add_noise(x):
    y = torch.randn(())  # y will be different across the batch
    return x + y

  x = torch.ones(3)
  result = vmap(add_noise, randomness="different")(x)  # we get 3 different values

Under "same" randomness, elements in a batch produce same random values. For instance,

::

  def add_noise(x):
    y = torch.randn(())  # y will be the same across the batch
    return x + y

  x = torch.ones(3)
  result = vmap(add_noise, randomness="same")(x)  # we get the same value, repeated 3 times


.. warning::
    Our system only determine the randomness behavior of PyTorch operators and cannot control the
    behavior of other libraries, like numpy. This is similar to JAX's limitations with their solutions

.. note::
    Multiple vmap calls using either type of supported randomness will not produce
    the same results. Like with standard PyTorch, a user can get randomness reproducibility through
    either using ``torch.manual_seed()`` outside of vmap or by using generators.

.. note::
    Finally, our randomness differs from JAX because we aren't using a stateless PRNG, in part because PyTorch
    doesn't have full support for a stateless PRNG. Instead, we've introduced a flag system to allow for the
    most common forms of randomness that we see. If your use case does not fit these forms of randomness, please
    file an issue.