Note
This page describes an internal API which is not intended to be used outside of the PyTorch codebase and can be modified or removed without notice.
Source code for torch.jit._async
# mypy: allow-untyped-defs
"""Async API.
This module contains the API for parallelism in TorchScript, notably:
* torch.jit.fork
* torch.jit.wait
This is not intended to be imported directly; please use the exposed
functionalities in `torch.jit`.
"""
import torch
from torch._jit_internal import Future
from torch.jit._builtins import _register_builtin
from torch.utils import set_module
set_module(Future, "torch.jit")
[docs]def fork(func, *args, **kwargs):
r"""
Create an asynchronous task executing `func` and a reference to the value of the result of this execution.
`fork` will return immediately, so the return value of `func` may not have been computed yet. To force completion
of the task and access the return value invoke `torch.jit.wait` on the Future. `fork` invoked
with a `func` which returns `T` is typed as `torch.jit.Future[T]`. `fork` calls can be arbitrarily
nested, and may be invoked with positional and keyword arguments.
Asynchronous execution will only occur when run in TorchScript. If run in pure python,
`fork` will not execute in parallel. `fork` will also not execute in parallel when invoked
while tracing, however the `fork` and `wait` calls will be captured in the exported IR Graph.
.. warning::
`fork` tasks will execute non-deterministically. We recommend only spawning
parallel fork tasks for pure functions that do not modify their inputs,
module attributes, or global state.
Args:
func (callable or torch.nn.Module): A Python function or `torch.nn.Module`
that will be invoked. If executed in TorchScript, it will execute asynchronously,
otherwise it will not. Traced invocations of fork will be captured in the IR.
``*args``, ``**kwargs``: arguments to invoke `func` with.
Returns:
`torch.jit.Future[T]`: a reference to the execution of `func`. The value `T`
can only be accessed by forcing completion of `func` through `torch.jit.wait`.
Example (fork a free function):
.. code-block:: python
import torch
from torch import Tensor
def foo(a : Tensor, b : int) -> Tensor:
return a + b
def bar(a):
fut : torch.jit.Future[Tensor] = torch.jit.fork(foo, a, b=2)
return torch.jit.wait(fut)
script_bar = torch.jit.script(bar)
input = torch.tensor(2)
# only the scripted version executes asynchronously
assert script_bar(input) == bar(input)
# trace is not run asynchronously, but fork is captured in IR
graph = torch.jit.trace(bar, (input,)).graph
assert "fork" in str(graph)
Example (fork a module method):
.. code-block:: python
import torch
from torch import Tensor
class AddMod(torch.nn.Module):
def forward(self, a: Tensor, b : int):
return a + b
class Mod(torch.nn.Module):
def __init__(self) -> None:
super(self).__init__()
self.mod = AddMod()
def forward(self, input):
fut = torch.jit.fork(self.mod, a, b=2)
return torch.jit.wait(fut)
input = torch.tensor(2)
mod = Mod()
assert mod(input) == torch.jit.script(mod).forward(input)
"""
return torch._C.fork(func, *args, **kwargs)
[docs]def wait(future):
r"""
Force completion of a `torch.jit.Future[T]` asynchronous task, returning the result of the task.
See :func:`~fork` for docs and examples.
Args:
future (torch.jit.Future[T]): an asynchronous task reference, created through `torch.jit.fork`
Returns:
`T`: the return value of the completed task
"""
return torch._C.wait(future)
_register_builtin(wait, "aten::wait")