[docs]classMovingMNIST(VisionDataset):"""`MovingMNIST <http://www.cs.toronto.edu/~nitish/unsupervised_video/>`_ Dataset. Args: root (str or ``pathlib.Path``): Root directory of dataset where ``MovingMNIST/mnist_test_seq.npy`` exists. split (string, optional): The dataset split, supports ``None`` (default), ``"train"`` and ``"test"``. If ``split=None``, the full data is returned. split_ratio (int, optional): The split ratio of number of frames. If ``split="train"``, the first split frames ``data[:, :split_ratio]`` is returned. If ``split="test"``, the last split frames ``data[:, split_ratio:]`` is returned. If ``split=None``, this parameter is ignored and the all frames data is returned. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in a torch Tensor and returns a transformed version. E.g, ``transforms.RandomCrop`` """_URL="http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy"def__init__(self,root:Union[str,Path],split:Optional[str]=None,split_ratio:int=10,download:bool=False,transform:Optional[Callable]=None,)->None:super().__init__(root,transform=transform)self._base_folder=os.path.join(self.root,self.__class__.__name__)self._filename=self._URL.split("/")[-1]ifsplitisnotNone:verify_str_arg(split,"split",("train","test"))self.split=splitifnotisinstance(split_ratio,int):raiseTypeError(f"`split_ratio` should be an integer, but got {type(split_ratio)}")elifnot(1<=split_ratio<=19):raiseValueError(f"`split_ratio` should be `1 <= split_ratio <= 19`, but got {split_ratio} instead.")self.split_ratio=split_ratioifdownload:self.download()ifnotself._check_exists():raiseRuntimeError("Dataset not found. You can use download=True to download it.")data=torch.from_numpy(np.load(os.path.join(self._base_folder,self._filename)))ifself.split=="train":data=data[:self.split_ratio]elifself.split=="test":data=data[self.split_ratio:]self.data=data.transpose(0,1).unsqueeze(2).contiguous()
[docs]def__getitem__(self,idx:int)->torch.Tensor:""" Args: idx (int): Index Returns: torch.Tensor: Video frames (torch Tensor[T, C, H, W]). The `T` is the number of frames. """data=self.data[idx]ifself.transformisnotNone:data=self.transform(data)returndata
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.