this repo has no description
1"""
2Quick inference probe: run against test set examples or arbitrary images.
3
4Usage:
5 uv run probe [--checkpoint checkpoints/baseline/final] [--n 5]
6 uv run probe --images path/to/a.png path/to/b.png
7"""
8
9import argparse
10import torch
11from PIL import Image
12from unsloth import FastVisionModel
13
14from .data import BASE_MODEL, TEST_SPLITS, PROMPT, load_records
15from .eval import extract_assistant, normalize
16
17
18def _run_image(img: Image.Image, model, processor) -> str:
19 messages = [
20 {
21 "role": "user",
22 "content": [
23 {"type": "image", "image": img},
24 {"type": "text", "text": PROMPT},
25 ],
26 }
27 ]
28 inputs = processor(
29 images=[img],
30 text=processor.apply_chat_template(messages, add_generation_prompt=True),
31 return_tensors="pt",
32 ).to("cuda")
33 with torch.no_grad():
34 out = model.generate(**inputs, max_new_tokens=512, do_sample=False)
35 decoded_full = processor.decode(out[0], skip_special_tokens=False)
36 return extract_assistant(decoded_full)
37
38
39def main() -> None:
40 parser = argparse.ArgumentParser()
41 parser.add_argument("--checkpoint", default="checkpoints/baseline/final")
42 parser.add_argument("--n", type=int, default=5)
43 parser.add_argument("--images", nargs="+", metavar="IMG",
44 help="Arbitrary image files to run inference on")
45 args = parser.parse_args()
46
47 model, processor = FastVisionModel.from_pretrained(args.checkpoint, load_in_4bit=True)
48 FastVisionModel.for_inference(model)
49 model.eval()
50
51 if args.images:
52 for path in args.images:
53 img = Image.open(path).convert("RGB")
54 pred = _run_image(img, model, processor)
55 print(f"{path}: {pred}")
56 else:
57 records = load_records(TEST_SPLITS, dedupe=False)[: args.n]
58 for i, r in enumerate(records):
59 img = Image.open(r["image_path"]).convert("RGB")
60 pred = _run_image(img, model, processor)
61 print(f"\n{'='*60}")
62 print(f"[{i}] {r['image_path']}")
63 print(f" EXPECTED : {repr(r['typst'])}")
64 print(f" PREDICTED: {repr(pred)}")
65 print(f" MATCH : {normalize(pred) == normalize(r['typst'])}")
66
67
68if __name__ == "__main__":
69 main()