this repo has no description
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

V1 train

+90 -11
+1
pyproject.toml
··· 19 19 train-hnm = "src.train_hnm:main" 20 20 export = "src.export:main" 21 21 generate-typeset = "src.generate_typeset:main" 22 + probe = "src.probe:main" 22 23 23 24 [build-system] 24 25 requires = ["hatchling"]
+20 -10
src/eval.py
··· 24 24 25 25 def extract_assistant(decoded: str) -> str: 26 26 """Pull the assistant turn out of a full decoded sequence.""" 27 - # Gemma chat format ends assistant turn after <end_of_turn> 28 - if "<start_of_turn>model" in decoded: 29 - decoded = decoded.split("<start_of_turn>model")[-1] 30 - if "<end_of_turn>" in decoded: 31 - decoded = decoded.split("<end_of_turn>")[0] 27 + # Gemma 4 uses <|turn>model / <turn|>; older Gemma used <start_of_turn>model / <end_of_turn> 28 + for start_marker in ("<|turn>model", "<start_of_turn>model"): 29 + if start_marker in decoded: 30 + decoded = decoded.split(start_marker)[-1] 31 + break 32 + for end_marker in ("<turn|>", "<end_of_turn>"): 33 + if end_marker in decoded: 34 + decoded = decoded.split(end_marker)[0] 35 + break 32 36 return decoded.strip() 33 37 34 38 35 - def evaluate(checkpoint: str, batch_size: int = 8) -> float: 39 + def evaluate(checkpoint: str, batch_size: int = 8, n: int | None = None) -> float: 36 40 model, processor = FastVisionModel.from_pretrained(checkpoint, load_in_4bit=True) 37 41 FastVisionModel.for_inference(model) 38 42 model.eval() 39 43 40 44 records = load_records(TEST_SPLITS, dedupe=False) 45 + if n is not None: 46 + records = records[:n] 41 47 correct = 0 42 48 43 49 from PIL import Image ··· 54 60 ], 55 61 } 56 62 ] 57 - inputs = processor.apply_chat_template( 58 - messages, 59 - add_generation_prompt=True, 63 + inputs = processor( 64 + images=[img], 65 + text=processor.apply_chat_template( 66 + messages, 67 + add_generation_prompt=True, 68 + ), 60 69 return_tensors="pt", 61 70 ).to("cuda") 62 71 ··· 78 87 parser = argparse.ArgumentParser() 79 88 parser.add_argument("--checkpoint", default="checkpoints/baseline/final") 80 89 parser.add_argument("--batch-size", type=int, default=8) 90 + parser.add_argument("--n", type=int, default=None, help="Evaluate only first N examples") 81 91 args = parser.parse_args() 82 - evaluate(args.checkpoint, args.batch_size) 92 + evaluate(args.checkpoint, args.batch_size, args.n) 83 93 84 94 85 95 if __name__ == "__main__":
+62
src/probe.py
··· 1 + """ 2 + Quick inference probe: run N examples and print raw output vs ground truth. 3 + 4 + Usage: uv run probe [--checkpoint checkpoints/baseline/final] [--n 5] 5 + """ 6 + 7 + import argparse 8 + import torch 9 + from PIL import Image 10 + from unsloth import FastVisionModel 11 + 12 + from .data import BASE_MODEL, TEST_SPLITS, PROMPT, load_records 13 + from .eval import extract_assistant, normalize 14 + 15 + 16 + def main() -> None: 17 + parser = argparse.ArgumentParser() 18 + parser.add_argument("--checkpoint", default="checkpoints/baseline/final") 19 + parser.add_argument("--n", type=int, default=5) 20 + args = parser.parse_args() 21 + 22 + model, processor = FastVisionModel.from_pretrained(args.checkpoint, load_in_4bit=True) 23 + FastVisionModel.for_inference(model) 24 + model.eval() 25 + 26 + records = load_records(TEST_SPLITS, dedupe=False)[: args.n] 27 + 28 + for i, r in enumerate(records): 29 + img = Image.open(r["image_path"]).convert("RGB") 30 + messages = [ 31 + { 32 + "role": "user", 33 + "content": [ 34 + {"type": "image", "image": img}, 35 + {"type": "text", "text": PROMPT}, 36 + ], 37 + } 38 + ] 39 + inputs = processor( 40 + images=[img], 41 + text=processor.apply_chat_template(messages, add_generation_prompt=True), 42 + return_tensors="pt", 43 + ).to("cuda") 44 + 45 + with torch.no_grad(): 46 + out = model.generate(**inputs, max_new_tokens=256, do_sample=False) 47 + 48 + decoded_full = processor.decode(out[0], skip_special_tokens=False) 49 + decoded_clean = processor.decode(out[0], skip_special_tokens=True) 50 + pred = extract_assistant(decoded_full) 51 + 52 + print(f"\n{'='*60}") 53 + print(f"[{i}] image: {r['image_path']}") 54 + print(f" EXPECTED : {repr(r['typst'])}") 55 + print(f" EXTRACTED: {repr(pred)}") 56 + print(f" CLEAN : {repr(decoded_clean)}") 57 + print(f" MATCH : {normalize(pred) == normalize(r['typst'])}") 58 + print(f" RAW (last 300 chars): {repr(decoded_full[-300:])}") 59 + 60 + 61 + if __name__ == "__main__": 62 + main()
+7 -1
src/train.py
··· 43 43 typeset_val_records = load_records(TYPESET_VAL_SPLITS, dedupe=False, root=TYPESET_ROOT) 44 44 val_records += typeset_val_records 45 45 print(f"Typeset val: {len(typeset_val_records):,} records mixed in") 46 + 47 + import random as _random 48 + _rng = _random.Random(42) 49 + if len(val_records) > 1000: 50 + val_records = _rng.sample(val_records, 1000) 51 + 46 52 print(f"Train: {len(train_records):,} Val: {len(val_records):,}") 47 53 48 54 train_ds = make_dataset(train_records, do_augment=True) ··· 76 82 ), 77 83 ) 78 84 79 - trainer.train() 85 + trainer.train(resume_from_checkpoint=True) 80 86 model.save_pretrained("checkpoints/baseline/final") 81 87 processor.save_pretrained("checkpoints/baseline/final") 82 88 print("Saved to checkpoints/baseline/final")