this repo has no description
1"""
2ExpRate evaluation on mathwriting_Test.
3
4ExpRate = fraction of exact string matches after whitespace normalization,
5matching the eff-mer definition (tokenizer encode→decode round-trip omitted
6here since we're comparing raw model output strings directly).
7
8Usage: uv run evaluate [--checkpoint checkpoints/baseline/final]
9"""
10
11import argparse
12import re
13
14import torch
15from tqdm import tqdm
16from unsloth import FastVisionModel
17
18from .data import BASE_MODEL, TEST_SPLITS, PROMPT, load_records
19
20
21def normalize(s: str) -> str:
22 return re.sub(r"\s+", " ", s).strip()
23
24
25def extract_assistant(decoded: str) -> str:
26 """Pull the assistant turn out of a full decoded sequence."""
27 # Gemma 4 uses <|turn>model / <turn|>; older Gemma used <start_of_turn>model / <end_of_turn>
28 for start_marker in ("<|turn>model", "<start_of_turn>model"):
29 if start_marker in decoded:
30 decoded = decoded.split(start_marker)[-1]
31 break
32 for end_marker in ("<turn|>", "<end_of_turn>"):
33 if end_marker in decoded:
34 decoded = decoded.split(end_marker)[0]
35 break
36 return decoded.strip()
37
38
39def evaluate(checkpoint: str, batch_size: int = 8, n: int | None = None, verbose: bool = False) -> float:
40 model, processor = FastVisionModel.from_pretrained(checkpoint, load_in_4bit=True)
41 FastVisionModel.for_inference(model)
42 model.eval()
43
44 import random
45 if n is not None:
46 rng = random.Random(42)
47 records = []
48 for split in TEST_SPLITS:
49 split_recs = load_records([split], dedupe=False)
50 records.extend(rng.sample(split_recs, min(n, len(split_recs))))
51 else:
52 records = load_records(TEST_SPLITS, dedupe=False)
53 correct = 0
54
55 from PIL import Image
56 for i in tqdm(range(0, len(records), batch_size), desc="Evaluating"):
57 batch = records[i : i + batch_size]
58 for r in batch:
59 img = Image.open(r["image_path"]).convert("RGB")
60 messages = [
61 {
62 "role": "user",
63 "content": [
64 {"type": "image", "image": img},
65 {"type": "text", "text": PROMPT},
66 ],
67 }
68 ]
69 inputs = processor(
70 images=[img],
71 text=processor.apply_chat_template(
72 messages,
73 add_generation_prompt=True,
74 ),
75 return_tensors="pt",
76 ).to("cuda")
77
78 with torch.no_grad():
79 out = model.generate(**inputs, max_new_tokens=512, do_sample=False)
80
81 decoded = processor.decode(out[0], skip_special_tokens=False)
82 pred = extract_assistant(decoded)
83
84 match = normalize(pred) == normalize(r["typst"])
85 if match:
86 correct += 1
87 if verbose:
88 status = "OK" if match else "FAIL"
89 print(f"[{status}] split={r['split']}")
90 print(f" GT: {r['typst']}")
91 print(f" PRED: {pred}")
92
93 exprate = correct / len(records)
94 split_label = f" ({n}/split)" if n is not None else ""
95 print(f"ExpRate{split_label}: {exprate:.4f} ({correct}/{len(records)})")
96 return exprate
97
98
99def main() -> None:
100 parser = argparse.ArgumentParser()
101 parser.add_argument("--checkpoint", default="checkpoints/baseline/final")
102 parser.add_argument("--batch-size", type=int, default=8)
103 parser.add_argument("--n", type=int, default=None, help="Evaluate N random examples per split")
104 parser.add_argument("--verbose", action="store_true", help="Print GT and predicted label for each example")
105 args = parser.parse_args()
106 evaluate(args.checkpoint, args.batch_size, args.n, args.verbose)
107
108
109if __name__ == "__main__":
110 main()