this repo has no description
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

Add train_deepseek: QLoRA fine-tuning script for DeepSeek-OCR-2

4-bit NF4 + LoRA r=16 targeting MLA attention and MLP layers.
Freezes SAM vision encoder (already strongly pretrained).
Custom DeepSeekTrainer subclass moves list-of-tuple images to device.
Smoke-test flag validates forward+backward before full training run.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

+235
+1
pyproject.toml
··· 23 23 24 24 [project.scripts] 25 25 train = "src.train:main" 26 + train-deepseek = "src.train_deepseek:main" 26 27 evaluate = "src.eval:main" 27 28 mine = "src.mine_failures:main" 28 29 train-hnm = "src.train_hnm:main"
+234
src/train_deepseek.py
··· 1 + """ 2 + QLoRA fine-tuning of DeepSeek-OCR-2 for Typst math OCR. 3 + 4 + Architecture notes: 5 + - DeepseekOCR2ForCausalLM extends DeepseekV2ForCausalLM with a SAM-based 6 + image encoder. Standard Trainer + PEFT work because forward() accepts 7 + `labels` and returns a loss in the CausalLM style. 8 + - MLA (Multi-head Latent Attention) in the LM backbone uses compressed KV 9 + projections. LoRA targets cover both the MLA attention matrices and the 10 + dense MLP gates. The vision encoder is frozen -- it is already strongly 11 + pretrained and contributes no new information for Typst notation style. 12 + - With 4-bit NF4 + LoRA r=16, peak VRAM on an RTX 3060 12 GB is roughly 13 + 6-8 GB for batch=1, grad_accum=8, seq≤512 tokens. 14 + 15 + Usage: 16 + uv run train-deepseek [--smoke-test] [--output-dir checkpoints/deepseek] 17 + uv run train-deepseek --model /path/to/local/checkpoint ... 18 + """ 19 + 20 + import argparse 21 + from pathlib import Path 22 + 23 + import torch 24 + from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training 25 + from transformers import ( 26 + AutoModel, 27 + AutoTokenizer, 28 + BitsAndBytesConfig, 29 + Trainer, 30 + TrainingArguments, 31 + ) 32 + 33 + from .collate_deepseek import DeepSeekOCRCollator 34 + from .data import TRAIN_SPLITS, VAL_SPLITS, load_records, make_dataset 35 + 36 + MODEL_ID = "deepseek-ai/DeepSeek-OCR-2" 37 + 38 + # DeepSeek V2 MLA attention matrices + shared/routed MLP gates. 39 + # MLA splits QKV into low-rank compressed projections (a/b pairs). 40 + # If any name is absent the model will raise; adjust to match the loaded arch. 41 + LORA_TARGET_MODULES = [ 42 + "q_a_proj", "q_b_proj", 43 + "kv_a_proj_with_mqa", "kv_b_proj", 44 + "o_proj", 45 + "gate_proj", "up_proj", "down_proj", 46 + ] 47 + 48 + 49 + # ── Model loading ────────────────────────────────────────────────────────────── 50 + 51 + def _bnb_config() -> BitsAndBytesConfig: 52 + return BitsAndBytesConfig( 53 + load_in_4bit=True, 54 + bnb_4bit_quant_type="nf4", 55 + bnb_4bit_compute_dtype=torch.bfloat16, 56 + bnb_4bit_use_double_quant=True, 57 + ) 58 + 59 + 60 + def load_model_and_tokenizer(model_id: str): 61 + tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) 62 + 63 + model = AutoModel.from_pretrained( 64 + model_id, 65 + quantization_config=_bnb_config(), 66 + trust_remote_code=True, 67 + use_safetensors=True, 68 + # eager: flash_attn2 + grad-checkpointing combination can be fragile; 69 + # switch to flash_attention_2 once smoke-test confirms stability. 70 + _attn_implementation="eager", 71 + ) 72 + return model, tokenizer 73 + 74 + 75 + def _freeze_vision_encoder(model) -> None: 76 + """Freeze SAM image encoder. Adapt substring if model attr name differs.""" 77 + frozen = 0 78 + for name, param in model.named_parameters(): 79 + if "image_encoder" in name or "vision_encoder" in name or "vit" in name: 80 + param.requires_grad_(False) 81 + frozen += param.numel() 82 + if frozen: 83 + print(f"Frozen {frozen / 1e6:.1f} M vision encoder params.") 84 + else: 85 + print("WARNING: no vision encoder params matched -- check module names.") 86 + 87 + 88 + # ── Custom Trainer (handles non-tensor 'images' key) ────────────────────────── 89 + 90 + class DeepSeekTrainer(Trainer): 91 + """ 92 + Overrides _prepare_inputs to move the `images` list-of-tuples to the 93 + correct device. The base Trainer's send_to_device recurses through dicts 94 + and lists but not tuples-inside-lists reliably across all TF versions. 95 + """ 96 + 97 + def _prepare_inputs(self, inputs: dict) -> dict: 98 + inputs = super()._prepare_inputs(inputs) 99 + if "images" in inputs: 100 + dev = self.args.device 101 + inputs["images"] = [ 102 + (lc.to(dev, dtype=torch.bfloat16), 103 + gv.to(dev, dtype=torch.bfloat16)) 104 + for lc, gv in inputs["images"] 105 + ] 106 + return inputs 107 + 108 + 109 + # ── Main ─────────────────────────────────────────────────────────────────────── 110 + 111 + def main() -> None: 112 + parser = argparse.ArgumentParser() 113 + parser.add_argument("--model", default=MODEL_ID, 114 + help="HF model ID or local path") 115 + parser.add_argument("--smoke-test", action="store_true", 116 + help="One forward+backward pass then exit") 117 + parser.add_argument("--output-dir", default="checkpoints/deepseek") 118 + parser.add_argument("--epochs", type=int, default=3) 119 + parser.add_argument("--lr", type=float, default=1e-4) 120 + parser.add_argument("--lora-r", type=int, default=16) 121 + args = parser.parse_args() 122 + 123 + print(f"Loading {args.model} ...") 124 + model, tokenizer = load_model_and_tokenizer(args.model) 125 + 126 + # prepare_model_for_kbit_training: casts LayerNorm + embeddings to fp32, 127 + # enables grad checkpointing on the model itself. 128 + model = prepare_model_for_kbit_training( 129 + model, use_gradient_checkpointing=True, 130 + ) 131 + 132 + _freeze_vision_encoder(model) 133 + 134 + lora_cfg = LoraConfig( 135 + r=args.lora_r, 136 + lora_alpha=args.lora_r * 2, 137 + target_modules=LORA_TARGET_MODULES, 138 + lora_dropout=0.0, 139 + bias="none", 140 + task_type="CAUSAL_LM", 141 + ) 142 + model = get_peft_model(model, lora_cfg) 143 + model.print_trainable_parameters() 144 + 145 + train_records = load_records(TRAIN_SPLITS, dedupe=True) 146 + val_records = load_records(VAL_SPLITS, dedupe=False) 147 + 148 + import random as _random 149 + _rng = _random.Random(42) 150 + if len(val_records) > 500: 151 + val_records = _rng.sample(val_records, 500) 152 + 153 + print(f"Train: {len(train_records):,} Val: {len(val_records):,}") 154 + 155 + train_ds = make_dataset(train_records, do_augment=True) 156 + val_ds = make_dataset(val_records, do_augment=False) 157 + 158 + collator = DeepSeekOCRCollator(tokenizer) 159 + 160 + if args.smoke_test: 161 + _run_smoke_test(model, collator, train_ds) 162 + return 163 + 164 + out_dir = args.output_dir 165 + training_args = TrainingArguments( 166 + output_dir=out_dir, 167 + per_device_train_batch_size=1, 168 + per_device_eval_batch_size=1, 169 + gradient_accumulation_steps=8, # effective batch size 8 170 + num_train_epochs=args.epochs, 171 + learning_rate=args.lr, 172 + warmup_steps=200, 173 + lr_scheduler_type="cosine", 174 + bf16=True, 175 + fp16=False, 176 + gradient_checkpointing=True, 177 + dataloader_num_workers=0, # PIL + lazy load requires 0 178 + logging_steps=50, 179 + eval_strategy="steps", 180 + eval_steps=500, 181 + save_steps=500, 182 + save_total_limit=5, 183 + load_best_model_at_end=False, 184 + remove_unused_columns=False, # collator uses non-standard keys 185 + report_to="none", 186 + ) 187 + 188 + has_checkpoint = any(Path(out_dir).glob("checkpoint-*")) if Path(out_dir).exists() else False 189 + 190 + trainer = DeepSeekTrainer( 191 + model=model, 192 + args=training_args, 193 + train_dataset=train_ds, 194 + eval_dataset=val_ds, 195 + data_collator=collator, 196 + ) 197 + 198 + trainer.train(resume_from_checkpoint=has_checkpoint) 199 + 200 + final_dir = f"{out_dir}/final" 201 + model.save_pretrained(final_dir) 202 + tokenizer.save_pretrained(final_dir) 203 + print(f"Saved to {final_dir}") 204 + 205 + 206 + def _run_smoke_test(model, collator, train_ds) -> None: 207 + print("Running smoke test (2 samples, 1 forward+backward) ...") 208 + batch = collator([train_ds[0], train_ds[1]]) 209 + 210 + dev = next(model.parameters()).device 211 + batch_cuda = {} 212 + for k, v in batch.items(): 213 + if k == "images": 214 + batch_cuda[k] = [ 215 + (lc.to(dev, dtype=torch.bfloat16), 216 + gv.to(dev, dtype=torch.bfloat16)) 217 + for lc, gv in v 218 + ] 219 + elif isinstance(v, torch.Tensor): 220 + batch_cuda[k] = v.to(dev) 221 + else: 222 + batch_cuda[k] = v 223 + 224 + model.train() 225 + out = model(**batch_cuda) 226 + loss = out.loss 227 + print(f" forward OK -- loss: {loss.item():.4f}") 228 + loss.backward() 229 + print(" backward OK") 230 + print("Smoke test passed.") 231 + 232 + 233 + if __name__ == "__main__": 234 + main()