Shortcuts

Source code for torchaudio.transforms

from __future__ import absolute_import, division, print_function, unicode_literals
from warnings import warn
import math
import torch
from typing import Optional
from . import functional as F
from .compliance import kaldi


__all__ = [
    'Spectrogram',
    'AmplitudeToDB',
    'MelScale',
    'MelSpectrogram',
    'MFCC',
    'MuLawEncoding',
    'MuLawDecoding',
    'Resample',
]


[docs]class Spectrogram(torch.jit.ScriptModule): r"""Create a spectrogram from a audio signal Args: n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins win_length (int): Window size. (Default: ``n_fft``) hop_length (int, optional): Length of hop between STFT windows. ( Default: ``win_length // 2``) pad (int): Two sided padding of signal. (Default: ``0``) window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) power (int): Exponent for the magnitude spectrogram, (must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``) normalized (bool): Whether to normalize by magnitude after stft. (Default: ``False``) wkwargs (Dict[..., ...]): Arguments for window function. (Default: ``None``) """ __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized'] def __init__(self, n_fft=400, win_length=None, hop_length=None, pad=0, window_fn=torch.hann_window, power=2, normalized=False, wkwargs=None): super(Spectrogram, self).__init__() self.n_fft = n_fft # number of FFT bins. the returned STFT result will have n_fft // 2 + 1 # number of frequecies due to onesided=True in torch.stft self.win_length = win_length if win_length is not None else n_fft self.hop_length = hop_length if hop_length is not None else self.win_length // 2 window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs) self.window = torch.jit.Attribute(window, torch.Tensor) self.pad = pad self.power = power self.normalized = normalized
[docs] @torch.jit.script_method def forward(self, waveform): r""" Args: waveform (torch.Tensor): Tensor of audio of dimension (channel, time) Returns: torch.Tensor: Dimension (channel, freq, time), where channel is unchanged, freq is ``n_fft // 2 + 1`` where ``n_fft`` is the number of Fourier bins, and time is the number of window hops (n_frames). """ return F.spectrogram(waveform, self.pad, self.window, self.n_fft, self.hop_length, self.win_length, self.power, self.normalized)
[docs]class AmplitudeToDB(torch.jit.ScriptModule): r"""Turns a tensor from the power/amplitude scale to the decibel scale. This output depends on the maximum value in the input tensor, and so may return different values for an audio clip split into snippets vs. a a full clip. Args: stype (str): scale of input tensor ('power' or 'magnitude'). The power being the elementwise square of the magnitude. (Default: ``'power'``) top_db (float, optional): minimum negative cut-off in decibels. A reasonable number is 80. (Default: ``None``) """ __constants__ = ['multiplier', 'amin', 'ref_value', 'db_multiplier'] def __init__(self, stype='power', top_db=None): super(AmplitudeToDB, self).__init__() self.stype = torch.jit.Attribute(stype, str) if top_db is not None and top_db < 0: raise ValueError('top_db must be positive value') self.top_db = torch.jit.Attribute(top_db, Optional[float]) self.multiplier = 10.0 if stype == 'power' else 20.0 self.amin = 1e-10 self.ref_value = 1.0 self.db_multiplier = math.log10(max(self.amin, self.ref_value))
[docs] @torch.jit.script_method def forward(self, x): r"""Numerically stable implementation from Librosa https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html Args: x (torch.Tensor): Input tensor before being converted to decibel scale Returns: torch.Tensor: Output tensor in decibel scale """ return F.amplitude_to_DB(x, self.multiplier, self.amin, self.db_multiplier, self.top_db)
[docs]class MelScale(torch.jit.ScriptModule): r"""This turns a normal STFT into a mel frequency STFT, using a conversion matrix. This uses triangular filter banks. User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)). Args: n_mels (int): Number of mel filterbanks. (Default: ``128``) sample_rate (int): Sample rate of audio signal. (Default: ``16000``) f_min (float): Minimum frequency. (Default: ``0.``) f_max (float, optional): Maximum frequency. (Default: ``sample_rate // 2``) n_stft (int, optional): Number of bins in STFT. Calculated from first input if None is given. See ``n_fft`` in :class:`Spectrogram`. """ __constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max'] def __init__(self, n_mels=128, sample_rate=16000, f_min=0., f_max=None, n_stft=None): super(MelScale, self).__init__() self.n_mels = n_mels self.sample_rate = sample_rate self.f_max = f_max if f_max is not None else float(sample_rate // 2) assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max) self.f_min = f_min fb = torch.empty(0) if n_stft is None else F.create_fb_matrix( n_stft, self.f_min, self.f_max, self.n_mels) self.fb = torch.jit.Attribute(fb, torch.Tensor)
[docs] @torch.jit.script_method def forward(self, specgram): r""" Args: specgram (torch.Tensor): A spectrogram STFT of dimension (channel, freq, time) Returns: torch.Tensor: Mel frequency spectrogram of size (channel, ``n_mels``, time) """ if self.fb.numel() == 0: tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max, self.n_mels) # Attributes cannot be reassigned outside __init__ so workaround self.fb.resize_(tmp_fb.size()) self.fb.copy_(tmp_fb) # (channel, frequency, time).transpose(...) dot (frequency, n_mels) # -> (channel, time, n_mels).transpose(...) mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2) return mel_specgram
[docs]class MelSpectrogram(torch.jit.ScriptModule): r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram and MelScale. Sources * https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe * https://timsainb.github.io/spectrograms-mfccs-and-inversion-in-python.html * http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html Args: sample_rate (int): Sample rate of audio signal. (Default: ``16000``) win_length (int): Window size. (Default: ``n_fft``) hop_length (int, optional): Length of hop between STFT windows. ( Default: ``win_length // 2``) n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins f_min (float): Minimum frequency. (Default: ``0.``) f_max (float, optional): Maximum frequency. (Default: ``None``) pad (int): Two sided padding of signal. (Default: ``0``) n_mels (int): Number of mel filterbanks. (Default: ``128``) window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) wkwargs (Dict[..., ...]): Arguments for window function. (Default: ``None``) Example >>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True) >>> mel_specgram = transforms.MelSpectrogram(sample_rate)(waveform) # (channel, n_mels, time) """ __constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad', 'n_mels', 'f_min'] def __init__(self, sample_rate=16000, n_fft=400, win_length=None, hop_length=None, f_min=0., f_max=None, pad=0, n_mels=128, window_fn=torch.hann_window, wkwargs=None): super(MelSpectrogram, self).__init__() self.sample_rate = sample_rate self.n_fft = n_fft self.win_length = win_length if win_length is not None else n_fft self.hop_length = hop_length if hop_length is not None else self.win_length // 2 self.pad = pad self.n_mels = n_mels # number of mel frequency bins self.f_max = torch.jit.Attribute(f_max, Optional[float]) self.f_min = f_min self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length, hop_length=self.hop_length, pad=self.pad, window_fn=window_fn, power=2, normalized=False, wkwargs=wkwargs) self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max)
[docs] @torch.jit.script_method def forward(self, waveform): r""" Args: waveform (torch.Tensor): Tensor of audio of dimension (channel, time) Returns: torch.Tensor: Mel frequency spectrogram of size (channel, ``n_mels``, time) """ specgram = self.spectrogram(waveform) mel_specgram = self.mel_scale(specgram) return mel_specgram
[docs]class MFCC(torch.jit.ScriptModule): r"""Create the Mel-frequency cepstrum coefficients from an audio signal By default, this calculates the MFCC on the DB-scaled Mel spectrogram. This is not the textbook implementation, but is implemented here to give consistency with librosa. This output depends on the maximum value in the input spectrogram, and so may return different values for an audio clip split into snippets vs. a a full clip. Args: sample_rate (int): Sample rate of audio signal. (Default: ``16000``) n_mfcc (int): Number of mfc coefficients to retain. (Default: ``40``) dct_type (int): type of DCT (discrete cosine transform) to use. (Default: ``2``) norm (str, optional): norm to use. (Default: ``'ortho'``) log_mels (bool): whether to use log-mel spectrograms instead of db-scaled. (Default: ``False``) melkwargs (dict, optional): arguments for MelSpectrogram. (Default: ``None``) """ __constants__ = ['sample_rate', 'n_mfcc', 'dct_type', 'top_db', 'log_mels'] def __init__(self, sample_rate=16000, n_mfcc=40, dct_type=2, norm='ortho', log_mels=False, melkwargs=None): super(MFCC, self).__init__() supported_dct_types = [2] if dct_type not in supported_dct_types: raise ValueError('DCT type not supported'.format(dct_type)) self.sample_rate = sample_rate self.n_mfcc = n_mfcc self.dct_type = dct_type self.norm = torch.jit.Attribute(norm, Optional[str]) self.top_db = 80.0 self.amplitude_to_DB = AmplitudeToDB('power', self.top_db) if melkwargs is not None: self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate, **melkwargs) else: self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate) if self.n_mfcc > self.MelSpectrogram.n_mels: raise ValueError('Cannot select more MFCC coefficients than # mel bins') dct_mat = F.create_dct(self.n_mfcc, self.MelSpectrogram.n_mels, self.norm) self.dct_mat = torch.jit.Attribute(dct_mat, torch.Tensor) self.log_mels = log_mels
[docs] @torch.jit.script_method def forward(self, waveform): r""" Args: waveform (torch.Tensor): Tensor of audio of dimension (channel, time) Returns: torch.Tensor: specgram_mel_db of size (channel, ``n_mfcc``, time) """ mel_specgram = self.MelSpectrogram(waveform) if self.log_mels: log_offset = 1e-6 mel_specgram = torch.log(mel_specgram + log_offset) else: mel_specgram = self.amplitude_to_DB(mel_specgram) # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc) # -> (channel, time, n_mfcc).tranpose(...) mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2) return mfcc
[docs]class MuLawEncoding(torch.jit.ScriptModule): r"""Encode signal based on mu-law companding. For more info see the `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_ This algorithm assumes the signal has been scaled to between -1 and 1 and returns a signal encoded with values from 0 to quantization_channels - 1 Args: quantization_channels (int): Number of channels (Default: ``256``) """ __constants__ = ['quantization_channels'] def __init__(self, quantization_channels=256): super(MuLawEncoding, self).__init__() self.quantization_channels = quantization_channels
[docs] @torch.jit.script_method def forward(self, x): r""" Args: x (torch.Tensor): A signal to be encoded Returns: x_mu (torch.Tensor): An encoded signal """ return F.mu_law_encoding(x, self.quantization_channels)
[docs]class MuLawDecoding(torch.jit.ScriptModule): r"""Decode mu-law encoded signal. For more info see the `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_ This expects an input with values between 0 and quantization_channels - 1 and returns a signal scaled between -1 and 1. Args: quantization_channels (int): Number of channels (Default: ``256``) """ __constants__ = ['quantization_channels'] def __init__(self, quantization_channels=256): super(MuLawDecoding, self).__init__() self.quantization_channels = quantization_channels
[docs] @torch.jit.script_method def forward(self, x_mu): r""" Args: x_mu (torch.Tensor): A mu-law encoded signal which needs to be decoded Returns: torch.Tensor: The signal decoded """ return F.mu_law_decoding(x_mu, self.quantization_channels)
[docs]class Resample(torch.nn.Module): r"""Resamples a signal from one frequency to another. A resampling method can be given. Args: orig_freq (float): The original frequency of the signal. (Default: ``16000``) new_freq (float): The desired frequency. (Default: ``16000``) resampling_method (str): The resampling method (Default: ``'sinc_interpolation'``) """ def __init__(self, orig_freq=16000, new_freq=16000, resampling_method='sinc_interpolation'): super(Resample, self).__init__() self.orig_freq = orig_freq self.new_freq = new_freq self.resampling_method = resampling_method
[docs] def forward(self, waveform): r""" Args: waveform (torch.Tensor): The input signal of dimension (channel, time) Returns: torch.Tensor: Output signal of dimension (channel, time) """ if self.resampling_method == 'sinc_interpolation': return kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq) raise ValueError('Invalid resampling method: %s' % (self.resampling_method))

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources