Intermediate Advanced 90 min read

Chapter 17: Vision Transformers

ViT, Swin Transformer, and vision-only applications.

Learning Objectives

["Build vision transformers", "Apply to image tasks", "Use timm library"]


17.3 VAE Variants Advanced

VAE Variants

Variational Autoencoders learn latent representations through probabilistic encoding. Various extensions improve upon the basic VAE for specific applications and better generation quality.

Standard VAE

The basic VAE learns to encode and decode through a variational bottleneck.

PYTHON
import torch
import torch.nn as nn
import torch.nn.functional as F

class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim=512, latent_dim=20):
        super().__init__()
        self.latent_dim = latent_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return recon, mu, logvar


def vae_loss(recon, x, mu, logvar, beta=1.0):
    recon_loss = F.binary_cross_entropy(recon, x, reduction="sum")
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + beta * kl_loss

Convolutional VAE

CNNs provide better inductive bias for image data.

PYTHON
class ConvVAE(nn.Module):
    def __init__(self, in_channels=3, latent_dim=256):
        super().__init__()
        self.latent_dim = latent_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten()
        )

        self.fc_mu = nn.Linear(256 * 4 * 4, latent_dim)
        self.fc_logvar = nn.Linear(256 * 4 * 4, latent_dim)

        # Decoder
        self.fc_decode = nn.Linear(latent_dim, 256 * 4 * 4)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, in_channels, 4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = self.fc_decode(z)
        h = h.view(-1, 256, 4, 4)
        return self.decoder(h)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

Beta-VAE

Beta-VAE encourages disentangled representations by increasing KL weight.

PYTHON
class BetaVAE(nn.Module):
    def __init__(self, input_dim, latent_dim=10, beta=4.0):
        super().__init__()
        self.beta = beta
        self.latent_dim = latent_dim

        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_logvar = nn.Linear(256, latent_dim)

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, input_dim),
            nn.Sigmoid()
        )

    def forward(self, x):
        h = self.encoder(x)
        mu, logvar = self.fc_mu(h), self.fc_logvar(h)
        z = mu + torch.randn_like(mu) * torch.exp(0.5 * logvar)
        recon = self.decoder(z)
        return recon, mu, logvar

    def loss(self, recon, x, mu, logvar):
        recon_loss = F.mse_loss(recon, x, reduction="sum")
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return recon_loss + self.beta * kl_loss

VQ-VAE

Vector Quantized VAE uses discrete latent codes instead of continuous.

PYTHON
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost=0.25):
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.commitment_cost = commitment_cost

        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.weight.data.uniform_(-1/num_embeddings, 1/num_embeddings)

    def forward(self, z):
        z_flat = z.permute(0, 2, 3, 1).contiguous().view(-1, self.embedding_dim)

        distances = (
            z_flat.pow(2).sum(dim=1, keepdim=True)
            - 2 * z_flat @ self.embedding.weight.t()
            + self.embedding.weight.pow(2).sum(dim=1)
        )

        indices = distances.argmin(dim=1)
        z_q = self.embedding(indices).view(z.shape[0], z.shape[2], z.shape[3], -1)
        z_q = z_q.permute(0, 3, 1, 2)

        commitment_loss = F.mse_loss(z_q.detach(), z)
        codebook_loss = F.mse_loss(z_q, z.detach())
        loss = codebook_loss + self.commitment_cost * commitment_loss

        z_q = z + (z_q - z).detach()

        return z_q, loss, indices


class VQVAE(nn.Module):
    def __init__(self, in_channels=3, hidden_dim=128, num_embeddings=512, embedding_dim=64):
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, hidden_dim, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_dim, hidden_dim, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_dim, embedding_dim, 1)
        )

        self.quantizer = VectorQuantizer(num_embeddings, embedding_dim)

        self.decoder = nn.Sequential(
            nn.Conv2d(embedding_dim, hidden_dim, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(hidden_dim, hidden_dim, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(hidden_dim, in_channels, 4, stride=2, padding=1)
        )

    def forward(self, x):
        z_e = self.encoder(x)
        z_q, vq_loss, indices = self.quantizer(z_e)
        recon = self.decoder(z_q)
        return recon, vq_loss, indices

    def loss(self, recon, x, vq_loss):
        recon_loss = F.mse_loss(recon, x)
        return recon_loss + vq_loss

Key Takeaways

VAE variants address different limitations of the standard VAE. Convolutional VAEs provide better image modeling through spatial inductive bias. Beta-VAE encourages disentangled representations with increased KL weighting. VQ-VAE uses discrete latent codes for sharper reconstructions and easier autoregressive modeling. Each variant trades off between reconstruction quality, latent structure, and generation diversity.

17.4 Applications of VAEs Advanced

Applications of VAEs

Variational Autoencoders enable numerous applications beyond simple reconstruction. Their learned latent spaces support generation, interpolation, and structured data manipulation.

Image Generation

Sampling from the latent space produces new images.

PYTHON
import torch
import torch.nn as nn

class VAEGenerator:
    def __init__(self, model, device="cuda"):
        self.model = model.to(device)
        self.device = device

    def sample(self, num_samples, temperature=1.0):
        self.model.eval()
        with torch.no_grad():
            z = torch.randn(num_samples, self.model.latent_dim, device=self.device)
            z = z * temperature
            samples = self.model.decode(z)
        return samples

    def interpolate(self, x1, x2, steps=10):
        self.model.eval()
        with torch.no_grad():
            mu1, _ = self.model.encode(x1)
            mu2, _ = self.model.encode(x2)

            alphas = torch.linspace(0, 1, steps, device=self.device)
            interpolations = []

            for alpha in alphas:
                z = (1 - alpha) * mu1 + alpha * mu2
                recon = self.model.decode(z)
                interpolations.append(recon)

        return torch.stack(interpolations)

    def traverse_latent(self, x, dim, range_val=3, steps=10):
        self.model.eval()
        with torch.no_grad():
            mu, _ = self.model.encode(x)

            values = torch.linspace(-range_val, range_val, steps)
            traversals = []

            for val in values:
                z = mu.clone()
                z[:, dim] = val
                recon = self.model.decode(z)
                traversals.append(recon)

        return torch.stack(traversals)

Anomaly Detection

Reconstruction error identifies out-of-distribution samples.

PYTHON
class VAEAnomalyDetector:
    def __init__(self, model, threshold=None):
        self.model = model
        self.threshold = threshold

    def compute_reconstruction_error(self, x):
        self.model.eval()
        with torch.no_grad():
            recon, mu, logvar = self.model(x)
            recon_error = ((x - recon) ** 2).sum(dim=tuple(range(1, x.dim())))
        return recon_error

    def compute_elbo_anomaly_score(self, x):
        self.model.eval()
        with torch.no_grad():
            recon, mu, logvar = self.model(x)

            recon_loss = ((x - recon) ** 2).sum(dim=tuple(range(1, x.dim())))
            kl_loss = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(dim=1)

            return recon_loss + kl_loss

    def fit_threshold(self, train_data, percentile=95):
        scores = []
        for batch in train_data:
            score = self.compute_elbo_anomaly_score(batch)
            scores.append(score)

        all_scores = torch.cat(scores)
        self.threshold = torch.quantile(all_scores, percentile / 100)
        return self.threshold

    def detect(self, x):
        score = self.compute_elbo_anomaly_score(x)
        return score > self.threshold, score


class VAEEnsembleDetector:
    def __init__(self, models):
        self.models = models

    def compute_disagreement(self, x):
        reconstructions = []
        for model in self.models:
            model.eval()
            with torch.no_grad():
                recon, _, _ = model(x)
                reconstructions.append(recon)

        reconstructions = torch.stack(reconstructions)
        variance = reconstructions.var(dim=0).mean(dim=tuple(range(1, x.dim())))
        return variance

Data Augmentation

VAEs generate synthetic training samples.

PYTHON
class VAEAugmenter:
    def __init__(self, model, device="cuda"):
        self.model = model.to(device)
        self.device = device

    def augment_batch(self, x, num_augmentations=5, noise_std=0.1):
        self.model.eval()
        augmented = [x]

        with torch.no_grad():
            mu, logvar = self.model.encode(x)

            for _ in range(num_augmentations):
                noise = torch.randn_like(mu) * noise_std
                z_aug = mu + noise
                recon = self.model.decode(z_aug)
                augmented.append(recon)

        return torch.cat(augmented, dim=0)

    def mixup_latent(self, x1, x2, alpha=0.5):
        self.model.eval()
        with torch.no_grad():
            mu1, _ = self.model.encode(x1)
            mu2, _ = self.model.encode(x2)

            z_mixed = alpha * mu1 + (1 - alpha) * mu2
            return self.model.decode(z_mixed)

    def generate_class_samples(self, class_samples, num_generate=100):
        self.model.eval()
        with torch.no_grad():
            mu, logvar = self.model.encode(class_samples)

            class_mu = mu.mean(dim=0)
            class_std = mu.std(dim=0)

            z_samples = class_mu + torch.randn(num_generate, mu.size(1), device=self.device) * class_std
            return self.model.decode(z_samples)

Representation Learning

VAE latent spaces provide useful features.

PYTHON
class VAEFeatureExtractor:
    def __init__(self, model):
        self.model = model

    def extract_features(self, x):
        self.model.eval()
        with torch.no_grad():
            mu, _ = self.model.encode(x)
        return mu

    def extract_full_representation(self, x):
        self.model.eval()
        with torch.no_grad():
            mu, logvar = self.model.encode(x)
            std = torch.exp(0.5 * logvar)
        return torch.cat([mu, std], dim=1)


class SemiSupervisedVAE(nn.Module):
    def __init__(self, input_dim, latent_dim, num_classes):
        super().__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes

        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU()
        )

        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_logvar = nn.Linear(256, latent_dim)

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim + num_classes, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, input_dim),
            nn.Sigmoid()
        )

        self.classifier = nn.Linear(latent_dim, num_classes)

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def classify(self, x):
        mu, _ = self.encode(x)
        return self.classifier(mu)

    def decode(self, z, y):
        y_onehot = torch.zeros(z.size(0), self.num_classes, device=z.device)
        y_onehot.scatter_(1, y.unsqueeze(1), 1)
        z_y = torch.cat([z, y_onehot], dim=1)
        return self.decoder(z_y)

    def forward(self, x, y=None):
        mu, logvar = self.encode(x)
        z = mu + torch.randn_like(mu) * torch.exp(0.5 * logvar)

        if y is None:
            y_pred = self.classifier(mu).argmax(dim=1)
            y = y_pred

        recon = self.decode(z, y)
        return recon, mu, logvar, self.classifier(mu)

Key Takeaways

VAEs enable diverse applications through their structured latent spaces. Image generation produces new samples by decoding random latent vectors. Anomaly detection uses reconstruction error to identify outliers. Data augmentation creates synthetic training examples. Representation learning extracts useful features for downstream tasks. Semi-supervised learning combines generative and discriminative objectives.

17.1 Introduction to Vision Transformers Intermediate

Introduction to Vision Transformers

Vision Transformers (ViT) apply the transformer architecture to image understanding, challenging the dominance of convolutional neural networks. By treating images as sequences of patches, ViT achieves state-of-the-art results on image classification and beyond.

From CNNs to Transformers

Convolutional networks have dominated computer vision for a decade, but transformers offer compelling advantages.

PYTHON
import torch
import torch.nn as nn
import torch.nn.functional as F

class CNNvsTransformerComparison:
    """Compare characteristics of CNNs and Vision Transformers."""

    CNN_PROPERTIES = {
        "inductive_bias": "Strong (locality, translation equivariance)",
        "receptive_field": "Local, grows with depth",
        "parameter_sharing": "Kernel weights shared across positions",
        "data_efficiency": "High (works well with less data)",
        "scalability": "Limited by architecture constraints"
    }

    VIT_PROPERTIES = {
        "inductive_bias": "Weak (learned from data)",
        "receptive_field": "Global from first layer",
        "parameter_sharing": "Attention weights computed per input",
        "data_efficiency": "Lower (needs more data or pretraining)",
        "scalability": "Excellent (performance scales with data/compute)"
    }

    @classmethod
    def print_comparison(cls):
        print("CNN vs Vision Transformer")
        print("-" * 60)
        for prop in cls.CNN_PROPERTIES:
            print(f"{prop}:")
            print(f"  CNN: {cls.CNN_PROPERTIES[prop]}")
            print(f"  ViT: {cls.VIT_PROPERTIES[prop]}")
            print()

CNNvsTransformerComparison.print_comparison()

Image Patch Embedding

The key insight of ViT is treating an image as a sequence of patches.

PYTHON
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2

        self.projection = nn.Conv2d(
            in_channels,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )

    def forward(self, x):
        # x: [batch, channels, height, width]
        # Output: [batch, num_patches, embed_dim]
        x = self.projection(x)  # [batch, embed_dim, h/patch, w/patch]
        x = x.flatten(2)         # [batch, embed_dim, num_patches]
        x = x.transpose(1, 2)    # [batch, num_patches, embed_dim]
        return x


class LinearPatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        patch_dim = in_channels * patch_size * patch_size

        self.projection = nn.Linear(patch_dim, embed_dim)

    def forward(self, x):
        batch, channels, height, width = x.shape
        p = self.patch_size

        # Reshape to patches
        x = x.unfold(2, p, p).unfold(3, p, p)
        x = x.contiguous().view(batch, channels, -1, p, p)
        x = x.permute(0, 2, 1, 3, 4)
        x = x.contiguous().view(batch, -1, channels * p * p)

        return self.projection(x)


# Visualize patch extraction
def visualize_patches(img_size=224, patch_size=16):
    num_patches_per_side = img_size // patch_size
    total_patches = num_patches_per_side ** 2
    print(f"Image size: {img_size}x{img_size}")
    print(f"Patch size: {patch_size}x{patch_size}")
    print(f"Patches per side: {num_patches_per_side}")
    print(f"Total patches: {total_patches}")
    print(f"Sequence length: {total_patches + 1} (with CLS token)")

visualize_patches()

Positional Encoding for Images

Since transformers have no inherent notion of position, we add positional information.

PYTHON
class LearnedPositionalEmbedding(nn.Module):
    def __init__(self, num_patches, embed_dim):
        super().__init__()
        # +1 for CLS token
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

    def forward(self, x):
        return x + self.pos_embed


class SinusoidalPositionalEmbedding2D(nn.Module):
    def __init__(self, embed_dim, height, width, temperature=10000):
        super().__init__()
        self.embed_dim = embed_dim

        y_pos = torch.arange(height).unsqueeze(1).repeat(1, width)
        x_pos = torch.arange(width).unsqueeze(0).repeat(height, 1)

        dim_t = torch.arange(embed_dim // 4)
        dim_t = temperature ** (2 * dim_t / (embed_dim // 4))

        pos_x = x_pos.flatten().unsqueeze(1) / dim_t
        pos_y = y_pos.flatten().unsqueeze(1) / dim_t

        pos_embed = torch.cat([
            pos_x.sin(), pos_x.cos(),
            pos_y.sin(), pos_y.cos()
        ], dim=1)

        self.register_buffer("pos_embed", pos_embed.unsqueeze(0))

    def forward(self, x):
        return x + self.pos_embed[:, :x.size(1)]

The CLS Token

A learnable classification token aggregates global image information.

PYTHON
class CLSTokenEmbedding(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        nn.init.trunc_normal_(self.cls_token, std=0.02)

    def forward(self, x):
        batch_size = x.size(0)
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        return torch.cat([cls_tokens, x], dim=1)


class VisionTransformerEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768, dropout=0.1):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.dropout = nn.Dropout(dropout)

        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

    def forward(self, x):
        batch_size = x.size(0)

        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)
        x = x + self.pos_embed
        x = self.dropout(x)

        return x

Key Takeaways

Vision Transformers apply transformer architecture to images by treating them as sequences of patches. Patch embedding converts image regions into token vectors. Positional embeddings encode spatial information lost during tokenization. The CLS token provides a global representation for classification. ViT demonstrates that with sufficient data, transformers can match or exceed CNN performance on vision tasks.

17.2 ViT Architecture Intermediate

ViT Architecture

The Vision Transformer architecture closely follows the original transformer encoder design. Understanding each component reveals how transformers process visual information.

Complete ViT Implementation

A full Vision Transformer combines patch embedding, transformer blocks, and a classification head.

PYTHON
import torch
import torch.nn as nn
import torch.nn.functional as F

class VisionTransformer(nn.Module):
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_channels=3,
        num_classes=1000,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        dropout=0.1,
        attn_dropout=0.0
    ):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2

        # Patch embedding
        self.patch_embed = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )

        # CLS token and positional embedding
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(dropout)

        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout, attn_dropout)
            for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim)

        # Classification head
        self.head = nn.Linear(embed_dim, num_classes)

        self._init_weights()

    def _init_weights(self):
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.head.weight, std=0.02)
        nn.init.zeros_(self.head.bias)

    def forward(self, x):
        batch_size = x.size(0)

        # Patch embedding: [B, C, H, W] -> [B, num_patches, embed_dim]
        x = self.patch_embed(x).flatten(2).transpose(1, 2)

        # Prepend CLS token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)

        # Add positional embedding
        x = self.pos_drop(x + self.pos_embed)

        # Transformer blocks
        for block in self.blocks:
            x = block(x)

        x = self.norm(x)

        # Classification from CLS token
        return self.head(x[:, 0])

Transformer Block

Each block contains multi-head attention and a feed-forward network with residual connections.

PYTHON
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.1, attn_dropout=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, num_heads, attn_dropout)
        self.drop1 = nn.Dropout(dropout)

        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, int(embed_dim * mlp_ratio), dropout)
        self.drop2 = nn.Dropout(dropout)

    def forward(self, x):
        x = x + self.drop1(self.attn(self.norm1(x)))
        x = x + self.drop2(self.mlp(self.norm2(x)))
        return x


class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.0):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.attn_drop = nn.Dropout(dropout)

    def forward(self, x):
        batch, seq_len, embed_dim = x.shape

        qkv = self.qkv(x).reshape(batch, seq_len, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(batch, seq_len, embed_dim)
        return self.proj(x)


class MLP(nn.Module):
    def __init__(self, embed_dim, hidden_dim, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, embed_dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

ViT Model Variants

Different ViT configurations trade off between capacity and efficiency.

PYTHON
VIT_CONFIGS = {
    "vit_tiny": {
        "embed_dim": 192,
        "depth": 12,
        "num_heads": 3,
        "params": "5.7M"
    },
    "vit_small": {
        "embed_dim": 384,
        "depth": 12,
        "num_heads": 6,
        "params": "22M"
    },
    "vit_base": {
        "embed_dim": 768,
        "depth": 12,
        "num_heads": 12,
        "params": "86M"
    },
    "vit_large": {
        "embed_dim": 1024,
        "depth": 24,
        "num_heads": 16,
        "params": "307M"
    },
    "vit_huge": {
        "embed_dim": 1280,
        "depth": 32,
        "num_heads": 16,
        "params": "632M"
    }
}


def create_vit(variant="vit_base", img_size=224, patch_size=16, num_classes=1000):
    config = VIT_CONFIGS[variant]
    return VisionTransformer(
        img_size=img_size,
        patch_size=patch_size,
        num_classes=num_classes,
        embed_dim=config["embed_dim"],
        depth=config["depth"],
        num_heads=config["num_heads"]
    )


def print_model_summary():
    print("ViT Model Variants")
    print("-" * 50)
    for name, config in VIT_CONFIGS.items():
        print(f"{name}: {config['params']} parameters")
        print(f"  embed_dim={config['embed_dim']}, depth={config['depth']}, heads={config['num_heads']}")

print_model_summary()

Attention Visualization

Understanding what ViT attends to provides interpretability.

PYTHON
class AttentionVisualization:
    def __init__(self, model):
        self.model = model
        self.attention_maps = []

    def register_hooks(self):
        for block in self.model.blocks:
            block.attn.register_forward_hook(self._save_attention)

    def _save_attention(self, module, input, output):
        with torch.no_grad():
            qkv = module.qkv(input[0])
            batch, seq_len, _ = qkv.shape
            qkv = qkv.reshape(batch, seq_len, 3, module.num_heads, module.head_dim)
            q, k, _ = qkv.permute(2, 0, 3, 1, 4).unbind(0)
            attn = (q @ k.transpose(-2, -1)) * module.scale
            attn = attn.softmax(dim=-1)
            self.attention_maps.append(attn.cpu())

    def get_cls_attention(self, layer=-1):
        attn = self.attention_maps[layer]
        cls_attn = attn[:, :, 0, 1:]
        return cls_attn.mean(dim=1)

    def visualize_attention(self, image, patch_size=16):
        cls_attn = self.get_cls_attention()
        num_patches = int(cls_attn.size(-1) ** 0.5)
        attn_map = cls_attn.reshape(num_patches, num_patches)
        return attn_map

Key Takeaways

The ViT architecture applies standard transformer blocks to image patches. Multi-head self-attention enables global receptive fields from the first layer. The MLP ratio of 4x expands hidden dimensions in feed-forward layers. Model variants scale from tiny (5.7M) to huge (632M) parameters. Pre-norm with residual connections ensures stable training of deep models.

17.5 Advanced Vision Transformer Techniques Advanced

Advanced Vision Transformer Techniques

Building on the basic ViT architecture, numerous innovations improve efficiency, performance, and applicability to diverse vision tasks.

DeiT: Data-Efficient Training

Data-efficient Image Transformer introduces training techniques that enable ViT to work with smaller datasets.

PYTHON
import torch
import torch.nn as nn
import torch.nn.functional as F

class DeiT(nn.Module):
    def __init__(self, img_size=224, patch_size=16, num_classes=1000,
                 embed_dim=768, depth=12, num_heads=12):
        super().__init__()
        num_patches = (img_size // patch_size) ** 2

        self.patch_embed = nn.Conv2d(3, embed_dim, patch_size, stride=patch_size)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, embed_dim))

        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads)
            for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
        self.head_dist = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        batch_size = x.size(0)
        x = self.patch_embed(x).flatten(2).transpose(1, 2)

        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        dist_tokens = self.dist_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_tokens, dist_tokens, x], dim=1)
        x = x + self.pos_embed

        for block in self.blocks:
            x = block(x)

        x = self.norm(x)
        cls_out = self.head(x[:, 0])
        dist_out = self.head_dist(x[:, 1])

        if self.training:
            return cls_out, dist_out
        return (cls_out + dist_out) / 2


class DistillationLoss(nn.Module):
    def __init__(self, teacher, temperature=3.0, alpha=0.5):
        super().__init__()
        self.teacher = teacher
        self.temperature = temperature
        self.alpha = alpha

    def forward(self, student_cls, student_dist, labels, images):
        ce_loss = F.cross_entropy(student_cls, labels)

        with torch.no_grad():
            teacher_logits = self.teacher(images)

        soft_targets = F.softmax(teacher_logits / self.temperature, dim=-1)
        soft_student = F.log_softmax(student_dist / self.temperature, dim=-1)
        distill_loss = F.kl_div(soft_student, soft_targets, reduction="batchmean")
        distill_loss = distill_loss * (self.temperature ** 2)

        return self.alpha * ce_loss + (1 - self.alpha) * distill_loss

Swin Transformer

Swin introduces hierarchical structure and shifted windows for efficiency.

PYTHON
class WindowAttention(nn.Module):
    def __init__(self, dim, window_size, num_heads):
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)

        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size - 1) ** 2, num_heads)
        )
        self._init_relative_position_index()

    def _init_relative_position_index(self):
        coords = torch.arange(self.window_size)
        coords = torch.stack(torch.meshgrid([coords, coords], indexing="ij"))
        coords_flatten = coords.flatten(1)

        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        relative_coords[:, :, 0] += self.window_size - 1
        relative_coords[:, :, 1] += self.window_size - 1
        relative_coords[:, :, 0] *= 2 * self.window_size - 1

        self.register_buffer("relative_position_index", relative_coords.sum(-1))

    def forward(self, x, mask=None):
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale

        relative_bias = self.relative_position_bias_table[
            self.relative_position_index.view(-1)
        ].view(N, N, -1).permute(2, 0, 1)
        attn = attn + relative_bias.unsqueeze(0)

        if mask is not None:
            attn = attn + mask.unsqueeze(1).unsqueeze(0)

        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        return self.proj(x)


def window_partition(x, window_size):
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
    return windows.view(-1, window_size * window_size, C)


def window_reverse(windows, window_size, H, W):
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
    return x.view(B, H, W, -1)

Efficient Attention Mechanisms

Various approaches reduce the quadratic complexity of attention.

PYTHON
class LinearAttention(nn.Module):
    def __init__(self, dim, num_heads, kernel="elu"):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)
        self.kernel = kernel

    def feature_map(self, x):
        if self.kernel == "elu":
            return F.elu(x) + 1
        return F.softmax(x, dim=-1)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)

        q = self.feature_map(q)
        k = self.feature_map(k)

        kv = torch.einsum("bhnd,bhnv->bhdv", k, v)
        qkv = torch.einsum("bhnd,bhdv->bhnv", q, kv)

        z = torch.einsum("bhnd,bhd->bhn", q, k.sum(dim=2))
        out = qkv / (z.unsqueeze(-1) + 1e-6)

        return self.proj(out.transpose(1, 2).reshape(B, N, C))


class PoolingAttention(nn.Module):
    def __init__(self, dim, num_heads, pool_size=7):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.pool_size = pool_size

        self.q_proj = nn.Linear(dim, dim)
        self.kv_proj = nn.Linear(dim, dim * 2)
        self.out_proj = nn.Linear(dim, dim)

        self.pool = nn.AdaptiveAvgPool2d(pool_size)

    def forward(self, x, H, W):
        B, N, C = x.shape
        q = self.q_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)

        x_2d = x.transpose(1, 2).reshape(B, C, H, W)
        x_pooled = self.pool(x_2d).flatten(2).transpose(1, 2)

        kv = self.kv_proj(x_pooled).reshape(B, -1, 2, self.num_heads, self.head_dim)
        k, v = kv.permute(2, 0, 3, 1, 4).unbind(0)

        attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
        attn = attn.softmax(dim=-1)
        out = (attn @ v).transpose(1, 2).reshape(B, N, C)

        return self.out_proj(out)

Self-Supervised Vision Transformers

Pre-training ViTs without labels through masked image modeling.

PYTHON
class MAE(nn.Module):
    def __init__(self, encoder, decoder_dim=512, decoder_depth=8, mask_ratio=0.75):
        super().__init__()
        self.encoder = encoder
        self.mask_ratio = mask_ratio

        self.decoder_embed = nn.Linear(encoder.embed_dim, decoder_dim)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_dim))

        self.decoder_blocks = nn.ModuleList([
            TransformerBlock(decoder_dim, num_heads=8)
            for _ in range(decoder_depth)
        ])

        self.decoder_norm = nn.LayerNorm(decoder_dim)
        self.decoder_pred = nn.Linear(decoder_dim, encoder.patch_size ** 2 * 3)

    def random_masking(self, x):
        B, N, D = x.shape
        num_keep = int(N * (1 - self.mask_ratio))

        noise = torch.rand(B, N, device=x.device)
        ids_shuffle = noise.argsort(dim=1)
        ids_restore = ids_shuffle.argsort(dim=1)

        ids_keep = ids_shuffle[:, :num_keep]
        x_masked = torch.gather(x, 1, ids_keep.unsqueeze(-1).expand(-1, -1, D))

        return x_masked, ids_restore

    def forward(self, images):
        x = self.encoder.patch_embed(images).flatten(2).transpose(1, 2)
        x = x + self.encoder.pos_embed[:, 1:]

        x, ids_restore = self.random_masking(x)

        for block in self.encoder.blocks:
            x = block(x)
        x = self.encoder.norm(x)

        x = self.decoder_embed(x)

        mask_tokens = self.mask_token.expand(x.size(0), ids_restore.size(1) - x.size(1), -1)
        x = torch.cat([x, mask_tokens], dim=1)
        x = torch.gather(x, 1, ids_restore.unsqueeze(-1).expand(-1, -1, x.size(-1)))

        for block in self.decoder_blocks:
            x = block(x)
        x = self.decoder_norm(x)
        pred = self.decoder_pred(x)

        return pred

Key Takeaways

Advanced ViT techniques address efficiency and data requirements. DeiT enables training on smaller datasets through distillation. Swin Transformer uses hierarchical windows for linear complexity. Linear attention approximations reduce computational cost. Self-supervised methods like MAE pre-train ViTs without labels. These innovations make vision transformers practical across diverse applications.