Intermediate Advanced 90 min read

Chapter 19: Generative Adversarial Networks

GAN framework, training dynamics, and GAN variants.

Libraries covered: PyTorch

Learning Objectives

["Understand adversarial training", "Build GANs", "Evaluate generative models"]


19.1 Introduction to Diffusion Models Intermediate

Introduction to Diffusion Models

Diffusion models represent a paradigm shift in generative modeling, achieving state-of-the-art results in image synthesis, audio generation, and numerous other domains. These models learn to generate data by reversing a gradual noising process, transforming pure noise into structured samples through iterative refinement.

The Diffusion Paradigm

The core insight behind diffusion models is surprisingly simple: destroying information is easy, but learning to reverse that destruction enables generation. The forward diffusion process gradually adds noise to data until it becomes indistinguishable from random noise. The reverse process learns to undo this corruption step by step, recovering structure from chaos.

Unlike GANs which learn through adversarial competition, or VAEs which compress data through a bottleneck, diffusion models operate through a sequence of small denoising steps. Each step only needs to remove a small amount of noise, making the learning problem tractable even for complex high-dimensional data.

PYTHON
import torch
import torch.nn as nn
import numpy as np

class DiffusionProcess:
    def __init__(self, num_timesteps=1000, beta_start=0.0001, beta_end=0.02):
        self.num_timesteps = num_timesteps

        # Linear noise schedule
        self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod_prev = torch.cat([
            torch.tensor([1.0]), self.alphas_cumprod[:-1]
        ])

        # Precompute useful quantities
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)

    def q_sample(self, x_0, t, noise=None):
        """Forward process: add noise to data."""
        if noise is None:
            noise = torch.randn_like(x_0)

        sqrt_alpha = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alpha = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)

        return sqrt_alpha * x_0 + sqrt_one_minus_alpha * noise

    def visualize_forward_process(self, x_0, steps=[0, 250, 500, 750, 999]):
        """Show progressive noise addition."""
        samples = []
        for t in steps:
            t_tensor = torch.tensor([t])
            noisy = self.q_sample(x_0, t_tensor)
            samples.append(noisy)
        return samples

Forward Process: Adding Noise

The forward diffusion process defines a Markov chain that gradually adds Gaussian noise to data. At each timestep, a small amount of noise is added according to a variance schedule. After enough steps, the data distribution converges to a standard Gaussian regardless of the initial data distribution.

The variance schedule determines how quickly noise is added. Common choices include linear schedules that increase noise uniformly, cosine schedules that add noise more slowly at the start, and learned schedules optimized during training. The schedule significantly impacts generation quality and training stability.

PYTHON
class NoiseScheduler:
    def __init__(self, num_timesteps=1000, schedule_type='linear'):
        self.num_timesteps = num_timesteps

        if schedule_type == 'linear':
            self.betas = self._linear_schedule()
        elif schedule_type == 'cosine':
            self.betas = self._cosine_schedule()
        elif schedule_type == 'quadratic':
            self.betas = self._quadratic_schedule()

        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

    def _linear_schedule(self, beta_start=0.0001, beta_end=0.02):
        return torch.linspace(beta_start, beta_end, self.num_timesteps)

    def _cosine_schedule(self, s=0.008):
        steps = self.num_timesteps + 1
        t = torch.linspace(0, self.num_timesteps, steps)
        alphas_cumprod = torch.cos((t / self.num_timesteps + s) / (1 + s) * np.pi / 2) ** 2
        alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
        betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
        return torch.clamp(betas, 0.0001, 0.9999)

    def _quadratic_schedule(self, beta_start=0.0001, beta_end=0.02):
        return torch.linspace(beta_start**0.5, beta_end**0.5, self.num_timesteps) ** 2

    def add_noise(self, x_0, t):
        noise = torch.randn_like(x_0)
        sqrt_alpha = torch.sqrt(self.alphas_cumprod[t]).view(-1, 1, 1, 1)
        sqrt_one_minus = torch.sqrt(1 - self.alphas_cumprod[t]).view(-1, 1, 1, 1)
        return sqrt_alpha * x_0 + sqrt_one_minus * noise, noise

Reverse Process: Learning to Denoise

The reverse process learns to invert the forward diffusion, gradually removing noise to recover data. A neural network is trained to predict the noise added at each step, enabling iterative denoising from pure noise to clean samples. The network is conditioned on the timestep, allowing it to adapt its behavior to different noise levels.

Training minimizes the difference between predicted and actual noise across all timesteps. The objective is remarkably simple: given a noisy image and timestep, predict the noise that was added. This prediction is then used to take a small step toward cleaner data.

PYTHON
class SimpleUNet(nn.Module):
    def __init__(self, in_channels=3, base_channels=64, time_embed_dim=256):
        super().__init__()

        self.time_mlp = nn.Sequential(
            nn.Linear(1, time_embed_dim),
            nn.SiLU(),
            nn.Linear(time_embed_dim, time_embed_dim)
        )

        # Encoder
        self.enc1 = self._conv_block(in_channels, base_channels)
        self.enc2 = self._conv_block(base_channels, base_channels * 2)
        self.enc3 = self._conv_block(base_channels * 2, base_channels * 4)

        # Bottleneck
        self.bottleneck = self._conv_block(base_channels * 4, base_channels * 8)

        # Decoder
        self.dec3 = self._conv_block(base_channels * 8 + base_channels * 4, base_channels * 4)
        self.dec2 = self._conv_block(base_channels * 4 + base_channels * 2, base_channels * 2)
        self.dec1 = self._conv_block(base_channels * 2 + base_channels, base_channels)

        self.final = nn.Conv2d(base_channels, in_channels, 1)

        self.pool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    def _conv_block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch),
            nn.SiLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch),
            nn.SiLU()
        )

    def forward(self, x, t):
        # Time embedding
        t_embed = self.time_mlp(t.float().unsqueeze(-1) / 1000)

        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))

        # Bottleneck
        b = self.bottleneck(self.pool(e3))

        # Decoder with skip connections
        d3 = self.dec3(torch.cat([self.upsample(b), e3], dim=1))
        d2 = self.dec2(torch.cat([self.upsample(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.upsample(d2), e1], dim=1))

        return self.final(d1)


def train_step(model, x_0, noise_scheduler, optimizer):
    batch_size = x_0.size(0)
    device = x_0.device

    # Sample random timesteps
    t = torch.randint(0, noise_scheduler.num_timesteps, (batch_size,), device=device)

    # Add noise
    noisy_x, noise = noise_scheduler.add_noise(x_0, t)

    # Predict noise
    predicted_noise = model(noisy_x, t)

    # Simple MSE loss
    loss = nn.functional.mse_loss(predicted_noise, noise)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item()

Comparison with Other Generative Models

Diffusion models occupy a unique position in the landscape of generative models. Unlike GANs, they do not require adversarial training and avoid mode collapse. Unlike VAEs, they do not compress data through a bottleneck, allowing for higher fidelity generation. The iterative nature of sampling trades computation for quality.

GANs excel at fast single-shot generation but suffer from training instability and mode dropping. VAEs provide stable training and meaningful latent spaces but often produce blurry outputs. Diffusion models achieve both stable training and high-quality generation, though they require many function evaluations during sampling.

PYTHON
class GenerativeModelComparison:
    """Conceptual comparison of generative model approaches."""

    def gan_generation(self, generator, latent_dim, num_samples):
        """GAN: Single forward pass from noise to image."""
        z = torch.randn(num_samples, latent_dim)
        return generator(z)  # One step

    def vae_generation(self, decoder, latent_dim, num_samples):
        """VAE: Sample latent, decode in one pass."""
        z = torch.randn(num_samples, latent_dim)
        return decoder(z)  # One step

    def diffusion_generation(self, model, scheduler, shape, num_steps):
        """Diffusion: Iterative refinement from noise."""
        x = torch.randn(shape)  # Start from pure noise

        for t in reversed(range(num_steps)):
            # Predict and remove noise
            noise_pred = model(x, torch.tensor([t]))
            x = scheduler.step(noise_pred, t, x)

        return x  # Many steps


class DiffusionSampler:
    def __init__(self, model, scheduler):
        self.model = model
        self.scheduler = scheduler

    @torch.no_grad()
    def sample(self, shape, device):
        self.model.eval()

        # Start from pure noise
        x = torch.randn(shape, device=device)

        for t in reversed(range(self.scheduler.num_timesteps)):
            t_batch = torch.full((shape[0],), t, device=device, dtype=torch.long)

            # Predict noise
            noise_pred = self.model(x, t_batch)

            # Compute denoised estimate
            alpha = self.scheduler.alphas[t]
            alpha_cumprod = self.scheduler.alphas_cumprod[t]
            beta = self.scheduler.betas[t]

            if t > 0:
                noise = torch.randn_like(x)
            else:
                noise = torch.zeros_like(x)

            x = (1 / torch.sqrt(alpha)) * (
                x - (beta / torch.sqrt(1 - alpha_cumprod)) * noise_pred
            ) + torch.sqrt(beta) * noise

        self.model.train()
        return x

Key Advantages

Diffusion models offer several compelling advantages. Training is stable without requiring careful balancing between competing networks. The iterative sampling process enables trading computation for quality, with more steps producing better results. The models naturally support conditioning and guidance without architectural changes.

The framework is highly flexible, accommodating various data types including images, audio, video, and 3D structures. Recent advances have dramatically reduced sampling time while maintaining quality, making diffusion models practical for real-world applications.

PYTHON
class DiffusionAdvantages:
    """Demonstrating key diffusion model properties."""

    def stable_training(self, model, dataloader, scheduler, epochs=100):
        """Training is simple MSE minimization - no adversarial dynamics."""
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
        losses = []

        for epoch in range(epochs):
            epoch_loss = 0
            for batch in dataloader:
                loss = train_step(model, batch, scheduler, optimizer)
                epoch_loss += loss

            losses.append(epoch_loss / len(dataloader))
            # No mode collapse, no vanishing gradients

        return losses

    def quality_vs_speed_tradeoff(self, model, scheduler, shape, device):
        """More steps = better quality."""
        results = {}

        for num_steps in [10, 50, 100, 250, 1000]:
            sampler = DiffusionSampler(model, scheduler)
            # Adjust scheduler for fewer steps
            sample = sampler.sample(shape, device)
            results[num_steps] = sample

        return results

    def easy_conditioning(self, model, condition, x_t, t):
        """Conditioning can be added through simple concatenation or cross-attention."""
        # Concatenation approach
        conditioned_input = torch.cat([x_t, condition], dim=1)
        return model(conditioned_input, t)

Key Takeaways

Diffusion models generate data by learning to reverse a gradual noising process. The forward process adds noise according to a variance schedule until data becomes Gaussian. The reverse process trains a neural network to predict and remove noise iteratively. Unlike GANs and VAEs, diffusion models combine stable training with high-quality generation through iterative refinement. The framework offers flexibility in sampling speed, conditioning, and application domains.

19.2 Denoising Diffusion Probabilistic Models Advanced

Denoising Diffusion Probabilistic Models

Denoising Diffusion Probabilistic Models (DDPM) established the foundational framework for modern diffusion-based generation. The formulation provides both theoretical grounding through variational inference and practical training objectives that scale to complex data distributions.

The DDPM Framework

DDPM defines generation as the reverse of a fixed Markov noising process. The forward process progressively corrupts data through a sequence of Gaussian transitions until reaching an isotropic Gaussian distribution. The reverse process, parameterized by a neural network, learns to invert these transitions step by step.

The forward process is fully specified by a variance schedule and requires no learning. At each step, a small amount of Gaussian noise is added, with the variance determined by the schedule. The reverse process learns the conditional distribution at each step, enabling generation by sampling backward through the chain.

PYTHON
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class DDPM:
    def __init__(self, num_timesteps=1000, beta_start=1e-4, beta_end=0.02, device='cuda'):
        self.num_timesteps = num_timesteps
        self.device = device

        # Define beta schedule
        self.betas = torch.linspace(beta_start, beta_end, num_timesteps, device=device)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)

        # Calculations for diffusion q(x_t | x_{t-1})
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)

        # Calculations for posterior q(x_{t-1} | x_t, x_0)
        self.posterior_variance = (
            self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        self.posterior_log_variance_clipped = torch.log(
            torch.clamp(self.posterior_variance, min=1e-20)
        )
        self.posterior_mean_coef1 = (
            self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        self.posterior_mean_coef2 = (
            (1.0 - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1.0 - self.alphas_cumprod)
        )

    def q_sample(self, x_0, t, noise=None):
        """Sample from q(x_t | x_0) - the forward process."""
        if noise is None:
            noise = torch.randn_like(x_0)

        sqrt_alpha_cumprod_t = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)

        return sqrt_alpha_cumprod_t * x_0 + sqrt_one_minus_alpha_cumprod_t * noise

    def q_posterior_mean_variance(self, x_0, x_t, t):
        """Compute the mean and variance of q(x_{t-1} | x_t, x_0)."""
        posterior_mean = (
            self.posterior_mean_coef1[t].view(-1, 1, 1, 1) * x_0 +
            self.posterior_mean_coef2[t].view(-1, 1, 1, 1) * x_t
        )
        posterior_variance = self.posterior_variance[t].view(-1, 1, 1, 1)
        posterior_log_variance = self.posterior_log_variance_clipped[t].view(-1, 1, 1, 1)

        return posterior_mean, posterior_variance, posterior_log_variance

Training Objective

The DDPM training objective derives from a variational bound on the log-likelihood. Through a series of simplifications, this reduces to training a network to predict the noise added at each timestep. The simplified objective uniformly weights all timesteps and predicts the noise rather than the mean or variance.

The connection to denoising score matching provides additional theoretical justification. Predicting the noise is equivalent to estimating the score function, the gradient of the log probability density. This perspective connects diffusion models to score-based generative modeling.

PYTHON
class DDPMTrainer:
    def __init__(self, model, ddpm, optimizer, device='cuda'):
        self.model = model
        self.ddpm = ddpm
        self.optimizer = optimizer
        self.device = device

    def train_step(self, x_0):
        """Single training step for DDPM."""
        batch_size = x_0.size(0)

        # Sample random timesteps uniformly
        t = torch.randint(0, self.ddpm.num_timesteps, (batch_size,), device=self.device)

        # Sample noise
        noise = torch.randn_like(x_0)

        # Get noisy image
        x_t = self.ddpm.q_sample(x_0, t, noise)

        # Predict noise
        predicted_noise = self.model(x_t, t)

        # Simple MSE loss (equivalent to simplified ELBO)
        loss = F.mse_loss(predicted_noise, noise)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

    def train_epoch(self, dataloader):
        """Train for one epoch."""
        total_loss = 0
        num_batches = 0

        for batch in dataloader:
            x_0 = batch[0].to(self.device)
            loss = self.train_step(x_0)
            total_loss += loss
            num_batches += 1

        return total_loss / num_batches


def get_loss_weights(num_timesteps, weighting='uniform'):
    """Different loss weighting schemes."""
    if weighting == 'uniform':
        return torch.ones(num_timesteps)
    elif weighting == 'snr':
        # Signal-to-noise ratio weighting
        betas = torch.linspace(1e-4, 0.02, num_timesteps)
        alphas_cumprod = torch.cumprod(1 - betas, dim=0)
        snr = alphas_cumprod / (1 - alphas_cumprod)
        return snr / snr.sum() * num_timesteps
    elif weighting == 'min_snr':
        # Min-SNR weighting for improved training
        betas = torch.linspace(1e-4, 0.02, num_timesteps)
        alphas_cumprod = torch.cumprod(1 - betas, dim=0)
        snr = alphas_cumprod / (1 - alphas_cumprod)
        return torch.minimum(snr, torch.tensor(5.0))

Noise Prediction vs Data Prediction

The network can be trained to predict different quantities: the noise added to the data, the original clean data, or the velocity (a combination of both). Each parameterization has different numerical properties and can affect training stability and sample quality.

Noise prediction works well for most timesteps but can be numerically unstable at very low noise levels. Data prediction (x0 prediction) provides cleaner gradients at low noise but may struggle at high noise levels. Velocity prediction offers a balanced approach that interpolates between these extremes.

PYTHON
class DiffusionPrediction:
    def __init__(self, ddpm):
        self.ddpm = ddpm

    def predict_x0_from_noise(self, x_t, t, noise_pred):
        """Convert noise prediction to x0 prediction."""
        sqrt_alpha = self.ddpm.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus = self.ddpm.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
        return (x_t - sqrt_one_minus * noise_pred) / sqrt_alpha

    def predict_noise_from_x0(self, x_t, t, x0_pred):
        """Convert x0 prediction to noise prediction."""
        sqrt_alpha = self.ddpm.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus = self.ddpm.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
        return (x_t - sqrt_alpha * x0_pred) / sqrt_one_minus

    def predict_velocity(self, x_0, noise, t):
        """Compute velocity for v-prediction."""
        sqrt_alpha = self.ddpm.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus = self.ddpm.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
        return sqrt_alpha * noise - sqrt_one_minus * x_0

    def predict_x0_from_velocity(self, x_t, t, v_pred):
        """Convert velocity prediction to x0."""
        sqrt_alpha = self.ddpm.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus = self.ddpm.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
        return sqrt_alpha * x_t - sqrt_one_minus * v_pred


class VPredictionModel(nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.base_model = base_model

    def forward(self, x_t, t):
        return self.base_model(x_t, t)

    def get_x0_prediction(self, x_t, t, ddpm):
        v_pred = self.forward(x_t, t)
        predictor = DiffusionPrediction(ddpm)
        return predictor.predict_x0_from_velocity(x_t, t, v_pred)

DDPM Sampling Algorithm

Sampling from DDPM involves iteratively applying the learned reverse process starting from pure Gaussian noise. At each step, the model predicts the noise, which is used to estimate the mean of the reverse distribution. Random noise is added according to the posterior variance to maintain stochasticity.

The sampling process requires evaluating the neural network at every timestep, making it computationally expensive. With 1000 timesteps, generating a single image requires 1000 forward passes through the network.

PYTHON
class DDPMSampler:
    def __init__(self, model, ddpm):
        self.model = model
        self.ddpm = ddpm

    @torch.no_grad()
    def p_mean_variance(self, x_t, t):
        """Compute mean and variance for p(x_{t-1} | x_t)."""
        # Predict noise
        noise_pred = self.model(x_t, t)

        # Compute predicted x_0
        predictor = DiffusionPrediction(self.ddpm)
        x_0_pred = predictor.predict_x0_from_noise(x_t, t, noise_pred)

        # Clip x_0 prediction for stability
        x_0_pred = torch.clamp(x_0_pred, -1, 1)

        # Get posterior mean and variance
        mean, variance, log_variance = self.ddpm.q_posterior_mean_variance(x_0_pred, x_t, t)

        return mean, variance, log_variance

    @torch.no_grad()
    def p_sample(self, x_t, t):
        """Sample x_{t-1} from p(x_{t-1} | x_t)."""
        mean, variance, _ = self.p_mean_variance(x_t, t)

        noise = torch.randn_like(x_t)

        # No noise at t=0
        nonzero_mask = (t != 0).float().view(-1, 1, 1, 1)

        return mean + nonzero_mask * torch.sqrt(variance) * noise

    @torch.no_grad()
    def sample(self, shape):
        """Generate samples starting from noise."""
        self.model.eval()
        device = next(self.model.parameters()).device

        # Start from pure noise
        x = torch.randn(shape, device=device)

        for t in reversed(range(self.ddpm.num_timesteps)):
            t_batch = torch.full((shape[0],), t, device=device, dtype=torch.long)
            x = self.p_sample(x, t_batch)

        self.model.train()
        return x

DDIM: Accelerated Sampling

Denoising Diffusion Implicit Models (DDIM) enable faster sampling by using a non-Markovian reverse process. By making the reverse process deterministic or nearly so, DDIM allows skipping timesteps while maintaining sample quality. This can reduce the number of function evaluations from 1000 to as few as 20-50.

DDIM introduces a parameter controlling the stochasticity of sampling. Setting this parameter to zero makes sampling completely deterministic, enabling consistent outputs from the same initial noise and interpolation in the latent space.

PYTHON
class DDIMSampler:
    def __init__(self, model, ddpm, eta=0.0):
        self.model = model
        self.ddpm = ddpm
        self.eta = eta  # 0 = deterministic, 1 = DDPM

    @torch.no_grad()
    def sample(self, shape, num_inference_steps=50):
        """Generate samples with DDIM."""
        self.model.eval()
        device = next(self.model.parameters()).device

        # Create timestep schedule (evenly spaced)
        step_size = self.ddpm.num_timesteps // num_inference_steps
        timesteps = list(range(0, self.ddpm.num_timesteps, step_size))
        timesteps = list(reversed(timesteps))

        # Start from noise
        x = torch.randn(shape, device=device)

        for i, t in enumerate(timesteps):
            t_batch = torch.full((shape[0],), t, device=device, dtype=torch.long)

            # Predict noise
            noise_pred = self.model(x, t_batch)

            # Get alpha values
            alpha_cumprod_t = self.ddpm.alphas_cumprod[t]

            if i < len(timesteps) - 1:
                t_prev = timesteps[i + 1]
                alpha_cumprod_t_prev = self.ddpm.alphas_cumprod[t_prev]
            else:
                alpha_cumprod_t_prev = torch.tensor(1.0, device=device)

            # Predict x_0
            x_0_pred = (x - torch.sqrt(1 - alpha_cumprod_t) * noise_pred) / torch.sqrt(alpha_cumprod_t)
            x_0_pred = torch.clamp(x_0_pred, -1, 1)

            # Compute variance
            sigma_t = self.eta * torch.sqrt(
                (1 - alpha_cumprod_t_prev) / (1 - alpha_cumprod_t) *
                (1 - alpha_cumprod_t / alpha_cumprod_t_prev)
            )

            # Direction pointing to x_t
            pred_dir = torch.sqrt(1 - alpha_cumprod_t_prev - sigma_t**2) * noise_pred

            # Random noise
            noise = torch.randn_like(x) if t > 0 else torch.zeros_like(x)

            # Update x
            x = torch.sqrt(alpha_cumprod_t_prev) * x_0_pred + pred_dir + sigma_t * noise

        self.model.train()
        return x

    @torch.no_grad()
    def interpolate(self, x1, x2, num_steps=10, num_inference_steps=50):
        """Interpolate between two samples using deterministic DDIM."""
        device = next(self.model.parameters()).device

        # Encode both images to noise (inversion)
        z1 = self.ddim_inversion(x1, num_inference_steps)
        z2 = self.ddim_inversion(x2, num_inference_steps)

        # Interpolate in latent space
        interpolations = []
        for alpha in torch.linspace(0, 1, num_steps):
            z_interp = (1 - alpha) * z1 + alpha * z2
            x_interp = self.sample_from_noise(z_interp, num_inference_steps)
            interpolations.append(x_interp)

        return torch.stack(interpolations)

Improved Diffusion Training

Several techniques improve DDPM training and generation quality. Learning the variance in addition to the mean provides better sample quality, especially for fewer sampling steps. Improved noise schedules like the cosine schedule prevent information loss at the start of diffusion.

Classifier-free guidance trains a single model for both conditional and unconditional generation, enabling control over the strength of conditioning during sampling. This technique has become standard for conditional diffusion models.

PYTHON
class ImprovedDDPM(nn.Module):
    def __init__(self, base_channels=128, num_classes=None):
        super().__init__()
        self.num_classes = num_classes

        # Main UNet for noise prediction
        self.unet = UNetWithAttention(base_channels)

        # Additional head for variance prediction
        self.var_head = nn.Sequential(
            nn.Conv2d(base_channels, base_channels, 3, padding=1),
            nn.SiLU(),
            nn.Conv2d(base_channels, 3, 1)  # Predict log variance
        )

        # Class embedding for conditional generation
        if num_classes is not None:
            self.class_embed = nn.Embedding(num_classes + 1, base_channels * 4)

    def forward(self, x, t, class_labels=None, drop_prob=0.1):
        # Get class embedding
        if self.num_classes is not None and class_labels is not None:
            # Classifier-free guidance: randomly drop class labels
            if self.training:
                mask = torch.rand(class_labels.size(0)) < drop_prob
                class_labels = class_labels.clone()
                class_labels[mask] = self.num_classes  # Null class

            class_emb = self.class_embed(class_labels)
        else:
            class_emb = None

        # Get features from UNet
        features, noise_pred = self.unet(x, t, class_emb)

        # Predict variance
        var_pred = self.var_head(features)

        return noise_pred, var_pred


class UNetWithAttention(nn.Module):
    def __init__(self, base_channels=128):
        super().__init__()

        # Time embedding
        self.time_embed = nn.Sequential(
            SinusoidalPositionEmbeddings(base_channels),
            nn.Linear(base_channels, base_channels * 4),
            nn.SiLU(),
            nn.Linear(base_channels * 4, base_channels * 4)
        )

        # Encoder, bottleneck, decoder with attention
        self.encoder = nn.ModuleList([
            ResBlock(3, base_channels, base_channels * 4),
            ResBlock(base_channels, base_channels * 2, base_channels * 4),
            AttentionBlock(base_channels * 2),
            ResBlock(base_channels * 2, base_channels * 4, base_channels * 4)
        ])

        self.bottleneck = nn.Sequential(
            ResBlock(base_channels * 4, base_channels * 4, base_channels * 4),
            AttentionBlock(base_channels * 4),
            ResBlock(base_channels * 4, base_channels * 4, base_channels * 4)
        )

        self.decoder = nn.ModuleList([
            ResBlock(base_channels * 8, base_channels * 2, base_channels * 4),
            AttentionBlock(base_channels * 2),
            ResBlock(base_channels * 4, base_channels, base_channels * 4),
            ResBlock(base_channels * 2, base_channels, base_channels * 4)
        ])

        self.final = nn.Conv2d(base_channels, 3, 1)

    def forward(self, x, t, class_emb=None):
        t_emb = self.time_embed(t)
        if class_emb is not None:
            t_emb = t_emb + class_emb

        # Encode, process, decode
        skips = []
        for layer in self.encoder:
            if isinstance(layer, ResBlock):
                x = layer(x, t_emb)
                skips.append(x)
                x = F.avg_pool2d(x, 2)
            else:
                x = layer(x)

        x = self.bottleneck[0](x, t_emb)
        x = self.bottleneck[1](x)
        x = self.bottleneck[2](x, t_emb)

        for layer in self.decoder:
            if isinstance(layer, ResBlock):
                x = F.interpolate(x, scale_factor=2)
                x = torch.cat([x, skips.pop()], dim=1)
                x = layer(x, t_emb)
            else:
                x = layer(x)

        return x, self.final(x)


class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=t.device) * -embeddings)
        embeddings = t.unsqueeze(-1) * embeddings.unsqueeze(0)
        embeddings = torch.cat([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
        return embeddings

Key Takeaways

DDPM provides a principled framework for diffusion-based generation through variational inference. Training reduces to predicting noise added at random timesteps using a simple MSE loss. Different prediction targets (noise, data, velocity) offer tradeoffs in numerical stability. DDIM enables faster sampling by making the reverse process non-Markovian and deterministic. Improved techniques including learned variance and classifier-free guidance enhance quality and control.

19.3 Score-Based Models and Guidance Advanced

Score-Based Models and Guidance

Score-based generative models provide an alternative perspective on diffusion that connects to statistical physics and energy-based modeling. This framework introduces powerful guidance techniques that enable precise control over generation without retraining.

Score Functions and Score Matching

The score function is the gradient of the log probability density with respect to the data. Learning this gradient rather than the density itself avoids the intractable normalization constant that makes direct density estimation difficult. Score matching provides techniques to learn the score function from samples without knowing the true density.

The connection between score functions and diffusion is fundamental. The noise prediction in diffusion models is proportional to the negative score function. This equivalence reveals that diffusion models are implicitly learning to estimate probability gradients, enabling generation through gradient-based sampling.

PYTHON
import torch
import torch.nn as nn
import torch.nn.functional as F

class ScoreNetwork(nn.Module):
    """Network that estimates the score function."""

    def __init__(self, input_dim, hidden_dim=256):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(input_dim + 1, hidden_dim),  # +1 for noise level
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, input_dim)
        )

    def forward(self, x, sigma):
        """Predict score at noise level sigma."""
        sigma_embed = sigma.view(-1, 1)
        x_input = torch.cat([x, sigma_embed.expand(-1, 1)], dim=-1)
        return self.net(x_input)


def denoising_score_matching_loss(score_net, x, sigma):
    """Denoising score matching objective."""
    noise = torch.randn_like(x)
    perturbed_x = x + sigma * noise

    # True score of Gaussian perturbation
    target_score = -noise / sigma

    # Predicted score
    predicted_score = score_net(perturbed_x, sigma)

    # Weighted MSE loss
    loss = 0.5 * ((predicted_score - target_score) ** 2).sum(dim=-1)
    return loss.mean()


def sliced_score_matching_loss(score_net, x, sigma, num_projections=1):
    """Sliced score matching - more efficient for high dimensions."""
    x.requires_grad_(True)
    perturbed_x = x + sigma * torch.randn_like(x)

    score = score_net(perturbed_x, sigma)

    # Random projection vectors
    v = torch.randn_like(score)
    v = v / v.norm(dim=-1, keepdim=True)

    # Compute directional derivative
    score_v = (score * v).sum()
    grad_score_v = torch.autograd.grad(score_v, perturbed_x, create_graph=True)[0]
    grad_v = (grad_score_v * v).sum(dim=-1)

    # Loss
    loss = 0.5 * (score * v).sum(dim=-1) ** 2 + grad_v
    return loss.mean()

Score-Based Generative Models

Score-based generative models use the learned score function to generate samples through Langevin dynamics. Starting from random noise, iterative updates follow the score function to move toward high-probability regions. Adding noise at each step enables exploration and ensures samples come from the correct distribution.

The noise-conditional score network learns scores at multiple noise levels. During generation, sampling proceeds from high noise levels where the score field is smooth to low noise levels where fine details emerge. This multi-scale approach addresses the difficulty of learning accurate scores in low-density regions.

PYTHON
class ScoreBasedGenerator:
    def __init__(self, score_net, sigma_min=0.01, sigma_max=50, num_scales=10):
        self.score_net = score_net
        self.sigmas = torch.exp(
            torch.linspace(
                torch.log(torch.tensor(sigma_max)),
                torch.log(torch.tensor(sigma_min)),
                num_scales
            )
        )

    @torch.no_grad()
    def langevin_dynamics(self, x, sigma, num_steps=100, step_size=None):
        """Run Langevin dynamics at a fixed noise level."""
        if step_size is None:
            step_size = (sigma ** 2) * 2

        for _ in range(num_steps):
            score = self.score_net(x, sigma.expand(x.size(0)))
            noise = torch.randn_like(x)
            x = x + step_size * score + torch.sqrt(2 * step_size) * noise

        return x

    @torch.no_grad()
    def annealed_langevin_dynamics(self, shape, device, steps_per_sigma=100):
        """Generate samples using annealed Langevin dynamics."""
        self.score_net.eval()

        # Start from noise
        x = torch.randn(shape, device=device) * self.sigmas[0]

        for sigma in self.sigmas:
            sigma_tensor = torch.full((shape[0],), sigma, device=device)
            step_size = 0.5 * (sigma ** 2) / (self.sigmas[-1] ** 2)

            for _ in range(steps_per_sigma):
                score = self.score_net(x, sigma_tensor)
                noise = torch.randn_like(x)
                x = x + step_size * score + torch.sqrt(2 * step_size) * noise

        self.score_net.train()
        return x


class ContinuousScoreModel:
    """Score model with continuous noise levels (SDE framework)."""

    def __init__(self, score_net, sigma_min=0.01, sigma_max=50):
        self.score_net = score_net
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max

    def sample_sigma(self, batch_size, device):
        """Sample noise levels log-uniformly."""
        log_sigma = torch.rand(batch_size, device=device)
        log_sigma = log_sigma * (
            torch.log(torch.tensor(self.sigma_max)) -
            torch.log(torch.tensor(self.sigma_min))
        ) + torch.log(torch.tensor(self.sigma_min))
        return torch.exp(log_sigma)

    def train_step(self, x, optimizer):
        """Training step with continuous noise levels."""
        batch_size = x.size(0)
        device = x.device

        sigma = self.sample_sigma(batch_size, device)
        loss = denoising_score_matching_loss(self.score_net, x, sigma.view(-1, 1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return loss.item()

Classifier Guidance

Classifier guidance steers diffusion generation toward a target class using gradients from a separately trained classifier. During sampling, the score function is augmented with the gradient of the classifier log-probability. This pushes samples toward regions that the classifier identifies as belonging to the target class.

The guidance scale controls how strongly generation is steered. Higher scales produce samples more strongly associated with the target class but may reduce diversity or sample quality. The classifier must be trained on noisy images to provide meaningful gradients throughout the diffusion process.

PYTHON
class ClassifierGuidance:
    def __init__(self, diffusion_model, classifier, num_classes):
        self.diffusion_model = diffusion_model
        self.classifier = classifier
        self.num_classes = num_classes

    def get_classifier_gradient(self, x, t, target_class):
        """Compute gradient of classifier log-prob."""
        x = x.detach().requires_grad_(True)
        logits = self.classifier(x, t)
        log_probs = F.log_softmax(logits, dim=-1)
        selected = log_probs[range(len(target_class)), target_class].sum()

        gradient = torch.autograd.grad(selected, x)[0]
        return gradient

    @torch.no_grad()
    def guided_sample(self, shape, target_class, guidance_scale=1.0, device='cuda'):
        """Sample with classifier guidance."""
        self.diffusion_model.eval()
        self.classifier.eval()

        x = torch.randn(shape, device=device)
        target = torch.full((shape[0],), target_class, device=device, dtype=torch.long)

        for t in reversed(range(self.diffusion_model.num_timesteps)):
            t_batch = torch.full((shape[0],), t, device=device, dtype=torch.long)

            # Get unconditional score
            with torch.no_grad():
                noise_pred = self.diffusion_model(x, t_batch)

            # Get classifier gradient (need gradients here)
            with torch.enable_grad():
                x_grad = x.detach().requires_grad_(True)
                classifier_grad = self.get_classifier_gradient(x_grad, t_batch, target)

            # Combine: guided_score = unconditional_score + scale * classifier_grad
            guided_noise_pred = noise_pred - guidance_scale * classifier_grad

            # Standard DDPM update with guided prediction
            x = self._ddpm_step(x, t, guided_noise_pred)

        return x


class NoisyClassifier(nn.Module):
    """Classifier trained on noisy images for guidance."""

    def __init__(self, num_classes, base_channels=64):
        super().__init__()

        self.time_embed = nn.Sequential(
            nn.Linear(1, base_channels * 4),
            nn.SiLU(),
            nn.Linear(base_channels * 4, base_channels * 4)
        )

        self.conv = nn.Sequential(
            nn.Conv2d(3, base_channels, 3, padding=1),
            nn.SiLU(),
            nn.Conv2d(base_channels, base_channels * 2, 3, stride=2, padding=1),
            nn.SiLU(),
            nn.Conv2d(base_channels * 2, base_channels * 4, 3, stride=2, padding=1),
            nn.SiLU(),
            nn.AdaptiveAvgPool2d(1)
        )

        self.classifier = nn.Linear(base_channels * 4 + base_channels * 4, num_classes)

    def forward(self, x, t):
        t_emb = self.time_embed(t.float().unsqueeze(-1) / 1000)
        features = self.conv(x).flatten(1)
        combined = torch.cat([features, t_emb], dim=-1)
        return self.classifier(combined)

Classifier-Free Guidance

Classifier-free guidance eliminates the need for a separate classifier by training the diffusion model for both conditional and unconditional generation. During training, class labels are randomly dropped, teaching the model to generate both with and without conditioning. During sampling, conditional and unconditional predictions are combined.

The guidance formula extrapolates away from the unconditional prediction toward the conditional prediction. This amplifies features that distinguish conditioned from unconditioned generation, producing samples more strongly aligned with the condition while maintaining coherence.

PYTHON
class ClassifierFreeGuidance:
    def __init__(self, model, num_classes):
        self.model = model
        self.num_classes = num_classes
        self.null_class = num_classes  # Use extra class index for unconditional

    def train_step(self, x, labels, ddpm, optimizer, drop_prob=0.1):
        """Training with random label dropping."""
        batch_size = x.size(0)
        device = x.device

        # Randomly drop labels
        drop_mask = torch.rand(batch_size, device=device) < drop_prob
        labels = labels.clone()
        labels[drop_mask] = self.null_class

        # Standard diffusion training
        t = torch.randint(0, ddpm.num_timesteps, (batch_size,), device=device)
        noise = torch.randn_like(x)
        x_t = ddpm.q_sample(x, t, noise)

        noise_pred = self.model(x_t, t, labels)
        loss = F.mse_loss(noise_pred, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return loss.item()

    @torch.no_grad()
    def guided_sample(self, shape, target_class, ddpm, guidance_scale=7.5, device='cuda'):
        """Sample with classifier-free guidance."""
        self.model.eval()

        x = torch.randn(shape, device=device)
        target = torch.full((shape[0],), target_class, device=device, dtype=torch.long)
        null_target = torch.full((shape[0],), self.null_class, device=device, dtype=torch.long)

        for t in reversed(range(ddpm.num_timesteps)):
            t_batch = torch.full((shape[0],), t, device=device, dtype=torch.long)

            # Conditional prediction
            noise_pred_cond = self.model(x, t_batch, target)

            # Unconditional prediction
            noise_pred_uncond = self.model(x, t_batch, null_target)

            # Classifier-free guidance formula
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

            # DDPM update
            x = self._ddpm_step(x, t, noise_pred, ddpm)

        self.model.train()
        return x

    def _ddpm_step(self, x, t, noise_pred, ddpm):
        """Single DDPM reverse step."""
        alpha = ddpm.alphas[t]
        alpha_cumprod = ddpm.alphas_cumprod[t]
        beta = ddpm.betas[t]

        if t > 0:
            noise = torch.randn_like(x)
        else:
            noise = torch.zeros_like(x)

        x = (1 / torch.sqrt(alpha)) * (
            x - (beta / torch.sqrt(1 - alpha_cumprod)) * noise_pred
        ) + torch.sqrt(beta) * noise

        return x


class TextConditionedModel(nn.Module):
    """Diffusion model conditioned on text embeddings."""

    def __init__(self, unet, text_encoder, embed_dim=768):
        super().__init__()
        self.unet = unet
        self.text_encoder = text_encoder

        # Null embedding for unconditional generation
        self.null_embedding = nn.Parameter(torch.randn(1, 77, embed_dim))

    def forward(self, x, t, text_tokens=None, drop_prob=0.1):
        if text_tokens is not None:
            # Encode text
            text_embed = self.text_encoder(text_tokens)

            # Random dropping for classifier-free guidance training
            if self.training:
                batch_size = x.size(0)
                drop_mask = torch.rand(batch_size) < drop_prob
                text_embed[drop_mask] = self.null_embedding.expand(drop_mask.sum(), -1, -1)
        else:
            text_embed = self.null_embedding.expand(x.size(0), -1, -1)

        return self.unet(x, t, text_embed)

Guidance Scale Effects

The guidance scale dramatically affects generation quality and diversity. Low scales produce diverse but potentially unfocused samples. High scales produce samples strongly aligned with the condition but may reduce diversity or introduce artifacts. Finding the optimal scale balances these tradeoffs.

Guidance can be applied to any form of conditioning including text, class labels, images, or combinations. The same principle of extrapolating between conditional and unconditional predictions applies regardless of the conditioning modality.

PYTHON
class GuidanceAnalysis:
    """Tools for analyzing guidance scale effects."""

    def __init__(self, model, ddpm):
        self.model = model
        self.ddpm = ddpm

    @torch.no_grad()
    def sample_multiple_scales(self, shape, condition, scales, device='cuda'):
        """Generate samples at multiple guidance scales."""
        results = {}

        for scale in scales:
            samples = self._guided_sample(shape, condition, scale, device)
            results[scale] = samples

        return results

    def compute_diversity(self, samples):
        """Measure sample diversity via pairwise distances."""
        flat = samples.view(samples.size(0), -1)
        distances = torch.cdist(flat, flat)
        # Average off-diagonal distance
        mask = ~torch.eye(samples.size(0), dtype=torch.bool, device=samples.device)
        return distances[mask].mean().item()

    def compute_condition_alignment(self, samples, condition, classifier):
        """Measure how well samples match the condition."""
        with torch.no_grad():
            logits = classifier(samples)
            probs = F.softmax(logits, dim=-1)
            alignment = probs[:, condition].mean().item()
        return alignment

    @torch.no_grad()
    def analyze_scale_tradeoff(self, shape, condition, scales, classifier, device='cuda'):
        """Analyze diversity vs alignment tradeoff."""
        results = []

        for scale in scales:
            samples = self._guided_sample(shape, condition, scale, device)
            diversity = self.compute_diversity(samples)
            alignment = self.compute_condition_alignment(samples, condition, classifier)

            results.append({
                'scale': scale,
                'diversity': diversity,
                'alignment': alignment
            })

        return results


def dynamic_guidance(t, max_scale=7.5, min_scale=1.0, num_timesteps=1000):
    """Guidance scale that varies with timestep."""
    # Higher guidance at start (high noise), lower at end
    progress = t / num_timesteps
    return min_scale + (max_scale - min_scale) * progress

Key Takeaways

Score-based models learn the gradient of the log probability density through score matching. This connects diffusion to energy-based models and enables Langevin dynamics sampling. Classifier guidance uses a separately trained classifier to steer generation toward target classes. Classifier-free guidance trains a single model for conditional and unconditional generation, combining predictions during sampling. The guidance scale controls the strength of conditioning, trading diversity for alignment. These guidance techniques have become essential for controllable high-quality generation.

19.4 Diffusion Model Applications Advanced

Diffusion Model Applications

Diffusion models have revolutionized generative AI across multiple domains. From text-to-image synthesis to audio generation, the flexibility and quality of diffusion-based approaches have enabled applications that seemed impossible just years ago.

Text-to-Image Generation

Text-to-image diffusion models generate images from natural language descriptions. These models combine powerful vision architectures with text encoders to understand and visualize complex textual concepts. Cross-attention mechanisms allow the image generation process to attend to relevant parts of the text prompt at each denoising step.

Latent diffusion operates in a compressed latent space rather than pixel space, dramatically reducing computational requirements while maintaining image quality. A pretrained autoencoder compresses images to latents for training and generation, then decodes the final latents back to full-resolution images.

PYTHON
import torch
import torch.nn as nn
import torch.nn.functional as F

class LatentDiffusion(nn.Module):
    def __init__(self, vae, unet, text_encoder, scheduler):
        super().__init__()
        self.vae = vae
        self.unet = unet
        self.text_encoder = text_encoder
        self.scheduler = scheduler

        # Freeze VAE and text encoder
        self.vae.requires_grad_(False)
        self.text_encoder.requires_grad_(False)

    def encode_image(self, images):
        """Encode images to latent space."""
        with torch.no_grad():
            latents = self.vae.encode(images).latent_dist.sample()
            latents = latents * 0.18215  # Scaling factor
        return latents

    def decode_latents(self, latents):
        """Decode latents to images."""
        latents = latents / 0.18215
        with torch.no_grad():
            images = self.vae.decode(latents).sample
        return images

    def encode_text(self, text_tokens):
        """Encode text prompts."""
        with torch.no_grad():
            text_embeddings = self.text_encoder(text_tokens)[0]
        return text_embeddings

    def forward(self, latents, timesteps, text_embeddings):
        """Predict noise at given timesteps."""
        return self.unet(latents, timesteps, encoder_hidden_states=text_embeddings).sample


class TextToImagePipeline:
    def __init__(self, model, scheduler, device='cuda'):
        self.model = model.to(device)
        self.scheduler = scheduler
        self.device = device

    @torch.no_grad()
    def generate(self, prompt, negative_prompt="", height=512, width=512,
                 num_inference_steps=50, guidance_scale=7.5):
        """Generate image from text prompt."""

        # Encode prompts
        text_tokens = self._tokenize(prompt)
        neg_tokens = self._tokenize(negative_prompt)

        text_emb = self.model.encode_text(text_tokens.to(self.device))
        neg_emb = self.model.encode_text(neg_tokens.to(self.device))

        # Initialize latents
        latent_shape = (1, 4, height // 8, width // 8)
        latents = torch.randn(latent_shape, device=self.device)
        latents = latents * self.scheduler.init_noise_sigma

        # Sampling loop
        self.scheduler.set_timesteps(num_inference_steps)

        for t in self.scheduler.timesteps:
            latent_input = torch.cat([latents] * 2)
            text_input = torch.cat([neg_emb, text_emb])

            noise_pred = self.model(latent_input, t, text_input)

            # Split predictions
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

            # Classifier-free guidance
            noise_pred = noise_pred_uncond + guidance_scale * (
                noise_pred_text - noise_pred_uncond
            )

            # Scheduler step
            latents = self.scheduler.step(noise_pred, t, latents).prev_sample

        # Decode to image
        images = self.model.decode_latents(latents)
        return images

Image Editing and Inpainting

Diffusion models excel at image editing by selectively regenerating parts of an image while preserving others. Inpainting fills masked regions with content that blends seamlessly with the surrounding context. The model can be conditioned on text to guide what should appear in the edited region.

SDEdit enables image-to-image translation by adding noise to a source image and then denoising with a different prompt. The amount of noise controls how much the original image influences the result versus the new prompt. This enables style transfer, object replacement, and artistic modifications.

PYTHON
class ImageEditor:
    def __init__(self, model, scheduler, device='cuda'):
        self.model = model
        self.scheduler = scheduler
        self.device = device

    @torch.no_grad()
    def inpaint(self, image, mask, prompt, num_inference_steps=50, guidance_scale=7.5):
        """Inpaint masked region guided by text prompt."""
        # Encode original image
        latents_original = self.model.encode_image(image.to(self.device))

        # Resize mask to latent size
        mask_latent = F.interpolate(mask, size=latents_original.shape[-2:])
        mask_latent = (mask_latent > 0.5).float()

        # Encode prompt
        text_emb = self.model.encode_text(self._tokenize(prompt).to(self.device))
        neg_emb = self.model.encode_text(self._tokenize("").to(self.device))

        # Initialize with noise
        latents = torch.randn_like(latents_original)

        self.scheduler.set_timesteps(num_inference_steps)

        for t in self.scheduler.timesteps:
            # Predict noise
            latent_input = torch.cat([latents] * 2)
            text_input = torch.cat([neg_emb, text_emb])

            noise_pred = self.model(latent_input, t, text_input)
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (
                noise_pred_text - noise_pred_uncond
            )

            # Scheduler step
            latents = self.scheduler.step(noise_pred, t, latents).prev_sample

            # Replace unmasked regions with original (noised appropriately)
            noise = torch.randn_like(latents_original)
            noised_original = self.scheduler.add_noise(latents_original, noise, t)
            latents = mask_latent * latents + (1 - mask_latent) * noised_original

        return self.model.decode_latents(latents)

    @torch.no_grad()
    def sdedit(self, image, prompt, strength=0.8, num_inference_steps=50, guidance_scale=7.5):
        """Image-to-image translation via SDEdit."""
        # Encode image
        latents = self.model.encode_image(image.to(self.device))

        # Determine starting timestep based on strength
        start_step = int(num_inference_steps * (1 - strength))
        self.scheduler.set_timesteps(num_inference_steps)

        # Add noise to latents
        noise = torch.randn_like(latents)
        timestep = self.scheduler.timesteps[start_step]
        latents = self.scheduler.add_noise(latents, noise, timestep)

        # Encode prompt
        text_emb = self.model.encode_text(self._tokenize(prompt).to(self.device))
        neg_emb = self.model.encode_text(self._tokenize("").to(self.device))

        # Denoise from intermediate step
        for t in self.scheduler.timesteps[start_step:]:
            latent_input = torch.cat([latents] * 2)
            text_input = torch.cat([neg_emb, text_emb])

            noise_pred = self.model(latent_input, t, text_input)
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (
                noise_pred_text - noise_pred_uncond
            )

            latents = self.scheduler.step(noise_pred, t, latents).prev_sample

        return self.model.decode_latents(latents)

Super-Resolution

Diffusion-based super-resolution upscales low-resolution images while adding realistic high-frequency details. Unlike interpolation methods that produce blurry results, diffusion models learn to generate plausible details that match the image content. Cascaded approaches progressively increase resolution through multiple diffusion stages.

The conditioning on the low-resolution image can be achieved through concatenation, where the upsampled low-res image is concatenated with the noisy high-res target, or through cross-attention mechanisms that allow the model to reference the source image.

PYTHON
class SuperResolutionDiffusion(nn.Module):
    def __init__(self, unet, scale_factor=4):
        super().__init__()
        self.unet = unet
        self.scale_factor = scale_factor

    def forward(self, x_noisy, t, low_res):
        """Predict noise conditioned on low-resolution image."""
        # Upsample low-res to match noisy high-res size
        low_res_up = F.interpolate(
            low_res, scale_factor=self.scale_factor, mode='bilinear'
        )

        # Concatenate along channel dimension
        x_input = torch.cat([x_noisy, low_res_up], dim=1)

        return self.unet(x_input, t)


class SuperResolutionPipeline:
    def __init__(self, model, scheduler, device='cuda'):
        self.model = model.to(device)
        self.scheduler = scheduler
        self.device = device

    @torch.no_grad()
    def upscale(self, low_res_image, num_inference_steps=50):
        """Upscale low-resolution image."""
        low_res = low_res_image.to(self.device)

        # Determine output shape
        b, c, h, w = low_res.shape
        high_res_shape = (b, c, h * self.model.scale_factor, w * self.model.scale_factor)

        # Start from noise
        x = torch.randn(high_res_shape, device=self.device)

        self.scheduler.set_timesteps(num_inference_steps)

        for t in self.scheduler.timesteps:
            noise_pred = self.model(x, t, low_res)
            x = self.scheduler.step(noise_pred, t, x).prev_sample

        return x


class CascadedSuperResolution:
    """Multi-stage super-resolution for very high upscaling factors."""

    def __init__(self, stages, schedulers, device='cuda'):
        self.stages = [s.to(device) for s in stages]
        self.schedulers = schedulers
        self.device = device

    @torch.no_grad()
    def upscale(self, image, num_inference_steps=50):
        """Progressively upscale through multiple stages."""
        current = image.to(self.device)

        for stage, scheduler in zip(self.stages, self.schedulers):
            b, c, h, w = current.shape
            target_shape = (b, c, h * stage.scale_factor, w * stage.scale_factor)

            x = torch.randn(target_shape, device=self.device)
            scheduler.set_timesteps(num_inference_steps)

            for t in scheduler.timesteps:
                noise_pred = stage(x, t, current)
                x = scheduler.step(noise_pred, t, x).prev_sample

            current = x

        return current

Audio Generation

Diffusion models generate high-quality audio including speech, music, and sound effects. Audio diffusion typically operates on mel-spectrograms or learned audio representations, then uses a vocoder to convert to waveforms. Text-to-audio models can generate sound effects, music, or speech from descriptions.

The temporal structure of audio requires handling long sequences. Architectures often use U-Net variants with attention mechanisms that capture both local patterns and long-range dependencies across the time dimension.

PYTHON
class AudioDiffusion(nn.Module):
    def __init__(self, unet, vocoder):
        super().__init__()
        self.unet = unet
        self.vocoder = vocoder

    def forward(self, mel_noisy, t, text_embeddings=None):
        """Predict noise for mel spectrogram."""
        return self.unet(mel_noisy, t, text_embeddings)

    def mel_to_audio(self, mel):
        """Convert mel spectrogram to audio waveform."""
        return self.vocoder(mel)


class TextToAudioPipeline:
    def __init__(self, model, text_encoder, scheduler, device='cuda'):
        self.model = model.to(device)
        self.text_encoder = text_encoder.to(device)
        self.scheduler = scheduler
        self.device = device

    @torch.no_grad()
    def generate(self, prompt, duration_seconds=5.0, num_inference_steps=100,
                 guidance_scale=3.0):
        """Generate audio from text description."""
        # Encode text
        text_emb = self.text_encoder(prompt)
        null_emb = self.text_encoder("")

        # Calculate mel spectrogram shape
        mel_frames = int(duration_seconds * 100)  # 100 frames per second
        mel_shape = (1, 80, mel_frames)  # 80 mel bins

        # Initialize with noise
        mel = torch.randn(mel_shape, device=self.device)

        self.scheduler.set_timesteps(num_inference_steps)

        for t in self.scheduler.timesteps:
            mel_input = torch.cat([mel] * 2)
            text_input = torch.cat([null_emb, text_emb])

            noise_pred = self.model(mel_input, t, text_input)
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (
                noise_pred_text - noise_pred_uncond
            )

            mel = self.scheduler.step(noise_pred, t, mel).prev_sample

        # Convert to audio
        audio = self.model.mel_to_audio(mel)
        return audio


class MusicGenerationModel(nn.Module):
    """Diffusion model for music generation."""

    def __init__(self, base_channels=128, num_layers=12):
        super().__init__()

        self.time_embed = nn.Sequential(
            nn.Linear(1, base_channels * 4),
            nn.SiLU(),
            nn.Linear(base_channels * 4, base_channels * 4)
        )

        # 1D convolutions for audio
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()

        channels = [base_channels * (2 ** i) for i in range(num_layers // 2)]

        for i, ch in enumerate(channels):
            in_ch = 1 if i == 0 else channels[i-1]
            self.encoder.append(nn.Sequential(
                nn.Conv1d(in_ch, ch, 3, stride=2, padding=1),
                nn.GroupNorm(8, ch),
                nn.SiLU()
            ))

        for i, ch in enumerate(reversed(channels)):
            out_ch = 1 if i == len(channels) - 1 else channels[-(i+2)]
            self.decoder.append(nn.Sequential(
                nn.ConvTranspose1d(ch * 2, out_ch, 4, stride=2, padding=1),
                nn.GroupNorm(8, out_ch) if out_ch > 1 else nn.Identity(),
                nn.SiLU() if out_ch > 1 else nn.Tanh()
            ))

    def forward(self, x, t):
        t_emb = self.time_embed(t.float().unsqueeze(-1) / 1000)

        skips = []
        for layer in self.encoder:
            x = layer(x)
            skips.append(x)

        for layer, skip in zip(self.decoder, reversed(skips)):
            x = torch.cat([x, skip], dim=1)
            x = layer(x)

        return x

Video Generation

Video diffusion extends image generation to the temporal domain, producing coherent sequences of frames. The challenge lies in maintaining both spatial quality within frames and temporal consistency across frames. Approaches include 3D convolutions, temporal attention, and autoregressive frame generation.

Video models often leverage pretrained image diffusion models, adding temporal layers that learn motion and consistency. This transfer enables high-quality video generation without training from scratch on video data.

PYTHON
class VideoDiffusion(nn.Module):
    def __init__(self, image_unet, temporal_layers):
        super().__init__()
        self.image_unet = image_unet
        self.temporal_layers = temporal_layers

    def forward(self, x, t, text_embeddings=None):
        """Process video batch with spatial and temporal modeling."""
        b, f, c, h, w = x.shape  # batch, frames, channels, height, width

        # Process each frame spatially
        x = x.view(b * f, c, h, w)
        x = self.image_unet.encode(x, t.repeat(f), text_embeddings)

        # Reshape for temporal processing
        x = x.view(b, f, -1, h // 8, w // 8)

        # Apply temporal attention/convolutions
        x = self.temporal_layers(x)

        # Decode back to pixel space
        x = x.view(b * f, -1, h // 8, w // 8)
        x = self.image_unet.decode(x, t.repeat(f), text_embeddings)

        return x.view(b, f, c, h, w)


class TemporalAttention(nn.Module):
    def __init__(self, dim, num_heads=8):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.attention = nn.MultiheadAttention(dim, num_heads, batch_first=True)

    def forward(self, x):
        # x: (batch, frames, channels, height, width)
        b, f, c, h, w = x.shape

        # Reshape for attention over frames
        x = x.permute(0, 3, 4, 1, 2).reshape(b * h * w, f, c)

        # Self-attention across temporal dimension
        residual = x
        x = self.norm(x)
        x, _ = self.attention(x, x, x)
        x = x + residual

        # Reshape back
        x = x.reshape(b, h, w, f, c).permute(0, 3, 4, 1, 2)
        return x

Key Takeaways

Diffusion models power diverse applications across images, audio, and video. Text-to-image models combine vision and language through cross-attention in latent space. Image editing uses inpainting and SDEdit to selectively modify regions while preserving context. Super-resolution adds realistic details through conditioning on low-resolution inputs. Audio diffusion generates speech and music via mel-spectrogram generation. Video diffusion extends to temporal modeling with 3D convolutions and temporal attention. These applications demonstrate the flexibility of the diffusion framework across modalities.

19.5 Advanced Topics and Future Directions Advanced

Advanced Topics and Future Directions

The rapid evolution of diffusion models has spawned numerous innovations addressing efficiency, controllability, and new modalities. These advances point toward a future where high-quality generation becomes increasingly accessible and versatile.

Consistency Models

Consistency models enable single-step or few-step generation by training the model to map any point along the diffusion trajectory directly to the clean data. Rather than requiring hundreds of iterative denoising steps, consistency models learn to perform the entire denoising in one forward pass, achieving orders of magnitude speedup.

The key insight is that all points along a probability flow trajectory map to the same origin. Training enforces this consistency by requiring predictions from adjacent timesteps to match when mapped to the data. This self-consistency provides supervision without needing the full iterative process.

PYTHON
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConsistencyModel(nn.Module):
    def __init__(self, base_model, sigma_min=0.002, sigma_max=80):
        super().__init__()
        self.model = base_model
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max

    def forward(self, x, sigma):
        """Map noisy input directly to clean data estimate."""
        # Skip connection scaling
        c_skip = self.sigma_min ** 2 / (sigma ** 2 + self.sigma_min ** 2)
        c_out = sigma * self.sigma_min / torch.sqrt(sigma ** 2 + self.sigma_min ** 2)

        # Model prediction
        model_output = self.model(x, sigma)

        # Consistency function output
        return c_skip.view(-1, 1, 1, 1) * x + c_out.view(-1, 1, 1, 1) * model_output


class ConsistencyTrainer:
    def __init__(self, model, ema_model, sigma_min=0.002, sigma_max=80):
        self.model = model
        self.ema_model = ema_model
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max

    def consistency_loss(self, x_0, optimizer):
        """Training loss enforcing self-consistency."""
        batch_size = x_0.size(0)
        device = x_0.device

        # Sample timesteps
        t = torch.rand(batch_size, device=device)
        sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
        sigma_next = self.sigma_min * (self.sigma_max / self.sigma_min) ** (t - 1/1000)
        sigma_next = torch.clamp(sigma_next, min=self.sigma_min)

        # Add noise
        noise = torch.randn_like(x_0)
        x_t = x_0 + sigma.view(-1, 1, 1, 1) * noise
        x_t_next = x_0 + sigma_next.view(-1, 1, 1, 1) * noise

        # Consistency targets (from EMA model)
        with torch.no_grad():
            target = self.ema_model(x_t_next, sigma_next)

        # Model prediction
        pred = self.model(x_t, sigma)

        # Consistency loss
        loss = F.mse_loss(pred, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update EMA
        self._update_ema()

        return loss.item()

    def _update_ema(self, decay=0.9999):
        for p, p_ema in zip(self.model.parameters(), self.ema_model.parameters()):
            p_ema.data.mul_(decay).add_(p.data, alpha=1 - decay)

    @torch.no_grad()
    def sample(self, shape, device, num_steps=1):
        """Generate samples with single or few steps."""
        self.model.eval()

        # Start from noise at max sigma
        x = torch.randn(shape, device=device) * self.sigma_max

        if num_steps == 1:
            # Single step generation
            sigma = torch.full((shape[0],), self.sigma_max, device=device)
            x = self.model(x, sigma)
        else:
            # Few-step generation
            sigmas = torch.linspace(self.sigma_max, self.sigma_min, num_steps + 1, device=device)
            for i in range(num_steps):
                sigma = sigmas[i].expand(shape[0])
                x = self.model(x, sigma)
                if i < num_steps - 1:
                    noise = torch.randn_like(x)
                    x = x + torch.sqrt(sigmas[i+1]**2 - self.sigma_min**2) * noise

        self.model.train()
        return x

Flow Matching

Flow matching provides an alternative to the diffusion framework by learning continuous normalizing flows through simulation-free training. Rather than defining a noising process and learning its reverse, flow matching directly learns a vector field that transports samples from noise to data. This approach often results in straighter trajectories requiring fewer integration steps.

Rectified flow further improves efficiency by learning to straighten the transport paths. The model is trained on pairs of noise and data samples connected by straight lines, encouraging direct paths that can be traversed in fewer steps.

PYTHON
class FlowMatchingModel(nn.Module):
    def __init__(self, vector_field):
        super().__init__()
        self.vector_field = vector_field

    def forward(self, x, t):
        """Predict velocity at time t."""
        return self.vector_field(x, t)


class FlowMatchingTrainer:
    def __init__(self, model, sigma_min=0.001):
        self.model = model
        self.sigma_min = sigma_min

    def train_step(self, x_0, x_1, optimizer):
        """Train with conditional flow matching objective."""
        batch_size = x_0.size(0)
        device = x_0.device

        # Sample time uniformly
        t = torch.rand(batch_size, device=device).view(-1, 1, 1, 1)

        # Interpolate between noise (x_0) and data (x_1)
        x_t = (1 - t) * x_0 + t * x_1

        # Target velocity: direction from noise to data
        target_velocity = x_1 - x_0

        # Predict velocity
        predicted_velocity = self.model(x_t, t.view(-1))

        # MSE loss
        loss = F.mse_loss(predicted_velocity, target_velocity)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return loss.item()

    @torch.no_grad()
    def sample(self, shape, device, num_steps=50):
        """Generate samples by integrating the flow."""
        self.model.eval()

        # Start from noise
        x = torch.randn(shape, device=device)

        # Euler integration
        dt = 1.0 / num_steps
        for i in range(num_steps):
            t = torch.full((shape[0],), i * dt, device=device)
            velocity = self.model(x, t)
            x = x + velocity * dt

        self.model.train()
        return x


class RectifiedFlow:
    """Rectified flow for straighter trajectories."""

    def __init__(self, model):
        self.model = model

    def reflow(self, x_0_samples, x_1_samples, num_iterations=1):
        """Iteratively straighten flow paths."""
        for _ in range(num_iterations):
            # Generate new pairs using current model
            x_0_new = torch.randn_like(x_0_samples)
            x_1_new = self._sample_from_noise(x_0_new)

            # Train on straightened pairs
            self._train_on_pairs(x_0_new, x_1_new)

    def _sample_from_noise(self, x_0, num_steps=100):
        """Sample using current flow model."""
        x = x_0.clone()
        dt = 1.0 / num_steps
        for i in range(num_steps):
            t = torch.full((x.size(0),), i * dt, device=x.device)
            velocity = self.model(x, t)
            x = x + velocity * dt
        return x

Distillation Techniques

Knowledge distillation compresses the knowledge from a slow teacher diffusion model into a faster student. Progressive distillation halves the number of sampling steps iteratively, training the student to match the teacher's two-step output in a single step. This can reduce sampling from hundreds of steps to just a few.

Guided distillation incorporates classifier-free guidance into the distillation process, allowing the student to perform guided generation without the computational overhead of running two forward passes per step.

PYTHON
class ProgressiveDistillation:
    def __init__(self, teacher, student, scheduler):
        self.teacher = teacher
        self.student = student
        self.scheduler = scheduler

        # Freeze teacher
        for param in self.teacher.parameters():
            param.requires_grad = False

    def distill_step(self, x_0, optimizer, current_steps):
        """One step of progressive distillation."""
        batch_size = x_0.size(0)
        device = x_0.device

        # Sample starting point
        t_start = torch.randint(0, current_steps - 1, (batch_size,), device=device)

        # Add noise to get x_t
        noise = torch.randn_like(x_0)
        x_t = self.scheduler.add_noise(x_0, noise, t_start * 2)

        # Teacher takes two steps
        with torch.no_grad():
            teacher_pred_1 = self.teacher(x_t, t_start * 2)
            x_mid = self.scheduler.step(teacher_pred_1, t_start * 2, x_t)
            teacher_pred_2 = self.teacher(x_mid, t_start * 2 + 1)
            teacher_target = self.scheduler.step(teacher_pred_2, t_start * 2 + 1, x_mid)

        # Student takes one step
        student_pred = self.student(x_t, t_start)
        student_output = self.scheduler.step(student_pred, t_start, x_t)

        # Match teacher's two-step result
        loss = F.mse_loss(student_output, teacher_target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return loss.item()


class GuidedDistillation:
    """Distill guidance into the model itself."""

    def __init__(self, teacher, student, guidance_scale=7.5):
        self.teacher = teacher
        self.student = student
        self.guidance_scale = guidance_scale

    def distill_step(self, x_0, condition, optimizer, scheduler):
        """Distill classifier-free guidance."""
        batch_size = x_0.size(0)
        device = x_0.device

        t = torch.randint(0, scheduler.num_timesteps, (batch_size,), device=device)
        noise = torch.randn_like(x_0)
        x_t = scheduler.add_noise(x_0, noise, t)

        # Teacher with guidance (two forward passes)
        with torch.no_grad():
            noise_pred_cond = self.teacher(x_t, t, condition)
            noise_pred_uncond = self.teacher(x_t, t, None)
            teacher_target = noise_pred_uncond + self.guidance_scale * (
                noise_pred_cond - noise_pred_uncond
            )

        # Student single pass (learns guided behavior)
        student_pred = self.student(x_t, t, condition)

        loss = F.mse_loss(student_pred, teacher_target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return loss.item()

ControlNet and Adapters

ControlNet adds spatial conditioning to pretrained diffusion models without modifying the original weights. A trainable copy of the encoder is created and connected to the frozen model through zero-initialized convolutions. This enables conditioning on edges, depth maps, poses, and other spatial signals.

Adapters provide lightweight modules that can be inserted into pretrained models for task-specific fine-tuning. LoRA (Low-Rank Adaptation) decomposes weight updates into low-rank matrices, dramatically reducing the number of trainable parameters while maintaining quality.

PYTHON
class ControlNet(nn.Module):
    def __init__(self, pretrained_unet, hint_channels=3):
        super().__init__()

        # Frozen copy of original model
        self.unet = pretrained_unet
        for param in self.unet.parameters():
            param.requires_grad = False

        # Trainable control encoder (copy of UNet encoder)
        self.control_encoder = self._copy_encoder(pretrained_unet)

        # Input hint processing
        self.hint_conv = nn.Conv2d(hint_channels, 320, 3, padding=1)

        # Zero-initialized connections
        self.zero_convs = nn.ModuleList([
            nn.Conv2d(ch, ch, 1) for ch in [320, 320, 640, 640, 1280, 1280]
        ])

        # Initialize zero convs to zero
        for conv in self.zero_convs:
            nn.init.zeros_(conv.weight)
            nn.init.zeros_(conv.bias)

    def forward(self, x, t, text_emb, hint):
        # Process hint
        hint_features = self.hint_conv(hint)

        # Get control features
        control_features = []
        h = x + hint_features
        for block, zero_conv in zip(self.control_encoder, self.zero_convs):
            h = block(h, t)
            control_features.append(zero_conv(h))

        # Run frozen UNet with control added
        return self.unet(x, t, text_emb, control_features=control_features)


class LoRALayer(nn.Module):
    """Low-Rank Adaptation layer."""

    def __init__(self, original_layer, rank=4, alpha=1.0):
        super().__init__()
        self.original_layer = original_layer
        self.rank = rank
        self.alpha = alpha

        # Freeze original
        for param in original_layer.parameters():
            param.requires_grad = False

        # Low-rank decomposition
        in_features = original_layer.in_features
        out_features = original_layer.out_features

        self.lora_A = nn.Parameter(torch.randn(rank, in_features) * 0.01)
        self.lora_B = nn.Parameter(torch.zeros(out_features, rank))

    def forward(self, x):
        original_output = self.original_layer(x)
        lora_output = (x @ self.lora_A.T) @ self.lora_B.T
        return original_output + self.alpha * lora_output


class LoRADiffusion:
    """Apply LoRA to diffusion model attention layers."""

    def __init__(self, model, rank=4, target_modules=['to_q', 'to_v']):
        self.model = model
        self.lora_layers = []

        for name, module in model.named_modules():
            if any(target in name for target in target_modules):
                if isinstance(module, nn.Linear):
                    lora_layer = LoRALayer(module, rank=rank)
                    self.lora_layers.append((name, lora_layer))
                    # Replace in model
                    parent = self._get_parent(model, name)
                    setattr(parent, name.split('.')[-1], lora_layer)

    def get_trainable_params(self):
        """Return only LoRA parameters for training."""
        params = []
        for _, layer in self.lora_layers:
            params.extend([layer.lora_A, layer.lora_B])
        return params

3D Generation

Diffusion models extend to 3D generation including point clouds, neural radiance fields, and mesh-based representations. Multi-view diffusion generates consistent images from multiple viewpoints that can be fused into 3D reconstructions. Score distillation enables optimizing 3D representations using 2D diffusion priors.

PYTHON
class PointCloudDiffusion(nn.Module):
    """Diffusion model for 3D point cloud generation."""

    def __init__(self, num_points=2048, point_dim=3, hidden_dim=256):
        super().__init__()

        self.time_embed = nn.Sequential(
            nn.Linear(1, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # Point-wise MLP with time conditioning
        self.layers = nn.ModuleList([
            nn.Linear(point_dim + hidden_dim, hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Linear(hidden_dim, point_dim)
        ])

    def forward(self, points, t):
        # points: (batch, num_points, 3)
        t_emb = self.time_embed(t.float().unsqueeze(-1) / 1000)
        t_emb = t_emb.unsqueeze(1).expand(-1, points.size(1), -1)

        x = torch.cat([points, t_emb], dim=-1)

        for layer in self.layers[:-1]:
            x = F.silu(layer(x))

        return self.layers[-1](x)


class ScoreDistillation:
    """Score Distillation Sampling for 3D generation."""

    def __init__(self, diffusion_model, guidance_scale=100):
        self.model = diffusion_model
        self.guidance_scale = guidance_scale

    def compute_sds_loss(self, rendered_images, text_embeddings, scheduler):
        """Compute SDS gradient for updating 3D representation."""
        batch_size = rendered_images.size(0)
        device = rendered_images.device

        # Random timestep
        t = torch.randint(20, 980, (batch_size,), device=device)

        # Add noise to rendered images
        noise = torch.randn_like(rendered_images)
        noisy_images = scheduler.add_noise(rendered_images, noise, t)

        # Get noise prediction
        with torch.no_grad():
            noise_pred_cond = self.model(noisy_images, t, text_embeddings)
            noise_pred_uncond = self.model(noisy_images, t, None)
            noise_pred = noise_pred_uncond + self.guidance_scale * (
                noise_pred_cond - noise_pred_uncond
            )

        # SDS gradient: difference between predicted and actual noise
        # This gradient is backpropagated to update 3D parameters
        w = (1 - scheduler.alphas_cumprod[t]).view(-1, 1, 1, 1)
        grad = w * (noise_pred - noise)

        # Create pseudo-loss for backprop
        loss = (grad * rendered_images).sum()
        return loss

Key Takeaways

Advanced diffusion techniques continue pushing the boundaries of generative AI. Consistency models enable single-step generation through self-consistency training. Flow matching provides simulation-free training with straighter sampling trajectories. Distillation compresses slow models into fast students while preserving quality. ControlNet and LoRA enable efficient adaptation of pretrained models for new conditions and tasks. 3D generation leverages diffusion priors through score distillation for neural radiance fields and other representations. These advances point toward increasingly efficient, controllable, and versatile generative systems.