Intermediate Advanced 90 min read

Chapter 18: Variational Autoencoders

Autoencoders, VAEs, VQ-VAE, and latent space models.

Libraries covered: PyTorch

Learning Objectives

["Understand VAE theory", "Implement reparameterization", "Build generative models"]


18.1 Introduction to GANs Intermediate

Introduction to Generative Adversarial Networks

Generative Adversarial Networks represent one of the most influential innovations in deep learning, introducing an adversarial training paradigm where two neural networks compete against each other to produce increasingly realistic synthetic data.

The Adversarial Paradigm

Traditional generative models learn to approximate a data distribution directly through maximum likelihood estimation or variational inference. GANs take a fundamentally different approach by framing generation as a game between two players. The generator network attempts to produce fake samples that are indistinguishable from real data, while the discriminator network tries to identify which samples are real and which are generated. This adversarial dynamic drives both networks to improve continuously.

The insight behind GANs comes from game theory. In a zero-sum game, one player's gain equals another's loss. The generator wins by fooling the discriminator, while the discriminator wins by correctly classifying real versus fake samples. At equilibrium, the generator produces samples so realistic that the discriminator cannot do better than random guessing.

PYTHON
import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, latent_dim, output_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, output_dim),
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)


class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

The Minimax Objective

The GAN training objective formalizes the adversarial game mathematically. The discriminator maximizes its ability to distinguish real from fake samples, while the generator minimizes the discriminator's success. This creates a minimax optimization problem where we simultaneously optimize both networks with opposing goals.

The value function that defines this game involves expectations over the real data distribution and the generator's output distribution. The discriminator receives high reward for assigning high probability to real samples and low probability to generated samples. The generator receives reward when the discriminator assigns high probability to its generated samples.

PYTHON
import torch.nn.functional as F

def discriminator_loss(real_output, fake_output):
    real_loss = F.binary_cross_entropy(real_output, torch.ones_like(real_output))
    fake_loss = F.binary_cross_entropy(fake_output, torch.zeros_like(fake_output))
    return real_loss + fake_loss


def generator_loss(fake_output):
    return F.binary_cross_entropy(fake_output, torch.ones_like(fake_output))


def train_step(generator, discriminator, real_data, latent_dim, g_optimizer, d_optimizer):
    batch_size = real_data.size(0)
    device = real_data.device

    # Train discriminator
    d_optimizer.zero_grad()

    real_output = discriminator(real_data)

    z = torch.randn(batch_size, latent_dim, device=device)
    fake_data = generator(z)
    fake_output = discriminator(fake_data.detach())

    d_loss = discriminator_loss(real_output, fake_output)
    d_loss.backward()
    d_optimizer.step()

    # Train generator
    g_optimizer.zero_grad()

    z = torch.randn(batch_size, latent_dim, device=device)
    fake_data = generator(z)
    fake_output = discriminator(fake_data)

    g_loss = generator_loss(fake_output)
    g_loss.backward()
    g_optimizer.step()

    return d_loss.item(), g_loss.item()

Latent Space and Generation

The generator transforms samples from a simple prior distribution, typically a multivariate Gaussian, into samples that approximate the complex data distribution. This latent space serves as the source of variation in generated samples. Each point in the latent space corresponds to a potential output, and the generator learns a mapping that transforms this simple space into the manifold of realistic data.

The dimensionality of the latent space affects generation quality and diversity. Too few dimensions may limit the generator's ability to capture all modes of variation in the data. Too many dimensions can make training more difficult and may not improve quality. Practitioners typically choose latent dimensions between 64 and 512 for image generation tasks.

PYTHON
class GAN:
    def __init__(self, latent_dim, data_dim, device):
        self.latent_dim = latent_dim
        self.device = device

        self.generator = Generator(latent_dim, data_dim).to(device)
        self.discriminator = Discriminator(data_dim).to(device)

        self.g_optimizer = torch.optim.Adam(
            self.generator.parameters(), lr=0.0002, betas=(0.5, 0.999)
        )
        self.d_optimizer = torch.optim.Adam(
            self.discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999)
        )

    def sample(self, num_samples):
        self.generator.eval()
        with torch.no_grad():
            z = torch.randn(num_samples, self.latent_dim, device=self.device)
            samples = self.generator(z)
        self.generator.train()
        return samples

    def interpolate(self, z1, z2, steps=10):
        self.generator.eval()
        with torch.no_grad():
            alphas = torch.linspace(0, 1, steps, device=self.device)
            interpolations = []
            for alpha in alphas:
                z = (1 - alpha) * z1 + alpha * z2
                sample = self.generator(z)
                interpolations.append(sample)
        self.generator.train()
        return torch.stack(interpolations)

Understanding the Discriminator

The discriminator functions as a learned loss function for the generator. Rather than using a fixed metric like mean squared error or pixel-wise differences, GANs learn what makes samples realistic through the discriminator's evolving judgment. This adaptive loss function can capture complex, high-level features that distinguish real from fake samples.

The discriminator must balance between being too easy and too hard to fool. If the discriminator is too weak, it provides no useful signal for the generator to improve. If the discriminator is too strong, it may reject all generated samples completely, leaving no gradient for the generator to follow. This balance is crucial for stable training.

PYTHON
class SpectralNormDiscriminator(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.utils.spectral_norm(nn.Linear(input_dim, 1024)),
            nn.LeakyReLU(0.2),
            nn.utils.spectral_norm(nn.Linear(1024, 512)),
            nn.LeakyReLU(0.2),
            nn.utils.spectral_norm(nn.Linear(512, 256)),
            nn.LeakyReLU(0.2),
            nn.utils.spectral_norm(nn.Linear(256, 1))
        )

    def forward(self, x):
        return self.model(x)


def compute_gradient_penalty(discriminator, real_samples, fake_samples):
    batch_size = real_samples.size(0)
    device = real_samples.device

    alpha = torch.rand(batch_size, 1, device=device)
    interpolates = alpha * real_samples + (1 - alpha) * fake_samples
    interpolates.requires_grad_(True)

    d_interpolates = discriminator(interpolates)

    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones_like(d_interpolates),
        create_graph=True,
        retain_graph=True
    )[0]

    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

The Nash Equilibrium Goal

The theoretical goal of GAN training is to reach a Nash equilibrium where neither network can improve by changing its strategy unilaterally. At this point, the generator produces samples from the true data distribution, and the discriminator outputs probability 0.5 for all inputs, unable to distinguish real from fake.

In practice, reaching this equilibrium is challenging. The optimization landscape is non-convex, and gradient descent on the minimax objective does not guarantee convergence. The training dynamics can exhibit oscillations, mode collapse where the generator produces limited variety, or divergence where one network dominates completely.

PYTHON
def train_gan(gan, dataloader, epochs):
    history = {"d_loss": [], "g_loss": []}

    for epoch in range(epochs):
        epoch_d_loss = 0
        epoch_g_loss = 0
        num_batches = 0

        for real_data in dataloader:
            real_data = real_data[0].to(gan.device)

            d_loss, g_loss = train_step(
                gan.generator,
                gan.discriminator,
                real_data,
                gan.latent_dim,
                gan.g_optimizer,
                gan.d_optimizer
            )

            epoch_d_loss += d_loss
            epoch_g_loss += g_loss
            num_batches += 1

        history["d_loss"].append(epoch_d_loss / num_batches)
        history["g_loss"].append(epoch_g_loss / num_batches)

        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}: D_loss={history['d_loss'][-1]:.4f}, "
                  f"G_loss={history['g_loss'][-1]:.4f}")

    return history

Key Takeaways

Generative Adversarial Networks introduce adversarial training where a generator and discriminator compete in a minimax game. The generator transforms latent vectors into synthetic samples while the discriminator learns to distinguish real from fake. This framework enables learning complex data distributions without explicit likelihood computation. The discriminator serves as a learned loss function that adapts during training. While theoretically elegant, GANs present significant training challenges that subsequent sections will address.

18.2 Training GANs Advanced

Training Generative Adversarial Networks

Training GANs presents unique challenges not found in standard supervised learning. The adversarial dynamic creates complex optimization landscapes where both networks must improve in tandem while neither should dominate the other.

The Mode Collapse Problem

Mode collapse occurs when the generator learns to produce only a limited variety of outputs, ignoring large portions of the data distribution. Instead of capturing all the diversity in the training data, the generator finds a small set of samples that consistently fool the discriminator. The generator essentially "collapses" to producing the same or very similar outputs regardless of the input latent vector.

This problem arises because the generator optimizes for fooling the current discriminator rather than matching the true data distribution. If producing one type of sample works well, the generator has no incentive to explore other modes. The discriminator eventually learns to reject these repeated samples, but the generator may simply collapse to a different mode rather than learning to produce diverse outputs.

PYTHON
import torch
import torch.nn as nn

class ModeCollapseDetector:
    def __init__(self, num_samples=1000, threshold=0.1):
        self.num_samples = num_samples
        self.threshold = threshold

    def compute_diversity(self, generator, latent_dim, device):
        generator.eval()
        with torch.no_grad():
            z = torch.randn(self.num_samples, latent_dim, device=device)
            samples = generator(z)

            # Compute pairwise distances
            samples_flat = samples.view(self.num_samples, -1)
            distances = torch.cdist(samples_flat, samples_flat)

            # Average distance excluding diagonal
            mask = ~torch.eye(self.num_samples, dtype=torch.bool, device=device)
            avg_distance = distances[mask].mean()

        generator.train()
        return avg_distance.item()

    def is_collapsed(self, generator, latent_dim, device):
        diversity = self.compute_diversity(generator, latent_dim, device)
        return diversity < self.threshold


def minibatch_discrimination(features, num_kernels=5, kernel_dim=3):
    batch_size = features.size(0)
    feature_dim = features.size(1)

    # Project features to kernel space
    T = nn.Parameter(torch.randn(feature_dim, num_kernels * kernel_dim))
    M = torch.mm(features, T).view(batch_size, num_kernels, kernel_dim)

    # Compute L1 distance between all pairs
    M_expanded = M.unsqueeze(0)
    M_transposed = M.unsqueeze(1)
    diffs = torch.abs(M_expanded - M_transposed).sum(dim=3)

    # Sum of exponential distances (excluding self)
    exp_diffs = torch.exp(-diffs)
    mask = 1 - torch.eye(batch_size, device=features.device)
    o = (exp_diffs * mask.unsqueeze(2)).sum(dim=1)

    return torch.cat([features, o], dim=1)

Vanishing Gradients

When the discriminator becomes too accurate, it provides near-zero gradients to the generator. The discriminator outputs probabilities very close to 0 for all generated samples, and the gradient of the binary cross-entropy loss vanishes in this regime. Without meaningful gradients, the generator cannot learn which direction to improve.

This problem is fundamental to the original GAN formulation. The generator loss involves maximizing log(D(G(z))), but when D(G(z)) is near zero, this quantity is extremely negative with very small gradients. Alternative formulations address this by changing the generator objective or the divergence measure used.

PYTHON
def non_saturating_generator_loss(fake_output):
    """Non-saturating loss helps with vanishing gradients."""
    return -torch.log(fake_output + 1e-8).mean()


def least_squares_losses(real_output, fake_output):
    """LSGAN uses least squares instead of log loss."""
    d_loss_real = ((real_output - 1) ** 2).mean()
    d_loss_fake = (fake_output ** 2).mean()
    d_loss = 0.5 * (d_loss_real + d_loss_fake)

    g_loss = 0.5 * ((fake_output - 1) ** 2).mean()

    return d_loss, g_loss


def hinge_losses(real_output, fake_output):
    """Hinge loss provides stable training."""
    d_loss_real = torch.relu(1.0 - real_output).mean()
    d_loss_fake = torch.relu(1.0 + fake_output).mean()
    d_loss = d_loss_real + d_loss_fake

    g_loss = -fake_output.mean()

    return d_loss, g_loss

Wasserstein GAN

The Wasserstein GAN fundamentally changes the GAN objective by using the Earth Mover's distance (Wasserstein-1 distance) instead of the Jensen-Shannon divergence implicit in the original formulation. This distance measures how much "work" is required to transform one distribution into another, providing meaningful gradients even when distributions have non-overlapping support.

The discriminator in WGAN, called the critic, outputs unbounded real values rather than probabilities. The critic must satisfy a Lipschitz constraint, originally enforced through weight clipping. The critic maximizes the difference between expected outputs on real and fake samples, while the generator minimizes the critic's output on fake samples.

PYTHON
class WGANCritic(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        return self.model(x)


def wgan_critic_loss(real_output, fake_output):
    return fake_output.mean() - real_output.mean()


def wgan_generator_loss(fake_output):
    return -fake_output.mean()


def train_wgan_step(generator, critic, real_data, latent_dim,
                    g_optimizer, c_optimizer, clip_value=0.01, n_critic=5):
    batch_size = real_data.size(0)
    device = real_data.device

    # Train critic multiple times per generator update
    for _ in range(n_critic):
        c_optimizer.zero_grad()

        real_output = critic(real_data)

        z = torch.randn(batch_size, latent_dim, device=device)
        fake_data = generator(z).detach()
        fake_output = critic(fake_data)

        c_loss = wgan_critic_loss(real_output, fake_output)
        c_loss.backward()
        c_optimizer.step()

        # Weight clipping for Lipschitz constraint
        for param in critic.parameters():
            param.data.clamp_(-clip_value, clip_value)

    # Train generator
    g_optimizer.zero_grad()

    z = torch.randn(batch_size, latent_dim, device=device)
    fake_data = generator(z)
    fake_output = critic(fake_data)

    g_loss = wgan_generator_loss(fake_output)
    g_loss.backward()
    g_optimizer.step()

    return c_loss.item(), g_loss.item()

Gradient Penalty

Weight clipping in WGAN can lead to pathological behavior where critic weights concentrate at the clipping boundaries. WGAN-GP replaces weight clipping with a gradient penalty that directly enforces the Lipschitz constraint by penalizing gradients with norm different from one.

The penalty is computed on interpolated samples between real and fake data. This encourages the critic to have unit gradient norm throughout the space where it matters for distinguishing distributions, providing a more principled enforcement of the Lipschitz constraint.

PYTHON
def compute_gradient_penalty(critic, real_samples, fake_samples, device):
    batch_size = real_samples.size(0)

    # Random interpolation coefficient
    alpha = torch.rand(batch_size, 1, device=device)
    alpha = alpha.expand_as(real_samples)

    # Interpolated samples
    interpolates = alpha * real_samples + (1 - alpha) * fake_samples
    interpolates.requires_grad_(True)

    # Critic output on interpolates
    d_interpolates = critic(interpolates)

    # Compute gradients
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones_like(d_interpolates),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]

    # Gradient penalty
    gradients = gradients.view(batch_size, -1)
    gradient_norm = gradients.norm(2, dim=1)
    gradient_penalty = ((gradient_norm - 1) ** 2).mean()

    return gradient_penalty


def train_wgan_gp_step(generator, critic, real_data, latent_dim,
                       g_optimizer, c_optimizer, lambda_gp=10, n_critic=5):
    batch_size = real_data.size(0)
    device = real_data.device

    for _ in range(n_critic):
        c_optimizer.zero_grad()

        real_output = critic(real_data)

        z = torch.randn(batch_size, latent_dim, device=device)
        fake_data = generator(z).detach()
        fake_output = critic(fake_data)

        # WGAN loss plus gradient penalty
        gp = compute_gradient_penalty(critic, real_data, fake_data, device)
        c_loss = wgan_critic_loss(real_output, fake_output) + lambda_gp * gp

        c_loss.backward()
        c_optimizer.step()

    g_optimizer.zero_grad()

    z = torch.randn(batch_size, latent_dim, device=device)
    fake_data = generator(z)
    fake_output = critic(fake_data)

    g_loss = wgan_generator_loss(fake_output)
    g_loss.backward()
    g_optimizer.step()

    return c_loss.item(), g_loss.item()

Training Heuristics

Practitioners have developed numerous heuristics to stabilize GAN training. These include architectural choices like using batch normalization in the generator but not the discriminator, employing LeakyReLU activations, and avoiding sparse gradients from maxpooling. Learning rate and optimizer settings also significantly impact stability.

Label smoothing replaces hard targets of 1 for real samples with softer targets like 0.9, preventing the discriminator from becoming overconfident. Adding noise to discriminator inputs, especially early in training, helps prevent the discriminator from finding trivial solutions. Two-timescale update rules train the discriminator more frequently or with higher learning rates.

PYTHON
class StableGANTrainer:
    def __init__(self, generator, discriminator, latent_dim, device,
                 g_lr=0.0001, d_lr=0.0004):
        self.generator = generator
        self.discriminator = discriminator
        self.latent_dim = latent_dim
        self.device = device

        # Different learning rates (TTUR)
        self.g_optimizer = torch.optim.Adam(
            generator.parameters(), lr=g_lr, betas=(0.0, 0.9)
        )
        self.d_optimizer = torch.optim.Adam(
            discriminator.parameters(), lr=d_lr, betas=(0.0, 0.9)
        )

        self.label_smoothing = 0.1
        self.noise_std = 0.1

    def add_instance_noise(self, x, decay_factor=1.0):
        noise = torch.randn_like(x) * self.noise_std * decay_factor
        return x + noise

    def smooth_labels(self, labels, smoothing=None):
        if smoothing is None:
            smoothing = self.label_smoothing
        return labels * (1 - smoothing) + 0.5 * smoothing

    def train_step(self, real_data, epoch):
        batch_size = real_data.size(0)

        # Decay noise over training
        noise_decay = max(0, 1 - epoch / 100)

        # Train discriminator
        self.d_optimizer.zero_grad()

        real_noisy = self.add_instance_noise(real_data, noise_decay)
        real_output = self.discriminator(real_noisy)
        real_labels = self.smooth_labels(torch.ones_like(real_output))

        z = torch.randn(batch_size, self.latent_dim, device=self.device)
        fake_data = self.generator(z).detach()
        fake_noisy = self.add_instance_noise(fake_data, noise_decay)
        fake_output = self.discriminator(fake_noisy)
        fake_labels = torch.zeros_like(fake_output)

        d_loss = (F.binary_cross_entropy(real_output, real_labels) +
                  F.binary_cross_entropy(fake_output, fake_labels))
        d_loss.backward()
        self.d_optimizer.step()

        # Train generator
        self.g_optimizer.zero_grad()

        z = torch.randn(batch_size, self.latent_dim, device=self.device)
        fake_data = self.generator(z)
        fake_output = self.discriminator(fake_data)

        g_loss = F.binary_cross_entropy(fake_output, torch.ones_like(fake_output))
        g_loss.backward()
        self.g_optimizer.step()

        return d_loss.item(), g_loss.item()

Evaluation Metrics

Evaluating GAN quality is challenging because we cannot directly compute the likelihood of generated samples. The Inception Score (IS) measures both quality and diversity by passing generated images through a pretrained classifier. High quality samples should have confident class predictions, while diversity means the marginal class distribution should be uniform.

The Frechet Inception Distance (FID) compares statistics of real and generated samples in a learned feature space. It computes the Frechet distance between multivariate Gaussians fitted to features extracted from a pretrained network. Lower FID indicates generated samples more closely match the distribution of real samples.

PYTHON
import numpy as np
from scipy import linalg

def compute_inception_score(samples, classifier, num_splits=10):
    """Compute Inception Score for generated samples."""
    with torch.no_grad():
        preds = torch.softmax(classifier(samples), dim=1)

    scores = []
    chunk_size = len(preds) // num_splits

    for i in range(num_splits):
        chunk = preds[i * chunk_size:(i + 1) * chunk_size]
        p_yx = chunk
        p_y = p_yx.mean(dim=0, keepdim=True)
        kl_div = p_yx * (torch.log(p_yx + 1e-10) - torch.log(p_y + 1e-10))
        kl_div = kl_div.sum(dim=1).mean()
        scores.append(torch.exp(kl_div).item())

    return np.mean(scores), np.std(scores)


def compute_fid(real_features, fake_features):
    """Compute Frechet Inception Distance."""
    mu_real = real_features.mean(dim=0).cpu().numpy()
    mu_fake = fake_features.mean(dim=0).cpu().numpy()

    sigma_real = np.cov(real_features.cpu().numpy(), rowvar=False)
    sigma_fake = np.cov(fake_features.cpu().numpy(), rowvar=False)

    diff = mu_real - mu_fake
    covmean, _ = linalg.sqrtm(sigma_real @ sigma_fake, disp=False)

    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = diff @ diff + np.trace(sigma_real + sigma_fake - 2 * covmean)
    return fid

Key Takeaways

GAN training faces unique challenges including mode collapse where generators produce limited variety, and vanishing gradients when discriminators become too accurate. Wasserstein GAN addresses these issues by using Earth Mover's distance with gradient penalty for stable Lipschitz enforcement. Practical heuristics like label smoothing, instance noise, and two-timescale update rules improve stability. Evaluation relies on metrics like Inception Score and FID that measure both quality and diversity of generated samples.

18.3 GAN Architectures Advanced

GAN Architectures

The evolution of GAN architectures has driven dramatic improvements in generation quality. From early fully-connected networks to sophisticated convolutional designs, architectural innovations address specific challenges in image synthesis.

Deep Convolutional GAN (DCGAN)

DCGAN established foundational principles for convolutional GAN architectures that remain influential today. The key insight was replacing pooling layers with strided convolutions, using batch normalization throughout, and carefully designing the generator and discriminator to be architectural mirrors of each other.

The generator uses transposed convolutions to progressively upsample from a latent vector to a full-resolution image. Each upsampling block doubles the spatial dimensions while reducing the number of channels. Batch normalization and ReLU activations follow each transposed convolution, with a final Tanh activation to produce pixel values in the range [-1, 1].

PYTHON
import torch
import torch.nn as nn

class DCGANGenerator(nn.Module):
    def __init__(self, latent_dim=100, num_channels=3, feature_maps=64):
        super().__init__()

        self.main = nn.Sequential(
            # Input: latent_dim x 1 x 1
            nn.ConvTranspose2d(latent_dim, feature_maps * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(feature_maps * 8),
            nn.ReLU(True),
            # State: (feature_maps*8) x 4 x 4

            nn.ConvTranspose2d(feature_maps * 8, feature_maps * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 4),
            nn.ReLU(True),
            # State: (feature_maps*4) x 8 x 8

            nn.ConvTranspose2d(feature_maps * 4, feature_maps * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 2),
            nn.ReLU(True),
            # State: (feature_maps*2) x 16 x 16

            nn.ConvTranspose2d(feature_maps * 2, feature_maps, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps),
            nn.ReLU(True),
            # State: feature_maps x 32 x 32

            nn.ConvTranspose2d(feature_maps, num_channels, 4, 2, 1, bias=False),
            nn.Tanh()
            # Output: num_channels x 64 x 64
        )

    def forward(self, z):
        z = z.view(z.size(0), -1, 1, 1)
        return self.main(z)


class DCGANDiscriminator(nn.Module):
    def __init__(self, num_channels=3, feature_maps=64):
        super().__init__()

        self.main = nn.Sequential(
            # Input: num_channels x 64 x 64
            nn.Conv2d(num_channels, feature_maps, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(feature_maps, feature_maps * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(feature_maps * 2, feature_maps * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(feature_maps * 4, feature_maps * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_maps * 8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(feature_maps * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.main(x).view(-1, 1)

Progressive Growing GAN

Progressive GAN addresses the challenge of generating high-resolution images by starting with low resolution and progressively adding layers. Training begins at 4x4 resolution and gradually increases to the target resolution, with each stage adding new convolutional layers to both generator and discriminator. This curriculum approach stabilizes training and enables generation at resolutions previously unattainable.

The key innovation is smooth layer transitions using alpha blending. When adding new layers, their contribution fades in gradually rather than appearing abruptly. This prevents sudden changes that could destabilize the learned representations at lower resolutions.

PYTHON
class ProgressiveGenerator(nn.Module):
    def __init__(self, latent_dim=512, max_resolution=1024):
        super().__init__()
        self.latent_dim = latent_dim

        # Initial 4x4 block
        self.initial = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 512, 3, 1, 1),
            nn.LeakyReLU(0.2)
        )

        # Progressive blocks (double resolution each time)
        self.blocks = nn.ModuleList()
        self.to_rgb = nn.ModuleList()

        channels = [512, 512, 512, 512, 256, 128, 64, 32, 16]
        self.to_rgb.append(nn.Conv2d(512, 3, 1))

        for i in range(int(torch.log2(torch.tensor(max_resolution / 4)))):
            in_ch = channels[i]
            out_ch = channels[i + 1]
            self.blocks.append(self._make_block(in_ch, out_ch))
            self.to_rgb.append(nn.Conv2d(out_ch, 3, 1))

        self.current_stage = 0
        self.alpha = 1.0

    def _make_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.LeakyReLU(0.2)
        )

    def forward(self, z):
        x = z.view(-1, self.latent_dim, 1, 1)
        x = self.initial(x)

        if self.current_stage == 0:
            return self.to_rgb[0](x)

        for i in range(self.current_stage - 1):
            x = self.blocks[i](x)

        # Fade in new layer
        old_x = nn.functional.interpolate(x, scale_factor=2, mode='nearest')
        old_rgb = self.to_rgb[self.current_stage - 1](old_x)

        new_x = self.blocks[self.current_stage - 1](x)
        new_rgb = self.to_rgb[self.current_stage](new_x)

        return (1 - self.alpha) * old_rgb + self.alpha * new_rgb

    def grow(self):
        self.current_stage += 1
        self.alpha = 0.0

    def update_alpha(self, delta):
        self.alpha = min(1.0, self.alpha + delta)

StyleGAN Architecture

StyleGAN revolutionized image synthesis by introducing a style-based generator that provides unprecedented control over generated images. Instead of feeding the latent vector directly into the generator, StyleGAN first maps it through a mapping network to an intermediate latent space W. This learned transformation allows the W space to become disentangled, meaning different dimensions control independent aspects of the image.

Styles are injected at each resolution level through adaptive instance normalization. The generator also incorporates learned per-pixel noise inputs that add stochastic variation to fine details like hair strands and skin pores. These noise inputs control aspects that vary randomly across different samples of the same underlying content.

PYTHON
class MappingNetwork(nn.Module):
    def __init__(self, latent_dim=512, num_layers=8):
        super().__init__()

        layers = []
        for _ in range(num_layers):
            layers.extend([
                nn.Linear(latent_dim, latent_dim),
                nn.LeakyReLU(0.2)
            ])

        self.mapping = nn.Sequential(*layers)

    def forward(self, z):
        return self.mapping(z)


class AdaIN(nn.Module):
    def __init__(self, style_dim, num_features):
        super().__init__()
        self.norm = nn.InstanceNorm2d(num_features, affine=False)
        self.style = nn.Linear(style_dim, num_features * 2)

    def forward(self, x, w):
        style = self.style(w)
        gamma, beta = style.chunk(2, dim=1)
        gamma = gamma.unsqueeze(2).unsqueeze(3)
        beta = beta.unsqueeze(2).unsqueeze(3)

        x = self.norm(x)
        return gamma * x + beta


class StyleBlock(nn.Module):
    def __init__(self, in_channels, out_channels, style_dim):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1)

        self.adain1 = AdaIN(style_dim, out_channels)
        self.adain2 = AdaIN(style_dim, out_channels)

        self.noise_scale1 = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
        self.noise_scale2 = nn.Parameter(torch.zeros(1, out_channels, 1, 1))

        self.activation = nn.LeakyReLU(0.2)

    def forward(self, x, w, noise1=None, noise2=None):
        batch_size, _, h, w_dim = x.shape

        x = self.conv1(x)
        if noise1 is None:
            noise1 = torch.randn(batch_size, 1, h, w_dim, device=x.device)
        x = x + self.noise_scale1 * noise1
        x = self.adain1(x, w)
        x = self.activation(x)

        x = self.conv2(x)
        if noise2 is None:
            noise2 = torch.randn(batch_size, 1, h, w_dim, device=x.device)
        x = x + self.noise_scale2 * noise2
        x = self.adain2(x, w)
        x = self.activation(x)

        return x

StyleGAN2 Improvements

StyleGAN2 addresses artifacts present in the original StyleGAN, particularly the characteristic "blob" artifacts and phase artifacts in generated images. The key changes include weight demodulation instead of instance normalization, path length regularization for smoother latent space, and architectural refinements that eliminate progressive growing in favor of skip connections.

Weight demodulation normalizes convolution weights based on the input statistics expected from the style modulation. This achieves the same decorrelation as instance normalization but without processing the activations directly, eliminating droplet artifacts caused by the normalization.

PYTHON
class ModulatedConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, style_dim, demodulate=True):
        super().__init__()

        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.demodulate = demodulate

        self.weight = nn.Parameter(
            torch.randn(out_channels, in_channels, kernel_size, kernel_size)
        )
        self.style = nn.Linear(style_dim, in_channels)

    def forward(self, x, w):
        batch_size = x.size(0)

        # Modulate weights by style
        style = self.style(w).view(batch_size, 1, -1, 1, 1)
        weight = self.weight.unsqueeze(0) * style

        if self.demodulate:
            # Demodulate weights
            demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
            weight = weight * demod.view(batch_size, self.out_channels, 1, 1, 1)

        # Reshape for grouped convolution
        weight = weight.view(
            batch_size * self.out_channels,
            weight.size(2),
            self.kernel_size,
            self.kernel_size
        )

        x = x.view(1, batch_size * x.size(1), x.size(2), x.size(3))
        padding = self.kernel_size // 2
        x = nn.functional.conv2d(x, weight, padding=padding, groups=batch_size)

        return x.view(batch_size, self.out_channels, x.size(2), x.size(3))


class StyleGAN2Block(nn.Module):
    def __init__(self, in_channels, out_channels, style_dim):
        super().__init__()

        self.conv1 = ModulatedConv2d(in_channels, out_channels, 3, style_dim)
        self.conv2 = ModulatedConv2d(out_channels, out_channels, 3, style_dim)

        self.noise_scale1 = nn.Parameter(torch.zeros(1))
        self.noise_scale2 = nn.Parameter(torch.zeros(1))

        self.bias1 = nn.Parameter(torch.zeros(out_channels))
        self.bias2 = nn.Parameter(torch.zeros(out_channels))

        self.activation = nn.LeakyReLU(0.2)

    def forward(self, x, w, noise1=None, noise2=None):
        x = self.conv1(x, w)
        if noise1 is not None:
            x = x + self.noise_scale1 * noise1
        x = x + self.bias1.view(1, -1, 1, 1)
        x = self.activation(x)

        x = self.conv2(x, w)
        if noise2 is not None:
            x = x + self.noise_scale2 * noise2
        x = x + self.bias2.view(1, -1, 1, 1)
        x = self.activation(x)

        return x

Self-Attention in GANs

Self-Attention GAN (SAGAN) incorporates attention mechanisms to capture long-range dependencies in images. Standard convolutional layers have limited receptive fields, making it difficult to maintain consistency across distant image regions. Self-attention allows the generator to reference any spatial location when generating each pixel, enabling globally coherent structures.

The attention mechanism computes queries, keys, and values from feature maps, then weights values based on query-key similarity. This enables the network to learn which distant features are relevant for generating each local region, improving consistency in global structures like symmetry and repeated patterns.

PYTHON
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super().__init__()

        self.query = nn.Conv2d(in_channels, in_channels // 8, 1)
        self.key = nn.Conv2d(in_channels, in_channels // 8, 1)
        self.value = nn.Conv2d(in_channels, in_channels, 1)

        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, C, H, W = x.size()

        # Compute query, key, value
        q = self.query(x).view(batch_size, -1, H * W).permute(0, 2, 1)
        k = self.key(x).view(batch_size, -1, H * W)
        v = self.value(x).view(batch_size, -1, H * W)

        # Attention weights
        attention = torch.softmax(torch.bmm(q, k), dim=-1)

        # Apply attention to values
        out = torch.bmm(v, attention.permute(0, 2, 1))
        out = out.view(batch_size, C, H, W)

        return self.gamma * out + x


class SAGANGenerator(nn.Module):
    def __init__(self, latent_dim=128, num_channels=3, base_channels=64):
        super().__init__()

        self.initial = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, base_channels * 16, 4, 1, 0),
            nn.BatchNorm2d(base_channels * 16),
            nn.ReLU()
        )

        self.block1 = self._make_block(base_channels * 16, base_channels * 8)
        self.block2 = self._make_block(base_channels * 8, base_channels * 4)

        self.attention1 = SelfAttention(base_channels * 4)

        self.block3 = self._make_block(base_channels * 4, base_channels * 2)

        self.attention2 = SelfAttention(base_channels * 2)

        self.block4 = self._make_block(base_channels * 2, base_channels)

        self.final = nn.Sequential(
            nn.Conv2d(base_channels, num_channels, 3, 1, 1),
            nn.Tanh()
        )

    def _make_block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_ch, out_ch, 3, 1, 1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU()
        )

    def forward(self, z):
        x = z.view(z.size(0), -1, 1, 1)
        x = self.initial(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.attention1(x)
        x = self.block3(x)
        x = self.attention2(x)
        x = self.block4(x)
        return self.final(x)

Key Takeaways

GAN architectures have evolved from simple fully-connected networks to sophisticated designs enabling photorealistic image synthesis. DCGAN established convolutional principles including strided convolutions, batch normalization, and LeakyReLU activations. Progressive GAN enables high-resolution synthesis through gradual resolution increase. StyleGAN introduces style-based generation with mapping networks and adaptive normalization for disentangled control. StyleGAN2 eliminates artifacts through weight demodulation. Self-attention mechanisms capture long-range dependencies for globally coherent generation.

18.4 Conditional GANs Advanced

Conditional Generative Adversarial Networks

While standard GANs generate samples from the learned data distribution without control over specific attributes, conditional GANs incorporate additional information to guide generation. This conditioning enables targeted synthesis based on class labels, input images, text descriptions, or other auxiliary information.

Class-Conditional GANs

The simplest form of conditioning provides class labels to both generator and discriminator. The generator learns to produce samples of a specific class when given the corresponding label, while the discriminator learns to identify whether samples match their claimed class. This transforms generation from unconditional sampling to controlled synthesis.

Class information can be incorporated through concatenation, where labels are embedded and concatenated with features, or through projection, where label embeddings interact multiplicatively with features. Projection-based conditioning typically produces better results by allowing class information to modulate features rather than simply adding another input channel.

PYTHON
import torch
import torch.nn as nn

class ConditionalGenerator(nn.Module):
    def __init__(self, latent_dim, num_classes, output_dim):
        super().__init__()

        self.label_embedding = nn.Embedding(num_classes, latent_dim)

        self.model = nn.Sequential(
            nn.Linear(latent_dim * 2, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, output_dim),
            nn.Tanh()
        )

    def forward(self, z, labels):
        label_embed = self.label_embedding(labels)
        combined = torch.cat([z, label_embed], dim=1)
        return self.model(combined)


class ConditionalDiscriminator(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()

        self.label_embedding = nn.Embedding(num_classes, input_dim)

        self.model = nn.Sequential(
            nn.Linear(input_dim * 2, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x, labels):
        label_embed = self.label_embedding(labels)
        combined = torch.cat([x, label_embed], dim=1)
        return self.model(combined)


class ProjectionDiscriminator(nn.Module):
    def __init__(self, input_channels, num_classes, feature_dim=512):
        super().__init__()

        self.features = nn.Sequential(
            nn.Conv2d(input_channels, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, feature_dim, 4, 2, 1),
            nn.LeakyReLU(0.2)
        )

        self.fc = nn.Linear(feature_dim * 4 * 4, 1)
        self.embed = nn.Embedding(num_classes, feature_dim * 4 * 4)

    def forward(self, x, labels):
        features = self.features(x)
        features = features.view(features.size(0), -1)

        output = self.fc(features)
        label_embed = self.embed(labels)
        projection = (features * label_embed).sum(dim=1, keepdim=True)

        return output + projection

Image-to-Image Translation with Pix2Pix

Pix2pix learns mappings between paired images, enabling applications like converting sketches to photographs, semantic labels to realistic images, or day scenes to night. The generator takes an input image and produces a corresponding output image, while the discriminator judges whether input-output pairs are real or generated.

The generator uses a U-Net architecture with skip connections that preserve fine details from input to output. Skip connections allow low-level information like edges and textures to bypass the bottleneck, while deeper layers transform the semantic content. The discriminator uses a PatchGAN design that classifies overlapping patches rather than the entire image, encouraging sharp local details.

PYTHON
class UNetGenerator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, features=64):
        super().__init__()

        # Encoder
        self.enc1 = self._encoder_block(in_channels, features, normalize=False)
        self.enc2 = self._encoder_block(features, features * 2)
        self.enc3 = self._encoder_block(features * 2, features * 4)
        self.enc4 = self._encoder_block(features * 4, features * 8)
        self.enc5 = self._encoder_block(features * 8, features * 8)
        self.enc6 = self._encoder_block(features * 8, features * 8)
        self.enc7 = self._encoder_block(features * 8, features * 8)

        self.bottleneck = nn.Sequential(
            nn.Conv2d(features * 8, features * 8, 4, 2, 1),
            nn.ReLU()
        )

        # Decoder with skip connections
        self.dec1 = self._decoder_block(features * 8, features * 8, dropout=True)
        self.dec2 = self._decoder_block(features * 16, features * 8, dropout=True)
        self.dec3 = self._decoder_block(features * 16, features * 8, dropout=True)
        self.dec4 = self._decoder_block(features * 16, features * 8)
        self.dec5 = self._decoder_block(features * 16, features * 4)
        self.dec6 = self._decoder_block(features * 8, features * 2)
        self.dec7 = self._decoder_block(features * 4, features)

        self.final = nn.Sequential(
            nn.ConvTranspose2d(features * 2, out_channels, 4, 2, 1),
            nn.Tanh()
        )

    def _encoder_block(self, in_ch, out_ch, normalize=True):
        layers = [nn.Conv2d(in_ch, out_ch, 4, 2, 1, bias=False)]
        if normalize:
            layers.append(nn.BatchNorm2d(out_ch))
        layers.append(nn.LeakyReLU(0.2))
        return nn.Sequential(*layers)

    def _decoder_block(self, in_ch, out_ch, dropout=False):
        layers = [
            nn.ConvTranspose2d(in_ch, out_ch, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_ch)
        ]
        if dropout:
            layers.append(nn.Dropout(0.5))
        layers.append(nn.ReLU())
        return nn.Sequential(*layers)

    def forward(self, x):
        # Encode
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)
        e5 = self.enc5(e4)
        e6 = self.enc6(e5)
        e7 = self.enc7(e6)

        b = self.bottleneck(e7)

        # Decode with skip connections
        d1 = self.dec1(b)
        d2 = self.dec2(torch.cat([d1, e7], dim=1))
        d3 = self.dec3(torch.cat([d2, e6], dim=1))
        d4 = self.dec4(torch.cat([d3, e5], dim=1))
        d5 = self.dec5(torch.cat([d4, e4], dim=1))
        d6 = self.dec6(torch.cat([d5, e3], dim=1))
        d7 = self.dec7(torch.cat([d6, e2], dim=1))

        return self.final(torch.cat([d7, e1], dim=1))


class PatchDiscriminator(nn.Module):
    def __init__(self, in_channels=6):
        super().__init__()

        self.model = nn.Sequential(
            nn.Conv2d(in_channels, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 512, 4, 1, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.Conv2d(512, 1, 4, 1, 1)
        )

    def forward(self, x, y):
        combined = torch.cat([x, y], dim=1)
        return self.model(combined)

Unpaired Translation with CycleGAN

CycleGAN enables image translation without paired training examples by learning bidirectional mappings between domains. Two generators learn transformations in opposite directions, while cycle consistency losses ensure that translating an image to the other domain and back recovers the original. This enables applications like style transfer, season conversion, and domain adaptation without requiring corresponding images.

The cycle consistency constraint prevents mode collapse and ensures meaningful translations. If generator G maps horses to zebras and generator F maps zebras to horses, then F(G(horse)) should recover the original horse image. This constraint, combined with adversarial losses for both directions, guides learning without paired supervision.

PYTHON
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, 3),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, 3),
            nn.InstanceNorm2d(channels)
        )

    def forward(self, x):
        return x + self.block(x)


class CycleGANGenerator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, num_residual=9):
        super().__init__()

        # Initial convolution
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_channels, 64, 7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        ]

        # Downsampling
        in_features = 64
        out_features = in_features * 2
        for _ in range(2):
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features * 2

        # Residual blocks
        for _ in range(num_residual):
            model += [ResidualBlock(in_features)]

        # Upsampling
        out_features = in_features // 2
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(in_features, out_features, 3, stride=2,
                                   padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features // 2

        # Output layer
        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, out_channels, 7),
            nn.Tanh()
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


def cycle_consistency_loss(real_images, reconstructed_images):
    return torch.mean(torch.abs(real_images - reconstructed_images))


def train_cyclegan_step(G_AB, G_BA, D_A, D_B, real_A, real_B,
                        opt_G, opt_D_A, opt_D_B, lambda_cycle=10):
    # Generate fake images
    fake_B = G_AB(real_A)
    fake_A = G_BA(real_B)

    # Reconstruct images
    reconstructed_A = G_BA(fake_B)
    reconstructed_B = G_AB(fake_A)

    # Generator losses
    opt_G.zero_grad()

    loss_GAN_AB = -torch.mean(D_B(fake_B))
    loss_GAN_BA = -torch.mean(D_A(fake_A))

    loss_cycle_A = cycle_consistency_loss(real_A, reconstructed_A)
    loss_cycle_B = cycle_consistency_loss(real_B, reconstructed_B)

    loss_G = loss_GAN_AB + loss_GAN_BA + lambda_cycle * (loss_cycle_A + loss_cycle_B)
    loss_G.backward()
    opt_G.step()

    # Discriminator A
    opt_D_A.zero_grad()
    loss_D_A = torch.mean(torch.relu(1 - D_A(real_A))) + torch.mean(torch.relu(1 + D_A(fake_A.detach())))
    loss_D_A.backward()
    opt_D_A.step()

    # Discriminator B
    opt_D_B.zero_grad()
    loss_D_B = torch.mean(torch.relu(1 - D_B(real_B))) + torch.mean(torch.relu(1 + D_B(fake_B.detach())))
    loss_D_B.backward()
    opt_D_B.step()

    return loss_G.item(), loss_D_A.item(), loss_D_B.item()

Semantic Image Synthesis with SPADE

SPADE (Spatially-Adaptive Denormalization) generates photorealistic images from semantic segmentation maps. Unlike approaches that simply concatenate the segmentation map with features, SPADE uses the semantic layout to modulate normalized activations throughout the generator. This allows the network to adaptively adjust its behavior based on the semantic class at each spatial location.

The SPADE layer learns class-specific scale and shift parameters that vary spatially according to the input segmentation. This is more expressive than global conditioning because different regions of the image receive different transformations based on their semantic content.

PYTHON
class SPADE(nn.Module):
    def __init__(self, norm_channels, label_channels):
        super().__init__()

        self.norm = nn.InstanceNorm2d(norm_channels, affine=False)

        hidden_channels = 128
        self.shared = nn.Sequential(
            nn.Conv2d(label_channels, hidden_channels, 3, 1, 1),
            nn.ReLU()
        )
        self.gamma = nn.Conv2d(hidden_channels, norm_channels, 3, 1, 1)
        self.beta = nn.Conv2d(hidden_channels, norm_channels, 3, 1, 1)

    def forward(self, x, segmap):
        normalized = self.norm(x)

        segmap = nn.functional.interpolate(segmap, size=x.shape[2:], mode='nearest')
        shared = self.shared(segmap)
        gamma = self.gamma(shared)
        beta = self.beta(shared)

        return normalized * (1 + gamma) + beta


class SPADEResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, label_channels):
        super().__init__()

        self.spade1 = SPADE(in_channels, label_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)

        self.spade2 = SPADE(out_channels, label_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1)

        self.shortcut_spade = SPADE(in_channels, label_channels)
        self.shortcut_conv = nn.Conv2d(in_channels, out_channels, 1)

        self.activation = nn.LeakyReLU(0.2)

    def forward(self, x, segmap):
        shortcut = self.shortcut_conv(self.activation(self.shortcut_spade(x, segmap)))

        x = self.conv1(self.activation(self.spade1(x, segmap)))
        x = self.conv2(self.activation(self.spade2(x, segmap)))

        return x + shortcut


class SPADEGenerator(nn.Module):
    def __init__(self, label_channels, z_dim=256):
        super().__init__()

        self.fc = nn.Linear(z_dim, 16384)

        self.head = SPADEResBlock(1024, 1024, label_channels)
        self.up1 = SPADEResBlock(1024, 1024, label_channels)
        self.up2 = SPADEResBlock(1024, 512, label_channels)
        self.up3 = SPADEResBlock(512, 256, label_channels)
        self.up4 = SPADEResBlock(256, 128, label_channels)
        self.up5 = SPADEResBlock(128, 64, label_channels)

        self.final = nn.Sequential(
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 3, 3, 1, 1),
            nn.Tanh()
        )

    def forward(self, segmap, z=None):
        if z is None:
            z = torch.randn(segmap.size(0), 256, device=segmap.device)

        x = self.fc(z).view(-1, 1024, 4, 4)

        x = self.head(x, segmap)
        x = nn.functional.interpolate(x, scale_factor=2)
        x = self.up1(x, segmap)
        x = nn.functional.interpolate(x, scale_factor=2)
        x = self.up2(x, segmap)
        x = nn.functional.interpolate(x, scale_factor=2)
        x = self.up3(x, segmap)
        x = nn.functional.interpolate(x, scale_factor=2)
        x = self.up4(x, segmap)
        x = nn.functional.interpolate(x, scale_factor=2)
        x = self.up5(x, segmap)

        return self.final(x)

Key Takeaways

Conditional GANs extend generation from random sampling to controlled synthesis. Class conditioning through embedding concatenation or projection enables targeted class generation. Pix2pix learns paired image translation using U-Net generators and PatchGAN discriminators. CycleGAN enables unpaired translation through bidirectional generators and cycle consistency constraints. SPADE provides spatially-varying conditioning for semantic image synthesis, modulating features based on segmentation layouts. These techniques enable precise control over generated content for diverse applications.

18.5 GAN Applications Advanced

GAN Applications

Generative Adversarial Networks have found widespread application across diverse domains, from enhancing image resolution to generating entirely new content. These applications leverage the adversarial training paradigm to produce results that match or exceed traditional methods.

Image Super-Resolution

Super-resolution GANs reconstruct high-resolution images from low-resolution inputs, recovering fine details that cannot be obtained through simple interpolation. The generator learns to add plausible high-frequency content while the discriminator ensures generated details look realistic rather than blurry or artificial.

SRGAN introduced perceptual losses that compare features from a pretrained network rather than pixel values directly. This encourages the generator to produce images that are perceptually similar to ground truth even if individual pixels differ. The combination of adversarial loss, perceptual loss, and content loss produces sharp, detailed upscaled images.

PYTHON
import torch
import torch.nn as nn
import torchvision.models as models

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.BatchNorm2d(channels),
            nn.PReLU(),
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.BatchNorm2d(channels)
        )

    def forward(self, x):
        return x + self.block(x)


class SRGenerator(nn.Module):
    def __init__(self, scale_factor=4, num_residual=16):
        super().__init__()

        self.initial = nn.Sequential(
            nn.Conv2d(3, 64, 9, 1, 4),
            nn.PReLU()
        )

        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(64) for _ in range(num_residual)]
        )

        self.mid_conv = nn.Sequential(
            nn.Conv2d(64, 64, 3, 1, 1),
            nn.BatchNorm2d(64)
        )

        # Upsampling blocks
        upsample_blocks = []
        for _ in range(scale_factor // 2):
            upsample_blocks.extend([
                nn.Conv2d(64, 256, 3, 1, 1),
                nn.PixelShuffle(2),
                nn.PReLU()
            ])
        self.upsample = nn.Sequential(*upsample_blocks)

        self.final = nn.Conv2d(64, 3, 9, 1, 4)

    def forward(self, x):
        initial = self.initial(x)
        residual = self.residual_blocks(initial)
        mid = self.mid_conv(residual) + initial
        upsampled = self.upsample(mid)
        return self.final(upsampled)


class VGGPerceptualLoss(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = models.vgg19(pretrained=True).features[:36]
        self.vgg = nn.Sequential(*vgg).eval()

        for param in self.vgg.parameters():
            param.requires_grad = False

    def forward(self, generated, target):
        gen_features = self.vgg(generated)
        target_features = self.vgg(target)
        return nn.functional.mse_loss(gen_features, target_features)


def sr_generator_loss(fake_output, generated, target, perceptual_loss_fn,
                      adversarial_weight=0.001, perceptual_weight=0.006):
    adversarial_loss = -torch.log(fake_output + 1e-8).mean()
    content_loss = nn.functional.mse_loss(generated, target)
    perceptual_loss = perceptual_loss_fn(generated, target)

    return (content_loss +
            adversarial_weight * adversarial_loss +
            perceptual_weight * perceptual_loss)

Image Inpainting

Inpainting GANs fill in missing or corrupted regions of images with plausible content. The generator must understand both local texture patterns and global semantic structure to produce coherent completions. This requires reasoning about what objects might exist in masked regions based on surrounding context.

Partial convolutions address the challenge of handling arbitrary masks by normalizing convolution outputs based on the proportion of valid pixels in each receptive field. The mask is updated through the network to track which regions have been filled. Gated convolutions learn to dynamically select features, providing even more flexibility in handling irregular masks.

PYTHON
class PartialConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super().__init__()

        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.mask_conv = nn.Conv2d(1, 1, kernel_size, stride, padding, bias=False)

        nn.init.constant_(self.mask_conv.weight, 1.0)
        for param in self.mask_conv.parameters():
            param.requires_grad = False

        self.kernel_area = kernel_size ** 2

    def forward(self, x, mask):
        with torch.no_grad():
            mask_output = self.mask_conv(mask)

        valid_ratio = self.kernel_area / (mask_output + 1e-8)
        valid_ratio = valid_ratio * mask

        x = x * mask
        output = self.conv(x) * valid_ratio

        new_mask = (mask_output > 0).float()

        return output, new_mask


class InpaintingGenerator(nn.Module):
    def __init__(self, in_channels=4):
        super().__init__()

        # Encoder with partial convolutions
        self.enc1 = PartialConv2d(in_channels, 64, 7, 2, 3)
        self.enc2 = PartialConv2d(64, 128, 5, 2, 2)
        self.enc3 = PartialConv2d(128, 256, 5, 2, 2)
        self.enc4 = PartialConv2d(256, 512, 3, 2, 1)

        # Decoder
        self.dec4 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2)
        )
        self.dec3 = nn.Sequential(
            nn.ConvTranspose2d(512, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2)
        )
        self.dec2 = nn.Sequential(
            nn.ConvTranspose2d(256, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2)
        )
        self.dec1 = nn.Sequential(
            nn.ConvTranspose2d(128, 3, 4, 2, 1),
            nn.Tanh()
        )

        self.activation = nn.LeakyReLU(0.2)

    def forward(self, x, mask):
        # Concatenate image and mask
        x_in = torch.cat([x * mask, mask], dim=1)
        init_mask = mask

        # Encode
        e1, m1 = self.enc1(x_in, init_mask)
        e1 = self.activation(e1)

        e2, m2 = self.enc2(e1, m1)
        e2 = self.activation(e2)

        e3, m3 = self.enc3(e2, m2)
        e3 = self.activation(e3)

        e4, m4 = self.enc4(e3, m3)
        e4 = self.activation(e4)

        # Decode with skip connections
        d4 = self.dec4(e4)
        d3 = self.dec3(torch.cat([d4, e3], dim=1))
        d2 = self.dec2(torch.cat([d3, e2], dim=1))
        d1 = self.dec1(torch.cat([d2, e1], dim=1))

        return d1


def inpainting_loss(generated, target, mask, discriminator):
    # Reconstruction loss on valid regions
    valid_loss = torch.abs(generated * mask - target * mask).mean()

    # Reconstruction loss on hole regions
    hole_loss = torch.abs(generated * (1 - mask) - target * (1 - mask)).mean()

    # Adversarial loss
    fake_output = discriminator(generated)
    adversarial_loss = -torch.log(fake_output + 1e-8).mean()

    return valid_loss + 6 * hole_loss + 0.1 * adversarial_loss

Face Generation and Editing

Face synthesis has become one of the most visible GAN applications, with models capable of generating photorealistic faces of non-existent people. Beyond generation, GANs enable semantic editing of faces including changing attributes like age, expression, pose, and accessories while preserving identity.

Latent space manipulation provides intuitive controls for face editing. By finding directions in the latent space that correspond to semantic attributes, we can modify faces by moving along these directions. The disentangled nature of StyleGAN's W space makes it particularly amenable to such manipulations.

PYTHON
class FaceEncoder(nn.Module):
    def __init__(self, latent_dim=512):
        super().__init__()

        self.features = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.Conv2d(512, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2)
        )

        self.fc = nn.Linear(512 * 4 * 4, latent_dim)

    def forward(self, x):
        features = self.features(x)
        features = features.view(features.size(0), -1)
        return self.fc(features)


class AttributeEditor:
    def __init__(self, generator, encoder, attribute_directions):
        self.generator = generator
        self.encoder = encoder
        self.directions = attribute_directions

    def encode(self, image):
        with torch.no_grad():
            return self.encoder(image)

    def edit(self, latent, attribute, strength=1.0):
        if attribute not in self.directions:
            raise ValueError(f"Unknown attribute: {attribute}")

        direction = self.directions[attribute]
        edited_latent = latent + strength * direction
        return edited_latent

    def generate(self, latent):
        with torch.no_grad():
            return self.generator(latent)

    def edit_image(self, image, attribute, strength=1.0):
        latent = self.encode(image)
        edited_latent = self.edit(latent, attribute, strength)
        return self.generate(edited_latent)


def find_attribute_direction(encoder, positive_images, negative_images):
    with torch.no_grad():
        positive_latents = torch.stack([encoder(img) for img in positive_images])
        negative_latents = torch.stack([encoder(img) for img in negative_images])

        positive_mean = positive_latents.mean(dim=0)
        negative_mean = negative_latents.mean(dim=0)

        direction = positive_mean - negative_mean
        direction = direction / direction.norm()

    return direction

Data Augmentation

GANs can generate synthetic training data to augment limited datasets, particularly valuable in domains where collecting real data is expensive or raises privacy concerns. Medical imaging, autonomous driving, and rare event detection all benefit from GAN-based augmentation.

The key challenge is ensuring generated samples are diverse and representative of the true data distribution rather than memorizing training examples. Conditional generation allows targeting specific underrepresented classes or scenarios, while careful evaluation ensures synthetic data improves rather than harms model performance.

PYTHON
class AugmentationGAN:
    def __init__(self, generator, num_classes, latent_dim, device):
        self.generator = generator
        self.num_classes = num_classes
        self.latent_dim = latent_dim
        self.device = device

    def generate_samples(self, num_samples, class_label=None):
        self.generator.eval()

        with torch.no_grad():
            z = torch.randn(num_samples, self.latent_dim, device=self.device)

            if class_label is not None:
                labels = torch.full((num_samples,), class_label,
                                    dtype=torch.long, device=self.device)
            else:
                labels = torch.randint(0, self.num_classes, (num_samples,),
                                       device=self.device)

            samples = self.generator(z, labels)

        self.generator.train()
        return samples, labels

    def augment_batch(self, real_images, real_labels, augment_ratio=0.5):
        batch_size = real_images.size(0)
        num_synthetic = int(batch_size * augment_ratio)

        synthetic_images, synthetic_labels = self.generate_samples(num_synthetic)

        combined_images = torch.cat([real_images, synthetic_images], dim=0)
        combined_labels = torch.cat([real_labels, synthetic_labels], dim=0)

        # Shuffle
        perm = torch.randperm(combined_images.size(0))
        return combined_images[perm], combined_labels[perm]


class BalancedAugmenter:
    def __init__(self, generator, latent_dim, device):
        self.generator = generator
        self.latent_dim = latent_dim
        self.device = device

    def balance_dataset(self, labels, target_count):
        unique_labels, counts = torch.unique(labels, return_counts=True)

        synthetic_samples = []
        synthetic_labels = []

        for label, count in zip(unique_labels, counts):
            if count < target_count:
                num_needed = target_count - count
                samples, labs = self._generate_class(label.item(), num_needed)
                synthetic_samples.append(samples)
                synthetic_labels.append(labs)

        if synthetic_samples:
            return (torch.cat(synthetic_samples, dim=0),
                    torch.cat(synthetic_labels, dim=0))
        return None, None

    def _generate_class(self, class_label, num_samples):
        self.generator.eval()
        with torch.no_grad():
            z = torch.randn(num_samples, self.latent_dim, device=self.device)
            labels = torch.full((num_samples,), class_label,
                                dtype=torch.long, device=self.device)
            samples = self.generator(z, labels)
        self.generator.train()
        return samples, labels

Video Generation

Extending GANs to video requires modeling temporal coherence in addition to spatial realism. Generated videos must not only contain realistic individual frames but also exhibit smooth, physically plausible motion. This presents challenges in maintaining consistency across time while allowing appropriate variation.

Video GANs typically use 3D convolutions or recurrent architectures to capture temporal dependencies. Discriminators may evaluate individual frames, short clips, or both to ensure quality at multiple temporal scales. Motion can be modeled explicitly through optical flow or implicitly through learned temporal representations.

PYTHON
class VideoGenerator(nn.Module):
    def __init__(self, latent_dim=256, num_frames=16):
        super().__init__()

        self.num_frames = num_frames

        # Temporal latent processing
        self.temporal_fc = nn.Sequential(
            nn.Linear(latent_dim, 512 * 4),
            nn.LeakyReLU(0.2)
        )

        # 3D convolutions for spatiotemporal generation
        self.conv3d_1 = nn.Sequential(
            nn.ConvTranspose3d(512, 256, (4, 4, 4), (2, 2, 2), (1, 1, 1)),
            nn.BatchNorm3d(256),
            nn.ReLU()
        )

        self.conv3d_2 = nn.Sequential(
            nn.ConvTranspose3d(256, 128, (4, 4, 4), (2, 2, 2), (1, 1, 1)),
            nn.BatchNorm3d(128),
            nn.ReLU()
        )

        self.conv3d_3 = nn.Sequential(
            nn.ConvTranspose3d(128, 64, (4, 4, 4), (2, 2, 2), (1, 1, 1)),
            nn.BatchNorm3d(64),
            nn.ReLU()
        )

        self.conv3d_4 = nn.Sequential(
            nn.ConvTranspose3d(64, 3, (4, 4, 4), (2, 2, 2), (1, 1, 1)),
            nn.Tanh()
        )

    def forward(self, z):
        x = self.temporal_fc(z)
        x = x.view(-1, 512, 1, 2, 2)

        x = self.conv3d_1(x)
        x = self.conv3d_2(x)
        x = self.conv3d_3(x)
        x = self.conv3d_4(x)

        return x


class VideoDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.spatial_disc = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2)
        )

        self.temporal_disc = nn.Sequential(
            nn.Conv3d(256, 512, (4, 4, 4), (2, 2, 2), (1, 1, 1)),
            nn.BatchNorm3d(512),
            nn.LeakyReLU(0.2),
            nn.Conv3d(512, 1, (4, 4, 4), (1, 1, 1), (0, 0, 0))
        )

    def forward(self, video):
        batch_size, channels, frames, height, width = video.shape

        # Process each frame spatially
        video_flat = video.permute(0, 2, 1, 3, 4).reshape(-1, channels, height, width)
        spatial_features = self.spatial_disc(video_flat)

        # Reshape back to video format
        _, c, h, w = spatial_features.shape
        spatial_features = spatial_features.view(batch_size, frames, c, h, w)
        spatial_features = spatial_features.permute(0, 2, 1, 3, 4)

        # Temporal discrimination
        output = self.temporal_disc(spatial_features)
        return output.view(batch_size, -1)

Key Takeaways

GANs enable diverse applications across image and video domains. Super-resolution recovers high-frequency details using perceptual losses for sharp results. Inpainting fills missing regions using partial or gated convolutions to handle arbitrary masks. Face generation and editing leverage latent space manipulation for semantic control. Data augmentation generates synthetic training samples to balance datasets or address data scarcity. Video generation extends the framework to temporal modeling with 3D convolutions and multi-scale discrimination. These applications demonstrate the versatility of adversarial training for generation tasks.