importtorchimportwarningsimportweakreffromtypingimportAny,Iterable,List,Tuple__all__=["checkpoint","checkpoint_sequential","CheckpointFunction","check_backward_validity","detach_variable","get_device_states","set_device_states",]defdetach_variable(inputs:Tuple[Any,...])->Tuple[torch.Tensor,...]:ifisinstance(inputs,tuple):out=[]forinpininputs:ifnotisinstance(inp,torch.Tensor):out.append(inp)continuex=inp.detach()x.requires_grad=inp.requires_gradout.append(x)returntuple(out)else:raiseRuntimeError("Only tuple of tensors is supported. Got Unsupported input type: ",type(inputs).__name__)defcheck_backward_validity(inputs:Iterable[Any])->None:ifnotany(inp.requires_gradforinpininputsifisinstance(inp,torch.Tensor)):warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")# We can't know if the run_fn will internally move some args to different devices,# which would require logic to preserve rng states for those devices as well.# We could paranoically stash and restore ALL the rng states for all visible devices,# but that seems very wasteful for most cases. Compromise: Stash the RNG state for# the device of all Tensor args.## To consider: maybe get_device_states and set_device_states should reside in torch/random.py?defget_device_states(*args)->Tuple[List[int],List[torch.Tensor]]:# This will not error out if "arg" is a CPU tensor or a non-tensor type because# the conditionals short-circuit.fwd_gpu_devices=list(set(arg.get_device()forarginargsifisinstance(arg,torch.Tensor)andarg.is_cuda))fwd_gpu_states=[]fordeviceinfwd_gpu_devices:withtorch.cuda.device(device):fwd_gpu_states.append(torch.cuda.get_rng_state())returnfwd_gpu_devices,fwd_gpu_statesdefset_device_states(devices,states)->None:fordevice,stateinzip(devices,states):withtorch.cuda.device(device):torch.cuda.set_rng_state(state)def_get_autocast_kwargs():gpu_autocast_kwargs={"enabled":torch.is_autocast_enabled(),"dtype":torch.get_autocast_gpu_dtype(),"cache_enabled":torch.is_autocast_cache_enabled()}cpu_autocast_kwargs={"enabled":torch.is_autocast_cpu_enabled(),"dtype":torch.get_autocast_cpu_dtype(),"cache_enabled":torch.is_autocast_cache_enabled()}returngpu_autocast_kwargs,cpu_autocast_kwargsclassCheckpointFunction(torch.autograd.Function):@staticmethoddefforward(ctx,run_function,preserve_rng_state,*args):check_backward_validity(args)ctx.run_function=run_functionctx.preserve_rng_state=preserve_rng_state# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.ctx.gpu_autocast_kwargs,ctx.cpu_autocast_kwargs=_get_autocast_kwargs()ifpreserve_rng_state:ctx.fwd_cpu_state=torch.get_rng_state()# Don't eagerly initialize the cuda context by accident.# (If the user intends that the context is initialized later, within their# run_function, we SHOULD actually stash the cuda state here. Unfortunately,# we have no way to anticipate this will happen before we run the function.)ctx.had_cuda_in_fwd=Falseiftorch.cuda._initialized:ctx.had_cuda_in_fwd=Truectx.fwd_gpu_devices,ctx.fwd_gpu_states=get_device_states(*args)# Save non-tensor inputs in ctx, keep a placeholder None for tensors# to be filled out during the backward.ctx.inputs=[]ctx.tensor_indices=[]tensor_inputs=[]fori,arginenumerate(args):iftorch.is_tensor(arg):tensor_inputs.append(arg)ctx.tensor_indices.append(i)ctx.inputs.append(None)else:ctx.inputs.append(arg)ctx.save_for_backward(*tensor_inputs)withtorch.no_grad():outputs=run_function(*args)returnoutputs@staticmethoddefbackward(ctx,*args):ifnottorch.autograd._is_checkpoint_valid():raiseRuntimeError("Checkpointing is not compatible with .grad() or when an `inputs` parameter"" is passed to .backward(). Please use .backward() and do not pass its `inputs`"" argument.")# Copy the list to avoid modifying original list.inputs=list(ctx.inputs)tensor_indices=ctx.tensor_indicestensors=ctx.saved_tensors# Fill in inputs with appropriate saved tensors.fori,idxinenumerate(tensor_indices):inputs[idx]=tensors[i]# Stash the surrounding rng state, and mimic the state that was# present at this time during forward. Restore the surrounding state# when we're done.rng_devices=[]ifctx.preserve_rng_stateandctx.had_cuda_in_fwd:rng_devices=ctx.fwd_gpu_deviceswithtorch.random.fork_rng(devices=rng_devices,enabled=ctx.preserve_rng_state):ifctx.preserve_rng_state:torch.set_rng_state(ctx.fwd_cpu_state)ifctx.had_cuda_in_fwd:set_device_states(ctx.fwd_gpu_devices,ctx.fwd_gpu_states)detached_inputs=detach_variable(tuple(inputs))withtorch.enable_grad(), \
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):outputs=ctx.run_function(*detached_inputs)ifisinstance(outputs,torch.Tensor):outputs=(outputs,)# run backward() with only tensor that requires gradoutputs_with_grad=[]args_with_grad=[]foriinrange(len(outputs)):iftorch.is_tensor(outputs[i])andoutputs[i].requires_grad:outputs_with_grad.append(outputs[i])args_with_grad.append(args[i])iflen(outputs_with_grad)==0:raiseRuntimeError("none of output has requires_grad=True,"" this checkpoint() is not necessary")torch.autograd.backward(outputs_with_grad,args_with_grad)grads=tuple(inp.gradifisinstance(inp,torch.Tensor)elseNoneforinpindetached_inputs)return(None,None)+grads
[docs]defcheckpoint(function,*args,use_reentrant:bool=True,**kwargs):r"""Checkpoint a model or part of the model Checkpointing works by trading compute for memory. Rather than storing all intermediate activations of the entire computation graph for computing backward, the checkpointed part does **not** save intermediate activations, and instead recomputes them in backward pass. It can be applied on any part of a model. Specifically, in the forward pass, :attr:`function` will run in :func:`torch.no_grad` manner, i.e., not storing the intermediate activations. Instead, the forward pass saves the inputs tuple and the :attr:`function` parameter. In the backwards pass, the saved inputs and :attr:`function` is retrieved, and the forward pass is computed on :attr:`function` again, now tracking the intermediate activations, and then the gradients are calculated using these activation values. The output of :attr:`function` can contain non-Tensor values and gradient recording is only performed for the Tensor values. Note that if the output consists of nested structures (ex: custom objects, lists, dicts etc.) consisting of Tensors, these Tensors nested in custom structures will not be considered as part of autograd. .. warning:: If :attr:`function` invocation during backward does anything different than the one during forward, e.g., due to some global variable, the checkpointed version won't be equivalent, and unfortunately it can't be detected. .. warning:: If ``use_reentrant=True`` is specified, then if the checkpointed segment contains tensors detached from the computational graph by `detach()` or `torch.no_grad()`, the backward pass will raise an error. This is because `checkpoint` makes all the outputs require gradients which causes issues when a tensor is defined to have no gradient in the model. To circumvent this, detach the tensors outside of the `checkpoint` function. Note that the checkpointed segment can contain tensors detached from the computational graph if ``use_reentrant=False`` is specified. .. warning:: If ``use_reentrant=True`` is specified, at least one of the inputs needs to have :code:`requires_grad=True` if grads are needed for model inputs, otherwise the checkpointed part of the model won't have gradients. At least one of the outputs needs to have :code:`requires_grad=True` as well. Note that this does not apply if ``use_reentrant=False`` is specified. .. warning:: If ``use_reentrant=True`` is specified, checkpointing currently only supports :func:`torch.autograd.backward` and only if its `inputs` argument is not passed. :func:`torch.autograd.grad` is not supported. If ``use_reentrant=False`` is specified, checkpointing will work with :func:`torch.autograd.grad`. Args: function: describes what to run in the forward pass of the model or part of the model. It should also know how to handle the inputs passed as the tuple. For example, in LSTM, if user passes ``(activation, hidden)``, :attr:`function` should correctly use the first input as ``activation`` and the second input as ``hidden`` preserve_rng_state(bool, optional): Omit stashing and restoring the RNG state during each checkpoint. Default: ``True`` use_reentrant(bool, optional): Use checkpointing implementation that requires re-entrant autograd. If ``use_reentrant=False`` is specified, ``checkpoint`` will use an implementation that does not require re-entrant autograd. This allows ``checkpoint`` to support additional functionality, such as working as expected with ``torch.autograd.grad`` and support for keyword arguments input into the checkpointed function. Note that future versions of PyTorch will default to ``use_reentrant=False``. Default: ``True`` args: tuple containing inputs to the :attr:`function` Returns: Output of running :attr:`function` on :attr:`*args` """# Hack to mix *args with **kwargs in a python 2.7-compliant waypreserve=kwargs.pop('preserve_rng_state',True)ifkwargsanduse_reentrant:raiseValueError("Unexpected keyword arguments: "+",".join(argforarginkwargs))ifuse_reentrant:returnCheckpointFunction.apply(function,preserve,*args)else:return_checkpoint_without_reentrant(function,preserve,*args,**kwargs,)
[docs]defcheckpoint_sequential(functions,segments,input,**kwargs):r"""A helper function for checkpointing sequential models. Sequential models execute a list of modules/functions in order (sequentially). Therefore, we can divide such a model in various segments and checkpoint each segment. All segments except the last will run in :func:`torch.no_grad` manner, i.e., not storing the intermediate activations. The inputs of each checkpointed segment will be saved for re-running the segment in the backward pass. See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works. .. warning:: Checkpointing currently only supports :func:`torch.autograd.backward` and only if its `inputs` argument is not passed. :func:`torch.autograd.grad` is not supported. .. warning: At least one of the inputs needs to have :code:`requires_grad=True` if grads are needed for model inputs, otherwise the checkpointed part of the model won't have gradients. .. warning: Since PyTorch 1.4, it allows only one Tensor as the input and intermediate outputs, just like :class:`torch.nn.Sequential`. Args: functions: A :class:`torch.nn.Sequential` or the list of modules or functions (comprising the model) to run sequentially. segments: Number of chunks to create in the model input: A Tensor that is input to :attr:`functions` preserve_rng_state(bool, optional): Omit stashing and restoring the RNG state during each checkpoint. Default: ``True`` Returns: Output of running :attr:`functions` sequentially on :attr:`*inputs` Example: >>> # xdoctest: +SKIP("stub") >>> model = nn.Sequential(...) >>> input_var = checkpoint_sequential(model, chunks, input_var) """# Hack for keyword-only parameter in a python 2.7-compliant waypreserve=kwargs.pop('preserve_rng_state',True)ifkwargs:raiseValueError("Unexpected keyword arguments: "+",".join(argforarginkwargs))defrun_function(start,end,functions):defforward(input):forjinrange(start,end+1):input=functions[j](input)returninputreturnforwardifisinstance(functions,torch.nn.Sequential):functions=list(functions.children())segment_size=len(functions)//segments# the last chunk has to be non-volatileend=-1forstartinrange(0,segment_size*(segments-1),segment_size):end=start+segment_size-1input=checkpoint(run_function(start,end,functions),input,preserve_rng_state=preserve)returnrun_function(end+1,len(functions)-1,functions)(input)
def_checkpoint_without_reentrant(function,preserve_rng_state=True,*args,**kwargs):"""Checkpointining without re-entrant autograd Args: function: describes what to run in the forward pass of the model or part of the model. It should also know how to handle the inputs passed as the tuple. For example, in LSTM, if user passes ``(activation, hidden)``, :attr:`function` should correctly use the first input as ``activation`` and the second input as ``hidden`` preserve_rng_state(bool, optional): Omit stashing and restoring the RNG state during each checkpoint. Default: ``True`` *args: Arguments to pass in to the given ``function``. **kwargs: Keyword arguments to pass into the given ``function``. """# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.gpu_autocast_kwargs,cpu_autocast_kwargs=_get_autocast_kwargs()ifpreserve_rng_state:fwd_cpu_state=torch.get_rng_state()# Don't eagerly initialize the cuda context by accident.# (If the user intends that the context is initialized later, within their# run_function, we SHOULD actually stash the cuda state here. Unfortunately,# we have no way to anticipate this will happen before we run the function.# If they do so, we raise an error.)had_cuda_in_fwd=Falseiftorch.cuda._initialized:had_cuda_in_fwd=Truefwd_gpu_devices,fwd_gpu_states=get_device_states(*args)# Custom class to be able to take weak referencesclassHolder():pass# The Holder object for each of the saved object is saved directly on the# SavedVariable and is cleared when reset_data() is called on it. We MUST make# sure that this is the only object having an owning reference to ensure that# the Tensor stored in storage is deleted as soon as the corresponding SavedVariable# data is cleared.storage:weakref.WeakKeyDictionary=weakref.WeakKeyDictionary()weak_holder_list=[]defpack(x):# TODO(varal7): Instead of returning abstract object, we can return things metadata (such as# size, device, ...) to catch certain cases of undeterministic behavior of the forwardres=Holder()weak_holder_list.append(weakref.ref(res))returnresdefunpack(x):unpack_counter=0iflen(storage)==0:definner_pack(inner):nonlocalunpack_counterunpack_counter+=1# If the holder went out of scope, the SavedVariable is dead and so# the value will never be read from the storage. Skip filling it.ifweak_holder_list[unpack_counter-1]()isNone:return# Use detach here to ensure we don't keep the temporary autograd# graph created during the second forwardstorage[weak_holder_list[unpack_counter-1]()]=inner.detach()returndefinner_unpack(packed):raiseRuntimeError("You are calling backwards on a tensor that is never exposed. Please open an issue.")# Stash the surrounding rng state, and mimic the state that was# present at this time during forward. Restore the surrounding state# when we're done.rng_devices=[]ifpreserve_rng_stateandhad_cuda_in_fwd:rng_devices=fwd_gpu_deviceswithtorch.random.fork_rng(devices=rng_devices,enabled=preserve_rng_state):ifpreserve_rng_state:torch.set_rng_state(fwd_cpu_state)ifhad_cuda_in_fwd:set_device_states(fwd_gpu_devices,fwd_gpu_states)withtorch.enable_grad(), \
torch.cuda.amp.autocast(**gpu_autocast_kwargs), \
torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
torch.autograd.graph.saved_tensors_hooks(inner_pack,inner_unpack):_unused=function(*args,**kwargs)ifxnotinstorage:raiseRuntimeError("Attempt to retrieve a tensor saved by autograd multiple times without checkpoint"" recomputation being triggered in between, this is not currently supported. Please"" open an issue with details on your use case so that we can prioritize adding this.")returnstorage[x]withtorch.autograd.graph.saved_tensors_hooks(pack,unpack):output=function(*args,**kwargs)iftorch.cuda._initializedandpreserve_rng_stateandnothad_cuda_in_fwd:# Cuda was not initialized before running the forward, so we didn't# stash the CUDA state.raiseRuntimeError("PyTorch's CUDA state was initialized in the forward pass ""of a Checkpoint, which is not allowed. Please open an issue ""if you need this feature.")returnoutput
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.