Multi-Head Attention
Multi-head attention runs multiple attention heads in parallel, each learning different patterns. Here's the implementation from src/model.py:
python
1class MultiHeadAttention(nn.Module):2 """Multi-head self-attention mechanism."""3
4 def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.1):5 super().__init__()6 self.embed_dim = embed_dim7 self.num_heads = num_heads8 self.head_dim = embed_dim // num_heads9
10 self.q_proj = nn.Linear(embed_dim, embed_dim)11 self.k_proj = nn.Linear(embed_dim, embed_dim)12 self.v_proj = nn.Linear(embed_dim, embed_dim)13 self.out_proj = nn.Linear(embed_dim, embed_dim)14 self.dropout = nn.Dropout(dropout)15 self.scale = math.sqrt(self.head_dim)16
17 def forward(18 self, x: torch.Tensor, mask: torch.Tensor | None = None19 ) -> tuple[torch.Tensor, torch.Tensor]:20 batch_size, seq_len, _ = x.shape21
22 # Project to Q, K, V and split into heads23 Q = (24 self.q_proj(x)25 .view(batch_size, seq_len, self.num_heads, self.head_dim)26 .transpose(1, 2)27 )28 K = (29 self.k_proj(x)30 .view(batch_size, seq_len, self.num_heads, self.head_dim)31 .transpose(1, 2)32 )33 V = (34 self.v_proj(x)35 .view(batch_size, seq_len, self.num_heads, self.head_dim)36 .transpose(1, 2)37 )38
39 # Attention scores40 scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale41 if mask is not None:42 scores = scores.masked_fill(mask == 0, float("-inf"))43
44 attn_weights = F.softmax(scores, dim=-1)45 attn_weights = self.dropout(attn_weights)46 attn_output = torch.matmul(attn_weights, V)47
48 # Concatenate heads49 attn_output = (50 attn_output.transpose(1, 2)51 .contiguous()52 .view(batch_size, seq_len, self.embed_dim)53 )54 return self.out_proj(attn_output), attn_weightsTests
python
1# tests/test_model.py2def test_attention_output_shape():3 attn = MultiHeadAttention(embed_dim=64, num_heads=4)4 x = torch.randn(2, 10, 64)5 output, weights = attn(x)6 assert output.shape == x.shape7
8def test_attention_weights_sum_to_one():9 attn = MultiHeadAttention(embed_dim=64, num_heads=4)10 x = torch.randn(2, 10, 64)11 output, weights = attn(x)12 row_sums = weights.sum(dim=-1)13 assert torch.allclose(row_sums, torch.ones_like(row_sums), atol=1e-5)14
15def test_masked_attention():16 attn = MultiHeadAttention(embed_dim=64, num_heads=4)17 x = torch.randn(1, 5, 64)18 mask = create_causal_mask(5)19 output, weights = attn(x, mask=mask)20 # Upper triangle should be ~0 (masked out)21 assert weights[0, :, 0, 1].max() < 1e-5Helpful?