this repo has no description
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

Fix DeepSeek-OCR-2 for training; load from local patched module

Two patches to modeling_deepseekocr2.py for training compatibility:
1. .clone(): breaks autograd leaf-variable link so masked_scatter_ on
inputs_embeds slice doesn't raise during backprop
2. .to(bfloat16): matches vision encoder dtype (prepare_model_for_kbit_training
upcasts embedding table to fp32; vision encoder stays bfloat16)

train_deepseek.py now imports DeepseekOCR2ForCausalLM directly from the local
src/deepseek_ocr2 module instead of trust_remote_code -- weights still fetched
from hub, only the forward() code is local and version-controlled.

Smoke test: forward OK (loss 16.85), backward OK.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

+35 -11
+8 -3
src/deepseek_ocr2/modeling_deepseekocr2.py
··· 390 390 391 391 392 392 if inputs_embeds is None: 393 - # inputs_embeds = self.embed_tokens(input_ids) 394 - # inputs_embeds = self.embed_tokens(input_ids) 395 - inputs_embeds = self.get_input_embeddings()(input_ids) 393 + # .clone().to(bfloat16): two training-only fixes -- 394 + # 1. clone() breaks the autograd leaf-variable link so the 395 + # masked_scatter_ below does not raise "in-place on leaf". 396 + # 2. to(bfloat16) matches the dtype of image features produced 397 + # by the vision encoder (prepare_model_for_kbit_training 398 + # upcasts the embedding table to float32; the vision encoder 399 + # stays bfloat16 after our explicit cast in the training script). 400 + inputs_embeds = self.get_input_embeddings()(input_ids).clone().to(torch.bfloat16) 396 401 397 402 398 403
+27 -8
src/train_deepseek.py
··· 23 23 import torch 24 24 from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training 25 25 from transformers import ( 26 - AutoModel, 27 26 AutoTokenizer, 28 27 BitsAndBytesConfig, 29 28 Trainer, ··· 32 31 33 32 from .collate_deepseek import DeepSeekOCRCollator 34 33 from .data import TRAIN_SPLITS, VAL_SPLITS, load_records, make_dataset 34 + from .deepseek_ocr2.modeling_deepseekocr2 import DeepseekOCR2ForCausalLM 35 35 36 36 MODEL_ID = "deepseek-ai/DeepSeek-OCR-2" 37 37 ··· 58 58 59 59 60 60 def load_model_and_tokenizer(model_id: str): 61 + # Tokenizer still loaded from hub (no code, just vocab/config files). 61 62 tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) 62 63 63 - model = AutoModel.from_pretrained( 64 + # Model loaded via our local patched class -- no trust_remote_code needed. 65 + # Weights still fetched from hub; only the Python forward() code is local. 66 + model = DeepseekOCR2ForCausalLM.from_pretrained( 64 67 model_id, 65 68 quantization_config=_bnb_config(), 66 - trust_remote_code=True, 67 69 use_safetensors=True, 68 70 # eager: flash_attn2 + grad-checkpointing combination can be fragile; 69 71 # switch to flash_attention_2 once smoke-test confirms stability. ··· 72 74 return model, tokenizer 73 75 74 76 75 - def _freeze_vision_encoder(model) -> None: 76 - """Freeze SAM image encoder. Adapt substring if model attr name differs.""" 77 + def _freeze_and_cast_vision_encoder(model) -> None: 78 + """ 79 + Freeze the SAM / Qwen2 vision encoders and projector, and cast all their 80 + parameters to bfloat16. 81 + 82 + BnB only quantizes nn.Linear layers; conv layers (e.g. patch_embed.proj in 83 + SAM) are left in float32. prepare_model_for_kbit_training does not cast 84 + them either, so the conv receives bfloat16 activations but has a float32 85 + bias -- RuntimeError. Casting to bf16 before training fixes this. 86 + 87 + Module structure (DeepseekOCR2Model): 88 + model.sam_model -- SAM ViT-B image encoder (conv + attention) 89 + model.qwen2_model -- Qwen2 decoder used as a second encoder 90 + model.projector -- MLP mapping vision features to LM embedding dim 91 + """ 92 + # Substrings that uniquely identify vision encoder params (post-PEFT naming) 93 + VISION_SUBSTRINGS = ("sam_model", "qwen2_model", "projector", "view_seperator") 94 + 77 95 frozen = 0 78 96 for name, param in model.named_parameters(): 79 - if "image_encoder" in name or "vision_encoder" in name or "vit" in name: 97 + if any(s in name for s in VISION_SUBSTRINGS): 80 98 param.requires_grad_(False) 99 + param.data = param.data.to(torch.bfloat16) 81 100 frozen += param.numel() 82 101 if frozen: 83 - print(f"Frozen {frozen / 1e6:.1f} M vision encoder params.") 102 + print(f"Frozen + cast to bf16: {frozen / 1e6:.1f} M vision params.") 84 103 else: 85 104 print("WARNING: no vision encoder params matched -- check module names.") 86 105 ··· 129 148 model, use_gradient_checkpointing=True, 130 149 ) 131 150 132 - _freeze_vision_encoder(model) 151 + _freeze_and_cast_vision_encoder(model) 133 152 134 153 lora_cfg = LoraConfig( 135 154 r=args.lora_r,