Intermediate Advanced 90 min read

Chapter 14: Pre-trained Language Models

BERT, GPT, tokenization, and the Hugging Face ecosystem.

Learning Objectives

["Use pre-trained models", "Understand tokenization", "Navigate Hugging Face Hub"]


14.1 The Transformer Architecture Intermediate

The Transformer Architecture

The transformer architecture, introduced in "Attention Is All You Need" (2017), replaced recurrence with self-attention as the primary mechanism for sequence modeling. This architectural shift enabled unprecedented parallelization during training and led to dramatic improvements in translation, language modeling, and eventually all of natural language processing.

From RNNs to Transformers

Recurrent neural networks process sequences step by step, maintaining hidden state that accumulates information. This sequential nature creates two fundamental limitations: training cannot be parallelized across time steps, and information must traverse many steps to connect distant positions.

Transformers address both limitations by processing all positions simultaneously through self-attention. Every position can directly attend to every other position in a single operation, eliminating the need for sequential processing.

PYTHON
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, embed_dim),
            nn.Dropout(dropout)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        # Self-attention with residual
        attn_out, _ = self.attention(x, x, x, attn_mask=mask)
        x = self.norm1(x + self.dropout(attn_out))
        
        # Feed-forward with residual
        ff_out = self.ff(x)
        x = self.norm2(x + ff_out)
        
        return x

Core Components

The transformer consists of several key components working together:

Multi-Head Self-Attention: Allows each position to gather information from all other positions, with multiple heads capturing different relationship types.

Position-wise Feed-Forward Networks: Two-layer networks applied identically to each position, providing non-linear transformation capacity.

Layer Normalization: Stabilizes training by normalizing activations within each layer.

Residual Connections: Enable gradient flow through deep networks by adding inputs to outputs.

PYTHON
class Transformer(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, ff_dim, num_layers, max_len=512, dropout=0.1):
        super().__init__()
        
        # Embeddings
        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.position_embedding = nn.Embedding(max_len, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
        # Transformer blocks
        self.layers = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, ff_dim, dropout)
            for _ in range(num_layers)
        ])
        
        # Output
        self.norm = nn.LayerNorm(embed_dim)
        self.output = nn.Linear(embed_dim, vocab_size)
        
        self.embed_dim = embed_dim
        
    def forward(self, x, mask=None):
        batch_size, seq_len = x.shape
        
        # Token + position embeddings
        positions = torch.arange(seq_len, device=x.device).unsqueeze(0)
        x = self.token_embedding(x) + self.position_embedding(positions)
        x = self.dropout(x)
        
        # Apply transformer blocks
        for layer in self.layers:
            x = layer(x, mask)
        
        x = self.norm(x)
        return self.output(x)

The Attention Mechanism in Detail

Self-attention computes a weighted sum of values based on query-key compatibility:

  1. Project input to queries, keys, and values
  2. Compute attention scores as scaled dot products
  3. Apply softmax to get attention weights
  4. Multiply weights by values to get output

PYTHON
class SelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5
        
        self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
        self.proj = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, x, mask=None):
        B, N, C = x.shape
        
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        
        if mask is not None:
            attn = attn.masked_fill(mask == 0, float('-inf'))
        
        attn = attn.softmax(dim=-1)
        
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        return self.proj(x)

Feed-Forward Networks

Each transformer block includes a position-wise feed-forward network, typically expanding the dimension by 4x before projecting back:

PYTHON
class FeedForward(nn.Module):
    def __init__(self, embed_dim, ff_dim, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, embed_dim),
            nn.Dropout(dropout)
        )
        
    def forward(self, x):
        return self.net(x)

Why Transformers Work

Several factors contribute to transformer effectiveness:

Parallelization: All positions processed simultaneously during training, enabling efficient GPU utilization.

Direct connections: Any two positions connected through a single attention operation, avoiding vanishing gradients over long distances.

Flexibility: Attention patterns learned from data rather than fixed by architecture.

Scalability: Performance improves predictably with more parameters and data.

Key Takeaways

The transformer architecture replaces recurrence with self-attention, enabling parallel training and direct modeling of long-range dependencies. Core components include multi-head attention, feed-forward networks, layer normalization, and residual connections.

This architecture forms the foundation for modern language models including BERT, GPT, and their successors. Understanding transformer fundamentals is essential for working with contemporary NLP systems.

14.2 Encoder and Decoder Blocks Intermediate

Encoder and Decoder Blocks

The original transformer uses an encoder-decoder architecture. The encoder processes the input sequence, and the decoder generates output while attending to the encoded input.

Encoder Block Structure

Each encoder block contains two sub-layers:

  1. Multi-head self-attention
  2. Position-wise feed-forward network

Both use residual connections and layer normalization.

PYTHON
import torch
import torch.nn as nn

class EncoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None, padding_mask=None):
        attn_out, _ = self.self_attn(x, x, x, attn_mask=mask, key_padding_mask=padding_mask)
        x = self.norm1(x + self.dropout(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.dropout(ff_out))
        return x

Decoder Block Structure

Each decoder block contains three sub-layers:

  1. Masked multi-head self-attention (causal)
  2. Multi-head cross-attention to encoder output
  3. Position-wise feed-forward network

PYTHON
class DecoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.norm3 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, encoder_output, tgt_mask=None, memory_padding_mask=None):
        attn_out, _ = self.self_attn(x, x, x, attn_mask=tgt_mask)
        x = self.norm1(x + self.dropout(attn_out))
        
        cross_out, _ = self.cross_attn(x, encoder_output, encoder_output, 
                                        key_padding_mask=memory_padding_mask)
        x = self.norm2(x + self.dropout(cross_out))
        
        ff_out = self.ff(x)
        x = self.norm3(x + self.dropout(ff_out))
        return x

Causal Masking

The decoder uses causal masking to prevent attending to future positions:

PYTHON
def generate_causal_mask(seq_len, device):
    mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
    mask = mask.masked_fill(mask == 1, float('-inf'))
    return mask

Full Encoder-Decoder Transformer

PYTHON
class Transformer(nn.Module):
    def __init__(self, src_vocab, tgt_vocab, embed_dim, num_heads, ff_dim, 
                 num_layers, max_len=512, dropout=0.1):
        super().__init__()
        self.src_embed = nn.Embedding(src_vocab, embed_dim)
        self.tgt_embed = nn.Embedding(tgt_vocab, embed_dim)
        self.pos_embed = nn.Embedding(max_len, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
        self.encoder_layers = nn.ModuleList([
            EncoderBlock(embed_dim, num_heads, ff_dim, dropout)
            for _ in range(num_layers)
        ])
        
        self.decoder_layers = nn.ModuleList([
            DecoderBlock(embed_dim, num_heads, ff_dim, dropout)
            for _ in range(num_layers)
        ])
        
        self.encoder_norm = nn.LayerNorm(embed_dim)
        self.decoder_norm = nn.LayerNorm(embed_dim)
        self.output = nn.Linear(embed_dim, tgt_vocab)
        
    def encode(self, src, mask=None):
        positions = torch.arange(src.size(1), device=src.device).unsqueeze(0)
        x = self.src_embed(src) + self.pos_embed(positions)
        x = self.dropout(x)
        for layer in self.encoder_layers:
            x = layer(x, mask)
        return self.encoder_norm(x)
    
    def decode(self, tgt, memory, tgt_mask=None):
        positions = torch.arange(tgt.size(1), device=tgt.device).unsqueeze(0)
        x = self.tgt_embed(tgt) + self.pos_embed(positions)
        x = self.dropout(x)
        for layer in self.decoder_layers:
            x = layer(x, memory, tgt_mask)
        return self.decoder_norm(x)
    
    def forward(self, src, tgt, tgt_mask=None):
        memory = self.encode(src)
        output = self.decode(tgt, memory, tgt_mask)
        return self.output(output)

Pre-Norm vs Post-Norm

Original transformer uses post-norm. Modern implementations often use pre-norm for stable training:

PYTHON
class PreNormBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.ff = nn.Sequential(nn.Linear(embed_dim, ff_dim), nn.GELU(), nn.Linear(ff_dim, embed_dim))
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        x = x + self.dropout(self.attn(self.norm1(x), self.norm1(x), self.norm1(x), attn_mask=mask)[0])
        x = x + self.dropout(self.ff(self.norm2(x)))
        return x

Key Takeaways

Encoder blocks use bidirectional self-attention. Decoder blocks use causal self-attention plus cross-attention. The encoder-decoder architecture suits sequence-to-sequence tasks. Pre-norm placement often provides more stable training.

14.3 Position Embeddings Intermediate

Position Embeddings

Self-attention treats input as a set, lacking inherent position awareness. Position embeddings inject positional information, enabling the model to distinguish token order. Different approaches offer tradeoffs between generalization, efficiency, and expressiveness.

Why Position Information Matters

Without position information, the sentence "dog bites man" and "man bites dog" would produce identical representations. Position embeddings break this symmetry.

PYTHON
import torch
import torch.nn as nn
import math

def demonstrate_position_importance():
    # Without position info, these are identical to attention
    sentence1 = ["dog", "bites", "man"]
    sentence2 = ["man", "bites", "dog"]
    
    print("Without position embeddings:")
    print(f"  '{' '.join(sentence1)}' and '{' '.join(sentence2)}'")
    print("  would have identical self-attention patterns")
    print()
    print("Position embeddings distinguish word order")

demonstrate_position_importance()

Sinusoidal Positional Encoding

The original transformer uses fixed sinusoidal functions:

PYTHON
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
        pe = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * 
                            (-math.log(10000.0) / embed_dim))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

Properties of sinusoidal encoding:

  • Deterministic, no learned parameters
  • Can extrapolate to longer sequences than training
  • Relative positions representable as linear functions
  • Learned Positional Embeddings

Many models learn position embeddings as parameters:

PYTHON
class LearnedPositionalEmbedding(nn.Module):
    def __init__(self, embed_dim, max_len=512, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(max_len, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        positions = torch.arange(x.size(1), device=x.device)
        pos_embed = self.embedding(positions)
        return self.dropout(x + pos_embed)

Learned embeddings:

  • More flexible, can capture dataset-specific patterns
  • Limited to maxlen seen during training
  • Used by BERT, GPT-2
  • Rotary Position Embeddings (RoPE)

RoPE encodes position through rotation in the embedding space:

PYTHON
class RotaryPositionalEmbedding(nn.Module):
    def __init__(self, dim, max_len=2048, base=10000):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        self.max_len = max_len
        
    def forward(self, x, seq_len):
        t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
        freqs = torch.einsum('i,j->ij', t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        return emb.cos(), emb.sin()


def apply_rotary_pos_emb(q, k, cos, sin):
    # Rotate queries and keys
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


def rotate_half(x):
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

RoPE properties:

  • Encodes relative position naturally
  • No maximum sequence length limit
  • Used by LLaMA, GPT-NeoX
  • ALiBi: Attention with Linear Biases

ALiBi adds position-dependent biases to attention scores:

PYTHON
class ALiBi(nn.Module):
    def __init__(self, num_heads):
        super().__init__()
        # Slopes decrease geometrically per head
        slopes = torch.tensor([2 ** (-8 * i / num_heads) for i in range(num_heads)])
        self.register_buffer('slopes', slopes)
        
    def forward(self, attention_scores, seq_len):
        # Create position bias matrix
        positions = torch.arange(seq_len)
        relative_pos = positions.unsqueeze(0) - positions.unsqueeze(1)
        
        # Apply head-specific slopes
        bias = self.slopes.view(-1, 1, 1) * relative_pos.unsqueeze(0)
        
        return attention_scores + bias

ALiBi properties:

  • No additional parameters
  • Excellent length generalization
  • Simple to implement
  • Relative Position Encodings

Instead of absolute positions, encode relative distances:

PYTHON
class RelativePositionBias(nn.Module):
    def __init__(self, num_heads, max_distance=128):
        super().__init__()
        self.num_heads = num_heads
        self.max_distance = max_distance
        
        # Learnable bias for each relative position
        self.relative_bias = nn.Embedding(2 * max_distance + 1, num_heads)
        
    def forward(self, seq_len):
        positions = torch.arange(seq_len)
        relative_pos = positions.unsqueeze(0) - positions.unsqueeze(1)
        relative_pos = relative_pos.clamp(-self.max_distance, self.max_distance)
        relative_pos = relative_pos + self.max_distance  # Shift to positive
        
        bias = self.relative_bias(relative_pos)
        return bias.permute(2, 0, 1)  # [heads, seq, seq]

Comparison

| Method | Parameters | Length Generalization | Relative Position | |--------|-----------|----------------------|-------------------| | Sinusoidal | None | Good | Implicit | | Learned | O(maxlen d) | Poor | No | | RoPE | None | Excellent | Yes | | ALiBi | None | Excellent | Yes | | Relative | O(max_dist heads) | Good | Yes |

Key Takeaways

Position embeddings are essential for transformers to understand token order. Sinusoidal encodings require no parameters and generalize to longer sequences. Learned embeddings are more flexible but limited to training length.

Modern approaches like RoPE and ALiBi provide better length generalization and natural relative position encoding. The choice depends on requirements for extrapolation, parameter efficiency, and computational cost.

14.4 Training Transformers Advanced

Training Transformers

Training transformers requires careful attention to optimization, regularization, and efficiency. The architecture presents unique challenges: deep networks with attention mechanisms can be unstable, and large models require distributed training strategies.

Learning Rate Schedule

Transformers benefit from learning rate warmup followed by decay:

PYTHON
import torch
import torch.nn as nn
import math

class TransformerLRScheduler:
    def __init__(self, optimizer, embed_dim, warmup_steps=4000):
        self.optimizer = optimizer
        self.embed_dim = embed_dim
        self.warmup_steps = warmup_steps
        self.step_num = 0
        
    def step(self):
        self.step_num += 1
        lr = self.get_lr()
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
            
    def get_lr(self):
        step = self.step_num
        return self.embed_dim ** (-0.5) * min(
            step ** (-0.5),
            step * self.warmup_steps ** (-1.5)
        )


def visualize_lr_schedule():
    embed_dim = 512
    warmup = 4000
    
    print("Learning Rate Schedule")
    print("-" * 40)
    for step in [1, 1000, 4000, 10000, 100000]:
        lr = embed_dim ** (-0.5) * min(step ** (-0.5), step * warmup ** (-1.5))
        print(f"Step {step:6d}: lr = {lr:.6f}")

visualize_lr_schedule()

Adam Optimizer Settings

Transformers typically use Adam with specific hyperparameters:

PYTHON
def create_optimizer(model, lr=1e-4, betas=(0.9, 0.98), eps=1e-9, weight_decay=0.01):
    # Separate parameters for weight decay
    decay_params = []
    no_decay_params = []
    
    for name, param in model.named_parameters():
        if 'bias' in name or 'norm' in name or 'embedding' in name:
            no_decay_params.append(param)
        else:
            decay_params.append(param)
    
    param_groups = [
        {'params': decay_params, 'weight_decay': weight_decay},
        {'params': no_decay_params, 'weight_decay': 0.0}
    ]
    
    return torch.optim.AdamW(param_groups, lr=lr, betas=betas, eps=eps)

Label Smoothing

Label smoothing prevents overconfident predictions and improves generalization:

PYTHON
class LabelSmoothingLoss(nn.Module):
    def __init__(self, vocab_size, smoothing=0.1, ignore_index=-100):
        super().__init__()
        self.vocab_size = vocab_size
        self.smoothing = smoothing
        self.ignore_index = ignore_index
        
    def forward(self, logits, target):
        # logits: [batch, seq, vocab]
        # target: [batch, seq]
        
        logits = logits.view(-1, self.vocab_size)
        target = target.view(-1)
        
        # Create smoothed distribution
        smooth_target = torch.zeros_like(logits)
        smooth_target.fill_(self.smoothing / (self.vocab_size - 1))
        
        # Set correct class probability
        mask = target != self.ignore_index
        smooth_target[mask] = smooth_target[mask].scatter_(
            1, target[mask].unsqueeze(1), 1.0 - self.smoothing
        )
        
        # Cross entropy with smoothed targets
        log_probs = torch.log_softmax(logits, dim=-1)
        loss = -torch.sum(smooth_target * log_probs, dim=-1)
        
        return loss[mask].mean()

Dropout Strategies

Transformers apply dropout at multiple points:

PYTHON
class TransformerWithDropout(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super().__init__()
        
        # Embedding dropout
        self.embed_dropout = nn.Dropout(dropout)
        
        # Attention dropout (applied to attention weights)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        
        # Residual dropout
        self.residual_dropout = nn.Dropout(dropout)
        
        # Feed-forward with internal dropout
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, embed_dim),
            nn.Dropout(dropout)
        )

Gradient Clipping

Gradient clipping prevents exploding gradients:

PYTHON
def train_step(model, optimizer, batch, max_grad_norm=1.0):
    model.train()
    optimizer.zero_grad()
    
    src, tgt = batch
    output = model(src, tgt[:, :-1])
    
    loss = compute_loss(output, tgt[:, 1:])
    loss.backward()
    
    # Clip gradients
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
    
    optimizer.step()
    return loss.item()

Mixed Precision Training

FP16 training reduces memory and improves speed:

PYTHON
from torch.cuda.amp import autocast, GradScaler

def train_with_amp(model, optimizer, dataloader, epochs):
    scaler = GradScaler()
    
    for epoch in range(epochs):
        for batch in dataloader:
            optimizer.zero_grad()
            
            with autocast():
                output = model(batch['src'], batch['tgt'][:, :-1])
                loss = compute_loss(output, batch['tgt'][:, 1:])
            
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()

Gradient Accumulation

For large effective batch sizes with limited memory:

PYTHON
def train_with_accumulation(model, optimizer, dataloader, accumulation_steps=4):
    model.train()
    optimizer.zero_grad()
    
    for i, batch in enumerate(dataloader):
        output = model(batch['src'], batch['tgt'][:, :-1])
        loss = compute_loss(output, batch['tgt'][:, 1:])
        loss = loss / accumulation_steps
        loss.backward()
        
        if (i + 1) % accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()

Key Takeaways

Training transformers requires warmup learning rate schedules, careful optimizer settings, and regularization through dropout and label smoothing. Gradient clipping prevents instability.

Mixed precision training and gradient accumulation enable training larger models with limited hardware. These techniques combine to make transformer training practical across different scales.

14.5 Transformer Variants Advanced

Transformer Variants

The original transformer architecture has spawned numerous variants, each optimized for different tasks and constraints. Understanding these variants illuminates the design space of attention-based models.

Encoder-Only Transformers

Encoder-only models process input bidirectionally, seeing all tokens simultaneously. This makes them ideal for understanding tasks.

PYTHON
import torch
import torch.nn as nn

class BERTStyleEncoder(nn.Module):
    def __init__(self, vocab_size, embed_dim=768, num_heads=12,
                 num_layers=12, ff_dim=3072, max_seq_len=512, dropout=0.1):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, embed_dim)
        self.pos_embed = nn.Embedding(max_seq_len, embed_dim)
        self.segment_embed = nn.Embedding(2, embed_dim)
        self.embed_norm = nn.LayerNorm(embed_dim)
        self.embed_dropout = nn.Dropout(dropout)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads,
            dim_feedforward=ff_dim, dropout=dropout,
            activation="gelu", batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        self.pooler = nn.Sequential(nn.Linear(embed_dim, embed_dim), nn.Tanh())
        
    def forward(self, input_ids, segment_ids=None, attention_mask=None):
        seq_len = input_ids.size(1)
        positions = torch.arange(seq_len, device=input_ids.device)
        x = self.token_embed(input_ids) + self.pos_embed(positions)
        if segment_ids is not None:
            x = x + self.segment_embed(segment_ids)
        x = self.embed_dropout(self.embed_norm(x))
        if attention_mask is not None:
            attention_mask = attention_mask == 0
        encoded = self.encoder(x, src_key_padding_mask=attention_mask)
        pooled = self.pooler(encoded[:, 0])
        return encoded, pooled

The bidirectional context allows each token to attend to all other tokens, capturing rich contextual representations.

Decoder-Only Transformers

Decoder-only models use causal attention for generation tasks.

PYTHON
class GPTStyleDecoder(nn.Module):
    def __init__(self, vocab_size, embed_dim=768, num_heads=12,
                 num_layers=12, ff_dim=3072, max_seq_len=1024, dropout=0.1):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, embed_dim)
        self.pos_embed = nn.Embedding(max_seq_len, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=embed_dim, nhead=num_heads,
                dim_feedforward=ff_dim, dropout=dropout, batch_first=True
            ) for _ in range(num_layers)
        ])
        self.final_norm = nn.LayerNorm(embed_dim)
        self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
        self.lm_head.weight = self.token_embed.weight
        
    def forward(self, input_ids):
        seq_len = input_ids.size(1)
        positions = torch.arange(seq_len, device=input_ids.device)
        x = self.dropout(self.token_embed(input_ids) + self.pos_embed(positions))
        causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=input_ids.device), diagonal=1).bool()
        for layer in self.layers:
            x = layer(x, src_mask=causal_mask)
        return self.lm_head(self.final_norm(x))
    
    def generate(self, input_ids, max_new_tokens, temperature=1.0):
        for _ in range(max_new_tokens):
            logits = self.forward(input_ids)[:, -1, :] / temperature
            next_token = torch.multinomial(torch.softmax(logits, dim=-1), 1)
            input_ids = torch.cat([input_ids, next_token], dim=1)
        return input_ids

Encoder-Decoder Transformers

The original transformer uses both encoder and decoder for sequence-to-sequence tasks.

PYTHON
class T5StyleTransformer(nn.Module):
    def __init__(self, vocab_size, embed_dim=512, num_heads=8,
                 num_encoder_layers=6, num_decoder_layers=6, ff_dim=2048):
        super().__init__()
        self.shared_embed = nn.Embedding(vocab_size, embed_dim)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads,
            dim_feedforward=ff_dim, batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_encoder_layers)
        
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=embed_dim, nhead=num_heads,
            dim_feedforward=ff_dim, batch_first=True
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_decoder_layers)
        self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
        
    def forward(self, src_ids, tgt_ids, src_mask=None):
        memory = self.encoder(self.shared_embed(src_ids), src_key_padding_mask=src_mask)
        seq_len = tgt_ids.size(1)
        causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=tgt_ids.device), diagonal=1).bool()
        decoded = self.decoder(self.shared_embed(tgt_ids), memory, tgt_mask=causal_mask)
        return self.lm_head(decoded)

Pre-Norm vs Post-Norm

Layer normalization placement affects training dynamics significantly.

PYTHON
class PreNormBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim), nn.GELU(), nn.Linear(ff_dim, embed_dim)
        )
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        attn_out, _ = self.attn(self.norm1(x), self.norm1(x), self.norm1(x))
        x = x + self.dropout(attn_out)
        x = x + self.dropout(self.ff(self.norm2(x)))
        return x

Pre-norm improves gradient flow in deep networks. Most modern transformers use pre-norm for stability.

Mixture of Experts

MoE scales capacity without proportional compute by routing tokens to specialized experts.

PYTHON
class MoELayer(nn.Module):
    def __init__(self, embed_dim, ff_dim, num_experts=8, top_k=2):
        super().__init__()
        self.top_k = top_k
        self.gate = nn.Linear(embed_dim, num_experts)
        self.experts = nn.ModuleList([
            nn.Sequential(nn.Linear(embed_dim, ff_dim), nn.GELU(), nn.Linear(ff_dim, embed_dim))
            for _ in range(num_experts)
        ])
        
    def forward(self, x):
        batch, seq_len, dim = x.shape
        x_flat = x.view(-1, dim)
        probs = torch.softmax(self.gate(x_flat), dim=-1)
        top_probs, top_idx = torch.topk(probs, self.top_k, dim=-1)
        top_probs = top_probs / top_probs.sum(dim=-1, keepdim=True)
        
        output = torch.zeros_like(x_flat)
        for k in range(self.top_k):
            for e, expert in enumerate(self.experts):
                mask = top_idx[:, k] == e
                if mask.any():
                    output[mask] += top_probs[mask, k:k+1] * expert(x_flat[mask])
        return output.view(batch, seq_len, dim)

Key Takeaways

Encoder-only models like BERT excel at understanding. Decoder-only models like GPT specialize in generation. Encoder-decoder models handle sequence-to-sequence tasks. Pre-norm improves stability; MoE scales efficiently. These choices form the foundation for modern language models.