Causal Masking
The causal mask prevents the model from "peeking" at future tokens during training:
python
1def create_causal_mask(seq_len: int) -> torch.Tensor:2 """Create a causal mask to prevent attending to future tokens."""3 mask = torch.tril(torch.ones(seq_len, seq_len))4 return mask.unsqueeze(0).unsqueeze(0)5
6# Example for seq_len=4:7# [[[[1., 0., 0., 0.], ← Token 0 sees only itself8# [1., 1., 0., 0.], ← Token 1 sees tokens 0, 19# [1., 1., 1., 0.], ← Token 2 sees tokens 0, 1, 210# [1., 1., 1., 1.]]]] ← Token 3 sees allApplied in the attention forward pass:
python
1if mask is not None:2 scores = scores.masked_fill(mask == 0, float("-inf"))3attn_weights = F.softmax(scores, dim=-1)4# softmax(-inf) = 0, so masked positions get 0% attentionHelpful?