this repo has no description
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

Remove previous deepseek-ocr 2 finetune attempt

-5826
-5
pyproject.toml
··· 25 25 26 26 [project.scripts] 27 27 train = "src.train:main" 28 - train-deepseek = "src.train_deepseek:main" 29 28 evaluate = "src.eval:main" 30 - eval-deepseek = "src.eval_deepseek:main" 31 - infer-debug = "src.infer_debug:main" 32 - mine = "src.mine_failures:main" 33 29 train-hnm = "src.train_hnm:main" 34 30 export = "src.export:main" 35 31 generate-typeset = "src.generate_typeset:main" 36 32 download-hw-fonts = "src.download_hw_fonts:main" 37 33 probe = "src.probe:main" 38 34 app = "src.app:main" 39 - probe-deepseek = "src.probe_deepseek:main" 40 35 review = "src.review_app:main" 41 36 apply-edits = "src.apply_edits:main" 42 37 search-labels = "src.search_labels:main"
-347
scripts/train_deepseek.py
··· 1 - #!/usr/bin/env python3 2 - # /// script 3 - # requires-python = ">=3.13" 4 - # dependencies = [ 5 - # "torch==2.6.0", 6 - # "torchvision", 7 - # "transformers==4.46.3", 8 - # "accelerate", 9 - # "peft", 10 - # "bitsandbytes", 11 - # "einops", 12 - # "addict", 13 - # "easydict", 14 - # "pillow", 15 - # "tqdm", 16 - # ] 17 - # /// 18 - """ 19 - QLoRA fine-tuning of DeepSeek-OCR-2 for Typst math OCR. 20 - 21 - The LM backbone (LlamaFlashAttention2 / LlamaAttention) is targeted with LoRA; 22 - the visual encoder runs frozen (the model's forward pass wraps it in no_grad). 23 - 24 - Usage: 25 - uv run scripts/train_deepseek.py \\ 26 - --train-manifest ../eff-mer/data/raster/mathwriting_Train/manifest.jsonl \\ 27 - --val-manifest ../eff-mer/data/raster/mathwriting_Val/manifest.jsonl \\ 28 - --output checkpoints/deepseek-typst \\ 29 - --steps 5000 30 - 31 - Optional: 32 - --lora-r 32 LoRA rank (default 32) 33 - --lr 2e-4 Learning rate 34 - --grad-accum 8 Gradient accumulation steps 35 - --save-steps 500 Checkpoint every N steps 36 - --val-steps 100 Validate every N steps (0 = disable) 37 - --n-val 64 Number of val examples per eval 38 - --bits 4 Quantisation: 4 (NF4) or 8 (INT8) 39 - --list-modules Print trainable module names and exit 40 - """ 41 - 42 - import argparse 43 - import json 44 - import math 45 - import os 46 - import random 47 - from pathlib import Path 48 - 49 - import torch 50 - import torch.nn as nn 51 - from PIL import Image, ImageOps 52 - from torchvision import transforms 53 - from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig 54 - from peft import LoraConfig, get_peft_model, PeftModel 55 - from tqdm import tqdm 56 - 57 - # --------------------------------------------------------------------------- 58 - # Constants matching infer() with crop_mode=False, image_size=768 59 - # --------------------------------------------------------------------------- 60 - MODEL_ID = "deepseek-ai/DeepSeek-OCR-2" 61 - IMAGE_TOKEN_ID = 128815 62 - BOS_ID = 0 63 - EOS_ID = 1 64 - IMAGE_SIZE = 768 # global view resolution 65 - BASE_SIZE = 1024 # base_size arg (only needed for zero-pad of unused crop tensor) 66 - PATCH_SIZE = 16 67 - DOWNSAMPLE_RATIO = 4 68 - # num_queries = ceil((IMAGE_SIZE // PATCH_SIZE) / DOWNSAMPLE_RATIO) = 12 69 - NUM_QUERIES = math.ceil((IMAGE_SIZE // PATCH_SIZE) / DOWNSAMPLE_RATIO) 70 - # image tokens per image: 12*12 + 1 = 145 71 - N_IMAGE_TOKENS = NUM_QUERIES * NUM_QUERIES + 1 72 - 73 - PROMPT_SUFFIX = "\nConvert this mathematical expression to Typst math notation. " 74 - 75 - _IMAGE_TRANSFORM = transforms.Compose([ 76 - transforms.ToTensor(), 77 - transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 78 - ]) 79 - _PAD_COLOR = (127, 127, 127) # int(0.5 * 255) 80 - 81 - 82 - # --------------------------------------------------------------------------- 83 - # Preprocessing 84 - # --------------------------------------------------------------------------- 85 - 86 - def make_sample(image_path: str, typst: str, tokenizer) -> dict: 87 - """ 88 - Build a training sample. 89 - 90 - Returns a dict with: 91 - input_ids LongTensor (seq_len,) 92 - labels LongTensor (seq_len,) prompt positions = -100 93 - image_ori bfloat16 Tensor (3, IMAGE_SIZE, IMAGE_SIZE) 94 - images_seq_mask bool Tensor (seq_len,) 95 - images_spatial_crop LongTensor (1, 2) 96 - """ 97 - image = Image.open(image_path).convert("RGB") 98 - global_view = ImageOps.pad(image, (IMAGE_SIZE, IMAGE_SIZE), color=_PAD_COLOR) 99 - image_ori = _IMAGE_TRANSFORM(global_view).to(torch.bfloat16) # (3, 768, 768) 100 - 101 - image_tokens = [IMAGE_TOKEN_ID] * N_IMAGE_TOKENS # 145 tokens 102 - 103 - prompt_ids = tokenizer.encode(PROMPT_SUFFIX, add_special_tokens=False) 104 - response_ids = tokenizer.encode(typst, add_special_tokens=False) + [EOS_ID] 105 - 106 - # Full sequence: BOS | image_tokens | prompt_suffix | response | EOS-already-in-response_ids 107 - input_ids = [BOS_ID] + image_tokens + prompt_ids + response_ids 108 - # Mask everything before the response with -100 109 - n_prompt = 1 + N_IMAGE_TOKENS + len(prompt_ids) 110 - labels = [-100] * n_prompt + response_ids 111 - 112 - images_seq_mask = ( 113 - [False] # BOS 114 - + [True] * N_IMAGE_TOKENS # image tokens 115 - + [False] * (len(prompt_ids) + len(response_ids)) 116 - ) 117 - 118 - return { 119 - "input_ids": torch.tensor(input_ids, dtype=torch.long), 120 - "labels": torch.tensor(labels, dtype=torch.long), 121 - "image_ori": image_ori, 122 - "images_seq_mask": torch.tensor(images_seq_mask, dtype=torch.bool), 123 - "images_spatial_crop": torch.tensor([[1, 1]], dtype=torch.long), 124 - } 125 - 126 - 127 - # --------------------------------------------------------------------------- 128 - # Data loading 129 - # --------------------------------------------------------------------------- 130 - 131 - def load_manifest(manifest_path: str) -> list[dict]: 132 - records = [] 133 - base = Path(manifest_path).parent 134 - for line in Path(manifest_path).read_text().splitlines(): 135 - r = json.loads(line) 136 - if not r.get("typst") or r["typst"].startswith("ERROR:"): 137 - continue 138 - records.append({ 139 - "image_path": str(base / r["image"]), 140 - "typst": r["typst"], 141 - }) 142 - return records 143 - 144 - 145 - # --------------------------------------------------------------------------- 146 - # Model loading 147 - # --------------------------------------------------------------------------- 148 - 149 - def load_model(model_id: str, bits: int): 150 - tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) 151 - 152 - if bits == 4: 153 - bnb_cfg = BitsAndBytesConfig( 154 - load_in_4bit=True, 155 - bnb_4bit_quant_type="nf4", 156 - bnb_4bit_compute_dtype=torch.bfloat16, 157 - bnb_4bit_use_double_quant=True, 158 - ) 159 - elif bits == 8: 160 - bnb_cfg = BitsAndBytesConfig(load_in_8bit=True) 161 - else: 162 - bnb_cfg = None 163 - 164 - kwargs = dict(trust_remote_code=True, use_safetensors=True) 165 - if bnb_cfg is not None: 166 - model = AutoModel.from_pretrained(model_id, quantization_config=bnb_cfg, **kwargs) 167 - else: 168 - model = AutoModel.from_pretrained(model_id, torch_dtype=torch.bfloat16, 169 - device_map="auto", **kwargs) 170 - 171 - return model, tokenizer 172 - 173 - 174 - # --------------------------------------------------------------------------- 175 - # LoRA setup 176 - # --------------------------------------------------------------------------- 177 - 178 - LORA_TARGETS = ["q_proj", "k_proj", "v_proj", "o_proj"] 179 - 180 - 181 - def apply_lora(model, lora_r: int) -> nn.Module: 182 - lora_cfg = LoraConfig( 183 - r=lora_r, 184 - lora_alpha=lora_r * 2, 185 - target_modules=LORA_TARGETS, 186 - lora_dropout=0.05, 187 - bias="none", 188 - ) 189 - model = get_peft_model(model, lora_cfg) 190 - model.print_trainable_parameters() 191 - return model 192 - 193 - 194 - # --------------------------------------------------------------------------- 195 - # Forward pass helper 196 - # --------------------------------------------------------------------------- 197 - 198 - def forward_sample(model, sample: dict, device: str = "cuda"): 199 - input_ids = sample["input_ids"].unsqueeze(0).to(device) 200 - labels = sample["labels"].unsqueeze(0).to(device) 201 - images_seq_mask = sample["images_seq_mask"].unsqueeze(0).to(device) 202 - images_spatial_crop = sample["images_spatial_crop"].to(device) 203 - image_ori = sample["image_ori"].unsqueeze(0).to(device) 204 - # Zeros for crop tensor: sum==0 triggers the no-crop branch in forward() 205 - images_crop = torch.zeros(1, 3, IMAGE_SIZE, IMAGE_SIZE, 206 - dtype=torch.bfloat16, device=device) 207 - 208 - with torch.autocast("cuda", dtype=torch.bfloat16): 209 - out = model( 210 - input_ids=input_ids, 211 - labels=labels, 212 - images=[(images_crop, image_ori)], 213 - images_seq_mask=images_seq_mask, 214 - images_spatial_crop=images_spatial_crop, 215 - use_cache=False, 216 - ) 217 - return out.loss 218 - 219 - 220 - # --------------------------------------------------------------------------- 221 - # Training loop 222 - # --------------------------------------------------------------------------- 223 - 224 - def train(args): 225 - print(f"Loading {args.model} at {args.bits}-bit ...") 226 - model, tokenizer = load_model(args.model, args.bits) 227 - print("Applying LoRA ...") 228 - model = apply_lora(model, args.lora_r) 229 - model.train() 230 - 231 - train_records = load_manifest(args.train_manifest) 232 - print(f"Training records: {len(train_records)}") 233 - val_records = [] 234 - if args.val_manifest: 235 - val_records = load_manifest(args.val_manifest) 236 - print(f"Val records: {len(val_records)}") 237 - 238 - optimizer = torch.optim.AdamW( 239 - (p for p in model.parameters() if p.requires_grad), 240 - lr=args.lr, 241 - weight_decay=0.01, 242 - ) 243 - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 244 - optimizer, T_max=args.steps 245 - ) 246 - 247 - output_dir = Path(args.output) 248 - output_dir.mkdir(parents=True, exist_ok=True) 249 - 250 - step = 0 251 - accum_loss = 0.0 252 - optimizer.zero_grad() 253 - 254 - pbar = tqdm(total=args.steps, desc="training") 255 - 256 - while step < args.steps: 257 - random.shuffle(train_records) 258 - for r in train_records: 259 - if step >= args.steps: 260 - break 261 - try: 262 - sample = make_sample(r["image_path"], r["typst"], tokenizer) 263 - except Exception as e: 264 - print(f" skip {r['image_path']}: {e}") 265 - continue 266 - 267 - loss = forward_sample(model, sample) / args.grad_accum 268 - loss.backward() 269 - accum_loss += loss.item() 270 - 271 - if (step + 1) % args.grad_accum == 0: 272 - torch.nn.utils.clip_grad_norm_( 273 - (p for p in model.parameters() if p.requires_grad), 1.0 274 - ) 275 - optimizer.step() 276 - scheduler.step() 277 - optimizer.zero_grad() 278 - 279 - pbar.set_postfix( 280 - loss=f"{accum_loss * args.grad_accum:.4f}", 281 - lr=f"{scheduler.get_last_lr()[0]:.2e}", 282 - ) 283 - step += 1 284 - pbar.update(1) 285 - 286 - if args.val_steps and step % args.val_steps == 0 and val_records: 287 - model.eval() 288 - subset = random.sample(val_records, min(args.n_val, len(val_records))) 289 - val_loss = 0.0 290 - with torch.no_grad(): 291 - for vr in subset: 292 - try: 293 - vs = make_sample(vr["image_path"], vr["typst"], tokenizer) 294 - val_loss += forward_sample(model, vs).item() 295 - except Exception: 296 - pass 297 - print(f"\n[step {step}] val_loss = {val_loss / len(subset):.4f}") 298 - model.train() 299 - accum_loss = 0.0 300 - 301 - if args.save_steps and step % args.save_steps == 0: 302 - ckpt = output_dir / f"step-{step}" 303 - model.save_pretrained(ckpt) 304 - tokenizer.save_pretrained(ckpt) 305 - print(f"\nSaved checkpoint → {ckpt}") 306 - 307 - pbar.close() 308 - final = output_dir / "final" 309 - model.save_pretrained(final) 310 - tokenizer.save_pretrained(final) 311 - print(f"Done. Final checkpoint → {final}") 312 - 313 - 314 - # --------------------------------------------------------------------------- 315 - # Entry point 316 - # --------------------------------------------------------------------------- 317 - 318 - def main(): 319 - parser = argparse.ArgumentParser() 320 - parser.add_argument("--model", default=MODEL_ID) 321 - parser.add_argument("--bits", type=int, default=4, choices=[4, 8, 16]) 322 - parser.add_argument("--lora-r", type=int, default=32) 323 - parser.add_argument("--lr", type=float, default=2e-4) 324 - parser.add_argument("--grad-accum", type=int, default=8) 325 - parser.add_argument("--steps", type=int, default=5000) 326 - parser.add_argument("--save-steps", type=int, default=500) 327 - parser.add_argument("--val-steps", type=int, default=200) 328 - parser.add_argument("--n-val", type=int, default=64) 329 - parser.add_argument("--train-manifest", required=True) 330 - parser.add_argument("--val-manifest", default=None) 331 - parser.add_argument("--output", default="checkpoints/deepseek-typst") 332 - parser.add_argument("--list-modules", action="store_true", 333 - help="Print model module names and exit") 334 - args = parser.parse_args() 335 - 336 - if args.list_modules: 337 - model, _ = load_model(args.model, args.bits) 338 - for name, mod in model.named_modules(): 339 - if isinstance(mod, nn.Linear): 340 - print(f" {name} [{mod.in_features} → {mod.out_features}]") 341 - return 342 - 343 - train(args) 344 - 345 - 346 - if __name__ == "__main__": 347 - main()
-78
src/backfill_tb.py
··· 1 - """ 2 - Backfill TensorBoard events from training stdout. 3 - 4 - Usage: 5 - # Save your terminal output to a file, then: 6 - uv run python -m src.backfill_tb training.log checkpoints/deepseek/runs/backfill 7 - 8 - The log file should contain lines like: 9 - {'loss': 9.56, 'grad_norm': 11.6, 'learning_rate': 9.99e-05, 'epoch': 0.04} 10 - {'eval_loss': 2.72, 'eval_runtime': 542.7, ..., 'epoch': 0.04} 11 - 12 - Train entries are assumed to occur every logging_steps=50 steps. 13 - Eval entries are assumed to occur every eval_steps=500 steps. 14 - Adjust LOGGING_STEPS / EVAL_STEPS below if different. 15 - """ 16 - 17 - import ast 18 - import re 19 - import sys 20 - from pathlib import Path 21 - 22 - LOGGING_STEPS = 50 23 - EVAL_STEPS = 500 24 - 25 - 26 - def parse_log(path: str) -> tuple[list[dict], list[dict]]: 27 - train_entries: list[dict] = [] 28 - eval_entries: list[dict] = [] 29 - dict_re = re.compile(r"^\{.*\}$") 30 - for line in Path(path).read_text().splitlines(): 31 - line = line.strip() 32 - if not dict_re.match(line): 33 - continue 34 - try: 35 - d = ast.literal_eval(line) 36 - except Exception: 37 - continue 38 - if not isinstance(d, dict): 39 - continue 40 - if "eval_loss" in d: 41 - eval_entries.append(d) 42 - elif "loss" in d: 43 - train_entries.append(d) 44 - return train_entries, eval_entries 45 - 46 - 47 - def write_events(train_entries, eval_entries, out_dir: str) -> None: 48 - from torch.utils.tensorboard import SummaryWriter 49 - writer = SummaryWriter(log_dir=out_dir) 50 - 51 - for i, d in enumerate(train_entries): 52 - step = (i + 1) * LOGGING_STEPS 53 - writer.add_scalar("train/loss", d["loss"], step) 54 - writer.add_scalar("train/grad_norm", d["grad_norm"], step) 55 - writer.add_scalar("train/learning_rate", d["learning_rate"], step) 56 - writer.add_scalar("train/epoch", d["epoch"], step) 57 - 58 - for i, d in enumerate(eval_entries): 59 - step = (i + 1) * EVAL_STEPS 60 - writer.add_scalar("eval/loss", d["eval_loss"], step) 61 - writer.add_scalar("eval/epoch", d["epoch"], step) 62 - 63 - writer.close() 64 - print(f"Wrote {len(train_entries)} train + {len(eval_entries)} eval entries to {out_dir}") 65 - 66 - 67 - def main() -> None: 68 - if len(sys.argv) < 3: 69 - print("Usage: python -m src.backfill_tb <log_file> <out_dir>") 70 - sys.exit(1) 71 - log_file, out_dir = sys.argv[1], sys.argv[2] 72 - train_entries, eval_entries = parse_log(log_file) 73 - print(f"Parsed {len(train_entries)} train entries, {len(eval_entries)} eval entries") 74 - write_events(train_entries, eval_entries, out_dir) 75 - 76 - 77 - if __name__ == "__main__": 78 - main()
-167
src/collate_deepseek.py
··· 1 - """ 2 - Data collator for DeepSeek-OCR-2 fine-tuning. 3 - 4 - Preprocesses image+typst pairs into the tensor format expected by 5 - DeepseekOCR2ForCausalLM.forward(): 6 - 7 - - PIL image letterbox-padded to 768×768 (neutral gray fill, mean=0.5) 8 - - 145 image tokens at the start of the sequence (12² spatial + 1 separator) 9 - - prompt tokens follow, then response tokens + EOS 10 - - labels mask the image+prompt prefix with -100 (train on response only) 11 - 12 - Design decisions: 13 - - crop_mode=False: single global view, no dynamic tiling. Our math 14 - expression images are small fragments -- tiling adds no information and 15 - wastes tokens. 16 - - image_size=768: fixed encoder input (SAM ViT-B, patch=16, downsample=4 17 - → 12×12 grid = 144 + 1 separator = 145 tokens). 18 - - The `images` argument to forward() is a list of (local_crops, global_view) 19 - tuples. With crop_mode=False there are no distinct local crops; we pass 20 - the same 768×768 tensor for both slots. This matches the single-tile 21 - case in infer() but should be validated against the model source once 22 - available locally. 23 - """ 24 - 25 - import math 26 - 27 - import torch 28 - from PIL import Image, ImageOps 29 - from torch.nn.utils.rnn import pad_sequence 30 - from torchvision import transforms 31 - 32 - # ── Constants ────────────────────────────────────────────────────────────────── 33 - 34 - IMAGE_TOKEN_ID = 128815 # <image> placeholder token in the DeepSeek-OCR-2 vocab 35 - IMAGE_SIZE = 768 # fixed encoder input resolution (px) 36 - PATCH_SIZE = 16 # SAM ViT-B patch size (px) 37 - DOWNSAMPLE = 4 # spatial downsample ratio inside the encoder 38 - 39 - # Number of image tokens inserted per image: 40 - # ceil(768 / 16 / 4) = 12 → 12×12 + 1 separator = 145 41 - _N_GRID = math.ceil(IMAGE_SIZE / PATCH_SIZE / DOWNSAMPLE) # 12 42 - N_IMAGE_TOKENS = _N_GRID * _N_GRID + 1 # 145 43 - 44 - _MEAN = (0.5, 0.5, 0.5) 45 - _STD = (0.5, 0.5, 0.5) 46 - _PAD_COLOR = tuple(int(x * 255) for x in _MEAN) # (127, 127, 127) gray 47 - 48 - _transform = transforms.Compose([ 49 - transforms.ToTensor(), 50 - transforms.Normalize(mean=_MEAN, std=_STD), 51 - ]) 52 - 53 - # Prompt appended after the image tokens. Leading \n because <image> is the 54 - # very first token and the model expects a newline before the text instruction. 55 - _PROMPT = "\nTranscribe this image to Typst notation.\n" 56 - 57 - 58 - # ── Image preprocessing ──────────────────────────────────────────────────────── 59 - 60 - def preprocess_image(img: Image.Image) -> torch.Tensor: 61 - """ 62 - Letterbox-pad a PIL image to IMAGE_SIZE×IMAGE_SIZE, normalize. 63 - 64 - ImageOps.pad preserves aspect ratio by centering and filling borders with 65 - _PAD_COLOR, so non-square images are not distorted. 66 - 67 - Returns float16 tensor of shape [3, IMAGE_SIZE, IMAGE_SIZE]. 68 - """ 69 - img = img.convert("RGB") 70 - img = ImageOps.pad(img, (IMAGE_SIZE, IMAGE_SIZE), color=_PAD_COLOR) 71 - return _transform(img).to(torch.bfloat16) 72 - 73 - 74 - # ── Collator ─────────────────────────────────────────────────────────────────── 75 - 76 - class DeepSeekOCRCollator: 77 - """ 78 - Collates MathOCRDataset items into inputs for DeepseekOCR2ForCausalLM. 79 - 80 - Each dataset item has the structure produced by MathOCRDataset.__getitem__: 81 - { 82 - "messages": [ 83 - {"role": "user", "content": [{"type": "image", "image": <PIL>}, 84 - {"type": "text", "text": PROMPT}]}, 85 - {"role": "assistant", "content": [{"type": "text", "text": <typst>}]}, 86 - ] 87 - } 88 - 89 - Output dict keys match DeepseekOCR2ForCausalLM.forward() signature: 90 - input_ids, attention_mask, labels, 91 - images, images_seq_mask, images_spatial_crop 92 - """ 93 - 94 - def __init__(self, tokenizer) -> None: 95 - self.tokenizer = tokenizer 96 - self.prompt_ids = tokenizer.encode(_PROMPT, add_special_tokens=False) 97 - self.eos_id = tokenizer.eos_token_id 98 - self.pad_id = (tokenizer.pad_token_id 99 - if tokenizer.pad_token_id is not None 100 - else tokenizer.eos_token_id) 101 - 102 - def __call__(self, batch: list[dict]) -> dict: 103 - input_ids_list = [] 104 - labels_list = [] 105 - seq_mask_list = [] 106 - img_tensors = [] 107 - spatial_crops = [] 108 - 109 - for item in batch: 110 - user_content = item["messages"][0]["content"] 111 - pil_img = user_content[0]["image"] # PIL Image 112 - typst = item["messages"][1]["content"][0]["text"] # target string 113 - 114 - # Image → tensor ─────────────────────────────────────────────────── 115 - img_t = preprocess_image(pil_img) # [3, 768, 768] bfloat16 116 - img_tensors.append(img_t) 117 - 118 - # Token sequence ─────────────────────────────────────────────────── 119 - # Layout: [img×145] [prompt] [response] [EOS] 120 - response_ids = self.tokenizer.encode(typst, add_special_tokens=False) 121 - img_ids = [IMAGE_TOKEN_ID] * N_IMAGE_TOKENS 122 - ids = img_ids + self.prompt_ids + response_ids + [self.eos_id] 123 - 124 - # Labels: -100 on image+prompt; train only on response+EOS 125 - n_prefix = N_IMAGE_TOKENS + len(self.prompt_ids) 126 - lbl = [-100] * n_prefix + response_ids + [self.eos_id] 127 - 128 - # images_seq_mask: True at the N_IMAGE_TOKENS image token positions 129 - mask = [False] * len(ids) 130 - for i in range(N_IMAGE_TOKENS): 131 - mask[i] = True 132 - 133 - input_ids_list.append(torch.tensor(ids, dtype=torch.long)) 134 - labels_list.append(torch.tensor(lbl, dtype=torch.long)) 135 - seq_mask_list.append(torch.tensor(mask, dtype=torch.bool)) 136 - # [width_crops, height_crops] = [1, 1] for single global view 137 - spatial_crops.append([1, 1]) 138 - 139 - # Pad sequences ──────────────────────────────────────────────────────── 140 - input_ids = pad_sequence(input_ids_list, batch_first=True, 141 - padding_value=self.pad_id) 142 - labels = pad_sequence(labels_list, batch_first=True, 143 - padding_value=-100) 144 - seq_mask = pad_sequence(seq_mask_list, batch_first=True, 145 - padding_value=False) 146 - attn_mask = (input_ids != self.pad_id) 147 - 148 - # images: list of (local_crops, global_view) tuples, one per sample. 149 - # local_crops must be zeros so forward() takes the crop_mode=False else-branch, 150 - # producing exactly N_IMAGE_TOKENS=145 global features to match images_seq_mask. 151 - # Passing the real image for both slots triggers the if-branch which concatenates 152 - # local+global+sep = 289 features; masked_scatter_ then injects the wrong subset. 153 - imgs_stacked = torch.stack(img_tensors) # [B, 3, 768, 768] 154 - zeros = torch.zeros_like(imgs_stacked[0].unsqueeze(0)) 155 - images = [ 156 - (zeros, imgs_stacked[i].unsqueeze(0)) 157 - for i in range(len(batch)) 158 - ] 159 - 160 - return { 161 - "input_ids": input_ids, # [B, T] 162 - "attention_mask": attn_mask, # [B, T] 163 - "labels": labels, # [B, T] 164 - "images": images, # list[B] of (tensor, tensor) 165 - "images_seq_mask": seq_mask, # [B, T] 166 - "images_spatial_crop": torch.tensor(spatial_crops, dtype=torch.long), # [B, 2] 167 - }
src/deepseek_ocr2/__init__.py

This is a binary file and will not be displayed.

-210
src/deepseek_ocr2/configuration_deepseek_v2.py
··· 1 - from transformers.configuration_utils import PretrainedConfig 2 - from transformers.utils import logging 3 - 4 - logger = logging.get_logger(__name__) 5 - 6 - DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {} 7 - class DeepseekV2Config(PretrainedConfig): 8 - r""" 9 - This is the configuration class to store the configuration of a [`DeepseekV2Model`]. It is used to instantiate an DeepSeek 10 - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the 11 - defaults will yield a similar configuration to that of the DeepSeek-V2 with multi-latent attention. 12 - 13 - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 14 - documentation from [`PretrainedConfig`] for more information. 15 - 16 - 17 - Args: 18 - vocab_size (`int`, *optional*, defaults to 102400): 19 - Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the 20 - `inputs_ids` passed when calling [`DeepseekV2Model`] 21 - hidden_size (`int`, *optional*, defaults to 4096): 22 - Dimension of the hidden representations. 23 - intermediate_size (`int`, *optional*, defaults to 11008): 24 - Dimension of the MLP representations. 25 - moe_intermediate_size (`int`, *optional*, defaults to 1407): 26 - Dimension of the MoE representations. 27 - num_hidden_layers (`int`, *optional*, defaults to 32): 28 - Number of hidden layers in the Transformer decoder. 29 - num_attention_heads (`int`, *optional*, defaults to 32): 30 - Number of attention heads for each attention layer in the Transformer decoder. 31 - n_shared_experts (`int`, *optional*, defaults to None): 32 - Number of shared experts, None means dense model. 33 - n_routed_experts (`int`, *optional*, defaults to None): 34 - Number of routed experts, None means dense model. 35 - routed_scaling_factor (`float`, *optional*, defaults to 1.0): 36 - Scaling factor or routed experts. 37 - topk_method (`str`, *optional*, defaults to `gready`): 38 - Topk method used in routed gate. 39 - n_group (`int`, *optional*, defaults to None): 40 - Number of groups for routed experts. 41 - topk_group (`int`, *optional*, defaults to None): 42 - Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). 43 - num_experts_per_tok (`int`, *optional*, defaults to None): 44 - Number of selected experts, None means dense model. 45 - moe_layer_freq (`int`, *optional*, defaults to 1): 46 - The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. 47 - first_k_dense_replace (`int`, *optional*, defaults to 0): 48 - Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). 49 - \--k dense layers--/ 50 - norm_topk_prob (`bool`, *optional*, defaults to False): 51 - Whether to normalize the weights of the routed experts. 52 - scoring_func (`str`, *optional*, defaults to 'softmax'): 53 - Method of computing expert weights. 54 - aux_loss_alpha (`float`, *optional*, defaults to 0.001): 55 - Auxiliary loss weight coefficient. 56 - seq_aux = (`bool`, *optional*, defaults to True): 57 - Whether to compute the auxiliary loss for each individual sample. 58 - num_key_value_heads (`int`, *optional*): 59 - This is the number of key_value heads that should be used to implement Grouped Query Attention. If 60 - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 61 - `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When 62 - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 63 - by meanpooling all the original heads within that group. For more details checkout [this 64 - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to 65 - `num_attention_heads`. 66 - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 67 - The non-linear activation function (function or string) in the decoder. 68 - max_position_embeddings (`int`, *optional*, defaults to 2048): 69 - The maximum sequence length that this model might ever be used with. 70 - initializer_range (`float`, *optional*, defaults to 0.02): 71 - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 72 - rms_norm_eps (`float`, *optional*, defaults to 1e-06): 73 - The epsilon used by the rms normalization layers. 74 - use_cache (`bool`, *optional*, defaults to `True`): 75 - Whether or not the model should return the last key/values attentions (not used by all models). Only 76 - relevant if `config.is_decoder=True`. 77 - pad_token_id (`int`, *optional*): 78 - Padding token id. 79 - bos_token_id (`int`, *optional*, defaults to 1): 80 - Beginning of stream token id. 81 - eos_token_id (`int`, *optional*, defaults to 2): 82 - End of stream token id. 83 - pretraining_tp (`int`, *optional*, defaults to 1): 84 - Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this 85 - document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is 86 - necessary to ensure exact reproducibility of the pretraining results. Please refer to [this 87 - issue](https://github.com/pytorch/pytorch/issues/76232). 88 - tie_word_embeddings (`bool`, *optional*, defaults to `False`): 89 - Whether to tie weight embeddings 90 - rope_theta (`float`, *optional*, defaults to 10000.0): 91 - The base period of the RoPE embeddings. 92 - rope_scaling (`Dict`, *optional*): 93 - Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling 94 - strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is 95 - `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update 96 - `max_position_embeddings` to the expected new maximum. 97 - attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): 98 - Whether to use a bias in the query, key, value and output projection layers during self-attention. 99 - attention_dropout (`float`, *optional*, defaults to 0.0): 100 - The dropout ratio for the attention probabilities. 101 - use_mla (`bool`, *optional*, defaults to `True`): Use multi-latent attention or multi-head attention. If True, 102 - the model will use multi-latent attention, otherwise, it will use multi-head attention. 103 - 104 - ```python 105 - >>> from transformers import DeepseekV2Model, DeepseekV2Config 106 - 107 - >>> # Initializing a Deepseek-V2 style configuration 108 - >>> configuration = DeepseekV2Config() 109 - 110 - >>> # Accessing the model configuration 111 - >>> configuration = model.config 112 - ```""" 113 - 114 - model_type = "deepseek_v2" 115 - keys_to_ignore_at_inference = ["past_key_values"] 116 - 117 - def __init__( 118 - self, 119 - vocab_size=102400, 120 - hidden_size=4096, 121 - intermediate_size=11008, 122 - moe_intermediate_size = 1407, 123 - num_hidden_layers=30, 124 - num_attention_heads=32, 125 - num_key_value_heads=32, 126 - n_shared_experts = None, 127 - n_routed_experts = None, 128 - ep_size = 1, 129 - routed_scaling_factor = 1.0, 130 - kv_lora_rank = 512, 131 - q_lora_rank = 1536, 132 - qk_rope_head_dim = 64, 133 - v_head_dim = 128, 134 - qk_nope_head_dim = 128, 135 - topk_method = 'gready', 136 - n_group = None, 137 - topk_group = None, 138 - num_experts_per_tok = None, 139 - moe_layer_freq = 1, 140 - first_k_dense_replace = 0, 141 - norm_topk_prob = False, 142 - scoring_func = 'softmax', 143 - aux_loss_alpha = 0.001, 144 - seq_aux = True, 145 - hidden_act="silu", 146 - max_position_embeddings=2048, 147 - initializer_range=0.02, 148 - rms_norm_eps=1e-6, 149 - use_cache=True, 150 - pad_token_id=None, 151 - bos_token_id=100000, 152 - eos_token_id=100001, 153 - pretraining_tp=1, 154 - tie_word_embeddings=False, 155 - rope_theta=10000.0, 156 - rope_scaling=None, 157 - attention_bias=False, 158 - attention_dropout=0.0, 159 - use_mla=True, 160 - **kwargs, 161 - ): 162 - self.vocab_size = vocab_size 163 - self.max_position_embeddings = max_position_embeddings 164 - self.hidden_size = hidden_size 165 - self.intermediate_size = intermediate_size 166 - self.moe_intermediate_size = moe_intermediate_size 167 - self.num_hidden_layers = num_hidden_layers 168 - self.num_attention_heads = num_attention_heads 169 - self.n_shared_experts = n_shared_experts 170 - self.n_routed_experts = n_routed_experts 171 - self.ep_size = ep_size 172 - self.routed_scaling_factor = routed_scaling_factor 173 - self.kv_lora_rank = kv_lora_rank 174 - self.q_lora_rank = q_lora_rank 175 - self.qk_rope_head_dim = qk_rope_head_dim 176 - self.v_head_dim = v_head_dim 177 - self.qk_nope_head_dim = qk_nope_head_dim 178 - self.topk_method = topk_method 179 - self.n_group = n_group 180 - self.topk_group = topk_group 181 - self.num_experts_per_tok = num_experts_per_tok 182 - self.moe_layer_freq = moe_layer_freq 183 - self.first_k_dense_replace = first_k_dense_replace 184 - self.norm_topk_prob = norm_topk_prob 185 - self.scoring_func = scoring_func 186 - self.aux_loss_alpha = aux_loss_alpha 187 - self.seq_aux = seq_aux 188 - # for backward compatibility 189 - if num_key_value_heads is None: 190 - num_key_value_heads = num_attention_heads 191 - 192 - self.num_key_value_heads = num_key_value_heads 193 - self.hidden_act = hidden_act 194 - self.initializer_range = initializer_range 195 - self.rms_norm_eps = float(rms_norm_eps) 196 - self.pretraining_tp = pretraining_tp 197 - self.use_cache = use_cache 198 - self.rope_theta = rope_theta 199 - self.rope_scaling = rope_scaling 200 - self.attention_bias = attention_bias 201 - self.attention_dropout = attention_dropout 202 - self.use_mla = use_mla 203 - 204 - super().__init__( 205 - pad_token_id=pad_token_id, 206 - bos_token_id=bos_token_id, 207 - eos_token_id=eos_token_id, 208 - tie_word_embeddings=tie_word_embeddings, 209 - **kwargs, 210 - )
-280
src/deepseek_ocr2/conversation.py
··· 1 - """ 2 - From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py 3 - """ 4 - 5 - import dataclasses 6 - from enum import IntEnum, auto 7 - from typing import Any, Dict, List 8 - 9 - 10 - class SeparatorStyle(IntEnum): 11 - """Separator styles.""" 12 - 13 - DeepSeek = auto() 14 - DeepSeekV2 = auto() 15 - PLAIN = auto() 16 - ALIGNMENT = auto() 17 - 18 - 19 - @dataclasses.dataclass 20 - class Conversation: 21 - """A class that manages prompt templates and keeps all conversation history.""" 22 - 23 - # The name of this template 24 - name: str 25 - # The template of the system prompt 26 - system_template: str = "{system_message}" 27 - # The system message 28 - system_message: str = "" 29 - # The names of two roles 30 - roles: List[str] = (("USER", "ASSISTANT"),) 31 - # All messages. Each item is (role, message). 32 - messages: List[List[str]] = () 33 - # The number of few shot examples 34 - offset: int = 0 35 - # The separator style and configurations 36 - sep_style: SeparatorStyle = SeparatorStyle.DeepSeek 37 - sep: str = "\n" 38 - sep2: str = None 39 - # Stop criteria (the default one is EOS token) 40 - stop_str: str = None 41 - # Stops generation if meeting any token in this list 42 - stop_token_ids: List[int] = None 43 - 44 - def get_prompt(self) -> str: 45 - """Get the prompt for generation.""" 46 - system_prompt = self.system_template.format(system_message=self.system_message) 47 - if self.sep_style == SeparatorStyle.DeepSeek: 48 - seps = [self.sep, self.sep2] 49 - if system_prompt == "" or system_prompt is None: 50 - ret = "" 51 - else: 52 - ret = system_prompt + seps[0] 53 - for i, (role, message) in enumerate(self.messages): 54 - if message: 55 - ret += role + ": " + message + seps[i % 2] 56 - else: 57 - ret += role + ":" 58 - return ret 59 - elif self.sep_style == SeparatorStyle.DeepSeekV2: 60 - seps = [self.sep, self.sep2] 61 - if system_prompt == "" or system_prompt is None: 62 - ret = "" 63 - else: 64 - ret = system_prompt + seps[0] 65 - for i, (role, message) in enumerate(self.messages): 66 - if message: 67 - if role == "User": 68 - ret += "<|sft▁begin|>\n" + message + self.sep #<|sft▁begin|>User Input<|sft▁end|>\nResponse<|end▁of▁sentence|> 69 - else: 70 - ret += message + self.sep2 71 - else: 72 - ret = ret 73 - return ret 74 - 75 - elif self.sep_style == SeparatorStyle.PLAIN: 76 - seps = [self.sep, self.sep2] 77 - ret = "" 78 - for i, (role, message) in enumerate(self.messages): 79 - if message: 80 - if type(message) is tuple: 81 - message, _, _ = message 82 - if i % 2 == 0: 83 - ret += message + seps[i % 2] 84 - else: 85 - ret += message + seps[i % 2] 86 - else: 87 - ret += "" 88 - return ret 89 - elif self.sep_style == SeparatorStyle.ALIGNMENT: 90 - seps = [self.sep, self.sep2] 91 - ret = "" 92 - for i, (role, message) in enumerate(self.messages): 93 - if message: 94 - if type(message) is tuple: 95 - message, _, _ = message 96 - if i % 2 == 0: 97 - ret += '<image>\n' + seps[i % 2] 98 - else: 99 - ret += message + seps[i % 2] 100 - else: 101 - ret += "" 102 - return ret 103 - else: 104 - raise ValueError(f"Invalid style: {self.sep_style}") 105 - 106 - def set_system_message(self, system_message: str): 107 - """Set the system message.""" 108 - self.system_message = system_message 109 - 110 - def append_message(self, role: str, message: str): 111 - """Append a new message.""" 112 - self.messages.append([role, message]) 113 - 114 - def update_last_message(self, message: str): 115 - """Update the last output. 116 - 117 - The last message is typically set to be None when constructing the prompt, 118 - so we need to update it in-place after getting the response from a model. 119 - """ 120 - self.messages[-1][1] = message 121 - 122 - def reset_message(self): 123 - """Reset a new message.""" 124 - self.messages = [] 125 - 126 - def to_gradio_chatbot(self): 127 - """Convert the conversation to gradio chatbot format.""" 128 - ret = [] 129 - for i, (role, msg) in enumerate(self.messages[self.offset :]): 130 - if i % 2 == 0: 131 - ret.append([msg, None]) 132 - else: 133 - ret[-1][-1] = msg 134 - return ret 135 - 136 - def to_openai_api_messages(self): 137 - """Convert the conversation to OpenAI chat completion format.""" 138 - system_prompt = self.system_template.format(system_message=self.system_message) 139 - ret = [{"role": "system", "content": system_prompt}] 140 - 141 - for i, (_, msg) in enumerate(self.messages[self.offset :]): 142 - if i % 2 == 0: 143 - ret.append({"role": "user", "content": msg}) 144 - else: 145 - if msg is not None: 146 - ret.append({"role": "assistant", "content": msg}) 147 - return ret 148 - 149 - def copy(self): 150 - return Conversation( 151 - name=self.name, 152 - system_template=self.system_template, 153 - system_message=self.system_message, 154 - roles=self.roles, 155 - messages=[[x, y] for x, y in self.messages], 156 - offset=self.offset, 157 - sep_style=self.sep_style, 158 - sep=self.sep, 159 - sep2=self.sep2, 160 - stop_str=self.stop_str, 161 - stop_token_ids=self.stop_token_ids, 162 - ) 163 - 164 - def dict(self): 165 - return { 166 - "template_name": self.name, 167 - "system_message": self.system_message, 168 - "roles": self.roles, 169 - "messages": self.messages, 170 - "offset": self.offset, 171 - } 172 - 173 - 174 - # A global registry for all conversation templates 175 - conv_templates: Dict[str, Conversation] = {} 176 - 177 - 178 - def register_conv_template(template: Conversation, override: bool = False): 179 - """Register a new conversation template.""" 180 - if not override: 181 - assert template.name not in conv_templates, f"{template.name} has been registered." 182 - 183 - conv_templates[template.name] = template 184 - 185 - 186 - def get_conv_template(name: str) -> Conversation: 187 - """Get a conversation template.""" 188 - return conv_templates[name].copy() 189 - 190 - 191 - register_conv_template( 192 - Conversation( 193 - name="deepseek", 194 - system_template="{system_message}", 195 - # system_message="You are a helpful assistant. Please answer truthfully and write out your " 196 - # "thinking step by step to be sure you get the right answer.", 197 - system_message="", 198 - roles=("<|User|>", "<|Assistant|>"), 199 - messages=(), 200 - offset=0, 201 - sep_style=SeparatorStyle.DeepSeek, 202 - sep="\n\n", 203 - sep2="<|end▁of▁sentence|>", 204 - stop_token_ids=[100001], 205 - stop_str=["User:", "<|end▁of▁sentence|>"] 206 - ) 207 - ) 208 - register_conv_template( 209 - Conversation( 210 - name="deepseekv2", 211 - system_template="{system_message}", 212 - # system_message="You are a helpful assistant. Please answer truthfully and write out your " 213 - # "thinking step by step to be sure you get the right answer.", 214 - system_message="", 215 - roles=("<|User|>", "<|Assistant|>"), 216 - messages=(), 217 - offset=0, 218 - sep_style=SeparatorStyle.DeepSeek, 219 - sep="", 220 - sep2="<|end▁of▁sentence|>", 221 - stop_token_ids=[100001], 222 - stop_str=["User:", "<|end▁of▁sentence|>"] 223 - ) 224 - ) 225 - 226 - 227 - register_conv_template( 228 - Conversation( 229 - name="plain", 230 - system_template="", 231 - system_message="", 232 - roles=("", ""), 233 - messages=(), 234 - offset=0, 235 - sep_style=SeparatorStyle.PLAIN, 236 - sep="", 237 - sep2="", 238 - stop_token_ids=[100001], 239 - stop_str=['</s>'], 240 - ) 241 - ) 242 - 243 - 244 - register_conv_template( 245 - Conversation( 246 - name="alignment", 247 - system_template="", 248 - system_message="", 249 - roles=("", ""), 250 - messages=(), 251 - offset=0, 252 - sep_style=SeparatorStyle.ALIGNMENT, 253 - sep="", 254 - sep2="", 255 - stop_token_ids=[100001], 256 - stop_str=['</s>'], 257 - ) 258 - ) 259 - 260 - 261 - if __name__ == "__main__": 262 - print("deepseek template:") 263 - conv = get_conv_template("deepseek") 264 - conv.append_message(conv.roles[0], "Hello!") 265 - conv.append_message(conv.roles[1], "Hi! This is Tony.") 266 - conv.append_message(conv.roles[0], "Who are you?") 267 - conv.append_message(conv.roles[1], "I am a helpful assistant.") 268 - conv.append_message(conv.roles[0], "How are you?") 269 - conv.append_message(conv.roles[1], None) 270 - print(conv.get_prompt()) 271 - 272 - print("deepseekv2 template:") 273 - conv = get_conv_template("deepseekv2") 274 - conv.append_message(conv.roles[0], "Hello!") 275 - conv.append_message(conv.roles[1], "Hi! This is Tony.") 276 - conv.append_message(conv.roles[0], "Who are you?") 277 - conv.append_message(conv.roles[1], "I am a helpful assistant.") 278 - conv.append_message(conv.roles[0], "How are you?") 279 - conv.append_message(conv.roles[1], None) 280 - print(conv.get_prompt())
-1015
src/deepseek_ocr2/deepencoderv2.py
··· 1 - import torch.nn as nn 2 - import torch 3 - import torch.nn.functional as F 4 - import copy 5 - 6 - 7 - from typing import Optional, Tuple 8 - 9 - # from megatron.model import LayerNorm 10 - 11 - import transformers 12 - 13 - 14 - from typing import Optional, Tuple, Type 15 - from functools import partial 16 - 17 - 18 - 19 - class MlpProjector(nn.Module): 20 - 21 - def __init__(self, cfg): 22 - 23 - super().__init__() 24 - 25 - self.cfg = cfg 26 - 27 - if cfg.projector_type == "identity": 28 - modules = nn.Identity() 29 - 30 - elif cfg.projector_type == "linear": 31 - modules = nn.Linear(cfg.input_dim, cfg.n_embed) 32 - 33 - elif cfg.projector_type == "mlp_gelu": 34 - mlp_depth = cfg.get("depth", 1) 35 - modules = [nn.Linear(cfg.input_dim, cfg.n_embed)] 36 - for _ in range(1, mlp_depth): 37 - modules.append(nn.GELU()) 38 - modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) 39 - modules = nn.Sequential(*modules) 40 - 41 - elif cfg.projector_type == "normlayer_downsample_mlp_gelu": 42 - mlp_depth = cfg.get("depth", 1) 43 - mlp_ratio = cfg.get("mlp_ratio", 1) 44 - modules = [ 45 - nn.LayerNorm(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio), 46 - nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio) 47 - ] 48 - for _ in range(1, mlp_depth - 1): 49 - modules.append(nn.GELU()) 50 - modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio)) 51 - modules.append(nn.GELU()) 52 - modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed)) 53 - modules = nn.Sequential(*modules) 54 - 55 - elif cfg.projector_type == "downsample_mlp_gelu": 56 - mlp_depth = cfg.get("depth", 1) 57 - mlp_ratio = cfg.get("mlp_ratio", 1) 58 - modules = [nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)] 59 - for _ in range(1, mlp_depth - 1): 60 - modules.append(nn.GELU()) 61 - modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio)) 62 - modules.append(nn.GELU()) 63 - modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed)) 64 - modules = nn.Sequential(*modules) 65 - 66 - elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu": 67 - mlp_depth = cfg.get("depth", 1) 68 - self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) 69 - self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) 70 - 71 - modules = [] 72 - for _ in range(1, mlp_depth): 73 - modules.append(nn.GELU()) 74 - modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) 75 - modules = nn.Sequential(*modules) 76 - 77 - elif cfg.projector_type == "hybrid_split_feature_mlp_gelu": 78 - mlp_depth = cfg.get("depth", 1) 79 - channel_div = cfg.get("channel_div", 0.5) 80 - self.high_up_proj = nn.Linear(cfg.input_dim[0], int(cfg.n_embed * channel_div)) 81 - self.low_up_proj = nn.Linear(cfg.input_dim[1], cfg.n_embed - int(cfg.n_embed * channel_div)) 82 - 83 - modules = [] 84 - for _ in range(1, mlp_depth): 85 - modules.append(nn.GELU()) 86 - modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) 87 - modules = nn.Sequential(*modules) 88 - 89 - elif cfg.projector_type == "low_high_split_mlp_gelu": 90 - mlp_depth = cfg.get("depth", 1) 91 - modules = [] 92 - for _ in range(1, mlp_depth): 93 - modules.append(nn.GELU()) 94 - modules.append(nn.Linear(cfg.n_embed // 2, cfg.n_embed // 2)) 95 - modules = nn.Sequential(*modules) 96 - self.high_layers = nn.Sequential(*modules) 97 - self.low_layers = copy.deepcopy(modules) 98 - 99 - else: 100 - raise ValueError(f"Unknown projector type: {cfg.projector_type}") 101 - 102 - if cfg.get("token_pooling", False): 103 - self.token_pooling_layer = nn.Linear(cfg.input_dim * 4, cfg.input_dim) 104 - 105 - if cfg.get("conv_fusion_high_low_features", False): 106 - self.fusion_layer = nn.Linear(cfg.input_dim, cfg.input_dim) 107 - self.layers = modules 108 - 109 - def forward(self, x): 110 - if self.cfg.get("token_pooling", False): 111 - batch_size, wxh, channels = x.shape 112 - w = h = int(wxh**0.5) 113 - x = x.view(batch_size, w, h, channels) 114 - x = x.permute(0, 3, 1, 2) 115 - # import ipdb; ipdb.set_trace() 116 - patches = x.unfold(2, 2, 2).unfold(3, 2, 2) 117 - batch_size, channels, h_patches, w_patches, _, _ = patches.size() 118 - # 在通道维度上拼接 119 - patches = patches.contiguous().view(batch_size, channels, h_patches * w_patches, -1) 120 - 121 - # 通过线性层 122 - patches = patches.permute(0, 2, 1, 3).contiguous() 123 - patches = patches.view(batch_size, h_patches * w_patches, channels * 4) 124 - 125 - x = self.token_pooling_layer(patches) 126 - 127 - if self.cfg.get("conv_fusion_high_low_features", False): 128 - x = self.fusion_layer(x[:, 0]) + x[:, 1] 129 - 130 - if self.cfg.projector_type == 'low_high_hybrid_split_mlp_gelu': 131 - high_x, low_x = x[0], x[1] 132 - high_x = self.high_up_proj(high_x) 133 - low_x = self.low_up_proj(low_x) 134 - x = torch.concat([high_x, low_x], dim=-1) 135 - 136 - if self.cfg.projector_type == 'hybrid_split_feature_mlp_gelu': 137 - high_x = x[...,:self.cfg.input_dim[0]] 138 - low_x = x[...,self.cfg.input_dim[0]:] 139 - high_x = self.high_up_proj(high_x) 140 - low_x = self.low_up_proj(low_x) 141 - x = torch.concat([high_x, low_x], dim=-1) 142 - 143 - if self.cfg.projector_type == 'low_high_split_mlp_gelu': 144 - high_x, low_x = x[0], x[1] 145 - high_x = self.high_layers(high_x) 146 - low_x = self.low_layers(low_x) 147 - x = torch.concat([high_x, low_x], dim=-1) 148 - return x 149 - 150 - if self.cfg.projector_type == 'downsample_mlp_gelu' or self.cfg.projector_type == 'normlayer_downsample_mlp_gelu': 151 - bs, hw, input_dim = x.shape 152 - h = w = int((hw) ** 0.5) 153 - 154 - """compute padding""" 155 - if h % self.cfg.downsample_ratio: 156 - pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio 157 - else: 158 - pad = 0 159 - x = x.reshape(bs, h, w, input_dim) 160 - if pad > 0: 161 - x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0) 162 - 163 - """4 to 1 concat""" 164 - x = x.permute(0, 3, 1, 2) # B, C, H, W 165 - x = F.unfold(x, kernel_size=self.cfg.downsample_ratio, stride=self.cfg.downsample_ratio, padding=0) # B, C*4, HW // 4 166 - x = x.permute(0, 2, 1) 167 - 168 - return self.layers(x) 169 - 170 - @staticmethod 171 - def get_flops_per_sample(cfg): 172 - if cfg.projector_type == "linear": 173 - fwd = 2 * cfg.input_dim * cfg.n_embed 174 - 175 - elif "mlp_gelu" in cfg.projector_type : 176 - mlp_depth = cfg.get("depth", 1) 177 - downsample_ratio = cfg.get("downsample_ratio", 1) 178 - input_dim = sum(cfg.input_dim) if isinstance(cfg.input_dim, list) else cfg.input_dim 179 - input_dim = input_dim * downsample_ratio * downsample_ratio 180 - fwd = 2 * input_dim * cfg.n_embed + (mlp_depth - 1) * 2 * cfg.n_embed * cfg.n_embed 181 - else: 182 - fwd = 0 183 - 184 - return fwd * 3 185 - 186 - 187 - #===================qwen2================================ 188 - 189 - class CustomQwen2Decoder(nn.Module): 190 - """ 191 - Qwen2 visual encoder 192 - non-causal attention + causal attention 193 - token_type_ids :0=non-causal, 1=causal 194 - """ 195 - 196 - def __init__( 197 - self, 198 - decoder_layer: int = 24, 199 - max_position_embeddings: int = 131072, 200 - hidden_dimension: int = 896, 201 - num_attention_heads: int = 14, 202 - num_key_value_heads: int = 2, 203 - intermediate_size: int = 4864, 204 - vocab_size: int = 151936, 205 - attn_implementation: str = "sdpa", # ⭐ 206 - rms_norm_eps: float = 1e-06, 207 - rope_theta: float = 1000000.0, 208 - attention_dropout: float = 0.0, 209 - hidden_act: str = "silu", 210 - initializer_range: float = 0.02, 211 - ): 212 - super().__init__() 213 - 214 - # attn_implementation check 215 - if attn_implementation == "flash_attention_2": 216 - raise ValueError( 217 - "CustomQwen2Decoder do not support flash_attention_2," 218 - "new attention mask needs 'sdpa' or 'eager'" 219 - ) 220 - 221 - # load 222 - Qwen2Model = getattr(transformers.models.qwen2.modeling_qwen2, 'Qwen2Model') 223 - Qwen2Config = getattr(transformers, 'Qwen2Config') 224 - 225 - # config 226 - config = Qwen2Config( 227 - hidden_size=hidden_dimension, 228 - num_hidden_layers=decoder_layer, 229 - num_attention_heads=num_attention_heads, 230 - num_key_value_heads=num_key_value_heads, 231 - intermediate_size=intermediate_size, 232 - max_position_embeddings=max_position_embeddings, 233 - vocab_size=vocab_size, 234 - rms_norm_eps=rms_norm_eps, 235 - rope_theta=rope_theta, 236 - attention_dropout=attention_dropout, 237 - hidden_act=hidden_act, 238 - initializer_range=initializer_range, 239 - _attn_implementation=attn_implementation, # ⭐ 240 - ) 241 - 242 - # 243 - self.model = self._create_custom_model(Qwen2Model, config) 244 - 245 - del self.model.embed_tokens 246 - 247 - def _create_custom_model(self, Qwen2Model, config): 248 - """ Qwen2Model """ 249 - 250 - class CustomQwen2ModelInner(Qwen2Model): 251 - 252 - 253 - def forward( 254 - self, 255 - input_ids=None, 256 - attention_mask=None, 257 - position_ids=None, 258 - past_key_values=None, 259 - inputs_embeds=None, 260 - token_type_ids=None, # ⭐ 261 - use_cache=None, 262 - output_attentions=None, 263 - output_hidden_states=None, 264 - return_dict=None, 265 - cache_position=None, 266 - ): 267 - # token_type_ids 268 - self._current_token_type_ids = token_type_ids 269 - 270 - outputs = super().forward( 271 - input_ids=input_ids, 272 - attention_mask=attention_mask, 273 - position_ids=position_ids, 274 - past_key_values=past_key_values, 275 - inputs_embeds=inputs_embeds, 276 - use_cache=use_cache, 277 - output_attentions=output_attentions, 278 - output_hidden_states=output_hidden_states, 279 - return_dict=return_dict, 280 - cache_position=cache_position, 281 - ) 282 - 283 - return outputs 284 - 285 - def _update_causal_mask( 286 - self, 287 - attention_mask, 288 - input_tensor, 289 - cache_position, 290 - past_key_values, 291 - output_attentions, 292 - ): 293 - dtype, device = input_tensor.dtype, input_tensor.device 294 - min_dtype = torch.finfo(dtype).min 295 - batch_size, sequence_length = input_tensor.shape[0], input_tensor.shape[1] 296 - 297 - token_type_ids = self._current_token_type_ids 298 - 299 - # attention mask 300 - causal_mask = self._create_custom_4d_mask( 301 - sequence_length=sequence_length, 302 - dtype=dtype, 303 - device=device, 304 - batch_size=batch_size, 305 - token_type_ids=token_type_ids, 306 - ) 307 - 308 - # padding mask 309 - if attention_mask is not None and attention_mask.dim() == 2: 310 - padding_mask = attention_mask[:, None, None, :].to(dtype=dtype) 311 - padding_mask = (1.0 - padding_mask) * min_dtype 312 - causal_mask = causal_mask + padding_mask 313 - 314 - return causal_mask 315 - 316 - def _create_custom_4d_mask( 317 - self, 318 - sequence_length, 319 - dtype, 320 - device, 321 - batch_size, 322 - token_type_ids, 323 - ): 324 - min_dtype = torch.finfo(dtype).min 325 - 326 - masks = [] 327 - for b in range(batch_size): 328 - mask = torch.full( 329 - (sequence_length, sequence_length), 330 - fill_value=min_dtype, 331 - dtype=dtype, 332 - device=device 333 - ) 334 - 335 - type_ids = token_type_ids[b] 336 - 337 - image_positions = (type_ids == 0).nonzero(as_tuple=True)[0] 338 - text_positions = (type_ids == 1).nonzero(as_tuple=True)[0] 339 - 340 - # non-casual 341 - if len(image_positions) > 0: 342 - mask[image_positions[:, None], image_positions] = 0.0 343 - 344 - # causal 345 - for i, text_pos in enumerate(text_positions): 346 - if len(image_positions) > 0: 347 - mask[text_pos, image_positions] = 0.0 348 - mask[text_pos, text_positions[:i+1]] = 0.0 349 - 350 - masks.append(mask) 351 - 352 - mask = torch.stack(masks, dim=0).unsqueeze(1) 353 - return mask 354 - 355 - return CustomQwen2ModelInner(config) 356 - 357 - def forward( 358 - self, 359 - inputs_embeds, 360 - token_type_ids, 361 - attention_mask=None, 362 - **kwargs 363 - ): 364 - """ 365 - Args: 366 - inputs_embeds: [batch_size, seq_len, hidden_dim] 367 - token_type_ids: [batch_size, seq_len], 0=non-causal, 1=causal 368 - attention_mask: [batch_size, seq_len], optional 369 - """ 370 - return self.model( 371 - inputs_embeds=inputs_embeds, 372 - token_type_ids=token_type_ids, 373 - attention_mask=attention_mask, 374 - **kwargs 375 - ) 376 - 377 - 378 - 379 - 380 - 381 - # batch_size = 2 382 - # inputs_embeds = torch.randn(batch_size, 512, 896).cuda() 383 - 384 - # inputs_embeds = torch.randn(batch_size, 512, 896).cuda() 385 - # token_type_ids = torch.cat([ 386 - # torch.zeros(batch_size, 256, dtype=torch.long), 387 - # torch.ones(batch_size, 256, dtype=torch.long), 388 - # ], dim=1).cuda() 389 - 390 - # # start = time.time() 391 - # with torch.no_grad(): 392 - # outputs_sdpa = decoder_sdpa(inputs_embeds, token_type_ids) 393 - # print(outputs_sdpa[0].shape) 394 - # print(f"SDPA time: {time.time() - start:.4f}s") 395 - 396 - 397 - 398 - class Qwen2Decoder2Encoder(nn.Module): 399 - """ 400 - Decoder based on Multilingual BART 401 - Set the initial weights and configuration with a pretrained multilingual BART model, 402 - and modify the detailed configurations as a Nougat decoder 403 - """ 404 - 405 - def __init__( 406 - self, 407 - decoder_layer: int, 408 - hidden_dimension: int, 409 - num_attention_heads: int, 410 - num_key_value_heads: int, 411 - intermediate_size: int, 412 - max_query: int, 413 - ): 414 - super().__init__() 415 - 416 - self.model = CustomQwen2Decoder( 417 - decoder_layer=decoder_layer, 418 - hidden_dimension=hidden_dimension, 419 - num_attention_heads=num_attention_heads, 420 - num_key_value_heads=num_key_value_heads, 421 - intermediate_size=intermediate_size, 422 - attn_implementation="sdpa", 423 - ) 424 - 425 - 426 - 427 - 428 - self.query_768 = nn.Embedding(144, hidden_dimension) 429 - self.query_1024 = nn.Embedding(256, hidden_dimension) 430 - 431 - 432 - # self.query_refixation = nn.Embedding(int(math.sqrt(max_query)), hidden_dimension) 433 - 434 - 435 - def forward(self, x: torch.Tensor) -> torch.Tensor: 436 - x = x.flatten(2).transpose(1, 2) 437 - 438 - bs, n_query, _ = x.shape 439 - 440 - if n_query == 144: 441 - param_img = self.query_768.weight 442 - elif n_query == 256: 443 - param_img = self.query_1024.weight 444 - 445 - batch_query_imgs = param_img.unsqueeze(0).expand( 446 - bs, -1, -1 447 - ) # (batch_size, num_queries, hidden_size) 448 - 449 - 450 - 451 - x_combined = torch.cat([x, batch_query_imgs], dim=1) 452 - 453 - token_type_ids = torch.cat([ 454 - torch.zeros(bs, n_query, dtype=torch.long), 455 - torch.ones(bs, n_query, dtype=torch.long), 456 - ], dim=1) 457 - 458 - 459 - y = self.model(x_combined, token_type_ids)[0] 460 - 461 - 462 - y = y[:, n_query:, :] # causal flow query 463 - 464 - 465 - return y 466 - 467 - 468 - def build_qwen2_decoder_as_encoder( 469 - decoder_layer=24, 470 - hidden_dimension=896, 471 - num_attention_heads=14, 472 - num_key_value_heads=2, 473 - intermediate_size=4864, 474 - max_query = 400, 475 - checkpoint=None, 476 - ): 477 - 478 - decoder_as_encoder = Qwen2Decoder2Encoder( 479 - decoder_layer=decoder_layer, 480 - hidden_dimension = hidden_dimension, 481 - num_attention_heads = num_attention_heads, 482 - num_key_value_heads = num_key_value_heads, 483 - intermediate_size = intermediate_size, 484 - max_query = max_query 485 - ) 486 - 487 - 488 - 489 - 490 - if checkpoint is not None: 491 - # with open(checkpoint, "rb") as f: 492 - state_dict = torch.load(checkpoint) 493 - 494 - decoder_as_encoder.load_state_dict(state_dict, strict=True) 495 - # tob 496 - print(checkpoint) 497 - return decoder_as_encoder 498 - 499 - 500 - 501 - 502 - #=========================Sam-Vary================================= 503 - 504 - 505 - def get_abs_pos_sam(abs_pos, tgt_size): 506 - 507 - dtype = abs_pos.dtype 508 - 509 - src_size = abs_pos.size(1) 510 - 511 - if src_size != tgt_size: 512 - old_pos_embed = abs_pos.permute(0, 3, 1, 2) 513 - old_pos_embed = old_pos_embed.to(torch.float32) 514 - new_pos_embed = F.interpolate( 515 - old_pos_embed, 516 - size=(tgt_size, tgt_size), 517 - mode='bicubic', 518 - antialias=True, 519 - align_corners=False, 520 - ).to(dtype) 521 - new_pos_embed = new_pos_embed.permute(0, 2, 3, 1) 522 - return new_pos_embed 523 - else: 524 - return abs_pos 525 - 526 - 527 - 528 - 529 - class MLPBlock(nn.Module): 530 - def __init__( 531 - self, 532 - embedding_dim: int, 533 - mlp_dim: int, 534 - act: Type[nn.Module] = nn.GELU, 535 - ) -> None: 536 - super().__init__() 537 - self.lin1 = nn.Linear(embedding_dim, mlp_dim) 538 - self.lin2 = nn.Linear(mlp_dim, embedding_dim) 539 - self.act = act() 540 - 541 - def forward(self, x: torch.Tensor) -> torch.Tensor: 542 - return self.lin2(self.act(self.lin1(x))) 543 - 544 - 545 - # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 546 - # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 547 - class LayerNorm2d(nn.Module): 548 - def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 549 - super().__init__() 550 - self.weight = nn.Parameter(torch.ones(num_channels)) 551 - self.bias = nn.Parameter(torch.zeros(num_channels)) 552 - self.eps = eps 553 - 554 - def forward(self, x: torch.Tensor) -> torch.Tensor: 555 - u = x.mean(1, keepdim=True) 556 - s = (x - u).pow(2).mean(1, keepdim=True) 557 - x = (x - u) / torch.sqrt(s + self.eps) 558 - x = self.weight[:, None, None] * x + self.bias[:, None, None] 559 - return x 560 - 561 - 562 - # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa 563 - class ImageEncoderViT(nn.Module): 564 - def __init__( 565 - self, 566 - img_size: int = 1024, 567 - patch_size: int = 16, 568 - in_chans: int = 3, 569 - embed_dim: int = 768, 570 - depth: int = 12, 571 - num_heads: int = 12, 572 - mlp_ratio: float = 4.0, 573 - out_chans: int = 256, 574 - qkv_bias: bool = True, 575 - norm_layer: Type[nn.Module] = nn.LayerNorm, 576 - act_layer: Type[nn.Module] = nn.GELU, 577 - use_abs_pos: bool = True, 578 - use_rel_pos: bool = False, 579 - rel_pos_zero_init: bool = True, 580 - window_size: int = 0, 581 - global_attn_indexes: Tuple[int, ...] = (), 582 - ) -> None: 583 - """ 584 - Args: 585 - img_size (int): Input image size. 586 - patch_size (int): Patch size. 587 - in_chans (int): Number of input image channels. 588 - embed_dim (int): Patch embedding dimension. 589 - depth (int): Depth of ViT. 590 - num_heads (int): Number of attention heads in each ViT block. 591 - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 592 - qkv_bias (bool): If True, add a learnable bias to query, key, value. 593 - norm_layer (nn.Module): Normalization layer. 594 - act_layer (nn.Module): Activation layer. 595 - use_abs_pos (bool): If True, use absolute positional embeddings. 596 - use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 597 - rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 598 - window_size (int): Window size for window attention blocks. 599 - global_attn_indexes (list): Indexes for blocks using global attention. 600 - """ 601 - super().__init__() 602 - self.img_size = img_size 603 - 604 - self.patch_embed = PatchEmbed( 605 - kernel_size=(patch_size, patch_size), 606 - stride=(patch_size, patch_size), 607 - in_chans=in_chans, 608 - embed_dim=embed_dim, 609 - ) 610 - 611 - self.pos_embed: Optional[nn.Parameter] = None 612 - if use_abs_pos: 613 - # Initialize absolute positional embedding with pretrain image size. 614 - self.pos_embed = nn.Parameter( 615 - torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) 616 - ) 617 - 618 - self.blocks = nn.ModuleList() 619 - for i in range(depth): 620 - block = Block( 621 - dim=embed_dim, 622 - num_heads=num_heads, 623 - mlp_ratio=mlp_ratio, 624 - qkv_bias=qkv_bias, 625 - norm_layer=norm_layer, 626 - act_layer=act_layer, 627 - use_rel_pos=use_rel_pos, 628 - rel_pos_zero_init=rel_pos_zero_init, 629 - window_size=window_size if i not in global_attn_indexes else 0, 630 - input_size=(img_size // patch_size, img_size // patch_size), 631 - ) 632 - self.blocks.append(block) 633 - 634 - self.neck = nn.Sequential( 635 - nn.Conv2d( 636 - embed_dim, 637 - out_chans, 638 - kernel_size=1, 639 - bias=False, 640 - ), 641 - LayerNorm2d(out_chans), 642 - nn.Conv2d( 643 - out_chans, 644 - out_chans, 645 - kernel_size=3, 646 - padding=1, 647 - bias=False, 648 - ), 649 - LayerNorm2d(out_chans), 650 - ) 651 - 652 - self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False) 653 - self.net_3 = nn.Conv2d(512, 896, kernel_size=3, stride=2, padding=1, bias=False) 654 - 655 - def forward(self, x: torch.Tensor) -> torch.Tensor: 656 - x = self.patch_embed(x) 657 - if self.pos_embed is not None: 658 - # x = x + self.pos_embed 659 - x = x + get_abs_pos_sam(self.pos_embed, x.size(1)) 660 - 661 - for blk in self.blocks: 662 - x = blk(x) 663 - 664 - x = self.neck(x.permute(0, 3, 1, 2)) 665 - x2 = self.net_2(x) 666 - x3 = self.net_3(x2.clone()) 667 - 668 - return x3 669 - 670 - 671 - class Block(nn.Module): 672 - """Transformer blocks with support of window attention and residual propagation blocks""" 673 - 674 - def __init__( 675 - self, 676 - dim: int, 677 - num_heads: int, 678 - mlp_ratio: float = 4.0, 679 - qkv_bias: bool = True, 680 - norm_layer: Type[nn.Module] = nn.LayerNorm, 681 - act_layer: Type[nn.Module] = nn.GELU, 682 - use_rel_pos: bool = False, 683 - rel_pos_zero_init: bool = True, 684 - window_size: int = 0, 685 - input_size: Optional[Tuple[int, int]] = None, 686 - ) -> None: 687 - """ 688 - Args: 689 - dim (int): Number of input channels. 690 - num_heads (int): Number of attention heads in each ViT block. 691 - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 692 - qkv_bias (bool): If True, add a learnable bias to query, key, value. 693 - norm_layer (nn.Module): Normalization layer. 694 - act_layer (nn.Module): Activation layer. 695 - use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 696 - rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 697 - window_size (int): Window size for window attention blocks. If it equals 0, then 698 - use global attention. 699 - input_size (tuple(int, int) or None): Input resolution for calculating the relative 700 - positional parameter size. 701 - """ 702 - super().__init__() 703 - self.norm1 = norm_layer(dim) 704 - self.attn = Attention( 705 - dim, 706 - num_heads=num_heads, 707 - qkv_bias=qkv_bias, 708 - use_rel_pos=use_rel_pos, 709 - rel_pos_zero_init=rel_pos_zero_init, 710 - input_size=input_size if window_size == 0 else (window_size, window_size), 711 - ) 712 - 713 - self.norm2 = norm_layer(dim) 714 - self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) 715 - 716 - self.window_size = window_size 717 - 718 - def forward(self, x: torch.Tensor) -> torch.Tensor: 719 - shortcut = x 720 - x = self.norm1(x) 721 - # Window partition 722 - if self.window_size > 0: 723 - H, W = x.shape[1], x.shape[2] 724 - x, pad_hw = window_partition(x, self.window_size) 725 - 726 - x = self.attn(x) 727 - # Reverse window partition 728 - if self.window_size > 0: 729 - x = window_unpartition(x, self.window_size, pad_hw, (H, W)) 730 - 731 - x = shortcut + x 732 - x = x + self.mlp(self.norm2(x)) 733 - 734 - return x 735 - 736 - 737 - class Attention(nn.Module): 738 - """Multi-head Attention block with relative position embeddings.""" 739 - 740 - def __init__( 741 - self, 742 - dim: int, 743 - num_heads: int = 8, 744 - qkv_bias: bool = True, 745 - use_rel_pos: bool = False, 746 - rel_pos_zero_init: bool = True, 747 - input_size: Optional[Tuple[int, int]] = None, 748 - ) -> None: 749 - """ 750 - Args: 751 - dim (int): Number of input channels. 752 - num_heads (int): Number of attention heads. 753 - qkv_bias (bool): If True, add a learnable bias to query, key, value. 754 - rel_pos (bool): If True, add relative positional embeddings to the attention map. 755 - rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 756 - input_size (tuple(int, int) or None): Input resolution for calculating the relative 757 - positional parameter size. 758 - """ 759 - super().__init__() 760 - self.num_heads = num_heads 761 - head_dim = dim // num_heads 762 - self.scale = head_dim**-0.5 763 - 764 - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 765 - self.proj = nn.Linear(dim, dim) 766 - 767 - self.use_rel_pos = use_rel_pos 768 - if self.use_rel_pos: 769 - assert ( 770 - input_size is not None 771 - ), "Input size must be provided if using relative positional encoding." 772 - # initialize relative positional embeddings 773 - self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) 774 - self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) 775 - 776 - def forward(self, x: torch.Tensor) -> torch.Tensor: 777 - B, H, W, _ = x.shape 778 - # qkv with shape (3, B, nHead, H * W, C) 779 - qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 780 - # q, k, v with shape (B * nHead, H * W, C) 781 - q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) 782 - 783 - rel_h, rel_w = None, None 784 - if self.use_rel_pos: 785 - rel_h, rel_w = add_decomposed_rel_pos(q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) 786 - 787 - q = q.view(B, self.num_heads, H * W, -1) 788 - k = k.view(B, self.num_heads, H * W, -1) 789 - v = v.view(B, self.num_heads, H * W, -1) 790 - 791 - if self.use_rel_pos: 792 - rel_h = rel_h.view(B, self.num_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3)) 793 - rel_w = rel_w.view(B, self.num_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3)) 794 - attn_bias = (rel_h + rel_w).view(B, self.num_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4)) 795 - x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias) 796 - # x = _attention_rel_h_rel_w(q, k, v, rel_h, rel_w) 797 - else: 798 - x = torch.nn.functional.scaled_dot_product_attention(q, k, v) 799 - 800 - x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) 801 - 802 - x = self.proj(x) 803 - 804 - return x 805 - 806 - 807 - def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: 808 - """ 809 - Partition into non-overlapping windows with padding if needed. 810 - Args: 811 - x (tensor): input tokens with [B, H, W, C]. 812 - window_size (int): window size. 813 - 814 - Returns: 815 - windows: windows after partition with [B * num_windows, window_size, window_size, C]. 816 - (Hp, Wp): padded height and width before partition 817 - """ 818 - B, H, W, C = x.shape 819 - 820 - pad_h = (window_size - H % window_size) % window_size 821 - pad_w = (window_size - W % window_size) % window_size 822 - if pad_h > 0 or pad_w > 0: 823 - x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 824 - Hp, Wp = H + pad_h, W + pad_w 825 - 826 - x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 827 - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 828 - return windows, (Hp, Wp) 829 - 830 - 831 - def window_unpartition( 832 - windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] 833 - ) -> torch.Tensor: 834 - """ 835 - Window unpartition into original sequences and removing padding. 836 - Args: 837 - windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 838 - window_size (int): window size. 839 - pad_hw (Tuple): padded height and width (Hp, Wp). 840 - hw (Tuple): original height and width (H, W) before padding. 841 - 842 - Returns: 843 - x: unpartitioned sequences with [B, H, W, C]. 844 - """ 845 - Hp, Wp = pad_hw 846 - H, W = hw 847 - B = windows.shape[0] // (Hp * Wp // window_size // window_size) 848 - x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) 849 - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 850 - 851 - if Hp > H or Wp > W: 852 - x = x[:, :H, :W, :].contiguous() 853 - return x 854 - 855 - 856 - def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: 857 - """ 858 - Get relative positional embeddings according to the relative positions of 859 - query and key sizes. 860 - Args: 861 - q_size (int): size of query q. 862 - k_size (int): size of key k. 863 - rel_pos (Tensor): relative position embeddings (L, C). 864 - 865 - Returns: 866 - Extracted positional embeddings according to relative positions. 867 - """ 868 - max_rel_dist = int(2 * max(q_size, k_size) - 1) 869 - # Interpolate rel pos if needed. 870 - if rel_pos.shape[0] != max_rel_dist: 871 - # Interpolate rel pos. 872 - dtype = rel_pos.dtype 873 - rel_pos = rel_pos.to(torch.float32) 874 - rel_pos_resized = F.interpolate( 875 - rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), 876 - size=max_rel_dist, 877 - mode="linear", 878 - ).to(dtype) 879 - rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) 880 - else: 881 - rel_pos_resized = rel_pos 882 - 883 - # Scale the coords with short length if shapes for q and k are different. 884 - q_coords = torch.arange(q_size, device=rel_pos.device)[:, None] * max(k_size / q_size, 1.0) 885 - k_coords = torch.arange(k_size, device=rel_pos.device)[None, :] * max(q_size / k_size, 1.0) 886 - relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) 887 - 888 - return rel_pos_resized[relative_coords.long()] 889 - 890 - 891 - def add_decomposed_rel_pos( 892 - q: torch.Tensor, 893 - rel_pos_h: torch.Tensor, 894 - rel_pos_w: torch.Tensor, 895 - q_size: Tuple[int, int], 896 - k_size: Tuple[int, int], 897 - ) -> torch.Tensor: 898 - """ 899 - Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. 900 - https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 901 - Args: 902 - q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). 903 - rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. 904 - rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. 905 - q_size (Tuple): spatial sequence size of query q with (q_h, q_w). 906 - k_size (Tuple): spatial sequence size of key k with (k_h, k_w). 907 - 908 - Returns: 909 - attn (Tensor): attention map with added relative positional embeddings. 910 - """ 911 - q_h, q_w = q_size 912 - k_h, k_w = k_size 913 - Rh = get_rel_pos(q_h, k_h, rel_pos_h) 914 - Rw = get_rel_pos(q_w, k_w, rel_pos_w) 915 - 916 - B, _, dim = q.shape 917 - r_q = q.reshape(B, q_h, q_w, dim) 918 - rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) 919 - rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) 920 - rel_h = rel_h.unsqueeze(-1) 921 - rel_w = rel_w.unsqueeze(-2) 922 - rel_h = rel_h.reshape(B, q_h * q_w, k_h, 1) 923 - rel_w = rel_w.reshape(B, q_h * q_w, 1, k_w) 924 - 925 - return rel_h, rel_w 926 - 927 - 928 - class PatchEmbed(nn.Module): 929 - """ 930 - Image to Patch Embedding. 931 - """ 932 - 933 - def __init__( 934 - self, 935 - kernel_size: Tuple[int, int] = (16, 16), 936 - stride: Tuple[int, int] = (16, 16), 937 - padding: Tuple[int, int] = (0, 0), 938 - in_chans: int = 3, 939 - embed_dim: int = 768, 940 - ) -> None: 941 - """ 942 - Args: 943 - kernel_size (Tuple): kernel size of the projection layer. 944 - stride (Tuple): stride of the projection layer. 945 - padding (Tuple): padding size of the projection layer. 946 - in_chans (int): Number of input image channels. 947 - embed_dim (int): Patch embedding dimension. 948 - """ 949 - super().__init__() 950 - 951 - self.proj = nn.Conv2d( 952 - in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 953 - ) 954 - 955 - def forward(self, x: torch.Tensor) -> torch.Tensor: 956 - x = self.proj(x) 957 - # B C H W -> B H W C 958 - x = x.permute(0, 2, 3, 1) 959 - return x 960 - 961 - 962 - def build_sam_vit_b(checkpoint=None): 963 - return _build_sam( 964 - encoder_embed_dim=768, 965 - encoder_depth=12, 966 - encoder_num_heads=12, 967 - encoder_global_attn_indexes=[2, 5, 8, 11], 968 - checkpoint=checkpoint, 969 - ) 970 - 971 - def build_sam_fast_vit_b(checkpoint=None, compile_mode='max-autotune', dtype=torch.bfloat16): 972 - image_encoder = build_sam_vit_b(checkpoint).eval().to(dtype) 973 - # sam = _apply_eval_dtype_sam(sam, dtype) 974 - image_encoder = torch.compile(image_encoder, mode=compile_mode) 975 - return image_encoder 976 - 977 - 978 - def _build_sam( 979 - encoder_embed_dim, 980 - encoder_depth, 981 - encoder_num_heads, 982 - encoder_global_attn_indexes, 983 - checkpoint=None, 984 - ): 985 - prompt_embed_dim = 256 986 - image_size = 1024 987 - vit_patch_size = 16 988 - image_embedding_size = image_size // vit_patch_size 989 - image_encoder=ImageEncoderViT( 990 - depth=encoder_depth, 991 - embed_dim=encoder_embed_dim, 992 - img_size=image_size, 993 - mlp_ratio=4, 994 - norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 995 - num_heads=encoder_num_heads, 996 - patch_size=vit_patch_size, 997 - qkv_bias=True, 998 - use_rel_pos=True, 999 - global_attn_indexes=encoder_global_attn_indexes, 1000 - window_size=14, 1001 - out_chans=prompt_embed_dim, 1002 - ) 1003 - image_encoder.eval() 1004 - if checkpoint is not None: 1005 - # with open(checkpoint, "rb") as f: 1006 - state_dict = torch.load(checkpoint) 1007 - # print(state_dict.keys()) 1008 - # for key in state_dict: 1009 - # image_encoder.load_state_dict({k[14:]: v for k, v in state_dict.items() if 'image_encoder' in k}, strict=False) 1010 - # ocr-anyting 1011 - # image_encoder.load_state_dict(state_dict, strict=True) 1012 - # tob 1013 - image_encoder.load_state_dict({k[30:]: v for k, v in state_dict.items() if 'vision_tower_high' in k}, strict=True) 1014 - print(checkpoint) 1015 - return image_encoder
-1029
src/deepseek_ocr2/modeling_deepseekocr2.py
··· 1 - from .modeling_deepseekv2 import DeepseekV2Model, DeepseekV2ForCausalLM 2 - from .configuration_deepseek_v2 import DeepseekV2Config 3 - from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast 4 - from typing import List, Optional, Tuple, Union 5 - from transformers.cache_utils import Cache 6 - import requests 7 - from PIL import Image, ImageOps, ImageDraw, ImageFont 8 - from io import BytesIO 9 - import torch 10 - import torch.nn as nn 11 - from torch.nn import CrossEntropyLoss 12 - from torchvision import transforms 13 - # from torchvision.transforms.functional import InterpolationMode 14 - import os 15 - from .deepencoderv2 import build_sam_vit_b, build_qwen2_decoder_as_encoder, MlpProjector 16 - from addict import Dict 17 - from transformers import TextStreamer 18 - from .conversation import get_conv_template 19 - from abc import ABC 20 - import math 21 - import re 22 - from tqdm import tqdm 23 - import numpy as np 24 - # import time 25 - 26 - 27 - 28 - def load_image(image_path): 29 - 30 - try: 31 - image = Image.open(image_path) 32 - 33 - corrected_image = ImageOps.exif_transpose(image) 34 - 35 - return corrected_image 36 - 37 - except Exception as e: 38 - print(f"error: {e}") 39 - try: 40 - return Image.open(image_path) 41 - except: 42 - return None 43 - 44 - 45 - def re_match(text): 46 - pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)' 47 - matches = re.findall(pattern, text, re.DOTALL) 48 - 49 - # pattern1 = r'<\|ref\|>.*?<\|/ref\|>\n' 50 - # new_text1 = re.sub(pattern1, '', text, flags=re.DOTALL) 51 - 52 - mathes_image = [] 53 - mathes_other = [] 54 - for a_match in matches: 55 - if '<|ref|>image<|/ref|>' in a_match[0]: 56 - mathes_image.append(a_match[0]) 57 - else: 58 - mathes_other.append(a_match[0]) 59 - return matches, mathes_image, mathes_other 60 - 61 - 62 - def extract_coordinates_and_label(ref_text, image_width, image_height): 63 - 64 - try: 65 - label_type = ref_text[1] 66 - cor_list = eval(ref_text[2]) 67 - except Exception as e: 68 - print(e) 69 - return None 70 - 71 - return (label_type, cor_list) 72 - 73 - 74 - def draw_bounding_boxes(image, refs, ouput_path): 75 - 76 - image_width, image_height = image.size 77 - 78 - img_draw = image.copy() 79 - draw = ImageDraw.Draw(img_draw) 80 - 81 - overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0)) 82 - draw2 = ImageDraw.Draw(overlay) 83 - 84 - # try: 85 - # except IOError: 86 - # try: 87 - # font = ImageFont.truetype("DejaVuSans.ttf", 20) 88 - # except IOError: 89 - font = ImageFont.load_default() 90 - 91 - img_idx = 0 92 - 93 - for i, ref in enumerate(refs): 94 - try: 95 - result = extract_coordinates_and_label(ref, image_width, image_height) 96 - if result: 97 - label_type, points_list = result 98 - 99 - color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255)) 100 - 101 - color_a = color + (20, ) 102 - for points in points_list: 103 - x1, y1, x2, y2 = points 104 - 105 - x1 = int(x1 / 999 * image_width) 106 - y1 = int(y1 / 999 * image_height) 107 - 108 - x2 = int(x2 / 999 * image_width) 109 - y2 = int(y2 / 999 * image_height) 110 - 111 - if label_type == 'image': 112 - try: 113 - cropped = image.crop((x1, y1, x2, y2)) 114 - cropped.save(f"{ouput_path}/images/{img_idx}.jpg") 115 - except Exception as e: 116 - print(e) 117 - pass 118 - img_idx += 1 119 - 120 - try: 121 - if label_type == 'title': 122 - draw.rectangle([x1, y1, x2, y2], outline=color, width=4) 123 - draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1) 124 - else: 125 - draw.rectangle([x1, y1, x2, y2], outline=color, width=2) 126 - draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1) 127 - text_x = x1 128 - text_y = max(0, y1 - 15) 129 - 130 - 131 - text_bbox = draw.textbbox((0, 0), label_type, font=font) 132 - text_width = text_bbox[2] - text_bbox[0] 133 - text_height = text_bbox[3] - text_bbox[1] 134 - draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height], 135 - fill=(255, 255, 255, 30)) 136 - 137 - draw.text((text_x, text_y), label_type, font=font, fill=color) 138 - except: 139 - pass 140 - except: 141 - continue 142 - img_draw.paste(overlay, (0, 0), overlay) 143 - return img_draw 144 - 145 - 146 - def process_image_with_refs(image, ref_texts, output_path): 147 - 148 - result_image = draw_bounding_boxes(image, ref_texts, output_path) 149 - 150 - return result_image 151 - 152 - 153 - 154 - 155 - 156 - def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): 157 - best_ratio_diff = float('inf') 158 - best_ratio = (1, 1) 159 - area = width * height 160 - for ratio in target_ratios: 161 - target_aspect_ratio = ratio[0] / ratio[1] 162 - ratio_diff = abs(aspect_ratio - target_aspect_ratio) 163 - if ratio_diff < best_ratio_diff: 164 - best_ratio_diff = ratio_diff 165 - best_ratio = ratio 166 - elif ratio_diff == best_ratio_diff: 167 - if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: 168 - best_ratio = ratio 169 - # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}') 170 - return best_ratio 171 - 172 - 173 - def dynamic_preprocess(image, min_num=2, max_num=6, image_size=768, use_thumbnail=False): 174 - orig_width, orig_height = image.size 175 - aspect_ratio = orig_width / orig_height 176 - 177 - # calculate the existing image aspect ratio 178 - target_ratios = set( 179 - (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if 180 - i * j <= max_num and i * j >= min_num) 181 - # print(target_ratios) 182 - target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) 183 - 184 - # find the closest aspect ratio to the target 185 - target_aspect_ratio = find_closest_aspect_ratio( 186 - aspect_ratio, target_ratios, orig_width, orig_height, image_size) 187 - 188 - # print(target_aspect_ratio) 189 - # calculate the target width and height 190 - target_width = image_size * target_aspect_ratio[0] 191 - target_height = image_size * target_aspect_ratio[1] 192 - blocks = target_aspect_ratio[0] * target_aspect_ratio[1] 193 - 194 - # resize the image 195 - resized_img = image.resize((target_width, target_height)) 196 - processed_images = [] 197 - for i in range(blocks): 198 - box = ( 199 - (i % (target_width // image_size)) * image_size, 200 - (i // (target_width // image_size)) * image_size, 201 - ((i % (target_width // image_size)) + 1) * image_size, 202 - ((i // (target_width // image_size)) + 1) * image_size 203 - ) 204 - # split the image 205 - split_img = resized_img.crop(box) 206 - processed_images.append(split_img) 207 - assert len(processed_images) == blocks 208 - if use_thumbnail and len(processed_images) != 1: 209 - thumbnail_img = image.resize((image_size, image_size)) 210 - processed_images.append(thumbnail_img) 211 - return processed_images, target_aspect_ratio 212 - 213 - 214 - 215 - def normalize_transform(mean, std): 216 - if mean is None and std is None: 217 - transform = None 218 - elif mean is None and std is not None: 219 - mean = [0.] * len(std) 220 - transform = transforms.Normalize(mean=mean, std=std) 221 - elif mean is not None and std is None: 222 - std = [1.] * len(mean) 223 - transform = transforms.Normalize(mean=mean, std=std) 224 - else: 225 - transform = transforms.Normalize(mean=mean, std=std) 226 - 227 - return transform 228 - 229 - 230 - 231 - def format_messages( 232 - conversations: List[Dict[str, str]], 233 - sft_format: str = "deepseek", 234 - system_prompt: str = "", 235 - ): 236 - """ 237 - Applies the SFT template to conversation. 238 - 239 - Args: 240 - conversations (List[Dict]): A List of messages. 241 - sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek". 242 - system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "". 243 - 244 - Returns: 245 - sft_prompt (str): The formatted text. 246 - """ 247 - 248 - conv = get_conv_template(sft_format) 249 - conv.set_system_message(system_prompt) 250 - for message in conversations: 251 - conv.append_message(message["role"], message["content"].strip()) 252 - sft_prompt = conv.get_prompt().strip() 253 - 254 - return sft_prompt 255 - 256 - 257 - def text_encode(tokenizer, text: str, bos: bool = True, eos: bool = False): 258 - t = tokenizer.encode(text, add_special_tokens=False) 259 - bos_id = 0 260 - eos_id = 1 261 - if bos: 262 - t = [bos_id] + t 263 - if eos: 264 - t = t + [eos_id] 265 - 266 - return t 267 - 268 - def load_pil_images(conversations: List[Dict[str, str]]) -> List[Image.Image]: 269 - """ 270 - 271 - Args: 272 - conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is : 273 - [ 274 - { 275 - "role": "User", 276 - "content": "<image_placeholder>\nExtract all information from this image and convert them into markdown format.", 277 - "images": ["./examples/table_datasets.png"] 278 - }, 279 - {"role": "Assistant", "content": ""}, 280 - ] 281 - 282 - Returns: 283 - pil_images (List[PIL.Image.Image]): the list of PIL images. 284 - 285 - """ 286 - 287 - pil_images = [] 288 - 289 - for message in conversations: 290 - if "images" not in message: 291 - continue 292 - 293 - for image_path in message["images"]: 294 - # print('----------------') 295 - # print(image_path) 296 - # print('----------------') 297 - # exit() 298 - 299 - # pil_img = Image.open(image_path) 300 - pil_img = load_image(image_path) 301 - pil_img = pil_img.convert("RGB") 302 - pil_images.append(pil_img) 303 - 304 - return pil_images 305 - 306 - 307 - class BaseTransform(ABC): 308 - 309 - def set_rng(self, *args, **kwargs): 310 - pass 311 - 312 - def __call__(self, *args, **kwargs) -> torch.Tensor: 313 - pass 314 - 315 - @property 316 - def default_shape(self): 317 - raise NotImplementedError 318 - 319 - 320 - class BasicImageTransform(BaseTransform): 321 - def __init__( 322 - self, 323 - mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), 324 - std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), 325 - normalize: bool = True 326 - ): 327 - self.mean = mean 328 - self.std = std 329 - 330 - transform_pipelines = [ 331 - transforms.ToTensor() 332 - ] 333 - 334 - normalize = normalize_transform(mean, std) if normalize else nn.Identity() 335 - if normalize is not None: 336 - transform_pipelines.append(normalize) 337 - 338 - self.transform = transforms.Compose(transform_pipelines) 339 - 340 - def __call__(self, x): 341 - x = self.transform(x) 342 - return x 343 - 344 - class NoEOSTextStreamer(TextStreamer): 345 - def on_finalized_text(self, text: str, stream_end: bool = False): 346 - 347 - eos_text = self.tokenizer.decode([self.tokenizer.eos_token_id], skip_special_tokens=False) 348 - text = text.replace(eos_text, "\n") 349 - print(text, flush=True, end="") 350 - 351 - 352 - class DeepseekOCR2Config(DeepseekV2Config): 353 - model_type = "DeepseekOCR2" 354 - 355 - class DeepseekOCR2Model(DeepseekV2Model): 356 - config_class = DeepseekOCR2Config 357 - 358 - def __init__(self, config: DeepseekV2Config): 359 - super(DeepseekOCR2Model, self).__init__(config) 360 - 361 - self.sam_model = build_sam_vit_b() 362 - self.qwen2_model = build_qwen2_decoder_as_encoder() 363 - # self.conv_2 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=2, stride=2) 364 - n_embed = 1280 365 - self.projector = MlpProjector(Dict(projector_type="linear", input_dim=896, n_embed=n_embed)) 366 - embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32)) 367 - # self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std) 368 - self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std) 369 - 370 - 371 - 372 - 373 - def forward( 374 - self, 375 - input_ids: torch.LongTensor = None, 376 - attention_mask: Optional[torch.Tensor] = None, 377 - position_ids: Optional[torch.LongTensor] = None, 378 - past_key_values: Optional[List[torch.FloatTensor]] = None, 379 - inputs_embeds: Optional[torch.FloatTensor] = None, 380 - use_cache: Optional[bool] = None, 381 - output_attentions: Optional[bool] = None, 382 - output_hidden_states: Optional[bool] = None, 383 - images: Optional[torch.FloatTensor] = None, 384 - images_seq_mask: Optional[torch.FloatTensor] = None, 385 - images_spatial_crop: Optional[torch.FloatTensor] = None, 386 - return_dict: Optional[bool] = None, 387 - ) -> Union[Tuple, BaseModelOutputWithPast]: 388 - 389 - 390 - 391 - 392 - if inputs_embeds is None: 393 - # .clone().to(bfloat16): two training-only fixes -- 394 - # 1. clone() breaks the autograd leaf-variable link so the 395 - # masked_scatter_ below does not raise "in-place on leaf". 396 - # 2. to(bfloat16) matches the dtype of image features produced 397 - # by the vision encoder (prepare_model_for_kbit_training 398 - # upcasts the embedding table to float32; the vision encoder 399 - # stays bfloat16 after our explicit cast in the training script). 400 - inputs_embeds = self.get_input_embeddings()(input_ids).clone().to(torch.bfloat16) 401 - 402 - 403 - 404 - sam_model = getattr(self, 'sam_model', None) 405 - # sam_model = self.sam_model 406 - qwen2_model = getattr(self, 'qwen2_model', None) 407 - 408 - 409 - 410 - if sam_model is not None and (input_ids.shape[1] != 1 or self.training) and torch.sum(images[0][1]).item() != 0: 411 - 412 - idx = 0 413 - 414 - # sam_model = torch.jit.script(sam_model) 415 - 416 - # start_time = time.time() 417 - for image, crop_shape in zip(images, images_spatial_crop): 418 - images_in_this_batch = [] 419 - 420 - patches = image[0] 421 - image_ori = image[1] 422 - 423 - with torch.no_grad(): 424 - # with torch.inference_mode(): 425 - 426 - if torch.sum(patches).item() != 0: 427 - # P, C, H, W = patches.shape 428 - crop_flag = 1 429 - local_features_1 = sam_model(patches) 430 - 431 - local_features_2 = qwen2_model(local_features_1) 432 - # vit_time = time.time() 433 - local_features = local_features_2 434 - local_features = self.projector(local_features) 435 - 436 - 437 - global_features_1 = sam_model(image_ori) 438 - global_features_2 = qwen2_model(global_features_1) 439 - global_features = global_features_2 440 - global_features = self.projector(global_features) 441 - 442 - pass 443 - 444 - _, hw, n_dim = global_features.shape 445 - # h = w = int(hw ** 0.5) 446 - 447 - _2, hw2, n_dim2 = local_features.shape 448 - # h2 = w2 = int(hw2 ** 0.5) 449 - 450 - 451 - global_features = global_features.view(-1, n_dim) 452 - 453 - 454 - local_features = local_features.view(-1, n_dim2) 455 - 456 - global_local_features = torch.cat([local_features, global_features, self.view_seperator[None, :]], dim=0) 457 - 458 - # end_time = time.time() 459 - 460 - # print('sam: ', sam_time - start_time) 461 - # print('vit: ', vit_time - sam_time) 462 - # print('all: ', end_time - start_time) 463 - 464 - # exit() 465 - 466 - else: 467 - global_features_1 = sam_model(image_ori) 468 - global_features_2 = qwen2_model(global_features_1) 469 - global_features = global_features_2 470 - global_features = self.projector(global_features) 471 - pass 472 - _, hw, n_dim = global_features.shape 473 - # h = w = int(hw ** 0.5) 474 - 475 - 476 - # global_features = global_features.view(h, w, n_dim) 477 - 478 - # global_features = torch.cat( 479 - # [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1 480 - # ) 481 - 482 - global_features = global_features.view(-1, n_dim) 483 - 484 - global_local_features = torch.cat([global_features, self.view_seperator[None, :]], dim=0) 485 - 486 - images_in_this_batch.append(global_local_features) 487 - 488 - 489 - # print(inputs_embeds.shape) 490 - 491 - if images_in_this_batch: 492 - images_in_this_batch = torch.cat(images_in_this_batch, dim=0) 493 - # exit() 494 - 495 - inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).cuda(), images_in_this_batch) 496 - 497 - idx += 1 498 - 499 - 500 - return super(DeepseekOCR2Model, self).forward( 501 - input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, 502 - inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids, 503 - output_attentions=output_attentions, output_hidden_states=output_hidden_states, 504 - return_dict=return_dict 505 - ) 506 - 507 - 508 - class DeepseekOCR2ForCausalLM(DeepseekV2ForCausalLM): 509 - 510 - config_class = DeepseekOCR2Config 511 - # supports_gradient_checkpointing = True 512 - 513 - def __init__(self, config): 514 - super(DeepseekV2ForCausalLM, self).__init__(config) 515 - self.model = DeepseekOCR2Model(config) 516 - 517 - self.vocab_size = config.vocab_size 518 - 519 - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 520 - 521 - # self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 522 - 523 - # Initialize weights and apply final processing 524 - self.post_init() 525 - 526 - def get_model(self): 527 - return self.model 528 - 529 - 530 - def forward( 531 - self, 532 - input_ids: torch.LongTensor = None, 533 - attention_mask: Optional[torch.Tensor] = None, 534 - position_ids: Optional[torch.LongTensor] = None, 535 - past_key_values: Optional[List[torch.FloatTensor]] = None, 536 - inputs_embeds: Optional[torch.FloatTensor] = None, 537 - labels: Optional[torch.LongTensor] = None, 538 - use_cache: Optional[bool] = None, 539 - output_attentions: Optional[bool] = None, 540 - output_hidden_states: Optional[bool] = None, 541 - images: Optional[torch.FloatTensor] = None, 542 - images_seq_mask: Optional[torch.FloatTensor] = None, 543 - images_spatial_crop: Optional[torch.FloatTensor] = None, 544 - return_dict: Optional[bool] = None, 545 - 546 - ) -> Union[Tuple, CausalLMOutputWithPast]: 547 - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 548 - output_hidden_states = ( 549 - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 550 - ) 551 - return_dict = return_dict if return_dict is not None else self.config.use_return_dict 552 - 553 - 554 - 555 - outputs = self.model( 556 - input_ids=input_ids, 557 - past_key_values=past_key_values, 558 - attention_mask=attention_mask, 559 - position_ids=position_ids, 560 - inputs_embeds=inputs_embeds, 561 - use_cache=use_cache, 562 - output_attentions=output_attentions, 563 - output_hidden_states=output_hidden_states, 564 - images=images, 565 - images_seq_mask = images_seq_mask, 566 - images_spatial_crop = images_spatial_crop, 567 - return_dict=return_dict 568 - 569 - ) 570 - 571 - 572 - 573 - # print(transformer_outputs) 574 - 575 - hidden_states = outputs[0] 576 - logits = self.lm_head(hidden_states) 577 - logits = logits.float() 578 - 579 - # logits 580 - 581 - loss = None 582 - if labels is not None: 583 - # Shift so that tokens < n predict n 584 - shift_logits = logits[..., :-1, :].contiguous() 585 - shift_labels = labels[..., 1:].contiguous() 586 - # Flatten the tokens 587 - loss_fct = CrossEntropyLoss() 588 - shift_logits = shift_logits.view(-1, self.config.vocab_size) 589 - shift_labels = shift_labels.view(-1) 590 - # Enable model parallelism 591 - shift_labels = shift_labels.to(shift_logits.device) 592 - loss = loss_fct(shift_logits, shift_labels) 593 - 594 - if not return_dict: 595 - output = (logits,) + outputs[1:] 596 - return (loss,) + output if loss is not None else output 597 - 598 - return CausalLMOutputWithPast( 599 - loss=loss, 600 - logits=logits, 601 - past_key_values=outputs.past_key_values, 602 - hidden_states=outputs.hidden_states, 603 - attentions=outputs.attentions, 604 - ) 605 - 606 - 607 - def prepare_inputs_for_generation( 608 - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 609 - ): 610 - # Omit tokens covered by past_key_values 611 - past_length = 0 612 - if past_key_values is not None: 613 - if isinstance(past_key_values, Cache): 614 - cache_length = past_key_values.get_seq_length() 615 - past_length = past_key_values.seen_tokens 616 - max_cache_length = past_key_values.get_max_length() 617 - else: 618 - cache_length = past_length = past_key_values[0][0].shape[2] 619 - max_cache_length = None 620 - 621 - # Keep only the unprocessed tokens: 622 - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where 623 - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as 624 - # input) 625 - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: 626 - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] 627 - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard 628 - # input_ids based on the past_length. 629 - elif past_length < input_ids.shape[1]: 630 - input_ids = input_ids[:, past_length:] 631 - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. 632 - 633 - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. 634 - if ( 635 - max_cache_length is not None 636 - and attention_mask is not None 637 - and cache_length + input_ids.shape[1] > max_cache_length 638 - ): 639 - attention_mask = attention_mask[:, -max_cache_length:] 640 - 641 - position_ids = kwargs.get("position_ids", None) 642 - if attention_mask is not None and position_ids is None: 643 - # create position_ids on the fly for batch generation 644 - position_ids = attention_mask.long().cumsum(-1) - 1 645 - position_ids.masked_fill_(attention_mask == 0, 1) 646 - if past_key_values: 647 - position_ids = position_ids[:, -input_ids.shape[1] :] 648 - 649 - # if self.generation_config.cache_implementation == "static": 650 - # # generation with static cache 651 - # cache_position = kwargs.get("cache_position", None) 652 - # if cache_position is None: 653 - # past_length = 0 654 - # else: 655 - # past_length = cache_position[-1] + 1 656 - # input_ids = input_ids[:, past_length:] 657 - # position_ids = position_ids[:, past_length:] 658 - 659 - # TODO @gante we should only keep a `cache_position` in generate, and do +=1. 660 - # same goes for position ids. Could also help with continued generation. 661 - cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) 662 - 663 - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 664 - if inputs_embeds is not None and past_key_values is None: 665 - model_inputs = {"inputs_embeds": inputs_embeds} 666 - else: 667 - model_inputs = {"input_ids": input_ids} 668 - 669 - model_inputs.update( 670 - { 671 - "position_ids": position_ids, 672 - "past_key_values": past_key_values, 673 - "use_cache": kwargs.get("use_cache"), 674 - "attention_mask": attention_mask, 675 - "images": kwargs.get("images", None), 676 - "images_seq_mask": kwargs.get("images_seq_mask", None), 677 - "images_spatial_crop": kwargs.get("images_spatial_crop", None), 678 - } 679 - ) 680 - return model_inputs 681 - 682 - 683 - def disable_torch_init(self): 684 - """ 685 - Disable the redundant torch default initialization to accelerate model creation. 686 - """ 687 - import torch 688 - setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 689 - setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 690 - 691 - 692 - 693 - def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False): 694 - self.disable_torch_init() 695 - 696 - os.makedirs(output_path, exist_ok=True) 697 - os.makedirs(f'{output_path}/images', exist_ok=True) 698 - 699 - if prompt and image_file: 700 - conversation = [ 701 - { 702 - "role": "<|User|>", 703 - # "content": "<image>\n<|grounding|>Given the layout of the image. ", 704 - "content": f'{prompt}', 705 - # "content": "君不见黄河之水天上来的下一句是什么?", 706 - # "content": "<image>\nFree OCR. ", 707 - # "content": "<image>\nParse the figure. ", 708 - # "content": "<image>\nExtract the text in the image. ", 709 - "images": [f'{image_file}'], 710 - }, 711 - {"role": "<|Assistant|>", "content": ""}, 712 - ] 713 - 714 - elif prompt: 715 - conversation = [ 716 - { 717 - "role": "<|User|>", 718 - # "content": "<image>\n<|grounding|>Given the layout of the image. ", 719 - "content": f'{prompt}', 720 - # "content": "君不见黄河之水天上来的下一句是什么?", 721 - # "content": "<image>\nFree OCR. ", 722 - # "content": "<image>\nParse the figure. ", 723 - # "content": "<image>\nExtract the text in the image. ", 724 - # "images": [f'{image_file}'], 725 - }, 726 - {"role": "<|Assistant|>", "content": ""}, 727 - ] 728 - else: 729 - assert False, f'prompt is none!' 730 - 731 - prompt = format_messages(conversations=conversation, sft_format='plain', system_prompt='') 732 - 733 - patch_size = 16 734 - downsample_ratio = 4 735 - images = load_pil_images(conversation) 736 - 737 - valid_img_tokens = 0 738 - ratio = 1 739 - 740 - image_draw = images[0].copy() 741 - 742 - w,h = image_draw.size 743 - # print(w, h) 744 - ratio = 1 - ((max(w, h) - min(w, h)) / (max(w, h))) 745 - 746 - 747 - image_transform=BasicImageTransform(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), normalize=True) 748 - images_seq_mask = [] 749 - 750 - image_token = '<image>' 751 - image_token_id = 128815 752 - text_splits = prompt.split(image_token) 753 - 754 - images_list, images_crop_list, images_seq_mask = [], [], [] 755 - tokenized_str = [] 756 - images_spatial_crop = [] 757 - for text_sep, image in zip(text_splits, images): 758 - 759 - tokenized_sep = text_encode(tokenizer, text_sep, bos=False, eos=False) 760 - tokenized_str += tokenized_sep 761 - images_seq_mask += [False] * len(tokenized_sep) 762 - 763 - if crop_mode: 764 - 765 - if image.size[0] <= 768 and image.size[1] <= 768: 766 - crop_ratio = [1, 1] 767 - 768 - else: 769 - if crop_mode: 770 - # best_width, best_height = select_best_resolution(image.size, self.candidate_resolutions) 771 - images_crop_raw, crop_ratio = dynamic_preprocess(image) 772 - else: 773 - # best_width, best_height = self.image_size, self.image_size 774 - crop_ratio = [1, 1] 775 - 776 - """process the global view""" 777 - # image = image.resize((base_size, base_size)) 778 - global_view = ImageOps.pad(image, (base_size, base_size), 779 - color=tuple(int(x * 255) for x in image_transform.mean)) 780 - 781 - if base_size == 1024: 782 - valid_img_tokens += int(256 * ratio) 783 - elif base_size == 1280: 784 - valid_img_tokens += int(400 * ratio) 785 - # elif base_size == 640: 786 - # valid_img_tokens += int(100 * ratio) 787 - 788 - 789 - 790 - 791 - 792 - images_list.append(image_transform(global_view).to(torch.bfloat16)) 793 - 794 - # global_view_tensor = image_transform(global_view).to(torch.bfloat16) 795 - 796 - width_crop_num, height_crop_num = crop_ratio 797 - 798 - images_spatial_crop.append([width_crop_num, height_crop_num]) 799 - 800 - 801 - if width_crop_num > 1 or height_crop_num > 1: 802 - """process the local views""" 803 - 804 - for i in range(len(images_crop_raw)): 805 - images_crop_list.append(image_transform(images_crop_raw[i]).to(torch.bfloat16)) 806 - 807 - if image_size == 768: 808 - valid_img_tokens += len(images_crop_list) * 144 809 - 810 - num_queries = math.ceil((image_size // patch_size) / downsample_ratio) 811 - num_queries_base = math.ceil((base_size // patch_size) / downsample_ratio) 812 - 813 - 814 - 815 - """add image tokens""" 816 - 817 - 818 - 819 - tokenized_image = ([image_token_id] * num_queries_base) * num_queries_base 820 - tokenized_image += [image_token_id] 821 - if width_crop_num > 1 or height_crop_num > 1: 822 - tokenized_image += ([image_token_id] * (num_queries * width_crop_num)) * ( 823 - num_queries * height_crop_num) 824 - tokenized_str += tokenized_image 825 - images_seq_mask += [True] * len(tokenized_image) 826 - # num_image_tokens.append(len(tokenized_image)) 827 - 828 - else: 829 - # best_width, best_height = self.image_size, self.image_size 830 - # print(image.size, (best_width, best_height)) # check the select_best_resolutions func 831 - 832 - """process the global view""" 833 - if image_size <= 768: 834 - print('directly resize') 835 - image = image.resize((image_size, image_size)) 836 - # else: 837 - global_view = ImageOps.pad(image, (image_size, image_size), 838 - color=tuple(int(x * 255) for x in image_transform.mean)) 839 - images_list.append(image_transform(global_view).to(torch.bfloat16)) 840 - 841 - if base_size == 1024: 842 - valid_img_tokens += int(256 * ratio) 843 - elif base_size == 1280: 844 - valid_img_tokens += int(400 * ratio) 845 - elif base_size == 640: 846 - valid_img_tokens += int(100 * 1) 847 - elif base_size == 512: 848 - valid_img_tokens += int(64 * 1) 849 - elif base_size == 768: 850 - valid_img_tokens += int(144 * 1) 851 - 852 - width_crop_num, height_crop_num = 1, 1 853 - 854 - images_spatial_crop.append([width_crop_num, height_crop_num]) 855 - 856 - 857 - """add image tokens""" 858 - num_queries = math.ceil((image_size // patch_size) / downsample_ratio) 859 - 860 - tokenized_image = ([image_token_id] * num_queries) * num_queries 861 - tokenized_image += [image_token_id] 862 - # tokenized_image += ([self.image_token_id] * (num_queries * width_crop_num) + [self.image_token_id]) * ( 863 - # num_queries * height_crop_num) 864 - tokenized_str += tokenized_image 865 - images_seq_mask += [True] * len(tokenized_image) 866 - # num_image_tokens.append(len(tokenized_image)) 867 - 868 - 869 - """process the last text split""" 870 - tokenized_sep = text_encode(tokenizer, text_splits[-1], bos=False, eos=False) 871 - tokenized_str += tokenized_sep 872 - images_seq_mask += [False] * len(tokenized_sep) 873 - 874 - """add the bos tokens""" 875 - bos_id = 0 876 - tokenized_str = [bos_id] + tokenized_str 877 - images_seq_mask = [False] + images_seq_mask 878 - 879 - 880 - 881 - input_ids = torch.LongTensor(tokenized_str) 882 - 883 - 884 - 885 - 886 - images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool) 887 - 888 - 889 - if len(images_list) == 0: 890 - images_ori = torch.zeros((1, 3, image_size, image_size)) 891 - images_spatial_crop = torch.zeros((1, 2), dtype=torch.long) 892 - images_crop = torch.zeros((1, 3, base_size, base_size)) 893 - 894 - else: 895 - images_ori = torch.stack(images_list, dim=0) 896 - images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) 897 - if images_crop_list: 898 - images_crop = torch.stack(images_crop_list, dim=0) 899 - else: 900 - images_crop = torch.zeros((1, 3, base_size, base_size)) 901 - 902 - 903 - 904 - if not eval_mode: 905 - streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False) 906 - with torch.autocast("cuda", dtype=torch.bfloat16): 907 - with torch.no_grad(): 908 - output_ids = self.generate( 909 - input_ids.unsqueeze(0).cuda(), 910 - images=[(images_crop.cuda(), images_ori.cuda())], 911 - images_seq_mask = images_seq_mask.unsqueeze(0).cuda(), 912 - images_spatial_crop = images_spatial_crop, 913 - # do_sample=False, 914 - # num_beams = 1, 915 - temperature=0.0, 916 - eos_token_id=tokenizer.eos_token_id, 917 - streamer=streamer, 918 - max_new_tokens=8192, 919 - no_repeat_ngram_size = 20, 920 - use_cache = True 921 - ) 922 - 923 - else: 924 - with torch.autocast("cuda", dtype=torch.bfloat16): 925 - with torch.no_grad(): 926 - output_ids = self.generate( 927 - input_ids.unsqueeze(0).cuda(), 928 - images=[(images_crop.cuda(), images_ori.cuda())], 929 - images_seq_mask = images_seq_mask.unsqueeze(0).cuda(), 930 - images_spatial_crop = images_spatial_crop, 931 - # do_sample=False, 932 - # num_beams = 1, 933 - temperature=0.0, 934 - eos_token_id=tokenizer.eos_token_id, 935 - max_new_tokens=8192, 936 - no_repeat_ngram_size = 35, 937 - use_cache = True 938 - ) 939 - 940 - 941 - if '<image>' in conversation[0]['content'] and eval_mode: 942 - outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:]) 943 - stop_str = '<|end▁of▁sentence|>' 944 - if outputs.endswith(stop_str): 945 - outputs = outputs[:-len(stop_str)] 946 - # re_match 947 - outputs = outputs.strip() 948 - 949 - return outputs 950 - 951 - if '<image>' in conversation[0]['content'] and test_compress: 952 - outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:]) 953 - pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False)) 954 - print('='*50) 955 - print('image size: ', (w, h)) 956 - print('valid image tokens: ', int(valid_img_tokens)) 957 - print('output texts tokens (valid): ', pure_texts_outputs_token_length) 958 - print('compression ratio: ', round(pure_texts_outputs_token_length/valid_img_tokens, 2)) 959 - print('='*50) 960 - 961 - 962 - if '<image>' in conversation[0]['content'] and save_results: 963 - outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:]) 964 - stop_str = '<|end▁of▁sentence|>' 965 - 966 - print('='*15 + 'save results:' + '='*15) 967 - 968 - # # # # conv.messages[-1][-1] = outputs 969 - if outputs.endswith(stop_str): 970 - outputs = outputs[:-len(stop_str)] 971 - outputs = outputs.strip() 972 - 973 - matches_ref, matches_images, mathes_other = re_match(outputs) 974 - # print(matches_ref) 975 - result = process_image_with_refs(image_draw, matches_ref, output_path) 976 - 977 - 978 - for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")): 979 - outputs = outputs.replace(a_match_image, '![](images/' + str(idx) + '.jpg)\n') 980 - 981 - for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")): 982 - outputs = outputs.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:') 983 - 984 - 985 - # if 'structural formula' in conversation[0]['content']: 986 - # outputs = '<smiles>' + outputs + '</smiles>' 987 - with open(f'{output_path}/result.mmd', 'w', encoding = 'utf-8') as afile: 988 - afile.write(outputs) 989 - 990 - if 'line_type' in outputs: 991 - import matplotlib.pyplot as plt 992 - lines = eval(outputs)['Line']['line'] 993 - 994 - line_type = eval(outputs)['Line']['line_type'] 995 - # print(lines) 996 - 997 - endpoints = eval(outputs)['Line']['line_endpoint'] 998 - 999 - fig, ax = plt.subplots(figsize=(3,3), dpi=200) 1000 - ax.set_xlim(-15, 15) 1001 - ax.set_ylim(-15, 15) 1002 - 1003 - for idx, line in enumerate(lines): 1004 - try: 1005 - p0 = eval(line.split(' -- ')[0]) 1006 - p1 = eval(line.split(' -- ')[-1]) 1007 - 1008 - if line_type[idx] == '--': 1009 - ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color='k') 1010 - else: 1011 - ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth = 0.8, color = 'k') 1012 - 1013 - ax.scatter(p0[0], p0[1], s=5, color = 'k') 1014 - ax.scatter(p1[0], p1[1], s=5, color = 'k') 1015 - except: 1016 - pass 1017 - 1018 - for endpoint in endpoints: 1019 - 1020 - label = endpoint.split(': ')[0] 1021 - (x, y) = eval(endpoint.split(': ')[1]) 1022 - ax.annotate(label, (x, y), xytext=(1, 1), textcoords='offset points', 1023 - fontsize=5, fontweight='light') 1024 - 1025 - 1026 - plt.savefig(f'{output_path}/geo.jpg') 1027 - plt.close() 1028 - 1029 - result.save(f"{output_path}/result_with_boxes.jpg")
-1992
src/deepseek_ocr2/modeling_deepseekv2.py
··· 1 - # coding=utf-8 2 - # Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. 3 - # 4 - # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 - # and OPT implementations in this library. It has been modified from its 6 - # original forms to accommodate minor architectural differences compared 7 - # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 - # 9 - # Licensed under the Apache License, Version 2.0 (the "License"); 10 - # you may not use this file except in compliance with the License. 11 - # You may obtain a copy of the License at 12 - # 13 - # http://www.apache.org/licenses/LICENSE-2.0 14 - # 15 - # Unless required by applicable law or agreed to in writing, software 16 - # distributed under the License is distributed on an "AS IS" BASIS, 17 - # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 - # See the License for the specific language governing permissions and 19 - # limitations under the License. 20 - """ PyTorch DeepSeek model and compatible with both DeepSeekV2 and DeepSeekV3""" 21 - import math 22 - import warnings 23 - from typing import List, Optional, Tuple, Union 24 - import numpy as np 25 - 26 - import torch 27 - import torch.nn.functional as F 28 - import torch.utils.checkpoint 29 - import torch.distributed as dist 30 - from einops import repeat 31 - from torch import nn 32 - from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 33 - 34 - from transformers.activations import ACT2FN 35 - from transformers.cache_utils import Cache, DynamicCache 36 - from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask 37 - from transformers.models.llama.modeling_llama import ( 38 - LlamaAttention, 39 - LlamaFlashAttention2 40 - ) 41 - from transformers.modeling_outputs import ( 42 - BaseModelOutputWithPast, 43 - CausalLMOutputWithPast, 44 - SequenceClassifierOutputWithPast, 45 - ) 46 - from transformers.modeling_utils import PreTrainedModel 47 - from transformers.pytorch_utils import ( 48 - ALL_LAYERNORM_LAYERS, 49 - is_torch_greater_or_equal_than_1_13, 50 - ) 51 - from transformers.utils import ( 52 - add_start_docstrings, 53 - add_start_docstrings_to_model_forward, 54 - is_flash_attn_2_available, 55 - is_flash_attn_greater_or_equal_2_10, 56 - logging, 57 - replace_return_docstrings, 58 - ) 59 - from transformers.utils.import_utils import is_torch_fx_available 60 - 61 - from .configuration_deepseek_v2 import DeepseekV2Config 62 - 63 - if is_flash_attn_2_available(): 64 - from flash_attn import flash_attn_func, flash_attn_varlen_func 65 - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa 66 - 67 - # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. 68 - # It means that the function will not be traced through and simply appear as a node in the graph. 69 - if is_torch_fx_available(): 70 - if not is_torch_greater_or_equal_than_1_13: 71 - import torch.fx 72 - 73 - _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) 74 - 75 - logger = logging.get_logger(__name__) 76 - 77 - _CONFIG_FOR_DOC = "DeepseekV2Config" 78 - 79 - 80 - def _get_unpad_data(attention_mask): 81 - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) 82 - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() 83 - max_seqlen_in_batch = seqlens_in_batch.max().item() 84 - cu_seqlens = F.pad( 85 - torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) 86 - ) 87 - return ( 88 - indices, 89 - cu_seqlens, 90 - max_seqlen_in_batch, 91 - ) 92 - 93 - 94 - class DeepseekV2RMSNorm(nn.Module): 95 - def __init__(self, hidden_size, eps=1e-6): 96 - """ 97 - DeepseekV2RMSNorm is equivalent to T5LayerNorm 98 - """ 99 - super().__init__() 100 - self.weight = nn.Parameter(torch.ones(hidden_size)) 101 - self.variance_epsilon = eps 102 - 103 - def forward(self, hidden_states): 104 - input_dtype = hidden_states.dtype 105 - hidden_states = hidden_states.to(torch.float32) 106 - variance = hidden_states.pow(2).mean(-1, keepdim=True) 107 - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 108 - return self.weight * hidden_states.to(input_dtype) 109 - 110 - 111 - ALL_LAYERNORM_LAYERS.append(DeepseekV2RMSNorm) 112 - 113 - 114 - 115 - 116 - class DeepseekV2RotaryEmbedding(nn.Module): 117 - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): 118 - super().__init__() 119 - 120 - self.dim = dim 121 - self.max_position_embeddings = max_position_embeddings 122 - self.base = base 123 - inv_freq = 1.0 / ( 124 - self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) 125 - ) 126 - self.register_buffer("inv_freq", inv_freq, persistent=False) 127 - 128 - # Build here to make `torch.jit.trace` work. 129 - self._set_cos_sin_cache( 130 - seq_len=max_position_embeddings, 131 - device=self.inv_freq.device, 132 - dtype=torch.get_default_dtype(), 133 - ) 134 - self.max_seq_len_cached = None 135 - 136 - def _set_cos_sin_cache(self, seq_len, device, dtype): 137 - self.max_seq_len_cached = seq_len 138 - t = torch.arange( 139 - self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype 140 - ) 141 - 142 - freqs = torch.outer(t, self.inv_freq.to(t.device)) 143 - # Different from paper, but it uses a different permutation in order to obtain the same calculation 144 - emb = torch.cat((freqs, freqs), dim=-1) 145 - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) 146 - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) 147 - 148 - def forward(self, x, seq_len=None): 149 - # x: [bs, num_attention_heads, seq_len, head_size] 150 - if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: 151 - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) 152 - 153 - return ( 154 - self.cos_cached[:seq_len].to(dtype=x.dtype), 155 - self.sin_cached[:seq_len].to(dtype=x.dtype), 156 - ) 157 - 158 - 159 - # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV2 160 - class DeepseekV2LinearScalingRotaryEmbedding(DeepseekV2RotaryEmbedding): 161 - """DeepseekV2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" 162 - 163 - def __init__( 164 - self, 165 - dim, 166 - max_position_embeddings=2048, 167 - base=10000, 168 - device=None, 169 - scaling_factor=1.0, 170 - ): 171 - self.scaling_factor = scaling_factor 172 - super().__init__(dim, max_position_embeddings, base, device) 173 - 174 - def _set_cos_sin_cache(self, seq_len, device, dtype): 175 - self.max_seq_len_cached = seq_len 176 - t = torch.arange( 177 - self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype 178 - ) 179 - t = t / self.scaling_factor 180 - 181 - freqs = torch.outer(t, self.inv_freq) 182 - # Different from paper, but it uses a different permutation in order to obtain the same calculation 183 - emb = torch.cat((freqs, freqs), dim=-1) 184 - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) 185 - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) 186 - 187 - 188 - # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV2 189 - class DeepseekV2DynamicNTKScalingRotaryEmbedding(DeepseekV2RotaryEmbedding): 190 - """DeepseekV2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" 191 - 192 - def __init__( 193 - self, 194 - dim, 195 - max_position_embeddings=2048, 196 - base=10000, 197 - device=None, 198 - scaling_factor=1.0, 199 - ): 200 - self.scaling_factor = scaling_factor 201 - super().__init__(dim, max_position_embeddings, base, device) 202 - 203 - def _set_cos_sin_cache(self, seq_len, device, dtype): 204 - self.max_seq_len_cached = seq_len 205 - 206 - if seq_len > self.max_position_embeddings: 207 - base = self.base * ( 208 - (self.scaling_factor * seq_len / self.max_position_embeddings) 209 - - (self.scaling_factor - 1) 210 - ) ** (self.dim / (self.dim - 2)) 211 - inv_freq = 1.0 / ( 212 - base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) 213 - ) 214 - self.register_buffer("inv_freq", inv_freq, persistent=False) 215 - 216 - t = torch.arange( 217 - self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype 218 - ) 219 - 220 - freqs = torch.outer(t, self.inv_freq) 221 - # Different from paper, but it uses a different permutation in order to obtain the same calculation 222 - emb = torch.cat((freqs, freqs), dim=-1) 223 - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) 224 - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) 225 - 226 - 227 - # Inverse dim formula to find dim based on number of rotations 228 - def yarn_find_correction_dim( 229 - num_rotations, dim, base=10000, max_position_embeddings=2048 230 - ): 231 - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( 232 - 2 * math.log(base) 233 - ) 234 - 235 - 236 - # Find dim range bounds based on rotations 237 - def yarn_find_correction_range( 238 - low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 239 - ): 240 - low = math.floor( 241 - yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) 242 - ) 243 - high = math.ceil( 244 - yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) 245 - ) 246 - return max(low, 0), min(high, dim - 1) # Clamp values just in case 247 - 248 - 249 - def yarn_get_mscale(scale=1, mscale=1): 250 - if scale <= 1: 251 - return 1.0 252 - return 0.1 * mscale * math.log(scale) + 1.0 253 - 254 - 255 - def yarn_linear_ramp_mask(min, max, dim): 256 - if min == max: 257 - max += 0.001 # Prevent singularity 258 - 259 - linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) 260 - ramp_func = torch.clamp(linear_func, 0, 1) 261 - return ramp_func 262 - 263 - 264 - class DeepseekV2YarnRotaryEmbedding(DeepseekV2RotaryEmbedding): 265 - 266 - def __init__( 267 - self, 268 - dim, 269 - max_position_embeddings=2048, 270 - base=10000, 271 - device=None, 272 - scaling_factor=1.0, 273 - original_max_position_embeddings=4096, 274 - beta_fast=32, 275 - beta_slow=1, 276 - mscale=1, 277 - mscale_all_dim=0, 278 - ): 279 - self.scaling_factor = scaling_factor 280 - self.original_max_position_embeddings = original_max_position_embeddings 281 - self.beta_fast = beta_fast 282 - self.beta_slow = beta_slow 283 - self.mscale = mscale 284 - self.mscale_all_dim = mscale_all_dim 285 - super().__init__(dim, max_position_embeddings, base, device) 286 - 287 - def _set_cos_sin_cache(self, seq_len, device, dtype): 288 - self.max_seq_len_cached = seq_len 289 - dim = self.dim 290 - 291 - freq_extra = 1.0 / ( 292 - self.base 293 - ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) 294 - ) 295 - freq_inter = 1.0 / ( 296 - self.scaling_factor 297 - * self.base 298 - ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) 299 - ) 300 - 301 - low, high = yarn_find_correction_range( 302 - self.beta_fast, 303 - self.beta_slow, 304 - dim, 305 - self.base, 306 - self.original_max_position_embeddings, 307 - ) 308 - inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( 309 - device=device, dtype=torch.float32 310 - ) 311 - inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask 312 - self.register_buffer("inv_freq", inv_freq, persistent=False) 313 - 314 - t = torch.arange(seq_len, device=device, dtype=torch.float32) 315 - 316 - freqs = torch.outer(t, inv_freq) 317 - 318 - _mscale = float( 319 - yarn_get_mscale(self.scaling_factor, self.mscale) 320 - / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) 321 - ) 322 - 323 - emb = torch.cat((freqs, freqs), dim=-1) 324 - self.register_buffer( 325 - "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False 326 - ) 327 - self.register_buffer( 328 - "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False 329 - ) 330 - 331 - 332 - # Copied from transformers.models.llama.modeling_llama.rotate_half 333 - def rotate_half(x): 334 - """Rotates half the hidden dims of the input.""" 335 - x1 = x[..., : x.shape[-1] // 2] 336 - x2 = x[..., x.shape[-1] // 2 :] 337 - return torch.cat((-x2, x1), dim=-1) 338 - 339 - 340 - # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb 341 - def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): 342 - """Applies Rotary Position Embedding to the query and key tensors. 343 - 344 - Args: 345 - q (`torch.Tensor`): The query tensor. 346 - k (`torch.Tensor`): The key tensor. 347 - cos (`torch.Tensor`): The cosine part of the rotary embedding. 348 - sin (`torch.Tensor`): The sine part of the rotary embedding. 349 - position_ids (`torch.Tensor`): 350 - The position indices of the tokens corresponding to the query and key tensors. For example, this can be 351 - used to pass offsetted position ids when working with a KV-cache. 352 - unsqueeze_dim (`int`, *optional*, defaults to 1): 353 - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and 354 - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note 355 - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and 356 - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes 357 - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have 358 - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. 359 - Returns: 360 - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. 361 - """ 362 - cos = cos[position_ids].unsqueeze(unsqueeze_dim) 363 - sin = sin[position_ids].unsqueeze(unsqueeze_dim) 364 - 365 - 366 - # print() 367 - 368 - b, h, s, d = q.shape 369 - q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) 370 - 371 - b, h, s, d = k.shape 372 - k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) 373 - 374 - q_embed = (q * cos) + (rotate_half(q) * sin) 375 - k_embed = (k * cos) + (rotate_half(k) * sin) 376 - 377 - 378 - return q_embed, k_embed 379 - 380 - 381 - class DeepseekV2MLP(nn.Module): 382 - def __init__(self, config, hidden_size=None, intermediate_size=None): 383 - super().__init__() 384 - self.config = config 385 - self.hidden_size = config.hidden_size if hidden_size is None else hidden_size 386 - self.intermediate_size = ( 387 - config.intermediate_size if intermediate_size is None else intermediate_size 388 - ) 389 - 390 - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 391 - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 392 - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) 393 - self.act_fn = ACT2FN[config.hidden_act] 394 - 395 - def forward(self, x): 396 - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) 397 - return down_proj 398 - 399 - 400 - class MoEGate(nn.Module): 401 - def __init__(self, config): 402 - super().__init__() 403 - self.config = config 404 - self.top_k = config.num_experts_per_tok 405 - self.n_routed_experts = config.n_routed_experts 406 - self.routed_scaling_factor = config.routed_scaling_factor 407 - self.scoring_func = config.scoring_func 408 - self.alpha = config.aux_loss_alpha 409 - self.seq_aux = config.seq_aux 410 - self.topk_method = config.topk_method 411 - self.n_group = config.n_group 412 - self.topk_group = config.topk_group 413 - 414 - # topk selection algorithm 415 - self.norm_topk_prob = config.norm_topk_prob 416 - self.gating_dim = config.hidden_size 417 - self.weight = nn.Parameter( 418 - torch.empty((self.n_routed_experts, self.gating_dim)) 419 - ) 420 - if self.topk_method == "noaux_tc": 421 - self.e_score_correction_bias = nn.Parameter( 422 - torch.empty((self.n_routed_experts)) 423 - ) 424 - self.reset_parameters() 425 - 426 - def reset_parameters(self) -> None: 427 - import torch.nn.init as init 428 - 429 - init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 430 - 431 - def forward(self, hidden_states): 432 - bsz, seq_len, h = hidden_states.shape 433 - ### compute gating score 434 - hidden_states = hidden_states.view(-1, h) 435 - logits = F.linear( 436 - hidden_states.type(torch.float32), self.weight.type(torch.float32), None 437 - ) 438 - if self.scoring_func == "softmax": 439 - scores = logits.softmax(dim=-1, dtype=torch.float32) 440 - elif self.scoring_func == "sigmoid": 441 - scores = logits.sigmoid() 442 - else: 443 - raise NotImplementedError( 444 - f"insupportable scoring function for MoE gating: {self.scoring_func}" 445 - ) 446 - 447 - ### select top-k experts 448 - if self.topk_method == "greedy": 449 - topk_weight, topk_idx = torch.topk( 450 - scores, k=self.top_k, dim=-1, sorted=False 451 - ) 452 - elif self.topk_method == "group_limited_greedy": 453 - group_scores = ( 454 - scores.view(bsz * seq_len, self.n_group, -1).max(dim=-1).values 455 - ) # [n, n_group] 456 - group_idx = torch.topk( 457 - group_scores, k=self.topk_group, dim=-1, sorted=False 458 - )[ 459 - 1 460 - ] # [n, top_k_group] 461 - group_mask = torch.zeros_like(group_scores) # [n, n_group] 462 - group_mask.scatter_(1, group_idx, 1) # [n, n_group] 463 - score_mask = ( 464 - group_mask.unsqueeze(-1) 465 - .expand( 466 - bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group 467 - ) 468 - .reshape(bsz * seq_len, -1) 469 - ) # [n, e] 470 - tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] 471 - topk_weight, topk_idx = torch.topk( 472 - tmp_scores, k=self.top_k, dim=-1, sorted=False 473 - ) 474 - elif self.topk_method == "noaux_tc": 475 - assert not self.training 476 - scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0) 477 - group_scores = ( 478 - scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim = -1) 479 - ) # [n, n_group] 480 - group_idx = torch.topk( 481 - group_scores, k=self.topk_group, dim=-1, sorted=False 482 - )[ 483 - 1 484 - ] # [n, top_k_group] 485 - group_mask = torch.zeros_like(group_scores) # [n, n_group] 486 - group_mask.scatter_(1, group_idx, 1) # [n, n_group] 487 - score_mask = ( 488 - group_mask.unsqueeze(-1) 489 - .expand( 490 - bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group 491 - ) 492 - .reshape(bsz * seq_len, -1) 493 - ) # [n, e] 494 - tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e] 495 - _, topk_idx = torch.topk( 496 - tmp_scores, k=self.top_k, dim=-1, sorted=False 497 - ) 498 - topk_weight = scores.gather(1, topk_idx) 499 - 500 - ### norm gate to sum 1 501 - if self.top_k > 1 and self.norm_topk_prob: 502 - denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 503 - topk_weight = topk_weight / denominator * self.routed_scaling_factor 504 - else: 505 - topk_weight = topk_weight * self.routed_scaling_factor 506 - ### expert-level computation auxiliary loss 507 - if self.training and self.alpha > 0.0: 508 - scores_for_aux = scores 509 - aux_topk = self.top_k 510 - # always compute aux loss based on the naive greedy topk method 511 - topk_idx_for_aux_loss = topk_idx.view(bsz, -1) 512 - if self.seq_aux: 513 - scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) 514 - ce = torch.zeros( 515 - bsz, self.n_routed_experts, device=hidden_states.device 516 - ) 517 - ce.scatter_add_( 518 - 1, 519 - topk_idx_for_aux_loss, 520 - torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device), 521 - ).div_(seq_len * aux_topk / self.n_routed_experts) 522 - aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum( 523 - dim=1 524 - ).mean() * self.alpha 525 - else: 526 - mask_ce = F.one_hot( 527 - topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts 528 - ) 529 - ce = mask_ce.float().mean(0) 530 - Pi = scores_for_aux.mean(0) 531 - fi = ce * self.n_routed_experts 532 - aux_loss = (Pi * fi).sum() * self.alpha 533 - else: 534 - aux_loss = None 535 - return topk_idx, topk_weight, aux_loss 536 - 537 - 538 - class AddAuxiliaryLoss(torch.autograd.Function): 539 - """ 540 - The trick function of adding auxiliary (aux) loss, 541 - which includes the gradient of the aux loss during backpropagation. 542 - """ 543 - 544 - @staticmethod 545 - def forward(ctx, x, loss): 546 - assert loss.numel() == 1 547 - ctx.dtype = loss.dtype 548 - ctx.required_aux_loss = loss.requires_grad 549 - return x 550 - 551 - @staticmethod 552 - def backward(ctx, grad_output): 553 - grad_loss = None 554 - if ctx.required_aux_loss: 555 - grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) 556 - return grad_output, grad_loss 557 - 558 - 559 - class DeepseekV2MoE(nn.Module): 560 - """ 561 - A mixed expert module containing shared experts. 562 - """ 563 - 564 - def __init__(self, config): 565 - super().__init__() 566 - self.config = config 567 - self.num_experts_per_tok = config.num_experts_per_tok 568 - 569 - if hasattr(config, "ep_size") and config.ep_size > 1: 570 - assert config.ep_size == dist.get_world_size() 571 - self.ep_size = config.ep_size 572 - self.experts_per_rank = config.n_routed_experts // config.ep_size 573 - self.ep_rank = dist.get_rank() 574 - self.experts = nn.ModuleList( 575 - [ 576 - ( 577 - DeepseekV2MLP( 578 - config, intermediate_size=config.moe_intermediate_size 579 - ) 580 - if i >= self.ep_rank * self.experts_per_rank 581 - and i < (self.ep_rank + 1) * self.experts_per_rank 582 - else None 583 - ) 584 - for i in range(config.n_routed_experts) 585 - ] 586 - ) 587 - else: 588 - self.ep_size = 1 589 - self.experts_per_rank = config.n_routed_experts 590 - self.ep_rank = 0 591 - self.experts = nn.ModuleList( 592 - [ 593 - DeepseekV2MLP( 594 - config, intermediate_size=config.moe_intermediate_size 595 - ) 596 - for i in range(config.n_routed_experts) 597 - ] 598 - ) 599 - self.gate = MoEGate(config) 600 - if config.n_shared_experts is not None: 601 - intermediate_size = config.moe_intermediate_size * config.n_shared_experts 602 - self.shared_experts = DeepseekV2MLP( 603 - config=config, intermediate_size=intermediate_size 604 - ) 605 - 606 - def forward(self, hidden_states): 607 - identity = hidden_states 608 - orig_shape = hidden_states.shape 609 - topk_idx, topk_weight, aux_loss = self.gate(hidden_states) 610 - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) 611 - flat_topk_idx = topk_idx.view(-1) 612 - if self.training: 613 - hidden_states = hidden_states.repeat_interleave( 614 - self.num_experts_per_tok, dim=0 615 - ) 616 - y = torch.empty_like(hidden_states) 617 - for i, expert in enumerate(self.experts): 618 - y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i]) 619 - y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) 620 - y = y.to(hidden_states.dtype).view(*orig_shape) 621 - y = AddAuxiliaryLoss.apply(y, aux_loss) 622 - else: 623 - y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape) 624 - if self.config.n_shared_experts is not None: 625 - y = y + self.shared_experts(identity) 626 - return y 627 - 628 - @torch.no_grad() 629 - def moe_infer(self, x, topk_ids, topk_weight): 630 - cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) 631 - cnts.scatter_(1, topk_ids, 1) 632 - tokens_per_expert = cnts.sum(dim=0) 633 - idxs = topk_ids.view(-1).argsort() 634 - sorted_tokens = x[idxs // topk_ids.shape[1]] 635 - sorted_tokens_shape = sorted_tokens.shape 636 - if self.ep_size > 1: 637 - tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) 638 - tokens_per_expert_group = tokens_per_expert.new_empty( 639 - tokens_per_expert.shape[0] 640 - ) 641 - dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert) 642 - output_splits = ( 643 - tokens_per_expert_group.view(self.ep_size, -1) 644 - .sum(1) 645 - .cpu() 646 - .numpy() 647 - .tolist() 648 - ) 649 - gathered_tokens = sorted_tokens.new_empty( 650 - tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1] 651 - ) 652 - input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist() 653 - dist.all_to_all( 654 - list(gathered_tokens.split(output_splits)), 655 - list(sorted_tokens.split(input_split_sizes)), 656 - ) 657 - tokens_per_expert_post_gather = tokens_per_expert_group.view( 658 - self.ep_size, self.experts_per_rank 659 - ).sum(dim=0) 660 - gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) 661 - s = 0 662 - for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): 663 - gatherd_idxs[s : s + k] = i % self.experts_per_rank 664 - s += k 665 - gatherd_idxs = gatherd_idxs.argsort() 666 - sorted_tokens = gathered_tokens[gatherd_idxs] 667 - tokens_per_expert = tokens_per_expert_post_gather 668 - tokens_per_expert = tokens_per_expert.cpu().numpy() 669 - 670 - outputs = [] 671 - start_idx = 0 672 - for i, num_tokens in enumerate(tokens_per_expert): 673 - end_idx = start_idx + num_tokens 674 - if num_tokens == 0: 675 - continue 676 - expert = self.experts[i + self.ep_rank * self.experts_per_rank] 677 - tokens_for_this_expert = sorted_tokens[start_idx:end_idx] 678 - expert_out = expert(tokens_for_this_expert) 679 - outputs.append(expert_out) 680 - start_idx = end_idx 681 - 682 - outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) 683 - if self.ep_size > 1: 684 - new_x = torch.empty_like(outs) 685 - new_x[gatherd_idxs] = outs 686 - gathered_tokens = new_x.new_empty(*sorted_tokens_shape) 687 - dist.all_to_all( 688 - list(gathered_tokens.split(input_split_sizes)), 689 - list(new_x.split(output_splits)), 690 - ) 691 - outs = gathered_tokens 692 - 693 - new_x = torch.empty_like(outs) 694 - new_x[idxs] = outs 695 - final_out = ( 696 - new_x.view(*topk_ids.shape, -1) 697 - .type(topk_weight.dtype) 698 - .mul_(topk_weight.unsqueeze(dim=-1)) 699 - .sum(dim=1) 700 - .type(new_x.dtype) 701 - ) 702 - return final_out 703 - 704 - 705 - # Copied from transformers.models.llama.modeling_llama.repeat_kv 706 - def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 707 - """ 708 - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, 709 - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) 710 - """ 711 - batch, num_key_value_heads, slen, head_dim = hidden_states.shape 712 - if n_rep == 1: 713 - return hidden_states 714 - hidden_states = hidden_states[:, :, None, :, :].expand( 715 - batch, num_key_value_heads, n_rep, slen, head_dim 716 - ) 717 - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) 718 - 719 - 720 - # Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV2 721 - class DeepseekV2Attention(nn.Module): 722 - """Multi-headed attention from 'Attention Is All You Need' paper""" 723 - 724 - def __init__(self, config: DeepseekV2Config, layer_idx: Optional[int] = None): 725 - super().__init__() 726 - self.config = config 727 - self.layer_idx = layer_idx 728 - if layer_idx is None: 729 - logger.warning_once( 730 - f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " 731 - "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " 732 - "when creating this class." 733 - ) 734 - 735 - self.attention_dropout = config.attention_dropout 736 - self.hidden_size = config.hidden_size 737 - self.num_heads = config.num_attention_heads 738 - 739 - self.max_position_embeddings = config.max_position_embeddings 740 - self.rope_theta = config.rope_theta 741 - self.q_lora_rank = config.q_lora_rank 742 - self.qk_rope_head_dim = config.qk_rope_head_dim 743 - self.kv_lora_rank = config.kv_lora_rank 744 - self.v_head_dim = config.v_head_dim 745 - self.qk_nope_head_dim = config.qk_nope_head_dim 746 - self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim 747 - 748 - self.is_causal = True 749 - 750 - if self.q_lora_rank is None: 751 - self.q_proj = nn.Linear( 752 - self.hidden_size, self.num_heads * self.q_head_dim, bias=False 753 - ) 754 - else: 755 - self.q_a_proj = nn.Linear( 756 - self.hidden_size, config.q_lora_rank, bias=config.attention_bias 757 - ) 758 - self.q_a_layernorm = DeepseekV2RMSNorm(config.q_lora_rank) 759 - self.q_b_proj = nn.Linear( 760 - config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False 761 - ) 762 - # config.kv_lora_rank + config.qk_rope_head_dim, 763 - self.kv_a_proj_with_mqa = nn.Linear( 764 - self.hidden_size, 765 - config.kv_lora_rank + config.qk_rope_head_dim, 766 - bias=config.attention_bias, 767 - ) 768 - self.kv_a_layernorm = DeepseekV2RMSNorm(config.kv_lora_rank) 769 - self.kv_b_proj = nn.Linear( 770 - config.kv_lora_rank, 771 - self.num_heads 772 - * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), 773 - bias=False, 774 - ) 775 - 776 - self.o_proj = nn.Linear( 777 - self.num_heads * self.v_head_dim, 778 - self.hidden_size, 779 - bias=config.attention_bias, 780 - ) 781 - self._init_rope() 782 - 783 - self.softmax_scale = self.q_head_dim ** (-0.5) 784 - if self.config.rope_scaling is not None: 785 - mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) 786 - scaling_factor = self.config.rope_scaling["factor"] 787 - if mscale_all_dim: 788 - mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) 789 - self.softmax_scale = self.softmax_scale * mscale * mscale 790 - 791 - def _init_rope(self): 792 - if self.config.rope_scaling is None: 793 - self.rotary_emb = DeepseekV2RotaryEmbedding( 794 - self.qk_rope_head_dim, 795 - max_position_embeddings=self.max_position_embeddings, 796 - base=self.rope_theta, 797 - ) 798 - # self.rotary_emb = DeepseekV2LinearScalingRotaryEmbedding( 799 - # self.qk_rope_head_dim, 800 - # max_position_embeddings=self.max_position_embeddings, 801 - # scaling_factor=scaling_factor, 802 - # base=self.rope_theta, 803 - # ) 804 - else: 805 - scaling_type = self.config.rope_scaling["type"] 806 - scaling_factor = self.config.rope_scaling["factor"] 807 - if scaling_type == "linear": 808 - self.rotary_emb = DeepseekV2LinearScalingRotaryEmbedding( 809 - self.qk_rope_head_dim, 810 - max_position_embeddings=self.max_position_embeddings, 811 - scaling_factor=scaling_factor, 812 - base=self.rope_theta, 813 - ) 814 - elif scaling_type == "dynamic": 815 - self.rotary_emb = DeepseekV2DynamicNTKScalingRotaryEmbedding( 816 - self.qk_rope_head_dim, 817 - max_position_embeddings=self.max_position_embeddings, 818 - scaling_factor=scaling_factor, 819 - base=self.rope_theta, 820 - ) 821 - elif scaling_type == "yarn": 822 - kwargs = { 823 - key: self.config.rope_scaling[key] 824 - for key in [ 825 - "original_max_position_embeddings", 826 - "beta_fast", 827 - "beta_slow", 828 - "mscale", 829 - "mscale_all_dim", 830 - ] 831 - if key in self.config.rope_scaling 832 - } 833 - self.rotary_emb = DeepseekV2YarnRotaryEmbedding( 834 - self.qk_rope_head_dim, 835 - max_position_embeddings=self.max_position_embeddings, 836 - scaling_factor=scaling_factor, 837 - base=self.rope_theta, 838 - **kwargs, 839 - ) 840 - else: 841 - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") 842 - 843 - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 844 - return ( 845 - tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim) 846 - .transpose(1, 2) 847 - .contiguous() 848 - ) 849 - 850 - def forward( 851 - self, 852 - hidden_states: torch.Tensor, 853 - attention_mask: Optional[torch.Tensor] = None, 854 - position_ids: Optional[torch.LongTensor] = None, 855 - past_key_value: Optional[Cache] = None, 856 - output_attentions: bool = False, 857 - use_cache: bool = False, 858 - **kwargs, 859 - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 860 - if "padding_mask" in kwargs: 861 - warnings.warn( 862 - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" 863 - ) 864 - bsz, q_len, _ = hidden_states.size() 865 - 866 - if self.q_lora_rank is None: 867 - q = self.q_proj(hidden_states) 868 - else: 869 - q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) 870 - q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) 871 - 872 - 873 - q_nope, q_pe = torch.split( 874 - q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 875 - ) 876 - 877 - compressed_kv = self.kv_a_proj_with_mqa(hidden_states) 878 - compressed_kv, k_pe = torch.split( 879 - compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 880 - ) 881 - compressed_kv = self.kv_a_layernorm(compressed_kv) 882 - k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) 883 - 884 - kv_seq_len = k_pe.shape[-2] 885 - if past_key_value is not None: 886 - if self.layer_idx is None: 887 - raise ValueError( 888 - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " 889 - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " 890 - "with a layer index." 891 - ) 892 - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) 893 - 894 - cos, sin = self.rotary_emb(q_pe, seq_len=kv_seq_len) 895 - q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) 896 - 897 - if past_key_value is not None: 898 - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models 899 - compressed_kv = compressed_kv.unsqueeze(1) 900 - k_pe, compressed_kv = past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs) 901 - compressed_kv = compressed_kv.squeeze(1) 902 - 903 - kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank) 904 - q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :] 905 - out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :] 906 - 907 - q_nope = torch.matmul(q_nope, q_absorb) 908 - attn_weights = (torch.matmul(q_pe, k_pe.mT) + 909 - torch.matmul(q_nope, compressed_kv.unsqueeze(-3).mT)) * self.softmax_scale 910 - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 911 - raise ValueError( 912 - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" 913 - f" {attn_weights.size()}" 914 - ) 915 - assert attention_mask is not None 916 - if attention_mask is not None: 917 - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 918 - raise ValueError( 919 - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 920 - ) 921 - attn_weights = attn_weights + attention_mask 922 - 923 - # upcast attention to fp32 924 - attn_weights = nn.functional.softmax( 925 - attn_weights, dim=-1, dtype=torch.float32 926 - ).to(q_pe.dtype) 927 - attn_weights = nn.functional.dropout( 928 - attn_weights, p=self.attention_dropout, training=self.training 929 - ) 930 - attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv) 931 - 932 - attn_output = torch.matmul(attn_output, out_absorb.mT) 933 - 934 - if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): 935 - raise ValueError( 936 - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is" 937 - f" {attn_output.size()}" 938 - ) 939 - 940 - attn_output = attn_output.transpose(1, 2).contiguous() 941 - 942 - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) 943 - 944 - attn_output = self.o_proj(attn_output) 945 - 946 - if not output_attentions: 947 - attn_weights = None 948 - 949 - return attn_output, attn_weights, past_key_value 950 - 951 - 952 - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV2 953 - class DeepseekV2FlashAttention2(DeepseekV2Attention): 954 - """ 955 - DeepseekV2 flash attention module. This module inherits from `DeepseekV2Attention` as the weights of the module stays 956 - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of 957 - flash attention and deal with padding tokens in case the input contains any of them. 958 - """ 959 - 960 - def __init__(self, *args, **kwargs): 961 - super().__init__(*args, **kwargs) 962 - 963 - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. 964 - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. 965 - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). 966 - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() 967 - 968 - def forward( 969 - self, 970 - hidden_states: torch.Tensor, 971 - attention_mask: Optional[torch.LongTensor] = None, 972 - position_ids: Optional[torch.LongTensor] = None, 973 - past_key_value: Optional[Cache] = None, 974 - output_attentions: bool = False, 975 - use_cache: bool = False, 976 - **kwargs, 977 - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 978 - # DeepseekV2FlashAttention2 attention does not support output_attentions 979 - if "padding_mask" in kwargs: 980 - warnings.warn( 981 - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" 982 - ) 983 - 984 - # overwrite attention_mask with padding_mask 985 - attention_mask = kwargs.pop("padding_mask") 986 - 987 - output_attentions = False 988 - 989 - bsz, q_len, _ = hidden_states.size() 990 - 991 - if self.q_lora_rank is None: 992 - q = self.q_proj(hidden_states) 993 - else: 994 - q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) 995 - q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) 996 - q_nope, q_pe = torch.split( 997 - q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 998 - ) 999 - 1000 - # Flash attention requires the input to have the shape 1001 - # batch_size x seq_length x head_dim x hidden_dim 1002 - # therefore we just need to keep the original shape 1003 - compressed_kv = self.kv_a_proj_with_mqa(hidden_states) 1004 - compressed_kv, k_pe = torch.split( 1005 - compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 1006 - ) 1007 - k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) 1008 - kv = ( 1009 - self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) 1010 - .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) 1011 - .transpose(1, 2) 1012 - ) 1013 - 1014 - k_nope, value_states = torch.split( 1015 - kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 1016 - ) 1017 - kv_seq_len = value_states.shape[-2] 1018 - 1019 - kv_seq_len = value_states.shape[-2] 1020 - if past_key_value is not None: 1021 - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) 1022 - 1023 - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 1024 - q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) 1025 - 1026 - query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) 1027 - query_states[:, :, :, : self.qk_nope_head_dim] = q_nope 1028 - query_states[:, :, :, self.qk_nope_head_dim :] = q_pe 1029 - 1030 - key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) 1031 - key_states[:, :, :, : self.qk_nope_head_dim] = k_nope 1032 - key_states[:, :, :, self.qk_nope_head_dim :] = k_pe 1033 - 1034 - if self.q_head_dim != self.v_head_dim: 1035 - value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) 1036 - 1037 - # TODO: support compressed_kv for kv_cache (instead of key_states, value_states) in flash_attention version 1038 - if past_key_value is not None: 1039 - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models 1040 - key_states, value_states = past_key_value.update( 1041 - key_states, value_states, self.layer_idx, cache_kwargs 1042 - ) 1043 - 1044 - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache 1045 - # to be able to avoid many of these transpose/reshape/view. 1046 - query_states = query_states.transpose(1, 2) 1047 - key_states = key_states.transpose(1, 2) 1048 - value_states = value_states.transpose(1, 2) 1049 - 1050 - dropout_rate = self.attention_dropout if self.training else 0.0 1051 - 1052 - # In PEFT, usually we cast the layer norms in float32 for training stability reasons 1053 - # therefore the input hidden states gets silently casted in float32. Hence, we need 1054 - # cast them back in the correct dtype just to be sure everything works as expected. 1055 - # This might slowdown training & inference so it is recommended to not cast the LayerNorms 1056 - # in fp32. (DeepseekV2RMSNorm handles it correctly) 1057 - 1058 - input_dtype = query_states.dtype 1059 - if input_dtype == torch.float32: 1060 - # Handle the case where the model is quantized 1061 - if hasattr(self.config, "_pre_quantization_dtype"): 1062 - target_dtype = self.config._pre_quantization_dtype 1063 - elif torch.is_autocast_enabled(): 1064 - target_dtype = torch.get_autocast_gpu_dtype() 1065 - else: 1066 - target_dtype = ( 1067 - self.q_proj.weight.dtype 1068 - if self.q_lora_rank is None 1069 - else self.q_a_proj.weight.dtype 1070 - ) 1071 - 1072 - logger.warning_once( 1073 - f"The input hidden states seems to be silently casted in float32, this might be related to" 1074 - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" 1075 - f" {target_dtype}." 1076 - ) 1077 - 1078 - query_states = query_states.to(target_dtype) 1079 - key_states = key_states.to(target_dtype) 1080 - value_states = value_states.to(target_dtype) 1081 - 1082 - attn_output = self._flash_attention_forward( 1083 - query_states, 1084 - key_states, 1085 - value_states, 1086 - attention_mask, 1087 - q_len, 1088 - dropout=dropout_rate, 1089 - softmax_scale=self.softmax_scale, 1090 - ) 1091 - if self.q_head_dim != self.v_head_dim: 1092 - attn_output = attn_output[:, :, :, : self.v_head_dim] 1093 - 1094 - attn_output = attn_output.reshape( 1095 - bsz, q_len, self.num_heads * self.v_head_dim 1096 - ).contiguous() 1097 - attn_output = self.o_proj(attn_output) 1098 - 1099 - if not output_attentions: 1100 - attn_weights = None 1101 - 1102 - return attn_output, attn_weights, past_key_value 1103 - 1104 - def _flash_attention_forward( 1105 - self, 1106 - query_states, 1107 - key_states, 1108 - value_states, 1109 - attention_mask, 1110 - query_length, 1111 - dropout=0.0, 1112 - softmax_scale=None, 1113 - ): 1114 - """ 1115 - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token 1116 - first unpad the input, then computes the attention scores and pad the final attention scores. 1117 - 1118 - Args: 1119 - query_states (`torch.Tensor`): 1120 - Input query states to be passed to Flash Attention API 1121 - key_states (`torch.Tensor`): 1122 - Input key states to be passed to Flash Attention API 1123 - value_states (`torch.Tensor`): 1124 - Input value states to be passed to Flash Attention API 1125 - attention_mask (`torch.Tensor`): 1126 - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the 1127 - position of padding tokens and 1 for the position of non-padding tokens. 1128 - dropout (`int`, *optional*): 1129 - Attention dropout 1130 - softmax_scale (`float`, *optional*): 1131 - The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) 1132 - """ 1133 - if not self._flash_attn_uses_top_left_mask: 1134 - causal = self.is_causal 1135 - else: 1136 - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV2FlashAttention2 __init__. 1137 - causal = self.is_causal and query_length != 1 1138 - 1139 - # Contains at least one padding token in the sequence 1140 - if attention_mask is not None: 1141 - batch_size = query_states.shape[0] 1142 - ( 1143 - query_states, 1144 - key_states, 1145 - value_states, 1146 - indices_q, 1147 - cu_seq_lens, 1148 - max_seq_lens, 1149 - ) = self._upad_input( 1150 - query_states, key_states, value_states, attention_mask, query_length 1151 - ) 1152 - 1153 - cu_seqlens_q, cu_seqlens_k = cu_seq_lens 1154 - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens 1155 - 1156 - attn_output_unpad = flash_attn_varlen_func( 1157 - query_states, 1158 - key_states, 1159 - value_states, 1160 - cu_seqlens_q=cu_seqlens_q, 1161 - cu_seqlens_k=cu_seqlens_k, 1162 - max_seqlen_q=max_seqlen_in_batch_q, 1163 - max_seqlen_k=max_seqlen_in_batch_k, 1164 - dropout_p=dropout, 1165 - softmax_scale=softmax_scale, 1166 - causal=causal, 1167 - ) 1168 - 1169 - attn_output = pad_input( 1170 - attn_output_unpad, indices_q, batch_size, query_length 1171 - ) 1172 - else: 1173 - attn_output = flash_attn_func( 1174 - query_states, 1175 - key_states, 1176 - value_states, 1177 - dropout, 1178 - softmax_scale=softmax_scale, 1179 - causal=causal, 1180 - ) 1181 - 1182 - return attn_output 1183 - 1184 - def _upad_input( 1185 - self, query_layer, key_layer, value_layer, attention_mask, query_length 1186 - ): 1187 - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) 1188 - batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape 1189 - 1190 - key_layer = index_first_axis( 1191 - key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), 1192 - indices_k, 1193 - ) 1194 - value_layer = index_first_axis( 1195 - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), 1196 - indices_k, 1197 - ) 1198 - if query_length == kv_seq_len: 1199 - query_layer = index_first_axis( 1200 - query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), 1201 - indices_k, 1202 - ) 1203 - cu_seqlens_q = cu_seqlens_k 1204 - max_seqlen_in_batch_q = max_seqlen_in_batch_k 1205 - indices_q = indices_k 1206 - elif query_length == 1: 1207 - max_seqlen_in_batch_q = 1 1208 - cu_seqlens_q = torch.arange( 1209 - batch_size + 1, dtype=torch.int32, device=query_layer.device 1210 - ) # There is a memcpy here, that is very bad. 1211 - indices_q = cu_seqlens_q[:-1] 1212 - query_layer = query_layer.squeeze(1) 1213 - else: 1214 - # The -q_len: slice assumes left padding. 1215 - attention_mask = attention_mask[:, -query_length:] 1216 - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( 1217 - query_layer, attention_mask 1218 - ) 1219 - 1220 - return ( 1221 - query_layer, 1222 - key_layer, 1223 - value_layer, 1224 - indices_q, 1225 - (cu_seqlens_q, cu_seqlens_k), 1226 - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), 1227 - ) 1228 - 1229 - 1230 - ATTENTION_CLASSES = { 1231 - "eager": DeepseekV2Attention, 1232 - "flash_attention_2": DeepseekV2FlashAttention2, 1233 - 1234 - "mla_eager": DeepseekV2Attention, 1235 - "mla_flash_attention_2": DeepseekV2FlashAttention2, 1236 - 1237 - "mha_eager": LlamaAttention, 1238 - "mha_flash_attention_2": LlamaFlashAttention2 1239 - } 1240 - 1241 - 1242 - class DeepseekV2DecoderLayer(nn.Module): 1243 - def __init__(self, config: DeepseekV2Config, layer_idx: int): 1244 - super().__init__() 1245 - self.hidden_size = config.hidden_size 1246 - 1247 - 1248 - if config.use_mla: 1249 - attn_implementation = "mla_" + config._attn_implementation 1250 - else: 1251 - attn_implementation = "mha_" + config._attn_implementation 1252 - 1253 - self.self_attn = ATTENTION_CLASSES[attn_implementation]( 1254 - config=config, layer_idx=layer_idx 1255 - ) 1256 - 1257 - self.mlp = ( 1258 - DeepseekV2MoE(config) 1259 - if ( 1260 - config.n_routed_experts is not None 1261 - and layer_idx >= config.first_k_dense_replace 1262 - and layer_idx % config.moe_layer_freq == 0 1263 - ) 1264 - else DeepseekV2MLP(config) 1265 - ) 1266 - self.input_layernorm = DeepseekV2RMSNorm( 1267 - config.hidden_size, eps=config.rms_norm_eps 1268 - ) 1269 - self.post_attention_layernorm = DeepseekV2RMSNorm( 1270 - config.hidden_size, eps=config.rms_norm_eps 1271 - ) 1272 - 1273 - def forward( 1274 - self, 1275 - hidden_states: torch.Tensor, 1276 - attention_mask: Optional[torch.Tensor] = None, 1277 - position_ids: Optional[torch.LongTensor] = None, 1278 - past_key_value: Optional[Tuple[torch.Tensor]] = None, 1279 - output_attentions: Optional[bool] = False, 1280 - use_cache: Optional[bool] = False, 1281 - **kwargs, 1282 - ) -> Tuple[ 1283 - torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] 1284 - ]: 1285 - """ 1286 - Args: 1287 - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` 1288 - attention_mask (`torch.FloatTensor`, *optional*): 1289 - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, 1290 - query_sequence_length, key_sequence_length)` if default attention is used. 1291 - output_attentions (`bool`, *optional*): 1292 - Whether or not to return the attentions tensors of all attention layers. See `attentions` under 1293 - returned tensors for more detail. 1294 - use_cache (`bool`, *optional*): 1295 - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding 1296 - (see `past_key_values`). 1297 - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states 1298 - """ 1299 - if "padding_mask" in kwargs: 1300 - warnings.warn( 1301 - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" 1302 - ) 1303 - residual = hidden_states 1304 - 1305 - hidden_states = self.input_layernorm(hidden_states) 1306 - 1307 - # Self Attention 1308 - hidden_states, self_attn_weights, present_key_value = self.self_attn( 1309 - hidden_states=hidden_states, 1310 - attention_mask=attention_mask, 1311 - position_ids=position_ids, 1312 - past_key_value=past_key_value, 1313 - output_attentions=output_attentions, 1314 - use_cache=use_cache, 1315 - **kwargs, 1316 - ) 1317 - hidden_states = residual + hidden_states 1318 - 1319 - # Fully Connected 1320 - residual = hidden_states 1321 - hidden_states = self.post_attention_layernorm(hidden_states) 1322 - hidden_states = self.mlp(hidden_states) 1323 - hidden_states = residual + hidden_states 1324 - 1325 - outputs = (hidden_states,) 1326 - 1327 - if output_attentions: 1328 - outputs += (self_attn_weights,) 1329 - 1330 - if use_cache: 1331 - outputs += (present_key_value,) 1332 - 1333 - return outputs 1334 - 1335 - 1336 - DeepseekV2_START_DOCSTRING = r""" 1337 - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 1338 - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 1339 - etc.) 1340 - 1341 - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 1342 - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 1343 - and behavior. 1344 - 1345 - Parameters: 1346 - config ([`DeepseekV2Config`]): 1347 - Model configuration class with all the parameters of the model. Initializing with a config file does not 1348 - load the weights associated with the model, only the configuration. Check out the 1349 - [`~PreTrainedModel.from_pretrained`] method to load the model weights. 1350 - """ 1351 - 1352 - 1353 - @add_start_docstrings( 1354 - "The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.", 1355 - DeepseekV2_START_DOCSTRING, 1356 - ) 1357 - class DeepseekV2PreTrainedModel(PreTrainedModel): 1358 - config_class = DeepseekV2Config 1359 - base_model_prefix = "model" 1360 - supports_gradient_checkpointing = True 1361 - _no_split_modules = ["DeepseekV2DecoderLayer"] 1362 - _skip_keys_device_placement = "past_key_values" 1363 - _supports_flash_attn_2 = True 1364 - _supports_cache_class = True 1365 - 1366 - def _init_weights(self, module): 1367 - std = self.config.initializer_range 1368 - if isinstance(module, nn.Linear): 1369 - module.weight.data.normal_(mean=0.0, std=std) 1370 - if module.bias is not None: 1371 - module.bias.data.zero_() 1372 - elif isinstance(module, nn.Embedding): 1373 - module.weight.data.normal_(mean=0.0, std=std) 1374 - if module.padding_idx is not None: 1375 - module.weight.data[module.padding_idx].zero_() 1376 - 1377 - 1378 - DeepseekV2_INPUTS_DOCSTRING = r""" 1379 - Args: 1380 - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 1381 - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide 1382 - it. 1383 - 1384 - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 1385 - [`PreTrainedTokenizer.__call__`] for details. 1386 - 1387 - [What are input IDs?](../glossary#input-ids) 1388 - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 1389 - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 1390 - 1391 - - 1 for tokens that are **not masked**, 1392 - - 0 for tokens that are **masked**. 1393 - 1394 - [What are attention masks?](../glossary#attention-mask) 1395 - 1396 - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 1397 - [`PreTrainedTokenizer.__call__`] for details. 1398 - 1399 - If `past_key_values` is used, optionally only the last `input_ids` have to be input (see 1400 - `past_key_values`). 1401 - 1402 - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] 1403 - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more 1404 - information on the default strategy. 1405 - 1406 - - 1 indicates the head is **not masked**, 1407 - - 0 indicates the head is **masked**. 1408 - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1409 - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 1410 - config.n_positions - 1]`. 1411 - 1412 - [What are position IDs?](../glossary#position-ids) 1413 - past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): 1414 - Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention 1415 - blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` 1416 - returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. 1417 - 1418 - Two formats are allowed: 1419 - - a [`~cache_utils.Cache`] instance; 1420 - - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of 1421 - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy 1422 - cache format. 1423 - 1424 - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the 1425 - legacy cache format will be returned. 1426 - 1427 - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't 1428 - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` 1429 - of shape `(batch_size, sequence_length)`. 1430 - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 1431 - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 1432 - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 1433 - model's internal embedding lookup matrix. 1434 - use_cache (`bool`, *optional*): 1435 - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 1436 - `past_key_values`). 1437 - output_attentions (`bool`, *optional*): 1438 - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 1439 - tensors for more detail. 1440 - output_hidden_states (`bool`, *optional*): 1441 - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 1442 - more detail. 1443 - return_dict (`bool`, *optional*): 1444 - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 1445 - """ 1446 - 1447 - 1448 - @add_start_docstrings( 1449 - "The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.", 1450 - DeepseekV2_START_DOCSTRING, 1451 - ) 1452 - class DeepseekV2Model(DeepseekV2PreTrainedModel): 1453 - """ 1454 - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`] 1455 - 1456 - Args: 1457 - config: DeepseekV2Config 1458 - """ 1459 - 1460 - def __init__(self, config: DeepseekV2Config): 1461 - super().__init__(config) 1462 - self.padding_idx = config.pad_token_id 1463 - self.vocab_size = config.vocab_size 1464 - 1465 - self.embed_tokens = nn.Embedding( 1466 - config.vocab_size, config.hidden_size, self.padding_idx 1467 - ) 1468 - self.layers = nn.ModuleList( 1469 - [ 1470 - DeepseekV2DecoderLayer(config, layer_idx) 1471 - for layer_idx in range(config.num_hidden_layers) 1472 - ] 1473 - ) 1474 - # print(config._attn_implementation) 1475 - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" 1476 - self.norm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 1477 - 1478 - self.gradient_checkpointing = False 1479 - # Initialize weights and apply final processing 1480 - self.post_init() 1481 - 1482 - def get_input_embeddings(self): 1483 - return self.embed_tokens 1484 - 1485 - def set_input_embeddings(self, value): 1486 - self.embed_tokens = value 1487 - 1488 - @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) 1489 - def forward( 1490 - self, 1491 - input_ids: torch.LongTensor = None, 1492 - attention_mask: Optional[torch.Tensor] = None, 1493 - position_ids: Optional[torch.LongTensor] = None, 1494 - past_key_values: Optional[List[torch.FloatTensor]] = None, 1495 - inputs_embeds: Optional[torch.FloatTensor] = None, 1496 - use_cache: Optional[bool] = None, 1497 - output_attentions: Optional[bool] = None, 1498 - output_hidden_states: Optional[bool] = None, 1499 - return_dict: Optional[bool] = None, 1500 - cache_position: Optional[torch.LongTensor] = None 1501 - ) -> Union[Tuple, BaseModelOutputWithPast]: 1502 - output_attentions = ( 1503 - output_attentions 1504 - if output_attentions is not None 1505 - else self.config.output_attentions 1506 - ) 1507 - output_hidden_states = ( 1508 - output_hidden_states 1509 - if output_hidden_states is not None 1510 - else self.config.output_hidden_states 1511 - ) 1512 - use_cache = use_cache if use_cache is not None else self.config.use_cache 1513 - 1514 - return_dict = ( 1515 - return_dict if return_dict is not None else self.config.use_return_dict 1516 - ) 1517 - 1518 - # retrieve input_ids and inputs_embeds 1519 - if input_ids is not None and inputs_embeds is not None: 1520 - raise ValueError( 1521 - "You cannot specify both input_ids and inputs_embeds at the same time" 1522 - ) 1523 - elif input_ids is not None: 1524 - batch_size, seq_length = input_ids.shape[:2] 1525 - elif inputs_embeds is not None: 1526 - batch_size, seq_length = inputs_embeds.shape[:2] 1527 - else: 1528 - raise ValueError("You have to specify either input_ids or inputs_embeds") 1529 - 1530 - if self.gradient_checkpointing and self.training: 1531 - if use_cache: 1532 - logger.warning_once( 1533 - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers." 1534 - ) 1535 - use_cache = False 1536 - 1537 - past_key_values_length = 0 1538 - if use_cache: 1539 - use_legacy_cache = not isinstance(past_key_values, Cache) 1540 - if use_legacy_cache: 1541 - past_key_values = DynamicCache.from_legacy_cache(past_key_values) 1542 - past_key_values_length = past_key_values.get_usable_length(seq_length) 1543 - 1544 - if position_ids is None: 1545 - device = input_ids.device if input_ids is not None else inputs_embeds.device 1546 - position_ids = torch.arange( 1547 - past_key_values_length, 1548 - seq_length + past_key_values_length, 1549 - dtype=torch.long, 1550 - device=device, 1551 - ) 1552 - position_ids = position_ids.unsqueeze(0) 1553 - 1554 - if inputs_embeds is None: 1555 - inputs_embeds = self.embed_tokens(input_ids) 1556 - 1557 - if self._use_flash_attention_2: 1558 - # 2d mask is passed through the layers 1559 - attention_mask = ( 1560 - attention_mask 1561 - if (attention_mask is not None and 0 in attention_mask) 1562 - else None 1563 - ) 1564 - else: 1565 - # 4d mask is passed through the layers 1566 - attention_mask = _prepare_4d_causal_attention_mask( 1567 - attention_mask, 1568 - (batch_size, seq_length), 1569 - inputs_embeds, 1570 - past_key_values_length, 1571 - ) 1572 - 1573 - # embed positions 1574 - hidden_states = inputs_embeds 1575 - 1576 - # decoder layers 1577 - all_hidden_states = () if output_hidden_states else None 1578 - all_self_attns = () if output_attentions else None 1579 - next_decoder_cache = None 1580 - 1581 - for decoder_layer in self.layers: 1582 - if output_hidden_states: 1583 - all_hidden_states += (hidden_states,) 1584 - 1585 - if self.gradient_checkpointing and self.training: 1586 - layer_outputs = self._gradient_checkpointing_func( 1587 - decoder_layer.__call__, 1588 - hidden_states, 1589 - attention_mask, 1590 - position_ids, 1591 - past_key_values, 1592 - output_attentions, 1593 - use_cache, 1594 - ) 1595 - else: 1596 - layer_outputs = decoder_layer( 1597 - hidden_states, 1598 - attention_mask=attention_mask, 1599 - position_ids=position_ids, 1600 - past_key_value=past_key_values, 1601 - output_attentions=output_attentions, 1602 - use_cache=use_cache, 1603 - ) 1604 - 1605 - hidden_states = layer_outputs[0] 1606 - 1607 - if use_cache: 1608 - next_decoder_cache = layer_outputs[2 if output_attentions else 1] 1609 - 1610 - if output_attentions: 1611 - all_self_attns += (layer_outputs[1],) 1612 - 1613 - hidden_states = self.norm(hidden_states) 1614 - 1615 - # add hidden states from the last decoder layer 1616 - if output_hidden_states: 1617 - all_hidden_states += (hidden_states,) 1618 - 1619 - next_cache = None 1620 - if use_cache: 1621 - next_cache = ( 1622 - next_decoder_cache.to_legacy_cache() 1623 - if use_legacy_cache 1624 - else next_decoder_cache 1625 - ) 1626 - if not return_dict: 1627 - return tuple( 1628 - v 1629 - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] 1630 - if v is not None 1631 - ) 1632 - return BaseModelOutputWithPast( 1633 - last_hidden_state=hidden_states, 1634 - past_key_values=next_cache, 1635 - hidden_states=all_hidden_states, 1636 - attentions=all_self_attns, 1637 - ) 1638 - 1639 - 1640 - class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel): 1641 - _tied_weights_keys = ["lm_head.weight"] 1642 - 1643 - def __init__(self, config): 1644 - super().__init__(config) 1645 - self.model = DeepseekV2Model(config) 1646 - self.vocab_size = config.vocab_size 1647 - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 1648 - 1649 - # Initialize weights and apply final processing 1650 - self.post_init() 1651 - 1652 - def get_input_embeddings(self): 1653 - return self.model.embed_tokens 1654 - 1655 - def set_input_embeddings(self, value): 1656 - self.model.embed_tokens = value 1657 - 1658 - def get_output_embeddings(self): 1659 - return self.lm_head 1660 - 1661 - def set_output_embeddings(self, new_embeddings): 1662 - self.lm_head = new_embeddings 1663 - 1664 - def set_decoder(self, decoder): 1665 - self.model = decoder 1666 - 1667 - def get_decoder(self): 1668 - return self.model 1669 - 1670 - @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) 1671 - @replace_return_docstrings( 1672 - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC 1673 - ) 1674 - def forward( 1675 - self, 1676 - input_ids: torch.LongTensor = None, 1677 - attention_mask: Optional[torch.Tensor] = None, 1678 - position_ids: Optional[torch.LongTensor] = None, 1679 - past_key_values: Optional[List[torch.FloatTensor]] = None, 1680 - inputs_embeds: Optional[torch.FloatTensor] = None, 1681 - labels: Optional[torch.LongTensor] = None, 1682 - use_cache: Optional[bool] = None, 1683 - output_attentions: Optional[bool] = None, 1684 - output_hidden_states: Optional[bool] = None, 1685 - return_dict: Optional[bool] = None, 1686 - cache_position: Optional[torch.LongTensor] = None 1687 - ) -> Union[Tuple, CausalLMOutputWithPast]: 1688 - r""" 1689 - Args: 1690 - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1691 - Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers., 1692 - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 1693 - (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`. 1694 - 1695 - Returns: 1696 - 1697 - Example: 1698 - 1699 - ```python 1700 - >>> from transformers import AutoTokenizer, DeepseekV2ForCausalLM 1701 - 1702 - >>> model = DeepseekV2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) 1703 - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) 1704 - 1705 - >>> prompt = "Hey, are you conscious? Can you talk to me?" 1706 - >>> inputs = tokenizer(prompt, return_tensors="pt") 1707 - 1708 - >>> # Generate 1709 - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 1710 - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 1711 - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." 1712 - ```""" 1713 - output_attentions = ( 1714 - output_attentions 1715 - if output_attentions is not None 1716 - else self.config.output_attentions 1717 - ) 1718 - output_hidden_states = ( 1719 - output_hidden_states 1720 - if output_hidden_states is not None 1721 - else self.config.output_hidden_states 1722 - ) 1723 - return_dict = ( 1724 - return_dict if return_dict is not None else self.config.use_return_dict 1725 - ) 1726 - 1727 - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 1728 - outputs = self.model( 1729 - input_ids=input_ids, 1730 - attention_mask=attention_mask, 1731 - position_ids=position_ids, 1732 - past_key_values=past_key_values, 1733 - inputs_embeds=inputs_embeds, 1734 - use_cache=use_cache, 1735 - output_attentions=output_attentions, 1736 - output_hidden_states=output_hidden_states, 1737 - return_dict=return_dict, 1738 - cache_position=cache_position 1739 - ) 1740 - 1741 - hidden_states = outputs[0] 1742 - logits = self.lm_head(hidden_states) 1743 - logits = logits.float() 1744 - 1745 - loss = None 1746 - if labels is not None: 1747 - # Shift so that tokens < n predict n 1748 - shift_logits = logits[..., :-1, :].contiguous() 1749 - shift_labels = labels[..., 1:].contiguous() 1750 - # Flatten the tokens 1751 - loss_fct = CrossEntropyLoss() 1752 - shift_logits = shift_logits.view(-1, self.config.vocab_size) 1753 - shift_labels = shift_labels.view(-1) 1754 - # Enable model parallelism 1755 - shift_labels = shift_labels.to(shift_logits.device) 1756 - loss = loss_fct(shift_logits, shift_labels) 1757 - 1758 - if not return_dict: 1759 - output = (logits,) + outputs[1:] 1760 - return (loss,) + output if loss is not None else output 1761 - 1762 - return CausalLMOutputWithPast( 1763 - loss=loss, 1764 - logits=logits, 1765 - past_key_values=outputs.past_key_values, 1766 - hidden_states=outputs.hidden_states, 1767 - attentions=outputs.attentions, 1768 - ) 1769 - 1770 - def prepare_inputs_for_generation( 1771 - self, 1772 - input_ids, 1773 - past_key_values=None, 1774 - attention_mask=None, 1775 - inputs_embeds=None, 1776 - **kwargs, 1777 - ): 1778 - past_length = 0 1779 - if past_key_values is not None: 1780 - if isinstance(past_key_values, Cache): 1781 - cache_length = past_key_values.get_seq_length() 1782 - past_length = past_key_values.seen_tokens 1783 - max_cache_length = past_key_values.get_max_length() 1784 - else: 1785 - cache_length = past_length = past_key_values[0][0].shape[2] 1786 - max_cache_length = None 1787 - 1788 - # Keep only the unprocessed tokens: 1789 - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where 1790 - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as 1791 - # input) 1792 - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: 1793 - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):] 1794 - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard 1795 - # input_ids based on the past_length. 1796 - elif past_length < input_ids.shape[1]: 1797 - input_ids = input_ids[:, past_length:] 1798 - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. 1799 - 1800 - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. 1801 - if ( 1802 - max_cache_length is not None 1803 - and attention_mask is not None 1804 - and cache_length + input_ids.shape[1] > max_cache_length 1805 - ): 1806 - attention_mask = attention_mask[:, -max_cache_length:] 1807 - 1808 - position_ids = kwargs.get("position_ids", None) 1809 - if attention_mask is not None and position_ids is None: 1810 - # create position_ids on the fly for batch generation 1811 - position_ids = attention_mask.long().cumsum(-1) - 1 1812 - position_ids.masked_fill_(attention_mask == 0, 1) 1813 - if past_key_values: 1814 - position_ids = position_ids[:, -input_ids.shape[1]:] 1815 - 1816 - if self.generation_config.cache_implementation == "static": 1817 - # generation with static cache 1818 - cache_position = kwargs.get("cache_position", None) 1819 - if cache_position is None: 1820 - past_length = 0 1821 - else: 1822 - past_length = cache_position[-1] + 1 1823 - input_ids = input_ids[:, past_length:] 1824 - position_ids = position_ids[:, past_length:] 1825 - 1826 - # TODO @gante we should only keep a `cache_position` in generate, and do +=1. 1827 - # same goes for position ids. Could also help with continued generation. 1828 - cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) 1829 - 1830 - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 1831 - if inputs_embeds is not None and past_key_values is None: 1832 - model_inputs = {"inputs_embeds": inputs_embeds} 1833 - else: 1834 - # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise 1835 - # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 1836 - # TODO: use `next_tokens` directly instead. 1837 - model_inputs = {"input_ids": input_ids.contiguous()} 1838 - 1839 - model_inputs.update( 1840 - { 1841 - "position_ids": position_ids.contiguous(), 1842 - "cache_position": cache_position, 1843 - "past_key_values": past_key_values, 1844 - "use_cache": kwargs.get("use_cache"), 1845 - "attention_mask": attention_mask, 1846 - } 1847 - ) 1848 - return model_inputs 1849 - 1850 - @staticmethod 1851 - def _reorder_cache(past_key_values, beam_idx): 1852 - reordered_past = () 1853 - for layer_past in past_key_values: 1854 - reordered_past += ( 1855 - tuple( 1856 - past_state.index_select(0, beam_idx.to(past_state.device)) 1857 - for past_state in layer_past 1858 - ), 1859 - ) 1860 - return reordered_past 1861 - 1862 - 1863 - @add_start_docstrings( 1864 - """ 1865 - The DeepseekV2 Model transformer with a sequence classification head on top (linear layer). 1866 - 1867 - [`DeepseekV2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models 1868 - (e.g. GPT-2) do. 1869 - 1870 - Since it does classification on the last token, it requires to know the position of the last token. If a 1871 - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If 1872 - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the 1873 - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in 1874 - each row of the batch). 1875 - """, 1876 - DeepseekV2_START_DOCSTRING, 1877 - ) 1878 - class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel): 1879 - def __init__(self, config): 1880 - super().__init__(config) 1881 - self.num_labels = config.num_labels 1882 - self.model = DeepseekV2Model(config) 1883 - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) 1884 - 1885 - # Initialize weights and apply final processing 1886 - self.post_init() 1887 - 1888 - def get_input_embeddings(self): 1889 - return self.model.embed_tokens 1890 - 1891 - def set_input_embeddings(self, value): 1892 - self.model.embed_tokens = value 1893 - 1894 - @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) 1895 - def forward( 1896 - self, 1897 - input_ids: torch.LongTensor = None, 1898 - attention_mask: Optional[torch.Tensor] = None, 1899 - position_ids: Optional[torch.LongTensor] = None, 1900 - past_key_values: Optional[List[torch.FloatTensor]] = None, 1901 - inputs_embeds: Optional[torch.FloatTensor] = None, 1902 - labels: Optional[torch.LongTensor] = None, 1903 - use_cache: Optional[bool] = None, 1904 - output_attentions: Optional[bool] = None, 1905 - output_hidden_states: Optional[bool] = None, 1906 - return_dict: Optional[bool] = None, 1907 - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: 1908 - r""" 1909 - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1910 - Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers., 1911 - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 1912 - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1913 - """ 1914 - return_dict = ( 1915 - return_dict if return_dict is not None else self.config.use_return_dict 1916 - ) 1917 - 1918 - transformer_outputs = self.model( 1919 - input_ids, 1920 - attention_mask=attention_mask, 1921 - position_ids=position_ids, 1922 - past_key_values=past_key_values, 1923 - inputs_embeds=inputs_embeds, 1924 - use_cache=use_cache, 1925 - output_attentions=output_attentions, 1926 - output_hidden_states=output_hidden_states, 1927 - return_dict=return_dict, 1928 - ) 1929 - hidden_states = transformer_outputs[0] 1930 - logits = self.score(hidden_states) 1931 - 1932 - if input_ids is not None: 1933 - batch_size = input_ids.shape[0] 1934 - else: 1935 - batch_size = inputs_embeds.shape[0] 1936 - 1937 - if self.config.pad_token_id is None and batch_size != 1: 1938 - raise ValueError( 1939 - "Cannot handle batch sizes > 1 if no padding token is defined." 1940 - ) 1941 - if self.config.pad_token_id is None: 1942 - sequence_lengths = -1 1943 - else: 1944 - if input_ids is not None: 1945 - sequence_lengths = ( 1946 - torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 1947 - ).to(logits.device) 1948 - else: 1949 - sequence_lengths = -1 1950 - 1951 - pooled_logits = logits[ 1952 - torch.arange(batch_size, device=logits.device), sequence_lengths 1953 - ] 1954 - 1955 - loss = None 1956 - if labels is not None: 1957 - labels = labels.to(logits.device) 1958 - if self.config.problem_type is None: 1959 - if self.num_labels == 1: 1960 - self.config.problem_type = "regression" 1961 - elif self.num_labels > 1 and ( 1962 - labels.dtype == torch.long or labels.dtype == torch.int 1963 - ): 1964 - self.config.problem_type = "single_label_classification" 1965 - else: 1966 - self.config.problem_type = "multi_label_classification" 1967 - 1968 - if self.config.problem_type == "regression": 1969 - loss_fct = MSELoss() 1970 - if self.num_labels == 1: 1971 - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) 1972 - else: 1973 - loss = loss_fct(pooled_logits, labels) 1974 - elif self.config.problem_type == "single_label_classification": 1975 - loss_fct = CrossEntropyLoss() 1976 - loss = loss_fct( 1977 - pooled_logits.view(-1, self.num_labels), labels.view(-1) 1978 - ) 1979 - elif self.config.problem_type == "multi_label_classification": 1980 - loss_fct = BCEWithLogitsLoss() 1981 - loss = loss_fct(pooled_logits, labels) 1982 - if not return_dict: 1983 - output = (pooled_logits,) + transformer_outputs[1:] 1984 - return ((loss,) + output) if loss is not None else output 1985 - 1986 - return SequenceClassifierOutputWithPast( 1987 - loss=loss, 1988 - logits=pooled_logits, 1989 - past_key_values=transformer_outputs.past_key_values, 1990 - hidden_states=transformer_outputs.hidden_states, 1991 - attentions=transformer_outputs.attentions, 1992 - )
-82
src/eval_deepseek.py
··· 1 - """ 2 - ExpRate evaluation of a fine-tuned DeepSeek-OCR-2 checkpoint on the test splits. 3 - 4 - ExpRate = fraction of exact string matches after whitespace normalisation. 5 - 6 - Usage: 7 - uv run eval-deepseek [--checkpoint checkpoints/deepseek/final] 8 - [--splits mathwriting_test typeset_test typeset_mixed_test] 9 - [--n 500] 10 - """ 11 - 12 - import argparse 13 - 14 - import torch 15 - from PIL import Image 16 - from tqdm import tqdm 17 - 18 - import re 19 - 20 - import random 21 - 22 - from .data import TEST_SPLITS, load_records 23 - from .mine_failures import _infer, _load 24 - 25 - 26 - def normalize(s: str) -> str: 27 - return re.sub(r"\s+", " ", s).strip() 28 - 29 - 30 - def evaluate( 31 - checkpoint: str, 32 - splits: list[str] | None = None, 33 - n: int | None = None, 34 - ) -> float: 35 - model, tokenizer = _load(checkpoint) 36 - prompt_ids = tokenizer.encode( 37 - "\nTranscribe this image to Typst notation.\n", add_special_tokens=False 38 - ) 39 - 40 - split_names = splits or TEST_SPLITS 41 - rng = random.Random(0) 42 - records = [] 43 - for name in split_names: 44 - recs = load_records([name], dedupe=False) 45 - if n is not None and len(recs) > n: 46 - recs = rng.sample(recs, n) 47 - records.extend(recs) 48 - 49 - correct = 0 50 - per_split: dict[str, list[bool]] = {} 51 - 52 - for r in tqdm(records, desc="Evaluating"): 53 - img = Image.open(r["image_path"]).convert("RGB") 54 - pred = normalize(_infer(model, tokenizer, prompt_ids, img)) 55 - gt = normalize(r["typst"]) 56 - hit = pred == gt 57 - correct += hit 58 - per_split.setdefault(r.get("split", "unknown"), []).append(hit) 59 - print(f" GT : {repr(gt)}") 60 - print(f" PRED: {repr(pred)}") 61 - print(f" HIT : {hit}\n") 62 - 63 - total = len(records) 64 - print(f"\nExpRate: {correct/total:.4f} ({correct}/{total})") 65 - for split, hits in sorted(per_split.items()): 66 - print(f" {split}: {sum(hits)/len(hits):.4f} ({sum(hits)}/{len(hits)})") 67 - return correct / total 68 - 69 - 70 - def main() -> None: 71 - parser = argparse.ArgumentParser() 72 - parser.add_argument("--checkpoint", default="checkpoints/deepseek/final") 73 - parser.add_argument("--splits", nargs="+", default=None, 74 - metavar="SPLIT", help="Override test splits") 75 - parser.add_argument("--n", type=int, default=None, 76 - help="Max records per split (stratified); omit for full test set") 77 - args = parser.parse_args() 78 - evaluate(args.checkpoint, args.splits, args.n) 79 - 80 - 81 - if __name__ == "__main__": 82 - main()
-82
src/infer_debug.py
··· 1 - """Quick diagnostic: run one inference and print raw token IDs + image injection status.""" 2 - import torch 3 - from PIL import Image 4 - from transformers import AutoTokenizer, BitsAndBytesConfig 5 - from peft import PeftModel 6 - 7 - from .collate_deepseek import IMAGE_TOKEN_ID, N_IMAGE_TOKENS, _PROMPT, preprocess_image 8 - from .data import TEST_SPLITS, load_records 9 - from .deepseek_ocr2.modeling_deepseekocr2 import DeepseekOCR2ForCausalLM 10 - 11 - MODEL_ID = "deepseek-ai/DeepSeek-OCR-2" 12 - CHECKPOINT = "checkpoints/deepseek/checkpoint-10500" 13 - 14 - 15 - def main(): 16 - bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", 17 - bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True) 18 - base = DeepseekOCR2ForCausalLM.from_pretrained(MODEL_ID, quantization_config=bnb, 19 - use_safetensors=True, _attn_implementation="eager") 20 - model = PeftModel.from_pretrained(base, CHECKPOINT) 21 - for name, param in model.named_parameters(): 22 - if param.dtype == torch.float16: 23 - param.data = param.data.to(torch.bfloat16) 24 - model.eval() 25 - tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) 26 - 27 - prompt_ids = tokenizer.encode(_PROMPT, add_special_tokens=False) 28 - print(f"Prompt token IDs: {prompt_ids}") 29 - print(f"Prompt decoded: {repr(tokenizer.decode(prompt_ids))}") 30 - print(f"IMAGE_TOKEN_ID={IMAGE_TOKEN_ID}, N_IMAGE_TOKENS={N_IMAGE_TOKENS}") 31 - 32 - r = load_records(TEST_SPLITS, dedupe=False)[0] 33 - img = Image.open(r["image_path"]).convert("RGB") 34 - img_t = preprocess_image(img).cuda() 35 - print(f"Image tensor: dtype={img_t.dtype}, sum={img_t.sum().item():.2f}, shape={img_t.shape}") 36 - 37 - img_ids = [IMAGE_TOKEN_ID] * N_IMAGE_TOKENS 38 - ids = torch.tensor([img_ids + prompt_ids], dtype=torch.long, device="cuda") 39 - n_input = ids.shape[1] 40 - seq_mask = torch.zeros(1, n_input, dtype=torch.bool, device="cuda") 41 - seq_mask[0, :N_IMAGE_TOKENS] = True 42 - images = [(img_t.unsqueeze(0), img_t.unsqueeze(0))] 43 - spatial = torch.tensor([[1, 1]], dtype=torch.long, device="cuda") 44 - attn = torch.ones_like(ids) 45 - 46 - # Monkey-patch masked_scatter_ to verify image injection fires 47 - injection_happened = [False] 48 - original_mscatter = torch.Tensor.masked_scatter_ 49 - 50 - def patched_mscatter(self_t, mask, source): 51 - injection_happened[0] = True 52 - print(f" masked_scatter_ called: self={self_t.shape} {self_t.dtype}, " 53 - f"mask={mask.shape}, source={source.shape} {source.dtype}") 54 - return original_mscatter(self_t, mask, source) 55 - 56 - torch.Tensor.masked_scatter_ = patched_mscatter 57 - 58 - with torch.no_grad(): 59 - out = model.generate( 60 - input_ids=ids, 61 - attention_mask=attn, 62 - images=images, 63 - images_seq_mask=seq_mask, 64 - images_spatial_crop=spatial, 65 - max_new_tokens=30, 66 - do_sample=False, 67 - ) 68 - 69 - torch.Tensor.masked_scatter_ = original_mscatter 70 - print(f"Image injection fired: {injection_happened[0]}") 71 - generated = out[0][n_input:] 72 - print(f"\nGenerated token IDs ({len(generated)} tokens): {generated[:20].tolist()}") 73 - print(f"Decoded (skip_special=True): {repr(tokenizer.decode(generated, skip_special_tokens=True))}") 74 - print(f"Decoded (skip_special=False): {repr(tokenizer.decode(generated, skip_special_tokens=False))}") 75 - print(f"\nGT: {repr(r['typst'])}") 76 - print(f"\nEOS token id (tokenizer): {tokenizer.eos_token_id}") 77 - print(f"EOS token id (model.config): {model.config.eos_token_id}") 78 - print(f"BOS token id (model.config): {model.config.bos_token_id}") 79 - 80 - 81 - if __name__ == "__main__": 82 - main()
-143
src/mine_failures.py
··· 1 - """ 2 - HNM Step 1: identify hard failures from the training pool. 3 - 4 - Runs inference on a random sample of training records (default 10k), 5 - computes normalized edit distance between prediction and ground truth, 6 - and writes the top-N hardest failures to data/hnm_pool.jsonl. 7 - 8 - Samples from the training pool (not val/test) to keep evaluation sets clean. 9 - 10 - Usage: uv run mine [--checkpoint checkpoints/deepseek/final] 11 - [--sample 10000] 12 - [--top-n 3000] 13 - [--out data/hnm_pool.jsonl] 14 - """ 15 - 16 - import argparse 17 - import json 18 - import random 19 - from pathlib import Path 20 - 21 - import editdistance 22 - import torch 23 - from peft import PeftModel 24 - from PIL import Image 25 - from tqdm import tqdm 26 - from transformers import AutoTokenizer, BitsAndBytesConfig 27 - 28 - import re 29 - 30 - from .collate_deepseek import IMAGE_TOKEN_ID, N_IMAGE_TOKENS, _PROMPT, preprocess_image 31 - from .data import TRAIN_SPLITS, load_records 32 - from .deepseek_ocr2.modeling_deepseekocr2 import DeepseekOCR2ForCausalLM 33 - 34 - 35 - def normalize(s: str) -> str: 36 - return re.sub(r"\s+", " ", s).strip() 37 - 38 - MODEL_ID = "deepseek-ai/DeepSeek-OCR-2" 39 - 40 - 41 - def _load(checkpoint: str): 42 - bnb = BitsAndBytesConfig( 43 - load_in_4bit=True, 44 - bnb_4bit_quant_type="nf4", 45 - bnb_4bit_compute_dtype=torch.bfloat16, 46 - bnb_4bit_use_double_quant=True, 47 - ) 48 - base = DeepseekOCR2ForCausalLM.from_pretrained( 49 - MODEL_ID, 50 - quantization_config=bnb, 51 - use_safetensors=True, 52 - _attn_implementation="eager", 53 - ) 54 - model = PeftModel.from_pretrained(base, checkpoint) 55 - 56 - for name, param in model.named_parameters(): 57 - if param.dtype == torch.float16: 58 - param.data = param.data.to(torch.bfloat16) 59 - 60 - model.eval() 61 - tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) 62 - return model, tokenizer 63 - 64 - 65 - def _infer(model, tokenizer, prompt_ids: list[int], img: Image.Image) -> str: 66 - img_t = preprocess_image(img).cuda() # [3, 768, 768] 67 - img_ids = [IMAGE_TOKEN_ID] * N_IMAGE_TOKENS 68 - ids = torch.tensor([img_ids + prompt_ids], dtype=torch.long, device="cuda") 69 - n_input = ids.shape[1] 70 - 71 - seq_mask = torch.zeros(1, n_input, dtype=torch.bool, device="cuda") 72 - seq_mask[0, :N_IMAGE_TOKENS] = True 73 - 74 - # zeros for local crops: forces the crop_mode=False else-branch → 145 global features 75 - images = [(torch.zeros_like(img_t.unsqueeze(0)), img_t.unsqueeze(0))] 76 - spatial = torch.tensor([[1, 1]], dtype=torch.long, device="cuda") 77 - attn = torch.ones_like(ids) 78 - 79 - with torch.no_grad(): 80 - out = model.generate( 81 - input_ids=ids, 82 - attention_mask=attn, 83 - images=images, 84 - images_seq_mask=seq_mask, 85 - images_spatial_crop=spatial, 86 - max_new_tokens=512, 87 - do_sample=False, 88 - ) 89 - return tokenizer.decode(out[0][n_input:], skip_special_tokens=True) 90 - 91 - 92 - def mine( 93 - checkpoint: str, 94 - sample_size: int = 10_000, 95 - top_n: int = 3_000, 96 - out_path: str = "data/hnm_pool.jsonl", 97 - ) -> None: 98 - model, tokenizer = _load(checkpoint) 99 - prompt_ids = tokenizer.encode(_PROMPT, add_special_tokens=False) 100 - 101 - all_records = load_records(TRAIN_SPLITS, dedupe=True) 102 - sample = random.sample(all_records, min(sample_size, len(all_records))) 103 - print(f"Mining over {len(sample):,} training samples") 104 - 105 - scored: list[tuple[float, dict]] = [] 106 - 107 - for r in tqdm(sample, desc="Inferring"): 108 - img = Image.open(r["image_path"]).convert("RGB") 109 - pred = normalize(_infer(model, tokenizer, prompt_ids, img)) 110 - gt = normalize(r["typst"]) 111 - 112 - if pred == gt: 113 - continue 114 - 115 - max_len = max(len(pred), len(gt), 1) 116 - dist = editdistance.eval(pred, gt) / max_len 117 - scored.append((dist, r)) 118 - 119 - scored.sort(key=lambda x: x[0], reverse=True) 120 - failures = [r for _, r in scored[:top_n]] 121 - 122 - out = Path(out_path) 123 - out.parent.mkdir(parents=True, exist_ok=True) 124 - with out.open("w") as f: 125 - for r in failures: 126 - f.write(json.dumps(r) + "\n") 127 - 128 - print(f"Wrote {len(failures)} hard negatives to {out_path}") 129 - print(f"Failure rate in sample: {len(scored)}/{len(sample)} = {len(scored)/len(sample):.2%}") 130 - 131 - 132 - def main() -> None: 133 - parser = argparse.ArgumentParser() 134 - parser.add_argument("--checkpoint", default="checkpoints/deepseek/final") 135 - parser.add_argument("--sample", type=int, default=10_000) 136 - parser.add_argument("--top-n", type=int, default=3_000) 137 - parser.add_argument("--out", default="data/hnm_pool.jsonl") 138 - args = parser.parse_args() 139 - mine(args.checkpoint, args.sample, args.top_n, args.out) 140 - 141 - 142 - if __name__ == "__main__": 143 - main()
-121
src/probe_deepseek.py
··· 1 - """ 2 - Inference probe for DeepSeek-OCR-2 on Typst OCR. 3 - 4 - The model outputs LaTeX / plain text (not Typst), so MATCH will almost 5 - always be False -- this is for qualitative inspection, not ExpRate. 6 - 7 - Quantisation (--bits): 8 - 16 bf16, no quant ~6.8 GB (official recipe) 9 - 8 INT8 bitsandbytes ~3.4 GB (experimental -- custom model, may fail) 10 - 4 NF4 bitsandbytes ~1.7 GB (experimental) 11 - 12 - Usage: 13 - uv run probe-deepseek [--bits 16] [--n 5] 14 - uv run probe-deepseek --images path/to/a.png path/to/b.png 15 - """ 16 - 17 - import argparse 18 - import os 19 - import tempfile 20 - from pathlib import Path 21 - 22 - import torch 23 - from PIL import Image 24 - from transformers import AutoModel, AutoTokenizer 25 - 26 - from .data import TEST_SPLITS, load_records 27 - from .eval import normalize 28 - 29 - MODEL_ID = "deepseek-ai/DeepSeek-OCR-2" 30 - 31 - DEFAULT_PROMPT = "<image>\nConvert this mathematical expression to Typst math notation. " 32 - 33 - 34 - def load_model(model_id: str = MODEL_ID, bits: int = 16): 35 - """Load DeepSeek-OCR-2. Returns (model, tokenizer).""" 36 - tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) 37 - 38 - kwargs = dict( 39 - trust_remote_code=True, 40 - use_safetensors=True, 41 - _attn_implementation="flash_attention_2", 42 - ) 43 - 44 - if bits in (4, 8): 45 - from transformers import BitsAndBytesConfig 46 - if bits == 4: 47 - bnb_cfg = BitsAndBytesConfig( 48 - load_in_4bit=True, 49 - bnb_4bit_quant_type="nf4", 50 - bnb_4bit_compute_dtype=torch.bfloat16, 51 - bnb_4bit_use_double_quant=True, 52 - ) 53 - else: 54 - bnb_cfg = BitsAndBytesConfig(load_in_8bit=True) 55 - model = AutoModel.from_pretrained(model_id, quantization_config=bnb_cfg, **kwargs) 56 - else: 57 - model = AutoModel.from_pretrained(model_id, **kwargs) 58 - model = model.eval().cuda().to(torch.bfloat16) 59 - 60 - model.eval() 61 - return model, tokenizer 62 - 63 - 64 - def run_image(img: Image.Image, model, tokenizer, prompt: str = DEFAULT_PROMPT) -> str: 65 - """Run inference on a PIL image. Saves to a tempfile (model.infer needs a path).""" 66 - with tempfile.TemporaryDirectory() as tmpdir: 67 - img_path = os.path.join(tmpdir, "input.png") 68 - out_dir = os.path.join(tmpdir, "out") 69 - os.makedirs(out_dir) 70 - img.save(img_path) 71 - 72 - res = model.infer( 73 - tokenizer, 74 - prompt=prompt, 75 - image_file=img_path, 76 - output_path=out_dir, 77 - base_size=1024, 78 - image_size=768, 79 - crop_mode=False, # math expressions are small -- don't crop 80 - save_results=False, 81 - ) 82 - 83 - # res is typically a string or list; normalise to str 84 - if isinstance(res, list): 85 - return "\n".join(str(x) for x in res).strip() 86 - return str(res).strip() 87 - 88 - 89 - def main() -> None: 90 - parser = argparse.ArgumentParser() 91 - parser.add_argument("--model", default=MODEL_ID) 92 - parser.add_argument("--bits", type=int, default=16, choices=[4, 8, 16]) 93 - parser.add_argument("--prompt", default=DEFAULT_PROMPT, 94 - help="Full prompt string; must start with '<image>\\n'") 95 - parser.add_argument("--n", type=int, default=5) 96 - parser.add_argument("--images", nargs="+", metavar="IMG") 97 - args = parser.parse_args() 98 - 99 - print(f"Loading {args.model} at {args.bits}-bit ...") 100 - model, tokenizer = load_model(args.model, args.bits) 101 - print("Model ready.\n") 102 - 103 - if args.images: 104 - for path in args.images: 105 - img = Image.open(path).convert("RGB") 106 - pred = run_image(img, model, tokenizer, args.prompt) 107 - print(f"{path}:\n {pred}\n") 108 - else: 109 - records = load_records(TEST_SPLITS, dedupe=False)[: args.n] 110 - for i, r in enumerate(records): 111 - img = Image.open(r["image_path"]).convert("RGB") 112 - pred = run_image(img, model, tokenizer, args.prompt) 113 - print(f"\n{'='*60}") 114 - print(f"[{i}] {r['image_path']}") 115 - print(f" EXPECTED : {repr(r['typst'])}") 116 - print(f" PREDICTED: {repr(pred)}") 117 - print(f" MATCH : {normalize(pred) == normalize(r['typst'])}") 118 - 119 - 120 - if __name__ == "__main__": 121 - main()
-275
src/train_deepseek.py
··· 1 - """ 2 - QLoRA fine-tuning of DeepSeek-OCR-2 for Typst math OCR. 3 - 4 - Architecture notes: 5 - - DeepseekOCR2ForCausalLM extends DeepseekV2ForCausalLM with a SAM-based 6 - image encoder. Standard Trainer + PEFT work because forward() accepts 7 - `labels` and returns a loss in the CausalLM style. 8 - - MLA (Multi-head Latent Attention) in the LM backbone uses compressed KV 9 - projections. LoRA targets cover both the MLA attention matrices and the 10 - dense MLP gates. The vision encoder is frozen -- it is already strongly 11 - pretrained and contributes no new information for Typst notation style. 12 - - With 4-bit NF4 + LoRA r=16, peak VRAM on an RTX 3060 12 GB is roughly 13 - 6-8 GB for batch=1, grad_accum=8, seq≤512 tokens. 14 - 15 - Usage: 16 - uv run train-deepseek [--smoke-test] [--output-dir checkpoints/deepseek] 17 - uv run train-deepseek --model /path/to/local/checkpoint ... 18 - """ 19 - 20 - import argparse 21 - from pathlib import Path 22 - 23 - import torch 24 - from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training 25 - from transformers import ( 26 - AutoTokenizer, 27 - BitsAndBytesConfig, 28 - Trainer, 29 - TrainingArguments, 30 - ) 31 - 32 - from .collate_deepseek import DeepSeekOCRCollator 33 - from .data import TRAIN_SPLITS, VAL_SPLITS, load_records, make_dataset 34 - from .deepseek_ocr2.modeling_deepseekocr2 import DeepseekOCR2ForCausalLM 35 - 36 - MODEL_ID = "deepseek-ai/DeepSeek-OCR-2" 37 - 38 - # DeepSeek V2 MLA attention matrices + shared/routed MLP gates. 39 - # MLA splits QKV into low-rank compressed projections (a/b pairs). 40 - # If any name is absent the model will raise; adjust to match the loaded arch. 41 - LORA_TARGET_MODULES = [ 42 - "q_a_proj", "q_b_proj", 43 - "kv_a_proj_with_mqa", "kv_b_proj", 44 - "o_proj", 45 - "gate_proj", "up_proj", "down_proj", 46 - ] 47 - 48 - 49 - # ── Model loading ────────────────────────────────────────────────────────────── 50 - 51 - def _bnb_config() -> BitsAndBytesConfig: 52 - return BitsAndBytesConfig( 53 - load_in_4bit=True, 54 - bnb_4bit_quant_type="nf4", 55 - bnb_4bit_compute_dtype=torch.bfloat16, 56 - bnb_4bit_use_double_quant=True, 57 - ) 58 - 59 - 60 - def load_model_and_tokenizer(model_id: str): 61 - # Tokenizer still loaded from hub (no code, just vocab/config files). 62 - tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) 63 - 64 - # Model loaded via our local patched class -- no trust_remote_code needed. 65 - # Weights still fetched from hub; only the Python forward() code is local. 66 - model = DeepseekOCR2ForCausalLM.from_pretrained( 67 - model_id, 68 - quantization_config=_bnb_config(), 69 - use_safetensors=True, 70 - _attn_implementation="eager", 71 - ) 72 - return model, tokenizer 73 - 74 - 75 - def _freeze_and_cast_vision_encoder(model) -> None: 76 - """ 77 - Freeze the SAM / Qwen2 vision encoders and projector, and cast all their 78 - parameters to bfloat16. 79 - 80 - BnB only quantizes nn.Linear layers; conv layers (e.g. patch_embed.proj in 81 - SAM) are left in float32. prepare_model_for_kbit_training does not cast 82 - them either, so the conv receives bfloat16 activations but has a float32 83 - bias -- RuntimeError. Casting to bf16 before training fixes this. 84 - 85 - Module structure (DeepseekOCR2Model): 86 - model.sam_model -- SAM ViT-B image encoder (conv + attention) 87 - model.qwen2_model -- Qwen2 decoder used as a second encoder 88 - model.projector -- MLP mapping vision features to LM embedding dim 89 - """ 90 - # Substrings that uniquely identify vision encoder params (post-PEFT naming) 91 - VISION_SUBSTRINGS = ("sam_model", "qwen2_model", "projector", "view_seperator") 92 - 93 - frozen = 0 94 - for name, param in model.named_parameters(): 95 - if any(s in name for s in VISION_SUBSTRINGS): 96 - param.requires_grad_(False) 97 - param.data = param.data.to(torch.bfloat16) 98 - frozen += param.numel() 99 - if frozen: 100 - print(f"Frozen + cast to bf16: {frozen / 1e6:.1f} M vision params.") 101 - else: 102 - print("WARNING: no vision encoder params matched -- check module names.") 103 - 104 - 105 - # ── Custom Trainer (handles non-tensor 'images' key) ────────────────────────── 106 - 107 - class DeepSeekTrainer(Trainer): 108 - """ 109 - Overrides _prepare_inputs to move the `images` list-of-tuples to the 110 - correct device. The base Trainer's send_to_device recurses through dicts 111 - and lists but not tuples-inside-lists reliably across all TF versions. 112 - """ 113 - 114 - def _prepare_inputs(self, inputs: dict) -> dict: 115 - inputs = super()._prepare_inputs(inputs) 116 - if "images" in inputs: 117 - dev = self.args.device 118 - inputs["images"] = [ 119 - (lc.to(dev, dtype=torch.bfloat16), 120 - gv.to(dev, dtype=torch.bfloat16)) 121 - for lc, gv in inputs["images"] 122 - ] 123 - return inputs 124 - 125 - 126 - # ── Main ─────────────────────────────────────────────────────────────────────── 127 - 128 - def main() -> None: 129 - parser = argparse.ArgumentParser() 130 - parser.add_argument("--model", default=MODEL_ID, 131 - help="HF model ID or local path") 132 - parser.add_argument("--smoke-test", action="store_true", 133 - help="One forward+backward pass then exit") 134 - parser.add_argument("--output-dir", default="checkpoints/deepseek") 135 - parser.add_argument("--epochs", type=int, default=1) 136 - parser.add_argument("--lr", type=float, default=1e-4) 137 - parser.add_argument("--lora-r", type=int, default=16) 138 - args = parser.parse_args() 139 - 140 - print(f"Loading {args.model} ...") 141 - model, tokenizer = load_model_and_tokenizer(args.model) 142 - 143 - # prepare_model_for_kbit_training: casts LayerNorm + embeddings to fp32, 144 - # enables grad checkpointing on the model itself. 145 - model = prepare_model_for_kbit_training( 146 - model, use_gradient_checkpointing=True, 147 - ) 148 - 149 - _freeze_and_cast_vision_encoder(model) 150 - 151 - lora_cfg = LoraConfig( 152 - r=args.lora_r, 153 - lora_alpha=args.lora_r * 2, 154 - target_modules=LORA_TARGET_MODULES, 155 - lora_dropout=0.0, 156 - bias="none", 157 - task_type="CAUSAL_LM", 158 - ) 159 - model = get_peft_model(model, lora_cfg) 160 - model.print_trainable_parameters() 161 - 162 - import random as _random 163 - _rng = _random.Random(29979) 164 - 165 - # Train: load each split individually, apply caps, then combine. 166 - # Caps prevent synthetic-heavy splits from dominating. 167 - # Policy: cap synthetics (mathwriting_synthetic, crohme_gen_2019); keep 168 - # real and document-structure splits uncapped. 169 - _TRAIN_CAPS = { 170 - "mathwriting_synthetic": 20_000, 171 - "crohme_gen_2019": 15_000, 172 - "mathwriting_train": 10_000, # real but large; cap to avoid dominating 173 - } 174 - train_records: list[dict] = [] 175 - for split in TRAIN_SPLITS: 176 - recs = load_records([split], dedupe=True) 177 - cap = _TRAIN_CAPS.get(split) 178 - if cap and len(recs) > cap: 179 - recs = _rng.sample(recs, cap) 180 - train_records.extend(recs) 181 - _rng.shuffle(train_records) 182 - 183 - # Val: all typeset_val + all typeset_mixed_val + 250 from mathwriting_val. 184 - _rng = _random.Random(29979) 185 - mw_val = load_records(["mathwriting_val"], dedupe=False) 186 - typeset_val = load_records(["typeset_val"], dedupe=False) 187 - mixed_val = load_records(["typeset_mixed_val"], dedupe=False) 188 - mw_sample = _rng.sample(mw_val, min(250, len(mw_val))) 189 - ts_sample = _rng.sample(typeset_val, min(250, len(typeset_val))) 190 - val_records = mw_sample + ts_sample + mixed_val 191 - _rng.shuffle(val_records) 192 - 193 - print(f"Train: {len(train_records):,} Val: {len(val_records):,} " 194 - f"(mathwriting={len(mw_sample)}, typeset={len(ts_sample)}, mixed={len(mixed_val)})") 195 - 196 - train_ds = make_dataset(train_records, do_augment=True) 197 - val_ds = make_dataset(val_records, do_augment=False) 198 - 199 - collator = DeepSeekOCRCollator(tokenizer) 200 - 201 - if args.smoke_test: 202 - _run_smoke_test(model, collator, train_ds) 203 - return 204 - 205 - out_dir = args.output_dir 206 - training_args = TrainingArguments( 207 - output_dir=out_dir, 208 - per_device_train_batch_size=2, 209 - per_device_eval_batch_size=2, 210 - gradient_accumulation_steps=4, # effective batch size 8 211 - num_train_epochs=args.epochs, 212 - learning_rate=args.lr, 213 - warmup_steps=200, 214 - lr_scheduler_type="cosine", 215 - bf16=True, 216 - fp16=False, 217 - gradient_checkpointing=True, 218 - dataloader_num_workers=0, # PIL + lazy load requires 0 219 - logging_steps=50, 220 - eval_strategy="steps", 221 - eval_steps=500, 222 - save_steps=250, 223 - save_total_limit=10, 224 - load_best_model_at_end=False, 225 - remove_unused_columns=False, # collator uses non-standard keys 226 - report_to=["tensorboard"], 227 - ) 228 - 229 - has_checkpoint = any(Path(out_dir).glob("checkpoint-*")) if Path(out_dir).exists() else False 230 - 231 - trainer = DeepSeekTrainer( 232 - model=model, 233 - args=training_args, 234 - train_dataset=train_ds, 235 - eval_dataset=val_ds, 236 - data_collator=collator, 237 - ) 238 - 239 - trainer.train(resume_from_checkpoint=has_checkpoint) 240 - 241 - final_dir = f"{out_dir}/final" 242 - model.save_pretrained(final_dir) 243 - tokenizer.save_pretrained(final_dir) 244 - print(f"Saved to {final_dir}") 245 - 246 - 247 - def _run_smoke_test(model, collator, train_ds) -> None: 248 - print("Running smoke test (2 samples, 1 forward+backward) ...") 249 - batch = collator([train_ds[0], train_ds[1]]) 250 - 251 - dev = next(model.parameters()).device 252 - batch_cuda = {} 253 - for k, v in batch.items(): 254 - if k == "images": 255 - batch_cuda[k] = [ 256 - (lc.to(dev, dtype=torch.bfloat16), 257 - gv.to(dev, dtype=torch.bfloat16)) 258 - for lc, gv in v 259 - ] 260 - elif isinstance(v, torch.Tensor): 261 - batch_cuda[k] = v.to(dev) 262 - else: 263 - batch_cuda[k] = v 264 - 265 - model.train() 266 - out = model(**batch_cuda) 267 - loss = out.loss 268 - print(f" forward OK -- loss: {loss.item():.4f}") 269 - loss.backward() 270 - print(" backward OK") 271 - print("Smoke test passed.") 272 - 273 - 274 - if __name__ == "__main__": 275 - main()