Intermediate Advanced 105 min read

Chapter 12: Sequence Models

RNNs, LSTMs, GRUs, and attention mechanisms.

Libraries covered: PyTorch

Learning Objectives

["Build recurrent networks", "Understand LSTM gates", "Apply attention"]


12.1 Recurrent Neural Networks Intermediate

Recurrent Neural Networks

Recurrent neural networks represent a fundamental departure from feedforward architectures by introducing connections that loop back through time, enabling the network to maintain memory of previous inputs when processing sequences. While convolutional networks excel at spatial patterns in fixed-size inputs, recurrent networks are designed for sequential data where the length varies and the order of elements carries meaning. From natural language and speech to time series and genomic sequences, RNNs provide the foundational architecture for learning from ordered data, though they have been largely superseded by transformers for many applications.

The Need for Sequential Processing

Traditional feedforward networks treat each input independently, making them fundamentally unsuitable for sequential data where context matters. Consider the task of predicting the next word in a sentence: "The clouds darkened and it began to _." The appropriate prediction depends entirely on the preceding words. A feedforward network processing each word in isolation cannot capture these dependencies. Similarly, understanding whether a stock price movement is anomalous requires knowledge of recent price history, and recognizing speech requires tracking phoneme sequences over time.

Sequential data exhibits temporal dependencies that span varying distances. Some patterns involve only adjacent elements, like the spelling rules that constrain which letters can follow others. Other patterns span many time steps, like the subject-verb agreement in "The student who excelled in all the advanced mathematics courses was..." where "was" must agree with "student" despite many intervening words. Effective sequence models must capture both short-range and long-range dependencies.

The key insight of recurrent networks is that the same operation can be applied repeatedly at each time step, with information passed from one step to the next through a hidden state. This hidden state acts as the network's memory, accumulating relevant information from past inputs to inform processing of the current input. The recurrent structure enables weight sharing across time steps, making the model applicable to sequences of any length.

PYTHON
import torch
import torch.nn as nn
import numpy as np

# Demonstrate why feedforward fails for sequences
def feedforward_limitation():
    """Show that feedforward networks can't handle variable-length sequences."""
    # Feedforward expects fixed input size
    ff_network = nn.Sequential(
        nn.Linear(10, 32),  # Expects exactly 10 inputs
        nn.ReLU(),
        nn.Linear(32, 1)
    )

    # Works for length 10
    x_10 = torch.randn(1, 10)
    print(f"Input length 10: {ff_network(x_10).shape}")

    # Fails for length 15
    x_15 = torch.randn(1, 15)
    try:
        ff_network(x_15)
    except RuntimeError as e:
        print(f"Input length 15: Error - {str(e)[:50]}...")

    # Each position processed independently - no context
    print("\nFeedforward treats each position independently:")
    print("Cannot learn that 'it began to ___' needs rain/snow context")

feedforward_limitation()

The Vanilla RNN Architecture

The basic recurrent neural network, often called vanilla RNN, processes sequences by maintaining a hidden state that evolves at each time step. At time $t$, the network receives the current input $x_t$ and the previous hidden state $h_{t-1}$, combining them to produce a new hidden state $h_t$. This hidden state can then be used to generate an output $y_t$ or passed to the next time step.

The mathematical formulation of a vanilla RNN is:

$$h_t = \tanh(W_{xh} x_t + W_{hh} h_{t-1} + b_h)$$
$$y_t = W_{hy} h_t + b_y$$

Here, $W_{xh}$ transforms the input, $W_{hh}$ transforms the previous hidden state, and $W_{hy}$ produces the output. The hyperbolic tangent activation squashes values to the range [-1, 1], preventing hidden state values from growing unboundedly. The same weights are applied at every time step, making the network applicable to sequences of any length.

The hidden state serves as a compressed summary of the sequence seen so far. At each step, the network must decide what information from the new input to incorporate and what to retain from its previous memory. This compression is both the power and limitation of RNNs: they can process arbitrarily long sequences with fixed memory, but that fixed capacity limits what they can remember.

PYTHON
import torch
import torch.nn as nn

class VanillaRNN(nn.Module):
    """
    Vanilla RNN implementation from scratch.
    """
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size

        # Weight matrices
        self.W_xh = nn.Linear(input_size, hidden_size)   # Input to hidden
        self.W_hh = nn.Linear(hidden_size, hidden_size)  # Hidden to hidden
        self.W_hy = nn.Linear(hidden_size, output_size)  # Hidden to output

    def forward(self, x, h_prev=None):
        """
        Process one time step.

        Args:
            x: Input at current time step (batch, input_size)
            h_prev: Previous hidden state (batch, hidden_size)

        Returns:
            output: Output at current time step
            h: New hidden state
        """
        if h_prev is None:
            h_prev = torch.zeros(x.size(0), self.hidden_size, device=x.device)

        # Combine input and previous hidden state
        h = torch.tanh(self.W_xh(x) + self.W_hh(h_prev))

        # Compute output
        output = self.W_hy(h)

        return output, h

    def forward_sequence(self, x_seq):
        """
        Process entire sequence.

        Args:
            x_seq: Sequence tensor (batch, seq_len, input_size)

        Returns:
            outputs: All outputs (batch, seq_len, output_size)
            h_final: Final hidden state
        """
        batch_size, seq_len, _ = x_seq.shape
        h = None
        outputs = []

        for t in range(seq_len):
            output, h = self.forward(x_seq[:, t, :], h)
            outputs.append(output)

        outputs = torch.stack(outputs, dim=1)
        return outputs, h

# Test the vanilla RNN
rnn = VanillaRNN(input_size=10, hidden_size=32, output_size=5)
x_seq = torch.randn(4, 20, 10)  # Batch of 4, sequence length 20, 10 features

outputs, h_final = rnn.forward_sequence(x_seq)
print(f"Input sequence: {x_seq.shape}")
print(f"Output sequence: {outputs.shape}")
print(f"Final hidden state: {h_final.shape}")

Unfolding Through Time

Understanding how RNNs process sequences becomes clearer when we "unfold" the network through time, visualizing each time step as a separate layer. The unfolded view reveals that an RNN processing a sequence of length $T$ is equivalent to a deep feedforward network with $T$ layers, where all layers share the same weights. This perspective illuminates both the power of RNNs and their training challenges.

When unfolded, the RNN forms a computational graph connecting inputs to outputs through a chain of hidden states. The initial hidden state $h_0$ (often initialized to zeros) connects to $h_1$, which connects to $h_2$, and so on. Each hidden state receives input from both the current input $x_t$ and the previous hidden state, creating paths through which information flows across time.

This unfolded view directly informs how we train RNNs through backpropagation through time (BPTT). Gradients flow backward through the unfolded network, from the loss at the final output back through each time step to the beginning of the sequence. The gradient at each time step depends on gradients from all future time steps, making RNN training inherently sequential for the backward pass even when the forward pass could be parallelized.

PYTHON
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

def visualize_unfolding():
    """
    Demonstrate RNN unfolding conceptually.
    """
    print("RNN Unfolding Through Time:")
    print("=" * 60)
    print()
    print("Folded (recurrent) view:")
    print("    ┌────────────────┐")
    print("    │                │")
    print("    │     ┌──────┐   │")
    print("x_t ─────►│  RNN │───┼───► y_t")
    print("    │     └──────┘   │")
    print("    │         ▲      │")
    print("    │         │      │")
    print("    └─────────┴──────┘")
    print("         h_t (loops back)")
    print()
    print("Unfolded view (for sequence length 4):")
    print()
    print("x_0      x_1      x_2      x_3")
    print(" │        │        │        │")
    print(" ▼        ▼        ▼        ▼")
    print("┌───┐    ┌───┐    ┌───┐    ┌───┐")
    print("│RNN│───►│RNN│───►│RNN│───►│RNN│")
    print("└───┘    └───┘    └───┘    └───┘")
    print(" │        │        │        │")
    print(" ▼        ▼        ▼        ▼")
    print("y_0      y_1      y_2      y_3")
    print()
    print("h_0 ──► h_1 ──► h_2 ──► h_3 ──► h_4")
    print()
    print("All RNN blocks share the SAME weights!")
    print("Gradient flows backward through all time steps (BPTT)")

visualize_unfolding()

# Demonstrate gradient flow
def trace_gradient_flow():
    """Show how gradients flow through time."""
    rnn = nn.RNN(input_size=10, hidden_size=20, batch_first=True)
    x = torch.randn(1, 5, 10, requires_grad=True)

    output, h_n = rnn(x)

    # Loss at final time step
    loss = output[:, -1, :].sum()
    loss.backward()

    print("\nGradient flow demonstration:")
    print(f"Input shape: {x.shape} (batch=1, seq_len=5, features=10)")
    print(f"Input gradient shape: {x.grad.shape}")
    print(f"Gradient magnitude at each time step:")
    for t in range(5):
        grad_mag = x.grad[0, t, :].norm().item()
        print(f"  t={t}: {grad_mag:.4f}")

trace_gradient_flow()

The Vanishing Gradient Problem

The fundamental limitation of vanilla RNNs emerges from the dynamics of gradient flow through the unfolded network. When gradients backpropagate through many time steps, they pass through repeated matrix multiplications with the hidden-to-hidden weight matrix $W_{hh}$ and the derivative of the activation function. If the largest singular value of $W_{hh}$ is less than 1, gradients shrink exponentially; if greater than 1, they explode exponentially.

The vanishing gradient problem manifests as the network's inability to learn long-range dependencies. When gradients from distant time steps shrink to near zero by the time they reach early time steps, the network cannot effectively credit or blame early inputs for outcomes at the end of the sequence. In practice, vanilla RNNs struggle with dependencies spanning more than 10-20 time steps.

Mathematically, consider the gradient of the loss with respect to the hidden state at time $t$:

$$\frac{\partial L}{\partial h_t} = \frac{\partial L}{\partial h_T} \prod_{k=t+1}^{T} \frac{\partial h_k}{\partial h_{k-1}}$$

Each factor $\frac{\partial h_k}{\partial h_{k-1}}$ involves $W_{hh}$ and the derivative of tanh, which is at most 1. The product of many numbers less than 1 vanishes exponentially, while the product of numbers greater than 1 explodes.

PYTHON
import torch
import torch.nn as nn
import numpy as np

def demonstrate_vanishing_gradients():
    """
    Show how gradients vanish in vanilla RNNs.
    """
    hidden_size = 50
    seq_lengths = [10, 50, 100, 200]

    print("Vanishing Gradient Demonstration")
    print("=" * 50)

    for seq_len in seq_lengths:
        # Create RNN and process sequence
        rnn = nn.RNN(input_size=10, hidden_size=hidden_size, batch_first=True)
        x = torch.randn(1, seq_len, 10, requires_grad=True)

        output, _ = rnn(x)
        loss = output[:, -1, :].sum()
        loss.backward()

        # Measure gradient magnitude at each time step
        grad_norms = [x.grad[0, t, :].norm().item() for t in range(seq_len)]

        # Compare first and last
        ratio = grad_norms[0] / grad_norms[-1] if grad_norms[-1] > 0 else float('inf')

        print(f"\nSequence length {seq_len}:")
        print(f"  Gradient at t=0: {grad_norms[0]:.6f}")
        print(f"  Gradient at t={seq_len-1}: {grad_norms[-1]:.6f}")
        print(f"  Ratio (first/last): {ratio:.2f}x")

    # Theoretical analysis
    print("\n" + "=" * 50)
    print("Theoretical Analysis:")
    print("For tanh activation, derivative is in (0, 1]")
    print("If ||W_hh|| < 1, gradients vanish exponentially")
    print("After T steps: gradient ≈ (||W_hh|| * tanh_derivative)^T")

    W = torch.randn(hidden_size, hidden_size) * 0.5  # Small weights
    singular_values = torch.linalg.svdvals(W)
    print(f"\nExample W_hh largest singular value: {singular_values[0]:.3f}")
    for T in [10, 50, 100]:
        decay = (singular_values[0].item() * 0.5) ** T  # Approximate
        print(f"  After {T} steps: gradient scale ≈ {decay:.2e}")

demonstrate_vanishing_gradients()

Gradient Clipping

While the vanishing gradient problem leads to slow learning of long-range dependencies, the exploding gradient problem causes training instability where gradients become so large that weight updates overshoot drastically. Gradient clipping provides a simple but effective solution by rescaling gradients when their norm exceeds a threshold.

The most common approach, global gradient clipping, computes the norm of all gradients concatenated into a single vector and rescales if this norm exceeds a threshold:

$$\text{if } \|\nabla\| > \theta: \quad \nabla \leftarrow \frac{\theta}{\|\nabla\|} \nabla$$

This preserves the direction of the gradient update while limiting its magnitude. The threshold $\theta$ is a hyperparameter, typically set between 1 and 5. Gradient clipping is essential for training RNNs and remains important even for more advanced architectures like LSTMs.

PYTHON
import torch
import torch.nn as nn

def demonstrate_gradient_clipping():
    """
    Show gradient clipping in action.
    """
    # Create a simple RNN
    rnn = nn.RNN(10, 20, batch_first=True)
    x = torch.randn(1, 100, 10)
    target = torch.randn(1, 100, 20)

    # Forward pass
    output, _ = rnn(x)
    loss = nn.MSELoss()(output, target)

    # Backward pass
    loss.backward()

    # Compute gradient norm before clipping
    total_norm_before = 0
    for p in rnn.parameters():
        if p.grad is not None:
            total_norm_before += p.grad.norm().item() ** 2
    total_norm_before = total_norm_before ** 0.5

    print("Gradient Clipping Demonstration")
    print("=" * 50)
    print(f"Gradient norm before clipping: {total_norm_before:.4f}")

    # Apply gradient clipping
    max_norm = 1.0
    torch.nn.utils.clip_grad_norm_(rnn.parameters(), max_norm)

    # Compute gradient norm after clipping
    total_norm_after = 0
    for p in rnn.parameters():
        if p.grad is not None:
            total_norm_after += p.grad.norm().item() ** 2
    total_norm_after = total_norm_after ** 0.5

    print(f"Gradient norm after clipping: {total_norm_after:.4f}")
    print(f"Clipping threshold: {max_norm}")
    print(f"Clipped: {total_norm_before > max_norm}")

demonstrate_gradient_clipping()

# Gradient clipping in training loop
def train_with_clipping(model, optimizer, data_loader, clip_value=1.0):
    """Training loop with gradient clipping."""
    model.train()
    total_loss = 0

    for batch_x, batch_y in data_loader:
        optimizer.zero_grad()

        output, _ = model(batch_x)
        loss = nn.CrossEntropyLoss()(output.view(-1, output.size(-1)),
                                      batch_y.view(-1))
        loss.backward()

        # Clip gradients before optimizer step
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)

        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(data_loader)

RNN Variants and Configurations

RNNs can be configured in various ways depending on the task. The relationship between input and output sequences determines the architecture: many-to-one for classification of entire sequences, one-to-many for generation from a single seed, many-to-many for sequence labeling or translation. Stacking multiple RNN layers creates deeper networks that can learn more complex transformations.

Bidirectional RNNs process sequences in both forward and backward directions, concatenating hidden states from both passes. This allows each position to access context from both past and future, beneficial for tasks like named entity recognition where understanding the full sentence helps identify entities. Bidirectional processing is only possible when the full sequence is available before generating outputs.

Deep RNNs stack multiple recurrent layers, where the output sequence of one layer becomes the input to the next. Each layer can learn different levels of abstraction, similar to layers in feedforward networks. Residual connections between layers help gradient flow in deep stacks.

PYTHON
import torch
import torch.nn as nn

# Different RNN configurations

# Many-to-one: Sequence classification
class SentimentClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size, num_classes):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.rnn = nn.RNN(embed_dim, hidden_size, batch_first=True)
        self.classifier = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        embedded = self.embedding(x)
        _, h_n = self.rnn(embedded)  # Only use final hidden state
        return self.classifier(h_n.squeeze(0))

# One-to-many: Sequence generation
class TextGenerator(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.rnn = nn.RNN(embed_dim, hidden_size, batch_first=True)
        self.output = nn.Linear(hidden_size, vocab_size)

    def forward(self, seed, length):
        # Start with seed token
        x = self.embedding(seed)
        h = None
        outputs = []

        for _ in range(length):
            out, h = self.rnn(x, h)
            logits = self.output(out)
            next_token = logits.argmax(dim=-1)
            outputs.append(next_token)
            x = self.embedding(next_token)

        return torch.cat(outputs, dim=1)

# Many-to-many: Sequence labeling
class POSTagger(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size, num_tags):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.rnn = nn.RNN(embed_dim, hidden_size, batch_first=True)
        self.tagger = nn.Linear(hidden_size, num_tags)

    def forward(self, x):
        embedded = self.embedding(x)
        output, _ = self.rnn(embedded)  # Output at each position
        return self.tagger(output)

# Bidirectional RNN
class BiRNNEncoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.rnn = nn.RNN(input_size, hidden_size,
                          bidirectional=True, batch_first=True)

    def forward(self, x):
        # Output has 2*hidden_size features (forward + backward)
        output, h_n = self.rnn(x)
        # h_n shape: (2, batch, hidden) for forward and backward
        return output, h_n

# Deep (stacked) RNN
class DeepRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super().__init__()
        self.rnn = nn.RNN(input_size, hidden_size,
                          num_layers=num_layers, batch_first=True)

    def forward(self, x):
        # h_n shape: (num_layers, batch, hidden)
        output, h_n = self.rnn(x)
        return output, h_n

# Test configurations
print("RNN Configuration Examples:")
print("=" * 50)

x = torch.randn(4, 20, 32)  # batch=4, seq_len=20, features=32

bi_rnn = BiRNNEncoder(32, 64)
output, h_n = bi_rnn(x)
print(f"\nBidirectional RNN:")
print(f"  Input: {x.shape}")
print(f"  Output: {output.shape} (hidden*2 = 128)")
print(f"  Hidden: {h_n.shape} (2 directions)")

deep_rnn = DeepRNN(32, 64, num_layers=3)
output, h_n = deep_rnn(x)
print(f"\nDeep RNN (3 layers):")
print(f"  Input: {x.shape}")
print(f"  Output: {output.shape}")
print(f"  Hidden: {h_n.shape} (3 layers)")

Using PyTorch's RNN Module

PyTorch provides optimized RNN implementations that handle batching, padding, and GPU acceleration efficiently. The nn.RNN, nn.LSTM, and nn.GRU modules support variable-length sequences through packed sequences, multiple layers, bidirectional processing, and dropout between layers.

When working with variable-length sequences in batches, sequences must be padded to the same length but we want the RNN to ignore padding tokens. PyTorch's pack<em>padded</em>sequence and pad<em>packed</em>sequence utilities handle this efficiently, ensuring the RNN only processes actual tokens.

PYTHON
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

# Using PyTorch's RNN with packed sequences
class EfficientRNN(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size, num_layers=1,
                 bidirectional=False, dropout=0.0):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.rnn = nn.RNN(
            input_size=embed_dim,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=bidirectional,
            dropout=dropout if num_layers > 1 else 0
        )

    def forward(self, x, lengths):
        """
        Forward pass with variable-length sequences.

        Args:
            x: Padded input (batch, max_seq_len)
            lengths: Actual lengths of each sequence
        """
        # Embed tokens
        embedded = self.embedding(x)

        # Pack sequences for efficient processing
        packed = pack_padded_sequence(embedded, lengths.cpu(),
                                      batch_first=True, enforce_sorted=False)

        # Process with RNN
        packed_output, h_n = self.rnn(packed)

        # Unpack back to padded tensor
        output, _ = pad_packed_sequence(packed_output, batch_first=True)

        return output, h_n

# Example with variable-length sequences
model = EfficientRNN(vocab_size=1000, embed_dim=64, hidden_size=128)

# Batch with different lengths (padded with 0)
sequences = torch.tensor([
    [1, 2, 3, 4, 5, 0, 0, 0],  # Length 5
    [1, 2, 3, 0, 0, 0, 0, 0],  # Length 3
    [1, 2, 3, 4, 5, 6, 7, 8],  # Length 8
    [1, 2, 0, 0, 0, 0, 0, 0],  # Length 2
])
lengths = torch.tensor([5, 3, 8, 2])

output, h_n = model(sequences, lengths)
print(f"Variable-length sequence handling:")
print(f"  Input shape: {sequences.shape}")
print(f"  Lengths: {lengths.tolist()}")
print(f"  Output shape: {output.shape}")
print(f"  Hidden shape: {h_n.shape}")

Key Takeaways

Recurrent neural networks process sequential data by maintaining hidden states that carry information across time steps, enabling learning from variable-length sequences where order matters. The vanilla RNN applies the same transformation at each step, combining the current input with the previous hidden state through learned weight matrices. When unfolded through time, RNNs reveal their structure as deep networks with shared weights, trained through backpropagation through time. The vanishing gradient problem severely limits vanilla RNNs' ability to learn long-range dependencies, as gradients shrink exponentially when backpropagating through many time steps. Gradient clipping addresses the complementary exploding gradient problem by rescaling large gradients. Various RNN configurations including bidirectional and deep stacked architectures adapt the basic structure for different tasks. Despite their historical importance, vanilla RNNs have been largely replaced by LSTMs, GRUs, and more recently transformers, which better address the long-range dependency problem.

12.2 LSTM and GRU Intermediate

LSTM and GRU

Long Short-Term Memory networks and Gated Recurrent Units represent the two most successful solutions to the vanishing gradient problem that plagues vanilla RNNs. Both architectures introduce gating mechanisms that control information flow through the network, enabling learning of dependencies spanning hundreds of time steps. These gated architectures dominated sequence modeling from their introduction until the rise of transformers, and they remain important for many applications where their computational properties offer advantages.

The LSTM Architecture

The Long Short-Term Memory network, introduced by Hochreiter and Schmidhuber in 1997, fundamentally reimagines how recurrent networks store and access information. The key innovation is the cell state, a dedicated pathway for information that can flow through time with minimal modification. Gates learned by the network control what information enters the cell state, what exits, and what remains. This architecture allows gradients to flow through the cell state relatively unchanged, solving the vanishing gradient problem.

An LSTM cell contains four interacting components: three gates (input, forget, output) and the cell state update. Each gate is a neural network layer with sigmoid activation, producing values between 0 and 1 that determine how much information passes through. The forget gate decides what to discard from the cell state, the input gate controls what new information to store, and the output gate determines what to expose as the hidden state.

The mathematical formulation of an LSTM at time step $t$ is:

$$f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) \quad \text{(forget gate)}$$
$$i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) \quad \text{(input gate)}$$
$$\tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C) \quad \text{(candidate cell state)}$$
$$C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t \quad \text{(cell state update)}$$
$$o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) \quad \text{(output gate)}$$
$$h_t = o_t \odot \tanh(C_t) \quad \text{(hidden state)}$$

Here $\sigma$ denotes the sigmoid function, $\odot$ represents element-wise multiplication, and $[h_{t-1}, x_t]$ denotes concatenation.

PYTHON
import torch
import torch.nn as nn

class LSTMCell(nn.Module):
    """
    LSTM cell implementation from scratch.
    """
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size

        # Combined weights for all gates (more efficient)
        self.gates = nn.Linear(input_size + hidden_size, 4 * hidden_size)

    def forward(self, x, state=None):
        """
        Process one time step.

        Args:
            x: Input (batch, input_size)
            state: Tuple of (h, c) or None

        Returns:
            h: New hidden state
            (h, c): New state tuple
        """
        batch_size = x.size(0)

        if state is None:
            h = torch.zeros(batch_size, self.hidden_size, device=x.device)
            c = torch.zeros(batch_size, self.hidden_size, device=x.device)
        else:
            h, c = state

        # Concatenate input and hidden state
        combined = torch.cat([x, h], dim=1)

        # Compute all gates at once
        gates = self.gates(combined)

        # Split into individual gates
        i, f, g, o = gates.chunk(4, dim=1)

        # Apply activations
        i = torch.sigmoid(i)  # Input gate
        f = torch.sigmoid(f)  # Forget gate
        g = torch.tanh(g)     # Candidate cell state
        o = torch.sigmoid(o)  # Output gate

        # Update cell state
        c_new = f * c + i * g

        # Compute hidden state
        h_new = o * torch.tanh(c_new)

        return h_new, (h_new, c_new)

# Test LSTM cell
lstm_cell = LSTMCell(input_size=32, hidden_size=64)
x = torch.randn(8, 32)  # batch=8, features=32

h, state = lstm_cell(x)
print(f"LSTM Cell:")
print(f"  Input: {x.shape}")
print(f"  Hidden state: {h.shape}")
print(f"  Cell state: {state[1].shape}")

# Process sequence
def process_sequence(cell, x_seq):
    """Process sequence with LSTM cell."""
    batch_size, seq_len, _ = x_seq.shape
    state = None
    outputs = []

    for t in range(seq_len):
        h, state = cell(x_seq[:, t, :], state)
        outputs.append(h)

    return torch.stack(outputs, dim=1), state

x_seq = torch.randn(8, 20, 32)
outputs, final_state = process_sequence(lstm_cell, x_seq)
print(f"\nSequence processing:")
print(f"  Input sequence: {x_seq.shape}")
print(f"  Output sequence: {outputs.shape}")

Understanding LSTM Gates

Each gate in an LSTM serves a specific purpose in controlling information flow. The forget gate $f_t$ determines what information from the previous cell state to discard. When the forget gate outputs values near 0, the corresponding cell state values are forgotten; when near 1, they are retained. This gate enables the LSTM to clear irrelevant information that would otherwise persist indefinitely.

The input gate $i_t$ and candidate cell state $\tilde{C}_t$ work together to add new information. The candidate represents potential new information derived from the current input and previous hidden state. The input gate then scales this candidate, determining how much of each dimension to actually incorporate into the cell state. This two-stage process allows fine-grained control over what new information to store.

The output gate $o_t$ controls what part of the cell state to expose as the hidden state, which serves as both the output at this time step and the input to the next time step's computations. The cell state passes through tanh (bounding it between -1 and 1) before being gated, ensuring the hidden state stays in a reasonable range.

PYTHON
import torch
import torch.nn as nn
import numpy as np

def visualize_gates():
    """
    Demonstrate how LSTM gates control information flow.
    """
    print("LSTM Gate Behavior:")
    print("=" * 60)

    # Create LSTM
    lstm = nn.LSTM(input_size=10, hidden_size=20, batch_first=True)

    # We can't easily access internal gates with nn.LSTM,
    # so we'll demonstrate the concept

    print("\nForget Gate (f_t):")
    print("  f_t ≈ 0: Forget previous cell state (clear memory)")
    print("  f_t ≈ 1: Keep previous cell state (preserve memory)")
    print("  Use case: Clear context when new sentence starts")

    print("\nInput Gate (i_t):")
    print("  i_t ≈ 0: Don't add new information")
    print("  i_t ≈ 1: Add new information to cell state")
    print("  Use case: Store important words, ignore filler words")

    print("\nOutput Gate (o_t):")
    print("  o_t ≈ 0: Don't expose cell state to output")
    print("  o_t ≈ 1: Expose cell state as hidden state")
    print("  Use case: Reveal information only when relevant")

    # Demonstrate gate patterns
    print("\n" + "=" * 60)
    print("Example: Processing 'The cat sat on the mat.'")
    print("-" * 60)

    words = ['The', 'cat', 'sat', 'on', 'the', 'mat', '.']
    patterns = [
        ('The', 'f≈1, i≈0.3', 'Article, minor info'),
        ('cat', 'f≈1, i≈0.9', 'Subject, important to remember'),
        ('sat', 'f≈1, i≈0.7', 'Verb, moderately important'),
        ('on', 'f≈1, i≈0.2', 'Preposition, minor'),
        ('the', 'f≈1, i≈0.1', 'Article, minimal info'),
        ('mat', 'f≈1, i≈0.8', 'Object, important'),
        ('.', 'f≈0.1, i≈0.1', 'End of sentence, reset context'),
    ]

    for word, gates, explanation in patterns:
        print(f"  {word:5s}: {gates:15s} - {explanation}")

visualize_gates()

# Demonstrate long-term memory preservation
def test_long_term_memory():
    """Show LSTM preserving information over long sequences."""
    print("\n" + "=" * 60)
    print("Long-term Memory Test:")
    print("-" * 60)

    # Compare RNN and LSTM on a simple memory task
    seq_len = 100
    hidden_size = 32

    rnn = nn.RNN(1, hidden_size, batch_first=True)
    lstm = nn.LSTM(1, hidden_size, batch_first=True)

    # Input: signal at t=0, zeros elsewhere
    x = torch.zeros(1, seq_len, 1)
    x[0, 0, 0] = 1.0  # Signal only at first step

    with torch.no_grad():
        rnn_out, _ = rnn(x)
        lstm_out, _ = lstm(x)

    # Check if signal persists to the end
    rnn_signal = rnn_out[0, -1, :].norm().item()
    lstm_signal = lstm_out[0, -1, :].norm().item()

    print(f"Signal strength at t=0: 1.0")
    print(f"Signal remaining at t={seq_len-1}:")
    print(f"  RNN:  {rnn_signal:.6f}")
    print(f"  LSTM: {lstm_signal:.6f}")
    print(f"  Ratio (LSTM/RNN): {lstm_signal/rnn_signal:.2f}x better retention")

test_long_term_memory()

The GRU Architecture

The Gated Recurrent Unit, introduced by Cho et al. in 2014, provides a simpler alternative to the LSTM with comparable performance on many tasks. The GRU combines the forget and input gates into a single update gate and merges the cell state with the hidden state, reducing the number of parameters and computational cost while maintaining the ability to learn long-range dependencies.

The GRU uses two gates: the reset gate $r_t$ and the update gate $z_t$. The reset gate determines how much of the previous hidden state to forget when computing the candidate hidden state. The update gate controls the balance between the previous hidden state and the new candidate, functioning somewhat like a combined forget and input gate.

The mathematical formulation of a GRU is:

$$z_t = \sigma(W_z \cdot [h_{t-1}, x_t] + b_z) \quad \text{(update gate)}$$
$$r_t = \sigma(W_r \cdot [h_{t-1}, x_t] + b_r) \quad \text{(reset gate)}$$
$$\tilde{h}_t = \tanh(W_h \cdot [r_t \odot h_{t-1}, x_t] + b_h) \quad \text{(candidate hidden state)}$$
$$h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t \quad \text{(hidden state update)}$$

The update gate's interpolation between old and new hidden states provides a direct path for gradient flow when $z_t \approx 0$, addressing the vanishing gradient problem.

PYTHON
import torch
import torch.nn as nn

class GRUCell(nn.Module):
    """
    GRU cell implementation from scratch.
    """
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size

        # Gates: reset and update
        self.W_r = nn.Linear(input_size + hidden_size, hidden_size)
        self.W_z = nn.Linear(input_size + hidden_size, hidden_size)
        # Candidate hidden state
        self.W_h = nn.Linear(input_size + hidden_size, hidden_size)

    def forward(self, x, h=None):
        """
        Process one time step.

        Args:
            x: Input (batch, input_size)
            h: Previous hidden state or None

        Returns:
            h_new: New hidden state
        """
        batch_size = x.size(0)

        if h is None:
            h = torch.zeros(batch_size, self.hidden_size, device=x.device)

        # Concatenate input and hidden state
        combined = torch.cat([x, h], dim=1)

        # Compute gates
        r = torch.sigmoid(self.W_r(combined))  # Reset gate
        z = torch.sigmoid(self.W_z(combined))  # Update gate

        # Compute candidate hidden state (with reset applied)
        combined_reset = torch.cat([x, r * h], dim=1)
        h_candidate = torch.tanh(self.W_h(combined_reset))

        # Interpolate between old and new
        h_new = (1 - z) * h + z * h_candidate

        return h_new

# Test GRU cell
gru_cell = GRUCell(input_size=32, hidden_size=64)
x = torch.randn(8, 32)

h = gru_cell(x)
print(f"GRU Cell:")
print(f"  Input: {x.shape}")
print(f"  Hidden state: {h.shape}")

# Compare GRU and LSTM parameter counts
def compare_parameters():
    """Compare LSTM and GRU parameter counts."""
    input_size = 256
    hidden_size = 512

    lstm = nn.LSTM(input_size, hidden_size)
    gru = nn.GRU(input_size, hidden_size)

    lstm_params = sum(p.numel() for p in lstm.parameters())
    gru_params = sum(p.numel() for p in gru.parameters())

    print(f"\nParameter Comparison (input={input_size}, hidden={hidden_size}):")
    print(f"  LSTM: {lstm_params:,} parameters")
    print(f"  GRU:  {gru_params:,} parameters")
    print(f"  Ratio: LSTM has {lstm_params/gru_params:.2f}x more parameters")
    print(f"\nLSTM formula: 4 * hidden * (input + hidden + 1)")
    print(f"GRU formula:  3 * hidden * (input + hidden + 1)")

compare_parameters()

Comparing LSTM and GRU

The choice between LSTM and GRU often comes down to the specific task and computational constraints. Both architectures effectively solve the vanishing gradient problem, and neither consistently outperforms the other across all tasks. GRUs are computationally more efficient due to fewer parameters and simpler structure, while LSTMs offer more modeling flexibility through the separate cell state.

Empirically, LSTMs tend to perform slightly better on tasks requiring very long-term memory, where the separate cell state pathway provides advantages. GRUs often match or exceed LSTM performance on many practical tasks while training faster. For language modeling, both perform comparably, with GRUs sometimes preferred for their simplicity.

The computational difference becomes significant at scale. GRUs require about 25% fewer parameters and correspondingly less computation per time step. When processing millions of sequences or deploying on resource-constrained devices, this efficiency gain matters. When accuracy is paramount and resources are plentiful, LSTMs remain a solid choice.

PYTHON
import torch
import torch.nn as nn
import time

def benchmark_lstm_vs_gru():
    """
    Compare LSTM and GRU performance.
    """
    print("LSTM vs GRU Benchmark")
    print("=" * 60)

    input_size = 256
    hidden_size = 512
    seq_len = 100
    batch_size = 64
    num_iterations = 100

    # Create models
    lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
    gru = nn.GRU(input_size, hidden_size, batch_first=True)

    # Move to GPU if available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    lstm = lstm.to(device)
    gru = gru.to(device)

    # Benchmark data
    x = torch.randn(batch_size, seq_len, input_size, device=device)

    # Warm up
    _ = lstm(x)
    _ = gru(x)

    # Benchmark LSTM
    if device.type == 'cuda':
        torch.cuda.synchronize()
    start = time.time()
    for _ in range(num_iterations):
        _ = lstm(x)
    if device.type == 'cuda':
        torch.cuda.synchronize()
    lstm_time = time.time() - start

    # Benchmark GRU
    if device.type == 'cuda':
        torch.cuda.synchronize()
    start = time.time()
    for _ in range(num_iterations):
        _ = gru(x)
    if device.type == 'cuda':
        torch.cuda.synchronize()
    gru_time = time.time() - start

    print(f"Device: {device}")
    print(f"Sequence length: {seq_len}, Batch size: {batch_size}")
    print(f"Input size: {input_size}, Hidden size: {hidden_size}")
    print(f"\nTiming ({num_iterations} iterations):")
    print(f"  LSTM: {lstm_time:.3f}s ({lstm_time/num_iterations*1000:.2f}ms per batch)")
    print(f"  GRU:  {gru_time:.3f}s ({gru_time/num_iterations*1000:.2f}ms per batch)")
    print(f"  GRU speedup: {lstm_time/gru_time:.2f}x")

# Run benchmark (may be slow on CPU)
# benchmark_lstm_vs_gru()

# Theoretical comparison
print("\nTheoretical Comparison:")
print("-" * 60)
print("LSTM:")
print("  - 4 gate operations (forget, input, output, candidate)")
print("  - Separate cell state and hidden state")
print("  - More parameters, more expressive")
print("  - Better for very long sequences")
print("\nGRU:")
print("  - 2 gate operations (reset, update)")
print("  - Single hidden state")
print("  - Fewer parameters, faster training")
print("  - Often sufficient for most tasks")
print("\nRecommendation:")
print("  - Start with GRU for efficiency")
print("  - Switch to LSTM if GRU underperforms")
print("  - Consider Transformer for long sequences")

Peephole Connections and Variants

Several LSTM variants have been proposed to improve performance or efficiency. Peephole connections allow the gates to look directly at the cell state, not just the hidden state. This gives gates more information for their decisions, potentially improving performance on tasks requiring precise timing or counting.

With peephole connections, the gate equations become:

$$f_t = \sigma(W_f \cdot [C_{t-1}, h_{t-1}, x_t] + b_f)$$
$$i_t = \sigma(W_i \cdot [C_{t-1}, h_{t-1}, x_t] + b_i)$$
$$o_t = \sigma(W_o \cdot [C_t, h_{t-1}, x_t] + b_o)$$

Other variants include coupled forget-input gates (where $i_t = 1 - f_t$, reducing parameters), LSTMs without output gates, and various combinations. Research has shown that no single variant consistently outperforms others, and the vanilla LSTM remains a strong baseline.

PYTHON
import torch
import torch.nn as nn

class PeepholeLSTMCell(nn.Module):
    """
    LSTM cell with peephole connections.
    Gates can look at the cell state directly.
    """
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size

        # Standard weights
        self.W_i = nn.Linear(input_size + hidden_size, hidden_size)
        self.W_f = nn.Linear(input_size + hidden_size, hidden_size)
        self.W_o = nn.Linear(input_size + hidden_size, hidden_size)
        self.W_c = nn.Linear(input_size + hidden_size, hidden_size)

        # Peephole weights (diagonal, so just vectors)
        self.p_i = nn.Parameter(torch.randn(hidden_size) * 0.1)
        self.p_f = nn.Parameter(torch.randn(hidden_size) * 0.1)
        self.p_o = nn.Parameter(torch.randn(hidden_size) * 0.1)

    def forward(self, x, state=None):
        batch_size = x.size(0)

        if state is None:
            h = torch.zeros(batch_size, self.hidden_size, device=x.device)
            c = torch.zeros(batch_size, self.hidden_size, device=x.device)
        else:
            h, c = state

        combined = torch.cat([x, h], dim=1)

        # Gates with peephole connections
        f = torch.sigmoid(self.W_f(combined) + self.p_f * c)  # Forget gate
        i = torch.sigmoid(self.W_i(combined) + self.p_i * c)  # Input gate
        g = torch.tanh(self.W_c(combined))                     # Candidate

        # Update cell state
        c_new = f * c + i * g

        # Output gate (uses new cell state)
        o = torch.sigmoid(self.W_o(combined) + self.p_o * c_new)

        # Hidden state
        h_new = o * torch.tanh(c_new)

        return h_new, (h_new, c_new)

# Test peephole LSTM
peephole_lstm = PeepholeLSTMCell(32, 64)
x = torch.randn(8, 32)
h, state = peephole_lstm(x)
print(f"Peephole LSTM Cell:")
print(f"  Input: {x.shape}")
print(f"  Output: {h.shape}")

# Coupled input-forget gate variant
class CoupledLSTMCell(nn.Module):
    """
    LSTM with coupled input-forget gates: i_t = 1 - f_t
    Reduces parameters while maintaining performance.
    """
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size

        # Only need 3 gates instead of 4
        self.W_f = nn.Linear(input_size + hidden_size, hidden_size)
        self.W_o = nn.Linear(input_size + hidden_size, hidden_size)
        self.W_c = nn.Linear(input_size + hidden_size, hidden_size)

    def forward(self, x, state=None):
        batch_size = x.size(0)

        if state is None:
            h = torch.zeros(batch_size, self.hidden_size, device=x.device)
            c = torch.zeros(batch_size, self.hidden_size, device=x.device)
        else:
            h, c = state

        combined = torch.cat([x, h], dim=1)

        f = torch.sigmoid(self.W_f(combined))
        i = 1 - f  # Coupled: input = 1 - forget
        g = torch.tanh(self.W_c(combined))
        o = torch.sigmoid(self.W_o(combined))

        c_new = f * c + i * g
        h_new = o * torch.tanh(c_new)

        return h_new, (h_new, c_new)

print("\nLSTM Variants Parameter Comparison:")
standard = nn.LSTMCell(32, 64)
peephole = PeepholeLSTMCell(32, 64)
coupled = CoupledLSTMCell(32, 64)

for name, model in [('Standard', standard), ('Peephole', peephole), ('Coupled', coupled)]:
    params = sum(p.numel() for p in model.parameters())
    print(f"  {name}: {params:,} parameters")

Stacking and Bidirectional Configurations

Both LSTMs and GRUs can be stacked into multiple layers, where each layer processes the output sequence of the layer below. Deeper networks can learn more complex transformations, though diminishing returns typically appear after 2-4 layers. Dropout between layers helps prevent overfitting in deep stacked configurations.

Bidirectional versions process sequences in both directions simultaneously, providing each position with context from both past and future. The forward and backward hidden states are typically concatenated, doubling the effective hidden size. Bidirectional LSTMs and GRUs are standard for tasks like named entity recognition, part-of-speech tagging, and any task where the full sequence is available before producing outputs.

PYTHON
import torch
import torch.nn as nn

# Stacked and bidirectional configurations
class StackedBiLSTM(nn.Module):
    """
    Stacked bidirectional LSTM encoder.
    """
    def __init__(self, input_size, hidden_size, num_layers=2, dropout=0.3):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout if num_layers > 1 else 0
        )

    def forward(self, x):
        # output: (batch, seq_len, 2*hidden_size)
        # h_n: (2*num_layers, batch, hidden_size)
        output, (h_n, c_n) = self.lstm(x)
        return output, (h_n, c_n)

# Test stacked bidirectional LSTM
model = StackedBiLSTM(input_size=64, hidden_size=128, num_layers=3)
x = torch.randn(8, 50, 64)  # batch=8, seq_len=50, features=64

output, (h_n, c_n) = model(x)

print("Stacked Bidirectional LSTM:")
print(f"  Input: {x.shape}")
print(f"  Output: {output.shape} (2*hidden = {2*128})")
print(f"  h_n: {h_n.shape} (2*num_layers = {2*3})")
print(f"  c_n: {c_n.shape}")
print(f"  Parameters: {sum(p.numel() for p in model.parameters()):,}")

# Extract final states for classification
def get_final_states(h_n, num_layers, bidirectional=True):
    """
    Extract final hidden states from stacked LSTM.
    Returns: (batch, num_layers * num_directions * hidden_size)
    """
    if bidirectional:
        # h_n shape: (num_layers * 2, batch, hidden)
        # Reshape to (num_layers, 2, batch, hidden)
        num_directions = 2
        h_n = h_n.view(num_layers, num_directions, h_n.size(1), h_n.size(2))
        # Concatenate forward and backward for each layer
        # Then concatenate all layers
        final = torch.cat([h_n[i] for i in range(num_layers)], dim=0)
        final = final.transpose(0, 1).contiguous()  # (batch, layers*dirs, hidden)
        final = final.view(final.size(0), -1)  # (batch, layers*dirs*hidden)
    else:
        final = h_n.transpose(0, 1).contiguous().view(h_n.size(1), -1)

    return final

final = get_final_states(h_n, num_layers=3, bidirectional=True)
print(f"\nFinal states for classification: {final.shape}")

Practical Usage with PyTorch

PyTorch's LSTM and GRU implementations are highly optimized for GPU execution. Key considerations include proper initialization, handling variable-length sequences, and managing the distinction between hidden and cell states.

PYTHON
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class TextClassifier(nn.Module):
    """
    Complete text classification model with LSTM.
    """
    def __init__(self, vocab_size, embed_dim, hidden_size, num_classes,
                 num_layers=2, bidirectional=True, dropout=0.5):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.dropout = nn.Dropout(dropout)

        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=bidirectional,
            dropout=dropout if num_layers > 1 else 0
        )

        # Classifier input size depends on bidirectional
        classifier_input = hidden_size * 2 if bidirectional else hidden_size
        self.classifier = nn.Sequential(
            nn.Linear(classifier_input, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, num_classes)
        )

    def forward(self, x, lengths):
        # Embed and apply dropout
        embedded = self.dropout(self.embedding(x))

        # Pack for efficient processing
        packed = pack_padded_sequence(embedded, lengths.cpu(),
                                      batch_first=True, enforce_sorted=False)

        # Process with LSTM
        packed_output, (h_n, c_n) = self.lstm(packed)

        # Unpack
        output, _ = pad_packed_sequence(packed_output, batch_first=True)

        # Use final hidden state for classification
        # Concatenate forward and backward final states
        if self.lstm.bidirectional:
            final_hidden = torch.cat([h_n[-2], h_n[-1]], dim=1)
        else:
            final_hidden = h_n[-1]

        # Classify
        logits = self.classifier(final_hidden)
        return logits

# Test the classifier
model = TextClassifier(
    vocab_size=10000,
    embed_dim=128,
    hidden_size=256,
    num_classes=5,
    num_layers=2,
    bidirectional=True
)

# Sample batch
x = torch.randint(1, 10000, (8, 100))  # 8 sequences, max length 100
lengths = torch.tensor([100, 85, 72, 90, 45, 88, 95, 60])

logits = model(x, lengths)
print(f"Text Classifier:")
print(f"  Input: {x.shape}")
print(f"  Output: {logits.shape}")
print(f"  Parameters: {sum(p.numel() for p in model.parameters()):,}")

Key Takeaways

LSTM and GRU architectures solve the vanishing gradient problem through gating mechanisms that control information flow. LSTMs use three gates (forget, input, output) and a separate cell state pathway that enables long-term memory preservation with minimal gradient decay. GRUs simplify this to two gates (reset, update) and a single hidden state while achieving comparable performance. The choice between LSTM and GRU typically depends on computational constraints and the specific task, with GRUs offering 25% parameter reduction and faster training. Both architectures can be stacked for greater model capacity and configured as bidirectional for tasks where full context is available. While largely superseded by transformers for many NLP tasks, LSTMs and GRUs remain important for time series, streaming applications, and scenarios where their sequential processing and lower memory footprint offer advantages.

12.3 Sequence-to-Sequence Models Intermediate

Sequence-to-Sequence Models

Sequence-to-sequence (seq2seq) models transform one sequence into another, enabling applications where input and output lengths differ and have no direct element-wise correspondence. Machine translation converts sentences between languages, summarization compresses long documents into shorter versions, and speech recognition transcribes audio waveforms into text. The encoder-decoder architecture that underpins seq2seq learning represents one of the most influential ideas in neural network design, establishing patterns that persist even in transformer-based models.

The Encoder-Decoder Architecture

The encoder-decoder framework divides sequence-to-sequence learning into two distinct phases. The encoder processes the input sequence and compresses it into a fixed-dimensional representation called the context vector. The decoder then generates the output sequence conditioned on this context, producing one token at a time until generating an end-of-sequence marker.

The encoder is typically an RNN (LSTM or GRU) that reads the input sequence and produces a final hidden state summarizing the entire input. This hidden state becomes the context vector, encapsulating all information the decoder needs to generate the output. For bidirectional encoders, the forward and backward final states are typically concatenated or combined to form the context.

The decoder is another RNN that generates the output sequence autoregressively, meaning each output token depends on previously generated tokens. At each step, the decoder takes its previous hidden state and the previous output token, producing a new hidden state and predicting the next token. During training, the true previous tokens are used (teacher forcing); during inference, the model's own predictions are fed back.

PYTHON
import torch
import torch.nn as nn

class Encoder(nn.Module):
    """
    Encoder: Compresses input sequence into context vector.
    """
    def __init__(self, input_size, embed_size, hidden_size, num_layers=1):
        super().__init__()
        self.embedding = nn.Embedding(input_size, embed_size)
        self.rnn = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)

    def forward(self, x):
        # x: (batch, seq_len)
        embedded = self.embedding(x)  # (batch, seq_len, embed_size)
        outputs, (hidden, cell) = self.rnn(embedded)
        # hidden, cell: (num_layers, batch, hidden_size)
        return hidden, cell

class Decoder(nn.Module):
    """
    Decoder: Generates output sequence from context vector.
    """
    def __init__(self, output_size, embed_size, hidden_size, num_layers=1):
        super().__init__()
        self.embedding = nn.Embedding(output_size, embed_size)
        self.rnn = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden, cell):
        # x: (batch, 1) - single token
        embedded = self.embedding(x)  # (batch, 1, embed_size)
        output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
        prediction = self.fc(output.squeeze(1))  # (batch, output_size)
        return prediction, hidden, cell

class Seq2Seq(nn.Module):
    """
    Complete sequence-to-sequence model.
    """
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        batch_size = src.size(0)
        trg_len = trg.size(1)
        trg_vocab_size = self.decoder.fc.out_features

        # Store decoder outputs
        outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(self.device)

        # Encode input sequence
        hidden, cell = self.encoder(src)

        # First decoder input is <SOS> token
        dec_input = trg[:, 0:1]

        for t in range(1, trg_len):
            prediction, hidden, cell = self.decoder(dec_input, hidden, cell)
            outputs[:, t, :] = prediction

            # Teacher forcing: use actual target or predicted token
            teacher_force = torch.rand(1).item() < teacher_forcing_ratio

            if teacher_force:
                dec_input = trg[:, t:t+1]
            else:
                dec_input = prediction.argmax(dim=1, keepdim=True)

        return outputs

# Create model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
encoder = Encoder(input_size=5000, embed_size=256, hidden_size=512, num_layers=2)
decoder = Decoder(output_size=5000, embed_size=256, hidden_size=512, num_layers=2)
model = Seq2Seq(encoder, decoder, device).to(device)

# Test
src = torch.randint(1, 5000, (4, 20)).to(device)  # batch=4, src_len=20
trg = torch.randint(1, 5000, (4, 25)).to(device)  # batch=4, trg_len=25

outputs = model(src, trg)
print(f"Seq2Seq Model:")
print(f"  Source: {src.shape}")
print(f"  Target: {trg.shape}")
print(f"  Outputs: {outputs.shape}")
print(f"  Parameters: {sum(p.numel() for p in model.parameters()):,}")

The Information Bottleneck Problem

The basic encoder-decoder architecture suffers from an information bottleneck: the entire input sequence must be compressed into a single fixed-size context vector. For long sequences, this compression inevitably loses information. The decoder must reconstruct potentially hundreds of output tokens from this single vector, a challenging task when important details are far from the most recent encoder states.

This bottleneck manifests as performance degradation on long sequences. Machine translation quality drops noticeably for sentences longer than those seen during training. The encoder's final hidden state emphasizes recent information due to the vanishing gradient problem, even with LSTMs, meaning earlier parts of long inputs are under-represented in the context.

Several approaches partially address this limitation. Reversing the input sequence places the beginning of the source sentence closer to the beginning of the target, improving alignment for languages with similar word orders. Using bidirectional encoders incorporates both forward and backward context. However, the fundamental bottleneck remains until addressed by attention mechanisms.

PYTHON
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

def demonstrate_bottleneck():
    """
    Show how the bottleneck limits information flow.
    """
    print("Information Bottleneck Analysis")
    print("=" * 60)

    hidden_size = 512
    max_information = hidden_size * 32  # Rough bits of information

    sequence_lengths = [10, 50, 100, 200, 500]

    print(f"\nContext vector size: {hidden_size} dimensions")
    print(f"Approximate information capacity: {max_information} bits")
    print()

    for seq_len in sequence_lengths:
        # Assume each token carries ~10 bits of information
        token_info = 10
        total_info = seq_len * token_info
        compression_ratio = total_info / max_information

        status = "OK" if compression_ratio < 1 else "LOSSY"
        print(f"Sequence length {seq_len:4d}: "
              f"{total_info:5d} bits -> {hidden_size} dims "
              f"(compression: {compression_ratio:.2f}x) [{status}]")

    print("\n" + "-" * 60)
    print("As sequence length increases, more information is lost")
    print("in the bottleneck, leading to degraded performance.")

demonstrate_bottleneck()

# Demonstrate with actual encoding
def measure_reconstruction():
    """Test how well information is preserved through encoding."""
    print("\n" + "=" * 60)
    print("Reconstruction Quality vs Sequence Length")
    print("-" * 60)

    hidden_size = 256

    for seq_len in [10, 50, 100]:
        # Create encoder
        encoder = nn.LSTM(64, hidden_size, batch_first=True)

        # Random input sequence
        x = torch.randn(1, seq_len, 64)

        # Encode
        _, (h, c) = encoder(x)

        # The hidden state must represent all of x
        # More length = more compression
        print(f"Seq len {seq_len:3d}: input {x.numel()} values -> "
              f"context {h.numel() + c.numel()} values "
              f"(compression: {x.numel() / (h.numel() + c.numel()):.1f}x)")

measure_reconstruction()

Teacher Forcing and Exposure Bias

Teacher forcing is a training technique where the decoder receives the ground truth previous token rather than its own prediction. This approach accelerates training and improves stability because the model doesn't compound errors during the training sequence. However, it creates a discrepancy between training and inference conditions known as exposure bias.

During training with teacher forcing, the decoder always sees correct previous tokens. During inference, it sees its own potentially incorrect predictions. If the model makes an error, it enters a distribution of inputs it never encountered during training, potentially leading to cascading errors. This mismatch can cause significant quality degradation at inference time.

Several strategies mitigate exposure bias. Scheduled sampling gradually transitions from teacher forcing to using model predictions during training. Curriculum learning starts with easier, shorter sequences before progressing to harder ones. Sequence-level training objectives like REINFORCE or minimum risk training directly optimize for sequence quality rather than token-level cross-entropy.

PYTHON
import torch
import torch.nn as nn
import random

class ScheduledSamplingDecoder(nn.Module):
    """
    Decoder with scheduled sampling to reduce exposure bias.
    """
    def __init__(self, output_size, embed_size, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(output_size, embed_size)
        self.rnn = nn.LSTM(embed_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, trg, hidden, cell, teacher_forcing_ratio):
        batch_size = trg.size(0)
        trg_len = trg.size(1)
        output_size = self.fc.out_features

        outputs = torch.zeros(batch_size, trg_len, output_size, device=trg.device)

        dec_input = trg[:, 0:1]  # <SOS>

        for t in range(1, trg_len):
            embedded = self.embedding(dec_input)
            output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
            prediction = self.fc(output.squeeze(1))
            outputs[:, t, :] = prediction

            # Scheduled sampling
            if random.random() < teacher_forcing_ratio:
                dec_input = trg[:, t:t+1]
            else:
                dec_input = prediction.argmax(dim=1, keepdim=True)

        return outputs

def scheduled_sampling_schedule(epoch, total_epochs, strategy='linear'):
    """
    Compute teacher forcing ratio based on training progress.
    """
    if strategy == 'linear':
        # Linear decay from 1.0 to 0.0
        return max(0, 1 - epoch / total_epochs)

    elif strategy == 'exponential':
        # Exponential decay
        k = 5  # decay rate
        return k / (k + torch.exp(torch.tensor(epoch / k)).item())

    elif strategy == 'inverse_sigmoid':
        # Inverse sigmoid decay
        k = 5
        return k / (k + torch.exp(torch.tensor((epoch - total_epochs/2) / k)).item())

# Demonstrate schedules
print("Teacher Forcing Schedules:")
print("-" * 50)
total_epochs = 20
for epoch in range(0, total_epochs + 1, 5):
    linear = scheduled_sampling_schedule(epoch, total_epochs, 'linear')
    exp = scheduled_sampling_schedule(epoch, total_epochs, 'exponential')
    print(f"Epoch {epoch:2d}: linear={linear:.2f}, exponential={exp:.2f}")

Beam Search Decoding

Greedy decoding selects the highest-probability token at each step, but this locally optimal choice may not lead to the globally best sequence. Beam search maintains multiple candidate sequences (beams) during decoding, exploring a broader search space while remaining computationally tractable.

At each step, beam search expands each of the $k$ current beams by considering all possible next tokens, scoring the resulting $k \times V$ candidates (where $V$ is vocabulary size), and keeping only the top $k$ sequences. This process continues until all beams produce an end-of-sequence token or reach maximum length.

The beam width $k$ controls the trade-off between quality and computation. Larger beams explore more possibilities but require more memory and time. In practice, beam widths of 4-10 work well for many tasks, with diminishing returns beyond that. Length normalization adjusts scores to prevent the search from favoring shorter sequences.

PYTHON
import torch
import torch.nn as nn
import torch.nn.functional as F

class BeamSearchDecoder:
    """
    Beam search decoder for sequence generation.
    """
    def __init__(self, decoder, beam_width=5, max_len=50, sos_id=1, eos_id=2):
        self.decoder = decoder
        self.beam_width = beam_width
        self.max_len = max_len
        self.sos_id = sos_id
        self.eos_id = eos_id

    def decode(self, hidden, cell):
        """
        Perform beam search decoding.

        Args:
            hidden, cell: Initial decoder states from encoder

        Returns:
            best_sequence: Highest-scoring complete sequence
        """
        device = hidden.device
        batch_size = hidden.size(1)
        vocab_size = self.decoder.fc.out_features

        # Initialize beams: (score, sequence, hidden, cell)
        beams = [(0.0, [self.sos_id], hidden, cell)]

        complete_sequences = []

        for _ in range(self.max_len):
            all_candidates = []

            for score, seq, h, c in beams:
                if seq[-1] == self.eos_id:
                    complete_sequences.append((score, seq))
                    continue

                # Get last token
                last_token = torch.tensor([[seq[-1]]], device=device)

                # Decode one step
                with torch.no_grad():
                    prediction, h_new, c_new = self.decoder(last_token, h, c)
                    log_probs = F.log_softmax(prediction, dim=-1)

                # Expand with top-k tokens
                top_log_probs, top_indices = log_probs.topk(self.beam_width)

                for i in range(self.beam_width):
                    new_score = score + top_log_probs[0, i].item()
                    new_token = top_indices[0, i].item()
                    new_seq = seq + [new_token]
                    all_candidates.append((new_score, new_seq, h_new, c_new))

            # Keep top beams
            all_candidates.sort(key=lambda x: x[0], reverse=True)
            beams = all_candidates[:self.beam_width]

            # Check if all beams are complete
            if all(seq[-1] == self.eos_id for _, seq, _, _ in beams):
                break

        # Add any remaining beams to complete sequences
        for score, seq, _, _ in beams:
            complete_sequences.append((score, seq))

        # Return best sequence (with length normalization)
        def length_normalize(score, seq):
            # Alpha for length penalty (typically 0.6-0.7)
            alpha = 0.6
            return score / (len(seq) ** alpha)

        complete_sequences.sort(key=lambda x: length_normalize(x[0], x[1]), reverse=True)

        return complete_sequences[0][1]

# Demonstrate beam search
print("Beam Search Decoding:")
print("-" * 50)
print("Beam width 1 = Greedy search")
print("Beam width k = Explore top-k candidates at each step")
print()
print("Example search tree (beam width 2):")
print("Step 0: <SOS>")
print("Step 1: 'The' (p=0.4), 'A' (p=0.3)")
print("Step 2: 'The cat' (p=0.15), 'The dog' (p=0.12), 'A cat' (p=0.10)...")
print("        Keep top 2: 'The cat', 'The dog'")
print("Step 3: Expand each, keep top 2 overall...")

Bidirectional and Multi-Layer Encoders

Strengthening the encoder helps capture more information before the bottleneck. Bidirectional encoders process the sequence in both directions, capturing context from both past and future at each position. The forward and backward hidden states are concatenated, providing richer representations.

Multi-layer encoders stack RNN layers to learn hierarchical representations. Each layer processes the outputs of the layer below, enabling the model to capture increasingly abstract patterns. Residual connections between layers help gradient flow in deep stacks.

For the decoder to use a bidirectional multi-layer encoder's state effectively, the hidden states must be transformed to match the decoder's expected input. Common approaches include using only the forward final state, concatenating and projecting through a linear layer, or using separate encoder final states to initialize each decoder layer.

PYTHON
import torch
import torch.nn as nn

class BidirectionalEncoder(nn.Module):
    """
    Bidirectional multi-layer encoder with projection for decoder.
    """
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers=2, dropout=0.3):
        super().__init__()

        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.dropout = nn.Dropout(dropout)

        self.rnn = nn.LSTM(
            embed_size, hidden_size,
            num_layers=num_layers,
            bidirectional=True,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0
        )

        # Project bidirectional states to decoder size
        self.fc_hidden = nn.Linear(hidden_size * 2, hidden_size)
        self.fc_cell = nn.Linear(hidden_size * 2, hidden_size)

    def forward(self, x):
        embedded = self.dropout(self.embedding(x))
        outputs, (hidden, cell) = self.rnn(embedded)

        # hidden: (num_layers * 2, batch, hidden_size)
        # Reshape to (num_layers, batch, hidden_size * 2)
        batch_size = hidden.size(1)

        # Concatenate forward and backward for each layer
        hidden = hidden.view(self.num_layers, 2, batch_size, self.hidden_size)
        hidden = torch.cat([hidden[:, 0, :, :], hidden[:, 1, :, :]], dim=2)

        cell = cell.view(self.num_layers, 2, batch_size, self.hidden_size)
        cell = torch.cat([cell[:, 0, :, :], cell[:, 1, :, :]], dim=2)

        # Project to decoder size
        hidden = torch.tanh(self.fc_hidden(hidden))
        cell = torch.tanh(self.fc_cell(cell))

        return outputs, hidden, cell

class ImprovedSeq2Seq(nn.Module):
    """
    Seq2seq with bidirectional encoder.
    """
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers=2):
        super().__init__()

        self.encoder = BidirectionalEncoder(
            vocab_size, embed_size, hidden_size, num_layers
        )
        self.decoder = Decoder(
            vocab_size, embed_size, hidden_size, num_layers
        )

    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        # Encode
        encoder_outputs, hidden, cell = self.encoder(src)

        batch_size = src.size(0)
        trg_len = trg.size(1)
        vocab_size = self.decoder.fc.out_features

        outputs = torch.zeros(batch_size, trg_len, vocab_size, device=src.device)
        dec_input = trg[:, 0:1]

        for t in range(1, trg_len):
            prediction, hidden, cell = self.decoder(dec_input, hidden, cell)
            outputs[:, t, :] = prediction

            if torch.rand(1).item() < teacher_forcing_ratio:
                dec_input = trg[:, t:t+1]
            else:
                dec_input = prediction.argmax(dim=1, keepdim=True)

        return outputs

# Test improved model
model = ImprovedSeq2Seq(vocab_size=5000, embed_size=256, hidden_size=512, num_layers=2)
src = torch.randint(1, 5000, (4, 20))
trg = torch.randint(1, 5000, (4, 25))

outputs = model(src, trg)
print(f"Improved Seq2Seq (bidirectional encoder):")
print(f"  Source: {src.shape}")
print(f"  Target: {trg.shape}")
print(f"  Outputs: {outputs.shape}")
print(f"  Parameters: {sum(p.numel() for p in model.parameters()):,}")

Training Sequence-to-Sequence Models

Training seq2seq models involves several practical considerations beyond standard neural network training. The loss function computes cross-entropy between predicted and true tokens, typically ignoring padding positions. Gradient clipping prevents exploding gradients common in long sequences. Learning rate scheduling and careful initialization improve convergence.

Handling variable-length sequences requires padding and masking. Inputs are padded to batch maximum length, and the loss function masks out padding positions. PyTorch's pack<em>padded</em>sequence and pad<em>packed</em>sequence enable efficient processing without explicitly handling padding in the RNN.

Validation should use the same decoding strategy as inference (greedy or beam search) rather than teacher forcing, providing a more accurate estimate of test-time performance. BLEU score and other sequence-level metrics complement token-level perplexity for evaluation.

PYTHON
import torch
import torch.nn as nn
import torch.optim as optim

class Seq2SeqTrainer:
    """
    Trainer for sequence-to-sequence models.
    """
    def __init__(self, model, vocab_size, pad_idx=0, device='cpu'):
        self.model = model.to(device)
        self.device = device

        # Ignore padding in loss
        self.criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
        self.optimizer = optim.Adam(model.parameters(), lr=0.001)
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, patience=3, factor=0.5
        )

    def train_epoch(self, data_loader, clip=1.0, teacher_forcing_ratio=0.5):
        self.model.train()
        total_loss = 0

        for src, trg in data_loader:
            src = src.to(self.device)
            trg = trg.to(self.device)

            self.optimizer.zero_grad()

            output = self.model(src, trg, teacher_forcing_ratio)

            # Reshape for loss: (batch * trg_len, vocab_size)
            output = output[:, 1:, :].contiguous().view(-1, output.size(-1))
            trg = trg[:, 1:].contiguous().view(-1)

            loss = self.criterion(output, trg)
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip)

            self.optimizer.step()
            total_loss += loss.item()

        return total_loss / len(data_loader)

    def evaluate(self, data_loader):
        self.model.eval()
        total_loss = 0

        with torch.no_grad():
            for src, trg in data_loader:
                src = src.to(self.device)
                trg = trg.to(self.device)

                # No teacher forcing during evaluation
                output = self.model(src, trg, teacher_forcing_ratio=0)

                output = output[:, 1:, :].contiguous().view(-1, output.size(-1))
                trg = trg[:, 1:].contiguous().view(-1)

                loss = self.criterion(output, trg)
                total_loss += loss.item()

        return total_loss / len(data_loader)

# Training loop example
def train_seq2seq(model, train_loader, val_loader, epochs=10):
    trainer = Seq2SeqTrainer(model, vocab_size=5000, pad_idx=0)

    best_val_loss = float('inf')

    for epoch in range(epochs):
        # Decay teacher forcing
        tf_ratio = max(0.5, 1 - epoch / epochs)

        train_loss = trainer.train_epoch(train_loader, teacher_forcing_ratio=tf_ratio)
        val_loss = trainer.evaluate(val_loader)

        trainer.scheduler.step(val_loss)

        print(f"Epoch {epoch+1}: train_loss={train_loss:.4f}, "
              f"val_loss={val_loss:.4f}, tf_ratio={tf_ratio:.2f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pt')

print("Training Configuration:")
print("-" * 50)
print("Loss: CrossEntropy (ignoring padding)")
print("Optimizer: Adam with lr scheduling")
print("Gradient clipping: max_norm=1.0")
print("Teacher forcing: Decay from 1.0 to 0.5")
print("Evaluation: No teacher forcing")

Key Takeaways

Sequence-to-sequence models enable transformations between sequences of different lengths through the encoder-decoder architecture. The encoder compresses the input into a context vector that the decoder uses to generate the output autoregressively. The information bottleneck in basic seq2seq limits performance on long sequences, motivating attention mechanisms. Teacher forcing accelerates training but creates exposure bias between training and inference. Beam search decoding explores multiple candidates to find better sequences than greedy search. Bidirectional and multi-layer encoders capture richer input representations. Training requires careful handling of variable-length sequences, gradient clipping, and appropriate evaluation with inference-time decoding. While transformers have largely superseded RNN-based seq2seq for many applications, the encoder-decoder framework remains foundational for understanding sequence-to-sequence learning.

12.4 Applications: NLP and Time Series Intermediate

Applications: NLP and Time Series

Recurrent neural networks have revolutionized two domains that depend fundamentally on sequential structure: natural language processing and time series analysis. While transformers have become the dominant architecture for many NLP tasks, understanding RNN-based approaches remains essential because they provide the conceptual foundation for sequence modeling, remain practical for resource-constrained deployments, and excel in streaming scenarios where processing must occur incrementally.

Natural Language Processing with RNNs

Natural language presents unique challenges for neural networks. Words carry meaning that depends heavily on context. RNNs address these challenges by maintaining state that accumulates information as they process text sequentially.

The first step in any NLP application involves converting text into numerical representations. Modern systems typically use word embeddings that capture semantic relationships between words.

Sentiment Analysis with LSTM

Sentiment analysis demonstrates the power of recurrent networks for document understanding. The model must read the entire text, accumulating evidence about sentiment, then produce a single classification based on the final hidden state.

Bidirectional processing helps because sentiment indicators can appear anywhere in the text. A review might say "Despite the hype, this product disappointed" where "disappointed" at the end colors the interpretation of earlier words.

Language Modeling

Language modeling predicts the next word given preceding context. A well-trained language model captures syntactic structure, semantic relationships, and even world knowledge implicitly encoded in text statistics.

The temperature parameter controls generation diversity. Lower temperatures produce more deterministic outputs while higher temperatures allow more creative but potentially less coherent generation.

Time Series Forecasting

Time series data differs from text in several important ways. Values are continuous rather than discrete, temporal spacing is typically uniform, and multiple correlated variables often need modeling simultaneously. Financial data, sensor readings, weather measurements, and energy consumption all exhibit patterns that RNNs can learn.

Time series models benefit from careful feature engineering. Beyond raw values, useful features include lag variables, rolling statistics, and time-based features.

Multi-Step Forecasting Strategies

Predicting multiple steps into the future can follow several strategies. Direct forecasting trains separate models for each horizon. Recursive forecasting uses the models own predictions as inputs for subsequent steps. Sequence-to-sequence approaches generate the entire forecast sequence at once.

Recursive forecasting accumulates errors as predictions build on predictions. Direct forecasting avoids error accumulation but requires more parameters.

Handling Irregular Time Series

Real-world time series often have irregular sampling or missing values. Time-aware models can explicitly reason about the duration between observations, applying decay to hidden states or learning time-dependent dynamics.

Multivariate Time Series

Many forecasting problems involve multiple interrelated variables. Stock prices correlate across sectors. Weather variables interact physically. Cross-attention mechanisms allow the model to learn which series provide useful information for predicting others.

Key Takeaways

Applying RNNs to NLP and time series requires domain-appropriate preprocessing, architecture choices, and evaluation strategies.

For NLP applications, tokenization and embedding quality significantly impact performance. Pretrained embeddings provide a strong starting point. Bidirectional processing captures context from both directions.

For time series forecasting, proper normalization prevents numerical instability. Multi-step forecasting requires choosing between direct, recursive, and seq2seq approaches. Walk-forward validation maintains temporal integrity during model selection.

Both domains benefit from careful feature engineering beyond raw inputs. Production deployment requires monitoring and retraining infrastructure beyond the initial model training.

12.5 Advanced RNN Techniques Advanced

Advanced RNN Techniques

While basic RNN architectures establish the foundation for sequence modeling, practical applications often require enhancements that address specific limitations. This section explores advanced techniques that improve training stability, model capacity, and generalization.

Attention Mechanisms for RNNs

The information bottleneck in encoder-decoder architectures motivated the development of attention mechanisms. Rather than compressing the entire input into a single fixed-size context vector, attention allows the decoder to look back at encoder hidden states dynamically.

The attention mechanism computes a context vector as a weighted sum of encoder hidden states. Weights are determined by a compatibility function that scores how well each encoder state matches the current decoder state.

PYTHON
import torch
import torch.nn as nn
import torch.nn.functional as F

class BahdanauAttention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super().__init__()
        self.encoder_proj = nn.Linear(encoder_dim, attention_dim)
        self.decoder_proj = nn.Linear(decoder_dim, attention_dim)
        self.v = nn.Linear(attention_dim, 1, bias=False)

    def forward(self, encoder_outputs, decoder_hidden):
        encoder_proj = self.encoder_proj(encoder_outputs)
        decoder_proj = self.decoder_proj(decoder_hidden).unsqueeze(1)
        scores = self.v(torch.tanh(encoder_proj + decoder_proj)).squeeze(-1)
        attention_weights = F.softmax(scores, dim=1)
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs).squeeze(1)
        return context, attention_weights

Attention dramatically improved machine translation quality, particularly for long sentences.

Deep and Stacked RNNs

Stacking multiple RNN layers increases model capacity by learning hierarchical representations. Each layer processes the hidden states of the layer below, capturing increasingly abstract patterns.

Residual connections add the input of each layer to its output, providing a direct gradient path. Layer normalization normalizes across features rather than across the batch dimension.

PYTHON
class ResidualLSTMLayer(nn.Module):
    def __init__(self, hidden_dim, dropout=0.2):
        super().__init__()
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, states=None):
        output, states = self.lstm(x, states)
        output = self.dropout(output)
        output = self.layer_norm(output + x)
        return output, states

Variational Dropout

Standard dropout applies a different mask at each timestep, which can harm RNN performance. Variational dropout applies the same dropout mask across all timesteps, preserving temporal coherence.

PYTHON
class VariationalDropout(nn.Module):
    def __init__(self, dropout_rate=0.5):
        super().__init__()
        self.dropout_rate = dropout_rate

    def forward(self, x):
        if not self.training:
            return x
        batch_size, seq_len, features = x.shape
        mask = x.new_empty(batch_size, 1, features).bernoulli_(1 - self.dropout_rate)
        mask = mask / (1 - self.dropout_rate)
        return x * mask

Weight Tying

Weight tying shares parameters between the input embedding and output projection layers of language models. This reduces model size significantly and often improves generalization.

PYTHON
class WeightTiedLanguageModel(nn.Module):
    def __init__(self, vocab_size, hidden_dim, num_layers=2, dropout=0.3):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers=num_layers,
                           batch_first=True, dropout=dropout if num_layers > 1 else 0)
        self.output_projection = nn.Linear(hidden_dim, vocab_size, bias=False)
        self.output_projection.weight = self.embedding.weight

    def forward(self, x):
        embedded = self.embedding(x)
        output, _ = self.lstm(embedded)
        return self.output_projection(output)

Recurrent Highway Networks

Recurrent highway networks extend the highway network concept to RNNs, adding gated connections that allow information to flow unchanged across many timesteps.

Bidirectional Training Strategies

Bidirectional RNNs process sequences in both directions but require the full sequence at inference time. For streaming applications, chunked bidirectional processing segments the input into chunks, applying bidirectional processing within each chunk.

Continuous-Time RNNs

Standard RNNs operate on discrete timesteps, but some applications involve irregularly sampled data. Neural ordinary differential equations model hidden state dynamics as continuous functions.

Key Takeaways

Advanced RNN techniques address fundamental limitations of basic architectures. Attention mechanisms remove the information bottleneck. Deep stacking with residual connections enables learning hierarchical representations. Variational dropout provides stronger regularization. Weight tying reduces parameters while enforcing semantic consistency.

These techniques represent the evolution of sequence modeling toward the transformer architecture. Understanding these intermediate steps provides insight into why certain design decisions work and when simpler RNN approaches may still be preferable.