this repo has no description
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

at 7ed4e562cd2ae4b2bf80fb17992106f346ea9d38 115 lines 3.5 kB view raw
1""" 2ExpRate evaluation of DeepSeek-OCR-2 on the test splits. 3 4Inference uses model.infer() (the model's native API), which takes a file path, 5so images are written to a temp file per sample. 6 7Usage: uv run evaluate-deepseek [--checkpoint checkpoints/deepseek-ocr2/final] 8""" 9 10import argparse 11import os 12import re 13import tempfile 14 15import torch 16from tqdm import tqdm 17from transformers import AutoModel 18from unsloth import FastVisionModel 19 20from .data import DEEPSEEK_MODEL_DIR, TEST_SPLITS, load_records 21 22os.environ["UNSLOTH_WARN_UNINITIALIZED"] = "0" 23 24 25def normalize(s: str) -> str: 26 return re.sub(r"\s+", " ", s).strip() 27 28 29def evaluate( 30 checkpoint: str, 31 n: int | None = None, 32 verbose: bool = False, 33 image_size: int = 768, 34 base_size: int = 1024, 35 crop_mode: bool = True, 36) -> float: 37 model, tokenizer = FastVisionModel.from_pretrained( 38 checkpoint, 39 load_in_4bit=False, 40 auto_model=AutoModel, 41 trust_remote_code=True, 42 unsloth_force_compile=True, 43 use_gradient_checkpointing="unsloth", 44 ) 45 FastVisionModel.for_inference(model) 46 model.eval() 47 48 import random 49 rng = random.Random(42) 50 records = [] 51 for split in TEST_SPLITS: 52 split_recs = load_records([split], dedupe=False) 53 sample = rng.sample(split_recs, min(n, len(split_recs))) if n is not None else split_recs 54 records.extend(sample) 55 56 correct = 0 57 prompt = "<image>\nFree OCR. " 58 59 with tempfile.TemporaryDirectory() as tmpdir: 60 tmp_img = os.path.join(tmpdir, "img.png") 61 for r in tqdm(records, desc="Evaluating"): 62 # model.infer requires a file path 63 import shutil 64 shutil.copy(r["image_path"], tmp_img) 65 66 result = model.infer( 67 tokenizer, 68 prompt=prompt, 69 image_file=tmp_img, 70 output_path=tmpdir, 71 image_size=image_size, 72 base_size=base_size, 73 crop_mode=crop_mode, 74 save_results=False, 75 test_compress=False, 76 ) 77 # model.infer returns the transcription string directly 78 pred = result if isinstance(result, str) else str(result) 79 80 match = normalize(pred) == normalize(r["typst"]) 81 if match: 82 correct += 1 83 if verbose: 84 status = "OK" if match else "FAIL" 85 print(f"[{status}] split={r['split']}") 86 print(f" GT: {r['typst']}") 87 print(f" PRED: {pred}") 88 89 exprate = correct / len(records) 90 split_label = f" ({n}/split)" if n is not None else "" 91 print(f"ExpRate{split_label}: {exprate:.4f} ({correct}/{len(records)})") 92 return exprate 93 94 95def main() -> None: 96 parser = argparse.ArgumentParser() 97 parser.add_argument("--checkpoint", default=f"checkpoints/deepseek-ocr2/final") 98 parser.add_argument("--n", type=int, default=None, help="Evaluate N random examples per split") 99 parser.add_argument("--verbose", action="store_true") 100 parser.add_argument("--image-size", type=int, default=768) 101 parser.add_argument("--base-size", type=int, default=1024) 102 parser.add_argument("--no-crop", action="store_true", help="Disable dynamic cropping") 103 args = parser.parse_args() 104 evaluate( 105 args.checkpoint, 106 n=args.n, 107 verbose=args.verbose, 108 image_size=args.image_size, 109 base_size=args.base_size, 110 crop_mode=not args.no_crop, 111 ) 112 113 114if __name__ == "__main__": 115 main()