this repo has no description
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

Overhaul Gemma training; fix eval stratification; fix mathwriting label format

- train.py: remove dedupe, port per-split caps from DeepSeek script, add
argparse (--epochs, --lr, --output-dir, --cap), default output dir to
checkpoints/gemma-4-e2b, add tensorboard logging
- eval_deepseek.py: --n is now per-split cap (stratified sampling) instead
of a head slice across all splits combined
- data.py: add mathwriting_val/mathwriting_test to _MATH_ONLY_SPLITS so
bare-expression labels get $ ... $ wrapping at eval time

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

+69 -20
+1
src/data.py
··· 40 40 _MATH_ONLY_SPLITS = { 41 41 "crohme_gen_2019", "crohme_gen_2023", "crohme_gen_syntactic", "crohme_real_train", 42 42 "mathwriting_train", "mathwriting_synthetic", "mathwriting_symbols", 43 + "mathwriting_val", "mathwriting_test", 43 44 "typeset_train", "typeset_val", "typeset_test", 44 45 } 45 46
+12 -5
src/eval_deepseek.py
··· 6 6 Usage: 7 7 uv run eval-deepseek [--checkpoint checkpoints/deepseek/final] 8 8 [--splits mathwriting_test typeset_test typeset_mixed_test] 9 - [--n 100] 9 + [--n 500] 10 10 """ 11 11 12 12 import argparse ··· 16 16 from tqdm import tqdm 17 17 18 18 import re 19 + 20 + import random 19 21 20 22 from .data import TEST_SPLITS, load_records 21 23 from .mine_failures import _infer, _load ··· 35 37 "\nTranscribe this image to Typst notation.\n", add_special_tokens=False 36 38 ) 37 39 38 - records = load_records(splits or TEST_SPLITS, dedupe=False) 39 - if n is not None: 40 - records = records[:n] 40 + split_names = splits or TEST_SPLITS 41 + rng = random.Random(0) 42 + records = [] 43 + for name in split_names: 44 + recs = load_records([name], dedupe=False) 45 + if n is not None and len(recs) > n: 46 + recs = rng.sample(recs, n) 47 + records.extend(recs) 41 48 42 49 correct = 0 43 50 per_split: dict[str, list[bool]] = {} ··· 66 73 parser.add_argument("--splits", nargs="+", default=None, 67 74 metavar="SPLIT", help="Override test splits") 68 75 parser.add_argument("--n", type=int, default=None, 69 - help="Evaluate only first N records") 76 + help="Max records per split (stratified); omit for full test set") 70 77 args = parser.parse_args() 71 78 evaluate(args.checkpoint, args.splits, args.n) 72 79
+56 -15
src/train.py
··· 1 1 """ 2 2 Baseline QLoRA fine-tuning of Gemma 4 E2B for Typst math OCR. 3 3 4 - Usage: uv run train 4 + Usage: uv run train [--epochs 2] [--cap mathwriting_synthetic 20000] 5 5 """ 6 + 7 + import argparse 8 + import random 6 9 7 10 from unsloth import FastVisionModel 8 11 from unsloth.trainer import UnslothVisionDataCollator ··· 10 13 11 14 from .data import (BASE_MODEL, TRAIN_SPLITS, VAL_SPLITS, load_records, make_dataset) 12 15 16 + # Per-split record caps. Synthetic-heavy splits are capped to prevent them 17 + # from dominating the training mix. Real and document-structure splits are 18 + # uncapped. Override any cap with --cap SPLIT N. 19 + _DEFAULT_CAPS = { 20 + "mathwriting_synthetic": 20_000, 21 + "crohme_gen_2019": 15_000, 22 + "mathwriting_train": 10_000, 23 + } 24 + 13 25 14 26 def main() -> None: 27 + parser = argparse.ArgumentParser() 28 + parser.add_argument("--epochs", type=int, default=2) 29 + parser.add_argument("--lr", type=float, default=2e-4) 30 + parser.add_argument("--output-dir", default="checkpoints/gemma-4-e2b") 31 + parser.add_argument("--cap", nargs=2, action="append", metavar=("SPLIT", "N"), 32 + default=None, 33 + help="Override per-split cap, e.g. --cap mathwriting_synthetic 20000") 34 + args = parser.parse_args() 35 + 36 + caps = dict(_DEFAULT_CAPS) 37 + if args.cap: 38 + for split, n in args.cap: 39 + caps[split] = int(n) 40 + 15 41 model, processor = FastVisionModel.from_pretrained( 16 42 BASE_MODEL, 17 43 load_in_4bit=True, ··· 30 56 bias="none", 31 57 ) 32 58 33 - train_records = load_records(TRAIN_SPLITS, dedupe=True) 34 - val_records = load_records(VAL_SPLITS, dedupe=False) 59 + rng = random.Random(42) 60 + 61 + train_records: list[dict] = [] 62 + for split in TRAIN_SPLITS: 63 + recs = load_records([split], dedupe=False) 64 + cap = caps.get(split) 65 + if cap and len(recs) > cap: 66 + recs = rng.sample(recs, cap) 67 + train_records.extend(recs) 68 + rng.shuffle(train_records) 35 69 36 - import random as _random 37 - _rng = _random.Random(42) 70 + val_rng = random.Random(42) 71 + val_records = load_records(VAL_SPLITS, dedupe=False) 38 72 if len(val_records) > 1000: 39 - val_records = _rng.sample(val_records, 1000) 73 + val_records = val_rng.sample(val_records, 1000) 40 74 41 75 print(f"Train: {len(train_records):,} Val: {len(val_records):,}") 76 + for split in TRAIN_SPLITS: 77 + n = sum(1 for r in train_records if r["split"] == split) 78 + cap = caps.get(split) 79 + print(f" {split}: {n:,}" + (f" (cap {cap:,})" if cap else "")) 42 80 43 81 train_ds = make_dataset(train_records, do_augment=True) 44 82 val_ds = make_dataset(val_records, do_augment=False) 45 83 84 + out_dir = args.output_dir 46 85 trainer = SFTTrainer( 47 86 model=model, 48 87 tokenizer=processor, ··· 52 91 args=SFTConfig( 53 92 per_device_train_batch_size=4, 54 93 per_device_eval_batch_size=4, 55 - gradient_accumulation_steps=4, # 4 * 4 = 16, same effective batch 56 - num_train_epochs=3, 57 - learning_rate=2e-4, 94 + gradient_accumulation_steps=4, # effective batch 16 95 + num_train_epochs=args.epochs, 96 + learning_rate=args.lr, 58 97 warmup_steps=500, 59 98 lr_scheduler_type="cosine", 60 99 bf16=True, 61 100 fp16=False, 62 - dataloader_num_workers=0, # required for set_transform + PIL 101 + dataloader_num_workers=0, # required for lazy PIL loading 63 102 logging_steps=50, 64 103 eval_strategy="steps", 65 104 eval_steps=500, 66 - save_steps=100, # ~14 min to first checkpoint; eval every 500 105 + save_steps=100, 67 106 save_total_limit=10, 68 107 load_best_model_at_end=False, 69 - output_dir="checkpoints/baseline", 108 + output_dir=out_dir, 109 + run_name="gemma-4-e2b", 110 + report_to=["tensorboard"], 70 111 dataset_kwargs={"skip_prepare_dataset": True}, 71 112 ), 72 113 ) 73 114 74 115 trainer.train(resume_from_checkpoint=True) 75 - model.save_pretrained("checkpoints/baseline/final") 76 - processor.save_pretrained("checkpoints/baseline/final") 77 - print("Saved to checkpoints/baseline/final") 116 + model.save_pretrained(f"{out_dir}/final") 117 + processor.save_pretrained(f"{out_dir}/final") 118 + print(f"Saved to {out_dir}/final") 78 119 79 120 80 121 if __name__ == "__main__":