this repo has no description
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

at c80071a330c79663dea79fd8fc958885fa422d94 160 lines 5.3 kB view raw
1""" 2Data loading for Gemma 4 vision fine-tuning. 3 4Manifests live in data/<split>/manifest.jsonl. 5Each line: {"image": "images/xxx.png", "latex": "...", "typst": "..."} 6 7Images are loaded lazily via Dataset.set_transform -- never all in memory. 8Augmentation is applied on-the-fly during training only. 9""" 10 11import json 12import re 13import random 14from collections import defaultdict 15from pathlib import Path 16 17from torch.utils.data import Dataset 18from PIL import Image, ImageFilter 19import torchvision.transforms.functional as TF 20 21DATA_ROOT = Path(__file__).parent.parent / "data" 22 23TRAIN_SPLITS = [ 24 "crohme_gen_2019", 25 "crohme_gen_2023", 26 "crohme_gen_syntactic", 27 "crohme_real_train", 28 "mathwriting_train", 29 "mathwriting_synthetic", 30 "mathwriting_symbols", 31 "typeset_train", 32 "typeset_mixed_train", 33] 34VAL_SPLITS = ["mathwriting_val", "typeset_val", "typeset_mixed_val"] 35TEST_SPLITS = ["mathwriting_test", "typeset_test", "typeset_mixed_test"] 36 37# Splits whose manifest typst field is a bare math expression (no $ delimiters). 38# These are wrapped as display math at load time so the training target is valid Typst. 39# Mixed splits already contain full body content with inline $...$ where needed. 40_MATH_ONLY_SPLITS = { 41 "crohme_gen_2019", "crohme_gen_2023", "crohme_gen_syntactic", "crohme_real_train", 42 "mathwriting_train", "mathwriting_synthetic", "mathwriting_symbols", 43 "typeset_train", "typeset_val", "typeset_test", 44} 45 46PROMPT = "Transcribe this image to Typst notation." 47BASE_MODEL = "unsloth/gemma-4-E2B-it" 48 49 50_NUM_RE = re.compile(r"\d+(\.\d+)?") 51_VAR_RE = re.compile(r"\b[a-zA-Z]\b") 52_MAX_PER_KEY = 5 53 54 55def _structural_key(typst: str) -> str: 56 """Normalize numbers and single-letter variables to expose structural pattern. 57 58 x^2, y^3, a^9 all map to v^# -- same template, different surface tokens. 59 Fractions like 1/2 vs 123/456 also collapse to #/#. 60 """ 61 s = _NUM_RE.sub("#", typst) 62 s = _VAR_RE.sub("v", s) 63 return s 64 65 66def load_records(split_names: list[str], dedupe: bool = True, 67 root: Path = DATA_ROOT) -> list[dict]: 68 """ 69 Load records from manifests. 70 71 With dedupe=True applies two passes: 72 1. Exact typst string dedup -- removes identical outputs across splits. 73 2. Structural key cap (MAX_PER_KEY=5) -- allows limited numeric/variable 74 variation per template while capping enumeration spam. 75 """ 76 seen_exact: set[str] = set() 77 struct_counts: defaultdict[str, int] = defaultdict(int) 78 records: list[dict] = [] 79 for name in split_names: 80 manifest = root / name / "manifest.jsonl" 81 base = (root / name).resolve() 82 math_only = name in _MATH_ONLY_SPLITS 83 for line in manifest.read_text().splitlines(): 84 r = json.loads(line) 85 typst = r.get("typst", "") 86 if not typst or typst.startswith("ERROR:"): 87 continue 88 if math_only: 89 typst = f"$ {typst} $" 90 if dedupe: 91 if typst in seen_exact: 92 continue 93 sk = _structural_key(typst) 94 if struct_counts[sk] >= _MAX_PER_KEY: 95 continue 96 seen_exact.add(typst) 97 struct_counts[sk] += 1 98 records.append({ 99 "image_path": str(base / r["image"]), 100 "typst": typst, 101 }) 102 return records 103 104 105def _augment(img: Image.Image) -> Image.Image: 106 """ 107 Mild augmentation for synthetic-to-phone-photo robustness. 108 Mirrors eff-mer's augmentation plus brightness/contrast jitter. 109 """ 110 angle = random.uniform(-5, 5) 111 scale = random.uniform(0.9, 1.1) 112 tx = int(random.uniform(-0.05, 0.05) * img.width) 113 ty = int(random.uniform(-0.05, 0.05) * img.height) 114 img = TF.affine(img, angle=angle, translate=(tx, ty), scale=scale, shear=0, fill=255) 115 116 if random.random() < 0.6: 117 img = img.filter(ImageFilter.GaussianBlur(random.uniform(0.0, 0.8))) 118 119 img = TF.adjust_brightness(img, random.uniform(0.75, 1.25)) 120 img = TF.adjust_contrast(img, random.uniform(0.75, 1.25)) 121 return img 122 123 124class MathOCRDataset(Dataset): 125 """ 126 Torch Dataset with lazy image loading. Avoids HuggingFace datasets 127 fingerprinting, which uses dill and breaks on Python 3.14. 128 """ 129 130 def __init__(self, records: list[dict], do_augment: bool = False) -> None: 131 self.records = records 132 self.do_augment = do_augment 133 134 def __len__(self) -> int: 135 return len(self.records) 136 137 def __getitem__(self, idx: int) -> dict: 138 r = self.records[idx] 139 img = Image.open(r["image_path"]).convert("RGB") 140 if self.do_augment: 141 img = _augment(img) 142 return { 143 "messages": [ 144 { 145 "role": "user", 146 "content": [ 147 {"type": "image", "image": img}, 148 {"type": "text", "text": PROMPT}, 149 ], 150 }, 151 { 152 "role": "assistant", 153 "content": [{"type": "text", "text": r["typst"]}], 154 }, 155 ] 156 } 157 158 159def make_dataset(records: list[dict], do_augment: bool = False) -> MathOCRDataset: 160 return MathOCRDataset(records, do_augment)