[docs]defstochastic_depth(input:Tensor,p:float,mode:str,training:bool=True)->Tensor:""" Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth" <https://arxiv.org/abs/1603.09382>`_ used for randomly dropping residual branches of residual architectures. Args: input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one being its batch i.e. a batch with ``N`` rows. p (float): probability of the input to be zeroed. mode (str): ``"batch"`` or ``"row"``. ``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes randomly selected rows from the batch. training: apply stochastic depth if is ``True``. Default: ``True`` Returns: Tensor[N, ...]: The randomly zeroed tensor. """ifnottorch.jit.is_scripting()andnottorch.jit.is_tracing():_log_api_usage_once(stochastic_depth)ifp<0.0orp>1.0:raiseValueError(f"drop probability has to be between 0 and 1, but got {p}")ifmodenotin["batch","row"]:raiseValueError(f"mode has to be either 'batch' or 'row', but got {mode}")ifnottrainingorp==0.0:returninputsurvival_rate=1.0-pifmode=="row":size=[input.shape[0]]+[1]*(input.ndim-1)else:size=[1]*input.ndimnoise=torch.empty(size,dtype=input.dtype,device=input.device)noise=noise.bernoulli_(survival_rate)ifsurvival_rate>0.0:noise.div_(survival_rate)returninput*noise
torch.fx.wrap("stochastic_depth")
[docs]classStochasticDepth(nn.Module):""" See :func:`stochastic_depth`. """def__init__(self,p:float,mode:str)->None:super().__init__()_log_api_usage_once(self)self.p=pself.mode=mode
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.