Back

Putting It Together: Single-Head Attention

Everything we've discussed—Query/Key matching and Value blending—is called single-head attention. Let's see it in code.

python
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import math
5
6class SingleHeadAttention(nn.Module):
7 def __init__(self, embed_dim: int):
8 """
9 Args:
10 embed_dim: Dimension of input embeddings (64 for our model)
11 """
12 super().__init__()
13 self.W_q = nn.Linear(embed_dim, embed_dim) # Query projection
14 self.W_k = nn.Linear(embed_dim, embed_dim) # Key projection
15 self.W_v = nn.Linear(embed_dim, embed_dim) # Value projection
16 self.scale = math.sqrt(embed_dim)
17
18 def forward(self, x):
19 """
20 Args:
21 x: [batch_size, seq_len, embed_dim]
22 Returns:
23 [batch_size, seq_len, embed_dim]
24 """
25 Q = self.W_q(x) # [batch, seq_len, embed_dim]
26 K = self.W_k(x) # [batch, seq_len, embed_dim]
27 V = self.W_v(x) # [batch, seq_len, embed_dim]
28
29 # Attention scores: Q @ K^T
30 scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
31 # [batch, seq_len, seq_len]
32
33 # Convert to probabilities
34 attention_weights = F.softmax(scores, dim=-1)
35
36 # Weighted sum of values
37 output = torch.matmul(attention_weights, V)
38 # [batch, seq_len, embed_dim]
39
40 return output

Mapping Code to Concepts

Let's trace through what happens when "two plus three" goes through this code:

Input x: shape [1, 3, 64]  (1 example, 3 tokens, 64-dim embeddings)
         "two"   → [0.8, 0.1, ...]
         "plus"  → [0.1, 0.9, ...]
         "three" → [0.7, 0.2, ...]

Q = W_q(x)  →  Each token now has a "query" (what it's looking for)
K = W_k(x)  →  Each token now has a "key" (what it offers)
V = W_v(x)  →  Each token now has a "value" (its content)

scores = Q @ K.T  →  3×3 matrix (each token's query vs all keys)
                     This is the attention matrix we visualized!

softmax(scores)   →  Normalize each row to sum to 1
                     "plus" row: [0.45, 0.10, 0.45]

output = weights @ V  →  Each token gathers info from others
                         "plus" now knows it's adding 2 and 3
That's it! The entire attention mechanism in ~15 lines. The complexity comes from doing this at scale—GPT-4 uses 12,288-dimensional embeddings instead of 64, which means 450 million parameters just for Q, K, V projections.
Helpful?