Generation

Your model predicts one token at a time. Generation means feeding output back as input, over and over.

python
1def generate(
2 model: CalculatorLLM,
3 tokenizer: Tokenizer,
4 prompt: str,
5 max_new_tokens: int = 10,
6 device: str = "cpu",
7) -> str:
8 """Generate text from a prompt using greedy decoding."""
9 model.eval()
10
11 # Encode prompt (without end token so we can continue generating)
12 tokens = tokenizer.encode(prompt, add_special_tokens=True)[:-1]
13 input_ids = torch.tensor([tokens]).to(device)
14
15 with torch.no_grad():
16 for _ in range(max_new_tokens):
17 logits = model(input_ids)
18 next_token = logits[0, -1, :].argmax().item()
19
20 if next_token == tokenizer.end_token_id:
21 break
22
23 input_ids = torch.cat(
24 [input_ids, torch.tensor([[next_token]]).to(device)], dim=1
25 )
26
27 return tokenizer.decode(input_ids[0].tolist())
28
29
30def solve(
31 model: CalculatorLLM,
32 tokenizer: Tokenizer,
33 problem: str,
34 device: str = "cpu",
35) -> str:
36 """Solve an English math problem."""
37 # Normalize and ensure it ends with "equals"
38 problem = problem.lower().strip()
39 if not problem.endswith("equals"):
40 problem = problem + " equals"
41
42 result = generate(model, tokenizer, problem, device=device)
43
44 # Extract just the answer after "equals"
45 if "equals" in result:
46 return result.split("equals")[-1].strip()
47 return result
48
49
50# Try it!
51print(solve(model, tokenizer, "two plus three")) # → "five"
52print(solve(model, tokenizer, "seven times six")) # → "forty two"
TemperatureEffectUse Case
0 (greedy)DeterministicMath, code, facts
0.7-1.0BalancedGeneral conversation
1.5+CreativeBrainstorming

Tests

python
1# tests/test_generate.py
2def test_generate_returns_string(model, tokenizer):
3 result = generate(model, tokenizer, "two plus three equals")
4 assert isinstance(result, str)
5
6def test_solve_returns_string(model, tokenizer):
7 result = solve(model, tokenizer, "two plus three")
8 assert isinstance(result, str)
9
10def test_solve_handles_uppercase(model, tokenizer):
11 result = solve(model, tokenizer, "TWO PLUS THREE")
12 assert result is not None
13
14def test_evaluate_returns_accuracy_and_errors(model, tokenizer):
15 test_data = [{"input": "two plus three", "output": "five"}]
16 accuracy, errors = evaluate_model(model, tokenizer, test_data)
17 assert 0 <= accuracy <= 1
18 assert isinstance(errors, list)

Run tests: pytest tests/test_generate.py -v

Helpful?