"""``torch.autograd`` provides classes and functions implementing automaticdifferentiation of arbitrary scalar valued functions. It requires minimalchanges to the existing code - you only need to declare :class:`Tensor` sfor which gradients should be computed with the ``requires_grad=True`` keyword.As of now, we only support autograd for floating point :class:`Tensor` types (half, float, double and bfloat16) and complex :class:`Tensor` types (cfloat, cdouble)."""importtorchimportwarningsfromtorch.typesimport_TensorOrTensorsfromtypingimportAny,Callable,List,Optional,Sequence,Tuple,Unionfrom.variableimportVariablefrom.functionimportFunction,NestedIOFunctionfrom.gradcheckimportgradcheck,gradgradcheckfrom.grad_modeimportno_grad,enable_grad,set_grad_enabled,inference_modefrom.anomaly_modeimportdetect_anomaly,set_detect_anomalyfrom..overridesimporthas_torch_function,handle_torch_functionfrom.importfunctionalfrom.importforward_adfrom.importgraph__all__=['Variable','Function','backward','grad_mode']_OptionalTensor=Optional[torch.Tensor]def_make_grads(outputs:Sequence[torch.Tensor],grads:Sequence[_OptionalTensor])->Tuple[_OptionalTensor,...]:new_grads:List[_OptionalTensor]=[]forout,gradinzip(outputs,grads):ifisinstance(grad,torch.Tensor):ifnotout.shape==grad.shape:raiseRuntimeError("Mismatch in shape: grad_output["+str(grads.index(grad))+"] has a shape of "+str(grad.shape)+" and output["+str(outputs.index(out))+"] has a shape of "+str(out.shape)+".")ifout.dtype.is_complex!=grad.dtype.is_complex:raiseRuntimeError("For complex Tensors, both grad_output and output"" are required to have the same dtype."" Mismatch in dtype: grad_output["+str(grads.index(grad))+"] has a dtype of "+str(grad.dtype)+" and output["+str(outputs.index(out))+"] has a dtype of "+str(out.dtype)+".")new_grads.append(grad)elifgradisNone:ifout.requires_grad:ifout.numel()!=1:raiseRuntimeError("grad can be implicitly created only for scalar outputs")new_grads.append(torch.ones_like(out,memory_format=torch.preserve_format))else:new_grads.append(None)else:raiseTypeError("gradients can be either Tensors or None, but got "+type(grad).__name__)returntuple(new_grads)def_tensor_or_tensors_to_tuple(tensors:Optional[_TensorOrTensors],length:int)->Tuple[_OptionalTensor,...]:iftensorsisNone:return(None,)*lengthifisinstance(tensors,torch.Tensor):return(tensors,)returntuple(tensors)defbackward(tensors:_TensorOrTensors,grad_tensors:Optional[_TensorOrTensors]=None,retain_graph:Optional[bool]=None,create_graph:bool=False,grad_variables:Optional[_TensorOrTensors]=None,inputs:Optional[_TensorOrTensors]=None,)->None:r"""Computes the sum of gradients of given tensors with respect to graph leaves. The graph is differentiated using the chain rule. If any of ``tensors`` are non-scalar (i.e. their data has more than one element) and require gradient, then the Jacobian-vector product would be computed, in this case the function additionally requires specifying ``grad_tensors``. It should be a sequence of matching length, that contains the "vector" in the Jacobian-vector product, usually the gradient of the differentiated function w.r.t. corresponding tensors (``None`` is an acceptable value for all tensors that don't need gradient tensors). This function accumulates gradients in the leaves - you might need to zero ``.grad`` attributes or set them to ``None`` before calling it. See :ref:`Default gradient layouts<default-grad-layouts>` for details on the memory layout of accumulated gradients. .. note:: Using this method with ``create_graph=True`` will create a reference cycle between the parameter and its gradient which can cause a memory leak. We recommend using ``autograd.grad`` when creating the graph to avoid this. If you have to use this function, make sure to reset the ``.grad`` fields of your parameters to ``None`` after use to break the cycle and avoid the leak. .. note:: If you run any forward ops, create ``grad_tensors``, and/or call ``backward`` in a user-specified CUDA stream context, see :ref:`Stream semantics of backward passes<bwd-cuda-stream-semantics>`. .. note:: When ``inputs`` are provided and a given input is not a leaf, the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients). It is an implementation detail on which the user should not rely. See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details. Args: tensors (Sequence[Tensor] or Tensor): Tensors of which the derivative will be computed. grad_tensors (Sequence[Tensor or None] or Tensor, optional): The "vector" in the Jacobian-vector product, usually gradients w.r.t. each element of corresponding tensors. None values can be specified for scalar Tensors or ones that don't require grad. If a None value would be acceptable for all grad_tensors, then this argument is optional. retain_graph (bool, optional): If ``False``, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to ``True`` is not needed and often can be worked around in a much more efficient way. Defaults to the value of ``create_graph``. create_graph (bool, optional): If ``True``, graph of the derivative will be constructed, allowing to compute higher order derivative products. Defaults to ``False``. inputs (Sequence[Tensor] or Tensor, optional): Inputs w.r.t. which the gradient be will accumulated into ``.grad``. All other Tensors will be ignored. If not provided, the gradient is accumulated into all the leaf Tensors that were used to compute the attr::tensors. """ifgrad_variablesisnotNone:warnings.warn("'grad_variables' is deprecated. Use 'grad_tensors' instead.")ifgrad_tensorsisNone:grad_tensors=grad_variableselse:raiseRuntimeError("'grad_tensors' and 'grad_variables' (deprecated) ""arguments both passed to backward(). Please only ""use 'grad_tensors'.")ifinputsisnotNoneandlen(inputs)==0:raiseRuntimeError("'inputs' argument to backward() cannot be empty.")tensors=(tensors,)ifisinstance(tensors,torch.Tensor)elsetuple(tensors)inputs=(inputs,)ifisinstance(inputs,torch.Tensor)else \
tuple(inputs)ifinputsisnotNoneelsetuple()grad_tensors_=_tensor_or_tensors_to_tuple(grad_tensors,len(tensors))grad_tensors_=_make_grads(tensors,grad_tensors_)ifretain_graphisNone:retain_graph=create_graphVariable._execution_engine.run_backward(tensors,grad_tensors_,retain_graph,create_graph,inputs,allow_unreachable=True,accumulate_grad=True)# allow_unreachable flag
[docs]defgrad(outputs:_TensorOrTensors,inputs:_TensorOrTensors,grad_outputs:Optional[_TensorOrTensors]=None,retain_graph:Optional[bool]=None,create_graph:bool=False,only_inputs:bool=True,allow_unused:bool=False)->Tuple[torch.Tensor,...]:r"""Computes and returns the sum of gradients of outputs with respect to the inputs. ``grad_outputs`` should be a sequence of length matching ``output`` containing the "vector" in Jacobian-vector product, usually the pre-computed gradients w.r.t. each of the outputs. If an output doesn't require_grad, then the gradient can be ``None``). .. note:: If you run any forward ops, create ``grad_outputs``, and/or call ``grad`` in a user-specified CUDA stream context, see :ref:`Stream semantics of backward passes<bwd-cuda-stream-semantics>`. .. note:: ``only_inputs`` argument is deprecated and is ignored now (defaults to ``True``). To accumulate gradient for other parts of the graph, please use ``torch.autograd.backward``. Args: outputs (sequence of Tensor): outputs of the differentiated function. inputs (sequence of Tensor): Inputs w.r.t. which the gradient will be returned (and not accumulated into ``.grad``). grad_outputs (sequence of Tensor): The "vector" in the Jacobian-vector product. Usually gradients w.r.t. each output. None values can be specified for scalar Tensors or ones that don't require grad. If a None value would be acceptable for all grad_tensors, then this argument is optional. Default: None. retain_graph (bool, optional): If ``False``, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to ``True`` is not needed and often can be worked around in a much more efficient way. Defaults to the value of ``create_graph``. create_graph (bool, optional): If ``True``, graph of the derivative will be constructed, allowing to compute higher order derivative products. Default: ``False``. allow_unused (bool, optional): If ``False``, specifying inputs that were not used when computing outputs (and therefore their grad is always zero) is an error. Defaults to ``False``. """outputs=(outputs,)ifisinstance(outputs,torch.Tensor)elsetuple(outputs)inputs=(inputs,)ifisinstance(inputs,torch.Tensor)elsetuple(inputs)overridable_args=outputs+inputsifhas_torch_function(overridable_args):returnhandle_torch_function(grad,overridable_args,outputs,inputs,grad_outputs=grad_outputs,retain_graph=retain_graph,create_graph=create_graph,only_inputs=only_inputs,allow_unused=allow_unused,)ifnotonly_inputs:warnings.warn("only_inputs argument is deprecated and is ignored now ""(defaults to True). To accumulate gradient for other ""parts of the graph, please use torch.autograd.backward.")grad_outputs_=_tensor_or_tensors_to_tuple(grad_outputs,len(outputs))grad_outputs_=_make_grads(outputs,grad_outputs_)ifretain_graphisNone:retain_graph=create_graphreturnVariable._execution_engine.run_backward(outputs,grad_outputs_,retain_graph,create_graph,inputs,allow_unused,accumulate_grad=False)
# This function applies in case of gradient checkpointing for memory# optimization. Currently, gradient checkpointing is supported only if the# execution engine is invoked through torch.autograd.backward() and its# inputs argument is not passed. It is not supported for torch.autograd.grad().# This is because if inputs are specified, the gradient won't be calculated for# anything else e.g. model parameters like weights, bias etc.## This function returns whether the checkpointing is valid i.e. torch.autograd.backward# or not i.e. torch.autograd.grad. The implementation works by maintaining a thread# local variable in torch/csrc/autograd/engine.cpp which looks at the NodeTask# in the stack and before a NodeTask is executed in evaluate_function, it# checks for whether reentrant backwards is imperative or not.# See https://github.com/pytorch/pytorch/pull/4594 for more discussion/contextdef_is_checkpoint_valid():returnVariable._execution_engine.is_checkpoint_valid()defvariable(*args,**kwargs):warnings.warn("torch.autograd.variable(...) is deprecated, use torch.tensor(...) instead")returntorch.tensor(*args,**kwargs)ifnottorch._C._autograd_init():raiseRuntimeError("autograd initialization failed")# Import all native method/classesfromtorch._C._autogradimport(DeviceType,ProfilerActivity,ProfilerState,ProfilerConfig,ProfilerEvent,_enable_profiler_legacy,_disable_profiler_legacy,_profiler_enabled,_enable_record_function,_set_empty_test_observer,kineto_available,_supported_activities,_add_metadata_json,SavedTensor,_register_saved_tensors_default_hooks,_reset_saved_tensors_default_hooks)fromtorch._C._autogradimport(_ProfilerResult,_KinetoEvent,_prepare_profiler,_enable_profiler,_disable_profiler)from.importprofiler
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.