this repo has no description
1"""
2ExpRate evaluation of DeepSeek-OCR-2 on the test splits.
3
4Inference uses model.infer() (the model's native API), which takes a file path,
5so images are written to a temp file per sample.
6
7Usage: uv run evaluate-deepseek [--checkpoint checkpoints/deepseek-ocr2/final]
8"""
9
10import argparse
11import os
12import re
13import tempfile
14
15import torch
16from tqdm import tqdm
17from transformers import AutoModel
18from unsloth import FastVisionModel
19
20from .data import DEEPSEEK_MODEL_DIR, TEST_SPLITS, load_records
21
22os.environ["UNSLOTH_WARN_UNINITIALIZED"] = "0"
23
24
25def normalize(s: str) -> str:
26 return re.sub(r"\s+", " ", s).strip()
27
28
29def evaluate(
30 checkpoint: str,
31 n: int | None = None,
32 verbose: bool = False,
33 image_size: int = 768,
34 base_size: int = 1024,
35 crop_mode: bool = True,
36) -> float:
37 model, tokenizer = FastVisionModel.from_pretrained(
38 checkpoint,
39 load_in_4bit=False,
40 auto_model=AutoModel,
41 trust_remote_code=True,
42 unsloth_force_compile=True,
43 use_gradient_checkpointing="unsloth",
44 )
45 FastVisionModel.for_inference(model)
46 model.eval()
47
48 import random
49 rng = random.Random(42)
50 records = []
51 for split in TEST_SPLITS:
52 split_recs = load_records([split], dedupe=False)
53 sample = rng.sample(split_recs, min(n, len(split_recs))) if n is not None else split_recs
54 records.extend(sample)
55
56 correct = 0
57 prompt = "<image>\nFree OCR. "
58
59 with tempfile.TemporaryDirectory() as tmpdir:
60 tmp_img = os.path.join(tmpdir, "img.png")
61 for r in tqdm(records, desc="Evaluating"):
62 # model.infer requires a file path
63 import shutil
64 shutil.copy(r["image_path"], tmp_img)
65
66 result = model.infer(
67 tokenizer,
68 prompt=prompt,
69 image_file=tmp_img,
70 output_path=tmpdir,
71 image_size=image_size,
72 base_size=base_size,
73 crop_mode=crop_mode,
74 save_results=False,
75 test_compress=False,
76 )
77 # model.infer returns the transcription string directly
78 pred = result if isinstance(result, str) else str(result)
79
80 match = normalize(pred) == normalize(r["typst"])
81 if match:
82 correct += 1
83 if verbose:
84 status = "OK" if match else "FAIL"
85 print(f"[{status}] split={r['split']}")
86 print(f" GT: {r['typst']}")
87 print(f" PRED: {pred}")
88
89 exprate = correct / len(records)
90 split_label = f" ({n}/split)" if n is not None else ""
91 print(f"ExpRate{split_label}: {exprate:.4f} ({correct}/{len(records)})")
92 return exprate
93
94
95def main() -> None:
96 parser = argparse.ArgumentParser()
97 parser.add_argument("--checkpoint", default=f"checkpoints/deepseek-ocr2/final")
98 parser.add_argument("--n", type=int, default=None, help="Evaluate N random examples per split")
99 parser.add_argument("--verbose", action="store_true")
100 parser.add_argument("--image-size", type=int, default=768)
101 parser.add_argument("--base-size", type=int, default=1024)
102 parser.add_argument("--no-crop", action="store_true", help="Disable dynamic cropping")
103 args = parser.parse_args()
104 evaluate(
105 args.checkpoint,
106 n=args.n,
107 verbose=args.verbose,
108 image_size=args.image_size,
109 base_size=args.base_size,
110 crop_mode=not args.no_crop,
111 )
112
113
114if __name__ == "__main__":
115 main()