Shortcuts

Source code for torchvision.ops.drop_block

import torch
import torch.fx
import torch.nn.functional as F
from torch import nn, Tensor

from ..utils import _log_api_usage_once


[docs]def drop_block2d( input: Tensor, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06, training: bool = True ) -> Tensor: """ Implements DropBlock2d from `"DropBlock: A regularization method for convolutional networks" <https://arxiv.org/abs/1810.12890>`. Args: input (Tensor[N, C, H, W]): The input tensor or 4-dimensions with the first one being its batch i.e. a batch with ``N`` rows. p (float): Probability of an element to be dropped. block_size (int): Size of the block to drop. inplace (bool): If set to ``True``, will do this operation in-place. Default: ``False``. eps (float): A value added to the denominator for numerical stability. Default: 1e-6. training (bool): apply dropblock if is ``True``. Default: ``True``. Returns: Tensor[N, C, H, W]: The randomly zeroed tensor after dropblock. """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(drop_block2d) if p < 0.0 or p > 1.0: raise ValueError(f"drop probability has to be between 0 and 1, but got {p}.") if input.ndim != 4: raise ValueError(f"input should be 4 dimensional. Got {input.ndim} dimensions.") if not training or p == 0.0: return input N, C, H, W = input.size() block_size = min(block_size, W, H) # compute the gamma of Bernoulli distribution gamma = (p * H * W) / ((block_size**2) * ((H - block_size + 1) * (W - block_size + 1))) noise = torch.empty((N, C, H - block_size + 1, W - block_size + 1), dtype=input.dtype, device=input.device) noise.bernoulli_(gamma) noise = F.pad(noise, [block_size // 2] * 4, value=0) noise = F.max_pool2d(noise, stride=(1, 1), kernel_size=(block_size, block_size), padding=block_size // 2) noise = 1 - noise normalize_scale = noise.numel() / (eps + noise.sum()) if inplace: input.mul_(noise).mul_(normalize_scale) else: input = input * noise * normalize_scale return input
[docs]def drop_block3d( input: Tensor, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06, training: bool = True ) -> Tensor: """ Implements DropBlock3d from `"DropBlock: A regularization method for convolutional networks" <https://arxiv.org/abs/1810.12890>`. Args: input (Tensor[N, C, D, H, W]): The input tensor or 5-dimensions with the first one being its batch i.e. a batch with ``N`` rows. p (float): Probability of an element to be dropped. block_size (int): Size of the block to drop. inplace (bool): If set to ``True``, will do this operation in-place. Default: ``False``. eps (float): A value added to the denominator for numerical stability. Default: 1e-6. training (bool): apply dropblock if is ``True``. Default: ``True``. Returns: Tensor[N, C, D, H, W]: The randomly zeroed tensor after dropblock. """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(drop_block3d) if p < 0.0 or p > 1.0: raise ValueError(f"drop probability has to be between 0 and 1, but got {p}.") if input.ndim != 5: raise ValueError(f"input should be 5 dimensional. Got {input.ndim} dimensions.") if not training or p == 0.0: return input N, C, D, H, W = input.size() block_size = min(block_size, D, H, W) # compute the gamma of Bernoulli distribution gamma = (p * D * H * W) / ((block_size**3) * ((D - block_size + 1) * (H - block_size + 1) * (W - block_size + 1))) noise = torch.empty( (N, C, D - block_size + 1, H - block_size + 1, W - block_size + 1), dtype=input.dtype, device=input.device ) noise.bernoulli_(gamma) noise = F.pad(noise, [block_size // 2] * 6, value=0) noise = F.max_pool3d( noise, stride=(1, 1, 1), kernel_size=(block_size, block_size, block_size), padding=block_size // 2 ) noise = 1 - noise normalize_scale = noise.numel() / (eps + noise.sum()) if inplace: input.mul_(noise).mul_(normalize_scale) else: input = input * noise * normalize_scale return input
torch.fx.wrap("drop_block2d")
[docs]class DropBlock2d(nn.Module): """ See :func:`drop_block2d`. """ def __init__(self, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06) -> None: super().__init__() self.p = p self.block_size = block_size self.inplace = inplace self.eps = eps
[docs] def forward(self, input: Tensor) -> Tensor: """ Args: input (Tensor): Input feature map on which some areas will be randomly dropped. Returns: Tensor: The tensor after DropBlock layer. """ return drop_block2d(input, self.p, self.block_size, self.inplace, self.eps, self.training)
def __repr__(self) -> str: s = f"{self.__class__.__name__}(p={self.p}, block_size={self.block_size}, inplace={self.inplace})" return s
torch.fx.wrap("drop_block3d")
[docs]class DropBlock3d(DropBlock2d): """ See :func:`drop_block3d`. """ def __init__(self, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06) -> None: super().__init__(p, block_size, inplace, eps)
[docs] def forward(self, input: Tensor) -> Tensor: """ Args: input (Tensor): Input feature map on which some areas will be randomly dropped. Returns: Tensor: The tensor after DropBlock layer. """ return drop_block3d(input, self.p, self.block_size, self.inplace, self.eps, self.training)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources