Understanding Attention Mechanisms
The attention mechanism is the cornerstone of modern language models. In this post, we'll explore how it works from the ground up.
The Core Intuition
At its heart, attention answers a simple question: "Given a query, which parts of the input should I focus on?"
Think of it like a librarian searching through documents. When you ask a question (query), the librarian scans through all available books (keys) and retrieves the most relevant passages (values).
Mathematical Foundation
The scaled dot-product attention is defined as:
import torch
import torch.nn.functional as F
def scaled_dot_product_attention(query, key, value, mask=None):
"""
Compute scaled dot-product attention.
Args:
query: (batch, seq_len, d_k)
key: (batch, seq_len, d_k)
value: (batch, seq_len, d_v)
mask: optional attention mask
"""
d_k = query.size(-1)
# Compute attention scores
scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(d_k)
# Apply mask if provided
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax to get attention weights
attention_weights = F.softmax(scores, dim=-1)
# Weighted sum of values
return torch.matmul(attention_weights, value), attention_weightsThe key insight is the scaling factor 1/√d_k. Without it, dot products grow large for high-dimensional vectors, pushing softmax into regions with extremely small gradients.
Multi-Head Attention
Rather than performing a single attention function, we can run multiple attention heads in parallel:
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
# Each head operates on a subspace
self.d_k = d_model // num_heads
self.num_heads = num_heads
# Linear projections for Q, K, V
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)The multi-head mechanism allows the model to jointly attend to information from different representation subspaces at different positions. — Vaswani et al., "Attention Is All You Need"
Why This Matters
Understanding attention is crucial because:
- Interpretability: Attention weights reveal what the model focuses on
- Efficiency: Self-attention processes all positions in parallel
- Long-range dependencies: No vanishing gradients across sequence length
Key Takeaways
| Concept | Description |
|---|---|
| Query | What we're looking for |
| Key | What we're searching through |
| Value | What we retrieve |
| Scaling | Prevents gradient saturation |
| Multi-head | Parallel attention in subspaces |
Next, we'll explore how attention enables the full transformer architecture with layer normalization and feed-forward networks.