Introduction to Attention
The attention mechanism represents one of the most significant advances in deep learning, fundamentally changing how neural networks process sequential and structured data. Before attention, sequence-to-sequence models compressed entire inputs into fixed-size vectors, creating an information bottleneck that degraded performance on long sequences.
The Information Bottleneck Problem
Traditional encoder-decoder architectures suffered from a fundamental limitation. The encoder processed the entire input sequence and produced a single fixed-dimensional context vector. The decoder then had to generate the entire output conditioned only on this vector, regardless of input length.
Consider translating a long paragraph. The encoder must compress all nuances, entities, and grammatical structures into perhaps 512 numbers. Information inevitably gets lost, particularly for elements early in the sequence.
import torch
import torch.nn as nn
import torch.nn.functional as F
def visualize_bottleneck():
hidden_size = 256
max_bits = hidden_size * 32
print("Information Bottleneck Analysis")
print("=" * 50)
for seq_len in [10, 50, 100, 200]:
input_bits = seq_len * 50
compression = input_bits / max_bits
status = "OK" if compression <= 1 else f"LOSSY"
print(f"Length {seq_len:3d}: {compression:.2f}x compression [{status}]")
visualize_bottleneck()The Core Attention Idea
Attention solves the bottleneck by giving the decoder direct access to all encoder hidden states. At each decoding step, the model computes a weighted combination of encoder states, where weights indicate which input positions are most relevant.
The mechanism works as follows:
- The encoder produces hidden states for each input position
- Compute compatibility scores between decoder state and each encoder state
- Normalize scores using softmax to get attention weights
- Compute weighted sum of encoder states as context vector
- Use context with decoder state to predict the next output
class BasicAttention(nn.Module):
def __init__(self, encoder_dim, decoder_dim):
super().__init__()
self.encoder_dim = encoder_dim
self.decoder_dim = decoder_dim
def forward(self, encoder_outputs, decoder_hidden):
scores = torch.bmm(encoder_outputs, decoder_hidden.unsqueeze(-1)).squeeze(-1)
attention_weights = F.softmax(scores, dim=1)
context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs).squeeze(1)
return context, attention_weightsQuery, Key, and Value
Modern attention uses the query-key-value (QKV) paradigm:
- Query (Q): What we are looking for
- Key (K): What each position offers for matching
- Value (V): The actual information to retrieve
The scaling factor prevents dot products from growing too large.
class ScaledDotProductAttention(nn.Module):
def __init__(self, dropout=0.0):
super().__init__()
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
d_k = query.size(-1)
scores = torch.bmm(query, key.transpose(-2, -1)) / (d_k ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
output = torch.bmm(attention_weights, value)
return output, attention_weightsAdditive vs Multiplicative Attention
Additive Attention (Bahdanau): Concatenates query and key, passes through a feedforward network. More expressive but slower.
Multiplicative Attention (Luong): Computes dot product between query and key. Faster and scales better.
class AdditiveAttention(nn.Module):
def __init__(self, query_dim, key_dim, hidden_dim):
super().__init__()
self.query_proj = nn.Linear(query_dim, hidden_dim, bias=False)
self.key_proj = nn.Linear(key_dim, hidden_dim, bias=False)
self.v = nn.Linear(hidden_dim, 1, bias=False)
def forward(self, query, key, value, mask=None):
query_proj = self.query_proj(query).unsqueeze(2)
key_proj = self.key_proj(key).unsqueeze(1)
scores = self.v(torch.tanh(query_proj + key_proj)).squeeze(-1)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.bmm(attention_weights, value)
return output, attention_weightsVisualizing Attention
Attention weights reveal what the model focuses on when generating each output. For translation, these weights often correspond to word alignments.
import numpy as np
source = ['The', 'cat', 'sat', 'on', 'the', 'mat']
target = ['Le', 'chat', 'assis', 'sur', 'le', 'tapis']
# Attention weights show alignment
attention = np.array([
[0.8, 0.1, 0.0, 0.0, 0.1, 0.0], # Le -> The
[0.1, 0.8, 0.1, 0.0, 0.0, 0.0], # chat -> cat
[0.0, 0.1, 0.8, 0.1, 0.0, 0.0], # assis -> sat
[0.0, 0.0, 0.1, 0.8, 0.1, 0.0], # sur -> on
[0.1, 0.0, 0.0, 0.1, 0.7, 0.1], # le -> the
[0.0, 0.0, 0.0, 0.0, 0.1, 0.9], # tapis -> mat
])
print("Attention shows word alignments:")
for i, tgt in enumerate(target):
max_idx = attention[i].argmax()
print(f" '{tgt}' attends to '{source[max_idx]}' ({attention[i, max_idx]:.1f})")Key Takeaways
Attention mechanisms removed the information bottleneck in encoder-decoder architectures. The query-key-value formulation provides a general framework where queries specify what to look for, keys specify what each position offers, and values contain the information to retrieve.
Attention weights provide interpretable insights into model behavior. This interpretability aids debugging and builds trust in model predictions.
The attention mechanism laid the groundwork for transformers, which replace recurrence entirely with attention.