Build Your First LLM from ScratchPart 4 · Section 7 of 7
Complete Attention Module
Putting It All Together
class Attention(nn.Module):
"""
Complete attention module with:
- Multi-head attention
- Optional causal masking
- Dropout for regularization
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0
):
super().__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.W_q = nn.Linear(embed_dim, embed_dim)
self.W_k = nn.Linear(embed_dim, embed_dim)
self.W_v = nn.Linear(embed_dim, embed_dim)
self.W_o = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
"""
Args:
x: [batch_size, seq_len, embed_dim]
mask: Optional causal mask
Returns:
[batch_size, seq_len, embed_dim]
"""
batch_size, seq_len, _ = x.shape
# Project to Q, K, V and split into heads
Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# Attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores.masked_fill(mask, float('-inf'))
attention = F.softmax(scores, dim=-1)
attention = self.dropout(attention)
# Apply attention to values
out = torch.matmul(attention, V)
# Concatenate heads and project
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
return self.W_o(out)Visual Summary
Input: embeddings [batch, seq_len, 64]
|
+---------+---------+
| | | |
v v v v
Head 1 Head 2 Head 3 Head 4
(16d) (16d) (16d) (16d)
| | | |
+---------+---------+---------+
|
Concatenate (64d)
|
Output projection
|
Output: [batch, seq_len, 64]Summary Comparison
| Component | Our Model | GPT-4 Scale |
|---|---|---|
| Embed dim | 64 | 12,288 |
| Num heads | 4 | 96 |
| Head dim | 16 | 128 |
| Attention params | ~16K | ~450M per layer |
| Attention matrix | 5×5 | 8000×8000 |
What You Can Now Do
- Explain what attention does (intuitively)
- Understand Query, Key, Value concepts
- Implement single-head attention
- Implement multi-head attention
- Apply causal masking for generation
Next up: In Part 5, we'll combine attention with feed-forward networks to build complete Transformer blocks!
Helpful?