this repo has no description
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

at b1a457691b1cf8d90b2c0a56ffbb6c1ac60f9665 69 lines 2.3 kB view raw
1""" 2Quick inference probe: run against test set examples or arbitrary images. 3 4Usage: 5 uv run probe [--checkpoint checkpoints/baseline/final] [--n 5] 6 uv run probe --images path/to/a.png path/to/b.png 7""" 8 9import argparse 10import torch 11from PIL import Image 12from unsloth import FastVisionModel 13 14from .data import BASE_MODEL, TEST_SPLITS, PROMPT, load_records 15from .eval import extract_assistant, normalize 16 17 18def _run_image(img: Image.Image, model, processor) -> str: 19 messages = [ 20 { 21 "role": "user", 22 "content": [ 23 {"type": "image", "image": img}, 24 {"type": "text", "text": PROMPT}, 25 ], 26 } 27 ] 28 inputs = processor( 29 images=[img], 30 text=processor.apply_chat_template(messages, add_generation_prompt=True), 31 return_tensors="pt", 32 ).to("cuda") 33 with torch.no_grad(): 34 out = model.generate(**inputs, max_new_tokens=512, do_sample=False) 35 decoded_full = processor.decode(out[0], skip_special_tokens=False) 36 return extract_assistant(decoded_full) 37 38 39def main() -> None: 40 parser = argparse.ArgumentParser() 41 parser.add_argument("--checkpoint", default="checkpoints/baseline/final") 42 parser.add_argument("--n", type=int, default=5) 43 parser.add_argument("--images", nargs="+", metavar="IMG", 44 help="Arbitrary image files to run inference on") 45 args = parser.parse_args() 46 47 model, processor = FastVisionModel.from_pretrained(args.checkpoint, load_in_4bit=True) 48 FastVisionModel.for_inference(model) 49 model.eval() 50 51 if args.images: 52 for path in args.images: 53 img = Image.open(path).convert("RGB") 54 pred = _run_image(img, model, processor) 55 print(f"{path}: {pred}") 56 else: 57 records = load_records(TEST_SPLITS, dedupe=False)[: args.n] 58 for i, r in enumerate(records): 59 img = Image.open(r["image_path"]).convert("RGB") 60 pred = _run_image(img, model, processor) 61 print(f"\n{'='*60}") 62 print(f"[{i}] {r['image_path']}") 63 print(f" EXPECTED : {repr(r['typst'])}") 64 print(f" PREDICTED: {repr(pred)}") 65 print(f" MATCH : {normalize(pred) == normalize(r['typst'])}") 66 67 68if __name__ == "__main__": 69 main()