Shortcuts

Source code for torch.nn.modules.channelshuffle

import torch.nn.functional as F
from torch import Tensor

from .module import Module


__all__ = ["ChannelShuffle"]


[docs]class ChannelShuffle(Module): r"""Divides and rearranges the channels in a tensor. This operation divides the channels in a tensor of shape :math:`(N, C, *)` into g groups as :math:`(N, \frac{C}{g}, g, *)` and shuffles them, while retaining the original tensor shape in the final output. Args: groups (int): number of groups to divide channels in. Examples:: >>> channel_shuffle = nn.ChannelShuffle(2) >>> input = torch.arange(1, 17, dtype=torch.float32).view(1, 4, 2, 2) >>> input tensor([[[[ 1., 2.], [ 3., 4.]], [[ 5., 6.], [ 7., 8.]], [[ 9., 10.], [11., 12.]], [[13., 14.], [15., 16.]]]]) >>> output = channel_shuffle(input) >>> output tensor([[[[ 1., 2.], [ 3., 4.]], [[ 9., 10.], [11., 12.]], [[ 5., 6.], [ 7., 8.]], [[13., 14.], [15., 16.]]]]) """ __constants__ = ["groups"] groups: int def __init__(self, groups: int) -> None: super().__init__() self.groups = groups def forward(self, input: Tensor) -> Tensor: return F.channel_shuffle(input, self.groups) def extra_repr(self) -> str: return f"groups={self.groups}"

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources