Advanced Expert 105 min read

Chapter 21: Vision-Language Foundations

CLIP, BLIP, LLaVA, and multimodal architectures.

Libraries covered: Hugging Face Transformers

Learning Objectives

["Use CLIP for zero-shot classification", "Build vision-language models", "Handle multimodal data"]


21.1 Vision Transformer Architecture Advanced

Vision Transformer Architecture

The Vision Transformer (ViT) marked a paradigm shift in computer vision by demonstrating that the transformer architecture, originally designed for natural language processing, could achieve state-of-the-art results on image classification when trained on sufficient data. Rather than relying on the inductive biases of convolutional neural networks—locality and translation equivariance—ViT treats images as sequences of patches and learns spatial relationships entirely through attention mechanisms. This approach opened new possibilities for unified architectures across modalities and sparked intense research into adapting transformers for visual understanding.

From Sequences to Images

Transformers process sequences of tokens, but images are 2D grids of pixels. The key insight of ViT is to convert images into sequences by dividing them into fixed-size patches, flattening each patch into a vector, and treating these vectors as the input sequence. A 224×224 image divided into 16×16 patches yields 196 patch tokens, a manageable sequence length for transformer processing.

PYTHON
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def image_to_patches(image, patch_size=16):
    """
    Convert an image into a sequence of flattened patches.

    Args:
        image: Tensor of shape (B, C, H, W)
        patch_size: Size of each square patch

    Returns:
        patches: Tensor of shape (B, num_patches, patch_dim)
        where patch_dim = C * patch_size * patch_size
    """
    B, C, H, W = image.shape
    assert H % patch_size == 0 and W % patch_size == 0, \
        f"Image dimensions must be divisible by patch size"

    # Number of patches along each dimension
    num_patches_h = H // patch_size
    num_patches_w = W // patch_size
    num_patches = num_patches_h * num_patches_w

    # Reshape: (B, C, H, W) -> (B, C, n_h, p_h, n_w, p_w)
    patches = image.reshape(B, C, num_patches_h, patch_size, num_patches_w, patch_size)

    # Rearrange: (B, C, n_h, p_h, n_w, p_w) -> (B, n_h, n_w, p_h, p_w, C)
    patches = patches.permute(0, 2, 4, 3, 5, 1)

    # Flatten patches: (B, n_h, n_w, p_h, p_w, C) -> (B, num_patches, patch_dim)
    patches = patches.reshape(B, num_patches, -1)

    return patches


# Demonstrate patch extraction
def visualize_patching():
    """Show how an image is converted to patch sequence."""
    # Example: 224x224 RGB image with 16x16 patches
    B, C, H, W = 1, 3, 224, 224
    patch_size = 16

    image = torch.randn(B, C, H, W)
    patches = image_to_patches(image, patch_size)

    num_patches = (H // patch_size) * (W // patch_size)
    patch_dim = C * patch_size * patch_size

    print("Vision Transformer Patching:")
    print(f"  Input image: {B} x {C} x {H} x {W}")
    print(f"  Patch size: {patch_size} x {patch_size}")
    print(f"  Number of patches: {num_patches} ({H//patch_size} x {W//patch_size})")
    print(f"  Patch dimension: {patch_dim} ({C} x {patch_size} x {patch_size})")
    print(f"  Output sequence: {patches.shape}")

visualize_patching()

This patching strategy preserves local structure within each patch while allowing the transformer to learn global relationships between patches through attention. The patch size represents a trade-off: smaller patches capture finer details but create longer sequences with quadratic attention cost, while larger patches reduce computational cost but may miss fine-grained features.

Patch Embedding Layer

The patch embedding layer projects flattened patches to the transformer's hidden dimension. While ViT can use a simple linear projection, implementing this as a convolution with kernel and stride equal to the patch size is more efficient and produces identical results.

PYTHON
class PatchEmbedding(nn.Module):
    """
    Convert image to patch embeddings.
    Uses convolution for efficient implementation.
    """

    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2

        # Convolution with kernel_size=stride=patch_size extracts non-overlapping patches
        # and projects them to embed_dim in one operation
        self.projection = nn.Conv2d(
            in_channels,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )

    def forward(self, x):
        """
        Args:
            x: Input image (B, C, H, W)
        Returns:
            Patch embeddings (B, num_patches, embed_dim)
        """
        B, C, H, W = x.shape
        assert H == self.img_size and W == self.img_size, \
            f"Input size ({H}x{W}) doesn't match model ({self.img_size}x{self.img_size})"

        # (B, C, H, W) -> (B, embed_dim, H/P, W/P)
        x = self.projection(x)

        # (B, embed_dim, H/P, W/P) -> (B, embed_dim, num_patches)
        x = x.flatten(2)

        # (B, embed_dim, num_patches) -> (B, num_patches, embed_dim)
        x = x.transpose(1, 2)

        return x


class PatchEmbeddingLinear(nn.Module):
    """
    Alternative implementation using reshape and linear layer.
    Mathematically equivalent to convolution approach.
    """

    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        patch_dim = in_channels * patch_size * patch_size

        self.projection = nn.Linear(patch_dim, embed_dim)

    def forward(self, x):
        B, C, H, W = x.shape
        p = self.patch_size

        # Reshape to patches
        x = x.reshape(B, C, H // p, p, W // p, p)
        x = x.permute(0, 2, 4, 3, 5, 1)  # (B, H/p, W/p, p, p, C)
        x = x.reshape(B, self.num_patches, -1)  # (B, num_patches, patch_dim)

        # Linear projection
        x = self.projection(x)

        return x

Positional Embeddings

Unlike CNNs that inherently encode spatial information through local connectivity, transformers have no notion of position. ViT adds learnable position embeddings to patch embeddings, allowing the model to learn spatial relationships. These embeddings are typically initialized randomly and learned during training.

PYTHON
class PositionalEmbedding(nn.Module):
    """
    Learnable 1D positional embeddings for ViT.
    Added to patch embeddings to encode spatial position.
    """

    def __init__(self, num_patches, embed_dim):
        super().__init__()
        # +1 for the [CLS] token
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))

        # Initialize with small random values
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

    def forward(self, x):
        """Add positional embeddings to input."""
        return x + self.pos_embed


class SinusoidalPositionalEmbedding(nn.Module):
    """
    Fixed sinusoidal positional embeddings (alternative to learned).
    Similar to original Transformer, adapted for 2D positions.
    """

    def __init__(self, num_patches, embed_dim, temperature=10000):
        super().__init__()
        self.embed_dim = embed_dim

        # Compute grid of positions
        grid_size = int(num_patches ** 0.5)
        pos_h = torch.arange(grid_size).float()
        pos_w = torch.arange(grid_size).float()
        grid = torch.stack(torch.meshgrid(pos_h, pos_w, indexing='ij'), dim=-1)
        grid = grid.reshape(-1, 2)  # (num_patches, 2)

        # Compute sinusoidal embeddings
        dim_t = torch.arange(embed_dim // 4).float()
        dim_t = temperature ** (2 * dim_t / (embed_dim // 2))

        pos_embed = torch.zeros(num_patches, embed_dim)
        pos_embed[:, 0::4] = torch.sin(grid[:, 0:1] / dim_t)
        pos_embed[:, 1::4] = torch.cos(grid[:, 0:1] / dim_t)
        pos_embed[:, 2::4] = torch.sin(grid[:, 1:2] / dim_t)
        pos_embed[:, 3::4] = torch.cos(grid[:, 1:2] / dim_t)

        # Add CLS token position (zeros)
        cls_pos = torch.zeros(1, embed_dim)
        pos_embed = torch.cat([cls_pos, pos_embed], dim=0)

        self.register_buffer('pos_embed', pos_embed.unsqueeze(0))

    def forward(self, x):
        return x + self.pos_embed


def visualize_position_embeddings():
    """
    Show the learned spatial structure in position embeddings.
    Trained ViT position embeddings show clear 2D structure.
    """
    num_patches = 196  # 14x14 grid
    embed_dim = 768

    pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))

    # Compute similarity between position embeddings
    # In trained models, nearby patches have similar embeddings
    pos_only = pos_embed[0, 1:]  # Exclude CLS token
    similarity = torch.matmul(pos_only, pos_only.T)

    print("Position Embedding Similarity Matrix:")
    print(f"  Shape: {similarity.shape}")
    print("  In trained ViT, this shows 2D spatial structure")
    print("  Nearby patches have high similarity")

The learned position embeddings in trained ViT models show remarkable structure: embeddings for spatially close patches are similar, and the model learns row and column structure despite receiving only 1D position indices. This demonstrates that transformers can discover spatial relationships from data.

The CLS Token

Following BERT, ViT prepends a learnable [CLS] token to the sequence of patch embeddings. After processing through transformer layers, the final representation of this token serves as the aggregate image representation for classification. The CLS token attends to all patches, effectively learning to summarize the entire image.

PYTHON
class CLSToken(nn.Module):
    """
    Learnable [CLS] token prepended to patch sequence.
    Its final representation is used for classification.
    """

    def __init__(self, embed_dim):
        super().__init__()
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        nn.init.trunc_normal_(self.cls_token, std=0.02)

    def forward(self, x):
        """
        Prepend CLS token to input sequence.

        Args:
            x: Patch embeddings (B, num_patches, embed_dim)
        Returns:
            Sequence with CLS token (B, num_patches + 1, embed_dim)
        """
        B = x.shape[0]
        cls_tokens = self.cls_token.expand(B, -1, -1)
        return torch.cat([cls_tokens, x], dim=1)


def alternative_pooling_strategies():
    """
    Alternatives to CLS token for image representation.
    """
    # 1. Global Average Pooling (GAP)
    # Average all patch representations
    def gap_pooling(x):
        # x: (B, num_patches + 1, embed_dim)
        # Exclude CLS token, average patches
        return x[:, 1:].mean(dim=1)  # (B, embed_dim)

    # 2. Global Max Pooling
    def gmp_pooling(x):
        return x[:, 1:].max(dim=1)[0]  # (B, embed_dim)

    # 3. Attention Pooling
    class AttentionPooling(nn.Module):
        def __init__(self, embed_dim):
            super().__init__()
            self.attention = nn.Linear(embed_dim, 1)

        def forward(self, x):
            # x: (B, seq_len, embed_dim)
            weights = F.softmax(self.attention(x), dim=1)  # (B, seq_len, 1)
            return (weights * x).sum(dim=1)  # (B, embed_dim)

    print("Pooling Strategies for ViT:")
    print("  CLS Token: Standard ViT approach, learns to aggregate")
    print("  GAP: Simple averaging, works well for some tasks")
    print("  Attention Pooling: Learned weighted average")

Multi-Head Self-Attention for Vision

The core of the transformer is multi-head self-attention, which allows each patch to attend to all other patches. For vision, this means every part of the image can directly influence every other part, unlike CNNs where information flows only through local connections.

PYTHON
class MultiHeadSelfAttention(nn.Module):
    """
    Multi-Head Self-Attention mechanism for Vision Transformer.
    """

    def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True):
        super().__init__()
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5  # 1/sqrt(d_k)

        # Combined QKV projection for efficiency
        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=bias)

        # Output projection
        self.proj = nn.Linear(embed_dim, embed_dim, bias=bias)

        self.attn_dropout = nn.Dropout(dropout)
        self.proj_dropout = nn.Dropout(dropout)

    def forward(self, x, return_attention=False):
        """
        Args:
            x: Input tensor (B, seq_len, embed_dim)
            return_attention: Whether to return attention weights

        Returns:
            Output tensor (B, seq_len, embed_dim)
            Optionally: attention weights (B, num_heads, seq_len, seq_len)
        """
        B, N, C = x.shape

        # Compute Q, K, V in one projection
        qkv = self.qkv(x)  # (B, N, 3 * embed_dim)
        qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, num_heads, N, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Scaled dot-product attention
        # (B, num_heads, N, head_dim) @ (B, num_heads, head_dim, N)
        attn = (q @ k.transpose(-2, -1)) * self.scale  # (B, num_heads, N, N)
        attn = F.softmax(attn, dim=-1)
        attn = self.attn_dropout(attn)

        # Apply attention to values
        # (B, num_heads, N, N) @ (B, num_heads, N, head_dim)
        x = attn @ v  # (B, num_heads, N, head_dim)

        # Reshape back
        x = x.transpose(1, 2).reshape(B, N, C)  # (B, N, embed_dim)

        # Output projection
        x = self.proj(x)
        x = self.proj_dropout(x)

        if return_attention:
            return x, attn
        return x


def visualize_attention_patterns():
    """
    Demonstrate how attention connects patches across the image.
    """
    print("ViT Attention Characteristics:")
    print("  - Every patch can attend to every other patch")
    print("  - Global receptive field from the first layer")
    print("  - Different heads learn different attention patterns:")
    print("    * Some attend to nearby patches (local)")
    print("    * Some attend to semantically similar regions")
    print("    * Some attend to specific image locations")
    print("  - CLS token attends to all patches, aggregating information")

visualize_attention_patterns()

The attention mechanism computes pairwise relationships between all patches with complexity $O(N^2)$ where $N$ is the number of patches. For a 224×224 image with 16×16 patches, this is $196^2 = 38,416$ attention computations per layer per head—substantial but manageable.

MLP Block

Each transformer block contains a multi-layer perceptron (MLP) applied independently to each token. The MLP typically expands the hidden dimension by a factor of 4, applies a non-linearity, then projects back to the original dimension.

PYTHON
class MLP(nn.Module):
    """
    Multi-Layer Perceptron block in Vision Transformer.
    Applied independently to each token position.
    """

    def __init__(self, embed_dim, mlp_ratio=4.0, dropout=0.0, activation=nn.GELU):
        super().__init__()
        hidden_dim = int(embed_dim * mlp_ratio)

        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.act = activation()
        self.fc2 = nn.Linear(hidden_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x


class GLU_MLP(nn.Module):
    """
    Gated Linear Unit MLP variant.
    Used in some modern ViT variants for improved performance.
    """

    def __init__(self, embed_dim, mlp_ratio=4.0, dropout=0.0):
        super().__init__()
        hidden_dim = int(embed_dim * mlp_ratio)

        # Two parallel projections for gating
        self.w1 = nn.Linear(embed_dim, hidden_dim)
        self.w2 = nn.Linear(embed_dim, hidden_dim)
        self.w3 = nn.Linear(hidden_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # SwiGLU: Swish(W1(x)) * W2(x)
        hidden = F.silu(self.w1(x)) * self.w2(x)
        hidden = self.dropout(hidden)
        output = self.w3(hidden)
        return self.dropout(output)

Transformer Encoder Block

A transformer encoder block combines multi-head self-attention and MLP with residual connections and layer normalization. ViT uses pre-normalization (LayerNorm before attention/MLP) rather than the original post-normalization, which improves training stability.

PYTHON
class TransformerEncoderBlock(nn.Module):
    """
    Single Transformer Encoder block for ViT.
    Combines self-attention and MLP with residual connections.
    """

    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.0,
                 attention_dropout=0.0, drop_path=0.0):
        super().__init__()

        # Pre-normalization
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadSelfAttention(
            embed_dim, num_heads, dropout=attention_dropout
        )

        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, mlp_ratio, dropout)

        # Stochastic depth (drop path) for regularization
        self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()

    def forward(self, x):
        # Self-attention with residual
        x = x + self.drop_path(self.attn(self.norm1(x)))

        # MLP with residual
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x


class DropPath(nn.Module):
    """
    Stochastic Depth: randomly drop entire residual branches during training.
    Regularization technique that improves generalization.
    """

    def __init__(self, drop_prob=0.0):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if self.drop_prob == 0 or not self.training:
            return x

        keep_prob = 1 - self.drop_prob
        # Generate random tensor, same for all elements in batch
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor = random_tensor.floor()

        # Scale to maintain expected value
        output = x / keep_prob * random_tensor
        return output

Complete Vision Transformer

Combining all components yields the complete Vision Transformer architecture:

PYTHON
class VisionTransformer(nn.Module):
    """
    Vision Transformer (ViT) for image classification.

    Architecture:
    1. Patch embedding: Convert image to sequence of patch embeddings
    2. Add CLS token and positional embeddings
    3. Process through transformer encoder blocks
    4. Extract CLS token representation
    5. Classification head
    """

    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_channels=3,
        num_classes=1000,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        dropout=0.0,
        attention_dropout=0.0,
        drop_path_rate=0.0,
    ):
        super().__init__()
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        self.num_patches = (img_size // patch_size) ** 2

        # Patch embedding
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)

        # CLS token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        # Position embeddings
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))
        self.pos_dropout = nn.Dropout(dropout)

        # Stochastic depth decay rule
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]

        # Transformer encoder blocks
        self.blocks = nn.ModuleList([
            TransformerEncoderBlock(
                embed_dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                dropout=dropout,
                attention_dropout=attention_dropout,
                drop_path=dpr[i]
            )
            for i in range(depth)
        ])

        # Final layer norm
        self.norm = nn.LayerNorm(embed_dim)

        # Classification head
        self.head = nn.Linear(embed_dim, num_classes)

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        # Initialize patch embedding
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)

        # Initialize other layers
        self.apply(self._init_layer_weights)

    def _init_layer_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    def forward_features(self, x):
        """Extract features without classification head."""
        B = x.shape[0]

        # Patch embedding
        x = self.patch_embed(x)  # (B, num_patches, embed_dim)

        # Prepend CLS token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)  # (B, num_patches + 1, embed_dim)

        # Add positional embeddings
        x = x + self.pos_embed
        x = self.pos_dropout(x)

        # Transformer blocks
        for block in self.blocks:
            x = block(x)

        # Final normalization
        x = self.norm(x)

        return x

    def forward(self, x):
        """Forward pass with classification."""
        x = self.forward_features(x)

        # Extract CLS token
        cls_token_final = x[:, 0]

        # Classification
        x = self.head(cls_token_final)

        return x


# Standard ViT configurations
def vit_tiny(num_classes=1000, **kwargs):
    return VisionTransformer(
        embed_dim=192, depth=12, num_heads=3,
        num_classes=num_classes, **kwargs
    )

def vit_small(num_classes=1000, **kwargs):
    return VisionTransformer(
        embed_dim=384, depth=12, num_heads=6,
        num_classes=num_classes, **kwargs
    )

def vit_base(num_classes=1000, **kwargs):
    return VisionTransformer(
        embed_dim=768, depth=12, num_heads=12,
        num_classes=num_classes, **kwargs
    )

def vit_large(num_classes=1000, **kwargs):
    return VisionTransformer(
        embed_dim=1024, depth=24, num_heads=16,
        num_classes=num_classes, **kwargs
    )

def vit_huge(num_classes=1000, **kwargs):
    return VisionTransformer(
        embed_dim=1280, depth=32, num_heads=16,
        num_classes=num_classes, **kwargs
    )


# Compare model sizes
def compare_vit_variants():
    """Compare different ViT model sizes."""
    variants = {
        'ViT-Ti': vit_tiny,
        'ViT-S': vit_small,
        'ViT-B': vit_base,
        'ViT-L': vit_large,
    }

    print("Vision Transformer Variants:")
    print("-" * 50)
    for name, factory in variants.items():
        model = factory()
        params = sum(p.numel() for p in model.parameters())
        print(f"{name:10} Parameters: {params / 1e6:.1f}M")

compare_vit_variants()

Key Takeaways

The Vision Transformer represents a fundamental departure from convolutional neural networks by treating images as sequences of patches and learning all spatial relationships through attention. The architecture consists of patch embedding to convert images to sequences, learnable position embeddings to encode spatial information, a CLS token for aggregating image-level representations, and stacked transformer encoder blocks with self-attention and MLP layers. ViT requires large-scale pretraining to match CNN performance because it lacks the inductive biases (locality, translation equivariance) that CNNs build in architecturally. However, when trained on sufficient data, ViT scales better than CNNs and achieves superior results. The attention mechanism provides global receptive fields from the first layer, enabling direct communication between distant image regions. Understanding ViT architecture is essential for working with modern vision systems, as transformers have become the dominant architecture for large-scale vision models.

21.2 ViT Variants and Improvements Advanced

ViT Variants and Improvements

Following the introduction of the original Vision Transformer, researchers rapidly developed numerous improvements addressing its limitations: the need for massive datasets, quadratic attention complexity, lack of hierarchical features, and training instability at scale. These variants have made vision transformers practical for a wider range of applications, from data-efficient training on ImageNet alone to efficient processing of high-resolution images. Understanding these improvements is essential for selecting and deploying modern vision transformers effectively.

DeiT: Data-Efficient Image Transformers

DeiT (Data-efficient Image Transformers) demonstrated that ViT could achieve competitive results when trained only on ImageNet, without the massive JFT-300M dataset used in the original paper. The key was a combination of strong data augmentation, regularization, and knowledge distillation from CNN teachers.

PYTHON
import torch
import torch.nn as nn
import torch.nn.functional as F

class DeiT(nn.Module):
    """
    DeiT: Data-efficient Image Transformer.
    Adds distillation token for knowledge distillation from CNN teacher.
    """

    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_channels=3,
        num_classes=1000,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        dropout=0.0,
    ):
        super().__init__()
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        self.num_patches = (img_size // patch_size) ** 2

        # Patch embedding
        self.patch_embed = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )

        # CLS token and distillation token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        # Position embeddings (+2 for CLS and distillation tokens)
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 2, embed_dim))
        self.pos_dropout = nn.Dropout(dropout)

        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerEncoderBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim)

        # Separate heads for classification and distillation
        self.head = nn.Linear(embed_dim, num_classes)
        self.head_dist = nn.Linear(embed_dim, num_classes)

        self._init_weights()

    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.dist_token, std=0.02)

    def forward(self, x):
        B = x.shape[0]

        # Patch embedding
        x = self.patch_embed(x).flatten(2).transpose(1, 2)

        # Prepend CLS and distillation tokens
        cls_tokens = self.cls_token.expand(B, -1, -1)
        dist_tokens = self.dist_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, dist_tokens, x], dim=1)

        # Add position embeddings
        x = x + self.pos_embed
        x = self.pos_dropout(x)

        # Transformer blocks
        for block in self.blocks:
            x = block(x)

        x = self.norm(x)

        # Extract CLS and distillation token outputs
        cls_output = x[:, 0]
        dist_output = x[:, 1]

        # Classification from both tokens
        x_cls = self.head(cls_output)
        x_dist = self.head_dist(dist_output)

        if self.training:
            return x_cls, x_dist
        else:
            # Average predictions at inference
            return (x_cls + x_dist) / 2


class DeiTDistillationLoss(nn.Module):
    """
    Distillation loss for DeiT training.
    Combines hard label loss with soft distillation from teacher.
    """

    def __init__(self, teacher_model, temperature=3.0, alpha=0.5):
        super().__init__()
        self.teacher = teacher_model
        self.teacher.eval()  # Teacher is frozen
        self.temperature = temperature
        self.alpha = alpha

        self.ce_loss = nn.CrossEntropyLoss()
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')

    def forward(self, student_outputs, targets, images):
        cls_logits, dist_logits = student_outputs

        # Hard label loss on CLS token
        loss_cls = self.ce_loss(cls_logits, targets)

        # Get teacher predictions
        with torch.no_grad():
            teacher_logits = self.teacher(images)

        # Soft distillation loss on distillation token
        # KL divergence between softened distributions
        student_soft = F.log_softmax(dist_logits / self.temperature, dim=1)
        teacher_soft = F.softmax(teacher_logits / self.temperature, dim=1)
        loss_dist = self.kl_loss(student_soft, teacher_soft) * (self.temperature ** 2)

        # Combined loss
        return self.alpha * loss_cls + (1 - self.alpha) * loss_dist


def deit_training_recipe():
    """
    DeiT training improvements that enable training on ImageNet alone.
    """
    training_config = {
        # Data augmentation
        'augmentation': {
            'RandAugment': {'num_ops': 2, 'magnitude': 9},
            'RandomErasing': {'probability': 0.25},
            'Mixup': {'alpha': 0.8},
            'CutMix': {'alpha': 1.0},
            'ColorJitter': {'strength': 0.3},
        },

        # Regularization
        'regularization': {
            'dropout': 0.0,  # No dropout in attention
            'drop_path': 0.1,  # Stochastic depth
            'label_smoothing': 0.1,
            'weight_decay': 0.05,
        },

        # Optimization
        'optimization': {
            'optimizer': 'AdamW',
            'lr': 1e-3,
            'warmup_epochs': 5,
            'epochs': 300,
            'batch_size': 1024,
            'lr_scheduler': 'cosine',
        },

        # Knowledge distillation
        'distillation': {
            'teacher': 'RegNetY-16GF',  # CNN teacher
            'temperature': 3.0,
            'alpha': 0.5,
        }
    }

    print("DeiT Training Recipe:")
    for category, settings in training_config.items():
        print(f"\n{category.upper()}:")
        for key, value in settings.items():
            print(f"  {key}: {value}")

    return training_config

deit_training_recipe()

DeiT's key insight was that careful training recipes could compensate for ViT's lack of inductive biases. The distillation token provides a mechanism for the model to learn from a CNN teacher, effectively inheriting some of the CNN's inductive biases through soft labels.

Swin Transformer: Hierarchical Vision Transformer

The Swin Transformer addresses ViT's limitations through hierarchical feature maps and local attention windows, making it suitable for dense prediction tasks and more efficient for high-resolution images.

PYTHON
class WindowAttention(nn.Module):
    """
    Window-based Multi-Head Self-Attention for Swin Transformer.
    Computes attention within local windows rather than globally.
    """

    def __init__(self, dim, window_size, num_heads):
        super().__init__()
        self.dim = dim
        self.window_size = window_size  # (Wh, Ww)
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)

        # Relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
        )
        nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)

        # Compute relative position index
        coords_h = torch.arange(window_size[0])
        coords_w = torch.arange(window_size[1])
        coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing='ij'))
        coords_flatten = coords.flatten(1)

        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        relative_coords[:, :, 0] += window_size[0] - 1
        relative_coords[:, :, 1] += window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)
        self.register_buffer('relative_position_index', relative_position_index)

    def forward(self, x, mask=None):
        """
        Args:
            x: (num_windows * B, window_size * window_size, C)
            mask: Attention mask for shifted windows
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale

        # Add relative position bias
        relative_position_bias = self.relative_position_bias_table[
            self.relative_position_index.view(-1)
        ].view(N, N, -1)
        relative_position_bias = relative_position_bias.permute(2, 0, 1)
        attn = attn + relative_position_bias.unsqueeze(0)

        # Apply mask for shifted window attention
        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N)
            attn = attn + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)

        attn = F.softmax(attn, dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        return x


def window_partition(x, window_size):
    """
    Partition feature map into non-overlapping windows.

    Args:
        x: (B, H, W, C)
        window_size: Window size

    Returns:
        windows: (num_windows * B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
    windows = windows.view(-1, window_size, window_size, C)
    return windows


def window_reverse(windows, window_size, H, W):
    """
    Reverse window partition.

    Args:
        windows: (num_windows * B, window_size, window_size, C)
        window_size: Window size
        H, W: Original feature map dimensions

    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x


class SwinTransformerBlock(nn.Module):
    """
    Swin Transformer Block with window attention and optional shift.
    """

    def __init__(self, dim, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4.0, dropout=0.0, drop_path=0.0):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size

        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, (window_size, window_size), num_heads)

        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(dim * mlp_ratio), dim),
            nn.Dropout(dropout)
        )

        self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()

    def forward(self, x, H, W):
        B, L, C = x.shape
        assert L == H * W, "Input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # Cyclic shift for cross-window connections
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
            attn_mask = self._create_mask(H, W, x.device)
        else:
            shifted_x = x
            attn_mask = None

        # Partition into windows
        x_windows = window_partition(shifted_x, self.window_size)
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)

        # Window attention
        attn_windows = self.attn(x_windows, mask=attn_mask)

        # Merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)

        # Reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x

        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

    def _create_mask(self, H, W, device):
        """Create attention mask for shifted window attention."""
        img_mask = torch.zeros((1, H, W, 1), device=device)
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1

        mask_windows = window_partition(img_mask, self.window_size)
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0))
        attn_mask = attn_mask.masked_fill(attn_mask == 0, float(0.0))
        return attn_mask


class PatchMerging(nn.Module):
    """
    Patch Merging layer for Swin Transformer.
    Reduces spatial resolution by 2x while doubling channels.
    Creates hierarchical feature maps like CNN downsampling.
    """

    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = nn.LayerNorm(4 * dim)

    def forward(self, x, H, W):
        B, L, C = x.shape
        assert L == H * W, "Input feature has wrong size"

        x = x.view(B, H, W, C)

        # Merge 2x2 patches
        x0 = x[:, 0::2, 0::2, :]  # (B, H/2, W/2, C)
        x1 = x[:, 1::2, 0::2, :]
        x2 = x[:, 0::2, 1::2, :]
        x3 = x[:, 1::2, 1::2, :]

        x = torch.cat([x0, x1, x2, x3], dim=-1)  # (B, H/2, W/2, 4*C)
        x = x.view(B, -1, 4 * C)

        x = self.norm(x)
        x = self.reduction(x)  # (B, H/2 * W/2, 2*C)

        return x

Swin Transformer's hierarchical design produces multi-scale features essential for dense prediction tasks. The shifted window mechanism enables cross-window connections without the quadratic cost of global attention.

CaiT: Class-Attention in Image Transformers

CaiT separates the self-attention among patches from the class attention that aggregates patch features for classification. This architectural change improves training stability and performance.

PYTHON
class ClassAttention(nn.Module):
    """
    Class Attention layer from CaiT.
    CLS token attends to patch tokens but patches don't attend to CLS.
    """

    def __init__(self, dim, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        # Q from CLS, K and V from patches
        self.q = nn.Linear(dim, dim)
        self.kv = nn.Linear(dim, dim * 2)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        """
        Args:
            x: (B, 1 + num_patches, dim) where first token is CLS
        """
        B, N, C = x.shape

        # CLS token query
        cls_token = x[:, 0:1]  # (B, 1, C)
        patch_tokens = x[:, 1:]  # (B, N-1, C)

        q = self.q(cls_token).reshape(B, 1, self.num_heads, self.head_dim)
        q = q.permute(0, 2, 1, 3)  # (B, heads, 1, head_dim)

        kv = self.kv(patch_tokens).reshape(B, N - 1, 2, self.num_heads, self.head_dim)
        kv = kv.permute(2, 0, 3, 1, 4)  # (2, B, heads, N-1, head_dim)
        k, v = kv[0], kv[1]

        # Attention: CLS attends to all patches
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)

        # Aggregate
        cls_token = (attn @ v).transpose(1, 2).reshape(B, 1, C)
        cls_token = self.proj(cls_token)

        # Return updated CLS with unchanged patches
        return torch.cat([cls_token, patch_tokens], dim=1)


class CaiTBlock(nn.Module):
    """
    CaiT Transformer Block.
    Processes patches with self-attention, without CLS token involvement.
    """

    def __init__(self, dim, num_heads, mlp_ratio=4.0, drop_path=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiHeadSelfAttention(dim, num_heads)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, mlp_ratio)
        self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()

        # LayerScale: learnable per-channel scaling
        self.gamma1 = nn.Parameter(1e-4 * torch.ones(dim))
        self.gamma2 = nn.Parameter(1e-4 * torch.ones(dim))

    def forward(self, x):
        # Self-attention with LayerScale
        x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x)))
        # MLP with LayerScale
        x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x)))
        return x


class CaiTClassAttentionBlock(nn.Module):
    """
    Class Attention Block for aggregating patch information to CLS token.
    Used in final layers after patch self-attention.
    """

    def __init__(self, dim, num_heads, mlp_ratio=4.0, drop_path=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = ClassAttention(dim, num_heads)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, mlp_ratio)
        self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()

        self.gamma1 = nn.Parameter(1e-4 * torch.ones(dim))
        self.gamma2 = nn.Parameter(1e-4 * torch.ones(dim))

    def forward(self, x, cls_token):
        # Prepend CLS token
        x_with_cls = torch.cat([cls_token, x], dim=1)

        # Class attention
        u = self.norm1(x_with_cls)
        u = self.attn(u)
        cls_token = cls_token + self.drop_path(self.gamma1 * u[:, 0:1])

        # MLP on CLS token only
        cls_token = cls_token + self.drop_path(self.gamma2 * self.mlp(self.norm2(cls_token)))

        return x, cls_token


class CaiT(nn.Module):
    """
    CaiT: Class-Attention in Image Transformers.
    Separates patch processing from class token aggregation.
    """

    def __init__(self, img_size=224, patch_size=16, num_classes=1000,
                 embed_dim=768, depth=24, num_heads=12, mlp_ratio=4.0,
                 class_attention_layers=2):
        super().__init__()

        self.patch_embed = PatchEmbedding(img_size, patch_size, 3, embed_dim)
        num_patches = self.patch_embed.num_patches

        # Learnable CLS token (added later, not in patch processing)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))

        # Patch processing blocks (no CLS token)
        self.blocks = nn.ModuleList([
            CaiTBlock(embed_dim, num_heads, mlp_ratio)
            for _ in range(depth)
        ])

        # Class attention blocks (CLS aggregates from patches)
        self.class_blocks = nn.ModuleList([
            CaiTClassAttentionBlock(embed_dim, num_heads, mlp_ratio)
            for _ in range(class_attention_layers)
        ])

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        self._init_weights()

    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)

    def forward(self, x):
        B = x.shape[0]

        # Patch embedding (no CLS token yet)
        x = self.patch_embed(x) + self.pos_embed

        # Patch self-attention
        for block in self.blocks:
            x = block(x)

        # Class attention: CLS token aggregates from patches
        cls_token = self.cls_token.expand(B, -1, -1)
        for block in self.class_blocks:
            x, cls_token = block(x, cls_token)

        # Classification
        cls_token = self.norm(cls_token)
        return self.head(cls_token.squeeze(1))

LayerScale and Other Stabilization Techniques

Training deep vision transformers requires careful attention to optimization stability. LayerScale, introduced with CaiT, helps by initializing residual branch contributions to near-zero values.

PYTHON
class LayerScale(nn.Module):
    """
    LayerScale: Learnable per-channel scaling of residual branch.
    Initialized to small values (e.g., 1e-4) for training stability.
    """

    def __init__(self, dim, init_value=1e-4):
        super().__init__()
        self.gamma = nn.Parameter(init_value * torch.ones(dim))

    def forward(self, x):
        return self.gamma * x


class StabilizedTransformerBlock(nn.Module):
    """
    Transformer block with modern stabilization techniques.
    """

    def __init__(self, dim, num_heads, mlp_ratio=4.0, drop_path=0.0,
                 layer_scale_init=1e-4, use_layer_scale=True):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiHeadSelfAttention(dim, num_heads)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, mlp_ratio)
        self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()

        # LayerScale for stability
        self.use_layer_scale = use_layer_scale
        if use_layer_scale:
            self.ls1 = LayerScale(dim, layer_scale_init)
            self.ls2 = LayerScale(dim, layer_scale_init)

    def forward(self, x):
        if self.use_layer_scale:
            x = x + self.drop_path(self.ls1(self.attn(self.norm1(x))))
            x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))
        else:
            x = x + self.drop_path(self.attn(self.norm1(x)))
            x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


def training_stabilization_techniques():
    """
    Summary of techniques for stable ViT training.
    """
    techniques = {
        'LayerScale': {
            'description': 'Learnable per-channel scaling of residual branches',
            'init_value': '1e-4 to 1e-6 for deep models',
            'benefit': 'Allows deeper models to train stably'
        },
        'Drop Path (Stochastic Depth)': {
            'description': 'Randomly drop entire residual branches',
            'typical_values': '0.1 to 0.5 depending on depth',
            'benefit': 'Regularization and implicit ensemble'
        },
        'Warmup': {
            'description': 'Linear LR increase from 0 to target',
            'typical_duration': '5-10 epochs or 10k steps',
            'benefit': 'Prevents early training instability'
        },
        'Gradient Clipping': {
            'description': 'Clip gradient norm to max value',
            'typical_value': '1.0',
            'benefit': 'Prevents gradient explosion'
        },
        'Pre-Normalization': {
            'description': 'LayerNorm before attention/MLP, not after',
            'benefit': 'More stable gradients'
        }
    }

    print("ViT Training Stabilization Techniques:")
    print("=" * 60)
    for name, info in techniques.items():
        print(f"\n{name}:")
        for key, value in info.items():
            print(f"  {key}: {value}")

training_stabilization_techniques()

PVT: Pyramid Vision Transformer

PVT introduces a pyramid structure to vision transformers, producing multi-scale features essential for dense prediction tasks like detection and segmentation.

PYTHON
class SpatialReductionAttention(nn.Module):
    """
    Spatial Reduction Attention from PVT.
    Reduces K, V spatial resolution to save computation.
    """

    def __init__(self, dim, num_heads, sr_ratio=1):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.sr_ratio = sr_ratio

        self.q = nn.Linear(dim, dim)
        self.kv = nn.Linear(dim, dim * 2)
        self.proj = nn.Linear(dim, dim)

        # Spatial reduction for K, V
        if sr_ratio > 1:
            self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
            self.norm = nn.LayerNorm(dim)

    def forward(self, x, H, W):
        B, N, C = x.shape

        # Query at full resolution
        q = self.q(x).reshape(B, N, self.num_heads, self.head_dim)
        q = q.permute(0, 2, 1, 3)  # (B, heads, N, head_dim)

        # Spatially reduce K, V
        if self.sr_ratio > 1:
            x_2d = x.transpose(1, 2).reshape(B, C, H, W)
            x_2d = self.sr(x_2d).flatten(2).transpose(1, 2)
            x_2d = self.norm(x_2d)
        else:
            x_2d = x

        kv = self.kv(x_2d).reshape(B, -1, 2, self.num_heads, self.head_dim)
        kv = kv.permute(2, 0, 3, 1, 4)
        k, v = kv[0], kv[1]

        # Attention
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)

        return x


class PVTBlock(nn.Module):
    """PVT Transformer Block with spatial reduction attention."""

    def __init__(self, dim, num_heads, mlp_ratio=4.0, sr_ratio=1, drop_path=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = SpatialReductionAttention(dim, num_heads, sr_ratio)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, mlp_ratio)
        self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()

    def forward(self, x, H, W):
        x = x + self.drop_path(self.attn(self.norm1(x), H, W))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


class PVTStage(nn.Module):
    """
    Single stage of PVT with patch embedding and transformer blocks.
    """

    def __init__(self, in_channels, out_channels, patch_size, num_blocks,
                 num_heads, mlp_ratio=4.0, sr_ratio=1):
        super().__init__()

        # Patch embedding / downsampling
        self.patch_embed = nn.Conv2d(
            in_channels, out_channels,
            kernel_size=patch_size, stride=patch_size
        )
        self.norm = nn.LayerNorm(out_channels)

        # Transformer blocks
        self.blocks = nn.ModuleList([
            PVTBlock(out_channels, num_heads, mlp_ratio, sr_ratio)
            for _ in range(num_blocks)
        ])

    def forward(self, x):
        # Patch embedding
        x = self.patch_embed(x)
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)  # (B, H*W, C)
        x = self.norm(x)

        # Transformer blocks
        for block in self.blocks:
            x = block(x, H, W)

        # Reshape back to spatial
        x = x.transpose(1, 2).reshape(B, C, H, W)
        return x


def compare_vit_variants():
    """Compare different ViT variants and their characteristics."""
    variants = {
        'ViT': {
            'attention': 'Global (full image)',
            'hierarchy': 'Single scale',
            'complexity': 'O(N²) where N = num_patches',
            'best_for': 'Classification with pretraining'
        },
        'DeiT': {
            'attention': 'Global (full image)',
            'hierarchy': 'Single scale',
            'complexity': 'O(N²)',
            'best_for': 'Data-efficient training on ImageNet'
        },
        'Swin': {
            'attention': 'Local windows + shifted windows',
            'hierarchy': 'Multi-scale (like CNN)',
            'complexity': 'O(N) linear in image size',
            'best_for': 'Dense prediction (detection, segmentation)'
        },
        'PVT': {
            'attention': 'Global with spatial reduction',
            'hierarchy': 'Multi-scale pyramid',
            'complexity': 'O(N²/r²) where r = reduction ratio',
            'best_for': 'Dense prediction with global context'
        },
        'CaiT': {
            'attention': 'Separate patch and class attention',
            'hierarchy': 'Single scale',
            'complexity': 'O(N²)',
            'best_for': 'Deep models, improved stability'
        }
    }

    print("\nViT Variant Comparison:")
    print("=" * 70)
    for name, info in variants.items():
        print(f"\n{name}:")
        for key, value in info.items():
            print(f"  {key:15}: {value}")

compare_vit_variants()

Key Takeaways

Vision Transformer variants have addressed the original ViT's key limitations through diverse architectural innovations. DeiT showed that careful training recipes with strong augmentation, regularization, and knowledge distillation enable competitive ImageNet-only training. Swin Transformer introduced hierarchical features and efficient window attention, making transformers practical for dense prediction tasks. CaiT separated patch processing from class aggregation and introduced LayerScale for training stability. PVT combined pyramid structures with spatial reduction attention for efficient multi-scale processing. These variants demonstrate that the transformer architecture is highly adaptable: by modifying attention patterns, adding hierarchy, or changing the training procedure, vision transformers can be tailored to different computational budgets, data regimes, and downstream tasks. Understanding these variations enables practitioners to select the most appropriate architecture for their specific requirements.

21.3 Self-Supervised Learning for Vision Advanced

Self-Supervised Learning for Vision

Self-supervised learning (SSL) has revolutionized computer vision by enabling models to learn powerful representations from unlabeled data. For Vision Transformers, SSL techniques are particularly effective because transformers can capture complex global relationships. This section explores contrastive learning, masked image modeling, and self-distillation.

The SSL Paradigm

SSL creates supervision signals from the data itself:

PYTHON
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from typing import List

class SSLFramework(nn.Module):
    """General SSL framework with encoder and projector."""
    def __init__(self, encoder, projector):
        super().__init__()
        self.encoder = encoder
        self.projector = projector

    def forward(self, x):
        return self.encoder(x), self.projector(self.encoder(x))

class Projector(nn.Module):
    """MLP projector for contrastive learning."""
    def __init__(self, in_dim=768, hidden=2048, out_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.BatchNorm1d(hidden), nn.ReLU(),
            nn.Linear(hidden, out_dim)
        )
    def forward(self, x): return self.net(x)

SimCLR: Contrastive Learning

SimCLR brings augmented views of same image together while pushing different images apart:

PYTHON
class NTXentLoss(nn.Module):
    """InfoNCE / NT-Xent loss for contrastive learning."""
    def __init__(self, temperature=0.5):
        super().__init__()
        self.temp = temperature

    def forward(self, z_i, z_j):
        B = z_i.size(0)
        z_i, z_j = F.normalize(z_i, dim=1), F.normalize(z_j, dim=1)
        z = torch.cat([z_i, z_j], dim=0)
        sim = torch.mm(z, z.t()) / self.temp
        mask = torch.eye(2*B, device=z.device).bool()
        sim = sim.masked_fill(mask, float('-inf'))
        labels = torch.cat([torch.arange(B)+B, torch.arange(B)]).to(z.device)
        return F.cross_entropy(sim, labels)

class SimCLR(nn.Module):
    """SimCLR with ViT encoder."""
    def __init__(self, encoder, dim=768, proj_dim=128):
        super().__init__()
        self.encoder = encoder
        self.projector = Projector(dim, 2048, proj_dim)
        self.loss_fn = NTXentLoss()

    def forward(self, v1, v2):
        h1, h2 = self.encoder(v1), self.encoder(v2)
        z1, z2 = self.projector(h1), self.projector(h2)
        return self.loss_fn(z1, z2), h1, h2

MoCo v3: Momentum Contrast

MoCo uses momentum-updated encoder for stable training:

PYTHON
class MoCoV3(nn.Module):
    """MoCo v3 for Vision Transformers."""
    def __init__(self, encoder, dim=768, proj_dim=256, momentum=0.99):
        super().__init__()
        self.encoder_q = encoder
        self.encoder_k = copy.deepcopy(encoder)
        for p in self.encoder_k.parameters(): p.requires_grad = False

        self.proj_q = nn.Sequential(
            nn.Linear(dim, 4096), nn.BatchNorm1d(4096), nn.ReLU(),
            nn.Linear(4096, proj_dim)
        )
        self.proj_k = copy.deepcopy(self.proj_q)
        for p in self.proj_k.parameters(): p.requires_grad = False

        self.predictor = nn.Sequential(
            nn.Linear(proj_dim, 4096), nn.BatchNorm1d(4096), nn.ReLU(),
            nn.Linear(4096, proj_dim)
        )
        self.m, self.temp = momentum, 0.2

    @torch.no_grad()
    def update_momentum(self):
        for pq, pk in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            pk.data = self.m * pk.data + (1-self.m) * pq.data
        for pq, pk in zip(self.proj_q.parameters(), self.proj_k.parameters()):
            pk.data = self.m * pk.data + (1-self.m) * pq.data

    def forward(self, v1, v2):
        q1 = self.predictor(self.proj_q(self.encoder_q(v1)))
        q2 = self.predictor(self.proj_q(self.encoder_q(v2)))
        with torch.no_grad():
            self.update_momentum()
            k1 = self.proj_k(self.encoder_k(v1))
            k2 = self.proj_k(self.encoder_k(v2))
        return self._loss(q1,k2) + self._loss(q2,k1)

    def _loss(self, q, k):
        q, k = F.normalize(q, dim=1), F.normalize(k, dim=1)
        logits = torch.cat([
            (q*k).sum(dim=1, keepdim=True),
            torch.mm(q, k.t())
        ], dim=1) / self.temp
        return F.cross_entropy(logits, torch.zeros(q.size(0), dtype=torch.long, device=q.device))

MAE: Masked Autoencoder

MAE learns by reconstructing randomly masked patches (75% masking):

PYTHON
class TransformerBlock(nn.Module):
    def __init__(self, dim, heads, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, heads, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim*mlp_ratio)), nn.GELU(),
            nn.Linear(int(dim*mlp_ratio), dim)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        return x + self.mlp(self.norm2(x))

class MAE(nn.Module):
    """Masked Autoencoder for self-supervised ViT pretraining."""
    def __init__(self, img_size=224, patch_size=16, dim=768,
                 enc_depth=12, dec_dim=512, dec_depth=8, mask_ratio=0.75):
        super().__init__()
        self.ps, self.mr = patch_size, mask_ratio
        num_patches = (img_size // patch_size) ** 2

        self.patch_embed = nn.Conv2d(3, dim, patch_size, patch_size)
        self.cls_token = nn.Parameter(torch.zeros(1,1,dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, dim))
        self.encoder = nn.ModuleList([TransformerBlock(dim, 12) for _ in range(enc_depth)])
        self.enc_norm = nn.LayerNorm(dim)

        self.dec_embed = nn.Linear(dim, dec_dim)
        self.mask_token = nn.Parameter(torch.zeros(1,1,dec_dim))
        self.dec_pos = nn.Parameter(torch.zeros(1, num_patches+1, dec_dim))
        self.decoder = nn.ModuleList([TransformerBlock(dec_dim, 16) for _ in range(dec_depth)])
        self.dec_norm = nn.LayerNorm(dec_dim)
        self.pred = nn.Linear(dec_dim, patch_size**2 * 3)

    def random_mask(self, x):
        B, N, D = x.shape
        keep = int(N * (1-self.mr))
        noise = torch.rand(B, N, device=x.device)
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)
        x_masked = torch.gather(x, 1, ids_shuffle[:,:keep].unsqueeze(-1).expand(-1,-1,D))
        mask = torch.ones(B, N, device=x.device)
        mask[:,:keep] = 0
        mask = torch.gather(mask, 1, ids_restore)
        return x_masked, mask, ids_restore

    def forward(self, imgs):
        # Encode
        x = self.patch_embed(imgs).flatten(2).transpose(1,2)
        x = x + self.pos_embed[:,1:,:]
        x, mask, ids_restore = self.random_mask(x)
        cls = (self.cls_token + self.pos_embed[:,:1,:]).expand(x.size(0),-1,-1)
        x = torch.cat([cls, x], dim=1)
        for blk in self.encoder: x = blk(x)
        x = self.enc_norm(x)

        # Decode
        x = self.dec_embed(x)
        B, N_enc, D = x.shape
        N = ids_restore.size(1)
        mask_tokens = self.mask_token.repeat(B, N+1-N_enc, 1)
        x_full = torch.cat([x[:,1:,:], mask_tokens], dim=1)
        x_full = torch.gather(x_full, 1, ids_restore.unsqueeze(-1).expand(-1,-1,D))
        x = torch.cat([x[:,:1,:], x_full], dim=1) + self.dec_pos
        for blk in self.decoder: x = blk(x)
        pred = self.pred(self.dec_norm(x)[:,1:,:])

        # Loss on masked patches
        target = self._patchify(imgs)
        target = (target - target.mean(-1, keepdim=True)) / (target.var(-1, keepdim=True)+1e-6).sqrt()
        loss = ((pred - target)**2).mean(-1)
        return (loss * mask).sum() / mask.sum()

    def _patchify(self, imgs):
        B,C,H,W = imgs.shape
        P = self.ps
        return imgs.reshape(B,C,H//P,P,W//P,P).permute(0,2,4,3,5,1).reshape(B,-1,P*P*C)

DINO: Self-Distillation

DINO learns by matching student outputs to momentum teacher:

PYTHON
class DINOHead(nn.Module):
    def __init__(self, in_dim, out_dim, hidden=2048):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.GELU(),
            nn.Linear(hidden, 256)
        )
        self.proj = nn.Linear(256, out_dim, bias=False)

    def forward(self, x):
        return self.proj(F.normalize(self.mlp(x), dim=-1))

class DINO(nn.Module):
    """DINO: Self-Distillation with No Labels."""
    def __init__(self, encoder, dim=768, out_dim=65536, momentum=0.996):
        super().__init__()
        self.student = encoder
        self.student_head = DINOHead(dim, out_dim)
        self.teacher = copy.deepcopy(encoder)
        self.teacher_head = DINOHead(dim, out_dim)
        for p in list(self.teacher.parameters()) + list(self.teacher_head.parameters()):
            p.requires_grad = False

        self.register_buffer('center', torch.zeros(1, out_dim))
        self.m = momentum
        self.teacher_temp, self.student_temp = 0.04, 0.1

    @torch.no_grad()
    def update_teacher(self):
        for ps, pt in zip(self.student.parameters(), self.teacher.parameters()):
            pt.data = self.m * pt.data + (1-self.m) * ps.data
        for ps, pt in zip(self.student_head.parameters(), self.teacher_head.parameters()):
            pt.data = self.m * pt.data + (1-self.m) * ps.data

    def forward(self, global_views: List[torch.Tensor], local_views: List[torch.Tensor]):
        all_views = global_views + local_views
        s_out = torch.stack([self.student_head(self.student(v)) for v in all_views])

        with torch.no_grad():
            t_out = torch.stack([self.teacher_head(self.teacher(v)) for v in global_views])
            t_out = F.softmax((t_out - self.center) / self.teacher_temp, dim=-1)

        s_out = F.log_softmax(s_out / self.student_temp, dim=-1)

        loss, n = 0, 0
        for ti, t in enumerate(t_out):
            for si, s in enumerate(s_out):
                if si != ti:
                    loss += -torch.sum(t * s, dim=-1).mean()
                    n += 1

        self.update_teacher()
        self.center = 0.9 * self.center + 0.1 * t_out.mean(dim=(0,1))
        return loss / n

DINOv2: Combined Objectives

DINOv2 combines DINO with masked image modeling:

PYTHON
class DINOv2Loss(nn.Module):
    """DINOv2: DINO + iBOT patch-level objectives."""
    def __init__(self, out_dim=65536, patch_dim=8192):
        super().__init__()
        self.register_buffer('center', torch.zeros(1, out_dim))
        self.register_buffer('patch_center', torch.zeros(1, patch_dim))
        self.t_temp, self.s_temp = 0.04, 0.1

    def forward(self, s_cls, t_cls, s_patch, t_patch, mask):
        # CLS token loss
        t_cls = F.softmax((t_cls - self.center) / self.t_temp, dim=-1)
        s_cls = F.log_softmax(s_cls / self.s_temp, dim=-1)
        cls_loss = -torch.sum(t_cls * s_cls, dim=-1).mean()

        # Patch loss on masked positions
        t_patch = F.softmax((t_patch - self.patch_center) / self.t_temp, dim=-1)
        s_patch = F.log_softmax(s_patch / self.s_temp, dim=-1)
        patch_loss = (-torch.sum(t_patch * s_patch, dim=-1) * mask).sum() / mask.sum()

        return cls_loss + patch_loss

Linear Probing Evaluation

PYTHON
def linear_probe(encoder, train_loader, val_loader, num_classes, device, epochs=100):
    """Evaluate representations with frozen encoder + linear classifier."""
    encoder.eval()
    for p in encoder.parameters(): p.requires_grad = False

    # Get feature dimension
    with torch.no_grad():
        feat_dim = encoder(next(iter(train_loader))[0][:1].to(device)).size(-1)

    classifier = nn.Linear(feat_dim, num_classes).to(device)
    opt = torch.optim.SGD(classifier.parameters(), lr=0.1, momentum=0.9)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, epochs)

    best_acc = 0
    for epoch in range(epochs):
        classifier.train()
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            with torch.no_grad(): feats = encoder(imgs)
            loss = F.cross_entropy(classifier(feats), labels)
            opt.zero_grad(); loss.backward(); opt.step()
        sched.step()

        classifier.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for imgs, labels in val_loader:
                imgs, labels = imgs.to(device), labels.to(device)
                pred = classifier(encoder(imgs)).argmax(1)
                correct += (pred == labels).sum().item()
                total += labels.size(0)
        best_acc = max(best_acc, correct/total)

    return best_acc

SSL Methods Comparison

| Method | Paradigm | Key Innovation | Batch Size | |--------|----------|----------------|------------| | SimCLR | Contrastive | NT-Xent loss, strong augmentation | 4096+ | | MoCo v3 | Contrastive | Momentum encoder, memory efficient | 256-1024 | | MAE | Masked Modeling | 75% masking, pixel reconstruction | Standard | | BEiT | Masked Modeling | Discrete visual tokens | Standard | | DINO | Self-Distillation | Teacher-student, centering | Standard | | DINOv2 | Combined | DINO + iBOT objectives | Large |

Key Takeaways

Self-supervised learning for vision has evolved through several paradigms. Contrastive methods (SimCLR, MoCo) learn by comparing augmented views but require many negatives or momentum encoders. Masked image modeling (MAE) reconstructs masked patches efficiently with asymmetric encoder-decoder. Self-distillation (DINO) matches student to momentum teacher without explicit negatives, producing features with emergent semantic properties. Modern methods like DINOv2 combine multiple objectives - image-level distillation and patch-level masked modeling - achieving state-of-the-art universal visual features that transfer well across tasks.

21.4 ViT for Dense Prediction Advanced

ViT for Dense Prediction

While Vision Transformers were initially designed for image classification, they have been successfully adapted for dense prediction tasks including semantic segmentation, object detection, and depth estimation. This section explores architectures that transform ViT's single-scale global features into multi-scale hierarchical representations suitable for pixel-level predictions.

Dense Prediction Challenges for ViT

Standard ViT outputs a single-scale feature map, but dense prediction tasks require:

PYTHON
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Optional

class DensePredictionChallenges:
    """
    Key challenges adapting ViT for dense prediction:

    1. Single-scale features: ViT produces features at one resolution
       (e.g., 14x14 for 224px image with 16x16 patches)

    2. No hierarchical features: Unlike CNNs, ViT lacks multi-scale
       feature pyramids needed for detecting objects at various sizes

    3. High resolution: Dense prediction often needs higher resolution
       than typical ViT patch sizes provide

    4. Positional encoding: Fixed positional encodings limit
       flexibility for varying input resolutions
    """

    @staticmethod
    def demonstrate_resolution_issue():
        # ViT-B with 16x16 patches on 224x224 image
        img_size, patch_size = 224, 16
        num_patches = (img_size // patch_size) ** 2  # 196 patches
        feature_resolution = img_size // patch_size  # 14x14

        print(f"Input: {img_size}x{img_size}")
        print(f"Patches: {num_patches} ({feature_resolution}x{feature_resolution})")
        print(f"For segmentation, need to upsample 16x to original resolution")

        # For 512x512 input (common in segmentation)
        img_size_seg = 512
        num_patches_seg = (img_size_seg // patch_size) ** 2  # 1024 patches
        print(f"\nFor {img_size_seg}x{img_size_seg}: {num_patches_seg} patches")
        print(f"Attention complexity: O({num_patches_seg}²) = O(1M)")

SETR: Segmentation Transformer

SETR (SEgmentation TRansformer) was one of the first to use pure ViT for semantic segmentation:

PYTHON
class SETR(nn.Module):
    """
    SETR: Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective.

    Uses ViT encoder with different decoder heads for upsampling.
    """

    def __init__(
        self,
        encoder: nn.Module,
        embed_dim: int = 768,
        num_classes: int = 21,
        img_size: int = 512,
        patch_size: int = 16,
        decoder_type: str = 'pup'  # 'naive', 'pup', or 'mla'
    ):
        super().__init__()
        self.encoder = encoder
        self.patch_size = patch_size
        self.feat_size = img_size // patch_size

        if decoder_type == 'naive':
            self.decoder = SETRNaiveDecoder(embed_dim, num_classes, self.feat_size)
        elif decoder_type == 'pup':
            self.decoder = SETRPUPDecoder(embed_dim, num_classes, self.feat_size)
        elif decoder_type == 'mla':
            self.decoder = SETRMLADecoder(embed_dim, num_classes, self.feat_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, H, W = x.shape
        # Get ViT features (excluding CLS token)
        features = self.encoder(x, return_all_tokens=True)[:, 1:, :]  # [B, N, D]

        # Reshape to spatial
        features = features.transpose(1, 2).reshape(B, -1, self.feat_size, self.feat_size)

        # Decode to full resolution
        logits = self.decoder(features)
        return F.interpolate(logits, size=(H, W), mode='bilinear', align_corners=False)


class SETRNaiveDecoder(nn.Module):
    """Simple 1x1 conv + bilinear upsample."""

    def __init__(self, embed_dim: int, num_classes: int, feat_size: int):
        super().__init__()
        self.conv = nn.Conv2d(embed_dim, num_classes, kernel_size=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.conv(x)


class SETRPUPDecoder(nn.Module):
    """Progressive UPsampling decoder."""

    def __init__(self, embed_dim: int, num_classes: int, feat_size: int):
        super().__init__()
        # Progressive 2x upsampling stages
        self.stages = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(embed_dim if i == 0 else 256, 256, 3, padding=1),
                nn.BatchNorm2d(256),
                nn.ReLU(),
                nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
            )
            for i in range(4)  # 4 stages: 2^4 = 16x upsample
        ])
        self.head = nn.Conv2d(256, num_classes, kernel_size=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for stage in self.stages:
            x = stage(x)
        return self.head(x)


class SETRMLADecoder(nn.Module):
    """Multi-Level feature Aggregation decoder."""

    def __init__(self, embed_dim: int, num_classes: int, feat_size: int):
        super().__init__()
        self.reduce = nn.ModuleList([
            nn.Conv2d(embed_dim, 256, kernel_size=1) for _ in range(4)
        ])
        self.fuse = nn.Conv2d(256 * 4, 256, kernel_size=1)
        self.head = nn.Conv2d(256, num_classes, kernel_size=1)

    def forward(self, features: List[torch.Tensor]) -> torch.Tensor:
        # Aggregate multi-level features from different ViT layers
        reduced = [self.reduce[i](f) for i, f in enumerate(features)]
        # Upsample all to same size and concatenate
        size = reduced[0].shape[2:]
        upsampled = [F.interpolate(f, size=size, mode='bilinear') for f in reduced]
        fused = self.fuse(torch.cat(upsampled, dim=1))
        return self.head(fused)

Segmenter: Mask Transformer

Segmenter uses a transformer decoder to generate class masks:

PYTHON
class Segmenter(nn.Module):
    """
    Segmenter: Transformer for Semantic Segmentation.

    Uses mask transformer decoder that generates per-class mask embeddings.
    """

    def __init__(
        self,
        encoder: nn.Module,
        embed_dim: int = 768,
        num_classes: int = 21,
        decoder_depth: int = 2,
        num_heads: int = 12,
        patch_size: int = 16
    ):
        super().__init__()
        self.encoder = encoder
        self.patch_size = patch_size

        # Learnable class embeddings
        self.cls_emb = nn.Parameter(torch.randn(1, num_classes, embed_dim))

        # Mask transformer decoder
        self.decoder = MaskTransformerDecoder(
            embed_dim, num_heads, decoder_depth, num_classes
        )

        # Project to masks
        self.mask_head = nn.Linear(embed_dim, embed_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, H, W = x.shape
        h, w = H // self.patch_size, W // self.patch_size

        # Encode image patches
        patch_embeddings = self.encoder(x, return_all_tokens=True)[:, 1:, :]

        # Decode class embeddings with cross-attention to patches
        cls_emb = self.cls_emb.expand(B, -1, -1)
        class_embeddings = self.decoder(cls_emb, patch_embeddings)

        # Generate masks via dot product
        patch_proj = self.mask_head(patch_embeddings)
        masks = torch.einsum('bnd,bcd->bnc', patch_proj, class_embeddings)
        masks = masks.transpose(1, 2).reshape(B, -1, h, w)

        return F.interpolate(masks, size=(H, W), mode='bilinear')


class MaskTransformerDecoder(nn.Module):
    """Transformer decoder for mask generation."""

    def __init__(self, embed_dim: int, num_heads: int, depth: int, num_classes: int):
        super().__init__()
        self.layers = nn.ModuleList([
            MaskDecoderLayer(embed_dim, num_heads) for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, cls_emb: torch.Tensor, patch_emb: torch.Tensor) -> torch.Tensor:
        x = cls_emb
        for layer in self.layers:
            x = layer(x, patch_emb)
        return self.norm(x)


class MaskDecoderLayer(nn.Module):
    """Single decoder layer with self and cross attention."""

    def __init__(self, embed_dim: int, num_heads: int):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Linear(embed_dim * 4, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.norm3 = nn.LayerNorm(embed_dim)

    def forward(self, x: torch.Tensor, memory: torch.Tensor) -> torch.Tensor:
        x = x + self.self_attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.cross_attn(self.norm2(x), memory, memory)[0]
        x = x + self.mlp(self.norm3(x))
        return x

DPT: Dense Prediction Transformer

DPT creates multi-scale features by reassembling ViT tokens:

PYTHON
class DPT(nn.Module):
    """
    DPT: Vision Transformers for Dense Prediction.

    Reassembles tokens from different ViT layers into multi-scale features.
    """

    def __init__(
        self,
        encoder: nn.Module,
        embed_dim: int = 768,
        features: List[int] = [256, 512, 1024, 1024],
        readout: str = 'project',  # 'ignore', 'add', or 'project'
        hooks: List[int] = [2, 5, 8, 11]  # Which ViT layers to tap
    ):
        super().__init__()
        self.encoder = encoder
        self.hooks = hooks
        self.readout = readout

        # Readout projections for CLS token
        if readout == 'project':
            self.readout_proj = nn.ModuleList([
                nn.Sequential(
                    nn.Linear(2 * embed_dim, embed_dim),
                    nn.GELU()
                ) for _ in hooks
            ])

        # Reassemble blocks to create multi-scale features
        self.reassemble = nn.ModuleList([
            ReassembleBlock(embed_dim, features[i], scale=2**(3-i))
            for i in range(4)
        ])

        # Fusion blocks
        self.fusion = nn.ModuleList([
            FeatureFusionBlock(features[i]) for i in range(4)
        ])

    def forward(self, x: torch.Tensor, return_features: bool = False):
        B, C, H, W = x.shape

        # Get intermediate features from ViT
        activations = self.encoder.get_intermediate_layers(x, self.hooks)

        # Process each scale
        features = []
        for i, (act, hook_idx) in enumerate(zip(activations, self.hooks)):
            # Handle readout (CLS token)
            if self.readout == 'ignore':
                tokens = act[:, 1:, :]
            elif self.readout == 'add':
                tokens = act[:, 1:, :] + act[:, 0:1, :].expand_as(act[:, 1:, :])
            elif self.readout == 'project':
                cls_token = act[:, 0:1, :].expand(-1, act.size(1) - 1, -1)
                tokens = self.readout_proj[i](torch.cat([act[:, 1:, :], cls_token], dim=-1))

            # Reassemble to 2D feature map
            feat = self.reassemble[i](tokens, H // 16, W // 16)
            features.append(feat)

        # Fuse features (reverse order, coarse to fine)
        fused = features[-1]
        for i in range(len(features) - 2, -1, -1):
            fused = self.fusion[i](features[i], fused)

        if return_features:
            return fused, features
        return fused


class ReassembleBlock(nn.Module):
    """Reassemble tokens into spatial feature maps at different scales."""

    def __init__(self, in_dim: int, out_dim: int, scale: int):
        super().__init__()
        self.proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)

        if scale > 1:
            self.upsample = nn.ConvTranspose2d(
                out_dim, out_dim, kernel_size=scale, stride=scale
            )
        elif scale < 1:
            self.upsample = nn.Conv2d(
                out_dim, out_dim,
                kernel_size=int(1/scale), stride=int(1/scale)
            )
        else:
            self.upsample = nn.Identity()

    def forward(self, tokens: torch.Tensor, h: int, w: int) -> torch.Tensor:
        B, N, D = tokens.shape
        x = tokens.transpose(1, 2).reshape(B, D, h, w)
        x = self.proj(x)
        x = self.upsample(x)
        return x


class FeatureFusionBlock(nn.Module):
    """Fuse features from different scales."""

    def __init__(self, features: int):
        super().__init__()
        self.res1 = nn.Sequential(
            nn.Conv2d(features, features, 3, padding=1, bias=False),
            nn.BatchNorm2d(features),
            nn.ReLU(),
            nn.Conv2d(features, features, 3, padding=1, bias=False),
            nn.BatchNorm2d(features),
            nn.ReLU()
        )
        self.res2 = nn.Sequential(
            nn.Conv2d(features, features, 3, padding=1, bias=False),
            nn.BatchNorm2d(features),
            nn.ReLU()
        )

    def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
        # Upsample skip connection
        skip = F.interpolate(skip, size=x.shape[2:], mode='bilinear')
        x = self.res1(x) + x
        x = self.res2(x + skip)
        return x

ViTDet: ViT for Object Detection

ViTDet adapts plain ViT for object detection:

PYTHON
class ViTDet(nn.Module):
    """
    ViTDet: Exploring Plain Vision Transformer Backbones for Object Detection.

    Uses windowed attention during training and simple feature pyramid.
    """

    def __init__(
        self,
        encoder: nn.Module,
        embed_dim: int = 768,
        out_channels: int = 256,
        num_classes: int = 80,
        window_size: int = 14,
        use_checkpoint: bool = True
    ):
        super().__init__()
        self.encoder = encoder
        self.window_size = window_size
        self.use_checkpoint = use_checkpoint

        # Simple Feature Pyramid (SFP)
        self.fpn = SimpleFeaturePyramid(embed_dim, out_channels)

        # Detection head (e.g., for Mask R-CNN)
        self.roi_heads = None  # Would be RoI pooling + heads

    def forward(self, x: torch.Tensor):
        B, C, H, W = x.shape
        h, w = H // 16, W // 16

        # Get ViT features with windowed attention for efficiency
        features = self.encoder(x, return_all_tokens=True)[:, 1:, :]
        features = features.transpose(1, 2).reshape(B, -1, h, w)

        # Build feature pyramid
        pyramid = self.fpn(features)

        return pyramid


class SimpleFeaturePyramid(nn.Module):
    """Simple Feature Pyramid for ViTDet."""

    def __init__(self, in_dim: int, out_dim: int = 256):
        super().__init__()

        # Lateral connections at multiple scales
        self.scale_factors = [4, 2, 1, 0.5]
        self.lateral_convs = nn.ModuleList([
            nn.Conv2d(in_dim, out_dim, kernel_size=1)
            for _ in self.scale_factors
        ])

        # Output convolutions
        self.output_convs = nn.ModuleList([
            nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1)
            for _ in self.scale_factors
        ])

    def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
        outputs = []
        for i, scale in enumerate(self.scale_factors):
            if scale > 1:
                feat = F.interpolate(x, scale_factor=scale, mode='bilinear')
            elif scale < 1:
                feat = F.avg_pool2d(x, kernel_size=int(1/scale))
            else:
                feat = x

            feat = self.lateral_convs[i](feat)
            feat = self.output_convs[i](feat)
            outputs.append(feat)

        return outputs


class WindowedAttention(nn.Module):
    """Windowed attention for efficient high-resolution processing."""

    def __init__(self, dim: int, num_heads: int, window_size: int):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.head_dim = dim // num_heads

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
        B, N, C = x.shape
        ws = self.window_size

        # Reshape to 2D
        x = x.reshape(B, H, W, C)

        # Pad to multiple of window size
        pad_h = (ws - H % ws) % ws
        pad_w = (ws - W % ws) % ws
        x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
        Hp, Wp = H + pad_h, W + pad_w

        # Partition into windows
        x = x.reshape(B, Hp // ws, ws, Wp // ws, ws, C)
        x = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, ws * ws, C)

        # Attention within windows
        qkv = self.qkv(x).reshape(-1, ws * ws, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(-1, ws * ws, C)
        x = self.proj(x)

        # Merge windows back
        x = x.reshape(B, Hp // ws, Wp // ws, ws, ws, C)
        x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, C)

        # Remove padding
        x = x[:, :H, :W, :].reshape(B, N, C)
        return x

ViT-Adapter: Adapting ViT for Dense Tasks

ViT-Adapter adds spatial priors to plain ViT:

PYTHON
class ViTAdapter(nn.Module):
    """
    ViT-Adapter: Vision Transformer Adapter for Dense Predictions.

    Injects multi-scale spatial features into ViT via adapters.
    """

    def __init__(
        self,
        encoder: nn.Module,
        embed_dim: int = 768,
        num_heads: int = 12,
        adapter_dims: List[int] = [64, 128, 320, 512]
    ):
        super().__init__()
        self.encoder = encoder

        # Spatial prior module (CNN backbone)
        self.spatial_prior = SpatialPriorModule(adapter_dims)

        # Interaction modules at selected layers
        self.interactions = nn.ModuleList([
            SpatialInteraction(embed_dim, adapter_dims[-1])
            for _ in range(4)
        ])

        # Inject adapter at these layer indices
        self.inject_layers = [2, 5, 8, 11]

    def forward(self, x: torch.Tensor):
        B, C, H, W = x.shape

        # Extract multi-scale spatial priors
        spatial_features = self.spatial_prior(x)

        # Forward through ViT with interactions
        patch_embed = self.encoder.patch_embed(x)
        x = patch_embed + self.encoder.pos_embed[:, 1:, :]

        interaction_idx = 0
        for i, block in enumerate(self.encoder.blocks):
            x = block(x)

            if i in self.inject_layers:
                # Inject spatial information
                x = self.interactions[interaction_idx](
                    x, spatial_features[-1], H // 16, W // 16
                )
                interaction_idx += 1

        return x, spatial_features


class SpatialPriorModule(nn.Module):
    """CNN module to extract multi-scale spatial features."""

    def __init__(self, dims: List[int]):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(3, dims[0], 3, stride=2, padding=1),
            nn.BatchNorm2d(dims[0]),
            nn.ReLU(),
            nn.Conv2d(dims[0], dims[0], 3, padding=1),
            nn.BatchNorm2d(dims[0]),
            nn.ReLU()
        )

        self.stages = nn.ModuleList()
        for i in range(len(dims) - 1):
            self.stages.append(nn.Sequential(
                nn.Conv2d(dims[i], dims[i+1], 3, stride=2, padding=1),
                nn.BatchNorm2d(dims[i+1]),
                nn.ReLU(),
                nn.Conv2d(dims[i+1], dims[i+1], 3, padding=1),
                nn.BatchNorm2d(dims[i+1]),
                nn.ReLU()
            ))

    def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
        features = []
        x = self.stem(x)
        features.append(x)
        for stage in self.stages:
            x = stage(x)
            features.append(x)
        return features


class SpatialInteraction(nn.Module):
    """Interaction between ViT tokens and spatial features."""

    def __init__(self, vit_dim: int, spatial_dim: int):
        super().__init__()
        self.spatial_proj = nn.Conv2d(spatial_dim, vit_dim, 1)
        self.norm = nn.LayerNorm(vit_dim)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, tokens: torch.Tensor, spatial: torch.Tensor,
                h: int, w: int) -> torch.Tensor:
        B, N, D = tokens.shape

        # Project spatial features
        spatial = self.spatial_proj(spatial)
        spatial = F.interpolate(spatial, size=(h, w), mode='bilinear')
        spatial = spatial.flatten(2).transpose(1, 2)

        # Add to tokens with learnable weight
        tokens = tokens + self.gamma * self.norm(spatial)
        return tokens

Depth Estimation with ViT

PYTHON
class ViTDepth(nn.Module):
    """ViT-based monocular depth estimation."""

    def __init__(
        self,
        encoder: nn.Module,
        embed_dim: int = 768,
        decoder_features: List[int] = [256, 128, 64, 32],
        max_depth: float = 10.0
    ):
        super().__init__()
        self.encoder = encoder
        self.max_depth = max_depth

        # Decoder with skip connections
        self.decoder = DepthDecoder(embed_dim, decoder_features)

        # Final depth head
        self.depth_head = nn.Sequential(
            nn.Conv2d(decoder_features[-1], 1, 3, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, H, W = x.shape

        # Encode
        features = self.encoder(x, return_all_tokens=True)[:, 1:, :]
        h, w = H // 16, W // 16
        features = features.transpose(1, 2).reshape(B, -1, h, w)

        # Decode
        decoded = self.decoder(features)

        # Predict depth
        depth = self.depth_head(decoded)
        depth = F.interpolate(depth, size=(H, W), mode='bilinear')
        return depth * self.max_depth


class DepthDecoder(nn.Module):
    """Decoder for depth estimation."""

    def __init__(self, in_dim: int, features: List[int]):
        super().__init__()
        self.layers = nn.ModuleList()
        prev_dim = in_dim
        for feat in features:
            self.layers.append(nn.Sequential(
                nn.Conv2d(prev_dim, feat, 3, padding=1),
                nn.BatchNorm2d(feat),
                nn.ReLU(),
                nn.Upsample(scale_factor=2, mode='bilinear')
            ))
            prev_dim = feat

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for layer in self.layers:
            x = layer(x)
        return x

Architecture Comparison

PYTHON
def compare_dense_prediction_architectures():
    """Compare ViT architectures for dense prediction."""
    architectures = {
        'SETR': {
            'approach': 'Simple decoder on ViT features',
            'multi_scale': 'No (single scale from ViT)',
            'modifications': 'None to ViT',
            'tasks': 'Semantic segmentation',
            'pros': 'Simple, preserves ViT design',
            'cons': 'Limited multi-scale, high memory'
        },
        'Segmenter': {
            'approach': 'Mask transformer decoder',
            'multi_scale': 'No',
            'modifications': 'Adds decoder transformer',
            'tasks': 'Semantic segmentation',
            'pros': 'End-to-end transformer',
            'cons': 'Single scale features'
        },
        'DPT': {
            'approach': 'Reassemble tokens to FPN',
            'multi_scale': 'Yes (from different ViT layers)',
            'modifications': 'None to ViT',
            'tasks': 'Segmentation, depth, surface normals',
            'pros': 'Good multi-scale features',
            'cons': 'Complex reassembly'
        },
        'ViTDet': {
            'approach': 'Plain ViT with simple FPN',
            'multi_scale': 'Yes (from single-scale features)',
            'modifications': 'Windowed attention',
            'tasks': 'Object detection, instance segmentation',
            'pros': 'Minimal changes, strong results',
            'cons': 'Needs windowed attention for efficiency'
        },
        'ViT-Adapter': {
            'approach': 'Inject spatial priors via adapters',
            'multi_scale': 'Yes (from CNN spatial prior)',
            'modifications': 'Adds adapter modules',
            'tasks': 'Detection, segmentation',
            'pros': 'Best of CNN and ViT',
            'cons': 'More parameters'
        }
    }

    for name, info in architectures.items():
        print(f"\n{name}:")
        for k, v in info.items():
            print(f"  {k}: {v}")

compare_dense_prediction_architectures()

Key Takeaways

Adapting Vision Transformers for dense prediction requires addressing the mismatch between ViT's single-scale global features and the multi-scale, high-resolution requirements of tasks like segmentation and detection. SETR pioneered the direct use of ViT features with simple decoders. DPT introduced reassembling tokens from different layers into a feature pyramid. ViTDet showed that plain ViT with windowed attention and simple FPN achieves strong detection results. ViT-Adapter bridges CNN and ViT by injecting spatial priors through adapter modules. The choice of architecture depends on the task: semantic segmentation benefits from global context (SETR, Segmenter), while detection and instance segmentation need efficient multi-scale features (ViTDet, ViT-Adapter).

21.5 Efficient Vision Transformers Advanced

Efficient Vision Transformers

While Vision Transformers achieve state-of-the-art results, their quadratic attention complexity and high computational costs limit deployment on resource-constrained devices. This section explores techniques to make ViTs more efficient: linear attention, token reduction, knowledge distillation, and architectures designed for mobile deployment.

Attention Efficiency

The self-attention mechanism has O(N²) complexity where N is the number of tokens. For high-resolution images, this becomes prohibitive:

PYTHON
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
import math

def attention_complexity_analysis():
    """Analyze attention complexity for different image sizes."""
    patch_size = 16

    for img_size in [224, 384, 512, 1024]:
        num_patches = (img_size // patch_size) ** 2
        # Attention: Q @ K^T is [N, N]
        attention_ops = num_patches ** 2  # For each head
        print(f"Image {img_size}x{img_size}: {num_patches} patches, "
              f"{attention_ops/1e6:.1f}M attention ops")

# Output:
# Image 224x224: 196 patches, 0.04M attention ops
# Image 384x384: 576 patches, 0.33M attention ops
# Image 512x512: 1024 patches, 1.05M attention ops
# Image 1024x1024: 4096 patches, 16.78M attention ops


class LinearAttention(nn.Module):
    """
    Linear attention with kernel feature maps.
    Complexity: O(N) instead of O(N²)
    """

    def __init__(self, dim: int, num_heads: int = 8):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)

    def feature_map(self, x: torch.Tensor) -> torch.Tensor:
        """ELU-based feature map for positive features."""
        return F.elu(x) + 1

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, C = x.shape

        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Apply feature map to get positive features
        q = self.feature_map(q)
        k = self.feature_map(k)

        # Linear attention: avoid computing N×N attention matrix
        # Instead: (Q @ (K^T @ V)) which is O(N * D * D)
        kv = torch.einsum('bhnd,bhnv->bhdv', k, v)
        qkv = torch.einsum('bhnd,bhdv->bhnv', q, kv)

        # Normalize
        k_sum = k.sum(dim=2, keepdim=True)
        normalizer = torch.einsum('bhnd,bhkd->bhnk', q, k_sum)
        x = qkv / (normalizer + 1e-6)

        x = x.transpose(1, 2).reshape(B, N, C)
        return self.proj(x)


class PerformerAttention(nn.Module):
    """
    Performer: Fast Attention via positive Orthogonal Random features (FAVOR+).
    Uses random feature approximation for linear complexity.
    """

    def __init__(self, dim: int, num_heads: int = 8, num_features: int = 256):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.num_features = num_features

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)

        # Random features for kernel approximation
        self.register_buffer(
            'random_features',
            torch.randn(self.head_dim, num_features) / math.sqrt(self.head_dim)
        )

    def softmax_kernel(self, x: torch.Tensor) -> torch.Tensor:
        """Approximate softmax kernel with random features."""
        # Project to random feature space
        x_proj = torch.einsum('bhnd,df->bhnf', x, self.random_features)
        # exp for softmax approximation
        return torch.exp(x_proj - x_proj.max(dim=-1, keepdim=True)[0]) / math.sqrt(self.num_features)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, C = x.shape

        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Apply random feature map
        q = self.softmax_kernel(q)
        k = self.softmax_kernel(k)

        # Linear attention computation
        kv = torch.einsum('bhnf,bhnd->bhfd', k, v)
        out = torch.einsum('bhnf,bhfd->bhnd', q, kv)

        # Normalize
        k_sum = k.sum(dim=2)
        normalizer = torch.einsum('bhnf,bhf->bhn', q, k_sum)
        out = out / normalizer.unsqueeze(-1)

        out = out.transpose(1, 2).reshape(B, N, C)
        return self.proj(out)

Token Reduction: ToMe (Token Merging)

Token merging reduces the number of tokens by progressively merging similar ones:

PYTHON
class TokenMerging(nn.Module):
    """
    ToMe: Token Merging for efficient ViT inference.
    Merges similar tokens to reduce computation.
    """

    def __init__(self, reduction_ratio: float = 0.5):
        super().__init__()
        self.reduction_ratio = reduction_ratio

    def forward(
        self,
        x: torch.Tensor,
        num_tokens_to_merge: Optional[int] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Merge similar tokens.

        Returns:
            merged: Merged tokens [B, N_new, C]
            unmerge_info: Information for unmerging if needed
        """
        B, N, C = x.shape

        if num_tokens_to_merge is None:
            num_tokens_to_merge = int(N * self.reduction_ratio)

        if num_tokens_to_merge == 0:
            return x, None

        # Bipartite soft matching: split tokens into source and target
        num_src = N // 2
        src = x[:, :num_src, :]
        tgt = x[:, num_src:, :]

        # Compute similarity
        src_norm = F.normalize(src, dim=-1)
        tgt_norm = F.normalize(tgt, dim=-1)
        similarity = torch.bmm(src_norm, tgt_norm.transpose(1, 2))

        # Find matches: each src token finds best tgt token
        node_max, node_idx = similarity.max(dim=-1)

        # Sort by similarity and merge top-r pairs
        edge_idx = node_max.argsort(dim=-1, descending=True)

        # Merge tokens
        num_merge = min(num_tokens_to_merge, num_src)
        merged_src = []
        merged_tgt_idx = set()

        for b in range(B):
            batch_merged_src = []
            batch_merged_tgt_idx = set()

            for i in range(num_merge):
                src_idx = edge_idx[b, i].item()
                tgt_idx = node_idx[b, src_idx].item()

                if tgt_idx not in batch_merged_tgt_idx:
                    # Merge source into target
                    batch_merged_tgt_idx.add(tgt_idx)

            merged_tgt_idx.add(frozenset(batch_merged_tgt_idx))

        # Build output
        # For simplicity, average merged tokens
        output = []
        for b in range(B):
            kept_src = []
            kept_tgt = list(range(tgt.size(1)))

            for i in range(num_src):
                src_i = edge_idx[b, i].item()
                if i < num_merge:
                    # Merge into target
                    tgt_i = node_idx[b, src_i].item()
                    tgt[b, tgt_i] = (tgt[b, tgt_i] + src[b, src_i]) / 2
                else:
                    kept_src.append(src[b, src_i])

            batch_tokens = []
            if kept_src:
                batch_tokens.append(torch.stack(kept_src))
            batch_tokens.append(tgt[b])
            output.append(torch.cat(batch_tokens, dim=0))

        output = torch.stack(output)
        return output, (edge_idx, node_idx, num_merge)


class EfficientViTWithToMe(nn.Module):
    """ViT with Token Merging for efficiency."""

    def __init__(
        self,
        base_vit: nn.Module,
        merge_schedule: list = [0.0, 0.0, 0.5, 0.5, 0.5, 0.5]
    ):
        super().__init__()
        self.vit = base_vit
        self.merge_schedule = merge_schedule
        self.token_merger = TokenMerging()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Patch embed
        x = self.vit.patch_embed(x)
        B, N, C = x.shape

        # Add CLS token and position
        cls_token = self.vit.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_token, x], dim=1)
        x = x + self.vit.pos_embed

        # Process through blocks with token merging
        for i, block in enumerate(self.vit.blocks):
            x = block(x)

            if i < len(self.merge_schedule) and self.merge_schedule[i] > 0:
                # Separate CLS token
                cls = x[:, :1, :]
                tokens = x[:, 1:, :]

                # Merge tokens
                tokens, _ = self.token_merger(
                    tokens,
                    num_tokens_to_merge=int(tokens.size(1) * self.merge_schedule[i])
                )

                x = torch.cat([cls, tokens], dim=1)

        x = self.vit.norm(x)
        return self.vit.head(x[:, 0])

Token Pruning

Dynamically prune unimportant tokens during inference:

PYTHON
class DynamicTokenPruning(nn.Module):
    """
    Prune tokens based on attention scores or learned importance.
    """

    def __init__(self, keep_ratio: float = 0.7):
        super().__init__()
        self.keep_ratio = keep_ratio
        self.importance_predictor = None  # Can be learned

    def compute_importance(
        self,
        x: torch.Tensor,
        attn_weights: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Compute token importance scores.

        Options:
        1. Attention to CLS token
        2. Attention entropy
        3. Learned predictor
        """
        if attn_weights is not None:
            # Use attention from CLS to other tokens
            # attn_weights: [B, heads, N, N]
            cls_attn = attn_weights[:, :, 0, 1:]  # [B, heads, N-1]
            importance = cls_attn.mean(dim=1)  # Average across heads
        else:
            # Use L2 norm as proxy for importance
            importance = x[:, 1:, :].norm(dim=-1)

        return importance

    def forward(
        self,
        x: torch.Tensor,
        attn_weights: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        B, N, C = x.shape

        # Separate CLS token
        cls_token = x[:, :1, :]
        tokens = x[:, 1:, :]

        # Compute importance
        importance = self.compute_importance(x, attn_weights)

        # Keep top-k tokens
        num_keep = int((N - 1) * self.keep_ratio)
        _, keep_idx = importance.topk(num_keep, dim=-1)

        # Gather kept tokens
        keep_idx = keep_idx.unsqueeze(-1).expand(-1, -1, C)
        kept_tokens = torch.gather(tokens, dim=1, index=keep_idx)

        # Reconstruct
        x = torch.cat([cls_token, kept_tokens], dim=1)

        return x, keep_idx


class AdaptiveTokenPruning(nn.Module):
    """
    Learn when and what to prune adaptively.
    """

    def __init__(self, dim: int, min_keep_ratio: float = 0.5):
        super().__init__()
        self.min_keep_ratio = min_keep_ratio

        # Learn importance scores
        self.importance_net = nn.Sequential(
            nn.Linear(dim, dim // 4),
            nn.ReLU(),
            nn.Linear(dim // 4, 1)
        )

        # Learn pruning threshold
        self.threshold = nn.Parameter(torch.zeros(1))

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        B, N, C = x.shape

        cls_token = x[:, :1, :]
        tokens = x[:, 1:, :]

        # Predict importance
        importance = self.importance_net(tokens).squeeze(-1)  # [B, N-1]
        importance = torch.sigmoid(importance)

        # Gumbel-softmax for differentiable selection during training
        if self.training:
            # Soft selection
            mask = importance
        else:
            # Hard selection: keep tokens above threshold
            mask = (importance > torch.sigmoid(self.threshold)).float()

            # Ensure minimum tokens kept
            num_keep = max(int((N-1) * self.min_keep_ratio),
                          mask.sum(dim=-1).min().int().item())
            if mask.sum(dim=-1).min() < num_keep:
                _, topk_idx = importance.topk(num_keep, dim=-1)
                mask = torch.zeros_like(importance)
                mask.scatter_(1, topk_idx, 1.0)

        # Apply mask
        masked_tokens = tokens * mask.unsqueeze(-1)

        # Remove zero tokens for efficiency (inference only)
        if not self.training:
            kept_idx = mask.bool()
            # This requires handling variable lengths per batch
            pass

        x = torch.cat([cls_token, masked_tokens], dim=1)
        return x, mask

Knowledge Distillation for ViT

Transfer knowledge from large to small ViT:

PYTHON
class ViTDistillation(nn.Module):
    """
    Knowledge distillation for Vision Transformers.
    """

    def __init__(
        self,
        teacher: nn.Module,
        student: nn.Module,
        temperature: float = 4.0,
        alpha: float = 0.5
    ):
        super().__init__()
        self.teacher = teacher
        self.student = student
        self.temperature = temperature
        self.alpha = alpha

        # Freeze teacher
        for p in self.teacher.parameters():
            p.requires_grad = False
        self.teacher.eval()

    def forward(
        self,
        x: torch.Tensor,
        labels: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Student forward
        student_logits = self.student(x)

        # Teacher forward (no grad)
        with torch.no_grad():
            teacher_logits = self.teacher(x)

        # Hard label loss
        hard_loss = F.cross_entropy(student_logits, labels)

        # Soft label loss (KL divergence)
        soft_student = F.log_softmax(student_logits / self.temperature, dim=-1)
        soft_teacher = F.softmax(teacher_logits / self.temperature, dim=-1)
        soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean')
        soft_loss = soft_loss * (self.temperature ** 2)

        # Combined loss
        loss = self.alpha * hard_loss + (1 - self.alpha) * soft_loss

        return loss, student_logits


class FeatureDistillation(nn.Module):
    """
    Distill intermediate features, not just logits.
    """

    def __init__(
        self,
        teacher: nn.Module,
        student: nn.Module,
        teacher_layers: list = [3, 6, 9, 11],
        student_layers: list = [1, 2, 3, 4]
    ):
        super().__init__()
        self.teacher = teacher
        self.student = student
        self.teacher_layers = teacher_layers
        self.student_layers = student_layers

        # Projectors to match dimensions
        self.projectors = nn.ModuleList([
            nn.Linear(student.embed_dim, teacher.embed_dim)
            for _ in student_layers
        ])

        for p in self.teacher.parameters():
            p.requires_grad = False

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Get teacher features
        with torch.no_grad():
            teacher_features = self.teacher.get_intermediate_layers(
                x, self.teacher_layers
            )

        # Get student features
        student_features = self.student.get_intermediate_layers(
            x, self.student_layers
        )

        # Feature matching loss
        feature_loss = 0
        for i, (s_feat, t_feat) in enumerate(zip(student_features, teacher_features)):
            s_feat_proj = self.projectors[i](s_feat)
            feature_loss += F.mse_loss(s_feat_proj, t_feat)

        return feature_loss / len(self.student_layers)

Mobile-Optimized Architectures

PYTHON
class MobileViTBlock(nn.Module):
    """
    MobileViT: combines CNN local processing with transformer global attention.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        embed_dim: int,
        num_heads: int = 4,
        patch_size: int = 2
    ):
        super().__init__()
        self.patch_size = patch_size

        # Local representation (CNN)
        self.local_rep = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 3, padding=1),
            nn.BatchNorm2d(in_channels),
            nn.SiLU(),
            nn.Conv2d(in_channels, embed_dim, 1)
        )

        # Global representation (Transformer)
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=embed_dim,
                nhead=num_heads,
                dim_feedforward=embed_dim * 2,
                batch_first=True
            ),
            num_layers=2
        )

        # Fusion
        self.fusion = nn.Sequential(
            nn.Conv2d(embed_dim, in_channels, 1),
            nn.BatchNorm2d(in_channels)
        )

        # Project to output
        self.proj = nn.Sequential(
            nn.Conv2d(2 * in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.SiLU()
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, H, W = x.shape
        p = self.patch_size

        # Local features
        local = self.local_rep(x)  # [B, D, H, W]

        # Unfold to patches for transformer
        # [B, D, H, W] -> [B, D, H/p, p, W/p, p] -> [B, H/p*W/p, p*p, D]
        D = local.size(1)
        local_patches = local.reshape(B, D, H//p, p, W//p, p)
        local_patches = local_patches.permute(0, 2, 4, 3, 5, 1)
        local_patches = local_patches.reshape(B, (H//p)*(W//p), p*p, D)

        # Apply transformer to each patch location across the image
        # Reshape: [B, num_patches, patch_tokens, D] -> [B*patch_tokens, num_patches, D]
        local_patches = local_patches.permute(0, 2, 1, 3)
        local_patches = local_patches.reshape(B * p * p, (H//p)*(W//p), D)

        global_patches = self.transformer(local_patches)

        # Reshape back
        global_patches = global_patches.reshape(B, p*p, (H//p)*(W//p), D)
        global_patches = global_patches.permute(0, 2, 1, 3)
        global_patches = global_patches.reshape(B, H//p, W//p, p, p, D)
        global_patches = global_patches.permute(0, 5, 1, 3, 2, 4)
        global_out = global_patches.reshape(B, D, H, W)

        # Fuse and project
        global_out = self.fusion(global_out)
        out = self.proj(torch.cat([x, global_out], dim=1))

        return out


class EfficientFormerBlock(nn.Module):
    """
    EfficientFormer: separates token mixer and MLP for efficiency.
    Uses pooling-based token mixer in early stages.
    """

    def __init__(
        self,
        dim: int,
        use_attn: bool = False,
        pool_size: int = 3
    ):
        super().__init__()
        self.use_attn = use_attn

        if use_attn:
            self.token_mixer = nn.MultiheadAttention(dim, 8, batch_first=True)
        else:
            # Pool-based token mixer (efficient)
            self.token_mixer = nn.AvgPool2d(
                pool_size, stride=1, padding=pool_size//2, count_include_pad=False
            )

        self.norm1 = nn.LayerNorm(dim) if use_attn else nn.BatchNorm2d(dim)
        self.norm2 = nn.LayerNorm(dim) if use_attn else nn.BatchNorm2d(dim)

        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4) if use_attn else nn.Conv2d(dim, dim * 4, 1),
            nn.GELU(),
            nn.Linear(dim * 4, dim) if use_attn else nn.Conv2d(dim * 4, dim, 1)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.use_attn:
            # Attention path (later stages)
            x = x + self.token_mixer(self.norm1(x), self.norm1(x), self.norm1(x))[0]
            x = x + self.mlp(self.norm2(x))
        else:
            # Pool path (early stages)
            x = x + self.token_mixer(self.norm1(x)) - self.norm1(x)
            x = x + self.mlp(self.norm2(x))
        return x


class LeViT(nn.Module):
    """
    LeViT: Vision Transformer with convolutions for fast inference.
    Uses attention bias instead of positional embeddings.
    """

    def __init__(
        self,
        num_classes: int = 1000,
        stages: list = [256, 384, 512],
        depths: list = [4, 4, 4],
        num_heads: list = [4, 6, 8]
    ):
        super().__init__()

        # Patch embed with convolutions
        self.patch_embed = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.Hardswish(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.Hardswish(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.Hardswish(),
            nn.Conv2d(128, stages[0], 3, stride=2, padding=1)
        )

        # Stages
        self.stages = nn.ModuleList()
        for i, (dim, depth, heads) in enumerate(zip(stages, depths, num_heads)):
            stage = nn.ModuleList([
                LeViTBlock(dim, heads) for _ in range(depth)
            ])
            self.stages.append(stage)

            # Downsample between stages
            if i < len(stages) - 1:
                self.stages.append(
                    nn.Conv2d(dim, stages[i+1], 3, stride=2, padding=1)
                )

        self.head = nn.Linear(stages[-1], num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.patch_embed(x)

        for stage in self.stages:
            if isinstance(stage, nn.Conv2d):
                x = stage(x)
            else:
                for block in stage:
                    x = block(x)

        x = x.mean(dim=[2, 3])
        return self.head(x)


class LeViTBlock(nn.Module):
    """LeViT attention block with hardswish and batch norm."""

    def __init__(self, dim: int, num_heads: int):
        super().__init__()
        self.attn = LeViTAttention(dim, num_heads)
        self.mlp = nn.Sequential(
            nn.Conv2d(dim, dim * 2, 1),
            nn.BatchNorm2d(dim * 2),
            nn.Hardswish(),
            nn.Conv2d(dim * 2, dim, 1),
            nn.BatchNorm2d(dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(x)
        x = x + self.mlp(x)
        return x


class LeViTAttention(nn.Module):
    """LeViT attention with learned bias instead of positional encoding."""

    def __init__(self, dim: int, num_heads: int):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.qkv = nn.Conv2d(dim, dim * 3, 1)
        self.proj = nn.Conv2d(dim, dim, 1)
        self.bn = nn.BatchNorm2d(dim)

        # Attention bias (replaces positional encoding)
        # Will be set based on resolution
        self.attention_bias = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, H, W = x.shape
        N = H * W

        qkv = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, N)
        q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]

        attn = (q.transpose(-2, -1) @ k) * (self.head_dim ** -0.5)

        # Add attention bias if available
        if self.attention_bias is not None:
            attn = attn + self.attention_bias[:, :, :N, :N]

        attn = attn.softmax(dim=-1)
        x = (v @ attn.transpose(-2, -1)).reshape(B, C, H, W)

        return self.bn(self.proj(x))

Efficiency Comparison

PYTHON
def compare_efficient_vit_methods():
    """Compare different efficiency techniques."""
    methods = {
        'Linear Attention': {
            'complexity': 'O(N)',
            'memory': 'Low',
            'accuracy_drop': '1-3%',
            'use_case': 'Long sequences, high resolution'
        },
        'Token Merging (ToMe)': {
            'complexity': 'O(N²) with fewer N',
            'memory': 'Reduced by merge ratio',
            'accuracy_drop': '0.5-1%',
            'use_case': 'Inference speedup, works with any ViT'
        },
        'Token Pruning': {
            'complexity': 'O(N²) with fewer N',
            'memory': 'Reduced dynamically',
            'accuracy_drop': '0.5-2%',
            'use_case': 'Adaptive computation'
        },
        'Knowledge Distillation': {
            'complexity': 'Same as student',
            'memory': 'Student model size',
            'accuracy_drop': '1-3%',
            'use_case': 'Model compression'
        },
        'MobileViT': {
            'complexity': 'O(N) for local, O(P²) for global',
            'memory': 'Low (mobile optimized)',
            'accuracy_drop': '2-5% vs full ViT',
            'use_case': 'Mobile deployment'
        },
        'EfficientFormer': {
            'complexity': 'O(N) early, O(N²) late',
            'memory': 'Medium',
            'accuracy_drop': '1-2%',
            'use_case': 'Balanced speed/accuracy'
        }
    }

    for name, info in methods.items():
        print(f"\n{name}:")
        for k, v in info.items():
            print(f"  {k}: {v}")

compare_efficient_vit_methods()

Key Takeaways

Making Vision Transformers efficient requires addressing the quadratic attention complexity and high computational costs. Linear attention approximations reduce complexity to O(N) but may lose some accuracy. Token reduction methods like ToMe and pruning decrease the number of tokens processed, providing significant speedups with minimal accuracy loss. Knowledge distillation transfers capabilities from large to small models. Mobile-optimized architectures like MobileViT and EfficientFormer redesign the basic blocks for efficiency, often combining CNN's local processing with transformer's global attention. The choice of efficiency method depends on the deployment constraints: token merging works as a plug-in for any pretrained ViT, while mobile architectures require training from scratch but achieve better efficiency-accuracy trade-offs.