Build Your First LLM from ScratchPart 4 · Section 7 of 7

Complete Attention Module

Putting It All Together

class Attention(nn.Module):
    """
    Complete attention module with:
    - Multi-head attention
    - Optional causal masking
    - Dropout for regularization
    """

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        dropout: float = 0.0
    ):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.W_q = nn.Linear(embed_dim, embed_dim)
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)
        self.W_o = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        """
        Args:
            x: [batch_size, seq_len, embed_dim]
            mask: Optional causal mask

        Returns:
            [batch_size, seq_len, embed_dim]
        """
        batch_size, seq_len, _ = x.shape

        # Project to Q, K, V and split into heads
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)

        if mask is not None:
            scores = scores.masked_fill(mask, float('-inf'))

        attention = F.softmax(scores, dim=-1)
        attention = self.dropout(attention)

        # Apply attention to values
        out = torch.matmul(attention, V)

        # Concatenate heads and project
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        return self.W_o(out)

Visual Summary

Input: embeddings [batch, seq_len, 64]
              |
    +---------+---------+
    |         |         |         |
    v         v         v         v
  Head 1    Head 2    Head 3    Head 4
  (16d)     (16d)     (16d)     (16d)
    |         |         |         |
    +---------+---------+---------+
              |
        Concatenate (64d)
              |
        Output projection
              |
Output: [batch, seq_len, 64]

Summary Comparison

ComponentOur ModelGPT-4 Scale
Embed dim6412,288
Num heads496
Head dim16128
Attention params~16K~450M per layer
Attention matrix5×58000×8000

What You Can Now Do

  • Explain what attention does (intuitively)
  • Understand Query, Key, Value concepts
  • Implement single-head attention
  • Implement multi-head attention
  • Apply causal masking for generation
Next up: In Part 5, we'll combine attention with feed-forward networks to build complete Transformer blocks!
Helpful?