The Tokenizer

The tokenizer in src/tokenizer.py converts text to token IDs and back:

python
1class Tokenizer:
2 def __init__(self, vocab: dict[str, int]):
3 self.vocab = vocab
4 self.id_to_word = {v: k for k, v in vocab.items()}
5
6 def encode(self, text: str) -> list[int]:
7 """Convert text to token IDs."""
8 ids = [self.vocab["[START]"]]
9 for word in text.lower().split():
10 ids.append(self.vocab.get(word, self.vocab["[UNK]"]))
11 ids.append(self.vocab["[END]"])
12 return ids
13
14 def decode(self, ids: list[int]) -> str:
15 """Convert token IDs back to text."""
16 special = {"[PAD]", "[START]", "[END]", "[UNK]"}
17 words = [self.id_to_word[i] for i in ids if self.id_to_word[i] not in special]
18 return " ".join(words)

Tests

python
1# tests/test_tokenizer.py
2def test_encode_simple(tokenizer):
3 ids = tokenizer.encode("two plus three")
4 assert ids == [1, 6, 32, 7, 2] # [START] two plus three [END]
5
6def test_decode_simple(tokenizer):
7 text = tokenizer.decode([1, 6, 32, 7, 2])
8 assert text == "two plus three"
9
10def test_encode_decode_roundtrip(tokenizer):
11 original = "five times seven"
12 ids = tokenizer.encode(original)
13 decoded = tokenizer.decode(ids)
14 assert decoded == original
15
16def test_unknown_word(tokenizer):
17 ids = tokenizer.encode("two plus banana")
18 assert ids == [1, 6, 32, 3, 2] # [UNK] = 3

Run tests: pytest tests/test_tokenizer.py -v

Helpful?