Evaluation & Saving

Evaluate Accuracy

python
1def evaluate_model(model, tokenizer, test_data, device="cpu"):
2 """Evaluate model accuracy on test data."""
3 model.eval()
4 correct = 0
5 errors = []
6
7 for item in test_data:
8 prompt = item["input"] + " equals"
9 result = generate(model, tokenizer, prompt, device=device)
10
11 if "equals" in result:
12 answer = result.split("equals")[-1].strip()
13 else:
14 answer = result
15
16 expected = item["output"]
17
18 if answer == expected:
19 correct += 1
20 else:
21 errors.append({
22 "input": item["input"],
23 "expected": expected,
24 "got": answer,
25 })
26
27 accuracy = correct / len(test_data)
28 return accuracy, errors
29
30
31# Evaluate
32accuracy, errors = evaluate_model(model, tokenizer, test_data)
33print(f"Test Accuracy: {accuracy:.1%}")
34
35if errors:
36 print("\nSample errors (first 5):")
37 for e in errors[:5]:
38 print(f" {e['input']} = {e['got']} (expected: {e['expected']})")

Save & Load

python
1# Save model and config to output directory
2output_dir = Path("output")
3output_dir.mkdir(exist_ok=True)
4
5torch.save(model.state_dict(), output_dir / "model.pt")
6shutil.copy("config/config.json", output_dir / "config.json")
7shutil.copy("config/vocab.json", output_dir / "vocab.json")
8
9
10def load_model(model_dir: str | Path, device: str = "cpu"):
11 """Load a trained Calculator LLM model."""
12 model_dir = Path(model_dir)
13
14 with open(model_dir / "config.json") as f:
15 config = json.load(f)
16
17 tokenizer = Tokenizer.from_file(model_dir / "vocab.json")
18
19 model = CalculatorLLM(
20 vocab_size=config["vocab_size"],
21 embed_dim=config["embed_dim"],
22 num_heads=config["num_heads"],
23 num_layers=config["num_layers"],
24 ff_dim=config["ff_dim"],
25 max_seq_len=config["max_seq_len"],
26 )
27
28 model.load_state_dict(
29 torch.load(model_dir / "model.pt", map_location=device, weights_only=True)
30 )
31 model.to(device)
32 model.eval()
33
34 return model, tokenizer, config
35
36
37# Load and use
38model, tokenizer, config = load_model("output")
39answer = solve(model, tokenizer, "two plus three")
40print(f"two plus three = {answer}")
Helpful?