this repo has no description
1"""
2Data loading for Gemma 4 vision fine-tuning.
3
4Manifests live in data/<split>/manifest.jsonl.
5Each line: {"image": "images/xxx.png", "latex": "...", "typst": "..."}
6
7Images are loaded lazily via Dataset.set_transform -- never all in memory.
8Augmentation is applied on-the-fly during training only.
9"""
10
11import json
12import re
13import random
14from collections import defaultdict
15from pathlib import Path
16
17from torch.utils.data import Dataset
18from PIL import Image, ImageFilter
19import torchvision.transforms.functional as TF
20
21DATA_ROOT = Path(__file__).parent.parent / "data"
22
23TRAIN_SPLITS = [
24 "crohme_gen_2019",
25 "crohme_gen_2023",
26 "crohme_gen_syntactic",
27 "crohme_real_train",
28 "mathwriting_train",
29 "mathwriting_synthetic",
30 "mathwriting_symbols",
31 "typeset_train",
32 "typeset_mixed_train",
33]
34VAL_SPLITS = ["mathwriting_val", "typeset_val", "typeset_mixed_val"]
35TEST_SPLITS = ["mathwriting_test", "typeset_test", "typeset_mixed_test"]
36
37# Splits whose manifest typst field is a bare math expression (no $ delimiters).
38# These are wrapped as display math at load time so the training target is valid Typst.
39# Mixed splits already contain full body content with inline $...$ where needed.
40_MATH_ONLY_SPLITS = {
41 "crohme_gen_2019", "crohme_gen_2023", "crohme_gen_syntactic", "crohme_real_train",
42 "mathwriting_train", "mathwriting_synthetic", "mathwriting_symbols",
43 "typeset_train", "typeset_val", "typeset_test",
44}
45
46PROMPT = "Transcribe this image to Typst notation."
47BASE_MODEL = "unsloth/gemma-4-E2B-it"
48
49
50_NUM_RE = re.compile(r"\d+(\.\d+)?")
51_VAR_RE = re.compile(r"\b[a-zA-Z]\b")
52_MAX_PER_KEY = 5
53
54
55def _structural_key(typst: str) -> str:
56 """Normalize numbers and single-letter variables to expose structural pattern.
57
58 x^2, y^3, a^9 all map to v^# -- same template, different surface tokens.
59 Fractions like 1/2 vs 123/456 also collapse to #/#.
60 """
61 s = _NUM_RE.sub("#", typst)
62 s = _VAR_RE.sub("v", s)
63 return s
64
65
66def load_records(split_names: list[str], dedupe: bool = True,
67 root: Path = DATA_ROOT) -> list[dict]:
68 """
69 Load records from manifests.
70
71 With dedupe=True applies two passes:
72 1. Exact typst string dedup -- removes identical outputs across splits.
73 2. Structural key cap (MAX_PER_KEY=5) -- allows limited numeric/variable
74 variation per template while capping enumeration spam.
75 """
76 seen_exact: set[str] = set()
77 struct_counts: defaultdict[str, int] = defaultdict(int)
78 records: list[dict] = []
79 for name in split_names:
80 manifest = root / name / "manifest.jsonl"
81 base = (root / name).resolve()
82 math_only = name in _MATH_ONLY_SPLITS
83 for line in manifest.read_text().splitlines():
84 r = json.loads(line)
85 typst = r.get("typst", "")
86 if not typst or typst.startswith("ERROR:"):
87 continue
88 if math_only:
89 typst = f"$ {typst} $"
90 if dedupe:
91 if typst in seen_exact:
92 continue
93 sk = _structural_key(typst)
94 if struct_counts[sk] >= _MAX_PER_KEY:
95 continue
96 seen_exact.add(typst)
97 struct_counts[sk] += 1
98 records.append({
99 "image_path": str(base / r["image"]),
100 "typst": typst,
101 })
102 return records
103
104
105def _augment(img: Image.Image) -> Image.Image:
106 """
107 Mild augmentation for synthetic-to-phone-photo robustness.
108 Mirrors eff-mer's augmentation plus brightness/contrast jitter.
109 """
110 angle = random.uniform(-5, 5)
111 scale = random.uniform(0.9, 1.1)
112 tx = int(random.uniform(-0.05, 0.05) * img.width)
113 ty = int(random.uniform(-0.05, 0.05) * img.height)
114 img = TF.affine(img, angle=angle, translate=(tx, ty), scale=scale, shear=0, fill=255)
115
116 if random.random() < 0.6:
117 img = img.filter(ImageFilter.GaussianBlur(random.uniform(0.0, 0.8)))
118
119 img = TF.adjust_brightness(img, random.uniform(0.75, 1.25))
120 img = TF.adjust_contrast(img, random.uniform(0.75, 1.25))
121 return img
122
123
124class MathOCRDataset(Dataset):
125 """
126 Torch Dataset with lazy image loading. Avoids HuggingFace datasets
127 fingerprinting, which uses dill and breaks on Python 3.14.
128 """
129
130 def __init__(self, records: list[dict], do_augment: bool = False) -> None:
131 self.records = records
132 self.do_augment = do_augment
133
134 def __len__(self) -> int:
135 return len(self.records)
136
137 def __getitem__(self, idx: int) -> dict:
138 r = self.records[idx]
139 img = Image.open(r["image_path"]).convert("RGB")
140 if self.do_augment:
141 img = _augment(img)
142 return {
143 "messages": [
144 {
145 "role": "user",
146 "content": [
147 {"type": "image", "image": img},
148 {"type": "text", "text": PROMPT},
149 ],
150 },
151 {
152 "role": "assistant",
153 "content": [{"type": "text", "text": r["typst"]}],
154 },
155 ]
156 }
157
158
159def make_dataset(records: list[dict], do_augment: bool = False) -> MathOCRDataset:
160 return MathOCRDataset(records, do_augment)