Understanding Attention Mechanisms

The attention mechanism is the cornerstone of modern language models. In this post, we'll explore how it works from the ground up.

The Core Intuition

At its heart, attention answers a simple question: "Given a query, which parts of the input should I focus on?"

Think of it like a librarian searching through documents. When you ask a question (query), the librarian scans through all available books (keys) and retrieves the most relevant passages (values).

Mathematical Foundation

The scaled dot-product attention is defined as:

import torch
import torch.nn.functional as F
 
def scaled_dot_product_attention(query, key, value, mask=None):
    """
    Compute scaled dot-product attention.
    
    Args:
        query: (batch, seq_len, d_k)
        key: (batch, seq_len, d_k)
        value: (batch, seq_len, d_v)
        mask: optional attention mask
    """
    d_k = query.size(-1)
    
    # Compute attention scores
    scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(d_k)
    
    # Apply mask if provided
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    
    # Softmax to get attention weights
    attention_weights = F.softmax(scores, dim=-1)
    
    # Weighted sum of values
    return torch.matmul(attention_weights, value), attention_weights

The key insight is the scaling factor 1/√d_k. Without it, dot products grow large for high-dimensional vectors, pushing softmax into regions with extremely small gradients.

Multi-Head Attention

Rather than performing a single attention function, we can run multiple attention heads in parallel:

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        # Each head operates on a subspace
        self.d_k = d_model // num_heads
        self.num_heads = num_heads
        
        # Linear projections for Q, K, V
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

The multi-head mechanism allows the model to jointly attend to information from different representation subspaces at different positions. — Vaswani et al., "Attention Is All You Need"

Why This Matters

Understanding attention is crucial because:

  1. Interpretability: Attention weights reveal what the model focuses on
  2. Efficiency: Self-attention processes all positions in parallel
  3. Long-range dependencies: No vanishing gradients across sequence length

Key Takeaways

ConceptDescription
QueryWhat we're looking for
KeyWhat we're searching through
ValueWhat we retrieve
ScalingPrevents gradient saturation
Multi-headParallel attention in subspaces

Next, we'll explore how attention enables the full transformer architecture with layer normalization and feed-forward networks.

Want to go deeper? Explore our premium series.

View Series