Shortcuts

Source code for torchaudio.datasets.yesno

import os
from pathlib import Path
from typing import List, Tuple, Union

import torchaudio
from torch import Tensor
from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import _extract_tar


_RELEASE_CONFIGS = {
    "release1": {
        "folder_in_archive": "waves_yesno",
        "url": "http://www.openslr.org/resources/1/waves_yesno.tar.gz",
        "checksum": "c3f49e0cca421f96b75b41640749167b52118f232498667ca7a5f9416aef8e73",
    }
}


[docs]class YESNO(Dataset): """*YesNo* :cite:`YesNo` dataset. Args: root (str or Path): Path to the directory where the dataset is found or downloaded. url (str, optional): The URL to download the dataset from. (default: ``"http://www.openslr.org/resources/1/waves_yesno.tar.gz"``) folder_in_archive (str, optional): The top-level directory of the dataset. (default: ``"waves_yesno"``) download (bool, optional): Whether to download the dataset if it is not found at root path. (default: ``False``). """ def __init__( self, root: Union[str, Path], url: str = _RELEASE_CONFIGS["release1"]["url"], folder_in_archive: str = _RELEASE_CONFIGS["release1"]["folder_in_archive"], download: bool = False, ) -> None: self._parse_filesystem(root, url, folder_in_archive, download) def _parse_filesystem(self, root: str, url: str, folder_in_archive: str, download: bool) -> None: root = Path(root) archive = os.path.basename(url) archive = root / archive self._path = root / folder_in_archive if download: if not os.path.isdir(self._path): if not os.path.isfile(archive): checksum = _RELEASE_CONFIGS["release1"]["checksum"] download_url_to_file(url, archive, hash_prefix=checksum) _extract_tar(archive) if not os.path.isdir(self._path): raise RuntimeError("Dataset not found. Please use `download=True` to download it.") self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*.wav")) def _load_item(self, fileid: str, path: str): labels = [int(c) for c in fileid.split("_")] file_audio = os.path.join(path, fileid + ".wav") waveform, sample_rate = torchaudio.load(file_audio) return waveform, sample_rate, labels
[docs] def __getitem__(self, n: int) -> Tuple[Tensor, int, List[int]]: """Load the n-th sample from the dataset. Args: n (int): The index of the sample to be loaded Returns: Tuple of the following items; Tensor: Waveform int: Sample rate List[int]: labels """ fileid = self._walker[n] item = self._load_item(fileid, self._path) return item
def __len__(self) -> int: return len(self._walker)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources