Source code for torch.nn.modules.channelshuffle
import torch.nn.functional as F
from torch import Tensor
from .module import Module
__all__ = ["ChannelShuffle"]
[docs]class ChannelShuffle(Module):
r"""Divides and rearranges the channels in a tensor.
This operation divides the channels in a tensor of shape :math:`(N, C, *)`
into g groups as :math:`(N, \frac{C}{g}, g, *)` and shuffles them,
while retaining the original tensor shape in the final output.
Args:
groups (int): number of groups to divide channels in.
Examples::
>>> channel_shuffle = nn.ChannelShuffle(2)
>>> input = torch.arange(1, 17, dtype=torch.float32).view(1, 4, 2, 2)
>>> input
tensor([[[[ 1., 2.],
[ 3., 4.]],
[[ 5., 6.],
[ 7., 8.]],
[[ 9., 10.],
[11., 12.]],
[[13., 14.],
[15., 16.]]]])
>>> output = channel_shuffle(input)
>>> output
tensor([[[[ 1., 2.],
[ 3., 4.]],
[[ 9., 10.],
[11., 12.]],
[[ 5., 6.],
[ 7., 8.]],
[[13., 14.],
[15., 16.]]]])
"""
__constants__ = ["groups"]
groups: int
def __init__(self, groups: int) -> None:
super().__init__()
self.groups = groups
def forward(self, input: Tensor) -> Tensor:
return F.channel_shuffle(input, self.groups)
def extra_repr(self) -> str:
return f"groups={self.groups}"