Intermediate Expert 150 min read

Chapter 23: Retrieval-Augmented Generation

RAG fundamentals, vector databases, embedding models, and advanced patterns.

Learning Objectives

["Build RAG pipelines", "Use vector databases", "Apply advanced retrieval"]


23.1 Audio Signal Processing Fundamentals Intermediate

Audio Signal Processing Fundamentals

Audio signal processing forms the foundation for all speech and audio AI systems. Understanding how sound is captured, represented, and transformed is essential for building effective models for speech recognition, synthesis, and audio understanding. This section covers the core concepts of digital audio, spectral analysis, and feature extraction techniques that underpin modern audio AI.

Digital Audio Basics

Sound waves must be converted to digital form for processing:

PYTHON
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List, Dict, Union
from dataclasses import dataclass
import matplotlib.pyplot as plt

@dataclass
class AudioConfig:
    """Configuration for audio processing."""
    sample_rate: int = 16000  # Samples per second (Hz)
    n_fft: int = 400  # FFT window size
    hop_length: int = 160  # Samples between STFT frames
    n_mels: int = 80  # Number of mel filterbank channels
    f_min: float = 0.0  # Minimum frequency for mel filterbank
    f_max: Optional[float] = 8000.0  # Maximum frequency
    window: str = "hann"  # Window function


def audio_fundamentals():
    """
    Digital Audio Fundamentals.

    Key concepts:
    - Sampling: Converting continuous signal to discrete samples
    - Sample rate: Number of samples per second (Hz)
    - Nyquist theorem: Sample rate must be > 2x highest frequency
    - Bit depth: Bits per sample (quantization resolution)
    """

    # Common sample rates
    sample_rates = {
        8000: "Telephone quality",
        16000: "Speech recognition standard",
        22050: "Low-quality music",
        44100: "CD quality",
        48000: "Professional audio/video",
        96000: "High-resolution audio"
    }

    # Nyquist frequency = sample_rate / 2
    # Human hearing: ~20 Hz to ~20,000 Hz
    # CD quality (44.1 kHz) can represent up to 22.05 kHz

    print("Common Sample Rates:")
    for rate, description in sample_rates.items():
        nyquist = rate // 2
        print(f"  {rate:5d} Hz: {description} (Nyquist: {nyquist} Hz)")

    return sample_rates


class AudioLoader:
    """
    Load and preprocess audio files.
    """

    def __init__(self, config: AudioConfig):
        self.config = config

    def load(
        self,
        path: str,
        normalize: bool = True
    ) -> Tuple[torch.Tensor, int]:
        """
        Load audio file and resample if needed.

        Args:
            path: Path to audio file
            normalize: Whether to normalize amplitude

        Returns:
            waveform: [1, num_samples] or [2, num_samples] tensor
            sample_rate: Original sample rate
        """
        import torchaudio

        waveform, sr = torchaudio.load(path)

        # Resample if necessary
        if sr != self.config.sample_rate:
            resampler = torchaudio.transforms.Resample(
                orig_freq=sr,
                new_freq=self.config.sample_rate
            )
            waveform = resampler(waveform)

        # Convert stereo to mono
        if waveform.size(0) > 1:
            waveform = waveform.mean(dim=0, keepdim=True)

        # Normalize
        if normalize:
            waveform = waveform / (waveform.abs().max() + 1e-8)

        return waveform, self.config.sample_rate

    def load_batch(
        self,
        paths: List[str],
        max_length: Optional[int] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Load and pad batch of audio files.

        Returns:
            waveforms: [batch, max_samples] padded tensor
            lengths: [batch] actual lengths
        """
        waveforms = []
        lengths = []

        for path in paths:
            wav, _ = self.load(path)
            waveforms.append(wav.squeeze(0))
            lengths.append(wav.size(1))

        # Pad to max length
        if max_length is None:
            max_length = max(lengths)

        padded = torch.zeros(len(waveforms), max_length)
        for i, (wav, length) in enumerate(zip(waveforms, lengths)):
            actual_length = min(length, max_length)
            padded[i, :actual_length] = wav[:actual_length]
            lengths[i] = actual_length

        return padded, torch.tensor(lengths)


def generate_waveforms():
    """Generate example waveforms for visualization."""

    sample_rate = 16000
    duration = 0.1  # 100ms
    t = np.linspace(0, duration, int(sample_rate * duration))

    # Pure sine wave
    freq = 440  # A4 note
    sine_wave = np.sin(2 * np.pi * freq * t)

    # Complex wave (multiple harmonics)
    complex_wave = (
        np.sin(2 * np.pi * freq * t) +  # Fundamental
        0.5 * np.sin(2 * np.pi * 2 * freq * t) +  # 2nd harmonic
        0.25 * np.sin(2 * np.pi * 3 * freq * t)  # 3rd harmonic
    )
    complex_wave = complex_wave / complex_wave.max()

    # White noise
    noise = np.random.randn(len(t))
    noise = noise / noise.max()

    return {
        'sine': (t, sine_wave),
        'complex': (t, complex_wave),
        'noise': (t, noise)
    }

Fourier Transform and Spectral Analysis

The Fourier transform reveals frequency content of audio signals:

PYTHON
class SpectralAnalysis:
    """
    Spectral analysis tools for audio.
    """

    def __init__(self, config: AudioConfig):
        self.config = config

    def fft(self, waveform: torch.Tensor) -> torch.Tensor:
        """
        Compute Fast Fourier Transform.

        Transforms time-domain signal to frequency domain.

        Args:
            waveform: [batch, samples] or [samples]

        Returns:
            spectrum: Complex frequency spectrum
        """
        return torch.fft.fft(waveform)

    def magnitude_spectrum(self, waveform: torch.Tensor) -> torch.Tensor:
        """
        Compute magnitude spectrum (amplitude at each frequency).
        """
        spectrum = self.fft(waveform)
        magnitude = torch.abs(spectrum)

        # Only keep positive frequencies (up to Nyquist)
        n = waveform.size(-1)
        return magnitude[..., :n // 2 + 1]

    def power_spectrum(self, waveform: torch.Tensor) -> torch.Tensor:
        """
        Compute power spectrum (squared magnitude).
        """
        magnitude = self.magnitude_spectrum(waveform)
        return magnitude ** 2

    def frequency_bins(self, n_fft: int) -> np.ndarray:
        """
        Get frequency values for each FFT bin.
        """
        return np.fft.rfftfreq(n_fft, d=1/self.config.sample_rate)


class STFT(nn.Module):
    """
    Short-Time Fourier Transform.

    Analyzes how frequency content changes over time by applying
    FFT to overlapping windows of the signal.
    """

    def __init__(self, config: AudioConfig):
        super().__init__()
        self.config = config

        # Create window function
        if config.window == "hann":
            window = torch.hann_window(config.n_fft)
        elif config.window == "hamming":
            window = torch.hamming_window(config.n_fft)
        else:
            window = torch.ones(config.n_fft)

        self.register_buffer('window', window)

    def forward(self, waveform: torch.Tensor) -> torch.Tensor:
        """
        Compute STFT.

        Args:
            waveform: [batch, samples] audio signal

        Returns:
            stft: [batch, freq_bins, time_frames] complex spectrogram
        """
        # Add channel dim if needed
        if waveform.dim() == 1:
            waveform = waveform.unsqueeze(0)

        # Compute STFT
        stft = torch.stft(
            waveform,
            n_fft=self.config.n_fft,
            hop_length=self.config.hop_length,
            win_length=self.config.n_fft,
            window=self.window,
            center=True,
            pad_mode='reflect',
            normalized=False,
            onesided=True,
            return_complex=True
        )

        return stft

    def spectrogram(self, waveform: torch.Tensor) -> torch.Tensor:
        """
        Compute magnitude spectrogram.

        Returns:
            spectrogram: [batch, freq_bins, time_frames]
        """
        stft = self.forward(waveform)
        return torch.abs(stft)

    def power_spectrogram(self, waveform: torch.Tensor) -> torch.Tensor:
        """
        Compute power spectrogram.
        """
        return self.spectrogram(waveform) ** 2

    def phase(self, waveform: torch.Tensor) -> torch.Tensor:
        """
        Extract phase information.
        """
        stft = self.forward(waveform)
        return torch.angle(stft)


class InverseSTFT(nn.Module):
    """
    Inverse Short-Time Fourier Transform.

    Reconstructs waveform from STFT representation.
    """

    def __init__(self, config: AudioConfig):
        super().__init__()
        self.config = config

        if config.window == "hann":
            window = torch.hann_window(config.n_fft)
        else:
            window = torch.ones(config.n_fft)

        self.register_buffer('window', window)

    def forward(
        self,
        stft: torch.Tensor,
        length: Optional[int] = None
    ) -> torch.Tensor:
        """
        Reconstruct waveform from STFT.

        Args:
            stft: [batch, freq_bins, time_frames] complex spectrogram
            length: Target output length

        Returns:
            waveform: [batch, samples]
        """
        waveform = torch.istft(
            stft,
            n_fft=self.config.n_fft,
            hop_length=self.config.hop_length,
            win_length=self.config.n_fft,
            window=self.window,
            center=True,
            normalized=False,
            onesided=True,
            length=length
        )

        return waveform


def stft_parameters_explained():
    """Explain STFT parameters and their effects."""

    parameters = {
        'n_fft': {
            'description': 'FFT window size (samples)',
            'effect': 'Larger = better frequency resolution, worse time resolution',
            'typical': '400-2048 for speech, 2048-4096 for music',
            'tradeoff': 'Time-frequency resolution tradeoff (uncertainty principle)'
        },
        'hop_length': {
            'description': 'Samples between consecutive frames',
            'effect': 'Smaller = more frames, finer time resolution',
            'typical': 'n_fft // 4 for 75% overlap',
            'tradeoff': 'Computation vs. time resolution'
        },
        'window': {
            'description': 'Window function applied to each frame',
            'options': ['hann', 'hamming', 'blackman', 'rectangular'],
            'effect': 'Reduces spectral leakage (artifacts from non-periodic signals)',
            'typical': 'Hann window most common'
        },
        'center': {
            'description': 'Whether to pad signal so frames are centered',
            'effect': 'True = first/last frames centered on signal edges',
            'typical': 'True for most applications'
        }
    }

    print("STFT Parameters:")
    print("=" * 60)
    for param, info in parameters.items():
        print(f"\n{param}:")
        for k, v in info.items():
            print(f"  {k}: {v}")

Mel Spectrograms

Mel spectrograms provide a perceptually-motivated representation:

PYTHON
class MelSpectrogram(nn.Module):
    """
    Mel Spectrogram computation.

    Applies mel filterbank to power spectrogram, mimicking
    human auditory perception (logarithmic frequency scaling).
    """

    def __init__(self, config: AudioConfig):
        super().__init__()
        self.config = config
        self.stft = STFT(config)

        # Create mel filterbank
        mel_fb = self._create_mel_filterbank()
        self.register_buffer('mel_filterbank', mel_fb)

    def _create_mel_filterbank(self) -> torch.Tensor:
        """
        Create mel filterbank matrix.

        Maps linear frequency bins to mel frequency bins.
        """
        n_freqs = self.config.n_fft // 2 + 1

        # Mel scale conversion functions
        def hz_to_mel(hz):
            return 2595 * np.log10(1 + hz / 700)

        def mel_to_hz(mel):
            return 700 * (10 ** (mel / 2595) - 1)

        # Frequency range
        f_min = self.config.f_min
        f_max = self.config.f_max or self.config.sample_rate / 2

        # Mel points
        mel_min = hz_to_mel(f_min)
        mel_max = hz_to_mel(f_max)
        mel_points = np.linspace(mel_min, mel_max, self.config.n_mels + 2)
        hz_points = mel_to_hz(mel_points)

        # Convert to FFT bin indices
        bin_points = np.floor(
            (self.config.n_fft + 1) * hz_points / self.config.sample_rate
        ).astype(int)

        # Create filterbank
        filterbank = np.zeros((self.config.n_mels, n_freqs))

        for i in range(self.config.n_mels):
            # Rising edge
            for j in range(bin_points[i], bin_points[i + 1]):
                filterbank[i, j] = (j - bin_points[i]) / (bin_points[i + 1] - bin_points[i])
            # Falling edge
            for j in range(bin_points[i + 1], bin_points[i + 2]):
                filterbank[i, j] = (bin_points[i + 2] - j) / (bin_points[i + 2] - bin_points[i + 1])

        return torch.FloatTensor(filterbank)

    def forward(self, waveform: torch.Tensor) -> torch.Tensor:
        """
        Compute mel spectrogram.

        Args:
            waveform: [batch, samples] audio signal

        Returns:
            mel_spec: [batch, n_mels, time_frames]
        """
        # Compute power spectrogram
        power_spec = self.stft.power_spectrogram(waveform)

        # Apply mel filterbank
        mel_spec = torch.matmul(self.mel_filterbank, power_spec)

        return mel_spec

    def log_mel_spectrogram(
        self,
        waveform: torch.Tensor,
        log_offset: float = 1e-6
    ) -> torch.Tensor:
        """
        Compute log mel spectrogram.

        Log compression matches human loudness perception.
        """
        mel_spec = self.forward(waveform)
        return torch.log(mel_spec + log_offset)

    def normalized_log_mel(
        self,
        waveform: torch.Tensor,
        mean: Optional[torch.Tensor] = None,
        std: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Compute normalized log mel spectrogram.

        Normalization improves model training stability.
        """
        log_mel = self.log_mel_spectrogram(waveform)

        if mean is not None and std is not None:
            log_mel = (log_mel - mean) / (std + 1e-6)
        else:
            # Per-utterance normalization
            log_mel = (log_mel - log_mel.mean()) / (log_mel.std() + 1e-6)

        return log_mel


class MelScale:
    """
    Mel scale utilities.

    The mel scale approximates human perception of pitch,
    where equal distances sound equally different to humans.
    """

    @staticmethod
    def hz_to_mel(hz: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
        """Convert Hz to mel scale."""
        return 2595 * np.log10(1 + hz / 700)

    @staticmethod
    def mel_to_hz(mel: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
        """Convert mel to Hz."""
        return 700 * (10 ** (mel / 2595) - 1)

    @staticmethod
    def visualize_mel_scale():
        """Visualize mel scale vs linear frequency."""
        hz = np.linspace(0, 8000, 100)
        mel = MelScale.hz_to_mel(hz)

        print("Mel Scale Mapping:")
        print("-" * 40)
        for h in [100, 500, 1000, 2000, 4000, 8000]:
            m = MelScale.hz_to_mel(h)
            print(f"  {h:5d} Hz -> {m:7.1f} mel")


MelScale.visualize_mel_scale()

MFCCs (Mel-Frequency Cepstral Coefficients)

MFCCs are a compact representation widely used in speech processing:

PYTHON
class MFCC(nn.Module):
    """
    Mel-Frequency Cepstral Coefficients.

    Compact representation that captures spectral envelope
    while removing fine harmonic structure.
    """

    def __init__(
        self,
        config: AudioConfig,
        n_mfcc: int = 13,
        include_deltas: bool = True
    ):
        super().__init__()
        self.config = config
        self.n_mfcc = n_mfcc
        self.include_deltas = include_deltas

        self.mel_spec = MelSpectrogram(config)

        # DCT matrix for cepstral coefficients
        dct_matrix = self._create_dct_matrix()
        self.register_buffer('dct_matrix', dct_matrix)

    def _create_dct_matrix(self) -> torch.Tensor:
        """
        Create DCT-II matrix.

        DCT decorrelates mel bands and compacts energy
        into lower coefficients.
        """
        n_mels = self.config.n_mels
        n_mfcc = self.n_mfcc

        # DCT-II matrix
        dct = np.zeros((n_mfcc, n_mels))
        for k in range(n_mfcc):
            for n in range(n_mels):
                dct[k, n] = np.cos(np.pi * k * (n + 0.5) / n_mels)

        # Orthonormalize
        dct[0] *= 1 / np.sqrt(2)
        dct *= np.sqrt(2 / n_mels)

        return torch.FloatTensor(dct)

    def forward(self, waveform: torch.Tensor) -> torch.Tensor:
        """
        Compute MFCCs.

        Args:
            waveform: [batch, samples]

        Returns:
            mfcc: [batch, n_mfcc * (1 + 2*include_deltas), time_frames]
        """
        # Log mel spectrogram
        log_mel = self.mel_spec.log_mel_spectrogram(waveform)

        # Apply DCT
        mfcc = torch.matmul(self.dct_matrix, log_mel)

        if self.include_deltas:
            # First derivative (delta)
            delta = self._compute_delta(mfcc)
            # Second derivative (delta-delta)
            delta2 = self._compute_delta(delta)
            # Concatenate
            mfcc = torch.cat([mfcc, delta, delta2], dim=1)

        return mfcc

    def _compute_delta(
        self,
        features: torch.Tensor,
        order: int = 2
    ) -> torch.Tensor:
        """
        Compute delta (derivative) features.

        Uses regression over nearby frames.
        """
        # Pad for edge frames
        padded = F.pad(features, (order, order), mode='replicate')

        # Compute weighted sum
        delta = torch.zeros_like(features)
        norm = 2 * sum(i ** 2 for i in range(1, order + 1))

        for i in range(1, order + 1):
            delta += i * (
                padded[..., order + i:padded.size(-1) - order + i] -
                padded[..., order - i:padded.size(-1) - order - i]
            )

        return delta / norm


def feature_comparison():
    """Compare different audio features."""

    features = {
        'Waveform': {
            'dimensions': 'samples',
            'info': 'Raw time-domain signal',
            'pros': 'Complete information, end-to-end learning',
            'cons': 'Very high dimensional, hard to learn from',
            'use_cases': 'WaveNet, raw audio models'
        },
        'Spectrogram': {
            'dimensions': 'freq_bins x time_frames',
            'info': 'Time-frequency representation',
            'pros': 'Shows frequency content over time',
            'cons': 'Linear frequency scale not perceptually motivated',
            'use_cases': 'General audio analysis'
        },
        'Mel Spectrogram': {
            'dimensions': 'n_mels x time_frames (e.g., 80 x T)',
            'info': 'Perceptually-motivated frequency scale',
            'pros': 'Matches human perception, compact',
            'cons': 'Loses phase information',
            'use_cases': 'ASR, TTS, speaker recognition'
        },
        'Log Mel Spectrogram': {
            'dimensions': 'n_mels x time_frames',
            'info': 'Log-compressed mel spectrogram',
            'pros': 'Log matches loudness perception, better dynamic range',
            'cons': 'Still loses phase',
            'use_cases': 'Most modern speech models (Whisper, etc.)'
        },
        'MFCC': {
            'dimensions': 'n_mfcc x time_frames (e.g., 13-39 x T)',
            'info': 'Cepstral coefficients from mel spectrum',
            'pros': 'Very compact, decorrelated features',
            'cons': 'Loses some information, may hurt end-to-end learning',
            'use_cases': 'Traditional ASR, speaker verification'
        }
    }

    print("Audio Feature Comparison:")
    print("=" * 70)
    for name, info in features.items():
        print(f"\n{name}:")
        for k, v in info.items():
            print(f"  {k}: {v}")


feature_comparison()

Audio Data Augmentation

Augmentation improves model robustness:

PYTHON
class AudioAugmentation:
    """
    Data augmentation for audio signals.

    Improves model robustness to noise, variations in
    recording conditions, and speaker differences.
    """

    def __init__(self, config: AudioConfig):
        self.config = config

    def add_noise(
        self,
        waveform: torch.Tensor,
        snr_db: float = 20.0
    ) -> torch.Tensor:
        """
        Add Gaussian noise at specified SNR.

        Args:
            waveform: Input audio
            snr_db: Signal-to-noise ratio in dB
        """
        signal_power = waveform.pow(2).mean()
        noise_power = signal_power / (10 ** (snr_db / 10))

        noise = torch.randn_like(waveform) * noise_power.sqrt()
        return waveform + noise

    def time_stretch(
        self,
        waveform: torch.Tensor,
        rate: float = 1.0
    ) -> torch.Tensor:
        """
        Time stretch without changing pitch.

        Args:
            rate: Stretch factor (>1 = slower, <1 = faster)
        """
        # Simple resampling (changes pitch too)
        # For true time stretch, use phase vocoder
        n_samples = int(waveform.size(-1) * rate)
        return F.interpolate(
            waveform.unsqueeze(0),
            size=n_samples,
            mode='linear',
            align_corners=False
        ).squeeze(0)

    def pitch_shift(
        self,
        waveform: torch.Tensor,
        semitones: float = 0.0
    ) -> torch.Tensor:
        """
        Shift pitch by semitones.

        Uses resampling followed by time stretch to maintain duration.
        """
        ratio = 2 ** (semitones / 12)

        # Resample (changes pitch and duration)
        n_samples = int(waveform.size(-1) / ratio)
        resampled = F.interpolate(
            waveform.unsqueeze(0),
            size=n_samples,
            mode='linear',
            align_corners=False
        ).squeeze(0)

        # Time stretch back to original duration
        return F.interpolate(
            resampled.unsqueeze(0),
            size=waveform.size(-1),
            mode='linear',
            align_corners=False
        ).squeeze(0)

    def random_crop(
        self,
        waveform: torch.Tensor,
        crop_length: int
    ) -> torch.Tensor:
        """
        Random crop to fixed length.
        """
        if waveform.size(-1) <= crop_length:
            # Pad if too short
            padding = crop_length - waveform.size(-1)
            return F.pad(waveform, (0, padding))

        start = torch.randint(0, waveform.size(-1) - crop_length, (1,)).item()
        return waveform[..., start:start + crop_length]

    def volume_perturbation(
        self,
        waveform: torch.Tensor,
        gain_db_range: Tuple[float, float] = (-10, 10)
    ) -> torch.Tensor:
        """
        Random volume change.
        """
        gain_db = torch.empty(1).uniform_(*gain_db_range).item()
        gain = 10 ** (gain_db / 20)
        return waveform * gain


class SpecAugment(nn.Module):
    """
    SpecAugment: Spectrogram augmentation.

    Applies time and frequency masking to spectrograms
    during training for regularization.
    """

    def __init__(
        self,
        freq_mask_param: int = 27,
        time_mask_param: int = 100,
        n_freq_masks: int = 2,
        n_time_masks: int = 2
    ):
        super().__init__()
        self.freq_mask_param = freq_mask_param
        self.time_mask_param = time_mask_param
        self.n_freq_masks = n_freq_masks
        self.n_time_masks = n_time_masks

    def forward(self, spectrogram: torch.Tensor) -> torch.Tensor:
        """
        Apply SpecAugment.

        Args:
            spectrogram: [batch, freq, time] or [freq, time]

        Returns:
            Augmented spectrogram
        """
        augmented = spectrogram.clone()

        if augmented.dim() == 2:
            augmented = augmented.unsqueeze(0)

        batch_size, n_freq, n_time = augmented.shape

        for _ in range(self.n_freq_masks):
            f = torch.randint(0, self.freq_mask_param, (1,)).item()
            f0 = torch.randint(0, max(1, n_freq - f), (1,)).item()
            augmented[:, f0:f0 + f, :] = 0

        for _ in range(self.n_time_masks):
            t = torch.randint(0, self.time_mask_param, (1,)).item()
            t = min(t, int(n_time * 0.5))  # Limit to 50% of length
            t0 = torch.randint(0, max(1, n_time - t), (1,)).item()
            augmented[:, :, t0:t0 + t] = 0

        return augmented.squeeze(0) if spectrogram.dim() == 2 else augmented


class RoomSimulator:
    """
    Simulate room acoustics with reverberation.
    """

    def __init__(self, sample_rate: int = 16000):
        self.sample_rate = sample_rate

    def generate_impulse_response(
        self,
        rt60: float = 0.5,
        room_dim: Tuple[float, float, float] = (5, 4, 3)
    ) -> torch.Tensor:
        """
        Generate synthetic room impulse response.

        Args:
            rt60: Reverberation time (time for 60dB decay)
            room_dim: Room dimensions in meters (length, width, height)

        Returns:
            Impulse response
        """
        # Simple exponential decay model
        duration = rt60 * 2
        n_samples = int(duration * self.sample_rate)

        # Generate noise and apply exponential decay
        ir = torch.randn(n_samples)
        decay = torch.exp(-6.9 * torch.arange(n_samples).float() / (rt60 * self.sample_rate))
        ir = ir * decay

        # Normalize
        ir = ir / ir.abs().max()

        return ir

    def apply_reverb(
        self,
        waveform: torch.Tensor,
        impulse_response: torch.Tensor,
        wet_level: float = 0.3
    ) -> torch.Tensor:
        """
        Apply reverberation using convolution.
        """
        # Convolve
        reverb = F.conv1d(
            waveform.unsqueeze(0).unsqueeze(0),
            impulse_response.unsqueeze(0).unsqueeze(0),
            padding=impulse_response.size(0) - 1
        ).squeeze()

        # Truncate to original length
        reverb = reverb[:waveform.size(-1)]

        # Mix dry and wet
        return (1 - wet_level) * waveform + wet_level * reverb

Key Takeaways

Audio signal processing provides the foundational representations for all speech and audio AI. Key concepts include: (1) digital audio fundamentals like sampling rate and Nyquist frequency that determine what frequencies can be captured, (2) the Short-Time Fourier Transform (STFT) that reveals how frequency content changes over time, (3) mel spectrograms that provide perceptually-motivated frequency scaling matching human hearing, and (4) MFCCs that offer compact, decorrelated features. Modern deep learning models typically use log mel spectrograms as input, with SpecAugment providing regularization during training. Understanding these representations is essential for building effective speech recognition, synthesis, and audio understanding systems.

23.2 Automatic Speech Recognition Advanced

Automatic Speech Recognition

Automatic Speech Recognition (ASR) converts spoken language into text, enabling voice interfaces, transcription services, and audio understanding. Modern ASR has evolved from traditional Hidden Markov Model systems to end-to-end deep learning approaches that directly map audio to text. This section covers the core architectures, training objectives, and techniques that power state-of-the-art speech recognition systems.

ASR Fundamentals

The speech recognition task presents unique challenges due to variable-length sequences:

PYTHON
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List, Dict, Union
from dataclasses import dataclass
import numpy as np

@dataclass
class ASRConfig:
    """Configuration for ASR model."""
    n_mels: int = 80
    encoder_dim: int = 512
    encoder_layers: int = 12
    decoder_dim: int = 512
    decoder_layers: int = 6
    n_heads: int = 8
    vocab_size: int = 5000  # Subword vocabulary
    dropout: float = 0.1
    max_audio_length: int = 3000  # Max spectrogram frames
    max_text_length: int = 448


class ASROverview:
    """
    Overview of ASR approaches and their characteristics.
    """

    @staticmethod
    def paradigms():
        """Main ASR paradigms."""
        paradigms = {
            'Hybrid (Traditional)': {
                'components': ['Acoustic Model (DNN)', 'Language Model (n-gram)', 'Lexicon'],
                'training': 'Separate training for each component',
                'decoder': 'WFST-based decoding',
                'pros': 'Well understood, interpretable',
                'cons': 'Complex pipeline, requires alignment'
            },
            'CTC (Connectionist Temporal Classification)': {
                'components': ['Encoder only'],
                'training': 'End-to-end with CTC loss',
                'decoder': 'Greedy or beam search with LM',
                'pros': 'Simple, no alignment needed',
                'cons': 'Conditional independence assumption'
            },
            'Attention-based Encoder-Decoder': {
                'components': ['Encoder', 'Decoder with attention'],
                'training': 'End-to-end with cross-entropy',
                'decoder': 'Autoregressive beam search',
                'pros': 'Strong modeling, handles alignments',
                'cons': 'Slower decoding, attention errors'
            },
            'Transducer (RNN-T)': {
                'components': ['Encoder', 'Prediction Network', 'Joint Network'],
                'training': 'Transducer loss',
                'decoder': 'Streaming-friendly beam search',
                'pros': 'Streaming capable, strong performance',
                'cons': 'Complex training'
            },
            'Encoder-only with LM (Whisper-style)': {
                'components': ['Encoder', 'Decoder (for multitask)'],
                'training': 'Weak supervision at scale',
                'decoder': 'Greedy or beam search',
                'pros': 'Robust, multilingual, multitask',
                'cons': 'Large model required'
            }
        }

        print("ASR Paradigms:")
        print("=" * 70)
        for name, info in paradigms.items():
            print(f"\n{name}:")
            for k, v in info.items():
                if isinstance(v, list):
                    print(f"  {k}: {', '.join(v)}")
                else:
                    print(f"  {k}: {v}")

        return paradigms


ASROverview.paradigms()

CTC Loss and Decoding

CTC enables training without explicit alignment between audio and text:

PYTHON
class CTCLoss(nn.Module):
    """
    Connectionist Temporal Classification loss.

    Key insight: Sum over all possible alignments between
    input sequence and output sequence.
    """

    def __init__(self, blank_id: int = 0, reduction: str = 'mean'):
        super().__init__()
        self.blank_id = blank_id
        self.ctc_loss = nn.CTCLoss(
            blank=blank_id,
            reduction=reduction,
            zero_infinity=True
        )

    def forward(
        self,
        log_probs: torch.Tensor,
        targets: torch.Tensor,
        input_lengths: torch.Tensor,
        target_lengths: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute CTC loss.

        Args:
            log_probs: [T, B, vocab_size] log probabilities
            targets: [B, S] target sequences (concatenated)
            input_lengths: [B] length of each input sequence
            target_lengths: [B] length of each target sequence

        Returns:
            CTC loss value
        """
        # CTC expects log_probs in [T, B, C] format
        if log_probs.dim() == 3 and log_probs.size(0) != log_probs.size(1):
            # Assume [B, T, C], transpose to [T, B, C]
            log_probs = log_probs.transpose(0, 1)

        return self.ctc_loss(log_probs, targets, input_lengths, target_lengths)


class CTCEncoder(nn.Module):
    """
    CTC-based ASR encoder.

    Outputs frame-level character/subword probabilities.
    """

    def __init__(self, config: ASRConfig):
        super().__init__()
        self.config = config

        # Mel spectrogram frontend
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )

        # Calculate output dimension after conv
        conv_out_dim = (config.n_mels // 4) * 32

        # Project to encoder dimension
        self.input_projection = nn.Linear(conv_out_dim, config.encoder_dim)

        # Conformer/Transformer encoder
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=config.encoder_dim,
                nhead=config.n_heads,
                dim_feedforward=config.encoder_dim * 4,
                dropout=config.dropout,
                batch_first=True
            ),
            num_layers=config.encoder_layers
        )

        # Output projection (vocab + blank)
        self.output_projection = nn.Linear(
            config.encoder_dim,
            config.vocab_size + 1  # +1 for CTC blank
        )

    def forward(
        self,
        mel_spectrogram: torch.Tensor,
        lengths: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass.

        Args:
            mel_spectrogram: [B, n_mels, T] input features
            lengths: [B] input lengths

        Returns:
            log_probs: [B, T', vocab_size+1] frame-level log probabilities
            output_lengths: [B] output sequence lengths
        """
        batch_size = mel_spectrogram.size(0)

        # Add channel dimension for conv
        x = mel_spectrogram.unsqueeze(1)  # [B, 1, n_mels, T]

        # Convolutional frontend
        x = self.conv_layers(x)  # [B, 32, n_mels/4, T/4]

        # Reshape for transformer
        x = x.permute(0, 3, 1, 2)  # [B, T/4, 32, n_mels/4]
        x = x.reshape(batch_size, -1, x.size(2) * x.size(3))  # [B, T/4, 32*n_mels/4]

        # Project
        x = self.input_projection(x)  # [B, T/4, encoder_dim]

        # Create attention mask if lengths provided
        mask = None
        if lengths is not None:
            output_lengths = lengths // 4  # Account for conv subsampling
            max_len = x.size(1)
            mask = torch.arange(max_len, device=x.device)[None, :] >= output_lengths[:, None]
        else:
            output_lengths = torch.full((batch_size,), x.size(1), device=x.device)

        # Transformer encoder
        x = self.encoder(x, src_key_padding_mask=mask)

        # Output projection
        logits = self.output_projection(x)
        log_probs = F.log_softmax(logits, dim=-1)

        return log_probs, output_lengths


class CTCDecoder:
    """
    CTC decoding strategies.
    """

    def __init__(
        self,
        blank_id: int = 0,
        space_id: Optional[int] = None
    ):
        self.blank_id = blank_id
        self.space_id = space_id

    def greedy_decode(
        self,
        log_probs: torch.Tensor,
        lengths: Optional[torch.Tensor] = None
    ) -> List[List[int]]:
        """
        Simple greedy (best path) decoding.

        Take most likely token at each frame, collapse repeats and remove blanks.
        """
        # Get most likely token at each frame
        predictions = log_probs.argmax(dim=-1)  # [B, T]

        batch_results = []
        batch_size = predictions.size(0)

        for b in range(batch_size):
            length = lengths[b].item() if lengths is not None else predictions.size(1)
            pred = predictions[b, :length].tolist()

            # Collapse repeats and remove blanks
            decoded = []
            prev_token = None
            for token in pred:
                if token != self.blank_id and token != prev_token:
                    decoded.append(token)
                prev_token = token

            batch_results.append(decoded)

        return batch_results

    def beam_search_decode(
        self,
        log_probs: torch.Tensor,
        beam_size: int = 10,
        length: Optional[int] = None
    ) -> List[Tuple[List[int], float]]:
        """
        Beam search decoding for single utterance.

        Maintains beam_size hypotheses and their scores.
        """
        T = length or log_probs.size(0)
        log_probs = log_probs[:T]

        # Initialize beams: (prefix, last_token, score)
        # prefix is collapsed sequence, last_token for repeat handling
        beams = [([], None, 0.0)]

        for t in range(T):
            new_beams = {}

            for prefix, last_token, score in beams:
                # For each possible token
                for token in range(log_probs.size(-1)):
                    token_score = log_probs[t, token].item()
                    new_score = score + token_score

                    if token == self.blank_id:
                        # Blank: keep prefix unchanged
                        key = (tuple(prefix), last_token)
                        if key not in new_beams or new_beams[key][2] < new_score:
                            new_beams[key] = (prefix, last_token, new_score)
                    elif token == last_token:
                        # Repeat: keep prefix unchanged (collapse)
                        key = (tuple(prefix), token)
                        if key not in new_beams or new_beams[key][2] < new_score:
                            new_beams[key] = (prefix, token, new_score)
                    else:
                        # New token: extend prefix
                        new_prefix = prefix + [token]
                        key = (tuple(new_prefix), token)
                        if key not in new_beams or new_beams[key][2] < new_score:
                            new_beams[key] = (new_prefix, token, new_score)

            # Keep top beam_size
            beams = sorted(new_beams.values(), key=lambda x: -x[2])[:beam_size]

        # Return sorted results
        results = [(list(b[0]), b[2]) for b in beams]
        return results

Encoder-Decoder ASR

Attention-based models learn soft alignments between audio and text:

PYTHON
class Seq2SeqASR(nn.Module):
    """
    Sequence-to-sequence ASR with attention.

    Encoder processes audio, decoder generates text autoregressively.
    """

    def __init__(self, config: ASRConfig):
        super().__init__()
        self.config = config

        # Audio encoder
        self.encoder = AudioEncoder(config)

        # Text decoder
        self.decoder = TextDecoder(config)

        # Output projection
        self.output_projection = nn.Linear(config.decoder_dim, config.vocab_size)

    def forward(
        self,
        mel_spectrogram: torch.Tensor,
        audio_lengths: torch.Tensor,
        text_tokens: torch.Tensor,
        text_lengths: torch.Tensor
    ) -> torch.Tensor:
        """
        Training forward pass with teacher forcing.

        Args:
            mel_spectrogram: [B, n_mels, T] audio features
            audio_lengths: [B] audio lengths
            text_tokens: [B, S] target text tokens
            text_lengths: [B] text lengths

        Returns:
            logits: [B, S, vocab_size]
        """
        # Encode audio
        encoder_output, encoder_mask = self.encoder(mel_spectrogram, audio_lengths)

        # Decode with teacher forcing
        decoder_output = self.decoder(
            text_tokens[:, :-1],  # Shift right (exclude last token)
            encoder_output,
            encoder_mask
        )

        # Project to vocabulary
        logits = self.output_projection(decoder_output)

        return logits

    @torch.no_grad()
    def transcribe(
        self,
        mel_spectrogram: torch.Tensor,
        audio_lengths: torch.Tensor,
        max_length: int = 200,
        sos_token: int = 1,
        eos_token: int = 2
    ) -> torch.Tensor:
        """
        Greedy decoding for inference.
        """
        batch_size = mel_spectrogram.size(0)
        device = mel_spectrogram.device

        # Encode audio
        encoder_output, encoder_mask = self.encoder(mel_spectrogram, audio_lengths)

        # Start with SOS token
        generated = torch.full((batch_size, 1), sos_token, dtype=torch.long, device=device)

        for _ in range(max_length):
            # Decode current sequence
            decoder_output = self.decoder(generated, encoder_output, encoder_mask)

            # Get next token
            logits = self.output_projection(decoder_output[:, -1])
            next_token = logits.argmax(dim=-1, keepdim=True)

            generated = torch.cat([generated, next_token], dim=1)

            # Check for EOS
            if (next_token == eos_token).all():
                break

        return generated


class AudioEncoder(nn.Module):
    """Transformer encoder for audio."""

    def __init__(self, config: ASRConfig):
        super().__init__()

        # Convolutional frontend for downsampling
        self.conv = nn.Sequential(
            nn.Conv1d(config.n_mels, config.encoder_dim, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv1d(config.encoder_dim, config.encoder_dim, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )

        # Positional encoding
        self.pos_encoding = SinusoidalPositionalEncoding(config.encoder_dim, config.max_audio_length)

        # Transformer layers
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(
                config.encoder_dim,
                config.n_heads,
                config.encoder_dim * 4,
                config.dropout
            )
            for _ in range(config.encoder_layers)
        ])

        self.ln = nn.LayerNorm(config.encoder_dim)

    def forward(
        self,
        mel_spectrogram: torch.Tensor,
        lengths: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Encode audio features.

        Args:
            mel_spectrogram: [B, n_mels, T]
            lengths: [B] input lengths

        Returns:
            encoder_output: [B, T', encoder_dim]
            encoder_mask: [B, T'] padding mask
        """
        # Convolutional frontend
        x = self.conv(mel_spectrogram)  # [B, encoder_dim, T/4]
        x = x.transpose(1, 2)  # [B, T/4, encoder_dim]

        # Add positional encoding
        x = self.pos_encoding(x)

        # Create mask
        mask = None
        if lengths is not None:
            output_lengths = lengths // 4
            max_len = x.size(1)
            mask = torch.arange(max_len, device=x.device)[None, :] >= output_lengths[:, None]

        # Transformer layers
        for layer in self.layers:
            x = layer(x, mask)

        x = self.ln(x)

        return x, mask


class TextDecoder(nn.Module):
    """Transformer decoder for text generation."""

    def __init__(self, config: ASRConfig):
        super().__init__()

        self.embedding = nn.Embedding(config.vocab_size, config.decoder_dim)
        self.pos_encoding = SinusoidalPositionalEncoding(config.decoder_dim, config.max_text_length)

        self.layers = nn.ModuleList([
            TransformerDecoderLayer(
                config.decoder_dim,
                config.n_heads,
                config.decoder_dim * 4,
                config.dropout
            )
            for _ in range(config.decoder_layers)
        ])

        self.ln = nn.LayerNorm(config.decoder_dim)

    def forward(
        self,
        text_tokens: torch.Tensor,
        encoder_output: torch.Tensor,
        encoder_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Decode text autoregressively.

        Args:
            text_tokens: [B, S] input tokens
            encoder_output: [B, T, encoder_dim] encoded audio
            encoder_mask: [B, T] encoder padding mask

        Returns:
            decoder_output: [B, S, decoder_dim]
        """
        # Embed tokens
        x = self.embedding(text_tokens)
        x = self.pos_encoding(x)

        # Causal mask for autoregressive decoding
        seq_len = text_tokens.size(1)
        causal_mask = torch.triu(
            torch.ones(seq_len, seq_len, device=text_tokens.device),
            diagonal=1
        ).bool()

        # Decoder layers
        for layer in self.layers:
            x = layer(x, encoder_output, causal_mask, encoder_mask)

        x = self.ln(x)

        return x


class TransformerEncoderLayer(nn.Module):
    """Transformer encoder layer."""

    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        # Self-attention
        residual = x
        x = self.ln1(x)
        x, _ = self.self_attn(x, x, x, key_padding_mask=mask)
        x = residual + self.dropout(x)

        # FFN
        residual = x
        x = self.ln2(x)
        x = residual + self.dropout(self.ffn(x))

        return x


class TransformerDecoderLayer(nn.Module):
    """Transformer decoder layer with cross-attention."""

    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.cross_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.ln3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x: torch.Tensor,
        encoder_output: torch.Tensor,
        causal_mask: torch.Tensor,
        encoder_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        # Masked self-attention
        residual = x
        x = self.ln1(x)
        x, _ = self.self_attn(x, x, x, attn_mask=causal_mask)
        x = residual + self.dropout(x)

        # Cross-attention
        residual = x
        x = self.ln2(x)
        x, _ = self.cross_attn(x, encoder_output, encoder_output, key_padding_mask=encoder_mask)
        x = residual + self.dropout(x)

        # FFN
        residual = x
        x = self.ln3(x)
        x = residual + self.dropout(self.ffn(x))

        return x


class SinusoidalPositionalEncoding(nn.Module):
    """Sinusoidal positional encoding."""

    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.pe[:, :x.size(1)]

Transducer (RNN-T)

Transducers combine CTC-style encoding with autoregressive decoding:

PYTHON
class RNNTransducer(nn.Module):
    """
    RNN Transducer (RNN-T) for streaming ASR.

    Components:
    - Encoder: Processes audio frames
    - Prediction Network: Processes previous tokens (like LM)
    - Joint Network: Combines encoder and predictor
    """

    def __init__(self, config: ASRConfig):
        super().__init__()
        self.config = config
        self.blank_id = 0

        # Encoder (processes audio)
        self.encoder = AudioEncoder(config)

        # Prediction network (processes previous tokens)
        self.predictor = PredictionNetwork(config)

        # Joint network (combines encoder + predictor)
        self.joint = JointNetwork(config)

    def forward(
        self,
        mel_spectrogram: torch.Tensor,
        audio_lengths: torch.Tensor,
        targets: torch.Tensor,
        target_lengths: torch.Tensor
    ) -> torch.Tensor:
        """
        Forward pass for training.

        Returns log probabilities over vocabulary for transducer loss.
        """
        # Encode audio: [B, T, encoder_dim]
        encoder_output, _ = self.encoder(mel_spectrogram, audio_lengths)

        # Prediction network with targets: [B, U+1, predictor_dim]
        predictor_output = self.predictor(targets)

        # Joint network: [B, T, U+1, vocab_size+1]
        joint_output = self.joint(encoder_output, predictor_output)

        return joint_output

    @torch.no_grad()
    def greedy_decode(
        self,
        mel_spectrogram: torch.Tensor,
        audio_lengths: torch.Tensor,
        max_symbols_per_step: int = 10,
        sos_token: int = 1
    ) -> List[List[int]]:
        """
        Greedy decoding for inference.
        """
        batch_size = mel_spectrogram.size(0)
        device = mel_spectrogram.device

        # Encode audio
        encoder_output, _ = self.encoder(mel_spectrogram, audio_lengths)
        T = encoder_output.size(1)

        results = []

        for b in range(batch_size):
            decoded = []
            # Start with SOS token
            pred_input = torch.tensor([[sos_token]], device=device)

            for t in range(T):
                encoder_frame = encoder_output[b:b+1, t:t+1]  # [1, 1, dim]

                symbols_emitted = 0
                while symbols_emitted < max_symbols_per_step:
                    # Predictor
                    predictor_out = self.predictor(pred_input)[:, -1:]  # [1, 1, dim]

                    # Joint
                    joint_out = self.joint(encoder_frame, predictor_out)  # [1, 1, 1, vocab]
                    log_probs = joint_out.squeeze()

                    # Greedy selection
                    token = log_probs.argmax().item()

                    if token == self.blank_id:
                        # Blank means move to next encoder frame
                        break
                    else:
                        # Emit token
                        decoded.append(token)
                        pred_input = torch.cat([
                            pred_input,
                            torch.tensor([[token]], device=device)
                        ], dim=1)
                        symbols_emitted += 1

            results.append(decoded)

        return results


class PredictionNetwork(nn.Module):
    """
    Prediction network (internal language model).

    Processes previously emitted tokens.
    """

    def __init__(self, config: ASRConfig):
        super().__init__()
        self.embedding = nn.Embedding(config.vocab_size + 1, config.decoder_dim)  # +1 for blank
        self.lstm = nn.LSTM(
            config.decoder_dim,
            config.decoder_dim,
            num_layers=2,
            batch_first=True,
            dropout=config.dropout
        )

    def forward(
        self,
        tokens: torch.Tensor,
        hidden: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
    ) -> torch.Tensor:
        """
        Args:
            tokens: [B, U] previous tokens
            hidden: Optional LSTM hidden state

        Returns:
            predictor_output: [B, U, decoder_dim]
        """
        x = self.embedding(tokens)
        x, _ = self.lstm(x, hidden)
        return x


class JointNetwork(nn.Module):
    """
    Joint network combining encoder and predictor.
    """

    def __init__(self, config: ASRConfig):
        super().__init__()
        self.encoder_proj = nn.Linear(config.encoder_dim, config.encoder_dim)
        self.predictor_proj = nn.Linear(config.decoder_dim, config.encoder_dim)
        self.output = nn.Linear(config.encoder_dim, config.vocab_size + 1)  # +1 for blank

    def forward(
        self,
        encoder_output: torch.Tensor,
        predictor_output: torch.Tensor
    ) -> torch.Tensor:
        """
        Args:
            encoder_output: [B, T, encoder_dim]
            predictor_output: [B, U, decoder_dim]

        Returns:
            joint_output: [B, T, U, vocab_size+1]
        """
        # Project both
        enc = self.encoder_proj(encoder_output)  # [B, T, dim]
        pred = self.predictor_proj(predictor_output)  # [B, U, dim]

        # Broadcast and add
        # enc: [B, T, 1, dim]
        # pred: [B, 1, U, dim]
        enc = enc.unsqueeze(2)
        pred = pred.unsqueeze(1)

        joint = torch.tanh(enc + pred)  # [B, T, U, dim]

        return F.log_softmax(self.output(joint), dim=-1)

Whisper Architecture

Whisper demonstrates the power of large-scale weakly supervised training:

PYTHON
class Whisper(nn.Module):
    """
    Whisper-style encoder-decoder ASR.

    Key innovations:
    - Trained on 680K hours of diverse audio
    - Multitask: transcription, translation, timestamps
    - Special tokens for task specification
    """

    def __init__(self, config: ASRConfig):
        super().__init__()

        # Convolutional stem
        self.conv1 = nn.Conv1d(config.n_mels, config.encoder_dim, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(config.encoder_dim, config.encoder_dim, kernel_size=3, stride=2, padding=1)

        # Positional embedding
        self.encoder_pos_embed = nn.Embedding(config.max_audio_length // 2, config.encoder_dim)
        self.decoder_pos_embed = nn.Embedding(config.max_text_length, config.decoder_dim)

        # Transformer encoder
        self.encoder_layers = nn.ModuleList([
            WhisperEncoderBlock(config.encoder_dim, config.n_heads, config.dropout)
            for _ in range(config.encoder_layers)
        ])
        self.encoder_ln = nn.LayerNorm(config.encoder_dim)

        # Token embeddings
        self.token_embedding = nn.Embedding(config.vocab_size, config.decoder_dim)

        # Transformer decoder
        self.decoder_layers = nn.ModuleList([
            WhisperDecoderBlock(config.decoder_dim, config.n_heads, config.dropout)
            for _ in range(config.decoder_layers)
        ])
        self.decoder_ln = nn.LayerNorm(config.decoder_dim)

        # Output projection (tied with token embeddings)
        self.output_proj = nn.Linear(config.decoder_dim, config.vocab_size, bias=False)

    def encode(self, mel: torch.Tensor) -> torch.Tensor:
        """
        Encode audio features.

        Args:
            mel: [B, n_mels, T] mel spectrogram

        Returns:
            encoder_output: [B, T/2, encoder_dim]
        """
        # Convolutional frontend
        x = F.gelu(self.conv1(mel))
        x = F.gelu(self.conv2(x))  # [B, encoder_dim, T/2]
        x = x.permute(0, 2, 1)  # [B, T/2, encoder_dim]

        # Add positional embedding
        positions = torch.arange(x.size(1), device=x.device)
        x = x + self.encoder_pos_embed(positions)

        # Encoder layers
        for layer in self.encoder_layers:
            x = layer(x)

        x = self.encoder_ln(x)

        return x

    def decode(
        self,
        tokens: torch.Tensor,
        encoder_output: torch.Tensor,
        kv_cache: Optional[Dict] = None
    ) -> Tuple[torch.Tensor, Dict]:
        """
        Decode text tokens.

        Args:
            tokens: [B, S] token ids
            encoder_output: [B, T, encoder_dim]
            kv_cache: Optional KV cache for efficient inference

        Returns:
            logits: [B, S, vocab_size]
            kv_cache: Updated KV cache
        """
        # Token embeddings
        x = self.token_embedding(tokens)

        # Positional embedding
        if kv_cache is None:
            positions = torch.arange(tokens.size(1), device=tokens.device)
        else:
            # During inference with cache, only add position for new tokens
            positions = torch.tensor([kv_cache.get('offset', 0)], device=tokens.device)

        x = x + self.decoder_pos_embed(positions)

        # Decoder layers with KV caching
        new_kv_cache = {}
        for i, layer in enumerate(self.decoder_layers):
            layer_cache = kv_cache.get(f'layer_{i}') if kv_cache else None
            x, new_cache = layer(x, encoder_output, layer_cache)
            new_kv_cache[f'layer_{i}'] = new_cache

        x = self.decoder_ln(x)

        # Project to vocabulary
        logits = self.output_proj(x)

        if kv_cache is not None:
            new_kv_cache['offset'] = kv_cache.get('offset', 0) + tokens.size(1)

        return logits, new_kv_cache

    def forward(
        self,
        mel: torch.Tensor,
        tokens: torch.Tensor
    ) -> torch.Tensor:
        """Training forward pass."""
        encoder_output = self.encode(mel)
        logits, _ = self.decode(tokens[:, :-1], encoder_output)
        return logits

    @torch.no_grad()
    def transcribe(
        self,
        mel: torch.Tensor,
        task: str = "transcribe",
        language: str = "en",
        max_length: int = 224
    ) -> List[str]:
        """
        Transcribe or translate audio.

        Args:
            mel: [B, n_mels, T] mel spectrogram
            task: "transcribe" or "translate"
            language: Language code
            max_length: Maximum output length
        """
        batch_size = mel.size(0)
        device = mel.device

        # Encode audio
        encoder_output = self.encode(mel)

        # Initialize with task tokens
        # Format: <|startoftranscript|><|lang|><|task|><|notimestamps|>
        # Simplified here - actual Whisper uses specific token IDs
        tokens = torch.tensor([[1]], device=device).expand(batch_size, -1)

        kv_cache = None

        for _ in range(max_length):
            logits, kv_cache = self.decode(tokens[:, -1:], encoder_output, kv_cache)
            next_token = logits[:, -1].argmax(dim=-1, keepdim=True)
            tokens = torch.cat([tokens, next_token], dim=1)

            # Check for end token
            if (next_token == 2).all():  # Assuming 2 is EOS
                break

        return tokens


class WhisperEncoderBlock(nn.Module):
    """Whisper encoder block."""

    def __init__(self, d_model: int, n_heads: int, dropout: float):
        super().__init__()
        self.attn_ln = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.mlp_ln = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.attn_ln(x), self.attn_ln(x), self.attn_ln(x))[0]
        x = x + self.mlp(self.mlp_ln(x))
        return x


class WhisperDecoderBlock(nn.Module):
    """Whisper decoder block with KV caching."""

    def __init__(self, d_model: int, n_heads: int, dropout: float):
        super().__init__()
        self.self_attn_ln = nn.LayerNorm(d_model)
        self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.cross_attn_ln = nn.LayerNorm(d_model)
        self.cross_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.mlp_ln = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout)
        )

    def forward(
        self,
        x: torch.Tensor,
        encoder_output: torch.Tensor,
        kv_cache: Optional[Dict] = None
    ) -> Tuple[torch.Tensor, Dict]:
        # Self-attention with causal mask
        residual = x
        x = self.self_attn_ln(x)

        if kv_cache is not None:
            # Use cached keys/values
            k_cache = kv_cache.get('k')
            v_cache = kv_cache.get('v')
            if k_cache is not None:
                k = torch.cat([k_cache, x], dim=1)
                v = torch.cat([v_cache, x], dim=1)
            else:
                k, v = x, x
        else:
            k, v = x, x

        # Create causal mask
        seq_len = k.size(1)
        causal_mask = torch.triu(torch.ones(x.size(1), seq_len, device=x.device), diagonal=1).bool()

        attn_out, _ = self.self_attn(x, k, v, attn_mask=causal_mask)
        x = residual + attn_out

        # Cross-attention
        residual = x
        x = self.cross_attn_ln(x)
        x = residual + self.cross_attn(x, encoder_output, encoder_output)[0]

        # MLP
        x = x + self.mlp(self.mlp_ln(x))

        # Update cache
        new_cache = {'k': k, 'v': v}

        return x, new_cache

Language Model Integration

External language models improve ASR accuracy:

PYTHON
class ShallowFusion:
    """
    Shallow fusion: combine ASR scores with external LM.

    Score = log P_ASR(y|x) + λ * log P_LM(y)
    """

    def __init__(
        self,
        lm_model,
        lm_weight: float = 0.3,
        word_score: float = 0.0
    ):
        self.lm = lm_model
        self.lm_weight = lm_weight
        self.word_score = word_score

    def rescore_beam(
        self,
        asr_log_probs: torch.Tensor,
        tokens: torch.Tensor
    ) -> torch.Tensor:
        """
        Rescore ASR beam with LM.

        Args:
            asr_log_probs: ASR log probabilities
            tokens: Token sequence

        Returns:
            Combined scores
        """
        # Get LM scores
        with torch.no_grad():
            lm_output = self.lm(tokens)
            lm_log_probs = F.log_softmax(lm_output, dim=-1)

        # Combine scores
        combined = asr_log_probs + self.lm_weight * lm_log_probs

        # Optional word insertion bonus
        if self.word_score != 0:
            # Add bonus for each word boundary
            pass

        return combined


def asr_metrics():
    """Common ASR evaluation metrics."""

    metrics = {
        'WER (Word Error Rate)': {
            'formula': '(S + D + I) / N * 100',
            'description': 'Percentage of words that need editing',
            'components': 'S=substitutions, D=deletions, I=insertions, N=total words',
            'typical_good': '< 10% for clean speech',
            'note': 'Most common metric'
        },
        'CER (Character Error Rate)': {
            'formula': '(S + D + I) / N * 100 (at character level)',
            'description': 'Percentage of characters that need editing',
            'typical_good': '< 5% for clean speech',
            'note': 'Better for languages without clear word boundaries'
        },
        'SER (Sentence Error Rate)': {
            'formula': 'Sentences with errors / Total sentences * 100',
            'description': 'Percentage of sentences with any error',
            'note': 'Stricter than WER'
        },
        'RTF (Real-Time Factor)': {
            'formula': 'Processing time / Audio duration',
            'description': 'Speed of transcription',
            'typical_good': '< 1.0 for real-time',
            'note': 'Important for streaming'
        }
    }

    print("ASR Evaluation Metrics:")
    print("=" * 60)
    for name, info in metrics.items():
        print(f"\n{name}:")
        for k, v in info.items():
            print(f"  {k}: {v}")


asr_metrics()

Key Takeaways

Modern ASR has evolved from complex pipelines to elegant end-to-end systems. Key approaches include: (1) CTC for simple non-autoregressive decoding with blank symbols enabling alignment-free training, (2) attention-based encoder-decoder models that learn soft alignments and generate text autoregressively, (3) transducers (RNN-T) that combine the benefits of both for streaming applications, and (4) Whisper-style models trained on massive weakly-supervised data for robust multilingual recognition. Language model integration through shallow or deep fusion further improves accuracy. The choice of architecture depends on requirements: CTC for simplicity, seq2seq for quality, transducers for streaming, and large pretrained models for robustness across diverse conditions.

23.3 Text-to-Speech Synthesis Advanced

Text-to-Speech Synthesis

Text-to-Speech (TTS) synthesis converts written text into natural-sounding speech, enabling voice assistants, audiobook narration, accessibility tools, and countless other applications. Modern neural TTS systems have achieved remarkable naturalness, approaching human quality in many scenarios. This section explores the architectures and techniques that power state-of-the-art speech synthesis, from acoustic models to neural vocoders.

TTS Pipeline Overview

Modern TTS typically involves multiple stages transforming text to audio:

PYTHON
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List, Dict
from dataclasses import dataclass
import numpy as np

@dataclass
class TTSConfig:
    """Configuration for TTS model."""
    vocab_size: int = 256  # Character vocabulary
    encoder_dim: int = 512
    decoder_dim: int = 1024
    n_mels: int = 80
    n_heads: int = 8
    encoder_layers: int = 6
    decoder_layers: int = 6
    dropout: float = 0.1
    max_text_length: int = 200
    max_mel_length: int = 1000


class TTSPipeline:
    """
    Overview of Text-to-Speech pipeline stages.
    """

    @staticmethod
    def stages():
        """TTS pipeline stages."""
        stages = {
            '1. Text Normalization': {
                'input': 'Raw text',
                'output': 'Normalized text',
                'examples': [
                    '$5.99 → five dollars ninety-nine cents',
                    'Dr. Smith → Doctor Smith',
                    '12/25/2024 → December twenty-fifth twenty twenty-four'
                ]
            },
            '2. Grapheme-to-Phoneme (G2P)': {
                'input': 'Normalized text',
                'output': 'Phoneme sequence',
                'examples': [
                    'hello → HH AH0 L OW1',
                    'read (present) → R IY1 D',
                    'read (past) → R EH1 D'
                ],
                'note': 'Handles pronunciation disambiguation'
            },
            '3. Acoustic Model': {
                'input': 'Phonemes (or characters)',
                'output': 'Mel spectrogram',
                'models': ['Tacotron 2', 'FastSpeech 2', 'VITS'],
                'note': 'Predicts acoustic features from linguistic features'
            },
            '4. Vocoder': {
                'input': 'Mel spectrogram',
                'output': 'Waveform',
                'models': ['WaveNet', 'WaveRNN', 'HiFi-GAN', 'UnivNet'],
                'note': 'Converts spectral features to audio'
            }
        }

        print("TTS Pipeline Stages:")
        print("=" * 60)
        for stage, info in stages.items():
            print(f"\n{stage}:")
            for k, v in info.items():
                if isinstance(v, list):
                    print(f"  {k}:")
                    for item in v:
                        print(f"    - {item}")
                else:
                    print(f"  {k}: {v}")

        return stages


TTSPipeline.stages()

Tacotron Architecture

Tacotron pioneered end-to-end neural TTS with attention:

PYTHON
class Tacotron2(nn.Module):
    """
    Tacotron 2: Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions.

    Architecture:
    - Character/phoneme encoder
    - Location-sensitive attention
    - Autoregressive mel decoder
    - Stop token predictor
    """

    def __init__(self, config: TTSConfig):
        super().__init__()
        self.config = config

        # Text encoder
        self.encoder = Tacotron2Encoder(config)

        # Attention mechanism
        self.attention = LocationSensitiveAttention(
            encoder_dim=config.encoder_dim,
            decoder_dim=config.decoder_dim,
            attention_dim=128,
            attention_location_n_filters=32,
            attention_location_kernel_size=31
        )

        # Decoder
        self.decoder = Tacotron2Decoder(config)

        # Postnet for spectrogram refinement
        self.postnet = Postnet(config.n_mels)

    def forward(
        self,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        mel_target: Optional[torch.Tensor] = None,
        mel_lengths: Optional[torch.Tensor] = None
    ) -> Dict[str, torch.Tensor]:
        """
        Training forward pass with teacher forcing.

        Args:
            text: [B, T_text] input text/phoneme indices
            text_lengths: [B] text lengths
            mel_target: [B, n_mels, T_mel] target mel spectrogram
            mel_lengths: [B] mel lengths

        Returns:
            Dictionary with mel outputs, stop tokens, and attention weights
        """
        # Encode text
        encoder_outputs = self.encoder(text, text_lengths)

        # Decode mel spectrogram
        mel_outputs, stop_outputs, attention_weights = self.decoder(
            encoder_outputs,
            text_lengths,
            mel_target
        )

        # Postnet refinement
        mel_postnet = mel_outputs + self.postnet(mel_outputs)

        return {
            'mel_outputs': mel_outputs,
            'mel_outputs_postnet': mel_postnet,
            'stop_outputs': stop_outputs,
            'attention_weights': attention_weights
        }

    @torch.no_grad()
    def inference(
        self,
        text: torch.Tensor,
        max_decoder_steps: int = 1000
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Generate mel spectrogram from text.
        """
        self.eval()

        # Encode
        encoder_outputs = self.encoder(text, torch.tensor([text.size(1)]))

        # Decode without teacher forcing
        mel_outputs, attention_weights = self.decoder.inference(
            encoder_outputs,
            max_decoder_steps
        )

        # Postnet
        mel_postnet = mel_outputs + self.postnet(mel_outputs)

        return mel_postnet, attention_weights


class Tacotron2Encoder(nn.Module):
    """
    Tacotron 2 encoder: character embedding + 3 conv layers + BiLSTM.
    """

    def __init__(self, config: TTSConfig):
        super().__init__()

        self.embedding = nn.Embedding(config.vocab_size, config.encoder_dim)

        # Convolutional layers
        self.convs = nn.ModuleList([
            nn.Sequential(
                nn.Conv1d(
                    config.encoder_dim,
                    config.encoder_dim,
                    kernel_size=5,
                    padding=2
                ),
                nn.BatchNorm1d(config.encoder_dim),
                nn.ReLU(),
                nn.Dropout(config.dropout)
            )
            for _ in range(3)
        ])

        # BiLSTM
        self.lstm = nn.LSTM(
            config.encoder_dim,
            config.encoder_dim // 2,
            batch_first=True,
            bidirectional=True
        )

    def forward(
        self,
        text: torch.Tensor,
        text_lengths: torch.Tensor
    ) -> torch.Tensor:
        """
        Encode text sequence.

        Args:
            text: [B, T] text indices
            text_lengths: [B] text lengths

        Returns:
            encoder_outputs: [B, T, encoder_dim]
        """
        # Embed
        x = self.embedding(text)  # [B, T, encoder_dim]

        # Conv layers
        x = x.transpose(1, 2)  # [B, encoder_dim, T]
        for conv in self.convs:
            x = conv(x)
        x = x.transpose(1, 2)  # [B, T, encoder_dim]

        # Pack for LSTM
        x = nn.utils.rnn.pack_padded_sequence(
            x, text_lengths.cpu(), batch_first=True, enforce_sorted=False
        )

        # BiLSTM
        x, _ = self.lstm(x)

        # Unpack
        x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)

        return x


class LocationSensitiveAttention(nn.Module):
    """
    Location-sensitive attention for TTS.

    Extends additive attention with location features
    to encourage monotonic alignment progression.
    """

    def __init__(
        self,
        encoder_dim: int,
        decoder_dim: int,
        attention_dim: int,
        attention_location_n_filters: int,
        attention_location_kernel_size: int
    ):
        super().__init__()

        self.query_layer = nn.Linear(decoder_dim, attention_dim, bias=False)
        self.memory_layer = nn.Linear(encoder_dim, attention_dim, bias=False)
        self.location_layer = nn.Linear(attention_location_n_filters, attention_dim, bias=False)
        self.v = nn.Linear(attention_dim, 1, bias=False)

        # Location convolution
        self.location_conv = nn.Conv1d(
            2, attention_location_n_filters,
            kernel_size=attention_location_kernel_size,
            padding=(attention_location_kernel_size - 1) // 2,
            bias=False
        )

        self.score_mask_value = -float('inf')

    def forward(
        self,
        query: torch.Tensor,
        memory: torch.Tensor,
        processed_memory: torch.Tensor,
        attention_weights_cat: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Compute attention weights and context.

        Args:
            query: [B, decoder_dim] decoder hidden state
            memory: [B, T, encoder_dim] encoder outputs
            processed_memory: [B, T, attention_dim] pre-computed memory projection
            attention_weights_cat: [B, 2, T] previous attention weights (cumulative and current)
            mask: [B, T] encoder padding mask

        Returns:
            context: [B, encoder_dim] attention context
            attention_weights: [B, T] attention weights
        """
        # Query projection
        processed_query = self.query_layer(query.unsqueeze(1))  # [B, 1, attention_dim]

        # Location features
        processed_location = self.location_conv(attention_weights_cat)  # [B, filters, T]
        processed_location = self.location_layer(processed_location.transpose(1, 2))  # [B, T, attention_dim]

        # Compute attention scores
        energies = self.v(torch.tanh(
            processed_query + processed_memory + processed_location
        )).squeeze(-1)  # [B, T]

        # Apply mask
        if mask is not None:
            energies = energies.masked_fill(mask, self.score_mask_value)

        # Softmax
        attention_weights = F.softmax(energies, dim=-1)

        # Context
        context = torch.bmm(attention_weights.unsqueeze(1), memory).squeeze(1)

        return context, attention_weights


class Tacotron2Decoder(nn.Module):
    """
    Tacotron 2 autoregressive decoder.
    """

    def __init__(self, config: TTSConfig):
        super().__init__()
        self.config = config
        self.n_mels = config.n_mels

        # Prenet
        self.prenet = nn.Sequential(
            nn.Linear(config.n_mels, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Dropout(0.5)
        )

        # Attention RNN
        self.attention_rnn = nn.LSTMCell(
            256 + config.encoder_dim,
            config.decoder_dim
        )

        # Attention
        self.attention = LocationSensitiveAttention(
            config.encoder_dim,
            config.decoder_dim,
            attention_dim=128,
            attention_location_n_filters=32,
            attention_location_kernel_size=31
        )

        # Decoder RNN
        self.decoder_rnn = nn.LSTMCell(
            config.decoder_dim + config.encoder_dim,
            config.decoder_dim
        )

        # Output projection
        self.linear_projection = nn.Linear(
            config.decoder_dim + config.encoder_dim,
            config.n_mels
        )

        # Stop token prediction
        self.stop_layer = nn.Linear(
            config.decoder_dim + config.encoder_dim,
            1
        )

    def forward(
        self,
        encoder_outputs: torch.Tensor,
        text_lengths: torch.Tensor,
        mel_target: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Training forward with teacher forcing.
        """
        batch_size = encoder_outputs.size(0)
        max_len = mel_target.size(2)
        device = encoder_outputs.device

        # Initialize states
        attention_hidden = torch.zeros(batch_size, self.config.decoder_dim, device=device)
        attention_cell = torch.zeros(batch_size, self.config.decoder_dim, device=device)
        decoder_hidden = torch.zeros(batch_size, self.config.decoder_dim, device=device)
        decoder_cell = torch.zeros(batch_size, self.config.decoder_dim, device=device)

        attention_weights = torch.zeros(batch_size, encoder_outputs.size(1), device=device)
        attention_weights_cum = torch.zeros(batch_size, encoder_outputs.size(1), device=device)
        context = torch.zeros(batch_size, self.config.encoder_dim, device=device)

        # Pre-compute memory projection
        processed_memory = self.attention.memory_layer(encoder_outputs)

        # Create mask
        max_text_len = encoder_outputs.size(1)
        mask = torch.arange(max_text_len, device=device)[None, :] >= text_lengths[:, None]

        # First frame is zeros (go frame)
        go_frame = torch.zeros(batch_size, self.n_mels, device=device)

        # Outputs
        mel_outputs = []
        stop_outputs = []
        attention_weights_all = []

        for t in range(max_len):
            # Get input (previous frame or target with teacher forcing)
            if t == 0:
                decoder_input = go_frame
            else:
                decoder_input = mel_target[:, :, t - 1]

            # Prenet
            prenet_output = self.prenet(decoder_input)

            # Attention RNN
            attention_rnn_input = torch.cat([prenet_output, context], dim=-1)
            attention_hidden, attention_cell = self.attention_rnn(
                attention_rnn_input, (attention_hidden, attention_cell)
            )

            # Attention
            attention_weights_cat = torch.stack([attention_weights_cum, attention_weights], dim=1)
            context, attention_weights = self.attention(
                attention_hidden,
                encoder_outputs,
                processed_memory,
                attention_weights_cat,
                mask
            )
            attention_weights_cum = attention_weights_cum + attention_weights

            # Decoder RNN
            decoder_rnn_input = torch.cat([attention_hidden, context], dim=-1)
            decoder_hidden, decoder_cell = self.decoder_rnn(
                decoder_rnn_input, (decoder_hidden, decoder_cell)
            )

            # Output projection
            decoder_output = torch.cat([decoder_hidden, context], dim=-1)
            mel_output = self.linear_projection(decoder_output)
            stop_output = self.stop_layer(decoder_output)

            mel_outputs.append(mel_output)
            stop_outputs.append(stop_output)
            attention_weights_all.append(attention_weights)

        mel_outputs = torch.stack(mel_outputs, dim=2)  # [B, n_mels, T]
        stop_outputs = torch.cat(stop_outputs, dim=1)  # [B, T]
        attention_weights_all = torch.stack(attention_weights_all, dim=2)  # [B, T_enc, T_dec]

        return mel_outputs, stop_outputs, attention_weights_all


class Postnet(nn.Module):
    """
    Postnet: 5 conv layers to refine predicted mel spectrogram.
    """

    def __init__(self, n_mels: int, postnet_dim: int = 512, n_layers: int = 5):
        super().__init__()

        layers = []
        for i in range(n_layers):
            in_channels = n_mels if i == 0 else postnet_dim
            out_channels = n_mels if i == n_layers - 1 else postnet_dim

            layers.append(nn.Sequential(
                nn.Conv1d(
                    in_channels, out_channels,
                    kernel_size=5, padding=2
                ),
                nn.BatchNorm1d(out_channels),
                nn.Tanh() if i < n_layers - 1 else nn.Identity(),
                nn.Dropout(0.5)
            ))

        self.convs = nn.ModuleList(layers)

    def forward(self, mel: torch.Tensor) -> torch.Tensor:
        """
        Refine mel spectrogram.

        Args:
            mel: [B, n_mels, T]

        Returns:
            residual: [B, n_mels, T] refinement to add to mel
        """
        x = mel
        for conv in self.convs:
            x = conv(x)
        return x

FastSpeech: Non-Autoregressive TTS

FastSpeech removes autoregressive decoding for faster synthesis:

PYTHON
class FastSpeech2(nn.Module):
    """
    FastSpeech 2: Fast and High-Quality End-to-End Text to Speech.

    Key innovations:
    - Non-autoregressive parallel generation
    - Variance adaptor for duration, pitch, energy
    - Direct waveform training optional
    """

    def __init__(self, config: TTSConfig):
        super().__init__()

        # Encoder (phoneme/character)
        self.encoder = FastSpeechEncoder(config)

        # Variance adaptor
        self.variance_adaptor = VarianceAdaptor(config)

        # Decoder
        self.decoder = FastSpeechDecoder(config)

        # Mel output
        self.mel_linear = nn.Linear(config.decoder_dim, config.n_mels)

        # Optional postnet
        self.postnet = Postnet(config.n_mels)

    def forward(
        self,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        mel_lengths: Optional[torch.Tensor] = None,
        duration_targets: Optional[torch.Tensor] = None,
        pitch_targets: Optional[torch.Tensor] = None,
        energy_targets: Optional[torch.Tensor] = None
    ) -> Dict[str, torch.Tensor]:
        """
        Training forward pass.

        Args:
            text: [B, T_text] phoneme indices
            text_lengths: [B] text lengths
            mel_lengths: [B] mel lengths (for training)
            duration_targets: [B, T_text] ground truth durations
            pitch_targets: [B, T_mel] ground truth pitch
            energy_targets: [B, T_mel] ground truth energy
        """
        # Encode
        encoder_output, text_mask = self.encoder(text, text_lengths)

        # Variance adaptor (duration, pitch, energy)
        variance_output = self.variance_adaptor(
            encoder_output,
            text_mask,
            mel_lengths,
            duration_targets,
            pitch_targets,
            energy_targets
        )

        # Decode
        decoder_output = self.decoder(
            variance_output['output'],
            variance_output['mel_mask']
        )

        # Mel output
        mel_output = self.mel_linear(decoder_output)
        mel_postnet = mel_output + self.postnet(mel_output.transpose(1, 2)).transpose(1, 2)

        return {
            'mel_output': mel_output,
            'mel_postnet': mel_postnet,
            'duration_prediction': variance_output['duration_prediction'],
            'pitch_prediction': variance_output['pitch_prediction'],
            'energy_prediction': variance_output['energy_prediction'],
            'mel_mask': variance_output['mel_mask']
        }

    @torch.no_grad()
    def inference(
        self,
        text: torch.Tensor,
        duration_control: float = 1.0,
        pitch_control: float = 1.0,
        energy_control: float = 1.0
    ) -> torch.Tensor:
        """
        Generate mel spectrogram from text.

        Args:
            text: [B, T] phoneme indices
            duration_control: Speed factor (>1 = slower)
            pitch_control: Pitch factor
            energy_control: Energy factor
        """
        text_lengths = torch.tensor([text.size(1)], device=text.device)

        # Encode
        encoder_output, text_mask = self.encoder(text, text_lengths)

        # Variance adaptor with control
        variance_output = self.variance_adaptor.inference(
            encoder_output,
            text_mask,
            duration_control,
            pitch_control,
            energy_control
        )

        # Decode
        decoder_output = self.decoder(
            variance_output['output'],
            variance_output['mel_mask']
        )

        # Mel output
        mel_output = self.mel_linear(decoder_output)
        mel_postnet = mel_output + self.postnet(mel_output.transpose(1, 2)).transpose(1, 2)

        return mel_postnet


class FastSpeechEncoder(nn.Module):
    """FastSpeech encoder with FFT blocks."""

    def __init__(self, config: TTSConfig):
        super().__init__()

        self.embedding = nn.Embedding(config.vocab_size, config.encoder_dim)
        self.pos_encoding = SinusoidalPositionalEncoding(config.encoder_dim, config.max_text_length)

        self.layers = nn.ModuleList([
            FFTBlock(config.encoder_dim, config.n_heads, config.dropout)
            for _ in range(config.encoder_layers)
        ])

    def forward(
        self,
        text: torch.Tensor,
        text_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Embed and add position
        x = self.embedding(text)
        x = self.pos_encoding(x)

        # Create mask
        max_len = text.size(1)
        mask = torch.arange(max_len, device=text.device)[None, :] >= text_lengths[:, None]

        # FFT layers
        for layer in self.layers:
            x = layer(x, mask)

        return x, mask


class VarianceAdaptor(nn.Module):
    """
    Variance adaptor: predicts and applies duration, pitch, energy.
    """

    def __init__(self, config: TTSConfig):
        super().__init__()

        # Duration predictor
        self.duration_predictor = VariancePredictor(config.encoder_dim)

        # Pitch predictor and embedding
        self.pitch_predictor = VariancePredictor(config.encoder_dim)
        self.pitch_embedding = nn.Conv1d(1, config.encoder_dim, kernel_size=9, padding=4)

        # Energy predictor and embedding
        self.energy_predictor = VariancePredictor(config.encoder_dim)
        self.energy_embedding = nn.Conv1d(1, config.encoder_dim, kernel_size=9, padding=4)

        # Length regulator
        self.length_regulator = LengthRegulator()

    def forward(
        self,
        encoder_output: torch.Tensor,
        text_mask: torch.Tensor,
        mel_lengths: Optional[torch.Tensor] = None,
        duration_targets: Optional[torch.Tensor] = None,
        pitch_targets: Optional[torch.Tensor] = None,
        energy_targets: Optional[torch.Tensor] = None
    ) -> Dict[str, torch.Tensor]:
        """Training forward with ground truth targets."""

        # Duration prediction
        duration_prediction = self.duration_predictor(encoder_output, text_mask)

        # Use ground truth durations for length regulation
        if duration_targets is not None:
            output, mel_mask = self.length_regulator(encoder_output, duration_targets, mel_lengths)
        else:
            output, mel_mask = self.length_regulator(encoder_output, duration_prediction.exp().round())

        # Pitch
        pitch_prediction = self.pitch_predictor(output, mel_mask)
        if pitch_targets is not None:
            pitch_embedding = self.pitch_embedding(pitch_targets.unsqueeze(1)).transpose(1, 2)
        else:
            pitch_embedding = self.pitch_embedding(pitch_prediction.unsqueeze(1)).transpose(1, 2)
        output = output + pitch_embedding

        # Energy
        energy_prediction = self.energy_predictor(output, mel_mask)
        if energy_targets is not None:
            energy_embedding = self.energy_embedding(energy_targets.unsqueeze(1)).transpose(1, 2)
        else:
            energy_embedding = self.energy_embedding(energy_prediction.unsqueeze(1)).transpose(1, 2)
        output = output + energy_embedding

        return {
            'output': output,
            'mel_mask': mel_mask,
            'duration_prediction': duration_prediction,
            'pitch_prediction': pitch_prediction,
            'energy_prediction': energy_prediction
        }


class VariancePredictor(nn.Module):
    """Variance predictor: 2 conv layers + linear."""

    def __init__(self, hidden_dim: int):
        super().__init__()

        self.conv1 = nn.Sequential(
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(0.2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(0.2)
        )
        self.linear = nn.Linear(hidden_dim, 1)

    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        x = x.transpose(1, 2)
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.transpose(1, 2)
        x = self.linear(x).squeeze(-1)

        if mask is not None:
            x = x.masked_fill(mask, 0.0)

        return x


class LengthRegulator(nn.Module):
    """
    Length regulator: expands encoder output based on durations.
    """

    def forward(
        self,
        encoder_output: torch.Tensor,
        durations: torch.Tensor,
        mel_lengths: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Expand encoder output according to durations.

        Args:
            encoder_output: [B, T_text, dim]
            durations: [B, T_text] duration for each encoder frame

        Returns:
            expanded: [B, T_mel, dim]
            mask: [B, T_mel]
        """
        durations = durations.long()
        batch_size = encoder_output.size(0)
        device = encoder_output.device

        # Calculate output lengths
        if mel_lengths is None:
            mel_lengths = durations.sum(dim=1)
        max_len = mel_lengths.max().item()

        # Expand
        expanded = []
        for b in range(batch_size):
            frames = []
            for i, d in enumerate(durations[b]):
                frames.extend([encoder_output[b, i]] * d.item())

            # Pad or truncate
            if len(frames) < max_len:
                frames.extend([torch.zeros_like(frames[0])] * (max_len - len(frames)))
            frames = frames[:max_len]

            expanded.append(torch.stack(frames))

        expanded = torch.stack(expanded)

        # Create mask
        mask = torch.arange(max_len, device=device)[None, :] >= mel_lengths[:, None]

        return expanded, mask


class FFTBlock(nn.Module):
    """Feed-Forward Transformer block."""

    def __init__(self, dim: int, n_heads: int, dropout: float):
        super().__init__()

        self.self_attn = nn.MultiheadAttention(dim, n_heads, dropout=dropout, batch_first=True)
        self.ln1 = nn.LayerNorm(dim)

        self.conv1 = nn.Conv1d(dim, dim * 4, kernel_size=9, padding=4)
        self.conv2 = nn.Conv1d(dim * 4, dim, kernel_size=1)
        self.ln2 = nn.LayerNorm(dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        # Self-attention
        residual = x
        x = self.ln1(x)
        x, _ = self.self_attn(x, x, x, key_padding_mask=mask)
        x = residual + self.dropout(x)

        # Conv FFN
        residual = x
        x = self.ln2(x)
        x = x.transpose(1, 2)
        x = F.relu(self.conv1(x))
        x = self.conv2(x)
        x = x.transpose(1, 2)
        x = residual + self.dropout(x)

        return x


class SinusoidalPositionalEncoding(nn.Module):
    """Sinusoidal positional encoding."""

    def __init__(self, dim: int, max_len: int = 5000):
        super().__init__()
        pe = torch.zeros(max_len, dim)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, dim, 2).float() * (-np.log(10000.0) / dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.pe[:, :x.size(1)]


class FastSpeechDecoder(nn.Module):
    """FastSpeech decoder."""

    def __init__(self, config: TTSConfig):
        super().__init__()

        self.pos_encoding = SinusoidalPositionalEncoding(config.decoder_dim, config.max_mel_length)

        self.layers = nn.ModuleList([
            FFTBlock(config.decoder_dim, config.n_heads, config.dropout)
            for _ in range(config.decoder_layers)
        ])

    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        x = self.pos_encoding(x)

        for layer in self.layers:
            x = layer(x, mask)

        return x

Neural Vocoders

Vocoders convert mel spectrograms to waveforms:

PYTHON
class HiFiGAN(nn.Module):
    """
    HiFi-GAN: High Fidelity Neural Vocoder.

    Generative adversarial vocoder with:
    - Multi-receptive field fusion (MRF) generator
    - Multi-scale discriminator (MSD)
    - Multi-period discriminator (MPD)
    """

    def __init__(
        self,
        n_mels: int = 80,
        upsample_rates: List[int] = [8, 8, 2, 2],
        upsample_kernel_sizes: List[int] = [16, 16, 4, 4],
        upsample_initial_channel: int = 512,
        resblock_kernel_sizes: List[int] = [3, 7, 11],
        resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
    ):
        super().__init__()

        self.num_kernels = len(resblock_kernel_sizes)
        self.num_upsamples = len(upsample_rates)

        # Initial convolution
        self.conv_pre = nn.Conv1d(n_mels, upsample_initial_channel, kernel_size=7, padding=3)

        # Upsampling layers
        self.ups = nn.ModuleList()
        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
            self.ups.append(
                nn.ConvTranspose1d(
                    upsample_initial_channel // (2 ** i),
                    upsample_initial_channel // (2 ** (i + 1)),
                    kernel_size=k,
                    stride=u,
                    padding=(k - u) // 2
                )
            )

        # Residual blocks with multi-receptive field fusion
        self.resblocks = nn.ModuleList()
        for i in range(len(self.ups)):
            ch = upsample_initial_channel // (2 ** (i + 1))
            for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
                self.resblocks.append(ResBlock(ch, k, d))

        # Final convolution
        self.conv_post = nn.Conv1d(ch, 1, kernel_size=7, padding=3)

    def forward(self, mel: torch.Tensor) -> torch.Tensor:
        """
        Generate waveform from mel spectrogram.

        Args:
            mel: [B, n_mels, T_mel]

        Returns:
            waveform: [B, 1, T_audio]
        """
        x = self.conv_pre(mel)

        for i, up in enumerate(self.ups):
            x = F.leaky_relu(x, 0.1)
            x = up(x)

            # Multi-receptive field fusion
            xs = None
            for j in range(self.num_kernels):
                if xs is None:
                    xs = self.resblocks[i * self.num_kernels + j](x)
                else:
                    xs += self.resblocks[i * self.num_kernels + j](x)
            x = xs / self.num_kernels

        x = F.leaky_relu(x)
        x = self.conv_post(x)
        x = torch.tanh(x)

        return x


class ResBlock(nn.Module):
    """Residual block for HiFi-GAN."""

    def __init__(self, channels: int, kernel_size: int, dilation: List[int]):
        super().__init__()

        self.convs1 = nn.ModuleList()
        self.convs2 = nn.ModuleList()

        for d in dilation:
            self.convs1.append(
                nn.Conv1d(
                    channels, channels,
                    kernel_size=kernel_size,
                    dilation=d,
                    padding=self._get_padding(kernel_size, d)
                )
            )
            self.convs2.append(
                nn.Conv1d(
                    channels, channels,
                    kernel_size=kernel_size,
                    dilation=1,
                    padding=self._get_padding(kernel_size, 1)
                )
            )

    def _get_padding(self, kernel_size: int, dilation: int) -> int:
        return (kernel_size * dilation - dilation) // 2

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for c1, c2 in zip(self.convs1, self.convs2):
            xt = F.leaky_relu(x, 0.1)
            xt = c1(xt)
            xt = F.leaky_relu(xt, 0.1)
            xt = c2(xt)
            x = xt + x
        return x


class MultiScaleDiscriminator(nn.Module):
    """Multi-scale discriminator for HiFi-GAN."""

    def __init__(self):
        super().__init__()
        self.discriminators = nn.ModuleList([
            ScaleDiscriminator(use_spectral_norm=True),
            ScaleDiscriminator(),
            ScaleDiscriminator()
        ])
        self.pooling = nn.AvgPool1d(kernel_size=4, stride=2, padding=2)

    def forward(self, y: torch.Tensor) -> List[Tuple[torch.Tensor, List[torch.Tensor]]]:
        outputs = []
        for i, d in enumerate(self.discriminators):
            if i > 0:
                y = self.pooling(y)
            out, fmap = d(y)
            outputs.append((out, fmap))
        return outputs


class ScaleDiscriminator(nn.Module):
    """Single scale discriminator."""

    def __init__(self, use_spectral_norm: bool = False):
        super().__init__()
        norm_f = nn.utils.spectral_norm if use_spectral_norm else lambda x: x

        self.convs = nn.ModuleList([
            norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)),
            norm_f(nn.Conv1d(128, 128, 41, 2, groups=4, padding=20)),
            norm_f(nn.Conv1d(128, 256, 41, 2, groups=16, padding=20)),
            norm_f(nn.Conv1d(256, 512, 41, 4, groups=16, padding=20)),
            norm_f(nn.Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
            norm_f(nn.Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
            norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
        ])
        self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        fmap = []
        for conv in self.convs:
            x = F.leaky_relu(conv(x), 0.1)
            fmap.append(x)
        x = self.conv_post(x)
        fmap.append(x)
        return x, fmap


def vocoder_comparison():
    """Compare neural vocoders."""
    vocoders = {
        'WaveNet': {
            'type': 'Autoregressive',
            'speed': 'Very slow (~0.1x real-time)',
            'quality': 'Excellent',
            'note': 'Original neural vocoder, dilated convolutions'
        },
        'WaveRNN': {
            'type': 'Autoregressive',
            'speed': 'Slow (~4x real-time with optimization)',
            'quality': 'Excellent',
            'note': 'Efficient RNN, sample-by-sample generation'
        },
        'WaveGlow': {
            'type': 'Flow-based',
            'speed': 'Fast (real-time on GPU)',
            'quality': 'Very good',
            'note': 'Parallel generation via normalizing flows'
        },
        'HiFi-GAN': {
            'type': 'GAN',
            'speed': 'Very fast (real-time on CPU)',
            'quality': 'Excellent',
            'note': 'Multi-scale + multi-period discriminators'
        },
        'UnivNet': {
            'type': 'GAN',
            'speed': 'Very fast',
            'quality': 'Excellent',
            'note': 'Multi-resolution spectrogram discriminator'
        },
        'BigVGAN': {
            'type': 'GAN',
            'speed': 'Fast',
            'quality': 'State-of-the-art',
            'note': 'Large-scale training, anti-aliased activations'
        }
    }

    print("Neural Vocoder Comparison:")
    print("=" * 60)
    for name, info in vocoders.items():
        print(f"\n{name}:")
        for k, v in info.items():
            print(f"  {k}: {v}")


vocoder_comparison()

Key Takeaways

Modern TTS has achieved remarkable naturalness through neural approaches. Key components include: (1) sequence-to-sequence acoustic models like Tacotron 2 that learn attention-based alignment between text and speech, (2) non-autoregressive models like FastSpeech that enable parallel generation with explicit duration/pitch/energy control, (3) neural vocoders like HiFi-GAN that convert mel spectrograms to high-fidelity waveforms in real-time. The evolution from autoregressive to parallel architectures dramatically improved synthesis speed while maintaining quality. Modern systems can synthesize speech faster than real-time on CPUs, enabling widespread deployment. Key challenges remain in prosody modeling, expressiveness control, and reducing the gap to human naturalness in challenging scenarios.

23.4 Speaker Recognition and Voice Cloning Advanced

Speaker Recognition and Voice Cloning

Speaker recognition identifies individuals by their voice characteristics, while voice cloning synthesizes speech that mimics a target speaker's voice. These technologies enable personalized TTS systems, voice assistants, biometric authentication, and content creation tools. This section explores speaker embedding networks, verification/identification systems, and neural voice cloning techniques that capture and reproduce speaker identity.

Speaker Embeddings

Speaker embeddings are fixed-dimensional vectors that capture speaker identity:

PYTHON
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List, Dict
from dataclasses import dataclass
import numpy as np

@dataclass
class SpeakerConfig:
    """Configuration for speaker recognition models."""
    n_mels: int = 80
    embedding_dim: int = 256
    n_speakers: int = 1000  # For training
    sample_rate: int = 16000
    encoder_channels: List[int] = None

    def __post_init__(self):
        if self.encoder_channels is None:
            self.encoder_channels = [512, 512, 512, 512, 1536]


class DVectorEncoder(nn.Module):
    """
    D-Vector speaker encoder based on GE2E (Generalized End-to-End) loss.

    Architecture:
    - 3-layer LSTM
    - Final hidden state projected to embedding
    - L2 normalization
    """

    def __init__(self, config: SpeakerConfig):
        super().__init__()

        # Mel input projection
        self.input_proj = nn.Linear(config.n_mels, 256)

        # LSTM layers
        self.lstm = nn.LSTM(
            input_size=256,
            hidden_size=768,
            num_layers=3,
            batch_first=True
        )

        # Embedding projection
        self.embedding = nn.Linear(768, config.embedding_dim)

    def forward(
        self,
        mel: torch.Tensor,
        lengths: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Extract d-vector from mel spectrogram.

        Args:
            mel: [B, T, n_mels] mel spectrogram
            lengths: [B] actual lengths

        Returns:
            embeddings: [B, embedding_dim] L2-normalized speaker embedding
        """
        batch_size = mel.size(0)

        # Project input
        x = F.relu(self.input_proj(mel))

        # LSTM
        if lengths is not None:
            x = nn.utils.rnn.pack_padded_sequence(
                x, lengths.cpu(), batch_first=True, enforce_sorted=False
            )

        _, (hidden, _) = self.lstm(x)

        # Take final hidden state from last layer
        x = hidden[-1]  # [B, 768]

        # Project to embedding
        x = self.embedding(x)

        # L2 normalize
        x = F.normalize(x, p=2, dim=-1)

        return x

    def embed_utterance(
        self,
        mel: torch.Tensor,
        partial_length: int = 160,
        min_coverage: float = 0.5
    ) -> torch.Tensor:
        """
        Embed utterance using partial utterance approach for variable length.

        Extracts embeddings from overlapping windows and averages them.
        """
        if mel.dim() == 2:
            mel = mel.unsqueeze(0)

        total_frames = mel.size(1)

        if total_frames < partial_length:
            # Pad short utterances
            padding = torch.zeros(1, partial_length - total_frames, mel.size(2), device=mel.device)
            mel = torch.cat([mel, padding], dim=1)
            partials = [mel]
        else:
            # Sliding window with overlap
            hop = int(partial_length * (1 - min_coverage))
            partials = []
            for start in range(0, total_frames - partial_length + 1, hop):
                partials.append(mel[:, start:start + partial_length, :])

        # Stack and embed
        partials = torch.cat(partials, dim=0)
        embeddings = self.forward(partials)

        # Average and normalize
        embedding = embeddings.mean(dim=0, keepdim=True)
        embedding = F.normalize(embedding, p=2, dim=-1)

        return embedding


class GE2ELoss(nn.Module):
    """
    Generalized End-to-End Loss for speaker verification.

    Trains embeddings so that same-speaker utterances cluster together
    and different speakers are pushed apart.
    """

    def __init__(self, init_w: float = 10.0, init_b: float = -5.0):
        super().__init__()
        self.w = nn.Parameter(torch.tensor(init_w))
        self.b = nn.Parameter(torch.tensor(init_b))

    def forward(
        self,
        embeddings: torch.Tensor,
        n_speakers: int,
        n_utterances: int
    ) -> torch.Tensor:
        """
        Compute GE2E loss.

        Args:
            embeddings: [N, embedding_dim] where N = n_speakers * n_utterances
            n_speakers: Number of speakers in batch
            n_utterances: Number of utterances per speaker

        Returns:
            loss: Scalar loss value
        """
        # Reshape: [n_speakers, n_utterances, embedding_dim]
        embeddings = embeddings.view(n_speakers, n_utterances, -1)

        # Compute centroids (excluding self for positive pairs)
        centroids_incl = embeddings.mean(dim=1)  # [n_speakers, embedding_dim]

        # Similarity matrix
        losses = []

        for speaker_idx in range(n_speakers):
            for utt_idx in range(n_utterances):
                # Centroid excluding current utterance
                mask = torch.ones(n_utterances, device=embeddings.device, dtype=torch.bool)
                mask[utt_idx] = False
                centroid_excl = embeddings[speaker_idx, mask].mean(dim=0)

                # Current embedding
                e = embeddings[speaker_idx, utt_idx]

                # Positive similarity (same speaker)
                sim_pos = self.w * F.cosine_similarity(e.unsqueeze(0), centroid_excl.unsqueeze(0)) + self.b

                # Negative similarities (different speakers)
                sim_neg = self.w * F.cosine_similarity(e.unsqueeze(0), centroids_incl) + self.b
                sim_neg[speaker_idx] = float('-inf')  # Exclude same speaker

                # Softmax loss
                sim_all = torch.cat([sim_pos, sim_neg])
                loss = -F.log_softmax(sim_all, dim=0)[0]
                losses.append(loss)

        return torch.stack(losses).mean()


class XVectorEncoder(nn.Module):
    """
    X-Vector speaker encoder (TDNN-based).

    Architecture:
    - TDNN layers with frame-level features
    - Statistics pooling (mean + std)
    - Segment-level DNN
    """

    def __init__(self, config: SpeakerConfig):
        super().__init__()

        # TDNN frame-level layers
        self.tdnn = nn.Sequential(
            TDNNBlock(config.n_mels, 512, context=[-2, -1, 0, 1, 2]),
            TDNNBlock(512, 512, context=[-2, 0, 2]),
            TDNNBlock(512, 512, context=[-3, 0, 3]),
            TDNNBlock(512, 512, context=[0]),
            TDNNBlock(512, 1500, context=[0])
        )

        # Statistics pooling
        self.stats_pool = StatisticsPooling()

        # Segment-level layers
        self.segment = nn.Sequential(
            nn.Linear(3000, 512),  # Mean + std = 2 * 1500
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Linear(512, config.embedding_dim),
            nn.BatchNorm1d(config.embedding_dim)
        )

    def forward(
        self,
        mel: torch.Tensor,
        lengths: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Extract x-vector from mel spectrogram.

        Args:
            mel: [B, T, n_mels]

        Returns:
            embeddings: [B, embedding_dim]
        """
        # TDNN (expects [B, C, T])
        x = mel.transpose(1, 2)
        x = self.tdnn(x)

        # Statistics pooling
        x = self.stats_pool(x, lengths)

        # Segment-level
        x = self.segment(x)

        return F.normalize(x, p=2, dim=-1)


class TDNNBlock(nn.Module):
    """Time-Delay Neural Network block."""

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        context: List[int]
    ):
        super().__init__()
        self.context = context

        # Dilation to handle context
        kernel_size = max(context) - min(context) + 1
        padding = (kernel_size - 1) // 2

        self.conv = nn.Conv1d(
            in_channels, out_channels,
            kernel_size=kernel_size,
            padding=padding
        )
        self.bn = nn.BatchNorm1d(out_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv(x)
        x = self.bn(x)
        x = F.relu(x)
        return x


class StatisticsPooling(nn.Module):
    """Statistics pooling: compute mean and std over time."""

    def forward(
        self,
        x: torch.Tensor,
        lengths: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Args:
            x: [B, C, T]
            lengths: [B] optional lengths for masked pooling

        Returns:
            stats: [B, 2*C] concatenated mean and std
        """
        if lengths is not None:
            # Masked statistics
            batch_size = x.size(0)
            means = []
            stds = []

            for b in range(batch_size):
                valid = x[b, :, :lengths[b]]
                means.append(valid.mean(dim=-1))
                stds.append(valid.std(dim=-1))

            mean = torch.stack(means)
            std = torch.stack(stds)
        else:
            mean = x.mean(dim=-1)
            std = x.std(dim=-1)

        return torch.cat([mean, std], dim=-1)


class ECAPATDNNEncoder(nn.Module):
    """
    ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation.

    State-of-the-art speaker embedding network with:
    - SE-Res2Net blocks
    - Multi-layer feature aggregation
    - Attentive statistics pooling
    """

    def __init__(self, config: SpeakerConfig):
        super().__init__()
        channels = config.encoder_channels

        # Initial convolution
        self.conv1 = nn.Conv1d(config.n_mels, channels[0], kernel_size=5, padding=2)
        self.bn1 = nn.BatchNorm1d(channels[0])

        # SE-Res2Net blocks
        self.blocks = nn.ModuleList([
            SERes2NetBlock(channels[0], channels[1], kernel_size=3, dilation=2),
            SERes2NetBlock(channels[1], channels[2], kernel_size=3, dilation=3),
            SERes2NetBlock(channels[2], channels[3], kernel_size=3, dilation=4)
        ])

        # Multi-layer feature aggregation
        self.mfa_conv = nn.Conv1d(
            channels[1] + channels[2] + channels[3],
            channels[4],
            kernel_size=1
        )

        # Attentive statistics pooling
        self.asp = AttentiveStatisticsPooling(channels[4], attention_dim=128)

        # Final embedding
        self.fc = nn.Linear(channels[4] * 2, config.embedding_dim)
        self.bn2 = nn.BatchNorm1d(config.embedding_dim)

    def forward(
        self,
        mel: torch.Tensor,
        lengths: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Extract ECAPA-TDNN embedding."""
        x = mel.transpose(1, 2)  # [B, n_mels, T]

        # Initial conv
        x = F.relu(self.bn1(self.conv1(x)))

        # SE-Res2Net blocks with skip connections
        outputs = []
        for block in self.blocks:
            x = block(x)
            outputs.append(x)

        # Multi-layer feature aggregation
        x = torch.cat(outputs, dim=1)
        x = F.relu(self.mfa_conv(x))

        # Attentive statistics pooling
        x = self.asp(x, lengths)

        # Final embedding
        x = self.bn2(self.fc(x))

        return F.normalize(x, p=2, dim=-1)


class SERes2NetBlock(nn.Module):
    """SE-Res2Net block for ECAPA-TDNN."""

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        dilation: int = 1,
        scale: int = 8
    ):
        super().__init__()
        self.scale = scale
        width = out_channels // scale

        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=1)
        self.bn1 = nn.BatchNorm1d(out_channels)

        # Res2Net: hierarchical residual-like connections
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        for _ in range(scale - 1):
            self.convs.append(nn.Conv1d(
                width, width, kernel_size, padding=dilation, dilation=dilation
            ))
            self.bns.append(nn.BatchNorm1d(width))

        self.conv3 = nn.Conv1d(out_channels, out_channels, kernel_size=1)
        self.bn3 = nn.BatchNorm1d(out_channels)

        # SE block
        self.se = SqueezeExcitation(out_channels, reduction=8)

        # Shortcut
        self.shortcut = nn.Identity() if in_channels == out_channels else \
            nn.Conv1d(in_channels, out_channels, kernel_size=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = self.shortcut(x)

        x = F.relu(self.bn1(self.conv1(x)))

        # Res2Net
        chunks = x.chunk(self.scale, dim=1)
        outputs = [chunks[0]]
        for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
            if i == 0:
                y = conv(chunks[i + 1])
            else:
                y = conv(chunks[i + 1] + outputs[-1])
            outputs.append(F.relu(bn(y)))

        x = torch.cat(outputs, dim=1)
        x = F.relu(self.bn3(self.conv3(x)))

        # SE
        x = self.se(x)

        return x + residual


class SqueezeExcitation(nn.Module):
    """Squeeze-and-Excitation channel attention."""

    def __init__(self, channels: int, reduction: int = 8):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction),
            nn.ReLU(),
            nn.Linear(channels // reduction, channels),
            nn.Sigmoid()
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Global average pooling
        w = x.mean(dim=-1)
        w = self.fc(w).unsqueeze(-1)
        return x * w


class AttentiveStatisticsPooling(nn.Module):
    """Attentive statistics pooling for ECAPA-TDNN."""

    def __init__(self, channels: int, attention_dim: int):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Conv1d(channels * 3, attention_dim, kernel_size=1),
            nn.ReLU(),
            nn.Conv1d(attention_dim, channels, kernel_size=1),
            nn.Softmax(dim=-1)
        )

    def forward(
        self,
        x: torch.Tensor,
        lengths: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        # Global statistics for attention input
        mean = x.mean(dim=-1, keepdim=True).expand_as(x)
        std = x.std(dim=-1, keepdim=True).expand_as(x)

        attn_input = torch.cat([x, mean, std], dim=1)
        attn_weights = self.attention(attn_input)

        # Weighted statistics
        weighted_mean = (x * attn_weights).sum(dim=-1)
        weighted_std = torch.sqrt(((x ** 2) * attn_weights).sum(dim=-1) - weighted_mean ** 2 + 1e-6)

        return torch.cat([weighted_mean, weighted_std], dim=-1)

Speaker Verification and Identification

Verification (1:1) confirms identity; identification (1:N) finds who's speaking:

PYTHON
class SpeakerVerificationSystem:
    """
    Speaker verification: is this person who they claim to be?

    Compares enrollment embedding with test embedding using cosine similarity.
    """

    def __init__(self, encoder: nn.Module, threshold: float = 0.5):
        self.encoder = encoder
        self.threshold = threshold
        self.enrollments: Dict[str, torch.Tensor] = {}

    @torch.no_grad()
    def enroll(
        self,
        speaker_id: str,
        mel_spectrograms: List[torch.Tensor]
    ) -> None:
        """
        Enroll a speaker with multiple utterances.

        Args:
            speaker_id: Unique speaker identifier
            mel_spectrograms: List of [T, n_mels] mel spectrograms
        """
        self.encoder.eval()

        embeddings = []
        for mel in mel_spectrograms:
            if mel.dim() == 2:
                mel = mel.unsqueeze(0)
            embedding = self.encoder(mel)
            embeddings.append(embedding)

        # Average embeddings
        avg_embedding = torch.stack(embeddings).mean(dim=0)
        avg_embedding = F.normalize(avg_embedding, p=2, dim=-1)

        self.enrollments[speaker_id] = avg_embedding
        print(f"Enrolled speaker '{speaker_id}' with {len(mel_spectrograms)} utterances")

    @torch.no_grad()
    def verify(
        self,
        speaker_id: str,
        test_mel: torch.Tensor
    ) -> Tuple[bool, float]:
        """
        Verify if test utterance matches claimed speaker.

        Returns:
            accepted: Whether verification passed
            score: Cosine similarity score
        """
        if speaker_id not in self.enrollments:
            raise ValueError(f"Speaker '{speaker_id}' not enrolled")

        self.encoder.eval()

        if test_mel.dim() == 2:
            test_mel = test_mel.unsqueeze(0)

        test_embedding = self.encoder(test_mel)
        enrolled_embedding = self.enrollments[speaker_id]

        # Cosine similarity
        score = F.cosine_similarity(test_embedding, enrolled_embedding).item()
        accepted = score >= self.threshold

        return accepted, score

    def compute_eer(
        self,
        positive_scores: List[float],
        negative_scores: List[float]
    ) -> Tuple[float, float]:
        """
        Compute Equal Error Rate (EER).

        EER is where False Acceptance Rate = False Rejection Rate.
        Lower is better.
        """
        all_scores = positive_scores + negative_scores
        labels = [1] * len(positive_scores) + [0] * len(negative_scores)

        # Sort by score
        sorted_pairs = sorted(zip(all_scores, labels), reverse=True)

        # Compute FAR and FRR at each threshold
        n_pos = len(positive_scores)
        n_neg = len(negative_scores)

        far_list = []
        frr_list = []
        thresholds = []

        true_accepts = 0
        false_accepts = 0

        for score, label in sorted_pairs:
            if label == 1:
                true_accepts += 1
            else:
                false_accepts += 1

            far = false_accepts / n_neg
            frr = 1 - (true_accepts / n_pos)

            far_list.append(far)
            frr_list.append(frr)
            thresholds.append(score)

        # Find EER (where FAR ≈ FRR)
        min_diff = float('inf')
        eer = 0
        eer_threshold = 0

        for far, frr, thr in zip(far_list, frr_list, thresholds):
            diff = abs(far - frr)
            if diff < min_diff:
                min_diff = diff
                eer = (far + frr) / 2
                eer_threshold = thr

        return eer, eer_threshold


class SpeakerIdentificationSystem:
    """
    Speaker identification: who is speaking?

    Compares test embedding against all enrolled speakers.
    """

    def __init__(self, encoder: nn.Module):
        self.encoder = encoder
        self.enrollments: Dict[str, torch.Tensor] = {}

    @torch.no_grad()
    def enroll(
        self,
        speaker_id: str,
        mel_spectrograms: List[torch.Tensor]
    ) -> None:
        """Enroll speaker (same as verification)."""
        self.encoder.eval()

        embeddings = []
        for mel in mel_spectrograms:
            if mel.dim() == 2:
                mel = mel.unsqueeze(0)
            embedding = self.encoder(mel)
            embeddings.append(embedding)

        avg_embedding = torch.stack(embeddings).mean(dim=0)
        avg_embedding = F.normalize(avg_embedding, p=2, dim=-1)
        self.enrollments[speaker_id] = avg_embedding

    @torch.no_grad()
    def identify(
        self,
        test_mel: torch.Tensor,
        top_k: int = 1
    ) -> List[Tuple[str, float]]:
        """
        Identify speaker from test utterance.

        Returns:
            Top-k (speaker_id, score) pairs ranked by similarity
        """
        self.encoder.eval()

        if test_mel.dim() == 2:
            test_mel = test_mel.unsqueeze(0)

        test_embedding = self.encoder(test_mel)

        # Compare with all enrolled speakers
        scores = []
        for speaker_id, enrolled_embedding in self.enrollments.items():
            score = F.cosine_similarity(test_embedding, enrolled_embedding).item()
            scores.append((speaker_id, score))

        # Sort by score descending
        scores.sort(key=lambda x: x[1], reverse=True)

        return scores[:top_k]

    def compute_accuracy(
        self,
        test_data: List[Tuple[torch.Tensor, str]]
    ) -> Dict[str, float]:
        """
        Compute identification accuracy metrics.

        Args:
            test_data: List of (mel, true_speaker_id) pairs
        """
        correct_top1 = 0
        correct_top5 = 0

        for mel, true_speaker in test_data:
            top_5 = self.identify(mel, top_k=5)

            if top_5[0][0] == true_speaker:
                correct_top1 += 1

            if any(speaker == true_speaker for speaker, _ in top_5):
                correct_top5 += 1

        n_total = len(test_data)

        return {
            'top1_accuracy': correct_top1 / n_total,
            'top5_accuracy': correct_top5 / n_total
        }


class AAMSoftmaxLoss(nn.Module):
    """
    Additive Angular Margin Softmax (ArcFace) loss for speaker recognition.

    Adds angular margin to enhance discriminative power.
    """

    def __init__(
        self,
        embedding_dim: int,
        n_classes: int,
        margin: float = 0.2,
        scale: float = 30.0
    ):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(n_classes, embedding_dim))
        nn.init.xavier_uniform_(self.weight)

        self.margin = margin
        self.scale = scale

    def forward(
        self,
        embeddings: torch.Tensor,
        labels: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute AAM-Softmax loss.

        Args:
            embeddings: [B, embedding_dim] L2-normalized embeddings
            labels: [B] speaker labels
        """
        # Normalize weight
        weight = F.normalize(self.weight, p=2, dim=-1)

        # Cosine similarity = dot product of normalized vectors
        cosine = F.linear(embeddings, weight)  # [B, n_classes]

        # Add margin to target class
        # cos(θ + m) = cos(θ)cos(m) - sin(θ)sin(m)
        theta = torch.acos(torch.clamp(cosine, -1 + 1e-7, 1 - 1e-7))
        target_logits = torch.cos(theta + self.margin)

        # Create one-hot mask
        one_hot = F.one_hot(labels, num_classes=self.weight.size(0)).float()

        # Replace target class logits with margin-added version
        logits = cosine * (1 - one_hot) + target_logits * one_hot

        # Scale and compute loss
        logits = logits * self.scale

        return F.cross_entropy(logits, labels)


def speaker_embedding_comparison():
    """Compare speaker embedding methods."""
    methods = {
        'i-vector': {
            'type': 'Traditional',
            'approach': 'Factor analysis on GMM supervectors',
            'dim': '400-600',
            'pros': 'Well-understood, interpretable',
            'cons': 'Requires speaker-specific adaptation'
        },
        'd-vector': {
            'type': 'Neural (frame-level)',
            'approach': 'LSTM with GE2E loss',
            'dim': '256',
            'pros': 'Simple, effective for short utterances',
            'cons': 'Variable-length handling'
        },
        'x-vector': {
            'type': 'Neural (segment-level)',
            'approach': 'TDNN + statistics pooling',
            'dim': '512',
            'pros': 'Robust, well-suited for variable length',
            'cons': 'Larger model'
        },
        'ECAPA-TDNN': {
            'type': 'Neural (SOTA)',
            'approach': 'SE-Res2Net + attentive pooling',
            'dim': '192-256',
            'pros': 'State-of-the-art performance',
            'cons': 'More complex architecture'
        },
        'wav2vec 2.0 + ECAPA': {
            'type': 'Self-supervised + supervised',
            'approach': 'Pretrained features + speaker head',
            'dim': '192',
            'pros': 'Leverages unlabeled data',
            'cons': 'Large pretrained model required'
        }
    }

    print("Speaker Embedding Method Comparison:")
    print("=" * 60)
    for name, info in methods.items():
        print(f"\n{name}:")
        for k, v in info.items():
            print(f"  {k}: {v}")


speaker_embedding_comparison()

Voice Cloning

Voice cloning synthesizes speech in a target speaker's voice:

PYTHON
class ZeroShotVoiceCloner(nn.Module):
    """
    Zero-shot voice cloning using speaker embeddings.

    Conditions TTS model on speaker embedding extracted from
    reference audio, enabling synthesis in any voice without
    fine-tuning.
    """

    def __init__(
        self,
        speaker_encoder: nn.Module,
        tts_model: nn.Module,
        embedding_dim: int = 256
    ):
        super().__init__()
        self.speaker_encoder = speaker_encoder
        self.tts = tts_model

        # Project speaker embedding to TTS conditioning space
        self.speaker_proj = nn.Linear(embedding_dim, tts_model.config.encoder_dim)

        # Freeze speaker encoder
        for param in self.speaker_encoder.parameters():
            param.requires_grad = False

    def forward(
        self,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        reference_mel: torch.Tensor,
        mel_target: Optional[torch.Tensor] = None
    ) -> Dict[str, torch.Tensor]:
        """
        Training forward with speaker conditioning.
        """
        # Extract speaker embedding
        with torch.no_grad():
            speaker_embedding = self.speaker_encoder(reference_mel)

        # Project to conditioning space
        speaker_cond = self.speaker_proj(speaker_embedding)

        # Condition TTS
        outputs = self.tts(
            text,
            text_lengths,
            mel_target,
            speaker_embedding=speaker_cond
        )

        return outputs

    @torch.no_grad()
    def clone_voice(
        self,
        text: torch.Tensor,
        reference_audio_mel: torch.Tensor
    ) -> torch.Tensor:
        """
        Synthesize speech in target speaker's voice.

        Args:
            text: [1, T_text] input phonemes/characters
            reference_audio_mel: [1, T, n_mels] reference mel spectrogram

        Returns:
            mel_output: [1, n_mels, T_mel] synthesized mel spectrogram
        """
        self.eval()

        # Extract speaker embedding from reference
        speaker_embedding = self.speaker_encoder(reference_audio_mel)
        speaker_cond = self.speaker_proj(speaker_embedding)

        # Generate with speaker conditioning
        mel_output = self.tts.inference(
            text,
            speaker_embedding=speaker_cond
        )

        return mel_output


class SpeakerConditionedTTS(nn.Module):
    """
    FastSpeech 2 variant with speaker conditioning.
    """

    def __init__(self, config):
        super().__init__()
        self.config = config

        # Standard TTS components
        self.encoder = FastSpeechEncoder(config)
        self.variance_adaptor = VarianceAdaptor(config)
        self.decoder = FastSpeechDecoder(config)
        self.mel_linear = nn.Linear(config.decoder_dim, config.n_mels)

        # Speaker conditioning layers
        self.speaker_encoder_proj = nn.Linear(config.encoder_dim, config.encoder_dim)
        self.speaker_decoder_proj = nn.Linear(config.decoder_dim, config.decoder_dim)

    def forward(
        self,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        mel_target: Optional[torch.Tensor] = None,
        speaker_embedding: Optional[torch.Tensor] = None,
        **kwargs
    ) -> Dict[str, torch.Tensor]:
        """Forward with optional speaker conditioning."""
        # Encode text
        encoder_output, text_mask = self.encoder(text, text_lengths)

        # Add speaker conditioning to encoder output
        if speaker_embedding is not None:
            speaker_cond_enc = self.speaker_encoder_proj(speaker_embedding)
            encoder_output = encoder_output + speaker_cond_enc.unsqueeze(1)

        # Variance adaptor
        variance_output = self.variance_adaptor(
            encoder_output, text_mask, **kwargs
        )

        # Add speaker conditioning to decoder input
        decoder_input = variance_output['output']
        if speaker_embedding is not None:
            speaker_cond_dec = self.speaker_decoder_proj(speaker_embedding)
            decoder_input = decoder_input + speaker_cond_dec.unsqueeze(1)

        # Decode
        decoder_output = self.decoder(decoder_input, variance_output['mel_mask'])
        mel_output = self.mel_linear(decoder_output)

        return {
            'mel_output': mel_output,
            **variance_output
        }


class VITSVoiceCloner(nn.Module):
    """
    VITS-based voice cloning system.

    VITS (Variational Inference with adversarial learning for
    end-to-end Text-to-Speech) enables high-quality zero-shot
    voice cloning with:
    - Variational autoencoder for latent representation
    - Normalizing flows for expressive synthesis
    - Adversarial training for quality
    """

    def __init__(
        self,
        vocab_size: int = 256,
        hidden_dim: int = 192,
        n_speakers: int = 1000,
        gin_channels: int = 256
    ):
        super().__init__()

        # Text encoder
        self.text_encoder = TextEncoder(
            vocab_size, hidden_dim, hidden_dim,
            n_layers=6, n_heads=2
        )

        # Posterior encoder (for training)
        self.posterior_encoder = PosteriorEncoder(
            in_channels=513,  # Linear spectrogram
            out_channels=192,
            hidden_channels=192,
            kernel_size=5,
            dilation_rate=1,
            n_layers=16,
            gin_channels=gin_channels
        )

        # Flow for variational inference
        self.flow = ResidualCouplingBlock(
            192, 192, 5, 1, 4,
            gin_channels=gin_channels
        )

        # HiFi-GAN decoder
        self.decoder = HiFiGANGenerator(
            192, gin_channels=gin_channels
        )

        # Speaker embedding
        self.speaker_embedding = nn.Embedding(n_speakers, gin_channels)

        # External speaker encoder for zero-shot
        self.external_speaker_encoder = None

    def set_speaker_encoder(self, encoder: nn.Module):
        """Set external speaker encoder for zero-shot cloning."""
        self.external_speaker_encoder = encoder
        for param in encoder.parameters():
            param.requires_grad = False

    @torch.no_grad()
    def clone_voice(
        self,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        reference_mel: torch.Tensor,
        noise_scale: float = 0.667,
        length_scale: float = 1.0
    ) -> torch.Tensor:
        """
        Zero-shot voice cloning inference.
        """
        if self.external_speaker_encoder is None:
            raise ValueError("External speaker encoder not set")

        self.eval()

        # Extract speaker embedding from reference
        speaker_embedding = self.external_speaker_encoder(reference_mel)
        g = speaker_embedding.unsqueeze(-1)  # [B, gin_channels, 1]

        # Encode text
        x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)

        # Duration prediction and expansion
        # (simplified - full VITS uses stochastic duration predictor)

        # Sample from prior
        z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale

        # Reverse flow
        z = self.flow.reverse(z_p, x_mask, g=g)

        # Decode to waveform
        audio = self.decoder(z, g=g)

        return audio


class TextEncoder(nn.Module):
    """VITS text encoder."""

    def __init__(
        self,
        vocab_size: int,
        hidden_channels: int,
        out_channels: int,
        n_layers: int,
        n_heads: int
    ):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_channels)

        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=hidden_channels,
                nhead=n_heads,
                dim_feedforward=hidden_channels * 4,
                batch_first=True
            ),
            num_layers=n_layers
        )

        self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)

    def forward(
        self,
        text: torch.Tensor,
        text_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        x = self.embedding(text)

        # Create mask
        max_len = text.size(1)
        mask = torch.arange(max_len, device=text.device)[None, :] >= text_lengths[:, None]

        x = self.encoder(x, src_key_padding_mask=mask)

        # Project to mean and log variance
        stats = self.proj(x.transpose(1, 2))
        m, logs = stats.chunk(2, dim=1)

        x_mask = (~mask).unsqueeze(1).float()

        return x.transpose(1, 2), m, logs, x_mask


class PosteriorEncoder(nn.Module):
    """VITS posterior encoder (WaveNet-style)."""

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        hidden_channels: int,
        kernel_size: int,
        dilation_rate: int,
        n_layers: int,
        gin_channels: int
    ):
        super().__init__()
        self.out_channels = out_channels

        self.pre = nn.Conv1d(in_channels, hidden_channels, 1)

        self.enc = nn.ModuleList()
        for i in range(n_layers):
            dilation = dilation_rate ** (i % 4)
            self.enc.append(
                nn.Conv1d(
                    hidden_channels, hidden_channels * 2,
                    kernel_size, padding=(kernel_size - 1) * dilation // 2,
                    dilation=dilation
                )
            )

        self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)

        # Speaker conditioning
        self.cond = nn.Conv1d(gin_channels, hidden_channels, 1)

    def forward(
        self,
        x: torch.Tensor,
        x_mask: torch.Tensor,
        g: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        x = self.pre(x) * x_mask

        if g is not None:
            x = x + self.cond(g)

        for layer in self.enc:
            h = layer(x)
            h_a, h_b = h.chunk(2, dim=1)
            h = torch.tanh(h_a) * torch.sigmoid(h_b)
            x = x + h

        stats = self.proj(x) * x_mask
        m, logs = stats.chunk(2, dim=1)

        z = m + torch.randn_like(m) * torch.exp(logs)

        return z, m, logs


class ResidualCouplingBlock(nn.Module):
    """Normalizing flow block for VITS."""

    def __init__(
        self,
        channels: int,
        hidden_channels: int,
        kernel_size: int,
        dilation_rate: int,
        n_layers: int,
        n_flows: int = 4,
        gin_channels: int = 0
    ):
        super().__init__()

        self.flows = nn.ModuleList()
        for _ in range(n_flows):
            self.flows.append(
                ResidualCouplingLayer(
                    channels, hidden_channels, kernel_size,
                    dilation_rate, n_layers, gin_channels
                )
            )
            self.flows.append(Flip())

    def forward(
        self,
        x: torch.Tensor,
        x_mask: torch.Tensor,
        g: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        for flow in self.flows:
            x, _ = flow(x, x_mask, g=g)
        return x

    def reverse(
        self,
        z: torch.Tensor,
        x_mask: torch.Tensor,
        g: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        for flow in reversed(self.flows):
            z = flow.reverse(z, x_mask, g=g)
        return z


class ResidualCouplingLayer(nn.Module):
    """Single residual coupling layer."""

    def __init__(
        self,
        channels: int,
        hidden_channels: int,
        kernel_size: int,
        dilation_rate: int,
        n_layers: int,
        gin_channels: int
    ):
        super().__init__()
        self.half_channels = channels // 2

        self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)

        self.enc = nn.ModuleList()
        for i in range(n_layers):
            dilation = dilation_rate ** i
            self.enc.append(
                nn.Conv1d(
                    hidden_channels, hidden_channels * 2,
                    kernel_size, padding=(kernel_size - 1) * dilation // 2,
                    dilation=dilation
                )
            )

        self.post = nn.Conv1d(hidden_channels, self.half_channels, 1)
        self.post.weight.data.zero_()
        self.post.bias.data.zero_()

        if gin_channels > 0:
            self.cond = nn.Conv1d(gin_channels, hidden_channels, 1)

    def forward(
        self,
        x: torch.Tensor,
        x_mask: torch.Tensor,
        g: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        x0, x1 = x.chunk(2, dim=1)

        h = self.pre(x0) * x_mask
        if g is not None:
            h = h + self.cond(g)

        for layer in self.enc:
            h_out = layer(h)
            h_a, h_b = h_out.chunk(2, dim=1)
            h = h + torch.tanh(h_a) * torch.sigmoid(h_b)

        m = self.post(h) * x_mask

        x1 = m + x1
        x = torch.cat([x0, x1], dim=1)

        log_det = torch.zeros(x.size(0), device=x.device)

        return x, log_det

    def reverse(
        self,
        z: torch.Tensor,
        x_mask: torch.Tensor,
        g: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        z0, z1 = z.chunk(2, dim=1)

        h = self.pre(z0) * x_mask
        if g is not None:
            h = h + self.cond(g)

        for layer in self.enc:
            h_out = layer(h)
            h_a, h_b = h_out.chunk(2, dim=1)
            h = h + torch.tanh(h_a) * torch.sigmoid(h_b)

        m = self.post(h) * x_mask

        z1 = z1 - m
        z = torch.cat([z0, z1], dim=1)

        return z


class Flip(nn.Module):
    """Flip operation for flow."""

    def forward(
        self,
        x: torch.Tensor,
        x_mask: torch.Tensor,
        **kwargs
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        x = torch.flip(x, [1])
        return x, torch.zeros(x.size(0), device=x.device)

    def reverse(
        self,
        z: torch.Tensor,
        x_mask: torch.Tensor,
        **kwargs
    ) -> torch.Tensor:
        return torch.flip(z, [1])


class HiFiGANGenerator(nn.Module):
    """HiFi-GAN generator with speaker conditioning."""

    def __init__(
        self,
        in_channels: int,
        upsample_rates: List[int] = [8, 8, 2, 2],
        upsample_initial_channel: int = 512,
        gin_channels: int = 0
    ):
        super().__init__()

        self.conv_pre = nn.Conv1d(in_channels, upsample_initial_channel, 7, padding=3)

        # Upsampling layers (simplified)
        self.ups = nn.ModuleList()
        for i, rate in enumerate(upsample_rates):
            self.ups.append(
                nn.ConvTranspose1d(
                    upsample_initial_channel // (2 ** i),
                    upsample_initial_channel // (2 ** (i + 1)),
                    rate * 2, stride=rate, padding=rate // 2
                )
            )

        self.conv_post = nn.Conv1d(
            upsample_initial_channel // (2 ** len(upsample_rates)),
            1, 7, padding=3
        )

        # Speaker conditioning
        if gin_channels > 0:
            self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)

    def forward(
        self,
        x: torch.Tensor,
        g: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        x = self.conv_pre(x)

        if g is not None:
            x = x + self.cond(g)

        for up in self.ups:
            x = F.leaky_relu(x, 0.1)
            x = up(x)

        x = F.leaky_relu(x)
        x = self.conv_post(x)
        x = torch.tanh(x)

        return x

Neural Voice Conversion

Voice conversion transforms speech while preserving content:

PYTHON
class VoiceConverter(nn.Module):
    """
    Voice conversion: change speaker identity while preserving content.

    Approaches:
    1. Parallel: requires paired source-target data
    2. Non-parallel: learns from unpaired data using:
       - Disentanglement (separate content from speaker)
       - Cycle consistency (inspired by CycleGAN)
       - Recognition-synthesis (ASR + TTS)
    """

    def __init__(
        self,
        content_encoder: nn.Module,
        speaker_encoder: nn.Module,
        decoder: nn.Module
    ):
        super().__init__()
        self.content_encoder = content_encoder
        self.speaker_encoder = speaker_encoder
        self.decoder = decoder

    def forward(
        self,
        source_mel: torch.Tensor,
        target_speaker_mel: torch.Tensor
    ) -> torch.Tensor:
        """
        Convert source speech to target speaker's voice.
        """
        # Extract content (speaker-independent)
        content = self.content_encoder(source_mel)

        # Extract target speaker embedding
        speaker = self.speaker_encoder(target_speaker_mel)

        # Decode with target speaker
        converted_mel = self.decoder(content, speaker)

        return converted_mel


class ContentEncoder(nn.Module):
    """
    Content encoder that removes speaker information.

    Uses instance normalization to remove speaker-specific
    statistics while preserving linguistic content.
    """

    def __init__(self, in_dim: int = 80, hidden_dim: int = 512):
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Conv1d(in_dim, hidden_dim, 5, padding=2),
            nn.InstanceNorm1d(hidden_dim),  # Removes speaker statistics
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.InstanceNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, padding=2),
            nn.InstanceNorm1d(hidden_dim),
            nn.ReLU()
        )

        # Bottleneck to remove residual speaker info
        self.bottleneck = nn.Conv1d(hidden_dim, hidden_dim // 4, 1)
        self.expand = nn.Conv1d(hidden_dim // 4, hidden_dim, 1)

    def forward(self, mel: torch.Tensor) -> torch.Tensor:
        """
        Extract speaker-independent content.

        Args:
            mel: [B, T, n_mels]

        Returns:
            content: [B, T, hidden_dim]
        """
        x = mel.transpose(1, 2)  # [B, n_mels, T]
        x = self.encoder(x)
        x = self.bottleneck(x)
        x = self.expand(x)
        return x.transpose(1, 2)


class AdaINDecoder(nn.Module):
    """
    Decoder with Adaptive Instance Normalization (AdaIN).

    Injects speaker style via learned affine transformations.
    """

    def __init__(
        self,
        content_dim: int = 512,
        speaker_dim: int = 256,
        n_mels: int = 80
    ):
        super().__init__()

        # AdaIN parameter generators
        self.style_fc = nn.Linear(speaker_dim, content_dim * 2 * 4)  # gamma, beta for 4 layers

        # Decoder layers
        self.decoder = nn.ModuleList([
            nn.Conv1d(content_dim, content_dim, 5, padding=2),
            nn.Conv1d(content_dim, content_dim, 5, padding=2),
            nn.Conv1d(content_dim, content_dim, 5, padding=2),
            nn.Conv1d(content_dim, content_dim, 5, padding=2)
        ])

        self.output = nn.Conv1d(content_dim, n_mels, 1)

    def forward(
        self,
        content: torch.Tensor,
        speaker_embedding: torch.Tensor
    ) -> torch.Tensor:
        """
        Decode content with speaker style.

        Args:
            content: [B, T, content_dim]
            speaker_embedding: [B, speaker_dim]

        Returns:
            mel: [B, T, n_mels]
        """
        # Generate style parameters
        style_params = self.style_fc(speaker_embedding)
        style_params = style_params.view(
            speaker_embedding.size(0), len(self.decoder), 2, -1
        )  # [B, n_layers, 2, content_dim]

        x = content.transpose(1, 2)  # [B, content_dim, T]

        for i, layer in enumerate(self.decoder):
            x = layer(x)

            # AdaIN
            gamma = style_params[:, i, 0, :].unsqueeze(-1)  # [B, content_dim, 1]
            beta = style_params[:, i, 1, :].unsqueeze(-1)

            x = F.instance_norm(x)
            x = gamma * x + beta

            x = F.relu(x)

        x = self.output(x)
        return x.transpose(1, 2)


class AutoVCEncoder(nn.Module):
    """
    AutoVC content encoder with information bottleneck.

    Uses downsampling bottleneck to force content-only representation.
    """

    def __init__(self, n_mels: int = 80, hidden_dim: int = 512, bottleneck_dim: int = 32):
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Conv1d(n_mels, hidden_dim, 5, padding=2),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 5, stride=2, padding=2),  # Downsample
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU()
        )

        self.lstm = nn.LSTM(hidden_dim, hidden_dim // 2, 2, batch_first=True, bidirectional=True)

        # Information bottleneck
        self.bottleneck = nn.Sequential(
            nn.Linear(hidden_dim, bottleneck_dim),
            nn.Linear(bottleneck_dim, hidden_dim)
        )

    def forward(self, mel: torch.Tensor) -> torch.Tensor:
        x = mel.transpose(1, 2)
        x = self.encoder(x)
        x = x.transpose(1, 2)

        x, _ = self.lstm(x)
        x = self.bottleneck(x)

        # Upsample back to original rate
        x = x.repeat_interleave(2, dim=1)

        return x


def voice_cloning_systems():
    """Overview of voice cloning approaches."""
    systems = {
        'SV2TTS (Real-Time Voice Cloning)': {
            'components': ['Speaker encoder', 'Synthesizer (Tacotron)', 'Vocoder'],
            'reference_audio': '~5 seconds',
            'quality': 'Good',
            'speed': 'Real-time',
            'approach': 'Embedding-based conditioning'
        },
        'YourTTS': {
            'components': ['VITS backbone', 'Speaker encoder'],
            'reference_audio': '~3 seconds',
            'quality': 'Very good',
            'speed': 'Real-time',
            'approach': 'Zero-shot multilingual cloning'
        },
        'VALL-E': {
            'components': ['Neural codec', 'AR + NAR transformers'],
            'reference_audio': '3 seconds',
            'quality': 'Excellent',
            'speed': 'Slower',
            'approach': 'Audio language model with codec'
        },
        'Bark': {
            'components': ['GPT-style transformers', 'Codec decoder'],
            'reference_audio': 'Speaker prompts',
            'quality': 'Good',
            'speed': 'Moderate',
            'approach': 'Text-to-semantic-to-acoustic'
        },
        'OpenVoice': {
            'components': ['Base TTS', 'Tone color converter'],
            'reference_audio': 'Short clip',
            'quality': 'Good',
            'speed': 'Fast',
            'approach': 'Decoupled style/content'
        },
        'Coqui XTTS': {
            'components': ['GPT-based TTS', 'Speaker conditioning'],
            'reference_audio': '6+ seconds',
            'quality': 'Very good',
            'speed': 'Real-time streaming',
            'approach': 'Multilingual with voice cloning'
        }
    }

    print("Voice Cloning Systems Comparison:")
    print("=" * 70)
    for name, info in systems.items():
        print(f"\n{name}:")
        for k, v in info.items():
            if isinstance(v, list):
                print(f"  {k}: {', '.join(v)}")
            else:
                print(f"  {k}: {v}")


voice_cloning_systems()

Key Takeaways

Speaker recognition and voice cloning have advanced dramatically through neural approaches. Modern speaker embeddings like ECAPA-TDNN achieve remarkable accuracy for verification and identification tasks, learning discriminative representations that capture speaker identity. Voice cloning systems can now synthesize natural speech in any speaker's voice from just a few seconds of reference audio, using techniques like speaker embedding conditioning, variational inference, and normalizing flows. Key challenges include: (1) robustness to noise, channel mismatch, and short utterances, (2) preventing misuse through deepfake detection, (3) preserving prosody and speaking style, not just voice timbre, (4) real-time performance for interactive applications. The technology enables personalized voice assistants, accessible audiobooks, content creation, and language preservation, while raising important ethical considerations about consent and misuse.

23.5 Audio Language Models Advanced

Audio Language Models

Audio language models extend the success of large language models to the audio domain, treating audio as a sequence of discrete tokens that can be modeled autoregressively. This paradigm shift—from specialized signal processing pipelines to unified language modeling—enables remarkable capabilities including zero-shot TTS, music generation, and audio understanding. This section explores neural audio codecs, audio tokenization, and transformer-based audio generation systems.

Neural Audio Codecs

Neural codecs compress audio into discrete tokens for language modeling:

PYTHON
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List, Dict
from dataclasses import dataclass
import numpy as np

@dataclass
class CodecConfig:
    """Configuration for neural audio codec."""
    sample_rate: int = 24000
    channels: int = 1
    n_codebooks: int = 8
    codebook_size: int = 1024
    hidden_dim: int = 128
    compressed_dim: int = 128
    n_residual_layers: int = 3
    ratios: List[int] = None  # Downsampling ratios

    def __post_init__(self):
        if self.ratios is None:
            self.ratios = [8, 5, 4, 2]  # Total: 320x compression


class EncodecModel(nn.Module):
    """
    Encodec: High Fidelity Neural Audio Compression.

    Encodes audio into discrete tokens using residual vector quantization.
    Enables audio language modeling by converting continuous signals
    to discrete sequences.

    Architecture:
    - Convolutional encoder (downsampling)
    - Residual Vector Quantizer (RVQ)
    - Convolutional decoder (upsampling)
    - Multi-scale discriminator (training)
    """

    def __init__(self, config: CodecConfig):
        super().__init__()
        self.config = config

        # Calculate hop length
        self.hop_length = np.prod(config.ratios)  # 320 for default

        # Encoder
        self.encoder = EncodecEncoder(config)

        # Residual Vector Quantizer
        self.quantizer = ResidualVectorQuantizer(
            dim=config.compressed_dim,
            n_codebooks=config.n_codebooks,
            codebook_size=config.codebook_size
        )

        # Decoder
        self.decoder = EncodecDecoder(config)

    def encode(self, audio: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        """
        Encode audio to discrete codes.

        Args:
            audio: [B, 1, T] input audio

        Returns:
            codes: [B, n_codebooks, T//hop_length] discrete codes
            embeddings: List of intermediate embeddings
        """
        # Encode to continuous
        z = self.encoder(audio)

        # Quantize
        z_q, codes, embeddings = self.quantizer(z)

        return codes, embeddings

    def decode(self, codes: torch.Tensor) -> torch.Tensor:
        """
        Decode discrete codes to audio.

        Args:
            codes: [B, n_codebooks, T_codes] discrete codes

        Returns:
            audio: [B, 1, T_codes * hop_length] reconstructed audio
        """
        # Dequantize
        z_q = self.quantizer.decode(codes)

        # Decode to audio
        audio = self.decoder(z_q)

        return audio

    def forward(
        self,
        audio: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Full encode-decode forward pass.

        Returns:
            reconstructed: Reconstructed audio
            codes: Discrete codes
            losses: Dictionary of losses
        """
        # Encode
        z = self.encoder(audio)

        # Quantize
        z_q, codes, commit_loss = self.quantizer(z)

        # Decode
        reconstructed = self.decoder(z_q)

        losses = {
            'commitment_loss': commit_loss
        }

        return reconstructed, codes, losses


class EncodecEncoder(nn.Module):
    """Encodec encoder with progressive downsampling."""

    def __init__(self, config: CodecConfig):
        super().__init__()

        channels = config.hidden_dim
        self.conv_in = nn.Conv1d(config.channels, channels, 7, padding=3)

        # Downsampling blocks
        self.blocks = nn.ModuleList()
        for ratio in config.ratios:
            self.blocks.append(
                EncoderBlock(channels, channels * 2, ratio, config.n_residual_layers)
            )
            channels *= 2

        self.conv_out = nn.Conv1d(channels, config.compressed_dim, 7, padding=3)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv_in(x)

        for block in self.blocks:
            x = block(x)

        x = self.conv_out(x)
        return x


class EncoderBlock(nn.Module):
    """Encoder block with residual units and downsampling."""

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        stride: int,
        n_residual: int
    ):
        super().__init__()

        # Residual units
        self.residuals = nn.ModuleList([
            ResidualUnit(in_channels, in_channels, dilation=3**i)
            for i in range(n_residual)
        ])

        # Downsampling
        self.downsample = nn.Conv1d(
            in_channels, out_channels,
            kernel_size=stride * 2,
            stride=stride,
            padding=stride // 2
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for res in self.residuals:
            x = res(x)
        x = self.downsample(x)
        return F.elu(x)


class ResidualUnit(nn.Module):
    """Residual unit with dilated convolution."""

    def __init__(self, in_channels: int, out_channels: int, dilation: int = 1):
        super().__init__()

        self.conv1 = nn.Conv1d(
            in_channels, out_channels,
            kernel_size=7, padding=3 * dilation, dilation=dilation
        )
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=1)

        self.shortcut = nn.Identity() if in_channels == out_channels else \
            nn.Conv1d(in_channels, out_channels, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = self.shortcut(x)
        x = F.elu(self.conv1(x))
        x = self.conv2(x)
        return x + residual


class EncodecDecoder(nn.Module):
    """Encodec decoder with progressive upsampling."""

    def __init__(self, config: CodecConfig):
        super().__init__()

        # Calculate initial channels
        channels = config.hidden_dim * (2 ** len(config.ratios))

        self.conv_in = nn.Conv1d(config.compressed_dim, channels, 7, padding=3)

        # Upsampling blocks (reverse order)
        self.blocks = nn.ModuleList()
        for ratio in reversed(config.ratios):
            self.blocks.append(
                DecoderBlock(channels, channels // 2, ratio, config.n_residual_layers)
            )
            channels //= 2

        self.conv_out = nn.Conv1d(channels, config.channels, 7, padding=3)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv_in(x)

        for block in self.blocks:
            x = block(x)

        x = self.conv_out(x)
        return torch.tanh(x)


class DecoderBlock(nn.Module):
    """Decoder block with upsampling and residual units."""

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        stride: int,
        n_residual: int
    ):
        super().__init__()

        # Upsampling
        self.upsample = nn.ConvTranspose1d(
            in_channels, out_channels,
            kernel_size=stride * 2,
            stride=stride,
            padding=stride // 2
        )

        # Residual units
        self.residuals = nn.ModuleList([
            ResidualUnit(out_channels, out_channels, dilation=3**i)
            for i in range(n_residual)
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.elu(self.upsample(x))
        for res in self.residuals:
            x = res(x)
        return x


class ResidualVectorQuantizer(nn.Module):
    """
    Residual Vector Quantization (RVQ).

    Progressively quantizes residuals using multiple codebooks,
    enabling hierarchical audio representation.
    """

    def __init__(
        self,
        dim: int,
        n_codebooks: int,
        codebook_size: int,
        commitment_weight: float = 1.0
    ):
        super().__init__()

        self.n_codebooks = n_codebooks
        self.commitment_weight = commitment_weight

        # Create codebooks
        self.codebooks = nn.ModuleList([
            VectorQuantizer(dim, codebook_size)
            for _ in range(n_codebooks)
        ])

    def forward(
        self,
        z: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Quantize with RVQ.

        Args:
            z: [B, dim, T] continuous embeddings

        Returns:
            z_q: [B, dim, T] quantized embeddings
            codes: [B, n_codebooks, T] discrete codes
            commit_loss: Commitment loss
        """
        z_q = torch.zeros_like(z)
        residual = z

        codes = []
        commit_losses = []

        for codebook in self.codebooks:
            # Quantize residual
            z_q_i, codes_i, commit_loss_i = codebook(residual)

            # Add to quantized
            z_q = z_q + z_q_i

            # Update residual
            residual = residual - z_q_i

            codes.append(codes_i)
            commit_losses.append(commit_loss_i)

        codes = torch.stack(codes, dim=1)  # [B, n_codebooks, T]
        commit_loss = torch.stack(commit_losses).mean()

        return z_q, codes, commit_loss

    def decode(self, codes: torch.Tensor) -> torch.Tensor:
        """
        Decode codes to embeddings.

        Args:
            codes: [B, n_codebooks, T]

        Returns:
            z_q: [B, dim, T]
        """
        z_q = None

        for i, codebook in enumerate(self.codebooks):
            z_q_i = codebook.decode(codes[:, i])
            if z_q is None:
                z_q = z_q_i
            else:
                z_q = z_q + z_q_i

        return z_q


class VectorQuantizer(nn.Module):
    """Single codebook vector quantizer."""

    def __init__(self, dim: int, n_embeddings: int):
        super().__init__()

        self.n_embeddings = n_embeddings
        self.embedding = nn.Embedding(n_embeddings, dim)

        # Initialize codebook
        nn.init.uniform_(self.embedding.weight, -1/n_embeddings, 1/n_embeddings)

    def forward(
        self,
        z: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Quantize input.

        Args:
            z: [B, dim, T]

        Returns:
            z_q: [B, dim, T] quantized
            codes: [B, T] code indices
            commit_loss: Commitment loss
        """
        # Reshape: [B, dim, T] -> [B*T, dim]
        B, dim, T = z.shape
        z_flat = z.permute(0, 2, 1).reshape(-1, dim)

        # Find nearest neighbors
        distances = torch.cdist(z_flat, self.embedding.weight)
        codes = distances.argmin(dim=-1)  # [B*T]

        # Quantize
        z_q = self.embedding(codes)  # [B*T, dim]

        # Commitment loss
        commit_loss = F.mse_loss(z_flat.detach(), z_q) + F.mse_loss(z_flat, z_q.detach())

        # Straight-through gradient
        z_q = z_flat + (z_q - z_flat).detach()

        # Reshape back
        z_q = z_q.reshape(B, T, dim).permute(0, 2, 1)
        codes = codes.reshape(B, T)

        return z_q, codes, commit_loss

    def decode(self, codes: torch.Tensor) -> torch.Tensor:
        """Decode codes to embeddings."""
        z_q = self.embedding(codes)  # [B, T, dim]
        return z_q.permute(0, 2, 1)  # [B, dim, T]


def audio_codec_comparison():
    """Compare neural audio codecs."""
    codecs = {
        'Opus': {
            'type': 'Traditional',
            'bitrate': '6-510 kbps',
            'latency': '2.5-60 ms',
            'quality': 'Very good',
            'note': 'Standard codec, not neural'
        },
        'SoundStream': {
            'type': 'Neural (RVQ)',
            'bitrate': '3-18 kbps',
            'latency': '13 ms',
            'quality': 'Excellent at low bitrate',
            'note': 'Google, pioneered neural codec'
        },
        'Encodec': {
            'type': 'Neural (RVQ)',
            'bitrate': '1.5-24 kbps',
            'latency': '13 ms',
            'quality': 'Excellent',
            'note': 'Meta, enables audio LM'
        },
        'DAC (Descript)': {
            'type': 'Neural (RVQ)',
            'bitrate': '8-16 kbps',
            'latency': '~5 ms',
            'quality': 'High fidelity',
            'note': '44.1kHz support, music-focused'
        },
        'Lyra': {
            'type': 'Neural (Autoencoder)',
            'bitrate': '3-9 kbps',
            'latency': '20 ms',
            'quality': 'Good',
            'note': 'Google, mobile-optimized'
        }
    }

    print("Neural Audio Codec Comparison:")
    print("=" * 60)
    for name, info in codecs.items():
        print(f"\n{name}:")
        for k, v in info.items():
            print(f"  {k}: {v}")


audio_codec_comparison()

Audio Language Models

Treating audio as language enables powerful generative models:

PYTHON
class AudioLM(nn.Module):
    """
    AudioLM: Audio Language Model.

    Three-stage hierarchical generation:
    1. Semantic tokens (from w2v-BERT): capture content
    2. Coarse acoustic tokens (SoundStream): capture speaker/acoustic
    3. Fine acoustic tokens: high-fidelity reconstruction

    This enables text-free audio continuation and generation.
    """

    def __init__(
        self,
        semantic_vocab_size: int = 1024,
        acoustic_vocab_size: int = 1024,
        n_coarse_codebooks: int = 4,
        n_fine_codebooks: int = 8,
        hidden_dim: int = 1024,
        n_layers: int = 12,
        n_heads: int = 16
    ):
        super().__init__()

        # Stage 1: Semantic modeling
        self.semantic_model = AudioLMStage(
            vocab_size=semantic_vocab_size,
            hidden_dim=hidden_dim,
            n_layers=n_layers,
            n_heads=n_heads
        )

        # Stage 2: Coarse acoustic modeling (conditioned on semantic)
        self.coarse_model = AudioLMStage(
            vocab_size=acoustic_vocab_size * n_coarse_codebooks,
            hidden_dim=hidden_dim,
            n_layers=n_layers,
            n_heads=n_heads,
            n_codebooks=n_coarse_codebooks,
            condition_dim=hidden_dim
        )

        # Stage 3: Fine acoustic modeling (conditioned on coarse)
        self.fine_model = AudioLMStage(
            vocab_size=acoustic_vocab_size * n_fine_codebooks,
            hidden_dim=hidden_dim,
            n_layers=n_layers // 2,
            n_heads=n_heads,
            n_codebooks=n_fine_codebooks,
            condition_dim=hidden_dim
        )

        self.n_coarse_codebooks = n_coarse_codebooks
        self.n_fine_codebooks = n_fine_codebooks

    @torch.no_grad()
    def generate(
        self,
        prompt_semantic: torch.Tensor,
        prompt_acoustic: Optional[torch.Tensor] = None,
        max_semantic_tokens: int = 500,
        max_acoustic_frames: int = 1000,
        temperature: float = 0.8,
        top_k: int = 50
    ) -> torch.Tensor:
        """
        Generate audio continuation.

        Args:
            prompt_semantic: [B, T_s] semantic token prompt
            prompt_acoustic: [B, n_codebooks, T_a] acoustic token prompt
            max_semantic_tokens: Maximum semantic tokens to generate
            max_acoustic_frames: Maximum acoustic frames

        Returns:
            acoustic_tokens: [B, n_fine_codebooks, T_a] generated tokens
        """
        # Stage 1: Generate semantic tokens
        semantic_tokens = self.semantic_model.generate(
            prompt_semantic,
            max_length=max_semantic_tokens,
            temperature=temperature,
            top_k=top_k
        )

        # Get semantic conditioning for acoustic
        semantic_hidden = self.semantic_model.get_hidden(semantic_tokens)

        # Stage 2: Generate coarse acoustic tokens
        if prompt_acoustic is not None:
            coarse_prompt = prompt_acoustic[:, :self.n_coarse_codebooks]
        else:
            coarse_prompt = None

        coarse_tokens = self.coarse_model.generate(
            coarse_prompt,
            max_length=max_acoustic_frames,
            temperature=temperature,
            top_k=top_k,
            condition=semantic_hidden
        )

        # Get coarse conditioning for fine
        coarse_hidden = self.coarse_model.get_hidden(coarse_tokens)

        # Stage 3: Generate fine acoustic tokens
        fine_tokens = self.fine_model.generate(
            coarse_tokens,  # Use coarse as prefix
            max_length=max_acoustic_frames,
            temperature=temperature,
            top_k=top_k,
            condition=coarse_hidden
        )

        return fine_tokens


class AudioLMStage(nn.Module):
    """Single stage of AudioLM (transformer decoder)."""

    def __init__(
        self,
        vocab_size: int,
        hidden_dim: int,
        n_layers: int,
        n_heads: int,
        n_codebooks: int = 1,
        condition_dim: Optional[int] = None
    ):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.pos_encoding = SinusoidalPositionalEncoding(hidden_dim)

        # Optional conditioning projection
        if condition_dim is not None:
            self.condition_proj = nn.Linear(condition_dim, hidden_dim)
        else:
            self.condition_proj = None

        # Transformer layers
        self.layers = nn.ModuleList([
            nn.TransformerDecoderLayer(
                d_model=hidden_dim,
                nhead=n_heads,
                dim_feedforward=hidden_dim * 4,
                batch_first=True
            )
            for _ in range(n_layers)
        ])

        self.ln_f = nn.LayerNorm(hidden_dim)
        self.head = nn.Linear(hidden_dim, vocab_size)

        self.n_codebooks = n_codebooks

    def forward(
        self,
        tokens: torch.Tensor,
        condition: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Forward pass.

        Args:
            tokens: [B, T] or [B, n_codebooks, T] input tokens
            condition: [B, T_c, hidden_dim] conditioning sequence

        Returns:
            logits: [B, T, vocab_size]
        """
        # Handle multi-codebook input
        if tokens.dim() == 3:
            B, n_cb, T = tokens.shape
            # Interleave codebooks: [c1_t1, c2_t1, ..., c1_t2, c2_t2, ...]
            tokens = tokens.permute(0, 2, 1).reshape(B, T * n_cb)

        x = self.embedding(tokens)
        x = self.pos_encoding(x)

        # Create causal mask
        T = x.size(1)
        causal_mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()

        # Apply conditioning
        if condition is not None and self.condition_proj is not None:
            condition = self.condition_proj(condition)
            # Cross-attention in each layer
            for layer in self.layers:
                x = layer(x, condition, tgt_mask=causal_mask)
        else:
            # Self-attention only
            for layer in self.layers:
                x = layer(x, x, tgt_mask=causal_mask)

        x = self.ln_f(x)
        logits = self.head(x)

        return logits

    @torch.no_grad()
    def generate(
        self,
        prompt: Optional[torch.Tensor],
        max_length: int,
        temperature: float = 1.0,
        top_k: int = 50,
        condition: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Autoregressive generation."""
        if prompt is None:
            prompt = torch.zeros(1, 1, dtype=torch.long, device=next(self.parameters()).device)

        generated = prompt.clone()

        for _ in range(max_length - prompt.size(-1)):
            logits = self.forward(generated, condition)
            logits = logits[:, -1, :] / temperature

            # Top-k filtering
            if top_k > 0:
                indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
                logits[indices_to_remove] = float('-inf')

            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, 1)

            generated = torch.cat([generated, next_token], dim=-1)

        return generated


class SinusoidalPositionalEncoding(nn.Module):
    """Sinusoidal positional encoding."""

    def __init__(self, dim: int, max_len: int = 10000):
        super().__init__()
        pe = torch.zeros(max_len, dim)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, dim, 2).float() * (-np.log(10000.0) / dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.pe[:, :x.size(1)]


class VALLE(nn.Module):
    """
    VALL-E: Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers.

    Uses discrete codec tokens as vocabulary and language model
    for zero-shot TTS from 3-second prompt.

    Architecture:
    - AR model: generates first codebook autoregressively
    - NAR model: generates remaining codebooks in parallel
    """

    def __init__(
        self,
        vocab_size: int = 1024,  # Phoneme vocabulary
        codec_vocab_size: int = 1024,  # Codec codebook size
        n_codebooks: int = 8,
        hidden_dim: int = 1024,
        n_layers: int = 12,
        n_heads: int = 16
    ):
        super().__init__()

        self.n_codebooks = n_codebooks

        # Phoneme embedding
        self.phoneme_embedding = nn.Embedding(vocab_size, hidden_dim)

        # Codec embeddings (one per codebook)
        self.codec_embeddings = nn.ModuleList([
            nn.Embedding(codec_vocab_size, hidden_dim)
            for _ in range(n_codebooks)
        ])

        # Autoregressive model for first codebook
        self.ar_model = VALLETransformer(
            hidden_dim=hidden_dim,
            n_layers=n_layers,
            n_heads=n_heads,
            output_dim=codec_vocab_size
        )

        # Non-autoregressive model for remaining codebooks
        self.nar_model = VALLETransformer(
            hidden_dim=hidden_dim,
            n_layers=n_layers // 2,
            n_heads=n_heads,
            output_dim=codec_vocab_size * (n_codebooks - 1)
        )

    def forward(
        self,
        phonemes: torch.Tensor,
        prompt_codes: torch.Tensor,
        target_codes: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        """
        Training forward.

        Args:
            phonemes: [B, T_p] phoneme sequence
            prompt_codes: [B, n_codebooks, T_prompt] acoustic prompt
            target_codes: [B, n_codebooks, T_target] target codes

        Returns:
            Dictionary of losses
        """
        # Combine prompt and target
        codes = torch.cat([prompt_codes, target_codes], dim=-1)

        # AR loss (first codebook)
        ar_logits = self._ar_forward(phonemes, codes[:, 0])
        ar_loss = F.cross_entropy(
            ar_logits[:, prompt_codes.size(-1):-1].reshape(-1, ar_logits.size(-1)),
            target_codes[:, 0].reshape(-1)
        )

        # NAR loss (remaining codebooks)
        nar_logits = self._nar_forward(phonemes, codes)
        nar_loss = 0
        for i in range(1, self.n_codebooks):
            loss_i = F.cross_entropy(
                nar_logits[i-1][:, prompt_codes.size(-1):].reshape(-1, nar_logits[i-1].size(-1)),
                target_codes[:, i].reshape(-1)
            )
            nar_loss = nar_loss + loss_i
        nar_loss = nar_loss / (self.n_codebooks - 1)

        return {
            'ar_loss': ar_loss,
            'nar_loss': nar_loss,
            'total_loss': ar_loss + nar_loss
        }

    def _ar_forward(
        self,
        phonemes: torch.Tensor,
        codes: torch.Tensor
    ) -> torch.Tensor:
        """AR model forward."""
        # Embed phonemes
        phone_emb = self.phoneme_embedding(phonemes)

        # Embed codes
        code_emb = self.codec_embeddings[0](codes)

        # Concatenate: [phonemes, codes]
        x = torch.cat([phone_emb, code_emb], dim=1)

        # AR transformer
        logits = self.ar_model(x)

        # Return only code part
        return logits[:, phonemes.size(1):]

    def _nar_forward(
        self,
        phonemes: torch.Tensor,
        codes: torch.Tensor
    ) -> List[torch.Tensor]:
        """NAR model forward."""
        # Embed phonemes
        phone_emb = self.phoneme_embedding(phonemes)

        # Sum code embeddings up to current level
        code_emb = self.codec_embeddings[0](codes[:, 0])
        for i in range(1, self.n_codebooks):
            code_emb = code_emb + self.codec_embeddings[i](codes[:, i])

        # Concatenate
        x = torch.cat([phone_emb, code_emb], dim=1)

        # NAR transformer
        logits = self.nar_model(x)

        # Split into per-codebook logits
        logits = logits[:, phonemes.size(1):]
        per_codebook_logits = logits.chunk(self.n_codebooks - 1, dim=-1)

        return per_codebook_logits

    @torch.no_grad()
    def generate(
        self,
        phonemes: torch.Tensor,
        prompt_codes: torch.Tensor,
        max_length: int = 1000,
        temperature: float = 0.8
    ) -> torch.Tensor:
        """
        Zero-shot TTS generation.

        Args:
            phonemes: [B, T_p] phoneme sequence
            prompt_codes: [B, n_codebooks, T_prompt] 3-second acoustic prompt

        Returns:
            codes: [B, n_codebooks, T] generated codec codes
        """
        B = phonemes.size(0)
        device = phonemes.device

        # AR generation for first codebook
        codes_0 = self._ar_generate(phonemes, prompt_codes[:, 0], max_length, temperature)

        # NAR generation for remaining codebooks
        all_codes = [codes_0]

        for i in range(1, self.n_codebooks):
            # Stack current codes
            current_codes = torch.stack(all_codes + [torch.zeros_like(codes_0)] * (self.n_codebooks - i), dim=1)

            # Get logits for codebook i
            nar_logits = self._nar_forward(phonemes, current_codes)

            # Sample
            logits_i = nar_logits[i - 1] / temperature
            codes_i = torch.argmax(logits_i, dim=-1)  # Greedy or sample

            all_codes.append(codes_i)

        return torch.stack(all_codes, dim=1)

    def _ar_generate(
        self,
        phonemes: torch.Tensor,
        prompt_codes: torch.Tensor,
        max_length: int,
        temperature: float
    ) -> torch.Tensor:
        """AR generation for first codebook."""
        codes = prompt_codes.clone()

        for _ in range(max_length - prompt_codes.size(1)):
            logits = self._ar_forward(phonemes, codes)
            logits = logits[:, -1, :] / temperature

            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, 1)

            codes = torch.cat([codes, next_token], dim=-1)

            # Check for end token (simplified)
            if (next_token == 0).all():
                break

        return codes


class VALLETransformer(nn.Module):
    """Transformer for VALL-E."""

    def __init__(
        self,
        hidden_dim: int,
        n_layers: int,
        n_heads: int,
        output_dim: int
    ):
        super().__init__()

        self.pos_encoding = SinusoidalPositionalEncoding(hidden_dim)

        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=hidden_dim,
                nhead=n_heads,
                dim_feedforward=hidden_dim * 4,
                batch_first=True
            )
            for _ in range(n_layers)
        ])

        self.ln_f = nn.LayerNorm(hidden_dim)
        self.head = nn.Linear(hidden_dim, output_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pos_encoding(x)

        # Causal mask
        T = x.size(1)
        mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()

        for layer in self.layers:
            x = layer(x, src_mask=mask)

        x = self.ln_f(x)
        return self.head(x)

Music and Sound Generation

Audio language models extend to music generation:

PYTHON
class MusicGen(nn.Module):
    """
    MusicGen-style music generation model.

    Generates music from text descriptions using:
    - Text encoder (T5 or similar)
    - Transformer decoder for codec tokens
    - Delay pattern for parallel codebook generation
    """

    def __init__(
        self,
        text_dim: int = 1024,
        hidden_dim: int = 1024,
        codec_vocab_size: int = 2048,
        n_codebooks: int = 4,
        n_layers: int = 24,
        n_heads: int = 16
    ):
        super().__init__()

        self.n_codebooks = n_codebooks

        # Text conditioning
        self.text_proj = nn.Linear(text_dim, hidden_dim)

        # Codec embeddings
        self.codec_embedding = nn.Embedding(codec_vocab_size * n_codebooks, hidden_dim)

        # Transformer decoder
        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(
                d_model=hidden_dim,
                nhead=n_heads,
                dim_feedforward=hidden_dim * 4,
                batch_first=True
            ),
            num_layers=n_layers
        )

        self.pos_encoding = SinusoidalPositionalEncoding(hidden_dim)
        self.ln_f = nn.LayerNorm(hidden_dim)
        self.head = nn.Linear(hidden_dim, codec_vocab_size * n_codebooks)

    def forward(
        self,
        text_features: torch.Tensor,
        codec_tokens: torch.Tensor
    ) -> torch.Tensor:
        """
        Training forward with teacher forcing.

        Args:
            text_features: [B, T_text, text_dim] text encoder output
            codec_tokens: [B, n_codebooks, T_audio] codec tokens

        Returns:
            logits: [B, T_flat, vocab_size * n_codebooks]
        """
        B, _, T = codec_tokens.shape

        # Project text
        text_cond = self.text_proj(text_features)

        # Flatten codec tokens with delay pattern
        flat_tokens = self._apply_delay_pattern(codec_tokens)

        # Embed
        # Offset each codebook's tokens
        offsets = torch.arange(self.n_codebooks, device=codec_tokens.device) * 2048
        flat_tokens_offset = flat_tokens + offsets.view(1, -1, 1).expand(B, -1, T)
        flat_tokens_offset = flat_tokens_offset.view(B, -1)

        x = self.codec_embedding(flat_tokens_offset)
        x = self.pos_encoding(x)

        # Causal mask
        T_flat = x.size(1)
        causal_mask = torch.triu(torch.ones(T_flat, T_flat, device=x.device), diagonal=1).bool()

        # Decode with cross-attention to text
        x = self.decoder(x, text_cond, tgt_mask=causal_mask)

        x = self.ln_f(x)
        logits = self.head(x)

        return logits

    def _apply_delay_pattern(self, tokens: torch.Tensor) -> torch.Tensor:
        """
        Apply delay pattern for parallel codebook modeling.

        Codebook 0: [t0, t1, t2, t3, ...]
        Codebook 1: [PAD, t0, t1, t2, ...]
        Codebook 2: [PAD, PAD, t0, t1, ...]
        ...

        This allows autoregressive generation while generating
        multiple codebooks efficiently.
        """
        B, n_cb, T = tokens.shape
        device = tokens.device

        # Pad for delays
        padded = torch.zeros(B, n_cb, T + n_cb - 1, dtype=tokens.dtype, device=device)

        for i in range(n_cb):
            padded[:, i, i:i+T] = tokens[:, i]

        return padded

    @torch.no_grad()
    def generate(
        self,
        text_features: torch.Tensor,
        max_length: int = 1500,  # ~30 seconds at 50Hz
        temperature: float = 1.0,
        top_k: int = 250,
        cfg_scale: float = 3.0
    ) -> torch.Tensor:
        """
        Generate music from text description.

        Args:
            text_features: [B, T_text, text_dim] text encoding
            max_length: Maximum frames to generate
            cfg_scale: Classifier-free guidance scale

        Returns:
            codes: [B, n_codebooks, max_length] generated codes
        """
        B = text_features.size(0)
        device = text_features.device

        # Start with empty sequence
        generated = torch.zeros(B, self.n_codebooks, 0, dtype=torch.long, device=device)

        for t in range(max_length):
            if generated.size(-1) == 0:
                # First token
                logits = self._get_first_logits(text_features)
            else:
                logits = self.forward(text_features, generated)
                logits = logits[:, -self.n_codebooks:]

            # Classifier-free guidance
            if cfg_scale > 1.0:
                uncond_logits = self.forward(torch.zeros_like(text_features), generated)
                uncond_logits = uncond_logits[:, -self.n_codebooks:]
                logits = uncond_logits + cfg_scale * (logits - uncond_logits)

            # Sample each codebook
            new_tokens = []
            for i in range(self.n_codebooks):
                # Get logits for codebook i
                start = i * 2048
                end = (i + 1) * 2048
                cb_logits = logits[:, i, start:end] / temperature

                # Top-k filtering
                if top_k > 0:
                    indices_to_remove = cb_logits < torch.topk(cb_logits, top_k)[0][..., -1:]
                    cb_logits[indices_to_remove] = float('-inf')

                probs = F.softmax(cb_logits, dim=-1)
                token = torch.multinomial(probs, 1)
                new_tokens.append(token)

            new_tokens = torch.stack(new_tokens, dim=1)  # [B, n_codebooks, 1]
            generated = torch.cat([generated, new_tokens], dim=-1)

        return generated

    def _get_first_logits(self, text_features: torch.Tensor) -> torch.Tensor:
        """Get logits for first token."""
        text_cond = self.text_proj(text_features)

        # Just use text as context
        x = self.pos_encoding(text_cond[:, :1])

        for layer in self.decoder.layers:
            x = layer(x, text_cond)

        x = self.ln_f(x)
        return self.head(x)


class WhisperStyleASR(nn.Module):
    """
    Whisper-style universal speech model.

    Multitask model trained on:
    - Speech recognition (transcription)
    - Translation
    - Language identification
    - Voice activity detection

    Uses encoder-decoder architecture with special task tokens.
    """

    def __init__(
        self,
        n_mels: int = 80,
        vocab_size: int = 51865,
        hidden_dim: int = 1024,
        encoder_layers: int = 24,
        decoder_layers: int = 24,
        n_heads: int = 16
    ):
        super().__init__()

        # Audio encoder
        self.encoder = WhisperEncoder(n_mels, hidden_dim, encoder_layers, n_heads)

        # Text decoder
        self.decoder = WhisperDecoder(vocab_size, hidden_dim, decoder_layers, n_heads)

        # Special tokens
        self.register_buffer('sot', torch.tensor([50258]))  # Start of transcript
        self.register_buffer('eot', torch.tensor([50257]))  # End of transcript

    def forward(
        self,
        mel: torch.Tensor,
        tokens: torch.Tensor
    ) -> torch.Tensor:
        """
        Training forward.

        Args:
            mel: [B, n_mels, T_audio] mel spectrogram
            tokens: [B, T_text] target tokens (with task prefix)

        Returns:
            logits: [B, T_text, vocab_size]
        """
        # Encode audio
        audio_features = self.encoder(mel)

        # Decode
        logits = self.decoder(tokens, audio_features)

        return logits

    @torch.no_grad()
    def transcribe(
        self,
        mel: torch.Tensor,
        task: str = 'transcribe',
        language: str = 'en',
        max_length: int = 448
    ) -> List[str]:
        """
        Transcribe audio.

        Args:
            mel: [B, n_mels, T] mel spectrogram
            task: 'transcribe' or 'translate'
            language: Target language code

        Returns:
            List of transcriptions
        """
        B = mel.size(0)
        device = mel.device

        # Encode
        audio_features = self.encoder(mel)

        # Build prompt: <|startoftranscript|><|lang|><|task|>
        prompt = self._build_prompt(task, language, device)
        prompt = prompt.unsqueeze(0).expand(B, -1)

        # Generate
        generated = prompt.clone()

        for _ in range(max_length):
            logits = self.decoder(generated, audio_features)
            next_token = logits[:, -1].argmax(dim=-1, keepdim=True)

            generated = torch.cat([generated, next_token], dim=-1)

            # Stop at end of transcript
            if (next_token == self.eot).all():
                break

        return generated

    def _build_prompt(self, task: str, language: str, device: torch.device) -> torch.Tensor:
        """Build task prompt tokens."""
        # Simplified - actual Whisper has complex tokenization
        tokens = [50258]  # <|startoftranscript|>

        # Language token (50259 + lang_id)
        lang_tokens = {'en': 50259, 'es': 50260, 'fr': 50261, 'de': 50262}
        tokens.append(lang_tokens.get(language, 50259))

        # Task token
        if task == 'transcribe':
            tokens.append(50359)  # <|transcribe|>
        else:
            tokens.append(50358)  # <|translate|>

        return torch.tensor(tokens, device=device)


class WhisperEncoder(nn.Module):
    """Whisper audio encoder."""

    def __init__(self, n_mels: int, hidden_dim: int, n_layers: int, n_heads: int):
        super().__init__()

        # Two conv layers for initial feature extraction
        self.conv1 = nn.Conv1d(n_mels, hidden_dim, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, stride=2, padding=1)

        self.pos_encoding = SinusoidalPositionalEncoding(hidden_dim, max_len=1500)

        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=hidden_dim,
                nhead=n_heads,
                dim_feedforward=hidden_dim * 4,
                batch_first=True
            )
            for _ in range(n_layers)
        ])

        self.ln = nn.LayerNorm(hidden_dim)

    def forward(self, mel: torch.Tensor) -> torch.Tensor:
        x = F.gelu(self.conv1(mel))
        x = F.gelu(self.conv2(x))

        x = x.permute(0, 2, 1)  # [B, T, hidden_dim]
        x = self.pos_encoding(x)

        for layer in self.layers:
            x = layer(x)

        return self.ln(x)


class WhisperDecoder(nn.Module):
    """Whisper text decoder."""

    def __init__(self, vocab_size: int, hidden_dim: int, n_layers: int, n_heads: int):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.pos_encoding = SinusoidalPositionalEncoding(hidden_dim, max_len=448)

        self.layers = nn.ModuleList([
            nn.TransformerDecoderLayer(
                d_model=hidden_dim,
                nhead=n_heads,
                dim_feedforward=hidden_dim * 4,
                batch_first=True
            )
            for _ in range(n_layers)
        ])

        self.ln = nn.LayerNorm(hidden_dim)
        self.head = nn.Linear(hidden_dim, vocab_size)

    def forward(
        self,
        tokens: torch.Tensor,
        encoder_output: torch.Tensor
    ) -> torch.Tensor:
        x = self.embedding(tokens)
        x = self.pos_encoding(x)

        # Causal mask
        T = x.size(1)
        causal_mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()

        for layer in self.layers:
            x = layer(x, encoder_output, tgt_mask=causal_mask)

        x = self.ln(x)
        return self.head(x)


def audio_lm_comparison():
    """Compare audio language models."""
    models = {
        'AudioLM': {
            'approach': 'Semantic + acoustic tokens',
            'capabilities': ['Audio continuation', 'Music generation'],
            'tokens': 'w2v-BERT semantic + SoundStream acoustic',
            'quality': 'High coherence',
            'note': 'Three-stage hierarchical'
        },
        'VALL-E': {
            'approach': 'AR + NAR codec modeling',
            'capabilities': ['Zero-shot TTS', 'Voice cloning'],
            'tokens': 'Encodec (8 codebooks)',
            'quality': 'Near human TTS',
            'note': '3-second prompt for cloning'
        },
        'MusicLM': {
            'approach': 'Hierarchical generation',
            'capabilities': ['Text-to-music'],
            'tokens': 'MuLan audio + SoundStream',
            'quality': 'High-quality music',
            'note': 'Semantic tokens from MuLan'
        },
        'MusicGen': {
            'approach': 'Single-stage with delay pattern',
            'capabilities': ['Text-to-music', 'Melody conditioning'],
            'tokens': 'Encodec (4 codebooks)',
            'quality': 'Excellent',
            'note': 'Efficient parallel generation'
        },
        'Bark': {
            'approach': 'GPT-style three-stage',
            'capabilities': ['TTS', 'Music', 'Sound effects'],
            'tokens': 'Custom codec',
            'quality': 'Good',
            'note': 'Highly controllable'
        },
        'Stable Audio': {
            'approach': 'Latent diffusion',
            'capabilities': ['Text-to-music', 'Text-to-audio'],
            'tokens': 'Continuous latents',
            'quality': 'High-fidelity',
            'note': 'Diffusion-based, not LM'
        }
    }

    print("Audio Language Model Comparison:")
    print("=" * 70)
    for name, info in models.items():
        print(f"\n{name}:")
        for k, v in info.items():
            if isinstance(v, list):
                print(f"  {k}: {', '.join(v)}")
            else:
                print(f"  {k}: {v}")


audio_lm_comparison()

Key Takeaways

Audio language models represent a paradigm shift in audio processing, unifying diverse tasks under a single language modeling framework. Neural codecs like Encodec enable this by converting continuous audio to discrete tokens suitable for transformer architectures. Key innovations include: (1) hierarchical tokenization separating semantic content from acoustic details, (2) residual vector quantization for efficient high-fidelity compression, (3) AR + NAR generation strategies balancing quality and speed, (4) classifier-free guidance for controllable generation. These models achieve remarkable zero-shot capabilities—VALL-E can clone voices from 3 seconds of audio, while MusicGen creates coherent music from text descriptions. The approach scales with compute and data, following the same scaling laws as text LLMs. Challenges include generation latency, long-form coherence, and preventing misuse through watermarking and detection. The convergence of audio and language modeling opens new frontiers in multimodal AI, enabling systems that seamlessly understand and generate across text, speech, and music.