Multi-Head Attention

Multi-head attention runs multiple attention heads in parallel, each learning different patterns. Here's the implementation from src/model.py:

python
1class MultiHeadAttention(nn.Module):
2 """Multi-head self-attention mechanism."""
3
4 def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.1):
5 super().__init__()
6 self.embed_dim = embed_dim
7 self.num_heads = num_heads
8 self.head_dim = embed_dim // num_heads
9
10 self.q_proj = nn.Linear(embed_dim, embed_dim)
11 self.k_proj = nn.Linear(embed_dim, embed_dim)
12 self.v_proj = nn.Linear(embed_dim, embed_dim)
13 self.out_proj = nn.Linear(embed_dim, embed_dim)
14 self.dropout = nn.Dropout(dropout)
15 self.scale = math.sqrt(self.head_dim)
16
17 def forward(
18 self, x: torch.Tensor, mask: torch.Tensor | None = None
19 ) -> tuple[torch.Tensor, torch.Tensor]:
20 batch_size, seq_len, _ = x.shape
21
22 # Project to Q, K, V and split into heads
23 Q = (
24 self.q_proj(x)
25 .view(batch_size, seq_len, self.num_heads, self.head_dim)
26 .transpose(1, 2)
27 )
28 K = (
29 self.k_proj(x)
30 .view(batch_size, seq_len, self.num_heads, self.head_dim)
31 .transpose(1, 2)
32 )
33 V = (
34 self.v_proj(x)
35 .view(batch_size, seq_len, self.num_heads, self.head_dim)
36 .transpose(1, 2)
37 )
38
39 # Attention scores
40 scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
41 if mask is not None:
42 scores = scores.masked_fill(mask == 0, float("-inf"))
43
44 attn_weights = F.softmax(scores, dim=-1)
45 attn_weights = self.dropout(attn_weights)
46 attn_output = torch.matmul(attn_weights, V)
47
48 # Concatenate heads
49 attn_output = (
50 attn_output.transpose(1, 2)
51 .contiguous()
52 .view(batch_size, seq_len, self.embed_dim)
53 )
54 return self.out_proj(attn_output), attn_weights

Tests

python
1# tests/test_model.py
2def test_attention_output_shape():
3 attn = MultiHeadAttention(embed_dim=64, num_heads=4)
4 x = torch.randn(2, 10, 64)
5 output, weights = attn(x)
6 assert output.shape == x.shape
7
8def test_attention_weights_sum_to_one():
9 attn = MultiHeadAttention(embed_dim=64, num_heads=4)
10 x = torch.randn(2, 10, 64)
11 output, weights = attn(x)
12 row_sums = weights.sum(dim=-1)
13 assert torch.allclose(row_sums, torch.ones_like(row_sums), atol=1e-5)
14
15def test_masked_attention():
16 attn = MultiHeadAttention(embed_dim=64, num_heads=4)
17 x = torch.randn(1, 5, 64)
18 mask = create_causal_mask(5)
19 output, weights = attn(x, mask=mask)
20 # Upper triangle should be ~0 (masked out)
21 assert weights[0, :, 0, 1].max() < 1e-5
Helpful?