Source code for torch.nn.modules.flatten

from .module import Module

from torch import Tensor

[docs]class Flatten(Module): r""" Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`. Args: start_dim: first dim to flatten (default = 1). end_dim: last dim to flatten (default = -1). Shape: - Input: :math:`(N, *dims)` - Output: :math:`(N, \prod *dims)` (for the default case). Examples:: >>> m = nn.Sequential( >>> nn.Conv2d(1, 32, 5, 1, 1), >>> nn.Flatten() >>> ) """ __constants__ = ['start_dim', 'end_dim'] start_dim: int end_dim: int def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None: super(Flatten, self).__init__() self.start_dim = start_dim self.end_dim = end_dim def forward(self, input: Tensor) -> Tensor: return input.flatten(self.start_dim, self.end_dim)


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources