this repo has no description
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

Fix image injection bug; add DeepSeek eval/mining scripts

collate_deepseek: pass zeros for local_crops so forward() takes the
global-only branch (145 features), not local+global (289 features).
The old code passed the real image for both slots, causing masked_scatter_
to inject local-crop features and discard the global view entirely.

Also: add eval_deepseek, mine_failures, infer_debug, train_hnm scripts;
add split tracking to data.py; save every 250 steps (keep 10); backfill
TensorBoard epoch + learning_rate scalars.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

+408 -10
+2
pyproject.toml
··· 28 28 train = "src.train:main" 29 29 train-deepseek = "src.train_deepseek:main" 30 30 evaluate = "src.eval:main" 31 + eval-deepseek = "src.eval_deepseek:main" 32 + infer-debug = "src.infer_debug:main" 31 33 mine = "src.mine_failures:main" 32 34 train-hnm = "src.train_hnm:main" 33 35 export = "src.export:main"
+6 -4
src/backfill_tb.py
··· 50 50 51 51 for i, d in enumerate(train_entries): 52 52 step = (i + 1) * LOGGING_STEPS 53 - writer.add_scalar("train/loss", d["loss"], step) 54 - writer.add_scalar("train/grad_norm", d["grad_norm"], step) 55 - writer.add_scalar("train/lr", d["learning_rate"], step) 53 + writer.add_scalar("train/loss", d["loss"], step) 54 + writer.add_scalar("train/grad_norm", d["grad_norm"], step) 55 + writer.add_scalar("train/learning_rate", d["learning_rate"], step) 56 + writer.add_scalar("train/epoch", d["epoch"], step) 56 57 57 58 for i, d in enumerate(eval_entries): 58 59 step = (i + 1) * EVAL_STEPS 59 - writer.add_scalar("eval/loss", d["eval_loss"], step) 60 + writer.add_scalar("eval/loss", d["eval_loss"], step) 61 + writer.add_scalar("eval/epoch", d["epoch"], step) 60 62 61 63 writer.close() 62 64 print(f"Wrote {len(train_entries)} train + {len(eval_entries)} eval entries to {out_dir}")
+6 -4
src/collate_deepseek.py
··· 146 146 attn_mask = (input_ids != self.pad_id) 147 147 148 148 # images: list of (local_crops, global_view) tuples, one per sample. 149 - # With crop_mode=False there is one tile; the same tensor serves both 150 - # slots. Shape per slot: [1, 3, 768, 768]. 151 - # TODO: validate tuple format against model source once loaded locally. 149 + # local_crops must be zeros so forward() takes the crop_mode=False else-branch, 150 + # producing exactly N_IMAGE_TOKENS=145 global features to match images_seq_mask. 151 + # Passing the real image for both slots triggers the if-branch which concatenates 152 + # local+global+sep = 289 features; masked_scatter_ then injects the wrong subset. 152 153 imgs_stacked = torch.stack(img_tensors) # [B, 3, 768, 768] 154 + zeros = torch.zeros_like(imgs_stacked[0].unsqueeze(0)) 153 155 images = [ 154 - (imgs_stacked[i].unsqueeze(0), imgs_stacked[i].unsqueeze(0)) 156 + (zeros, imgs_stacked[i].unsqueeze(0)) 155 157 for i in range(len(batch)) 156 158 ] 157 159
+1
src/data.py
··· 98 98 records.append({ 99 99 "image_path": str(base / r["image"]), 100 100 "typst": typst, 101 + "split": name, 101 102 }) 102 103 return records 103 104
+75
src/eval_deepseek.py
··· 1 + """ 2 + ExpRate evaluation of a fine-tuned DeepSeek-OCR-2 checkpoint on the test splits. 3 + 4 + ExpRate = fraction of exact string matches after whitespace normalisation. 5 + 6 + Usage: 7 + uv run eval-deepseek [--checkpoint checkpoints/deepseek/final] 8 + [--splits mathwriting_test typeset_test typeset_mixed_test] 9 + [--n 100] 10 + """ 11 + 12 + import argparse 13 + 14 + import torch 15 + from PIL import Image 16 + from tqdm import tqdm 17 + 18 + import re 19 + 20 + from .data import TEST_SPLITS, load_records 21 + from .mine_failures import _infer, _load 22 + 23 + 24 + def normalize(s: str) -> str: 25 + return re.sub(r"\s+", " ", s).strip() 26 + 27 + 28 + def evaluate( 29 + checkpoint: str, 30 + splits: list[str] | None = None, 31 + n: int | None = None, 32 + ) -> float: 33 + model, tokenizer = _load(checkpoint) 34 + prompt_ids = tokenizer.encode( 35 + "\nTranscribe this image to Typst notation.\n", add_special_tokens=False 36 + ) 37 + 38 + records = load_records(splits or TEST_SPLITS, dedupe=False) 39 + if n is not None: 40 + records = records[:n] 41 + 42 + correct = 0 43 + per_split: dict[str, list[bool]] = {} 44 + 45 + for r in tqdm(records, desc="Evaluating"): 46 + img = Image.open(r["image_path"]).convert("RGB") 47 + pred = normalize(_infer(model, tokenizer, prompt_ids, img)) 48 + gt = normalize(r["typst"]) 49 + hit = pred == gt 50 + correct += hit 51 + per_split.setdefault(r.get("split", "unknown"), []).append(hit) 52 + print(f" GT : {repr(gt)}") 53 + print(f" PRED: {repr(pred)}") 54 + print(f" HIT : {hit}\n") 55 + 56 + total = len(records) 57 + print(f"\nExpRate: {correct/total:.4f} ({correct}/{total})") 58 + for split, hits in sorted(per_split.items()): 59 + print(f" {split}: {sum(hits)/len(hits):.4f} ({sum(hits)}/{len(hits)})") 60 + return correct / total 61 + 62 + 63 + def main() -> None: 64 + parser = argparse.ArgumentParser() 65 + parser.add_argument("--checkpoint", default="checkpoints/deepseek/final") 66 + parser.add_argument("--splits", nargs="+", default=None, 67 + metavar="SPLIT", help="Override test splits") 68 + parser.add_argument("--n", type=int, default=None, 69 + help="Evaluate only first N records") 70 + args = parser.parse_args() 71 + evaluate(args.checkpoint, args.splits, args.n) 72 + 73 + 74 + if __name__ == "__main__": 75 + main()
+82
src/infer_debug.py
··· 1 + """Quick diagnostic: run one inference and print raw token IDs + image injection status.""" 2 + import torch 3 + from PIL import Image 4 + from transformers import AutoTokenizer, BitsAndBytesConfig 5 + from peft import PeftModel 6 + 7 + from .collate_deepseek import IMAGE_TOKEN_ID, N_IMAGE_TOKENS, _PROMPT, preprocess_image 8 + from .data import TEST_SPLITS, load_records 9 + from .deepseek_ocr2.modeling_deepseekocr2 import DeepseekOCR2ForCausalLM 10 + 11 + MODEL_ID = "deepseek-ai/DeepSeek-OCR-2" 12 + CHECKPOINT = "checkpoints/deepseek/checkpoint-10500" 13 + 14 + 15 + def main(): 16 + bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", 17 + bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True) 18 + base = DeepseekOCR2ForCausalLM.from_pretrained(MODEL_ID, quantization_config=bnb, 19 + use_safetensors=True, _attn_implementation="eager") 20 + model = PeftModel.from_pretrained(base, CHECKPOINT) 21 + for name, param in model.named_parameters(): 22 + if param.dtype == torch.float16: 23 + param.data = param.data.to(torch.bfloat16) 24 + model.eval() 25 + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) 26 + 27 + prompt_ids = tokenizer.encode(_PROMPT, add_special_tokens=False) 28 + print(f"Prompt token IDs: {prompt_ids}") 29 + print(f"Prompt decoded: {repr(tokenizer.decode(prompt_ids))}") 30 + print(f"IMAGE_TOKEN_ID={IMAGE_TOKEN_ID}, N_IMAGE_TOKENS={N_IMAGE_TOKENS}") 31 + 32 + r = load_records(TEST_SPLITS, dedupe=False)[0] 33 + img = Image.open(r["image_path"]).convert("RGB") 34 + img_t = preprocess_image(img).cuda() 35 + print(f"Image tensor: dtype={img_t.dtype}, sum={img_t.sum().item():.2f}, shape={img_t.shape}") 36 + 37 + img_ids = [IMAGE_TOKEN_ID] * N_IMAGE_TOKENS 38 + ids = torch.tensor([img_ids + prompt_ids], dtype=torch.long, device="cuda") 39 + n_input = ids.shape[1] 40 + seq_mask = torch.zeros(1, n_input, dtype=torch.bool, device="cuda") 41 + seq_mask[0, :N_IMAGE_TOKENS] = True 42 + images = [(img_t.unsqueeze(0), img_t.unsqueeze(0))] 43 + spatial = torch.tensor([[1, 1]], dtype=torch.long, device="cuda") 44 + attn = torch.ones_like(ids) 45 + 46 + # Monkey-patch masked_scatter_ to verify image injection fires 47 + injection_happened = [False] 48 + original_mscatter = torch.Tensor.masked_scatter_ 49 + 50 + def patched_mscatter(self_t, mask, source): 51 + injection_happened[0] = True 52 + print(f" masked_scatter_ called: self={self_t.shape} {self_t.dtype}, " 53 + f"mask={mask.shape}, source={source.shape} {source.dtype}") 54 + return original_mscatter(self_t, mask, source) 55 + 56 + torch.Tensor.masked_scatter_ = patched_mscatter 57 + 58 + with torch.no_grad(): 59 + out = model.generate( 60 + input_ids=ids, 61 + attention_mask=attn, 62 + images=images, 63 + images_seq_mask=seq_mask, 64 + images_spatial_crop=spatial, 65 + max_new_tokens=30, 66 + do_sample=False, 67 + ) 68 + 69 + torch.Tensor.masked_scatter_ = original_mscatter 70 + print(f"Image injection fired: {injection_happened[0]}") 71 + generated = out[0][n_input:] 72 + print(f"\nGenerated token IDs ({len(generated)} tokens): {generated[:20].tolist()}") 73 + print(f"Decoded (skip_special=True): {repr(tokenizer.decode(generated, skip_special_tokens=True))}") 74 + print(f"Decoded (skip_special=False): {repr(tokenizer.decode(generated, skip_special_tokens=False))}") 75 + print(f"\nGT: {repr(r['typst'])}") 76 + print(f"\nEOS token id (tokenizer): {tokenizer.eos_token_id}") 77 + print(f"EOS token id (model.config): {model.config.eos_token_id}") 78 + print(f"BOS token id (model.config): {model.config.bos_token_id}") 79 + 80 + 81 + if __name__ == "__main__": 82 + main()
+143
src/mine_failures.py
··· 1 + """ 2 + HNM Step 1: identify hard failures from the training pool. 3 + 4 + Runs inference on a random sample of training records (default 10k), 5 + computes normalized edit distance between prediction and ground truth, 6 + and writes the top-N hardest failures to data/hnm_pool.jsonl. 7 + 8 + Samples from the training pool (not val/test) to keep evaluation sets clean. 9 + 10 + Usage: uv run mine [--checkpoint checkpoints/deepseek/final] 11 + [--sample 10000] 12 + [--top-n 3000] 13 + [--out data/hnm_pool.jsonl] 14 + """ 15 + 16 + import argparse 17 + import json 18 + import random 19 + from pathlib import Path 20 + 21 + import editdistance 22 + import torch 23 + from peft import PeftModel 24 + from PIL import Image 25 + from tqdm import tqdm 26 + from transformers import AutoTokenizer, BitsAndBytesConfig 27 + 28 + import re 29 + 30 + from .collate_deepseek import IMAGE_TOKEN_ID, N_IMAGE_TOKENS, _PROMPT, preprocess_image 31 + from .data import TRAIN_SPLITS, load_records 32 + from .deepseek_ocr2.modeling_deepseekocr2 import DeepseekOCR2ForCausalLM 33 + 34 + 35 + def normalize(s: str) -> str: 36 + return re.sub(r"\s+", " ", s).strip() 37 + 38 + MODEL_ID = "deepseek-ai/DeepSeek-OCR-2" 39 + 40 + 41 + def _load(checkpoint: str): 42 + bnb = BitsAndBytesConfig( 43 + load_in_4bit=True, 44 + bnb_4bit_quant_type="nf4", 45 + bnb_4bit_compute_dtype=torch.bfloat16, 46 + bnb_4bit_use_double_quant=True, 47 + ) 48 + base = DeepseekOCR2ForCausalLM.from_pretrained( 49 + MODEL_ID, 50 + quantization_config=bnb, 51 + use_safetensors=True, 52 + _attn_implementation="eager", 53 + ) 54 + model = PeftModel.from_pretrained(base, checkpoint) 55 + 56 + for name, param in model.named_parameters(): 57 + if param.dtype == torch.float16: 58 + param.data = param.data.to(torch.bfloat16) 59 + 60 + model.eval() 61 + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) 62 + return model, tokenizer 63 + 64 + 65 + def _infer(model, tokenizer, prompt_ids: list[int], img: Image.Image) -> str: 66 + img_t = preprocess_image(img).cuda() # [3, 768, 768] 67 + img_ids = [IMAGE_TOKEN_ID] * N_IMAGE_TOKENS 68 + ids = torch.tensor([img_ids + prompt_ids], dtype=torch.long, device="cuda") 69 + n_input = ids.shape[1] 70 + 71 + seq_mask = torch.zeros(1, n_input, dtype=torch.bool, device="cuda") 72 + seq_mask[0, :N_IMAGE_TOKENS] = True 73 + 74 + # zeros for local crops: forces the crop_mode=False else-branch → 145 global features 75 + images = [(torch.zeros_like(img_t.unsqueeze(0)), img_t.unsqueeze(0))] 76 + spatial = torch.tensor([[1, 1]], dtype=torch.long, device="cuda") 77 + attn = torch.ones_like(ids) 78 + 79 + with torch.no_grad(): 80 + out = model.generate( 81 + input_ids=ids, 82 + attention_mask=attn, 83 + images=images, 84 + images_seq_mask=seq_mask, 85 + images_spatial_crop=spatial, 86 + max_new_tokens=512, 87 + do_sample=False, 88 + ) 89 + return tokenizer.decode(out[0][n_input:], skip_special_tokens=True) 90 + 91 + 92 + def mine( 93 + checkpoint: str, 94 + sample_size: int = 10_000, 95 + top_n: int = 3_000, 96 + out_path: str = "data/hnm_pool.jsonl", 97 + ) -> None: 98 + model, tokenizer = _load(checkpoint) 99 + prompt_ids = tokenizer.encode(_PROMPT, add_special_tokens=False) 100 + 101 + all_records = load_records(TRAIN_SPLITS, dedupe=True) 102 + sample = random.sample(all_records, min(sample_size, len(all_records))) 103 + print(f"Mining over {len(sample):,} training samples") 104 + 105 + scored: list[tuple[float, dict]] = [] 106 + 107 + for r in tqdm(sample, desc="Inferring"): 108 + img = Image.open(r["image_path"]).convert("RGB") 109 + pred = normalize(_infer(model, tokenizer, prompt_ids, img)) 110 + gt = normalize(r["typst"]) 111 + 112 + if pred == gt: 113 + continue 114 + 115 + max_len = max(len(pred), len(gt), 1) 116 + dist = editdistance.eval(pred, gt) / max_len 117 + scored.append((dist, r)) 118 + 119 + scored.sort(key=lambda x: x[0], reverse=True) 120 + failures = [r for _, r in scored[:top_n]] 121 + 122 + out = Path(out_path) 123 + out.parent.mkdir(parents=True, exist_ok=True) 124 + with out.open("w") as f: 125 + for r in failures: 126 + f.write(json.dumps(r) + "\n") 127 + 128 + print(f"Wrote {len(failures)} hard negatives to {out_path}") 129 + print(f"Failure rate in sample: {len(scored)}/{len(sample)} = {len(scored)/len(sample):.2%}") 130 + 131 + 132 + def main() -> None: 133 + parser = argparse.ArgumentParser() 134 + parser.add_argument("--checkpoint", default="checkpoints/deepseek/final") 135 + parser.add_argument("--sample", type=int, default=10_000) 136 + parser.add_argument("--top-n", type=int, default=3_000) 137 + parser.add_argument("--out", default="data/hnm_pool.jsonl") 138 + args = parser.parse_args() 139 + mine(args.checkpoint, args.sample, args.top_n, args.out) 140 + 141 + 142 + if __name__ == "__main__": 143 + main()
+2 -2
src/train_deepseek.py
··· 219 219 logging_steps=50, 220 220 eval_strategy="steps", 221 221 eval_steps=500, 222 - save_steps=500, 223 - save_total_limit=5, 222 + save_steps=250, 223 + save_total_limit=10, 224 224 load_best_model_at_end=False, 225 225 remove_unused_columns=False, # collator uses non-standard keys 226 226 report_to=["tensorboard"],
+91
src/train_hnm.py
··· 1 + """ 2 + HNM Step 2: one targeted fine-tuning pass on hard failures. 3 + 4 + Loads the baseline checkpoint and does a short (1 epoch, lower LR) pass 5 + on the records from data/hnm_pool.jsonl. Saves to checkpoints/hnm/final. 6 + 7 + Evaluate this checkpoint with: uv run evaluate --checkpoint checkpoints/hnm/final 8 + 9 + Usage: uv run train-hnm [--checkpoint checkpoints/baseline/final] 10 + [--pool data/hnm_pool.jsonl] 11 + """ 12 + 13 + import argparse 14 + import json 15 + import random 16 + from pathlib import Path 17 + 18 + from datasets import Dataset 19 + from PIL import Image 20 + from unsloth import FastVisionModel 21 + from unsloth.trainer import UnslothVisionDataCollator 22 + from trl import SFTTrainer, SFTConfig 23 + 24 + from .data import VAL_SPLITS, PROMPT, load_records, make_dataset 25 + 26 + 27 + def load_hnm_records(pool_path: str) -> list[dict]: 28 + records = [] 29 + for line in Path(pool_path).read_text().splitlines(): 30 + r = json.loads(line) 31 + if r.get("typst") and not r["typst"].startswith("ERROR:"): 32 + records.append(r) 33 + return records 34 + 35 + 36 + def main() -> None: 37 + parser = argparse.ArgumentParser() 38 + parser.add_argument("--checkpoint", default="checkpoints/baseline/final") 39 + parser.add_argument("--pool", default="data/hnm_pool.jsonl") 40 + args = parser.parse_args() 41 + 42 + model, processor = FastVisionModel.from_pretrained( 43 + args.checkpoint, 44 + load_in_4bit=True, 45 + use_gradient_checkpointing="unsloth", 46 + ) 47 + # Checkpoint is already a PEFT model; re-enable training mode 48 + model.train() 49 + 50 + hnm_records = load_hnm_records(args.pool) 51 + val_records = load_records(VAL_SPLITS, dedupe=False) 52 + print(f"HNM pool: {len(hnm_records):,} Val: {len(val_records):,}") 53 + 54 + hnm_ds = make_dataset(hnm_records, do_augment=True) 55 + val_ds = make_dataset(val_records, do_augment=False) 56 + 57 + trainer = SFTTrainer( 58 + model=model, 59 + tokenizer=processor, 60 + data_collator=UnslothVisionDataCollator(model, processor), 61 + train_dataset=hnm_ds, 62 + eval_dataset=val_ds, 63 + args=SFTConfig( 64 + per_device_train_batch_size=1, 65 + gradient_accumulation_steps=8, 66 + num_train_epochs=1, 67 + learning_rate=5e-5, # lower LR to avoid catastrophic forgetting 68 + warmup_ratio=0.1, 69 + lr_scheduler_type="cosine", 70 + bf16=True, 71 + fp16=False, 72 + dataloader_num_workers=0, 73 + logging_steps=50, 74 + eval_steps=200, 75 + save_steps=200, 76 + save_total_limit=2, 77 + load_best_model_at_end=True, 78 + metric_for_best_model="eval_loss", 79 + output_dir="checkpoints/hnm", 80 + dataset_kwargs={"skip_prepare_dataset": True}, 81 + ), 82 + ) 83 + 84 + trainer.train() 85 + model.save_pretrained("checkpoints/hnm/final") 86 + processor.save_pretrained("checkpoints/hnm/final") 87 + print("Saved to checkpoints/hnm/final") 88 + 89 + 90 + if __name__ == "__main__": 91 + main()