Build Your First LLM from ScratchPart 3 · Section 8 of 13

The Embedding Layer

import torch
import torch.nn as nn

class Embedding(nn.Module):
    def __init__(self, vocab_size: int = 36, embed_dim: int = 64):
        super().__init__()
        # Create a lookup table: vocab_size rows, embed_dim columns
        self.embedding = nn.Embedding(vocab_size, embed_dim)

    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        # Look up each token ID to get its vector
        return self.embedding(token_ids)

Usage:

embed = Embedding(vocab_size=36, embed_dim=64)

token_ids = torch.tensor([5, 31, 6])  # "two plus three"
vectors = embed(token_ids)

print(vectors.shape)  # torch.Size([3, 64])
# 3 tokens, each represented by 64 numbers
Helpful?