this repo has no description
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

Add handwriting-font data pipeline; expand body grammar; drop typeset_* splits

- src/download_hw_fonts.py: downloads 6 Google Fonts TTFs, strips WOFF wrappers,
instantiates variable fonts at wght=400 for full character coverage
- src/generate_handwritten.py: hw mode (whole-doc font) and mix mode (per-block
font mixing); 7-way uniform font sampling including Typst default; manifest
records clean body (no font directives)
- src/generate_mixed.py: expand generate_body -- add 18% bare math, 15% short
inline (1-2 tokens); reduce complex structured weight; min complexity now n=1
- src/data.py: replace typeset_* splits with hw_structured_* and hw_mixed_*;
update val sampling to use VAL_SPLITS
- src/train.py: fix val loading to use VAL_SPLITS from data.py; move import to
top level
- pyproject.toml: add generate-hw and download-hw-fonts entry points

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

+506 -18
+2
pyproject.toml
··· 34 34 export = "src.export:main" 35 35 generate-typeset = "src.generate_typeset:main" 36 36 generate-mixed = "src.generate_mixed:main" 37 + generate-hw = "src.generate_handwritten:main" 38 + download-hw-fonts = "src.download_hw_fonts:main" 37 39 probe = "src.probe:main" 38 40 app = "src.app:main" 39 41 probe-deepseek = "src.probe_deepseek:main"
+9 -4
src/data.py
··· 28 28 "mathwriting_train", 29 29 "mathwriting_synthetic", 30 30 "mathwriting_symbols", 31 - "typeset_train", 32 - "typeset_mixed_train", 31 + "hw_structured_train", "hw_mixed_train", 32 + ] 33 + VAL_SPLITS = [ 34 + "mathwriting_val", 35 + "hw_structured_val", "hw_mixed_val", 36 + ] 37 + TEST_SPLITS = [ 38 + "mathwriting_test", 39 + "hw_structured_test", "hw_mixed_test", 33 40 ] 34 - VAL_SPLITS = ["mathwriting_val", "typeset_val", "typeset_mixed_val"] 35 - TEST_SPLITS = ["mathwriting_test", "typeset_test", "typeset_mixed_test"] 36 41 37 42 # Splits whose manifest typst field is a bare math expression (no $ delimiters). 38 43 # These are wrapped as display math at load time so the training target is valid Typst.
+150
src/download_hw_fonts.py
··· 1 + """ 2 + Download handwriting fonts from Google Fonts for use in generate_handwritten.py. 3 + 4 + Saves TTF files to data/fonts/handwriting/ (or --out). Typst discovers them 5 + via --font-path at compile time. 6 + 7 + Usage: 8 + uv run download-hw-fonts 9 + uv run download-hw-fonts --out data/fonts/handwriting 10 + """ 11 + 12 + import argparse 13 + import re 14 + import time 15 + import urllib.parse 16 + import urllib.request 17 + from pathlib import Path 18 + 19 + from fontTools import ttLib 20 + from fontTools.varLib.instancer import instantiateVariableFont 21 + 22 + # Maps display name -> Google Fonts family string (used in CSS API query) 23 + _FAMILIES: dict[str, str] = { 24 + "Comic Neue": "Comic+Neue", 25 + "Gochi Hand": "Gochi+Hand", 26 + "Handlee": "Handlee", 27 + "Oswald": "Oswald", 28 + "Dancing Script": "Dancing+Script", 29 + "Special Elite": "Special+Elite", 30 + } 31 + 32 + # GitHub raw URLs used as primary source for variable fonts (CSS API only 33 + # returns character-range subsets for these) and as fallback for others. 34 + _GITHUB_URLS: dict[str, str] = { 35 + "Comic Neue": "https://github.com/google/fonts/raw/main/ofl/comicneue/ComicNeue-Regular.ttf", 36 + "Gochi Hand": "https://github.com/google/fonts/raw/main/ofl/gochihand/GochiHand-Regular.ttf", 37 + "Handlee": "https://github.com/google/fonts/raw/main/ofl/handlee/Handlee-Regular.ttf", 38 + # Variable fonts: CSS API returns subsets; download full variable font and instantiate. 39 + "Oswald": "https://github.com/google/fonts/raw/main/ofl/oswald/Oswald%5Bwght%5D.ttf", 40 + "Dancing Script": "https://github.com/google/fonts/raw/main/ofl/dancingscript/DancingScript%5Bwght%5D.ttf", 41 + "Special Elite": "https://github.com/google/fonts/raw/main/apache/specialelite/SpecialElite-Regular.ttf", 42 + } 43 + 44 + # Fonts where the CSS API returns the full font (not a subset). 45 + # For others we skip the API and go straight to GitHub. 46 + _CSS_API_OK: set[str] = {"Comic Neue", "Gochi Hand", "Handlee", "Special Elite"} 47 + 48 + _UA = "Mozilla/5.0 (X11; Linux x86_64; rv:125.0) Gecko/20100101 Firefox/125.0" 49 + _GSTATIC_RE = re.compile(r"url\((https://fonts\.gstatic\.com/[^)]+\.(?:ttf|woff2?))\)") 50 + 51 + 52 + def _css_api_url(family_str: str) -> str: 53 + return f"https://fonts.googleapis.com/css2?family={family_str}:wght@400&display=swap" 54 + 55 + 56 + def _fetch_gstatic_url(family_str: str) -> str | None: 57 + """Call Google Fonts CSS2 API and return the first ttf/woff2 URL found.""" 58 + url = _css_api_url(family_str) 59 + req = urllib.request.Request(url, headers={"User-Agent": _UA}) 60 + try: 61 + with urllib.request.urlopen(req, timeout=12) as resp: 62 + css = resp.read().decode("utf-8", errors="replace") 63 + matches = _GSTATIC_RE.findall(css) 64 + # Prefer .ttf over .woff2 65 + for m in matches: 66 + if m.endswith(".ttf"): 67 + return m 68 + return matches[0] if matches else None 69 + except Exception as exc: 70 + print(f" CSS API error: {exc}") 71 + return None 72 + 73 + 74 + def _download(url: str, dest: Path) -> bool: 75 + """Download url -> dest, convert to plain static TTF, return True on success.""" 76 + req = urllib.request.Request(url, headers={"User-Agent": _UA}) 77 + try: 78 + with urllib.request.urlopen(req, timeout=30) as resp: 79 + data = resp.read() 80 + dest.write_bytes(data) 81 + tt = ttLib.TTFont(dest) 82 + # Strip WOFF/WOFF2 wrapper -- Typst requires plain TTF/OTF. 83 + if tt.flavor in ("woff", "woff2"): 84 + tt.flavor = None 85 + # Instantiate variable fonts at Regular weight -- subsetted downloads 86 + # from the CSS API often lack lowercase; variable fonts from GitHub are 87 + # complete but need to be pinned to a single instance. 88 + if "fvar" in tt: 89 + instantiateVariableFont(tt, {"wght": 400}) 90 + tt.save(dest) 91 + # Verify the result has basic Latin coverage 92 + cmap = tt.getBestCmap() or {} 93 + if ord("a") not in cmap: 94 + print(f" warning: 'a' missing from cmap after processing") 95 + return True 96 + except Exception as exc: 97 + print(f" download error: {exc}") 98 + return False 99 + 100 + 101 + def main() -> None: 102 + parser = argparse.ArgumentParser() 103 + parser.add_argument("--out", default="data/fonts/handwriting", 104 + help="Directory to save font files") 105 + args = parser.parse_args() 106 + 107 + out = Path(args.out) 108 + out.mkdir(parents=True, exist_ok=True) 109 + 110 + ok = 0 111 + fail = 0 112 + for name, family_str in _FAMILIES.items(): 113 + # Derive expected filename from family string 114 + slug = family_str.replace("+", "") 115 + dest = out / f"{slug}-Regular.ttf" 116 + 117 + if dest.exists(): 118 + print(f" skip {name} (already exists)") 119 + ok += 1 120 + continue 121 + 122 + print(f" fetch {name} ...", end=" ", flush=True) 123 + 124 + # For variable fonts the CSS API returns subsets; go straight to GitHub. 125 + if name in _CSS_API_OK: 126 + gstatic_url = _fetch_gstatic_url(family_str) 127 + if gstatic_url and _download(gstatic_url, dest): 128 + print(f"ok ({dest.stat().st_size // 1024} KB)") 129 + ok += 1 130 + time.sleep(0.3) 131 + continue 132 + print("CSS API failed, trying GitHub ...", end=" ", flush=True) 133 + 134 + gh_url = _GITHUB_URLS.get(name) 135 + if gh_url and _download(gh_url, dest): 136 + print(f"ok via GitHub ({dest.stat().st_size // 1024} KB)") 137 + ok += 1 138 + else: 139 + print("FAILED") 140 + fail += 1 141 + 142 + time.sleep(0.5) 143 + 144 + print(f"\nDone: {ok} fonts saved to {out}, {fail} failed.") 145 + if ok: 146 + print("Run `uv run generate-hw --help` to generate training data.") 147 + 148 + 149 + if __name__ == "__main__": 150 + main()
+328
src/generate_handwritten.py
··· 1 + """ 2 + Generate structured text+math documents rendered in handwriting-style fonts. 3 + 4 + Fills two training data gaps: 5 + 6 + hw -- all text rendered in a handwriting font; math stays typeset. 7 + Realistic: handwritten notes typically have careful notation. 8 + Target splits: hw_structured_{train,val,test} 9 + 10 + mix -- per-paragraph font mixing within the same document: each block 11 + independently gets a handwriting font or the default typeset font. 12 + Models the appearance of partially annotated / hybrid documents. 13 + Target splits: hw_mixed_{train,val,test} 14 + 15 + In both modes the manifest typst field records the CLEAN body (no font 16 + directives), matching the format of typeset_mixed_* splits so data.py can 17 + load them without changes. Add the new split names to TRAIN_SPLITS / 18 + VAL_SPLITS / TEST_SPLITS in data.py when ready to train. 19 + 20 + Prerequisites: 21 + uv run download-hw-fonts # fetch fonts into data/fonts/handwriting/ 22 + 23 + Usage: 24 + uv run generate-hw --mode hw --count 15000 --out data/hw_structured_train 25 + uv run generate-hw --mode mix --count 10000 --out data/hw_mixed_train 26 + uv run generate-hw --mode hw --count 500 --out data/hw_structured_val --seed 100 27 + uv run generate-hw --mode mix --count 500 --out data/hw_mixed_val --seed 100 28 + uv run generate-hw --mode hw --count 500 --out data/hw_structured_test --seed 200 29 + uv run generate-hw --mode mix --count 500 --out data/hw_mixed_test --seed 200 30 + """ 31 + 32 + import argparse 33 + import hashlib 34 + import json 35 + import random 36 + import subprocess 37 + import tempfile 38 + from concurrent.futures import ThreadPoolExecutor, as_completed 39 + from pathlib import Path 40 + 41 + from tqdm import tqdm 42 + 43 + from .generate_mixed import generate_body 44 + 45 + 46 + # ── Font pool ───────────────────────────────────────────────────────────────── 47 + 48 + _DEFAULT_FONT_DIR = Path("data/fonts/handwriting") 49 + 50 + # Maps the Typst family name (what Typst sees in font metadata) to the filename 51 + # slug used when checking whether the font was downloaded. 52 + _FONTS: dict[str, str] = { 53 + "Comic Neue": "ComicNeue-Regular.ttf", 54 + "Gochi Hand": "GochiHand-Regular.ttf", 55 + "Handlee": "Handlee-Regular.ttf", 56 + "Oswald": "Oswald-Regular.ttf", 57 + "Dancing Script": "DancingScript-Regular.ttf", 58 + "Special Elite": "SpecialElite-Regular.ttf", 59 + } 60 + 61 + # Font sizes (pt) sampled for diversity. Heavier/larger fonts look more 62 + # "blackboard-style"; smaller sizes give tighter notes appearance. 63 + _FONT_SIZES = [10, 10, 11, 11, 11, 12, 12, 13, 14] 64 + 65 + # Ink colours: mostly black, occasional dark blue / dark grey 66 + _INK_COLOURS = [ 67 + "#000000", "#000000", "#000000", "#000000", # 4/7 pure black 68 + "#1a1a1a", # near-black 69 + "#0d0d4d", # dark navy (pen ink) 70 + "#2b2b2b", # dark grey 71 + ] 72 + 73 + 74 + def _available_fonts(font_dir: Path) -> list[str]: 75 + """Return Typst family names whose TTF files exist in font_dir.""" 76 + if not font_dir.exists(): 77 + return [] 78 + present = {p.name for p in font_dir.iterdir()} 79 + return [family for family, fname in _FONTS.items() if fname in present] 80 + 81 + 82 + def _pick_font(rng: random.Random, available: list[str]) -> str | None: 83 + """Sample uniformly from available handwriting fonts + Typst default (None).""" 84 + pool = available + [None] # None -> New Computer Modern (Typst default) 85 + return rng.choice(pool) 86 + 87 + 88 + # ── Typst templates ─────────────────────────────────────────────────────────── 89 + 90 + # hw mode: override the text font for the whole document. 91 + # Math equations continue using Typst's built-in math fonts (realistic). 92 + _TEMPLATE_HW = ( 93 + "#set page(width: {width}, height: auto, " 94 + "margin: (x: 10pt, y: 8pt), fill: white)\n" 95 + '#set text(font: ("{font}", "New Computer Modern"), ' 96 + "size: {size}pt, fill: rgb(\"{ink}\"), fallback: true)\n" 97 + "#set list(spacing: 1.2em)\n" 98 + "#set enum(spacing: 1.2em)\n" 99 + "{body}\n" 100 + ) 101 + 102 + # mix mode: default document font (typeset), individual blocks may be 103 + # wrapped in #text(font: ...) -- handled in _apply_mixed_fonts(). 104 + _TEMPLATE_MIX = ( 105 + "#set page(width: {width}, height: auto, " 106 + "margin: (x: 10pt, y: 8pt), fill: white)\n" 107 + "#set list(spacing: 1.0em)\n" 108 + "#set enum(spacing: 1.0em)\n" 109 + "{body}\n" 110 + ) 111 + 112 + 113 + # ── Block-level font mixing ─────────────────────────────────────────────────── 114 + 115 + def _apply_mixed_fonts(body: str, rng: random.Random, font: str, 116 + hw_prob: float = 0.5, size: int = 11, 117 + ink: str = "#000000") -> str: 118 + """ 119 + Wrap random paragraph-level blocks in a Typst scoped content block that 120 + overrides the text font. Returns a modified body for rendering only -- 121 + the label (manifest typst) records the original clean body. 122 + 123 + Blocks are delimited by double newlines (Typst paragraph breaks). Single- 124 + block bodies (inline sequences, simple tables) are treated as one unit. 125 + """ 126 + blocks = body.split("\n\n") 127 + 128 + hw_style = ( 129 + f'#set text(font: ("{font}", "New Computer Modern"), ' 130 + f'size: {size}pt, fill: rgb("{ink}"), fallback: true)' 131 + ) 132 + 133 + result: list[str] = [] 134 + for block in blocks: 135 + stripped = block.strip() 136 + if stripped and rng.random() < hw_prob: 137 + # Scoped content block: #[#set text(...); content] 138 + # This keeps list/table markup valid inside the block. 139 + result.append(f"#[{hw_style}; {block}]") 140 + else: 141 + result.append(block) 142 + 143 + return "\n\n".join(result) 144 + 145 + 146 + # ── Rendering ───────────────────────────────────────────────────────────────── 147 + 148 + def _render( 149 + body: str, 150 + out_path: Path, 151 + page_width: str, 152 + mode: str, 153 + font: str | None, 154 + font_dir: Path, 155 + size: int, 156 + ink: str, 157 + rng: random.Random, 158 + ) -> tuple[bool, str]: 159 + """ 160 + Compile one sample to PNG. 161 + 162 + mode='hw' -- set hw font globally in the page header; body is used as-is. 163 + mode='mix' -- body may contain per-block #[#set text(...); ...] wrappers. 164 + """ 165 + if mode == "hw" and font is not None: 166 + src = _TEMPLATE_HW.format( 167 + width=page_width, font=font, size=size, ink=ink, body=body 168 + ) 169 + else: 170 + # font=None (default) or mix mode: use plain template 171 + src = _TEMPLATE_MIX.format(width=page_width, body=body) 172 + 173 + with tempfile.NamedTemporaryFile(suffix=".typ", mode="w", delete=False) as f: 174 + f.write(src) 175 + typ_path = Path(f.name) 176 + 177 + cmd = [ 178 + "typst", "compile", 179 + "--format", "png", 180 + "--ppi", "150", 181 + str(typ_path), 182 + str(out_path), 183 + ] 184 + if font is not None and font_dir.exists(): 185 + cmd += ["--font-path", str(font_dir.resolve())] 186 + 187 + try: 188 + result = subprocess.run(cmd, capture_output=True, timeout=15) 189 + return result.returncode == 0, result.stderr.decode(errors="replace") 190 + except subprocess.TimeoutExpired: 191 + return False, "timeout" 192 + except FileNotFoundError: 193 + return False, "typst not found" 194 + finally: 195 + typ_path.unlink(missing_ok=True) 196 + 197 + 198 + # ── Main ────────────────────────────────────────────────────────────────────── 199 + 200 + def main() -> None: 201 + parser = argparse.ArgumentParser( 202 + description="Generate handwriting-font structured math+text data." 203 + ) 204 + parser.add_argument("--mode", choices=["hw", "mix"], default="hw", 205 + help="hw=whole doc handwritten; mix=per-block mixing") 206 + parser.add_argument("--count", type=int, default=15_000) 207 + parser.add_argument("--out", default="data/hw_structured_train") 208 + parser.add_argument("--jobs", type=int, default=4) 209 + parser.add_argument("--seed", type=int, default=42) 210 + parser.add_argument("--font-dir", default=str(_DEFAULT_FONT_DIR), 211 + help="Directory containing downloaded handwriting TTFs") 212 + parser.add_argument("--hw-prob", type=float, default=0.55, 213 + help="(mix mode) probability each block gets hw font") 214 + parser.add_argument("--show-failures", type=int, default=0, metavar="N") 215 + args = parser.parse_args() 216 + 217 + font_dir = Path(args.font_dir) 218 + available = _available_fonts(font_dir) 219 + if not available: 220 + print( 221 + f"No handwriting fonts found in {font_dir}.\n" 222 + "Run `uv run download-hw-fonts` first." 223 + ) 224 + raise SystemExit(1) 225 + 226 + print(f"Available fonts ({len(available)}): {', '.join(available)}") 227 + 228 + out = Path(args.out) 229 + img_dir = out / "images" 230 + img_dir.mkdir(parents=True, exist_ok=True) 231 + 232 + rng = random.Random(args.seed) 233 + 234 + # ── Phase 1: generate unique bodies ────────────────────────────────────── 235 + print(f"Generating {args.count:,} unique bodies (mode={args.mode}) ...") 236 + 237 + seen: set[str] = set() 238 + # (clean_body, page_width, font, size, ink) 239 + candidates: list[tuple[str, str, str, int, str]] = [] 240 + attempts = 0 241 + 242 + with tqdm(total=args.count, unit="body") as pbar: 243 + while len(candidates) < args.count: 244 + attempts += 1 245 + body, page_width = generate_body(rng) 246 + 247 + if body in seen: 248 + continue 249 + 250 + # mix mode: require at least two blocks so within-doc mixing is 251 + # meaningful (single-block bodies would always be fully hw or typeset). 252 + if args.mode == "mix" and "\n\n" not in body: 253 + continue 254 + 255 + seen.add(body) 256 + font = _pick_font(rng, available) 257 + size = rng.choice(_FONT_SIZES) 258 + ink = rng.choice(_INK_COLOURS) 259 + candidates.append((body, page_width, font, size, ink)) 260 + pbar.update(1) 261 + 262 + print(f" {attempts:,} attempts ({attempts / len(candidates):.1f}x overhead)") 263 + 264 + # ── Phase 2: render ─────────────────────────────────────────────────────── 265 + print(f"Rendering {len(candidates):,} images with {args.jobs} workers ...") 266 + 267 + records: list[dict] = [] 268 + failures = 0 269 + shown_failures = 0 270 + 271 + def _task( 272 + clean_body: str, page_width: str, font: str, size: int, ink: str 273 + ) -> tuple[str, str, bool, str]: 274 + if args.mode == "hw": 275 + render_body = clean_body 276 + else: 277 + render_body = _apply_mixed_fonts( 278 + clean_body, rng, font, hw_prob=args.hw_prob, size=size, ink=ink 279 + ) 280 + 281 + # Hash over render body + font + size + ink + mode for uniqueness 282 + h_key = f"{args.mode}:{font}:{size}:{ink}:{page_width}:{render_body}" 283 + h = hashlib.sha1(h_key.encode()).hexdigest()[:16] 284 + out_path = img_dir / f"{h}.png" 285 + 286 + ok, err = _render( 287 + render_body, out_path, page_width, 288 + mode=args.mode, font=font, font_dir=font_dir, size=size, ink=ink, 289 + rng=rng, 290 + ) 291 + # Manifest always records the CLEAN body (no font directives). 292 + return clean_body, f"images/{h}.png", ok, err, font 293 + 294 + with ThreadPoolExecutor(max_workers=args.jobs) as pool: 295 + futs = { 296 + pool.submit(_task, body, pw, font, size, ink): body 297 + for body, pw, font, size, ink in candidates 298 + } 299 + with tqdm(total=len(candidates), unit="img") as pbar: 300 + for fut in as_completed(futs): 301 + clean_body, rel_path, ok, err, used_font = fut.result() 302 + if ok: 303 + records.append({"image": rel_path, "typst": clean_body}) 304 + else: 305 + failures += 1 306 + if shown_failures < args.show_failures: 307 + tqdm.write(f"\n--- failure ---\nbody: {clean_body!r}\n{err.strip()}") 308 + shown_failures += 1 309 + pbar.update(1) 310 + 311 + # ── Phase 3: manifest ───────────────────────────────────────────────────── 312 + manifest = out / "manifest.jsonl" 313 + with manifest.open("w") as f: 314 + for r in records: 315 + f.write(json.dumps(r) + "\n") 316 + 317 + split_hint = out.name # e.g. hw_structured_train 318 + print(f"Wrote {len(records):,} records to {manifest} ({failures} render failures)") 319 + print( 320 + f"\nNext steps:\n" 321 + f" 1. Add '{split_hint}' to TRAIN_SPLITS in data.py\n" 322 + f" (it is NOT in _MATH_ONLY_SPLITS -- body already contains $ delimiters)\n" 323 + f" 2. Set a sampling cap in train.py if needed (e.g. 15k–20k)" 324 + ) 325 + 326 + 327 + if __name__ == "__main__": 328 + main()
+10 -5
src/generate_mixed.py
··· 252 252 def generate_body(rng: random.Random) -> tuple[str, str]: 253 253 """Returns (body, page_width) where page_width is 'auto' or e.g. '280pt'.""" 254 254 r = rng.random() 255 - if r < 0.15: 255 + if r < 0.18: # ~18% bare single equation 256 + return f"$ {generate_expr(rng)} $", "auto" 257 + elif r < 0.33: # ~15% short inline (1-2 tokens) 258 + n = rng.choices([1, 2], weights=[3, 2])[0] 259 + return _inline_seq(rng, n, require_math=True), "auto" 260 + elif r < 0.45: # ~12% table 256 261 return _generate_table(rng), "auto" 257 - elif r < 0.30: # ~15% multi-paragraph 262 + elif r < 0.57: # ~12% multi-paragraph 258 263 width = rng.choice(_PARA_WIDTHS) 259 264 return _multi_paragraph(rng), f"{width}pt" 260 - elif r < 0.40: # ~10% para→list or list→para 265 + elif r < 0.65: # ~8% para→list or list→para 261 266 width = rng.choice(_PARA_WIDTHS) 262 267 body = (_para_then_list if rng.random() < 0.5 else _list_then_para)(rng) 263 268 return body, f"{width}pt" 264 - elif r < 0.58: # ~18% plain lists 269 + elif r < 0.77: # ~12% plain lists 265 270 width = rng.choice(_PARA_WIDTHS) 266 271 return "\n".join(_list_body(rng)), f"{width}pt" 267 - n = rng.choices([2, 3, 4, 5, 6, 7], weights=[4, 8, 7, 5, 3, 1])[0] 272 + n = rng.choices([3, 4, 5, 6, 7], weights=[8, 7, 5, 3, 1])[0] # ~23% longer inline 268 273 return _inline_seq(rng, n, require_math=True), "auto" 269 274 270 275
+7 -9
src/train.py
··· 12 12 from unsloth.trainer import UnslothVisionDataCollator 13 13 from trl import SFTTrainer, SFTConfig 14 14 15 - from .data import (BASE_MODEL, TRAIN_SPLITS, load_records, make_dataset) 15 + from .data import (BASE_MODEL, TRAIN_SPLITS, VAL_SPLITS, load_records, make_dataset) 16 16 17 17 # Per-split record caps. Synthetic-heavy splits are capped to prevent them 18 18 # from dominating the training mix. Real and document-structure splits are ··· 69 69 train_records.extend(recs) 70 70 rng.shuffle(train_records) 71 71 72 - # Stratified val: 250 mathwriting + 250 typeset + all mixed (500) = 1000 72 + # Stratified val: 250 per split, capped to available 73 73 val_rng = random.Random(42) 74 - mw_val = load_records(["mathwriting_val"], dedupe=False) 75 - ts_val = load_records(["typeset_val"], dedupe=False) 76 - mixed_val = load_records(["typeset_mixed_val"], dedupe=False) 77 - val_records = (val_rng.sample(mw_val, min(250, len(mw_val))) 78 - + val_rng.sample(ts_val, min(250, len(ts_val))) 79 - + mixed_val) 74 + val_records = [] 75 + for split in VAL_SPLITS: 76 + recs = load_records([split], dedupe=False) 77 + val_records += val_rng.sample(recs, min(250, len(recs))) 80 78 val_rng.shuffle(val_records) 81 79 82 80 print(f"Train: {len(train_records):,} Val: {len(val_records):,}") ··· 110 108 eval_strategy="steps", 111 109 eval_steps=500, 112 110 save_steps=500, 113 - save_total_limit=5, 111 + save_total_limit=7, 114 112 load_best_model_at_end=False, 115 113 output_dir=out_dir, 116 114 run_name="gemma-4-e2b",