Putting It Together: Single-Head Attention
Everything we've discussed—Query/Key matching and Value blending—is called single-head attention. Let's see it in code.
python
1import torch2import torch.nn as nn3import torch.nn.functional as F4import math5
6class SingleHeadAttention(nn.Module):7 def __init__(self, embed_dim: int):8 """9 Args:10 embed_dim: Dimension of input embeddings (64 for our model)11 """12 super().__init__()13 self.W_q = nn.Linear(embed_dim, embed_dim) # Query projection14 self.W_k = nn.Linear(embed_dim, embed_dim) # Key projection15 self.W_v = nn.Linear(embed_dim, embed_dim) # Value projection16 self.scale = math.sqrt(embed_dim)17
18 def forward(self, x):19 """20 Args:21 x: [batch_size, seq_len, embed_dim]22 Returns:23 [batch_size, seq_len, embed_dim]24 """25 Q = self.W_q(x) # [batch, seq_len, embed_dim]26 K = self.W_k(x) # [batch, seq_len, embed_dim]27 V = self.W_v(x) # [batch, seq_len, embed_dim]28
29 # Attention scores: Q @ K^T30 scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale31 # [batch, seq_len, seq_len]32
33 # Convert to probabilities34 attention_weights = F.softmax(scores, dim=-1)35
36 # Weighted sum of values37 output = torch.matmul(attention_weights, V)38 # [batch, seq_len, embed_dim]39
40 return outputMapping Code to Concepts
Let's trace through what happens when "two plus three" goes through this code:
Input x: shape [1, 3, 64] (1 example, 3 tokens, 64-dim embeddings)
"two" → [0.8, 0.1, ...]
"plus" → [0.1, 0.9, ...]
"three" → [0.7, 0.2, ...]
Q = W_q(x) → Each token now has a "query" (what it's looking for)
K = W_k(x) → Each token now has a "key" (what it offers)
V = W_v(x) → Each token now has a "value" (its content)
scores = Q @ K.T → 3×3 matrix (each token's query vs all keys)
This is the attention matrix we visualized!
softmax(scores) → Normalize each row to sum to 1
"plus" row: [0.45, 0.10, 0.45]
output = weights @ V → Each token gathers info from others
"plus" now knows it's adding 2 and 3That's it! The entire attention mechanism in ~15 lines. The complexity comes from doing this at scale—GPT-4 uses 12,288-dimensional embeddings instead of 64, which means 450 million parameters just for Q, K, V projections.
Helpful?