Shortcuts

Source code for torch._functorch.deprecated

# mypy: allow-untyped-defs
"""
The APIs in this file are exposed as `functorch.*`. They are thin wrappers
around the torch.func.* APIs that have deprecation warnings -- we're trying
to move people to the torch.func.* equivalents.

NB: We don't use *args, **kwargs in the signatures because that changes the
documentation.
"""

import textwrap
import warnings
from typing import Any, Callable, Optional, Tuple, Union

import torch._functorch.apis as apis
import torch._functorch.eager_transforms as _impl
import torch._functorch.make_functional as _nn_impl
import torch.nn as nn
from torch._functorch.eager_transforms import argnums_t
from torch._functorch.vmap import in_dims_t, out_dims_t


def get_warning(api, new_api=None, replace_newlines=False):
    if new_api is None:
        new_api = f"torch.func.{api}"
    warning = (
        f"We've integrated functorch into PyTorch. As the final step of the \n"
        f"integration, `functorch.{api}` is deprecated as of PyTorch \n"
        f"2.0 and will be deleted in a future version of PyTorch >= 2.3. \n"
        f"Please use `{new_api}` instead; see the PyTorch 2.0 release notes \n"
        f"and/or the `torch.func` migration guide for more details \n"
        f"https://pytorch.org/docs/main/func.migrating.html"
    )
    if replace_newlines:
        warning = warning.replace("\n", "")
    return warning


def warn_deprecated(api, new_api=None):
    warning = get_warning(api, new_api, replace_newlines=True)
    warnings.warn(warning, FutureWarning, stacklevel=3)


def setup_docs(functorch_api, torch_func_api=None, new_api_name=None):
    api_name = functorch_api.__name__
    if torch_func_api is None:
        torch_func_api = getattr(_impl, api_name)
    # See https://docs.python.org/3/using/cmdline.html#cmdoption-OO
    if torch_func_api.__doc__ is None:
        return

    warning = get_warning(api_name, new_api_name)
    warning_note = "\n.. warning::\n\n" + textwrap.indent(warning, "    ")
    warning_note = textwrap.indent(warning_note, "    ")
    functorch_api.__doc__ = torch_func_api.__doc__ + warning_note


[docs]def vmap( func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0, randomness: str = "error", *, chunk_size=None, ) -> Callable: warn_deprecated("vmap", "torch.vmap") return apis.vmap(func, in_dims, out_dims, randomness, chunk_size=chunk_size)
[docs]def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable: warn_deprecated("grad") return apis.grad(func, argnums, has_aux)
[docs]def grad_and_value( func: Callable, argnums: argnums_t = 0, has_aux: bool = False ) -> Callable: warn_deprecated("grad_and_value") return apis.grad_and_value(func, argnums, has_aux)
[docs]def vjp(func: Callable, *primals, has_aux: bool = False): warn_deprecated("vjp") return _impl.vjp(func, *primals, has_aux=has_aux)
[docs]def jvp( func: Callable, primals: Any, tangents: Any, *, strict: bool = False, has_aux: bool = False, ): warn_deprecated("jvp") return _impl.jvp(func, primals, tangents, strict=strict, has_aux=has_aux)
[docs]def jacrev( func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False, chunk_size: Optional[int] = None, _preallocate_and_copy=False, ): warn_deprecated("jacrev") return _impl.jacrev( func, argnums, has_aux=has_aux, chunk_size=chunk_size, _preallocate_and_copy=_preallocate_and_copy, )
[docs]def jacfwd( func: Callable, argnums: argnums_t = 0, has_aux: bool = False, *, randomness: str = "error", ): warn_deprecated("jacfwd") return _impl.jacfwd(func, argnums, has_aux, randomness=randomness)
[docs]def hessian(func, argnums=0): warn_deprecated("hessian") return _impl.hessian(func, argnums=argnums)
[docs]def functionalize(func: Callable, *, remove: str = "mutations") -> Callable: warn_deprecated("functionalize") return _impl.functionalize(func, remove=remove)
[docs]def make_functional(model: nn.Module, disable_autograd_tracking: bool = False): warn_deprecated("make_functional", "torch.func.functional_call") return _nn_impl.make_functional(model, disable_autograd_tracking)
[docs]def make_functional_with_buffers( model: nn.Module, disable_autograd_tracking: bool = False ): warn_deprecated("make_functional_with_buffers", "torch.func.functional_call") return _nn_impl.make_functional_with_buffers(model, disable_autograd_tracking)
[docs]def combine_state_for_ensemble(models): warn_deprecated("combine_state_for_ensemble", "torch.func.stack_module_state") return _nn_impl.combine_state_for_ensemble(models)
setup_docs(vmap, apis.vmap, "torch.vmap") setup_docs(grad, apis.grad) setup_docs(grad_and_value, apis.grad_and_value) setup_docs(vjp) setup_docs(jvp) setup_docs(jacrev) setup_docs(jacfwd) setup_docs(hessian) setup_docs(functionalize) setup_docs(make_functional, _nn_impl.make_functional, "torch.func.functional_call") setup_docs( make_functional_with_buffers, _nn_impl.make_functional, "torch.func.functional_call" ) setup_docs( combine_state_for_ensemble, _nn_impl.combine_state_for_ensemble, "torch.func.stack_module_state", )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources