Source code for torch.distributed.rpc.functions
# mypy: allow-untyped-defs
import functools
[docs]def async_execution(fn):
r"""
A decorator for a function indicating that the return value of the function
is guaranteed to be a :class:`~torch.futures.Future` object and this
function can run asynchronously on the RPC callee. More specifically, the
callee extracts the :class:`~torch.futures.Future` returned by the wrapped
function and installs subsequent processing steps as a callback to that
:class:`~torch.futures.Future`. The installed callback will read the value
from the :class:`~torch.futures.Future` when completed and send the
value back as the RPC response. That also means the returned
:class:`~torch.futures.Future` only exists on the callee side and is never
sent through RPC. This decorator is useful when the wrapped function's
(``fn``) execution needs to pause and resume due to, e.g., containing
:meth:`~torch.distributed.rpc.rpc_async` or waiting for other signals.
.. note:: To enable asynchronous execution, applications must pass the
function object returned by this decorator to RPC APIs. If RPC detected
attributes installed by this decorator, it knows that this function
returns a ``Future`` object and will handle that accordingly.
However, this does not mean this decorator has to be outmost one when
defining a function. For example, when combined with ``@staticmethod``
or ``@classmethod``, ``@rpc.functions.async_execution`` needs to be the
inner decorator to allow the target function be recognized as a static
or class function. This target function can still execute asynchronously
because, when accessed, the static or class method preserves attributes
installed by ``@rpc.functions.async_execution``.
Example::
The returned :class:`~torch.futures.Future` object can come from
:meth:`~torch.distributed.rpc.rpc_async`,
:meth:`~torch.futures.Future.then`, or :class:`~torch.futures.Future`
constructor. The example below shows directly using the
:class:`~torch.futures.Future` returned by
:meth:`~torch.futures.Future.then`.
>>> from torch.distributed import rpc
>>>
>>> # omitting setup and shutdown RPC
>>>
>>> # On all workers
>>> @rpc.functions.async_execution
>>> def async_add_chained(to, x, y, z):
>>> # This function runs on "worker1" and returns immediately when
>>> # the callback is installed through the `then(cb)` API. In the
>>> # mean time, the `rpc_async` to "worker2" can run concurrently.
>>> # When the return value of that `rpc_async` arrives at
>>> # "worker1", "worker1" will run the lambda function accordingly
>>> # and set the value for the previously returned `Future`, which
>>> # will then trigger RPC to send the result back to "worker0".
>>> return rpc.rpc_async(to, torch.add, args=(x, y)).then(
>>> lambda fut: fut.wait() + z
>>> )
>>>
>>> # On worker0
>>> # xdoctest: +SKIP
>>> ret = rpc.rpc_sync(
>>> "worker1",
>>> async_add_chained,
>>> args=("worker2", torch.ones(2), 1, 1)
>>> )
>>> print(ret) # prints tensor([3., 3.])
When combined with TorchScript decorators, this decorator must be the
outmost one.
>>> from torch import Tensor
>>> from torch.futures import Future
>>> from torch.distributed import rpc
>>>
>>> # omitting setup and shutdown RPC
>>>
>>> # On all workers
>>> @torch.jit.script
>>> def script_add(x: Tensor, y: Tensor) -> Tensor:
>>> return x + y
>>>
>>> @rpc.functions.async_execution
>>> @torch.jit.script
>>> def async_add(to: str, x: Tensor, y: Tensor) -> Future[Tensor]:
>>> return rpc.rpc_async(to, script_add, (x, y))
>>>
>>> # On worker0
>>> ret = rpc.rpc_sync(
>>> "worker1",
>>> async_add,
>>> args=("worker2", torch.ones(2), 1)
>>> )
>>> print(ret) # prints tensor([2., 2.])
When combined with static or class method, this decorator must be the
inner one.
>>> from torch.distributed import rpc
>>>
>>> # omitting setup and shutdown RPC
>>>
>>> # On all workers
>>> class AsyncExecutionClass:
>>>
>>> @staticmethod
>>> @rpc.functions.async_execution
>>> def static_async_add(to, x, y, z):
>>> return rpc.rpc_async(to, torch.add, args=(x, y)).then(
>>> lambda fut: fut.wait() + z
>>> )
>>>
>>> @classmethod
>>> @rpc.functions.async_execution
>>> def class_async_add(cls, to, x, y, z):
>>> ret_fut = torch.futures.Future()
>>> rpc.rpc_async(to, torch.add, args=(x, y)).then(
>>> lambda fut: ret_fut.set_result(fut.wait() + z)
>>> )
>>> return ret_fut
>>>
>>> @rpc.functions.async_execution
>>> def bound_async_add(self, to, x, y, z):
>>> return rpc.rpc_async(to, torch.add, args=(x, y)).then(
>>> lambda fut: fut.wait() + z
>>> )
>>>
>>> # On worker0
>>> ret = rpc.rpc_sync(
>>> "worker1",
>>> AsyncExecutionClass.static_async_add,
>>> args=("worker2", torch.ones(2), 1, 2)
>>> )
>>> print(ret) # prints tensor([4., 4.])
>>>
>>> ret = rpc.rpc_sync(
>>> "worker1",
>>> AsyncExecutionClass.class_async_add,
>>> args=("worker2", torch.ones(2), 1, 2)
>>> )
>>> print(ret) # prints tensor([4., 4.])
This decorator also works with RRef helpers, i.e., .
:meth:`torch.distributed.rpc.RRef.rpc_sync`,
:meth:`torch.distributed.rpc.RRef.rpc_async`, and
:meth:`torch.distributed.rpc.RRef.remote`.
>>> from torch.distributed import rpc
>>>
>>> # reuse the AsyncExecutionClass class above
>>> rref = rpc.remote("worker1", AsyncExecutionClass)
>>> ret = rref.rpc_sync().static_async_add("worker2", torch.ones(2), 1, 2)
>>> print(ret) # prints tensor([4., 4.])
>>>
>>> rref = rpc.remote("worker1", AsyncExecutionClass)
>>> ret = rref.rpc_async().static_async_add("worker2", torch.ones(2), 1, 2).wait()
>>> print(ret) # prints tensor([4., 4.])
>>>
>>> rref = rpc.remote("worker1", AsyncExecutionClass)
>>> ret = rref.remote().static_async_add("worker2", torch.ones(2), 1, 2).to_here()
>>> print(ret) # prints tensor([4., 4.])
"""
@functools.wraps(fn)
def wrapper(*args, **kwargs):
return fn(*args, **kwargs)
# Can't declare and use attributes of function objects (mypy#2087)
wrapper._wrapped_async_rpc_function = fn # type: ignore[attr-defined]
return wrapper