Intermediate Advanced 120 min read

Chapter 13: Transformer Architecture

Self-attention, multi-head attention, positional encoding, and transformer blocks.

Learning Objectives

["Understand self-attention", "Implement transformers from scratch", "Apply positional encoding"]


13.1 Introduction to Attention Intermediate

Introduction to Attention

The attention mechanism represents one of the most significant advances in deep learning, fundamentally changing how neural networks process sequential and structured data. Before attention, sequence-to-sequence models compressed entire inputs into fixed-size vectors, creating an information bottleneck that degraded performance on long sequences.

The Information Bottleneck Problem

Traditional encoder-decoder architectures suffered from a fundamental limitation. The encoder processed the entire input sequence and produced a single fixed-dimensional context vector. The decoder then had to generate the entire output conditioned only on this vector, regardless of input length.

Consider translating a long paragraph. The encoder must compress all nuances, entities, and grammatical structures into perhaps 512 numbers. Information inevitably gets lost, particularly for elements early in the sequence.

PYTHON
import torch
import torch.nn as nn
import torch.nn.functional as F

def visualize_bottleneck():
    hidden_size = 256
    max_bits = hidden_size * 32
    
    print("Information Bottleneck Analysis")
    print("=" * 50)
    for seq_len in [10, 50, 100, 200]:
        input_bits = seq_len * 50
        compression = input_bits / max_bits
        status = "OK" if compression <= 1 else f"LOSSY"
        print(f"Length {seq_len:3d}: {compression:.2f}x compression [{status}]")

visualize_bottleneck()

The Core Attention Idea

Attention solves the bottleneck by giving the decoder direct access to all encoder hidden states. At each decoding step, the model computes a weighted combination of encoder states, where weights indicate which input positions are most relevant.

The mechanism works as follows:

  1. The encoder produces hidden states for each input position
  2. Compute compatibility scores between decoder state and each encoder state
  3. Normalize scores using softmax to get attention weights
  4. Compute weighted sum of encoder states as context vector
  5. Use context with decoder state to predict the next output

PYTHON
class BasicAttention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim):
        super().__init__()
        self.encoder_dim = encoder_dim
        self.decoder_dim = decoder_dim
        
    def forward(self, encoder_outputs, decoder_hidden):
        scores = torch.bmm(encoder_outputs, decoder_hidden.unsqueeze(-1)).squeeze(-1)
        attention_weights = F.softmax(scores, dim=1)
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs).squeeze(1)
        return context, attention_weights

Query, Key, and Value

Modern attention uses the query-key-value (QKV) paradigm:

  • Query (Q): What we are looking for
  • Key (K): What each position offers for matching
  • Value (V): The actual information to retrieve

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

The scaling factor prevents dot products from growing too large.

PYTHON
class ScaledDotProductAttention(nn.Module):
    def __init__(self, dropout=0.0):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, query, key, value, mask=None):
        d_k = query.size(-1)
        scores = torch.bmm(query, key.transpose(-2, -1)) / (d_k ** 0.5)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        output = torch.bmm(attention_weights, value)
        
        return output, attention_weights

Additive vs Multiplicative Attention

Additive Attention (Bahdanau): Concatenates query and key, passes through a feedforward network. More expressive but slower.

Multiplicative Attention (Luong): Computes dot product between query and key. Faster and scales better.

PYTHON
class AdditiveAttention(nn.Module):
    def __init__(self, query_dim, key_dim, hidden_dim):
        super().__init__()
        self.query_proj = nn.Linear(query_dim, hidden_dim, bias=False)
        self.key_proj = nn.Linear(key_dim, hidden_dim, bias=False)
        self.v = nn.Linear(hidden_dim, 1, bias=False)
        
    def forward(self, query, key, value, mask=None):
        query_proj = self.query_proj(query).unsqueeze(2)
        key_proj = self.key_proj(key).unsqueeze(1)
        scores = self.v(torch.tanh(query_proj + key_proj)).squeeze(-1)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attention_weights = F.softmax(scores, dim=-1)
        output = torch.bmm(attention_weights, value)
        return output, attention_weights

Visualizing Attention

Attention weights reveal what the model focuses on when generating each output. For translation, these weights often correspond to word alignments.

PYTHON
import numpy as np

source = ['The', 'cat', 'sat', 'on', 'the', 'mat']
target = ['Le', 'chat', 'assis', 'sur', 'le', 'tapis']

# Attention weights show alignment
attention = np.array([
    [0.8, 0.1, 0.0, 0.0, 0.1, 0.0],  # Le -> The
    [0.1, 0.8, 0.1, 0.0, 0.0, 0.0],  # chat -> cat
    [0.0, 0.1, 0.8, 0.1, 0.0, 0.0],  # assis -> sat
    [0.0, 0.0, 0.1, 0.8, 0.1, 0.0],  # sur -> on
    [0.1, 0.0, 0.0, 0.1, 0.7, 0.1],  # le -> the
    [0.0, 0.0, 0.0, 0.0, 0.1, 0.9],  # tapis -> mat
])

print("Attention shows word alignments:")
for i, tgt in enumerate(target):
    max_idx = attention[i].argmax()
    print(f"  '{tgt}' attends to '{source[max_idx]}' ({attention[i, max_idx]:.1f})")

Key Takeaways

Attention mechanisms removed the information bottleneck in encoder-decoder architectures. The query-key-value formulation provides a general framework where queries specify what to look for, keys specify what each position offers, and values contain the information to retrieve.

Attention weights provide interpretable insights into model behavior. This interpretability aids debugging and builds trust in model predictions.

The attention mechanism laid the groundwork for transformers, which replace recurrence entirely with attention.

13.2 Self-Attention Intermediate

Self-Attention

Self-attention extends the attention mechanism to relate different positions within a single sequence. Rather than attending from decoder to encoder, self-attention allows each position to attend to all positions in the same sequence, capturing dependencies regardless of distance.

From Cross-Attention to Self-Attention

In encoder-decoder attention, the query comes from the decoder and keys/values come from the encoder. Self-attention simplifies this: queries, keys, and values all derive from the same sequence. Each position can attend to every other position.

For a sequence of length n, self-attention computes an n x n attention matrix where entry (i, j) indicates how much position i attends to position j.

PYTHON
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SelfAttention(nn.Module):
    def __init__(self, embed_dim, dropout=0.0):
        super().__init__()
        self.embed_dim = embed_dim
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        Q = self.q_proj(x)
        K = self.k_proj(x)
        V = self.v_proj(x)
        
        scores = torch.bmm(Q, K.transpose(-2, -1)) / math.sqrt(self.embed_dim)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        output = torch.bmm(attention_weights, V)
        return self.out_proj(output), attention_weights

Positional Encoding

Self-attention treats input as a set, not a sequence. Positional encodings inject position awareness. The transformer uses sinusoidal encodings with useful properties for learning relative positions.

PYTHON
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
        pe = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * 
                            (-math.log(10000.0) / embed_dim))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

Causal Self-Attention

For autoregressive generation, positions must not attend to future positions. Causal self-attention applies a triangular mask.

PYTHON
class CausalSelfAttention(nn.Module):
    def __init__(self, embed_dim, max_len=512, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
        mask = torch.tril(torch.ones(max_len, max_len))
        self.register_buffer('causal_mask', mask)
        
    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        
        Q = self.q_proj(x)
        K = self.k_proj(x)
        V = self.v_proj(x)
        
        scores = torch.bmm(Q, K.transpose(-2, -1)) / math.sqrt(self.embed_dim)
        
        mask = self.causal_mask[:seq_len, :seq_len]
        scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        output = torch.bmm(attention_weights, V)
        return self.out_proj(output), attention_weights

Computational Complexity

Self-attention has O(n^2) complexity in sequence length. For length n with embedding dimension d:

  • Attention matrix: O(n^2 * d) computation
  • Memory: O(n^2) for attention weights

This quadratic scaling limits standard self-attention to sequences of a few thousand tokens.

Self-Attention vs RNNs

Self-attention advantages:

  1. Parallel computation - all positions computed simultaneously
  2. Constant path length - direct connections between any positions
  3. Better gradient flow - no sequential dependencies

RNN advantages:

  • Linear complexity in sequence length
  • Natural streaming capability
  • Implicit position awareness
  • Key Takeaways

Self-attention enables each position to directly attend to all other positions, capturing long-range dependencies in constant path length. Positional encodings are required since attention treats inputs as sets.

Causal masking enables autoregressive generation by preventing attention to future tokens. The quadratic complexity limits application to moderately long sequences, motivating efficient variants.

13.3 Multi-Head Attention Intermediate

Multi-Head Attention

Multi-head attention extends single-head attention by running multiple attention operations in parallel, each learning different relationship patterns. Rather than performing one attention function with full-dimensional keys, queries, and values, multi-head attention projects them into multiple lower-dimensional subspaces and applies attention independently in each.

Motivation for Multiple Heads

A single attention head can only focus on one type of relationship at a time. In language, words relate to each other in multiple ways simultaneously: syntactic relationships (subject-verb agreement), semantic relationships (synonymy, antonymy), positional relationships (nearby words), and more.

Multi-head attention allows the model to jointly attend to information from different representation subspaces. One head might learn to attend to syntactic dependencies while another captures semantic similarity.

The Multi-Head Mechanism

Given an embedding dimension d and h heads, each head operates on dimension d/h. The computation proceeds as:

  1. Project input to h separate query, key, and value representations
  2. Apply scaled dot-product attention to each head independently
  3. Concatenate all head outputs
  4. Apply final linear projection

PYTHON
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.0):
        super().__init__()
        assert embed_dim % num_heads == 0
        
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        # Combined projections for efficiency
        self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape
        
        # Project to Q, K, V
        qkv = self.qkv_proj(x)
        qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # [3, batch, heads, seq, head_dim]
        Q, K, V = qkv[0], qkv[1], qkv[2]
        
        # Attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        # Apply attention
        output = torch.matmul(attention_weights, V)
        
        # Concatenate heads
        output = output.transpose(1, 2).reshape(batch_size, seq_len, self.embed_dim)
        output = self.out_proj(output)
        
        return output, attention_weights


def demo_multihead():
    batch_size = 2
    seq_len = 10
    embed_dim = 64
    num_heads = 8
    
    x = torch.randn(batch_size, seq_len, embed_dim)
    mha = MultiHeadAttention(embed_dim, num_heads)
    output, weights = mha(x)
    
    print("Multi-Head Attention Demo")
    print("-" * 40)
    print(f"Input: {x.shape}")
    print(f"Output: {output.shape}")
    print(f"Attention weights: {weights.shape}")
    print(f"  - {num_heads} heads")
    print(f"  - Each head: {embed_dim // num_heads} dimensions")

demo_multihead()

Head Specialization

Different heads learn to capture different types of relationships. Research has shown that in trained transformers:

  • Some heads attend to previous/next tokens (local patterns)
  • Some heads attend to specific syntactic roles
  • Some heads attend to rare or specific tokens
  • Some heads develop positional preferences

PYTHON
def analyze_head_patterns():
    print("Head Specialization Examples")
    print("=" * 50)
    
    patterns = [
        ("Head 0", "Previous token attention", "Local context"),
        ("Head 1", "Subject-verb agreement", "Syntactic"),
        ("Head 2", "Coreference resolution", "Semantic"),
        ("Head 3", "Punctuation attention", "Structural"),
        ("Head 4", "Long-range dependency", "Global context"),
        ("Head 5", "Named entity attention", "Semantic"),
        ("Head 6", "Position-based", "Positional"),
        ("Head 7", "Aggregate/summary", "Global"),
    ]
    
    for head, pattern, category in patterns:
        print(f"  {head}: {pattern} ({category})")

analyze_head_patterns()

Efficient Implementation

Modern implementations fuse operations for efficiency. The key optimizations include:

  1. Combined QKV projection in a single matrix multiply
  2. Batched attention across all heads
  3. Memory-efficient attention computation

PYTHON
class EfficientMultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5
        
        self.qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
        self.proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        B, N, C = x.shape
        
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        
        if mask is not None:
            attn = attn.masked_fill(mask == 0, float('-inf'))
            
        attn = attn.softmax(dim=-1)
        attn = self.dropout(attn)
        
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        
        return x

Number of Heads

The choice of number of heads involves tradeoffs:

  • More heads: finer-grained attention patterns, but smaller head dimension
  • Fewer heads: larger head dimension, but less pattern diversity

Common configurations:

  • BERT-base: 12 heads, 768 dim (64 per head)
  • GPT-2: 12 heads, 768 dim (64 per head)
  • GPT-3: 96 heads, 12288 dim (128 per head)

PYTHON
def compare_configurations():
    print("Common Multi-Head Configurations")
    print("-" * 50)
    
    configs = [
        ("BERT-base", 768, 12),
        ("BERT-large", 1024, 16),
        ("GPT-2 small", 768, 12),
        ("GPT-2 medium", 1024, 16),
        ("GPT-2 large", 1280, 20),
        ("GPT-3", 12288, 96),
    ]
    
    for name, dim, heads in configs:
        head_dim = dim // heads
        params = dim * 3 * dim + dim * dim  # QKV + output proj
        print(f"{name:15s}: {dim:5d} dim, {heads:2d} heads, "
              f"{head_dim:3d} per head, {params/1e6:.1f}M params")

compare_configurations()

Key Takeaways

Multi-head attention enables models to learn multiple types of relationships simultaneously. Each head can specialize in different patterns - syntactic, semantic, positional, or structural.

The mechanism splits the embedding into multiple subspaces, applies attention independently in each, and combines results. This provides representational capacity beyond what a single attention head could achieve.

Efficient implementations combine projections and batch operations across heads, making multi-head attention practical for large-scale models.

13.4 Cross-Attention and Encoder-Decoder Attention Intermediate

Cross-Attention and Encoder-Decoder Attention

Cross-attention connects two different sequences, allowing one sequence to attend to another. In transformer encoder-decoder architectures, cross-attention enables the decoder to access encoder representations.

Cross-Attention vs Self-Attention

In self-attention, queries, keys, and values come from the same sequence. In cross-attention, queries come from one sequence while keys and values come from another.

PYTHON
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class CrossAttention(nn.Module):
    def __init__(self, query_dim, kv_dim, num_heads, dropout=0.0):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = query_dim // num_heads
        
        self.q_proj = nn.Linear(query_dim, query_dim)
        self.k_proj = nn.Linear(kv_dim, query_dim)
        self.v_proj = nn.Linear(kv_dim, query_dim)
        self.out_proj = nn.Linear(query_dim, query_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, query, key_value, mask=None):
        batch_size, query_len, _ = query.shape
        _, kv_len, _ = key_value.shape
        
        Q = self.q_proj(query)
        K = self.k_proj(key_value)
        V = self.v_proj(key_value)
        
        Q = Q.view(batch_size, query_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        
        output = torch.matmul(attn, V)
        output = output.transpose(1, 2).reshape(batch_size, query_len, -1)
        
        return self.out_proj(output), attn

Encoder-Decoder Architecture

In a transformer encoder-decoder, the decoder has both self-attention and cross-attention:

  1. Decoder self-attention: Attend to previous decoder positions (causal)
  2. Encoder-decoder cross-attention: Attend to encoder outputs
  3. Feed-forward network: Process combined representations

PYTHON
class TransformerDecoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.norm3 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, encoder_output, causal_mask=None, memory_mask=None):
        attn_out, _ = self.self_attn(x, x, x, attn_mask=causal_mask)
        x = self.norm1(x + self.dropout(attn_out))
        
        cross_out, _ = self.cross_attn(x, encoder_output, encoder_output, key_padding_mask=memory_mask)
        x = self.norm2(x + self.dropout(cross_out))
        
        ff_out = self.ff(x)
        x = self.norm3(x + self.dropout(ff_out))
        
        return x

Applications of Cross-Attention

Cross-attention enables many sequence-to-sequence tasks:

  • Machine Translation: Decoder attends to source language encoder outputs
  • Image Captioning: Text decoder attends to image features from vision encoder
  • Speech Recognition: Text decoder attends to audio encoder features
  • Document Summarization: Summary decoder attends to document encoder
  • Attention Patterns in Cross-Attention

Cross-attention learns meaningful alignments:

  • Word alignment in translation shows source-target correspondences
  • Spatial attention in image captioning highlights relevant image regions
  • Temporal alignment in speech recognition tracks audio-text alignment
  • Key Takeaways

Cross-attention connects different sequences, with queries from one sequence attending to keys and values from another. In encoder-decoder transformers, this allows the decoder to access encoded input representations.

The mechanism enables translation, captioning, summarization, and any task requiring conditioning on another sequence. Unlike self-attention, cross-attention does not require causal masking, though padding masks remain important.

13.5 Efficient Attention Mechanisms Advanced

Efficient Attention Mechanisms

Standard self-attention has O(n^2) complexity in sequence length, limiting its application to long sequences. Efficient attention mechanisms reduce this complexity through sparse patterns, linear approximations, or memory-efficient computation.

The Quadratic Bottleneck

For a sequence of length n, standard attention computes n^2 attention scores and requires O(n^2 d) computation. At long sequence lengths, this becomes prohibitive.

PYTHON
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def analyze_attention_memory():
    print("Attention Memory Requirements")
    print("=" * 50)
    for seq_len in [512, 2048, 8192, 32768]:
        bytes_needed = seq_len * seq_len * 4
        gb = bytes_needed / (1024**3)
        print(f"Length {seq_len:6d}: {gb:8.2f} GB per head")

analyze_attention_memory()

Sparse Attention Patterns

Sparse attention computes only a subset of attention scores:

  • Local attention: Each position attends to nearby positions
  • Strided attention: Attend to every k-th position
  • Global tokens: Special tokens that attend to all positions

PYTHON
class LocalAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, window_size=256):
        super().__init__()
        self.window_size = window_size
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
        self.proj = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        
        # Create local attention mask
        mask = torch.ones(N, N, device=x.device)
        for i in range(N):
            start = max(0, i - self.window_size // 2)
            end = min(N, i + self.window_size // 2)
            mask[i, :start] = 0
            mask[i, end:] = 0
        
        attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
        attn = attn.masked_fill(mask == 0, float('-inf'))
        attn = attn.softmax(dim=-1)
        
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        return self.proj(x)

Linear Attention

Linear attention avoids the quadratic attention matrix by reordering computation. This reduces complexity to O(n d^2).

PYTHON
class LinearAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, eps=1e-6):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.eps = eps
        self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
        self.proj = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        
        q = F.elu(q) + 1
        k = F.elu(k) + 1
        
        kv = torch.einsum('bhnd,bhnv->bhdv', k, v)
        qkv = torch.einsum('bhnd,bhdv->bhnv', q, kv)
        
        k_sum = k.sum(dim=2, keepdim=True)
        normalizer = torch.einsum('bhnd,bhkd->bhn', q, k_sum) + self.eps
        out = qkv / normalizer.unsqueeze(-1)
        
        out = out.transpose(1, 2).reshape(B, N, C)
        return self.proj(out)

Flash Attention

Flash Attention is a memory-efficient exact attention algorithm. It computes attention in blocks, never materializing the full attention matrix.

Key benefits:

  • Memory: O(N) instead of O(N^2)
  • Speed: 2-4x faster on long sequences
  • Exact: No approximation
  • Longformer and BigBird Patterns

These models combine local and global attention:

  • Local sliding window for nearby tokens
  • Global attention on special tokens
  • Random attention for additional connectivity
  • Complexity Comparison

Standard attention: O(n^2 d) time, O(n^2) memory Linear attention: O(n d^2) time, O(n) memory Local attention: O(n w d) time, O(n w) memory Flash attention: O(n^2 d) time, O(n) memory (exact)

Key Takeaways

Efficient attention mechanisms enable transformers to handle long sequences. Sparse patterns reduce complexity by limiting attention scope. Linear attention reformulates computation to avoid the quadratic bottleneck. Flash Attention achieves efficiency through memory-aware computation without approximation.