Evaluation & Saving
Evaluate Accuracy
python
1def evaluate_model(model, tokenizer, test_data, device="cpu"):2 """Evaluate model accuracy on test data."""3 model.eval()4 correct = 05 errors = []6
7 for item in test_data:8 prompt = item["input"] + " equals"9 result = generate(model, tokenizer, prompt, device=device)10
11 if "equals" in result:12 answer = result.split("equals")[-1].strip()13 else:14 answer = result15
16 expected = item["output"]17
18 if answer == expected:19 correct += 120 else:21 errors.append({22 "input": item["input"],23 "expected": expected,24 "got": answer,25 })26
27 accuracy = correct / len(test_data)28 return accuracy, errors29
30
31# Evaluate32accuracy, errors = evaluate_model(model, tokenizer, test_data)33print(f"Test Accuracy: {accuracy:.1%}")34
35if errors:36 print("\nSample errors (first 5):")37 for e in errors[:5]:38 print(f" {e['input']} = {e['got']} (expected: {e['expected']})")Save & Load
python
1# Save model and config to output directory2output_dir = Path("output")3output_dir.mkdir(exist_ok=True)4
5torch.save(model.state_dict(), output_dir / "model.pt")6shutil.copy("config/config.json", output_dir / "config.json")7shutil.copy("config/vocab.json", output_dir / "vocab.json")8
9
10def load_model(model_dir: str | Path, device: str = "cpu"):11 """Load a trained Calculator LLM model."""12 model_dir = Path(model_dir)13
14 with open(model_dir / "config.json") as f:15 config = json.load(f)16
17 tokenizer = Tokenizer.from_file(model_dir / "vocab.json")18
19 model = CalculatorLLM(20 vocab_size=config["vocab_size"],21 embed_dim=config["embed_dim"],22 num_heads=config["num_heads"],23 num_layers=config["num_layers"],24 ff_dim=config["ff_dim"],25 max_seq_len=config["max_seq_len"],26 )27
28 model.load_state_dict(29 torch.load(model_dir / "model.pt", map_location=device, weights_only=True)30 )31 model.to(device)32 model.eval()33
34 return model, tokenizer, config35
36
37# Load and use38model, tokenizer, config = load_model("output")39answer = solve(model, tokenizer, "two plus three")40print(f"two plus three = {answer}")Helpful?