this repo has no description
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

at 7ed4e562cd2ae4b2bf80fb17992106f346ea9d38 110 lines 3.8 kB view raw
1""" 2ExpRate evaluation on mathwriting_Test. 3 4ExpRate = fraction of exact string matches after whitespace normalization, 5matching the eff-mer definition (tokenizer encode→decode round-trip omitted 6here since we're comparing raw model output strings directly). 7 8Usage: uv run evaluate [--checkpoint checkpoints/baseline/final] 9""" 10 11import argparse 12import re 13 14import torch 15from tqdm import tqdm 16from unsloth import FastVisionModel 17 18from .data import BASE_MODEL, TEST_SPLITS, PROMPT, load_records 19 20 21def normalize(s: str) -> str: 22 return re.sub(r"\s+", " ", s).strip() 23 24 25def extract_assistant(decoded: str) -> str: 26 """Pull the assistant turn out of a full decoded sequence.""" 27 # Gemma 4 uses <|turn>model / <turn|>; older Gemma used <start_of_turn>model / <end_of_turn> 28 for start_marker in ("<|turn>model", "<start_of_turn>model"): 29 if start_marker in decoded: 30 decoded = decoded.split(start_marker)[-1] 31 break 32 for end_marker in ("<turn|>", "<end_of_turn>"): 33 if end_marker in decoded: 34 decoded = decoded.split(end_marker)[0] 35 break 36 return decoded.strip() 37 38 39def evaluate(checkpoint: str, batch_size: int = 8, n: int | None = None, verbose: bool = False) -> float: 40 model, processor = FastVisionModel.from_pretrained(checkpoint, load_in_4bit=True) 41 FastVisionModel.for_inference(model) 42 model.eval() 43 44 import random 45 if n is not None: 46 rng = random.Random(42) 47 records = [] 48 for split in TEST_SPLITS: 49 split_recs = load_records([split], dedupe=False) 50 records.extend(rng.sample(split_recs, min(n, len(split_recs)))) 51 else: 52 records = load_records(TEST_SPLITS, dedupe=False) 53 correct = 0 54 55 from PIL import Image 56 for i in tqdm(range(0, len(records), batch_size), desc="Evaluating"): 57 batch = records[i : i + batch_size] 58 for r in batch: 59 img = Image.open(r["image_path"]).convert("RGB") 60 messages = [ 61 { 62 "role": "user", 63 "content": [ 64 {"type": "image", "image": img}, 65 {"type": "text", "text": PROMPT}, 66 ], 67 } 68 ] 69 inputs = processor( 70 images=[img], 71 text=processor.apply_chat_template( 72 messages, 73 add_generation_prompt=True, 74 ), 75 return_tensors="pt", 76 ).to("cuda") 77 78 with torch.no_grad(): 79 out = model.generate(**inputs, max_new_tokens=512, do_sample=False) 80 81 decoded = processor.decode(out[0], skip_special_tokens=False) 82 pred = extract_assistant(decoded) 83 84 match = normalize(pred) == normalize(r["typst"]) 85 if match: 86 correct += 1 87 if verbose: 88 status = "OK" if match else "FAIL" 89 print(f"[{status}] split={r['split']}") 90 print(f" GT: {r['typst']}") 91 print(f" PRED: {pred}") 92 93 exprate = correct / len(records) 94 split_label = f" ({n}/split)" if n is not None else "" 95 print(f"ExpRate{split_label}: {exprate:.4f} ({correct}/{len(records)})") 96 return exprate 97 98 99def main() -> None: 100 parser = argparse.ArgumentParser() 101 parser.add_argument("--checkpoint", default="checkpoints/baseline/final") 102 parser.add_argument("--batch-size", type=int, default=8) 103 parser.add_argument("--n", type=int, default=None, help="Evaluate N random examples per split") 104 parser.add_argument("--verbose", action="store_true", help="Print GT and predicted label for each example") 105 args = parser.parse_args() 106 evaluate(args.checkpoint, args.batch_size, args.n, args.verbose) 107 108 109if __name__ == "__main__": 110 main()