[docs]classFlatten(Module):r""" Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`. See :meth:`torch.flatten` for details. Shape: - Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,' where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any number of dimensions including none. - Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`. Args: start_dim: first dim to flatten (default = 1). end_dim: last dim to flatten (default = -1). Examples:: >>> input = torch.randn(32, 1, 5, 5) >>> # With default parameters >>> m = nn.Flatten() >>> output = m(input) >>> output.size() torch.Size([32, 25]) >>> # With non-default parameters >>> m = nn.Flatten(0, 2) >>> output = m(input) >>> output.size() torch.Size([160, 5]) """__constants__=['start_dim','end_dim']start_dim:intend_dim:intdef__init__(self,start_dim:int=1,end_dim:int=-1)->None:super().__init__()self.start_dim=start_dimself.end_dim=end_dimdefforward(self,input:Tensor)->Tensor:returninput.flatten(self.start_dim,self.end_dim)defextra_repr(self)->str:returnf'start_dim={self.start_dim}, end_dim={self.end_dim}'
[docs]classUnflatten(Module):r""" Unflattens a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`. * :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively. * :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape` (tuple of `(name, size)` tuples) for `NamedTensor` input. Shape: - Input: :math:`(*, S_{\text{dim}}, *)`, where :math:`S_{\text{dim}}` is the size at dimension :attr:`dim` and :math:`*` means any number of dimensions including none. - Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and :math:`\prod_{i=1}^n U_i = S_{\text{dim}}`. Args: dim (Union[int, str]): Dimension to be unflattened unflattened_size (Union[torch.Size, Tuple, List, NamedShape]): New shape of the unflattened dimension Examples: >>> input = torch.randn(2, 50) >>> # With tuple of ints >>> m = nn.Sequential( >>> nn.Linear(50, 50), >>> nn.Unflatten(1, (2, 5, 5)) >>> ) >>> output = m(input) >>> output.size() torch.Size([2, 2, 5, 5]) >>> # With torch.Size >>> m = nn.Sequential( >>> nn.Linear(50, 50), >>> nn.Unflatten(1, torch.Size([2, 5, 5])) >>> ) >>> output = m(input) >>> output.size() torch.Size([2, 2, 5, 5]) >>> # With namedshape (tuple of tuples) >>> input = torch.randn(2, 50, names=('N', 'features')) >>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5))) >>> output = unflatten(input) >>> output.size() torch.Size([2, 2, 5, 5]) """NamedShape=Tuple[Tuple[str,int]]__constants__=['dim','unflattened_size']dim:Union[int,str]unflattened_size:Union[_size,NamedShape]def__init__(self,dim:Union[int,str],unflattened_size:Union[_size,NamedShape])->None:super().__init__()ifisinstance(dim,int):self._require_tuple_int(unflattened_size)elifisinstance(dim,str):self._require_tuple_tuple(unflattened_size)else:raiseTypeError("invalid argument type for dim parameter")self.dim=dimself.unflattened_size=unflattened_sizedef_require_tuple_tuple(self,input):if(isinstance(input,tuple)):foridx,eleminenumerate(input):ifnotisinstance(elem,tuple):raiseTypeError("unflattened_size must be tuple of tuples, "+f"but found element of type {type(elem).__name__} at pos {idx}")returnraiseTypeError("unflattened_size must be a tuple of tuples, "+f"but found type {type(input).__name__}")def_require_tuple_int(self,input):if(isinstance(input,(tuple,list))):foridx,eleminenumerate(input):ifnotisinstance(elem,int):raiseTypeError("unflattened_size must be tuple of ints, "+f"but found element of type {type(elem).__name__} at pos {idx}")returnraiseTypeError(f"unflattened_size must be a tuple of ints, but found type {type(input).__name__}")defforward(self,input:Tensor)->Tensor:returninput.unflatten(self.dim,self.unflattened_size)defextra_repr(self)->str:returnf'dim={self.dim}, unflattened_size={self.unflattened_size}'
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.