Vision Transformer Architecture
The Vision Transformer (ViT) marked a paradigm shift in computer vision by demonstrating that the transformer architecture, originally designed for natural language processing, could achieve state-of-the-art results on image classification when trained on sufficient data. Rather than relying on the inductive biases of convolutional neural networks—locality and translation equivariance—ViT treats images as sequences of patches and learns spatial relationships entirely through attention mechanisms. This approach opened new possibilities for unified architectures across modalities and sparked intense research into adapting transformers for visual understanding.
From Sequences to Images
Transformers process sequences of tokens, but images are 2D grids of pixels. The key insight of ViT is to convert images into sequences by dividing them into fixed-size patches, flattening each patch into a vector, and treating these vectors as the input sequence. A 224×224 image divided into 16×16 patches yields 196 patch tokens, a manageable sequence length for transformer processing.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def image_to_patches(image, patch_size=16):
"""
Convert an image into a sequence of flattened patches.
Args:
image: Tensor of shape (B, C, H, W)
patch_size: Size of each square patch
Returns:
patches: Tensor of shape (B, num_patches, patch_dim)
where patch_dim = C * patch_size * patch_size
"""
B, C, H, W = image.shape
assert H % patch_size == 0 and W % patch_size == 0, \
f"Image dimensions must be divisible by patch size"
# Number of patches along each dimension
num_patches_h = H // patch_size
num_patches_w = W // patch_size
num_patches = num_patches_h * num_patches_w
# Reshape: (B, C, H, W) -> (B, C, n_h, p_h, n_w, p_w)
patches = image.reshape(B, C, num_patches_h, patch_size, num_patches_w, patch_size)
# Rearrange: (B, C, n_h, p_h, n_w, p_w) -> (B, n_h, n_w, p_h, p_w, C)
patches = patches.permute(0, 2, 4, 3, 5, 1)
# Flatten patches: (B, n_h, n_w, p_h, p_w, C) -> (B, num_patches, patch_dim)
patches = patches.reshape(B, num_patches, -1)
return patches
# Demonstrate patch extraction
def visualize_patching():
"""Show how an image is converted to patch sequence."""
# Example: 224x224 RGB image with 16x16 patches
B, C, H, W = 1, 3, 224, 224
patch_size = 16
image = torch.randn(B, C, H, W)
patches = image_to_patches(image, patch_size)
num_patches = (H // patch_size) * (W // patch_size)
patch_dim = C * patch_size * patch_size
print("Vision Transformer Patching:")
print(f" Input image: {B} x {C} x {H} x {W}")
print(f" Patch size: {patch_size} x {patch_size}")
print(f" Number of patches: {num_patches} ({H//patch_size} x {W//patch_size})")
print(f" Patch dimension: {patch_dim} ({C} x {patch_size} x {patch_size})")
print(f" Output sequence: {patches.shape}")
visualize_patching()This patching strategy preserves local structure within each patch while allowing the transformer to learn global relationships between patches through attention. The patch size represents a trade-off: smaller patches capture finer details but create longer sequences with quadratic attention cost, while larger patches reduce computational cost but may miss fine-grained features.
Patch Embedding Layer
The patch embedding layer projects flattened patches to the transformer's hidden dimension. While ViT can use a simple linear projection, implementing this as a convolution with kernel and stride equal to the patch size is more efficient and produces identical results.
class PatchEmbedding(nn.Module):
"""
Convert image to patch embeddings.
Uses convolution for efficient implementation.
"""
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
# Convolution with kernel_size=stride=patch_size extracts non-overlapping patches
# and projects them to embed_dim in one operation
self.projection = nn.Conv2d(
in_channels,
embed_dim,
kernel_size=patch_size,
stride=patch_size
)
def forward(self, x):
"""
Args:
x: Input image (B, C, H, W)
Returns:
Patch embeddings (B, num_patches, embed_dim)
"""
B, C, H, W = x.shape
assert H == self.img_size and W == self.img_size, \
f"Input size ({H}x{W}) doesn't match model ({self.img_size}x{self.img_size})"
# (B, C, H, W) -> (B, embed_dim, H/P, W/P)
x = self.projection(x)
# (B, embed_dim, H/P, W/P) -> (B, embed_dim, num_patches)
x = x.flatten(2)
# (B, embed_dim, num_patches) -> (B, num_patches, embed_dim)
x = x.transpose(1, 2)
return x
class PatchEmbeddingLinear(nn.Module):
"""
Alternative implementation using reshape and linear layer.
Mathematically equivalent to convolution approach.
"""
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
patch_dim = in_channels * patch_size * patch_size
self.projection = nn.Linear(patch_dim, embed_dim)
def forward(self, x):
B, C, H, W = x.shape
p = self.patch_size
# Reshape to patches
x = x.reshape(B, C, H // p, p, W // p, p)
x = x.permute(0, 2, 4, 3, 5, 1) # (B, H/p, W/p, p, p, C)
x = x.reshape(B, self.num_patches, -1) # (B, num_patches, patch_dim)
# Linear projection
x = self.projection(x)
return xPositional Embeddings
Unlike CNNs that inherently encode spatial information through local connectivity, transformers have no notion of position. ViT adds learnable position embeddings to patch embeddings, allowing the model to learn spatial relationships. These embeddings are typically initialized randomly and learned during training.
class PositionalEmbedding(nn.Module):
"""
Learnable 1D positional embeddings for ViT.
Added to patch embeddings to encode spatial position.
"""
def __init__(self, num_patches, embed_dim):
super().__init__()
# +1 for the [CLS] token
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
# Initialize with small random values
nn.init.trunc_normal_(self.pos_embed, std=0.02)
def forward(self, x):
"""Add positional embeddings to input."""
return x + self.pos_embed
class SinusoidalPositionalEmbedding(nn.Module):
"""
Fixed sinusoidal positional embeddings (alternative to learned).
Similar to original Transformer, adapted for 2D positions.
"""
def __init__(self, num_patches, embed_dim, temperature=10000):
super().__init__()
self.embed_dim = embed_dim
# Compute grid of positions
grid_size = int(num_patches ** 0.5)
pos_h = torch.arange(grid_size).float()
pos_w = torch.arange(grid_size).float()
grid = torch.stack(torch.meshgrid(pos_h, pos_w, indexing='ij'), dim=-1)
grid = grid.reshape(-1, 2) # (num_patches, 2)
# Compute sinusoidal embeddings
dim_t = torch.arange(embed_dim // 4).float()
dim_t = temperature ** (2 * dim_t / (embed_dim // 2))
pos_embed = torch.zeros(num_patches, embed_dim)
pos_embed[:, 0::4] = torch.sin(grid[:, 0:1] / dim_t)
pos_embed[:, 1::4] = torch.cos(grid[:, 0:1] / dim_t)
pos_embed[:, 2::4] = torch.sin(grid[:, 1:2] / dim_t)
pos_embed[:, 3::4] = torch.cos(grid[:, 1:2] / dim_t)
# Add CLS token position (zeros)
cls_pos = torch.zeros(1, embed_dim)
pos_embed = torch.cat([cls_pos, pos_embed], dim=0)
self.register_buffer('pos_embed', pos_embed.unsqueeze(0))
def forward(self, x):
return x + self.pos_embed
def visualize_position_embeddings():
"""
Show the learned spatial structure in position embeddings.
Trained ViT position embeddings show clear 2D structure.
"""
num_patches = 196 # 14x14 grid
embed_dim = 768
pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))
# Compute similarity between position embeddings
# In trained models, nearby patches have similar embeddings
pos_only = pos_embed[0, 1:] # Exclude CLS token
similarity = torch.matmul(pos_only, pos_only.T)
print("Position Embedding Similarity Matrix:")
print(f" Shape: {similarity.shape}")
print(" In trained ViT, this shows 2D spatial structure")
print(" Nearby patches have high similarity")The learned position embeddings in trained ViT models show remarkable structure: embeddings for spatially close patches are similar, and the model learns row and column structure despite receiving only 1D position indices. This demonstrates that transformers can discover spatial relationships from data.
The CLS Token
Following BERT, ViT prepends a learnable [CLS] token to the sequence of patch embeddings. After processing through transformer layers, the final representation of this token serves as the aggregate image representation for classification. The CLS token attends to all patches, effectively learning to summarize the entire image.
class CLSToken(nn.Module):
"""
Learnable [CLS] token prepended to patch sequence.
Its final representation is used for classification.
"""
def __init__(self, embed_dim):
super().__init__()
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
nn.init.trunc_normal_(self.cls_token, std=0.02)
def forward(self, x):
"""
Prepend CLS token to input sequence.
Args:
x: Patch embeddings (B, num_patches, embed_dim)
Returns:
Sequence with CLS token (B, num_patches + 1, embed_dim)
"""
B = x.shape[0]
cls_tokens = self.cls_token.expand(B, -1, -1)
return torch.cat([cls_tokens, x], dim=1)
def alternative_pooling_strategies():
"""
Alternatives to CLS token for image representation.
"""
# 1. Global Average Pooling (GAP)
# Average all patch representations
def gap_pooling(x):
# x: (B, num_patches + 1, embed_dim)
# Exclude CLS token, average patches
return x[:, 1:].mean(dim=1) # (B, embed_dim)
# 2. Global Max Pooling
def gmp_pooling(x):
return x[:, 1:].max(dim=1)[0] # (B, embed_dim)
# 3. Attention Pooling
class AttentionPooling(nn.Module):
def __init__(self, embed_dim):
super().__init__()
self.attention = nn.Linear(embed_dim, 1)
def forward(self, x):
# x: (B, seq_len, embed_dim)
weights = F.softmax(self.attention(x), dim=1) # (B, seq_len, 1)
return (weights * x).sum(dim=1) # (B, embed_dim)
print("Pooling Strategies for ViT:")
print(" CLS Token: Standard ViT approach, learns to aggregate")
print(" GAP: Simple averaging, works well for some tasks")
print(" Attention Pooling: Learned weighted average")Multi-Head Self-Attention for Vision
The core of the transformer is multi-head self-attention, which allows each patch to attend to all other patches. For vision, this means every part of the image can directly influence every other part, unlike CNNs where information flows only through local connections.
class MultiHeadSelfAttention(nn.Module):
"""
Multi-Head Self-Attention mechanism for Vision Transformer.
"""
def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True):
super().__init__()
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scale = self.head_dim ** -0.5 # 1/sqrt(d_k)
# Combined QKV projection for efficiency
self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=bias)
# Output projection
self.proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.attn_dropout = nn.Dropout(dropout)
self.proj_dropout = nn.Dropout(dropout)
def forward(self, x, return_attention=False):
"""
Args:
x: Input tensor (B, seq_len, embed_dim)
return_attention: Whether to return attention weights
Returns:
Output tensor (B, seq_len, embed_dim)
Optionally: attention weights (B, num_heads, seq_len, seq_len)
"""
B, N, C = x.shape
# Compute Q, K, V in one projection
qkv = self.qkv(x) # (B, N, 3 * embed_dim)
qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, num_heads, N, head_dim)
q, k, v = qkv[0], qkv[1], qkv[2]
# Scaled dot-product attention
# (B, num_heads, N, head_dim) @ (B, num_heads, head_dim, N)
attn = (q @ k.transpose(-2, -1)) * self.scale # (B, num_heads, N, N)
attn = F.softmax(attn, dim=-1)
attn = self.attn_dropout(attn)
# Apply attention to values
# (B, num_heads, N, N) @ (B, num_heads, N, head_dim)
x = attn @ v # (B, num_heads, N, head_dim)
# Reshape back
x = x.transpose(1, 2).reshape(B, N, C) # (B, N, embed_dim)
# Output projection
x = self.proj(x)
x = self.proj_dropout(x)
if return_attention:
return x, attn
return x
def visualize_attention_patterns():
"""
Demonstrate how attention connects patches across the image.
"""
print("ViT Attention Characteristics:")
print(" - Every patch can attend to every other patch")
print(" - Global receptive field from the first layer")
print(" - Different heads learn different attention patterns:")
print(" * Some attend to nearby patches (local)")
print(" * Some attend to semantically similar regions")
print(" * Some attend to specific image locations")
print(" - CLS token attends to all patches, aggregating information")
visualize_attention_patterns()The attention mechanism computes pairwise relationships between all patches with complexity $O(N^2)$ where $N$ is the number of patches. For a 224×224 image with 16×16 patches, this is $196^2 = 38,416$ attention computations per layer per head—substantial but manageable.
MLP Block
Each transformer block contains a multi-layer perceptron (MLP) applied independently to each token. The MLP typically expands the hidden dimension by a factor of 4, applies a non-linearity, then projects back to the original dimension.
class MLP(nn.Module):
"""
Multi-Layer Perceptron block in Vision Transformer.
Applied independently to each token position.
"""
def __init__(self, embed_dim, mlp_ratio=4.0, dropout=0.0, activation=nn.GELU):
super().__init__()
hidden_dim = int(embed_dim * mlp_ratio)
self.fc1 = nn.Linear(embed_dim, hidden_dim)
self.act = activation()
self.fc2 = nn.Linear(hidden_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class GLU_MLP(nn.Module):
"""
Gated Linear Unit MLP variant.
Used in some modern ViT variants for improved performance.
"""
def __init__(self, embed_dim, mlp_ratio=4.0, dropout=0.0):
super().__init__()
hidden_dim = int(embed_dim * mlp_ratio)
# Two parallel projections for gating
self.w1 = nn.Linear(embed_dim, hidden_dim)
self.w2 = nn.Linear(embed_dim, hidden_dim)
self.w3 = nn.Linear(hidden_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# SwiGLU: Swish(W1(x)) * W2(x)
hidden = F.silu(self.w1(x)) * self.w2(x)
hidden = self.dropout(hidden)
output = self.w3(hidden)
return self.dropout(output)Transformer Encoder Block
A transformer encoder block combines multi-head self-attention and MLP with residual connections and layer normalization. ViT uses pre-normalization (LayerNorm before attention/MLP) rather than the original post-normalization, which improves training stability.
class TransformerEncoderBlock(nn.Module):
"""
Single Transformer Encoder block for ViT.
Combines self-attention and MLP with residual connections.
"""
def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.0,
attention_dropout=0.0, drop_path=0.0):
super().__init__()
# Pre-normalization
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = MultiHeadSelfAttention(
embed_dim, num_heads, dropout=attention_dropout
)
self.norm2 = nn.LayerNorm(embed_dim)
self.mlp = MLP(embed_dim, mlp_ratio, dropout)
# Stochastic depth (drop path) for regularization
self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()
def forward(self, x):
# Self-attention with residual
x = x + self.drop_path(self.attn(self.norm1(x)))
# MLP with residual
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class DropPath(nn.Module):
"""
Stochastic Depth: randomly drop entire residual branches during training.
Regularization technique that improves generalization.
"""
def __init__(self, drop_prob=0.0):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x):
if self.drop_prob == 0 or not self.training:
return x
keep_prob = 1 - self.drop_prob
# Generate random tensor, same for all elements in batch
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor = random_tensor.floor()
# Scale to maintain expected value
output = x / keep_prob * random_tensor
return outputComplete Vision Transformer
Combining all components yields the complete Vision Transformer architecture:
class VisionTransformer(nn.Module):
"""
Vision Transformer (ViT) for image classification.
Architecture:
1. Patch embedding: Convert image to sequence of patch embeddings
2. Add CLS token and positional embeddings
3. Process through transformer encoder blocks
4. Extract CLS token representation
5. Classification head
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_channels=3,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
dropout=0.0,
attention_dropout=0.0,
drop_path_rate=0.0,
):
super().__init__()
self.num_classes = num_classes
self.embed_dim = embed_dim
self.num_patches = (img_size // patch_size) ** 2
# Patch embedding
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
# CLS token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# Position embeddings
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))
self.pos_dropout = nn.Dropout(dropout)
# Stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
# Transformer encoder blocks
self.blocks = nn.ModuleList([
TransformerEncoderBlock(
embed_dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
dropout=dropout,
attention_dropout=attention_dropout,
drop_path=dpr[i]
)
for i in range(depth)
])
# Final layer norm
self.norm = nn.LayerNorm(embed_dim)
# Classification head
self.head = nn.Linear(embed_dim, num_classes)
# Initialize weights
self._init_weights()
def _init_weights(self):
# Initialize patch embedding
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
# Initialize other layers
self.apply(self._init_layer_weights)
def _init_layer_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward_features(self, x):
"""Extract features without classification head."""
B = x.shape[0]
# Patch embedding
x = self.patch_embed(x) # (B, num_patches, embed_dim)
# Prepend CLS token
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat([cls_tokens, x], dim=1) # (B, num_patches + 1, embed_dim)
# Add positional embeddings
x = x + self.pos_embed
x = self.pos_dropout(x)
# Transformer blocks
for block in self.blocks:
x = block(x)
# Final normalization
x = self.norm(x)
return x
def forward(self, x):
"""Forward pass with classification."""
x = self.forward_features(x)
# Extract CLS token
cls_token_final = x[:, 0]
# Classification
x = self.head(cls_token_final)
return x
# Standard ViT configurations
def vit_tiny(num_classes=1000, **kwargs):
return VisionTransformer(
embed_dim=192, depth=12, num_heads=3,
num_classes=num_classes, **kwargs
)
def vit_small(num_classes=1000, **kwargs):
return VisionTransformer(
embed_dim=384, depth=12, num_heads=6,
num_classes=num_classes, **kwargs
)
def vit_base(num_classes=1000, **kwargs):
return VisionTransformer(
embed_dim=768, depth=12, num_heads=12,
num_classes=num_classes, **kwargs
)
def vit_large(num_classes=1000, **kwargs):
return VisionTransformer(
embed_dim=1024, depth=24, num_heads=16,
num_classes=num_classes, **kwargs
)
def vit_huge(num_classes=1000, **kwargs):
return VisionTransformer(
embed_dim=1280, depth=32, num_heads=16,
num_classes=num_classes, **kwargs
)
# Compare model sizes
def compare_vit_variants():
"""Compare different ViT model sizes."""
variants = {
'ViT-Ti': vit_tiny,
'ViT-S': vit_small,
'ViT-B': vit_base,
'ViT-L': vit_large,
}
print("Vision Transformer Variants:")
print("-" * 50)
for name, factory in variants.items():
model = factory()
params = sum(p.numel() for p in model.parameters())
print(f"{name:10} Parameters: {params / 1e6:.1f}M")
compare_vit_variants()Key Takeaways
The Vision Transformer represents a fundamental departure from convolutional neural networks by treating images as sequences of patches and learning all spatial relationships through attention. The architecture consists of patch embedding to convert images to sequences, learnable position embeddings to encode spatial information, a CLS token for aggregating image-level representations, and stacked transformer encoder blocks with self-attention and MLP layers. ViT requires large-scale pretraining to match CNN performance because it lacks the inductive biases (locality, translation equivariance) that CNNs build in architecturally. However, when trained on sufficient data, ViT scales better than CNNs and achieves superior results. The attention mechanism provides global receptive fields from the first layer, enabling direct communication between distant image regions. Understanding ViT architecture is essential for working with modern vision systems, as transformers have become the dominant architecture for large-scale vision models.