importsysimporttorchdefis_available():returnhasattr(torch._C,"_dist_autograd_init")ifis_available()andnottorch._C._dist_autograd_init():raiseRuntimeError("Failed to initialize torch.distributed.autograd")ifis_available():fromtorch._C._distributed_autogradimport(get_gradients,backward,_init,_new_context,_release_context,_get_max_id,_is_valid_context,_retrieve_context,_current_context,_get_debug_info,DistAutogradContext,)
[docs]classcontext:''' Context object to wrap forward and backward passes when using distributed autograd. The ``context_id`` generated in the ``with`` statement is required to uniquely identify a distributed backward pass on all workers. Each worker stores metadata associated with this ``context_id``, which is required to correctly execute a distributed autograd pass. Example:: >>> # xdoctest: +SKIP >>> import torch.distributed.autograd as dist_autograd >>> with dist_autograd.context() as context_id: >>> t1 = torch.rand((3, 3), requires_grad=True) >>> t2 = torch.rand((3, 3), requires_grad=True) >>> loss = rpc.rpc_sync("worker1", torch.add, args=(t1, t2)).sum() >>> dist_autograd.backward(context_id, [loss]) '''def__enter__(self):self.autograd_context=_new_context()returnself.autograd_context._context_id()def__exit__(self,type,value,traceback):_release_context(self.autograd_context._context_id())
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.