importinspectimportwarningsfromcollectionsimportabc,defaultdictfromenumimportEnumfromtypingimportAny,cast,Dict,List,Optional,Tupleimporttorchfrom.commonimportamp_definitely_not_available__all__=["OptState","GradScaler"]class_MultiDeviceReplicator:""" Lazily serves copies of a tensor to requested devices. Copies are cached per-device. """def__init__(self,master_tensor:torch.Tensor)->None:assertmaster_tensor.is_cudaormaster_tensor.device.type=="xla"self.master=master_tensorself._per_device_tensors:Dict[torch.device,torch.Tensor]={}defget(self,device)->torch.Tensor:retval=self._per_device_tensors.get(device,None)ifretvalisNone:retval=self.master.to(device=device,non_blocking=True,copy=True)self._per_device_tensors[device]=retvalreturnretval# Defines default_factory for GradScaler's _per_optimizer_states defaultdict,# as well as associated "enum" values. Prefers defining these at top level because# - Lambdas can't be pickled, so we don't want to supply a lambda as the factory.# - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler# causes a circular reference, which we'd rather avoid.classOptState(Enum):READY=0UNSCALED=1STEPPED=2def_refresh_per_optimizer_state():return{"stage":OptState.READY,"found_inf_per_device":{}}
[docs]classGradScaler:_scale:Optional[torch.Tensor]_grows_tracker:Optional[torch.Tensor]_per_optimizer_states:Dict[int,Dict[str,Any]]""" An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling conveniently. * ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor. * ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``. * ``scaler.update()`` updates ``scaler``'s scale factor. Example:: # Creates a GradScaler once at the beginning of training. scaler = GradScaler() for epoch in epochs: for input, target in data: optimizer.zero_grad() output = model(input) loss = loss_fn(output, target) # Scales loss. Calls backward() on scaled loss to create scaled gradients. scaler.scale(loss).backward() # scaler.step() first unscales gradients of the optimizer's params. # If gradients don't contain infs/NaNs, optimizer.step() is then called, # otherwise, optimizer.step() is skipped. scaler.step(optimizer) # Updates the scale for next iteration. scaler.update() See the :ref:`Automatic Mixed Precision examples<amp-examples>` for usage (along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty, and multiple losses/optimizers. ``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow, a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used without incurring inf or NaN gradient values. ``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every ``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`). * If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``. * If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual. If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by ``growth_factor``. The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations). Args: init_scale (float, optional, default=2.**16): Initial scale factor. growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during :meth:`update` if inf/NaN gradients occur in an iteration. growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients that must occur for the scale to be multiplied by ``growth_factor``. enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply invokes the underlying ``optimizer.step()``, and other methods become no-ops. Default: ``True`` """def__init__(self,init_scale=2.0**16,growth_factor=2.0,backoff_factor=0.5,growth_interval=2000,enabled=True,):ifenabledandamp_definitely_not_available():warnings.warn("torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.")self._enabled=Falseelse:self._enabled=enabledifself._enabled:assertgrowth_factor>1.0,"The growth factor must be > 1.0."assertbackoff_factor<1.0,"The backoff factor must be < 1.0."self._init_scale=init_scale# self._scale will be lazily initialized during the first call to scale()self._scale=Noneself._growth_factor=growth_factorself._backoff_factor=backoff_factorself._growth_interval=growth_intervalself._init_growth_tracker=0# self._growth_tracker will be lazily initialized during the first call to scale()self._growth_tracker=Noneself._per_optimizer_states=defaultdict(_refresh_per_optimizer_state)def_check_scale_growth_tracker(self,funcname)->Tuple[torch.Tensor,torch.Tensor]:fix="This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration."assertself._scaleisnotNone,(f"Attempted {funcname} but _scale is None. "+fix)assertself._growth_trackerisnotNone,(f"Attempted {funcname} but _growth_tracker is None. "+fix)return(self._scale,self._growth_tracker)def_lazy_init_scale_growth_tracker(self,dev):assertself._growth_trackerisNone,"_growth_tracker initialized before _scale"self._scale=torch.full((),self._init_scale,dtype=torch.float32,device=dev)self._growth_tracker=torch.full((),self._init_growth_tracker,dtype=torch.int32,device=dev)
[docs]defscale(self,outputs):""" Multiplies ('scales') a tensor or list of tensors by the scale factor. Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned unmodified. Args: outputs (Tensor or iterable of Tensors): Outputs to scale. """ifnotself._enabled:returnoutputs# Short-circuit for the common case.ifisinstance(outputs,torch.Tensor):assertoutputs.is_cudaoroutputs.device.type=="xla"ifself._scaleisNone:self._lazy_init_scale_growth_tracker(outputs.device)assertself._scaleisnotNonereturnoutputs*self._scale.to(device=outputs.device,non_blocking=True)# Invoke the more complex machinery only if we're treating multiple outputs.stash:List[_MultiDeviceReplicator]=[]# holds a reference that can be overwritten by apply_scaledefapply_scale(val):ifisinstance(val,torch.Tensor):assertval.is_cudaorval.device.type=="xla"iflen(stash)==0:ifself._scaleisNone:self._lazy_init_scale_growth_tracker(val.device)assertself._scaleisnotNonestash.append(_MultiDeviceReplicator(self._scale))returnval*stash[0].get(val.device)elifisinstance(val,abc.Iterable):iterable=map(apply_scale,val)ifisinstance(val,(list,tuple)):returntype(val)(iterable)else:returniterableelse:raiseValueError("outputs must be a Tensor or an iterable of Tensors")returnapply_scale(outputs)
def_unscale_grads_(self,optimizer,inv_scale,found_inf,allow_fp16):per_device_inv_scale=_MultiDeviceReplicator(inv_scale)per_device_found_inf=_MultiDeviceReplicator(found_inf)# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.# There could be hundreds of grads, so we'd like to iterate through them just once.# However, we don't know their devices or dtypes in advance.# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict# Google says mypy struggles with defaultdicts type annotations.per_device_and_dtype_grads=defaultdict(lambda:defaultdict(list))# type: ignore[var-annotated]withtorch.no_grad():forgroupinoptimizer.param_groups:forparamingroup["params"]:ifparam.gradisNone:continueif(notallow_fp16)andparam.grad.dtype==torch.float16:raiseValueError("Attempting to unscale FP16 gradients.")ifparam.grad.is_sparse:# is_coalesced() == False means the sparse grad has values with duplicate indices.# coalesce() deduplicates indices and adds all values that have the same index.# For scaled fp16 values, there's a good chance coalescing will cause overflow,# so we should check the coalesced _values().ifparam.grad.dtypeistorch.float16:param.grad=param.grad.coalesce()to_unscale=param.grad._values()else:to_unscale=param.grad# TODO: is there a way to split by device and dtype without appending in the inner loop?per_device_and_dtype_grads[to_unscale.device][to_unscale.dtype].append(to_unscale)fordevice,per_dtype_gradsinper_device_and_dtype_grads.items():forgradsinper_dtype_grads.values():torch._amp_foreach_non_finite_check_and_unscale_(grads,per_device_found_inf.get(device),per_device_inv_scale.get(device),)returnper_device_found_inf._per_device_tensors
[docs]defunscale_(self,optimizer):""" Divides ("unscales") the optimizer's gradient tensors by the scale factor. :meth:`unscale_` is optional, serving cases where you need to :ref:`modify or inspect gradients<working-with-unscaled-gradients>` between the backward pass(es) and :meth:`step`. If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`. Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients:: ... scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) scaler.step(optimizer) scaler.update() Args: optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled. .. note:: :meth:`unscale_` does not incur a CPU-GPU sync. .. warning:: :meth:`unscale_` should only be called once per optimizer per :meth:`step` call, and only after all gradients for that optimizer's assigned parameters have been accumulated. Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError. .. warning:: :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute. """ifnotself._enabled:returnself._check_scale_growth_tracker("unscale_")optimizer_state=self._per_optimizer_states[id(optimizer)]ifoptimizer_state["stage"]isOptState.UNSCALED:raiseRuntimeError("unscale_() has already been called on this optimizer since the last update().")elifoptimizer_state["stage"]isOptState.STEPPED:raiseRuntimeError("unscale_() is being called after step().")# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.assertself._scaleisnotNoneinv_scale=self._scale.double().reciprocal().float()found_inf=torch.full((),0.0,dtype=torch.float32,device=self._scale.device)optimizer_state["found_inf_per_device"]=self._unscale_grads_(optimizer,inv_scale,found_inf,False)optimizer_state["stage"]=OptState.UNSCALED
[docs]defstep(self,optimizer,*args,**kwargs):""" :meth:`step` carries out the following two operations: 1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer`` earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs. 2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params. ``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``. Returns the return value of ``optimizer.step(*args, **kwargs)``. Args: optimizer (torch.optim.Optimizer): Optimizer that applies the gradients. args: Any arguments. kwargs: Any keyword arguments. .. warning:: Closure use is not currently supported. """ifnotself._enabled:returnoptimizer.step(*args,**kwargs)if"closure"inkwargs:raiseRuntimeError("Closure use is not currently supported if GradScaler is enabled.")self._check_scale_growth_tracker("step")optimizer_state=self._per_optimizer_states[id(optimizer)]ifoptimizer_state["stage"]isOptState.STEPPED:raiseRuntimeError("step() has already been called since the last update().")retval=Noneif(hasattr(optimizer,"_step_supports_amp_scaling")andoptimizer._step_supports_amp_scaling):# This optimizer has customized scale-handling logic, so we can call optimizer.step() directly.# The contract with custom optimizers is that their step() should accept an additional,# optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information:# it can query its own state, invoke unscale_ on itself, etc# The contract above is being deprecated to avoid introducing `grad_scaler: GradScaler` argument# to `Optimizer.step`. The new behavior is going to add two Tensor attributes of `grad_scale`# and `found_inf` to the passed optimizer so that the optimizer can utilize those# to skip the parameter updates or unscale gradients before updating parameters in# the fused kernel, e.g. `FusedAdamMathFunctor`.# In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`,# while the method is expected to be called by users side, i.e. their optimizers.kwargs_=kwargshas_grad_scaler_kwarg=("grad_scaler"ininspect.signature(optimizer.step).parameters)ifhas_grad_scaler_kwarg:warnings.warn("GradScaler is going to stop passing itself as a keyword argument to the passed ""optimizer. In the near future GradScaler registers `grad_scale: Tensor` and ""`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.",FutureWarning,)kwargs_.update({"grad_scaler":self})else:ifoptimizer_state["stage"]isOptState.READY:self._check_inf_per_device(optimizer)scaler=self._get_scale_async()found_inf=cast(torch.Tensor,sum([t.to(scaler.device,non_blocking=True)fortinoptimizer_state["found_inf_per_device"].values()]),)optimizer.grad_scale=(Noneifoptimizer_state["stage"]==OptState.UNSCALEDelsescaler)optimizer.found_inf=found_infretval=optimizer.step(*args,**kwargs_)optimizer_state["stage"]=OptState.STEPPEDifnothas_grad_scaler_kwarg:deloptimizer.grad_scaledeloptimizer.found_infreturnretvalifoptimizer_state["stage"]isOptState.READY:self.unscale_(optimizer)assert(len(optimizer_state["found_inf_per_device"])>0),"No inf checks were recorded for this optimizer."retval=self._maybe_opt_step(optimizer,optimizer_state,*args,**kwargs)optimizer_state["stage"]=OptState.STEPPEDreturnretval
[docs]defupdate(self,new_scale=None):""" Updates the scale factor. If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, the scale is multiplied by ``growth_factor`` to increase it. Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not used directly, it's used to fill GradScaler's internal scale tensor. So if ``new_scale`` was a tensor, later in-place changes to that tensor will not further affect the scale GradScaler uses internally.) Args: new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor. .. warning:: :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has been invoked for all optimizers used this iteration. .. warning:: For performance reasons, we do not check the scale factor value to avoid synchronizations, so the scale factor is not guaranteed to be above 1. If the scale falls below 1 and/or you are seeing NaNs in your gradients or loss, something is likely wrong. For example, bf16-pretrained models are often incompatible with AMP/fp16 due to differing dynamic ranges. """ifnotself._enabled:return_scale,_growth_tracker=self._check_scale_growth_tracker("update")ifnew_scaleisnotNone:# Accept a new user-defined scale.ifisinstance(new_scale,float):self._scale.fill_(new_scale)# type: ignore[union-attr]else:reason="new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."assertisinstance(new_scale,torch.cuda.FloatTensor),reason# type: ignore[attr-defined]assertnew_scale.numel()==1,reasonassertnew_scale.requires_gradisFalse,reasonself._scale.copy_(new_scale)# type: ignore[union-attr]else:# Consume shared inf/nan data collected from optimizers to update the scale.# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.found_infs=[found_inf.to(device=_scale.device,non_blocking=True)forstateinself._per_optimizer_states.values()forfound_infinstate["found_inf_per_device"].values()]assertlen(found_infs)>0,"No inf checks were recorded prior to update."found_inf_combined=found_infs[0]iflen(found_infs)>1:foriinrange(1,len(found_infs)):found_inf_combined+=found_infs[i]torch._amp_update_scale_(_scale,_growth_tracker,found_inf_combined,self._growth_factor,self._backoff_factor,self._growth_interval,)# To prepare for next iteration, clear the data collected from optimizers this iteration.self._per_optimizer_states=defaultdict(_refresh_per_optimizer_state)
def_get_scale_async(self):returnself._scale
[docs]defget_scale(self):""" Returns a Python float containing the current scale, or 1.0 if scaling is disabled. .. warning:: :meth:`get_scale` incurs a CPU-GPU sync. """ifself._enabled:return(self._init_scaleifself._scaleisNoneelseself._get_scale_async().item())else:return1.0
[docs]defget_growth_factor(self):r""" Returns a Python float containing the scale growth factor. """returnself._growth_factor
[docs]defset_growth_factor(self,new_factor):r""" Args: new_scale (float): Value to use as the new scale growth factor. """self._growth_factor=new_factor
[docs]defget_backoff_factor(self):r""" Returns a Python float containing the scale backoff factor. """returnself._backoff_factor
[docs]defset_backoff_factor(self,new_factor):r""" Args: new_scale (float): Value to use as the new scale backoff factor. """self._backoff_factor=new_factor
[docs]defget_growth_interval(self):r""" Returns a Python int containing the growth interval. """returnself._growth_interval
[docs]defset_growth_interval(self,new_interval):r""" Args: new_interval (int): Value to use as the new growth interval. """self._growth_interval=new_interval
[docs]defis_enabled(self):r""" Returns a bool indicating whether this instance is enabled. """returnself._enabled
[docs]defstate_dict(self):r""" Returns the state of the scaler as a :class:`dict`. It contains five entries: * ``"scale"`` - a Python float containing the current scale * ``"growth_factor"`` - a Python float containing the current growth factor * ``"backoff_factor"`` - a Python float containing the current backoff factor * ``"growth_interval"`` - a Python int containing the current growth interval * ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps. If this instance is not enabled, returns an empty dict. .. note:: If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict` should be called after :meth:`update`. """return({"scale":self.get_scale(),"growth_factor":self._growth_factor,"backoff_factor":self._backoff_factor,"growth_interval":self._growth_interval,"_growth_tracker":self._get_growth_tracker(),}ifself._enabledelse{})
[docs]defload_state_dict(self,state_dict):r""" Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op. Args: state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`. """ifnotself._enabled:returniflen(state_dict)==0:raiseRuntimeError("The source state dict is empty, possibly because it was saved ""from a disabled instance of GradScaler.")self._init_scale=state_dict["scale"]ifself._scaleisnotNone:self._scale.fill_(state_dict["scale"])self._growth_factor=state_dict["growth_factor"]self._backoff_factor=state_dict["backoff_factor"]self._growth_interval=state_dict["growth_interval"]self._init_growth_tracker=state_dict["_growth_tracker"]ifself._growth_trackerisnotNone:self._growth_tracker.fill_(state_dict["_growth_tracker"])
def__getstate__(self):state=self.__dict__.copy()ifself._enabled:assertlen(self._per_optimizer_states)==0,("A GradScaler instance may only be pickled at the beginning ""of an iteration, or at the end after scaler.update().")# Pickling _scale and _growth_tracker Tensors directly triggers# "warnings.warn("pickle support for Storage will be removed in 1.5..."# so instead, we set the unpickled instance up to reinitialize them lazily.state["_init_scale"]=self.get_scale()state["_init_growth_tracker"]=self._get_growth_tracker()state["_scale"]=Nonestate["_growth_tracker"]=Nonereturnstatedef__setstate__(self,state):self.__dict__.update(state)def_check_inf_per_device(self,optimizer):_scale,_=self._check_scale_growth_tracker("_check_inf_per_device")dummy_inv_scale=torch.full((),1.0,dtype=torch.float32,device=_scale.device)found_inf=torch.full((),0.0,dtype=torch.float32,device=_scale.device)self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]=self._unscale_grads_(optimizer,dummy_inv_scale,found_inf,True)returnself._per_optimizer_states[id(optimizer)]["found_inf_per_device"]def_found_inf_per_device(self,optimizer):returnself._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.