Causal Masking

The causal mask prevents the model from "peeking" at future tokens during training:

python
1def create_causal_mask(seq_len: int) -> torch.Tensor:
2 """Create a causal mask to prevent attending to future tokens."""
3 mask = torch.tril(torch.ones(seq_len, seq_len))
4 return mask.unsqueeze(0).unsqueeze(0)
5
6# Example for seq_len=4:
7# [[[[1., 0., 0., 0.], ← Token 0 sees only itself
8# [1., 1., 0., 0.], ← Token 1 sees tokens 0, 1
9# [1., 1., 1., 0.], ← Token 2 sees tokens 0, 1, 2
10# [1., 1., 1., 1.]]]] ← Token 3 sees all

Applied in the attention forward pass:

python
1if mask is not None:
2 scores = scores.masked_fill(mask == 0, float("-inf"))
3attn_weights = F.softmax(scores, dim=-1)
4# softmax(-inf) = 0, so masked positions get 0% attention
Helpful?