Embeddings & Positions

The TokenEmbedding class in src/model.py converts token IDs to vectors and adds position information:

python
1class PositionalEncoding(nn.Module):
2 def __init__(self, embed_dim: int, max_seq_len: int, dropout: float = 0.1):
3 super().__init__()
4 self.dropout = nn.Dropout(p=dropout)
5 pe = torch.zeros(max_seq_len, embed_dim)
6 position = torch.arange(0, max_seq_len).unsqueeze(1).float()
7 div_term = torch.exp(
8 torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim)
9 )
10 pe[:, 0::2] = torch.sin(position * div_term)
11 pe[:, 1::2] = torch.cos(position * div_term)
12 self.register_buffer("pe", pe.unsqueeze(0))
13
14 def forward(self, x: torch.Tensor) -> torch.Tensor:
15 return self.dropout(x + self.pe[:, :x.size(1), :])
16
17
18class TokenEmbedding(nn.Module):
19 def __init__(self, vocab_size: int, embed_dim: int, max_seq_len: int, dropout: float = 0.1):
20 super().__init__()
21 self.token_embedding = nn.Embedding(vocab_size, embed_dim)
22 self.pos_encoding = PositionalEncoding(embed_dim, max_seq_len, dropout)
23 self.scale = math.sqrt(embed_dim)
24
25 def forward(self, x: torch.Tensor) -> torch.Tensor:
26 x = self.token_embedding(x) * self.scale
27 return self.pos_encoding(x)

Tests

python
1# tests/test_model.py
2def test_embedding_output_shape():
3 embed = TokenEmbedding(vocab_size=36, embed_dim=64, max_seq_len=16)
4 x = torch.tensor([[1, 6, 32, 7, 2]]) # 5 tokens
5 output = embed(x)
6 assert output.shape == (1, 5, 64) # batch=1, seq=5, dim=64
7
8def test_different_tokens_different_embeddings():
9 embed = TokenEmbedding(vocab_size=36, embed_dim=64, max_seq_len=16)
10 x = torch.tensor([[6, 7]]) # "two", "three"
11 output = embed(x)
12 assert not torch.allclose(output[0, 0], output[0, 1])
13
14def test_positional_encoding_adds_position_info():
15 pe = PositionalEncoding(embed_dim=64, max_seq_len=100, dropout=0.0)
16 x = torch.zeros(1, 5, 64)
17 output = pe(x)
18 assert not torch.allclose(output, x) # Position info added

Run tests: pytest tests/test_model.py -v

Helpful?