Feed-Forward Network

After attention blends information between tokens, each token passes through a feed-forward network (FFN) independently. This is a simple 2-layer neural network.

python
1class FeedForward(nn.Module):
2 """Position-wise feed-forward network."""
3
4 def __init__(self, embed_dim: int, ff_dim: int, dropout: float = 0.1):
5 super().__init__()
6 self.linear1 = nn.Linear(embed_dim, ff_dim)
7 self.linear2 = nn.Linear(ff_dim, embed_dim)
8 self.dropout = nn.Dropout(dropout)
9
10 def forward(self, x: torch.Tensor) -> torch.Tensor:
11 return self.linear2(self.dropout(F.relu(self.linear1(x))))
Why 4x expansion? The FFN temporarily expands to a larger dimension (64→256) to give the network more capacity to learn complex transformations, then contracts back.
Helpful?