Advanced Expert 120 min read

Chapter 15: Fine-tuning and Alignment

LoRA, RLHF, DPO, and parameter-efficient fine-tuning.

Learning Objectives

["Apply LoRA and QLoRA", "Understand RLHF pipeline", "Implement DPO"]


15.1 Introduction to Large Language Models Intermediate

Introduction to Large Language Models

Large Language Models (LLMs) represent a paradigm shift in natural language processing. These models, with billions of parameters trained on massive text corpora, exhibit remarkable capabilities in understanding and generating human language.

What Makes a Model Large

Scale defines LLMs across three dimensions: parameters, training data, and compute.

PYTHON
import torch
import torch.nn as nn

def count_parameters(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"Total parameters: {total:,}")
    print(f"Trainable parameters: {trainable:,}")
    print(f"Model size (FP32): {total * 4 / 1e9:.2f} GB")
    print(f"Model size (FP16): {total * 2 / 1e9:.2f} GB")
    return total

def estimate_parameters(n_layers, embed_dim, vocab_size=50000, ff_mult=4):
    embedding_params = vocab_size * embed_dim
    attention_params = 4 * embed_dim * embed_dim
    ff_params = 2 * embed_dim * (ff_mult * embed_dim)
    layer_norm_params = 4 * embed_dim
    layer_params = attention_params + ff_params + layer_norm_params
    total = embedding_params + (n_layers * layer_params) + embed_dim
    return total

params_7b = estimate_parameters(n_layers=32, embed_dim=4096)
print(f"Estimated 7B model parameters: {params_7b / 1e9:.1f}B")

Modern LLMs range from 7 billion to over 100 billion parameters. The compute required to train these models follows predictable scaling laws.

The Scaling Laws

Research has revealed predictable relationships between scale and performance.

PYTHON
import numpy as np

class ScalingLaws:
    def __init__(self, E=1.69, A=406.4, B=410.7, alpha=0.34, beta=0.28):
        self.E = E
        self.A = A
        self.B = B
        self.alpha = alpha
        self.beta = beta
    
    def predict_loss(self, N, D):
        return self.E + self.A / (N ** self.alpha) + self.B / (D ** self.beta)
    
    def optimal_tokens(self, N):
        return 20 * N  # Chinchilla ratio
    
    def compute_optimal_allocation(self, compute_budget):
        N = np.sqrt(compute_budget / 120)
        D = self.optimal_tokens(N)
        return int(N), int(D)

scaling = ScalingLaws()
for N in [7e9, 13e9, 70e9]:
    optimal_D = scaling.optimal_tokens(N)
    loss = scaling.predict_loss(N, optimal_D)
    print(f"{N/1e9:.0f}B model: {optimal_D/1e12:.1f}T tokens, Loss: {loss:.3f}")

The Chinchilla scaling laws suggest training on approximately 20 tokens per parameter for compute-optimal training.

Architecture Overview

Modern LLMs use decoder-only transformer architectures with specific modifications.

PYTHON
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
        
    def forward(self, x):
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        return x / rms * self.weight


class LLMBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_mult=4):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        self.norm1 = RMSNorm(embed_dim)
        self.norm2 = RMSNorm(embed_dim)
        
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.o_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        
        ff_dim = int(ff_mult * embed_dim * 2 / 3)
        self.gate_proj = nn.Linear(embed_dim, ff_dim, bias=False)
        self.up_proj = nn.Linear(embed_dim, ff_dim, bias=False)
        self.down_proj = nn.Linear(ff_dim, embed_dim, bias=False)
        
    def forward(self, x, mask=None):
        h = self.norm1(x)
        batch, seq_len, _ = h.shape
        
        q = self.q_proj(h).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(h).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(h).view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        attn = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        if mask is not None:
            attn = attn.masked_fill(mask, float("-inf"))
        attn = torch.softmax(attn, dim=-1)
        out = torch.matmul(attn, v).transpose(1, 2).reshape(batch, seq_len, self.embed_dim)
        x = x + self.o_proj(out)
        
        h = self.norm2(x)
        x = x + self.down_proj(nn.functional.silu(self.gate_proj(h)) * self.up_proj(h))
        return x

Key innovations include RMSNorm for efficiency, SwiGLU activation for better gradients, and rotary position embeddings for length generalization.

Training Data Composition

LLMs train on diverse internet-scale text corpora with careful curation.

PYTHON
TRAINING_DATA_MIX = {
    "web_crawl": 0.60,      # Common Crawl, filtered
    "code": 0.15,           # GitHub repositories  
    "books": 0.08,          # Book corpora
    "conversations": 0.08,  # Forums, Q&A sites
    "scientific": 0.05,     # arXiv papers
    "wikipedia": 0.04,      # Wikipedia dumps
}

def calculate_tokens_needed(params_billions, chinchilla_ratio=20):
    tokens = params_billions * chinchilla_ratio * 1e9
    print(f"{params_billions}B model needs {tokens/1e12:.1f}T tokens")
    return tokens

for size in [7, 13, 70]:
    calculate_tokens_needed(size)

Data quality significantly impacts model capabilities. Filtering, deduplication, and careful mixing of data sources are essential preprocessing steps.

Key Takeaways

Large Language Models achieve capabilities through massive scale in parameters, data, and compute. Scaling laws predict performance improvements and guide resource allocation. Modern architectures incorporate RMSNorm, SwiGLU activation, and rotary embeddings. Training data diversity and quality significantly impact model capabilities.

15.2 Pre-training at Scale Advanced

Pre-training at Scale

Pre-training LLMs requires distributed computing across thousands of GPUs. The computational demands of training billion-parameter models necessitate sophisticated parallelization strategies and careful optimization.

The Pre-training Objective

LLMs typically use next-token prediction as their pre-training objective.

PYTHON
import torch
import torch.nn as nn
import torch.nn.functional as F

class CausalLanguageModelLoss(nn.Module):
    def __init__(self, vocab_size, ignore_index=-100):
        super().__init__()
        self.vocab_size = vocab_size
        self.ignore_index = ignore_index

    def forward(self, logits, labels):
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = labels[:, 1:].contiguous()

        loss = F.cross_entropy(
            shift_logits.view(-1, self.vocab_size),
            shift_labels.view(-1),
            ignore_index=self.ignore_index,
            reduction="mean"
        )
        return loss


def compute_perplexity(loss):
    return torch.exp(loss).item()


class PretrainingStep:
    def __init__(self, model, optimizer, scaler, gradient_accumulation=1):
        self.model = model
        self.optimizer = optimizer
        self.scaler = scaler
        self.gradient_accumulation = gradient_accumulation
        self.step_count = 0

    def train_step(self, batch):
        self.model.train()

        with torch.cuda.amp.autocast():
            logits = self.model(batch["input_ids"])
            loss = F.cross_entropy(
                logits[:, :-1, :].reshape(-1, logits.size(-1)),
                batch["input_ids"][:, 1:].reshape(-1)
            )
            loss = loss / self.gradient_accumulation

        self.scaler.scale(loss).backward()
        self.step_count += 1

        if self.step_count % self.gradient_accumulation == 0:
            self.scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.scaler.step(self.optimizer)
            self.scaler.update()
            self.optimizer.zero_grad()

        return loss.item() * self.gradient_accumulation

Data Parallelism

The simplest form of distributed training replicates the model across GPUs.

PYTHON
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def setup_distributed(rank, world_size):
    dist.init_process_group(
        backend="nccl",
        init_method="env://",
        world_size=world_size,
        rank=rank
    )
    torch.cuda.set_device(rank)

def create_ddp_model(model, rank):
    model = model.to(rank)
    model = DDP(model, device_ids=[rank], output_device=rank)
    return model


class DistributedDataLoader:
    def __init__(self, dataset, batch_size, rank, world_size):
        self.sampler = torch.utils.data.DistributedSampler(
            dataset,
            num_replicas=world_size,
            rank=rank,
            shuffle=True
        )
        self.loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            sampler=self.sampler,
            num_workers=4,
            pin_memory=True
        )

    def set_epoch(self, epoch):
        self.sampler.set_epoch(epoch)

    def __iter__(self):
        return iter(self.loader)

Model Parallelism

For models too large to fit on a single GPU, we split layers across devices.

PYTHON
class ColumnParallelLinear(nn.Module):
    def __init__(self, in_features, out_features, world_size, rank):
        super().__init__()
        self.world_size = world_size
        self.rank = rank
        self.local_out = out_features // world_size

        self.weight = nn.Parameter(
            torch.randn(self.local_out, in_features) * 0.02
        )

    def forward(self, x):
        return F.linear(x, self.weight)


class RowParallelLinear(nn.Module):
    def __init__(self, in_features, out_features, world_size, rank):
        super().__init__()
        self.world_size = world_size
        self.local_in = in_features // world_size

        self.weight = nn.Parameter(
            torch.randn(out_features, self.local_in) * 0.02
        )

    def forward(self, x):
        local_out = F.linear(x, self.weight)
        dist.all_reduce(local_out, op=dist.ReduceOp.SUM)
        return local_out

Tensor parallelism splits matrix operations across GPUs. Column parallel splits output features while row parallel splits input features.

Pipeline Parallelism

Pipeline parallelism splits layers sequentially across GPUs with micro-batching.

PYTHON
class PipelineStage(nn.Module):
    def __init__(self, layers, stage_id, num_stages):
        super().__init__()
        self.layers = nn.ModuleList(layers)
        self.stage_id = stage_id
        self.num_stages = num_stages

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


class GPipeScheduler:
    def __init__(self, num_microbatches, num_stages):
        self.num_microbatches = num_microbatches
        self.num_stages = num_stages

    def generate_schedule(self):
        schedule = []
        for step in range(self.num_microbatches + self.num_stages - 1):
            stage_ops = []
            for stage in range(self.num_stages):
                microbatch = step - stage
                if 0 <= microbatch < self.num_microbatches:
                    stage_ops.append((stage, "forward", microbatch))
            schedule.append(stage_ops)
        return schedule

ZeRO Optimizer

ZeRO partitions optimizer states, gradients, and parameters across GPUs.

PYTHON
class ZeROStage1Optimizer:
    def __init__(self, model, lr, world_size, rank):
        self.model = model
        self.world_size = world_size
        self.rank = rank

        self.param_groups = list(model.parameters())
        self.local_params = self.param_groups[rank::world_size]
        self.optimizer = torch.optim.AdamW(self.local_params, lr=lr)

    def step(self):
        for param in self.param_groups:
            dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)

        self.optimizer.step()

        for i, param in enumerate(self.param_groups):
            src_rank = i % self.world_size
            dist.broadcast(param.data, src=src_rank)

    def zero_grad(self):
        self.optimizer.zero_grad()

ZeRO Stage 1 partitions optimizer states, Stage 2 adds gradient partitioning, and Stage 3 partitions parameters themselves, enabling training of models larger than GPU memory.

Key Takeaways

Pre-training LLMs requires combining multiple parallelism strategies. Data parallelism replicates models across GPUs for batch scaling. Tensor parallelism splits individual layers. Pipeline parallelism distributes layers sequentially. ZeRO optimization reduces memory by partitioning optimizer states. These techniques enable training models with hundreds of billions of parameters.

15.3 Tokenization and Vocabulary Intermediate

Tokenization and Vocabulary

Tokenization converts raw text into discrete tokens that language models can process. The choice of tokenization strategy significantly impacts model performance, efficiency, and ability to handle diverse languages and domains.

Byte-Pair Encoding

BPE builds a vocabulary by iteratively merging the most frequent character pairs.

PYTHON
from collections import Counter
import re

class BPETokenizer:
    def __init__(self, vocab_size=10000):
        self.vocab_size = vocab_size
        self.merges = {}
        self.vocab = {}

    def get_pairs(self, word):
        pairs = Counter()
        symbols = word.split()
        for i in range(len(symbols) - 1):
            pairs[(symbols[i], symbols[i + 1])] += 1
        return pairs

    def merge_pair(self, pair, word):
        pattern = re.escape(" ".join(pair))
        replacement = "".join(pair)
        return re.sub(pattern, replacement, word)

    def train(self, corpus):
        # Initialize with characters
        word_freqs = Counter()
        for text in corpus:
            words = text.split()
            for word in words:
                word_freqs[" ".join(word) + " </w>"] += 1

        # Build vocabulary with character tokens
        self.vocab = {chr(i): i for i in range(256)}
        vocab_idx = 256

        # Iteratively merge most frequent pairs
        for i in range(self.vocab_size - 256):
            pairs = Counter()
            for word, freq in word_freqs.items():
                word_pairs = self.get_pairs(word)
                for pair, count in word_pairs.items():
                    pairs[pair] += count * freq

            if not pairs:
                break

            best_pair = pairs.most_common(1)[0][0]
            self.merges[best_pair] = vocab_idx
            self.vocab["".join(best_pair)] = vocab_idx
            vocab_idx += 1

            # Apply merge to all words
            new_word_freqs = {}
            for word, freq in word_freqs.items():
                new_word = self.merge_pair(best_pair, word)
                new_word_freqs[new_word] = freq
            word_freqs = new_word_freqs

    def encode(self, text):
        tokens = []
        for word in text.split():
            word = " ".join(word) + " </w>"
            for pair, idx in self.merges.items():
                word = self.merge_pair(pair, word)
            tokens.extend([self.vocab[t] for t in word.split()])
        return tokens

SentencePiece and Unigram

SentencePiece treats text as raw bytes, enabling language-agnostic tokenization.

PYTHON
class UnigramTokenizer:
    def __init__(self, vocab_size=32000):
        self.vocab_size = vocab_size
        self.vocab = {}
        self.log_probs = {}

    def train(self, corpus):
        # Start with large vocabulary of substrings
        substring_counts = Counter()
        for text in corpus:
            for i in range(len(text)):
                for j in range(i + 1, min(i + 20, len(text) + 1)):
                    substring_counts[text[i:j]] += 1

        # Initialize probabilities
        total = sum(substring_counts.values())
        for token, count in substring_counts.items():
            self.log_probs[token] = count / total

        # Prune vocabulary using EM
        while len(self.log_probs) > self.vocab_size:
            self.log_probs = self._prune_vocabulary(corpus)

        self.vocab = {token: i for i, token in enumerate(self.log_probs.keys())}

    def _prune_vocabulary(self, corpus):
        # Compute loss increase from removing each token
        losses = {}
        for token in self.log_probs:
            losses[token] = self._compute_loss_increase(token, corpus)

        # Keep tokens with highest loss increase
        sorted_tokens = sorted(losses.items(), key=lambda x: -x[1])
        keep_size = int(len(self.log_probs) * 0.9)
        return {token: self.log_probs[token] for token, _ in sorted_tokens[:keep_size]}

    def _compute_loss_increase(self, removed_token, corpus):
        # Simplified: return frequency as proxy
        return self.log_probs.get(removed_token, 0)

    def encode(self, text):
        # Viterbi algorithm for optimal segmentation
        return self._viterbi_segment(text)

    def _viterbi_segment(self, text):
        n = len(text)
        best_score = [-float("inf")] * (n + 1)
        best_score[0] = 0
        best_edge = [None] * (n + 1)

        for i in range(n):
            for j in range(i + 1, min(i + 20, n + 1)):
                token = text[i:j]
                if token in self.log_probs:
                    score = best_score[i] + self.log_probs[token]
                    if score > best_score[j]:
                        best_score[j] = score
                        best_edge[j] = i

        # Backtrack to find tokens
        tokens = []
        i = n
        while i > 0:
            j = best_edge[i]
            tokens.append(self.vocab[text[j:i]])
            i = j
        return list(reversed(tokens))

Vocabulary Design Considerations

Different vocabulary sizes trade off compression and flexibility.

PYTHON
class TokenizerAnalysis:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def analyze_compression(self, texts):
        total_chars = sum(len(t) for t in texts)
        total_tokens = sum(len(self.tokenizer.encode(t)) for t in texts)

        compression_ratio = total_chars / total_tokens
        print(f"Characters: {total_chars}")
        print(f"Tokens: {total_tokens}")
        print(f"Compression ratio: {compression_ratio:.2f}")
        return compression_ratio

    def analyze_fertility(self, texts, languages):
        for lang, text in zip(languages, texts):
            tokens = self.tokenizer.encode(text)
            words = text.split()
            fertility = len(tokens) / len(words)
            print(f"{lang}: {fertility:.2f} tokens per word")

    def find_rare_tokens(self, corpus, threshold=10):
        token_counts = Counter()
        for text in corpus:
            tokens = self.tokenizer.encode(text)
            token_counts.update(tokens)

        rare = [t for t, c in token_counts.items() if c < threshold]
        print(f"Rare tokens (count < {threshold}): {len(rare)}")
        return rare


# Typical vocabulary sizes
VOCAB_SIZES = {
    "GPT-2": 50257,
    "GPT-3/4": 100000,
    "LLaMA": 32000,
    "Claude": 100000,
    "Gemini": 256000,
}

Special Tokens

LLMs use special tokens for formatting and control.

PYTHON
class SpecialTokens:
    PAD = "<pad>"
    UNK = "<unk>"
    BOS = "<s>"
    EOS = "</s>"
    SEP = "<sep>"
    MASK = "<mask>"

    # Chat format tokens
    SYSTEM = "<|system|>"
    USER = "<|user|>"
    ASSISTANT = "<|assistant|>"
    END_TURN = "<|end|>"


class ChatTokenizer:
    def __init__(self, base_tokenizer, special_tokens):
        self.tokenizer = base_tokenizer
        self.special_tokens = special_tokens

    def encode_chat(self, messages):
        tokens = []
        for msg in messages:
            role = msg["role"]
            content = msg["content"]

            if role == "system":
                tokens.append(self.special_tokens.SYSTEM)
            elif role == "user":
                tokens.append(self.special_tokens.USER)
            elif role == "assistant":
                tokens.append(self.special_tokens.ASSISTANT)

            tokens.extend(self.tokenizer.encode(content))
            tokens.append(self.special_tokens.END_TURN)

        return tokens

    def decode_chat(self, tokens):
        text = self.tokenizer.decode(tokens)
        return text

Key Takeaways

Tokenization is a critical preprocessing step for LLMs. BPE and Unigram are the dominant algorithms, both building subword vocabularies from training data. Vocabulary size affects compression ratio and multilingual performance. Special tokens enable chat formatting and control sequences. The tokenizer must be trained jointly with model considerations.

15.4 Emergent Abilities Intermediate

Emergent Abilities

Emergent abilities are capabilities that appear suddenly as models scale, seemingly absent in smaller models but present in larger ones. These abilities have transformed our understanding of what language models can achieve.

In-Context Learning

LLMs can learn new tasks from examples provided in the prompt without parameter updates.

PYTHON
class InContextLearning:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def few_shot_classification(self, examples, query):
        prompt = "Classify the sentiment of the following texts.\n\n"

        for text, label in examples:
            prompt += f"Text: {text}\nSentiment: {label}\n\n"

        prompt += f"Text: {query}\nSentiment:"

        tokens = self.tokenizer.encode(prompt)
        output = self.model.generate(tokens, max_new_tokens=5)
        return self.tokenizer.decode(output)

    def few_shot_translation(self, examples, query, src_lang, tgt_lang):
        prompt = f"Translate from {src_lang} to {tgt_lang}.\n\n"

        for src, tgt in examples:
            prompt += f"{src_lang}: {src}\n{tgt_lang}: {tgt}\n\n"

        prompt += f"{src_lang}: {query}\n{tgt_lang}:"

        tokens = self.tokenizer.encode(prompt)
        output = self.model.generate(tokens, max_new_tokens=100)
        return self.tokenizer.decode(output)


def demonstrate_icl_scaling():
    model_sizes = ["1B", "7B", "13B", "70B"]
    tasks = ["arithmetic", "translation", "reasoning"]

    print("In-Context Learning Performance by Scale")
    print("-" * 50)
    for task in tasks:
        print(f"\n{task.capitalize()}:")
        for size in model_sizes:
            # Simulated accuracy improvement with scale
            base = {"1B": 0.2, "7B": 0.5, "13B": 0.7, "70B": 0.9}
            task_mod = {"arithmetic": 0.1, "translation": 0.0, "reasoning": -0.1}
            acc = base[size] + task_mod.get(task, 0)
            print(f"  {size}: {acc:.1%} accuracy")

demonstrate_icl_scaling()

Chain-of-Thought Reasoning

Explicit reasoning steps improve performance on complex tasks.

PYTHON
class ChainOfThought:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def solve_with_cot(self, problem):
        prompt = f"""Solve this problem step by step.

Problem: {problem}

Let me think through this step by step:
"""
        tokens = self.tokenizer.encode(prompt)
        output = self.model.generate(tokens, max_new_tokens=500)
        return self.tokenizer.decode(output)

    def few_shot_cot(self, examples, problem):
        prompt = "Solve these math problems step by step.\n\n"

        for prob, solution, answer in examples:
            prompt += f"Problem: {prob}\n"
            prompt += f"Solution: {solution}\n"
            prompt += f"Answer: {answer}\n\n"

        prompt += f"Problem: {problem}\nSolution:"

        tokens = self.tokenizer.encode(prompt)
        output = self.model.generate(tokens, max_new_tokens=300)
        return self.tokenizer.decode(output)


# Example chain of thought
COT_EXAMPLE = {
    "problem": "If a train travels 120 miles in 2 hours, how far will it travel in 5 hours at the same speed?",
    "reasoning": [
        "First, I need to find the speed of the train.",
        "Speed = Distance / Time = 120 miles / 2 hours = 60 mph",
        "Now I can find the distance in 5 hours.",
        "Distance = Speed x Time = 60 mph x 5 hours = 300 miles"
    ],
    "answer": "300 miles"
}

Instruction Following

Larger models better follow complex, multi-step instructions.

PYTHON
class InstructionFollowing:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def execute_instruction(self, instruction):
        prompt = f"""Follow these instructions carefully:

{instruction}

Response:"""

        tokens = self.tokenizer.encode(prompt)
        output = self.model.generate(tokens, max_new_tokens=500)
        return self.tokenizer.decode(output)

    def complex_instruction(self, task, constraints):
        prompt = f"""Complete the following task while respecting all constraints.

Task: {task}

Constraints:
"""
        for i, constraint in enumerate(constraints, 1):
            prompt += f"{i}. {constraint}\n"

        prompt += "\nResponse:"

        tokens = self.tokenizer.encode(prompt)
        output = self.model.generate(tokens, max_new_tokens=500)
        return self.tokenizer.decode(output)


# Instruction complexity levels
INSTRUCTION_BENCHMARKS = {
    "simple": "Write a haiku about mountains.",
    "medium": "Write a haiku about mountains that includes the word 'snow'.",
    "complex": "Write a haiku about mountains, include 'snow', avoid 'cold', end with a question.",
    "very_complex": "Write 3 haikus about mountains, each must include 'snow', the second must rhyme, the third must be a question, and avoid repeating adjectives."
}

Self-Correction

Large models can identify and correct their own errors.

PYTHON
class SelfCorrection:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def generate_with_verification(self, problem, max_attempts=3):
        for attempt in range(max_attempts):
            # Generate solution
            solution = self.generate_solution(problem)

            # Verify solution
            verification = self.verify_solution(problem, solution)

            if verification["is_correct"]:
                return solution

            # Self-correct
            problem = self.create_correction_prompt(
                problem, solution, verification["errors"]
            )

        return solution

    def generate_solution(self, problem):
        prompt = f"Solve: {problem}\nSolution:"
        tokens = self.tokenizer.encode(prompt)
        output = self.model.generate(tokens, max_new_tokens=300)
        return self.tokenizer.decode(output)

    def verify_solution(self, problem, solution):
        prompt = f"""Check if this solution is correct.

Problem: {problem}
Solution: {solution}

Is this correct? If not, identify the errors.
Analysis:"""

        tokens = self.tokenizer.encode(prompt)
        output = self.model.generate(tokens, max_new_tokens=200)
        analysis = self.tokenizer.decode(output)

        return {
            "is_correct": "correct" in analysis.lower() and "incorrect" not in analysis.lower(),
            "errors": analysis
        }

    def create_correction_prompt(self, problem, solution, errors):
        return f"""The previous solution had errors.

Problem: {problem}
Previous attempt: {solution}
Errors identified: {errors}

Please provide a corrected solution."""

Measuring Emergence

Emergence can be characterized by sharp performance transitions.

PYTHON
import numpy as np

def plot_emergence(model_sizes, performance, task_name):
    print(f"\nEmergence Pattern for {task_name}")
    print("-" * 40)

    for size, perf in zip(model_sizes, performance):
        bar = "#" * int(perf * 30)
        print(f"{size:>8}: {bar} {perf:.1%}")


# Typical emergence patterns
EMERGENCE_PATTERNS = {
    "arithmetic": {
        "sizes": ["100M", "1B", "10B", "100B"],
        "performance": [0.05, 0.10, 0.25, 0.95]  # Sharp jump
    },
    "word_unscrambling": {
        "sizes": ["100M", "1B", "10B", "100B"],
        "performance": [0.02, 0.05, 0.15, 0.85]
    },
    "multi_step_reasoning": {
        "sizes": ["100M", "1B", "10B", "100B"],
        "performance": [0.10, 0.12, 0.20, 0.75]
    }
}

for task, data in EMERGENCE_PATTERNS.items():
    plot_emergence(data["sizes"], data["performance"], task)

Key Takeaways

Emergent abilities arise unpredictably with scale. In-context learning enables few-shot task performance without fine-tuning. Chain-of-thought reasoning improves complex problem solving. Instruction following becomes more reliable with scale. Self-correction capabilities enable iterative refinement. These abilities have fundamentally changed how we design and deploy AI systems.

15.5 LLM Inference and Deployment Advanced

LLM Inference and Deployment

Deploying LLMs at scale requires careful optimization of memory, compute, and latency. Production systems must handle thousands of concurrent requests while maintaining acceptable response times.

KV Cache Management

The key-value cache stores attention states to avoid recomputation during autoregressive generation.

PYTHON
import torch
import torch.nn as nn

class KVCache:
    def __init__(self, batch_size, max_seq_len, num_layers, num_heads, head_dim, device):
        self.max_seq_len = max_seq_len
        self.cache_k = torch.zeros(
            num_layers, batch_size, max_seq_len, num_heads, head_dim,
            device=device
        )
        self.cache_v = torch.zeros(
            num_layers, batch_size, max_seq_len, num_heads, head_dim,
            device=device
        )
        self.seq_len = 0

    def update(self, layer_idx, k, v):
        batch_size, seq_len, num_heads, head_dim = k.shape
        self.cache_k[layer_idx, :batch_size, self.seq_len:self.seq_len + seq_len] = k
        self.cache_v[layer_idx, :batch_size, self.seq_len:self.seq_len + seq_len] = v

    def get(self, layer_idx):
        return (
            self.cache_k[layer_idx, :, :self.seq_len],
            self.cache_v[layer_idx, :, :self.seq_len]
        )

    def increment_seq_len(self, n=1):
        self.seq_len += n


class PagedKVCache:
    def __init__(self, num_blocks, block_size, num_heads, head_dim, device):
        self.block_size = block_size
        self.num_heads = num_heads
        self.head_dim = head_dim

        self.k_cache = torch.zeros(
            num_blocks, block_size, num_heads, head_dim, device=device
        )
        self.v_cache = torch.zeros(
            num_blocks, block_size, num_heads, head_dim, device=device
        )

        self.free_blocks = list(range(num_blocks))
        self.block_tables = {}

    def allocate(self, request_id, num_tokens):
        num_blocks_needed = (num_tokens + self.block_size - 1) // self.block_size
        blocks = [self.free_blocks.pop() for _ in range(num_blocks_needed)]
        self.block_tables[request_id] = blocks
        return blocks

    def free(self, request_id):
        blocks = self.block_tables.pop(request_id)
        self.free_blocks.extend(blocks)

Quantization

Reducing precision dramatically decreases memory and improves throughput.

PYTHON
class Int8Quantizer:
    @staticmethod
    def quantize_weights(weights):
        scale = weights.abs().max() / 127
        quantized = (weights / scale).round().clamp(-128, 127).to(torch.int8)
        return quantized, scale

    @staticmethod
    def dequantize(quantized, scale):
        return quantized.float() * scale


class Int4Quantizer:
    @staticmethod
    def quantize_block(weights, block_size=128):
        original_shape = weights.shape
        weights = weights.view(-1, block_size)

        scales = weights.abs().max(dim=1, keepdim=True)[0] / 7
        quantized = (weights / scales).round().clamp(-8, 7)

        packed = Int4Quantizer._pack_int4(quantized.to(torch.int8))
        return packed, scales, original_shape

    @staticmethod
    def _pack_int4(tensor):
        low = tensor[:, ::2] & 0x0F
        high = (tensor[:, 1::2] & 0x0F) << 4
        return (low | high).to(torch.uint8)


class AWQQuantizer:
    def __init__(self, model, calibration_data):
        self.model = model
        self.calibration_data = calibration_data

    def compute_scales(self, layer):
        activations = self._collect_activations(layer)
        weights = layer.weight.data

        activation_scales = activations.abs().mean(dim=0)
        weight_scales = weights.abs().mean(dim=1)

        optimal_scales = (activation_scales ** 0.5) / (weight_scales ** 0.5)
        return optimal_scales.clamp(min=1e-5)

    def _collect_activations(self, layer):
        activations = []
        hook = layer.register_forward_hook(
            lambda m, i, o: activations.append(i[0].detach())
        )
        for batch in self.calibration_data:
            self.model(batch)
        hook.remove()
        return torch.cat(activations, dim=0)

Continuous Batching

Dynamic batching maximizes GPU utilization with requests of varying lengths.

PYTHON
from dataclasses import dataclass
from typing import List, Optional
import time

@dataclass
class Request:
    id: str
    prompt_tokens: List[int]
    max_new_tokens: int
    generated_tokens: List[int] = None
    start_time: float = None

    def __post_init__(self):
        self.generated_tokens = []
        self.start_time = time.time()


class ContinuousBatcher:
    def __init__(self, model, max_batch_size, max_seq_len):
        self.model = model
        self.max_batch_size = max_batch_size
        self.max_seq_len = max_seq_len

        self.waiting_queue = []
        self.running_batch = []

    def add_request(self, request):
        self.waiting_queue.append(request)

    def step(self):
        # Add requests from waiting queue
        while self.waiting_queue and len(self.running_batch) < self.max_batch_size:
            request = self.waiting_queue.pop(0)
            self.running_batch.append(request)

        if not self.running_batch:
            return []

        # Prepare batch
        input_tokens = self._prepare_batch()

        # Forward pass
        logits = self.model(input_tokens)
        next_tokens = logits[:, -1, :].argmax(dim=-1)

        # Update requests and check completion
        completed = []
        still_running = []

        for i, request in enumerate(self.running_batch):
            token = next_tokens[i].item()
            request.generated_tokens.append(token)

            if self._is_complete(request, token):
                completed.append(request)
            else:
                still_running.append(request)

        self.running_batch = still_running
        return completed

    def _prepare_batch(self):
        max_len = max(
            len(r.prompt_tokens) + len(r.generated_tokens)
            for r in self.running_batch
        )
        batch = torch.zeros(len(self.running_batch), max_len, dtype=torch.long)

        for i, request in enumerate(self.running_batch):
            tokens = request.prompt_tokens + request.generated_tokens
            batch[i, :len(tokens)] = torch.tensor(tokens)

        return batch

    def _is_complete(self, request, token):
        eos_token = 2
        return (
            token == eos_token or
            len(request.generated_tokens) >= request.max_new_tokens
        )

Speculative Decoding

Using a smaller draft model to propose tokens verified by the main model.

PYTHON
class SpeculativeDecoder:
    def __init__(self, main_model, draft_model, tokenizer, gamma=4):
        self.main_model = main_model
        self.draft_model = draft_model
        self.tokenizer = tokenizer
        self.gamma = gamma

    def generate(self, prompt_tokens, max_tokens):
        generated = list(prompt_tokens)

        while len(generated) - len(prompt_tokens) < max_tokens:
            # Draft model generates gamma tokens
            draft_tokens = self._draft(generated)

            # Main model verifies in parallel
            accepted = self._verify(generated, draft_tokens)

            generated.extend(accepted)

            if self._is_eos(accepted[-1]):
                break

        return generated

    def _draft(self, context):
        draft_tokens = []
        current = context.copy()

        for _ in range(self.gamma):
            logits = self.draft_model(torch.tensor([current]))
            next_token = logits[0, -1].argmax().item()
            draft_tokens.append(next_token)
            current.append(next_token)

        return draft_tokens

    def _verify(self, context, draft_tokens):
        full_sequence = context + draft_tokens
        main_logits = self.main_model(torch.tensor([full_sequence]))

        accepted = []
        for i, draft_token in enumerate(draft_tokens):
            pos = len(context) + i - 1
            main_probs = torch.softmax(main_logits[0, pos], dim=-1)
            draft_probs = self._get_draft_probs(context + draft_tokens[:i])

            if self._accept_token(draft_token, main_probs, draft_probs):
                accepted.append(draft_token)
            else:
                corrected = self._sample_correction(main_probs, draft_probs)
                accepted.append(corrected)
                break

        return accepted

    def _accept_token(self, token, main_probs, draft_probs):
        r = torch.rand(1).item()
        return r < min(1, main_probs[token] / draft_probs[token])

    def _sample_correction(self, main_probs, draft_probs):
        corrected_probs = torch.clamp(main_probs - draft_probs, min=0)
        corrected_probs = corrected_probs / corrected_probs.sum()
        return torch.multinomial(corrected_probs, 1).item()

    def _get_draft_probs(self, context):
        logits = self.draft_model(torch.tensor([context]))
        return torch.softmax(logits[0, -1], dim=-1)

    def _is_eos(self, token):
        return token == 2

Key Takeaways

LLM inference optimization is critical for production deployment. KV caching eliminates redundant computation during generation. Quantization reduces memory footprint by 2-4x with minimal quality loss. Continuous batching maximizes GPU utilization across variable-length requests. Speculative decoding accelerates generation using draft models. These techniques combine to make LLM serving practical at scale.