Shortcuts

Source code for torchaudio.datasets.yesno

from __future__ import absolute_import, division, print_function, unicode_literals
import torch.utils.data as data
import os
import os.path
import shutil
import errno
import torch
import torchaudio


[docs]class YESNO(data.Dataset): r"""`YesNo Hebrew <http://www.openslr.org/1/>`_ Dataset. Args: root (str): Root directory of dataset where ``processed/training.pt`` and ``processed/test.pt`` exist. transform (Callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.Spectrogram``. ( Default: ``None``) target_transform (Callable, optional): A function/transform that takes in the target and transforms it. (Default: ``None``) download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. (Default: ``False``) dev_mode(bool, optional): If true, clean up is not performed on downloaded files. Useful to keep raw audio and transcriptions. (Default: ``False``) """ raw_folder = 'yesno/raw' processed_folder = 'yesno/processed' url = 'http://www.openslr.org/resources/1/waves_yesno.tar.gz' dset_path = 'waves_yesno' processed_file = 'yesno.pt' def __init__(self, root, transform=None, target_transform=None, download=False, dev_mode=False): self.root = os.path.expanduser(root) self.transform = transform self.target_transform = target_transform self.dev_mode = dev_mode self.data = [] self.labels = [] self.num_samples = 0 self.max_len = 0 if download: self.download() if not self._check_exists(): raise RuntimeError('Dataset not found.' + ' You can use download=True to download it') self.data, self.labels = torch.load(os.path.join( self.root, self.processed_folder, self.processed_file))
[docs] def __getitem__(self, index): """ Args: index (int): Index Returns: Tuple[torch.Tensor, int]: The output tuple (image, target) where target is index of the target class. """ audio, target = self.data[index], self.labels[index] if self.transform is not None: audio = self.transform(audio) if self.target_transform is not None: target = self.target_transform(target) return audio, target
def __len__(self): return len(self.data) def _check_exists(self): return os.path.exists(os.path.join(self.root, self.processed_folder, self.processed_file)) def download(self): """Download the yesno data if it doesn't exist in processed_folder already.""" from six.moves import urllib import tarfile if self._check_exists(): return raw_abs_dir = os.path.join(self.root, self.raw_folder) processed_abs_dir = os.path.join(self.root, self.processed_folder) dset_abs_path = os.path.join( self.root, self.raw_folder, self.dset_path) # download files try: os.makedirs(os.path.join(self.root, self.raw_folder)) os.makedirs(os.path.join(self.root, self.processed_folder)) except OSError as e: if e.errno == errno.EEXIST: pass else: raise url = self.url print('Downloading ' + url) filename = url.rpartition('/')[2] file_path = os.path.join(self.root, self.raw_folder, filename) if not os.path.isfile(file_path): urllib.request.urlretrieve(url, file_path) else: print("Tar file already downloaded") if not os.path.exists(dset_abs_path): with tarfile.open(file_path) as zip_f: zip_f.extractall(raw_abs_dir) else: print("Tar file already extracted") if not self.dev_mode: os.unlink(file_path) # process and save as torch files print('Processing...') shutil.copyfile( os.path.join(dset_abs_path, "README"), os.path.join(processed_abs_dir, "YESNO_README") ) audios = [x for x in os.listdir(dset_abs_path) if ".wav" in x] print("Found {} audio files".format(len(audios))) tensors = [] labels = [] lengths = [] for i, f in enumerate(audios): full_path = os.path.join(dset_abs_path, f) sig, sr = torchaudio.load(full_path) tensors.append(sig) lengths.append(sig.size(1)) labels.append(os.path.basename(f).split(".", 1)[0].split("_")) # sort sigs/labels: longest -> shortest tensors, labels = zip(*[(b, c) for (a, b, c) in sorted( zip(lengths, tensors, labels), key=lambda x: x[0], reverse=True)]) self.max_len = tensors[0].size(1) torch.save( (tensors, labels), os.path.join( self.root, self.processed_folder, self.processed_file ) ) if not self.dev_mode: shutil.rmtree(raw_abs_dir, ignore_errors=True) print('Done!')

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources