···11-"""
22-Backfill TensorBoard events from training stdout.
33-44-Usage:
55- # Save your terminal output to a file, then:
66- uv run python -m src.backfill_tb training.log checkpoints/deepseek/runs/backfill
77-88-The log file should contain lines like:
99- {'loss': 9.56, 'grad_norm': 11.6, 'learning_rate': 9.99e-05, 'epoch': 0.04}
1010- {'eval_loss': 2.72, 'eval_runtime': 542.7, ..., 'epoch': 0.04}
1111-1212-Train entries are assumed to occur every logging_steps=50 steps.
1313-Eval entries are assumed to occur every eval_steps=500 steps.
1414-Adjust LOGGING_STEPS / EVAL_STEPS below if different.
1515-"""
1616-1717-import ast
1818-import re
1919-import sys
2020-from pathlib import Path
2121-2222-LOGGING_STEPS = 50
2323-EVAL_STEPS = 500
2424-2525-2626-def parse_log(path: str) -> tuple[list[dict], list[dict]]:
2727- train_entries: list[dict] = []
2828- eval_entries: list[dict] = []
2929- dict_re = re.compile(r"^\{.*\}$")
3030- for line in Path(path).read_text().splitlines():
3131- line = line.strip()
3232- if not dict_re.match(line):
3333- continue
3434- try:
3535- d = ast.literal_eval(line)
3636- except Exception:
3737- continue
3838- if not isinstance(d, dict):
3939- continue
4040- if "eval_loss" in d:
4141- eval_entries.append(d)
4242- elif "loss" in d:
4343- train_entries.append(d)
4444- return train_entries, eval_entries
4545-4646-4747-def write_events(train_entries, eval_entries, out_dir: str) -> None:
4848- from torch.utils.tensorboard import SummaryWriter
4949- writer = SummaryWriter(log_dir=out_dir)
5050-5151- for i, d in enumerate(train_entries):
5252- step = (i + 1) * LOGGING_STEPS
5353- writer.add_scalar("train/loss", d["loss"], step)
5454- writer.add_scalar("train/grad_norm", d["grad_norm"], step)
5555- writer.add_scalar("train/learning_rate", d["learning_rate"], step)
5656- writer.add_scalar("train/epoch", d["epoch"], step)
5757-5858- for i, d in enumerate(eval_entries):
5959- step = (i + 1) * EVAL_STEPS
6060- writer.add_scalar("eval/loss", d["eval_loss"], step)
6161- writer.add_scalar("eval/epoch", d["epoch"], step)
6262-6363- writer.close()
6464- print(f"Wrote {len(train_entries)} train + {len(eval_entries)} eval entries to {out_dir}")
6565-6666-6767-def main() -> None:
6868- if len(sys.argv) < 3:
6969- print("Usage: python -m src.backfill_tb <log_file> <out_dir>")
7070- sys.exit(1)
7171- log_file, out_dir = sys.argv[1], sys.argv[2]
7272- train_entries, eval_entries = parse_log(log_file)
7373- print(f"Parsed {len(train_entries)} train entries, {len(eval_entries)} eval entries")
7474- write_events(train_entries, eval_entries, out_dir)
7575-7676-7777-if __name__ == "__main__":
7878- main()
-167
src/collate_deepseek.py
···11-"""
22-Data collator for DeepSeek-OCR-2 fine-tuning.
33-44-Preprocesses image+typst pairs into the tensor format expected by
55-DeepseekOCR2ForCausalLM.forward():
66-77- - PIL image letterbox-padded to 768×768 (neutral gray fill, mean=0.5)
88- - 145 image tokens at the start of the sequence (12² spatial + 1 separator)
99- - prompt tokens follow, then response tokens + EOS
1010- - labels mask the image+prompt prefix with -100 (train on response only)
1111-1212-Design decisions:
1313- - crop_mode=False: single global view, no dynamic tiling. Our math
1414- expression images are small fragments -- tiling adds no information and
1515- wastes tokens.
1616- - image_size=768: fixed encoder input (SAM ViT-B, patch=16, downsample=4
1717- → 12×12 grid = 144 + 1 separator = 145 tokens).
1818- - The `images` argument to forward() is a list of (local_crops, global_view)
1919- tuples. With crop_mode=False there are no distinct local crops; we pass
2020- the same 768×768 tensor for both slots. This matches the single-tile
2121- case in infer() but should be validated against the model source once
2222- available locally.
2323-"""
2424-2525-import math
2626-2727-import torch
2828-from PIL import Image, ImageOps
2929-from torch.nn.utils.rnn import pad_sequence
3030-from torchvision import transforms
3131-3232-# ── Constants ──────────────────────────────────────────────────────────────────
3333-3434-IMAGE_TOKEN_ID = 128815 # <image> placeholder token in the DeepSeek-OCR-2 vocab
3535-IMAGE_SIZE = 768 # fixed encoder input resolution (px)
3636-PATCH_SIZE = 16 # SAM ViT-B patch size (px)
3737-DOWNSAMPLE = 4 # spatial downsample ratio inside the encoder
3838-3939-# Number of image tokens inserted per image:
4040-# ceil(768 / 16 / 4) = 12 → 12×12 + 1 separator = 145
4141-_N_GRID = math.ceil(IMAGE_SIZE / PATCH_SIZE / DOWNSAMPLE) # 12
4242-N_IMAGE_TOKENS = _N_GRID * _N_GRID + 1 # 145
4343-4444-_MEAN = (0.5, 0.5, 0.5)
4545-_STD = (0.5, 0.5, 0.5)
4646-_PAD_COLOR = tuple(int(x * 255) for x in _MEAN) # (127, 127, 127) gray
4747-4848-_transform = transforms.Compose([
4949- transforms.ToTensor(),
5050- transforms.Normalize(mean=_MEAN, std=_STD),
5151-])
5252-5353-# Prompt appended after the image tokens. Leading \n because <image> is the
5454-# very first token and the model expects a newline before the text instruction.
5555-_PROMPT = "\nTranscribe this image to Typst notation.\n"
5656-5757-5858-# ── Image preprocessing ────────────────────────────────────────────────────────
5959-6060-def preprocess_image(img: Image.Image) -> torch.Tensor:
6161- """
6262- Letterbox-pad a PIL image to IMAGE_SIZE×IMAGE_SIZE, normalize.
6363-6464- ImageOps.pad preserves aspect ratio by centering and filling borders with
6565- _PAD_COLOR, so non-square images are not distorted.
6666-6767- Returns float16 tensor of shape [3, IMAGE_SIZE, IMAGE_SIZE].
6868- """
6969- img = img.convert("RGB")
7070- img = ImageOps.pad(img, (IMAGE_SIZE, IMAGE_SIZE), color=_PAD_COLOR)
7171- return _transform(img).to(torch.bfloat16)
7272-7373-7474-# ── Collator ───────────────────────────────────────────────────────────────────
7575-7676-class DeepSeekOCRCollator:
7777- """
7878- Collates MathOCRDataset items into inputs for DeepseekOCR2ForCausalLM.
7979-8080- Each dataset item has the structure produced by MathOCRDataset.__getitem__:
8181- {
8282- "messages": [
8383- {"role": "user", "content": [{"type": "image", "image": <PIL>},
8484- {"type": "text", "text": PROMPT}]},
8585- {"role": "assistant", "content": [{"type": "text", "text": <typst>}]},
8686- ]
8787- }
8888-8989- Output dict keys match DeepseekOCR2ForCausalLM.forward() signature:
9090- input_ids, attention_mask, labels,
9191- images, images_seq_mask, images_spatial_crop
9292- """
9393-9494- def __init__(self, tokenizer) -> None:
9595- self.tokenizer = tokenizer
9696- self.prompt_ids = tokenizer.encode(_PROMPT, add_special_tokens=False)
9797- self.eos_id = tokenizer.eos_token_id
9898- self.pad_id = (tokenizer.pad_token_id
9999- if tokenizer.pad_token_id is not None
100100- else tokenizer.eos_token_id)
101101-102102- def __call__(self, batch: list[dict]) -> dict:
103103- input_ids_list = []
104104- labels_list = []
105105- seq_mask_list = []
106106- img_tensors = []
107107- spatial_crops = []
108108-109109- for item in batch:
110110- user_content = item["messages"][0]["content"]
111111- pil_img = user_content[0]["image"] # PIL Image
112112- typst = item["messages"][1]["content"][0]["text"] # target string
113113-114114- # Image → tensor ───────────────────────────────────────────────────
115115- img_t = preprocess_image(pil_img) # [3, 768, 768] bfloat16
116116- img_tensors.append(img_t)
117117-118118- # Token sequence ───────────────────────────────────────────────────
119119- # Layout: [img×145] [prompt] [response] [EOS]
120120- response_ids = self.tokenizer.encode(typst, add_special_tokens=False)
121121- img_ids = [IMAGE_TOKEN_ID] * N_IMAGE_TOKENS
122122- ids = img_ids + self.prompt_ids + response_ids + [self.eos_id]
123123-124124- # Labels: -100 on image+prompt; train only on response+EOS
125125- n_prefix = N_IMAGE_TOKENS + len(self.prompt_ids)
126126- lbl = [-100] * n_prefix + response_ids + [self.eos_id]
127127-128128- # images_seq_mask: True at the N_IMAGE_TOKENS image token positions
129129- mask = [False] * len(ids)
130130- for i in range(N_IMAGE_TOKENS):
131131- mask[i] = True
132132-133133- input_ids_list.append(torch.tensor(ids, dtype=torch.long))
134134- labels_list.append(torch.tensor(lbl, dtype=torch.long))
135135- seq_mask_list.append(torch.tensor(mask, dtype=torch.bool))
136136- # [width_crops, height_crops] = [1, 1] for single global view
137137- spatial_crops.append([1, 1])
138138-139139- # Pad sequences ────────────────────────────────────────────────────────
140140- input_ids = pad_sequence(input_ids_list, batch_first=True,
141141- padding_value=self.pad_id)
142142- labels = pad_sequence(labels_list, batch_first=True,
143143- padding_value=-100)
144144- seq_mask = pad_sequence(seq_mask_list, batch_first=True,
145145- padding_value=False)
146146- attn_mask = (input_ids != self.pad_id)
147147-148148- # images: list of (local_crops, global_view) tuples, one per sample.
149149- # local_crops must be zeros so forward() takes the crop_mode=False else-branch,
150150- # producing exactly N_IMAGE_TOKENS=145 global features to match images_seq_mask.
151151- # Passing the real image for both slots triggers the if-branch which concatenates
152152- # local+global+sep = 289 features; masked_scatter_ then injects the wrong subset.
153153- imgs_stacked = torch.stack(img_tensors) # [B, 3, 768, 768]
154154- zeros = torch.zeros_like(imgs_stacked[0].unsqueeze(0))
155155- images = [
156156- (zeros, imgs_stacked[i].unsqueeze(0))
157157- for i in range(len(batch))
158158- ]
159159-160160- return {
161161- "input_ids": input_ids, # [B, T]
162162- "attention_mask": attn_mask, # [B, T]
163163- "labels": labels, # [B, T]
164164- "images": images, # list[B] of (tensor, tensor)
165165- "images_seq_mask": seq_mask, # [B, T]
166166- "images_spatial_crop": torch.tensor(spatial_crops, dtype=torch.long), # [B, 2]
167167- }
src/deepseek_ocr2/__init__.py
This is a binary file and will not be displayed.
-210
src/deepseek_ocr2/configuration_deepseek_v2.py
···11-from transformers.configuration_utils import PretrainedConfig
22-from transformers.utils import logging
33-44-logger = logging.get_logger(__name__)
55-66-DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
77-class DeepseekV2Config(PretrainedConfig):
88- r"""
99- This is the configuration class to store the configuration of a [`DeepseekV2Model`]. It is used to instantiate an DeepSeek
1010- model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
1111- defaults will yield a similar configuration to that of the DeepSeek-V2 with multi-latent attention.
1212-1313- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
1414- documentation from [`PretrainedConfig`] for more information.
1515-1616-1717- Args:
1818- vocab_size (`int`, *optional*, defaults to 102400):
1919- Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
2020- `inputs_ids` passed when calling [`DeepseekV2Model`]
2121- hidden_size (`int`, *optional*, defaults to 4096):
2222- Dimension of the hidden representations.
2323- intermediate_size (`int`, *optional*, defaults to 11008):
2424- Dimension of the MLP representations.
2525- moe_intermediate_size (`int`, *optional*, defaults to 1407):
2626- Dimension of the MoE representations.
2727- num_hidden_layers (`int`, *optional*, defaults to 32):
2828- Number of hidden layers in the Transformer decoder.
2929- num_attention_heads (`int`, *optional*, defaults to 32):
3030- Number of attention heads for each attention layer in the Transformer decoder.
3131- n_shared_experts (`int`, *optional*, defaults to None):
3232- Number of shared experts, None means dense model.
3333- n_routed_experts (`int`, *optional*, defaults to None):
3434- Number of routed experts, None means dense model.
3535- routed_scaling_factor (`float`, *optional*, defaults to 1.0):
3636- Scaling factor or routed experts.
3737- topk_method (`str`, *optional*, defaults to `gready`):
3838- Topk method used in routed gate.
3939- n_group (`int`, *optional*, defaults to None):
4040- Number of groups for routed experts.
4141- topk_group (`int`, *optional*, defaults to None):
4242- Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
4343- num_experts_per_tok (`int`, *optional*, defaults to None):
4444- Number of selected experts, None means dense model.
4545- moe_layer_freq (`int`, *optional*, defaults to 1):
4646- The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
4747- first_k_dense_replace (`int`, *optional*, defaults to 0):
4848- Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
4949- \--k dense layers--/
5050- norm_topk_prob (`bool`, *optional*, defaults to False):
5151- Whether to normalize the weights of the routed experts.
5252- scoring_func (`str`, *optional*, defaults to 'softmax'):
5353- Method of computing expert weights.
5454- aux_loss_alpha (`float`, *optional*, defaults to 0.001):
5555- Auxiliary loss weight coefficient.
5656- seq_aux = (`bool`, *optional*, defaults to True):
5757- Whether to compute the auxiliary loss for each individual sample.
5858- num_key_value_heads (`int`, *optional*):
5959- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
6060- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
6161- `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
6262- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
6363- by meanpooling all the original heads within that group. For more details checkout [this
6464- paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
6565- `num_attention_heads`.
6666- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
6767- The non-linear activation function (function or string) in the decoder.
6868- max_position_embeddings (`int`, *optional*, defaults to 2048):
6969- The maximum sequence length that this model might ever be used with.
7070- initializer_range (`float`, *optional*, defaults to 0.02):
7171- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
7272- rms_norm_eps (`float`, *optional*, defaults to 1e-06):
7373- The epsilon used by the rms normalization layers.
7474- use_cache (`bool`, *optional*, defaults to `True`):
7575- Whether or not the model should return the last key/values attentions (not used by all models). Only
7676- relevant if `config.is_decoder=True`.
7777- pad_token_id (`int`, *optional*):
7878- Padding token id.
7979- bos_token_id (`int`, *optional*, defaults to 1):
8080- Beginning of stream token id.
8181- eos_token_id (`int`, *optional*, defaults to 2):
8282- End of stream token id.
8383- pretraining_tp (`int`, *optional*, defaults to 1):
8484- Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
8585- document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
8686- necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
8787- issue](https://github.com/pytorch/pytorch/issues/76232).
8888- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
8989- Whether to tie weight embeddings
9090- rope_theta (`float`, *optional*, defaults to 10000.0):
9191- The base period of the RoPE embeddings.
9292- rope_scaling (`Dict`, *optional*):
9393- Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
9494- strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
9595- `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
9696- `max_position_embeddings` to the expected new maximum.
9797- attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
9898- Whether to use a bias in the query, key, value and output projection layers during self-attention.
9999- attention_dropout (`float`, *optional*, defaults to 0.0):
100100- The dropout ratio for the attention probabilities.
101101- use_mla (`bool`, *optional*, defaults to `True`): Use multi-latent attention or multi-head attention. If True,
102102- the model will use multi-latent attention, otherwise, it will use multi-head attention.
103103-104104- ```python
105105- >>> from transformers import DeepseekV2Model, DeepseekV2Config
106106-107107- >>> # Initializing a Deepseek-V2 style configuration
108108- >>> configuration = DeepseekV2Config()
109109-110110- >>> # Accessing the model configuration
111111- >>> configuration = model.config
112112- ```"""
113113-114114- model_type = "deepseek_v2"
115115- keys_to_ignore_at_inference = ["past_key_values"]
116116-117117- def __init__(
118118- self,
119119- vocab_size=102400,
120120- hidden_size=4096,
121121- intermediate_size=11008,
122122- moe_intermediate_size = 1407,
123123- num_hidden_layers=30,
124124- num_attention_heads=32,
125125- num_key_value_heads=32,
126126- n_shared_experts = None,
127127- n_routed_experts = None,
128128- ep_size = 1,
129129- routed_scaling_factor = 1.0,
130130- kv_lora_rank = 512,
131131- q_lora_rank = 1536,
132132- qk_rope_head_dim = 64,
133133- v_head_dim = 128,
134134- qk_nope_head_dim = 128,
135135- topk_method = 'gready',
136136- n_group = None,
137137- topk_group = None,
138138- num_experts_per_tok = None,
139139- moe_layer_freq = 1,
140140- first_k_dense_replace = 0,
141141- norm_topk_prob = False,
142142- scoring_func = 'softmax',
143143- aux_loss_alpha = 0.001,
144144- seq_aux = True,
145145- hidden_act="silu",
146146- max_position_embeddings=2048,
147147- initializer_range=0.02,
148148- rms_norm_eps=1e-6,
149149- use_cache=True,
150150- pad_token_id=None,
151151- bos_token_id=100000,
152152- eos_token_id=100001,
153153- pretraining_tp=1,
154154- tie_word_embeddings=False,
155155- rope_theta=10000.0,
156156- rope_scaling=None,
157157- attention_bias=False,
158158- attention_dropout=0.0,
159159- use_mla=True,
160160- **kwargs,
161161- ):
162162- self.vocab_size = vocab_size
163163- self.max_position_embeddings = max_position_embeddings
164164- self.hidden_size = hidden_size
165165- self.intermediate_size = intermediate_size
166166- self.moe_intermediate_size = moe_intermediate_size
167167- self.num_hidden_layers = num_hidden_layers
168168- self.num_attention_heads = num_attention_heads
169169- self.n_shared_experts = n_shared_experts
170170- self.n_routed_experts = n_routed_experts
171171- self.ep_size = ep_size
172172- self.routed_scaling_factor = routed_scaling_factor
173173- self.kv_lora_rank = kv_lora_rank
174174- self.q_lora_rank = q_lora_rank
175175- self.qk_rope_head_dim = qk_rope_head_dim
176176- self.v_head_dim = v_head_dim
177177- self.qk_nope_head_dim = qk_nope_head_dim
178178- self.topk_method = topk_method
179179- self.n_group = n_group
180180- self.topk_group = topk_group
181181- self.num_experts_per_tok = num_experts_per_tok
182182- self.moe_layer_freq = moe_layer_freq
183183- self.first_k_dense_replace = first_k_dense_replace
184184- self.norm_topk_prob = norm_topk_prob
185185- self.scoring_func = scoring_func
186186- self.aux_loss_alpha = aux_loss_alpha
187187- self.seq_aux = seq_aux
188188- # for backward compatibility
189189- if num_key_value_heads is None:
190190- num_key_value_heads = num_attention_heads
191191-192192- self.num_key_value_heads = num_key_value_heads
193193- self.hidden_act = hidden_act
194194- self.initializer_range = initializer_range
195195- self.rms_norm_eps = float(rms_norm_eps)
196196- self.pretraining_tp = pretraining_tp
197197- self.use_cache = use_cache
198198- self.rope_theta = rope_theta
199199- self.rope_scaling = rope_scaling
200200- self.attention_bias = attention_bias
201201- self.attention_dropout = attention_dropout
202202- self.use_mla = use_mla
203203-204204- super().__init__(
205205- pad_token_id=pad_token_id,
206206- bos_token_id=bos_token_id,
207207- eos_token_id=eos_token_id,
208208- tie_word_embeddings=tie_word_embeddings,
209209- **kwargs,
210210- )
-280
src/deepseek_ocr2/conversation.py
···11-"""
22-From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
33-"""
44-55-import dataclasses
66-from enum import IntEnum, auto
77-from typing import Any, Dict, List
88-99-1010-class SeparatorStyle(IntEnum):
1111- """Separator styles."""
1212-1313- DeepSeek = auto()
1414- DeepSeekV2 = auto()
1515- PLAIN = auto()
1616- ALIGNMENT = auto()
1717-1818-1919-@dataclasses.dataclass
2020-class Conversation:
2121- """A class that manages prompt templates and keeps all conversation history."""
2222-2323- # The name of this template
2424- name: str
2525- # The template of the system prompt
2626- system_template: str = "{system_message}"
2727- # The system message
2828- system_message: str = ""
2929- # The names of two roles
3030- roles: List[str] = (("USER", "ASSISTANT"),)
3131- # All messages. Each item is (role, message).
3232- messages: List[List[str]] = ()
3333- # The number of few shot examples
3434- offset: int = 0
3535- # The separator style and configurations
3636- sep_style: SeparatorStyle = SeparatorStyle.DeepSeek
3737- sep: str = "\n"
3838- sep2: str = None
3939- # Stop criteria (the default one is EOS token)
4040- stop_str: str = None
4141- # Stops generation if meeting any token in this list
4242- stop_token_ids: List[int] = None
4343-4444- def get_prompt(self) -> str:
4545- """Get the prompt for generation."""
4646- system_prompt = self.system_template.format(system_message=self.system_message)
4747- if self.sep_style == SeparatorStyle.DeepSeek:
4848- seps = [self.sep, self.sep2]
4949- if system_prompt == "" or system_prompt is None:
5050- ret = ""
5151- else:
5252- ret = system_prompt + seps[0]
5353- for i, (role, message) in enumerate(self.messages):
5454- if message:
5555- ret += role + ": " + message + seps[i % 2]
5656- else:
5757- ret += role + ":"
5858- return ret
5959- elif self.sep_style == SeparatorStyle.DeepSeekV2:
6060- seps = [self.sep, self.sep2]
6161- if system_prompt == "" or system_prompt is None:
6262- ret = ""
6363- else:
6464- ret = system_prompt + seps[0]
6565- for i, (role, message) in enumerate(self.messages):
6666- if message:
6767- if role == "User":
6868- ret += "<|sft▁begin|>\n" + message + self.sep #<|sft▁begin|>User Input<|sft▁end|>\nResponse<|end▁of▁sentence|>
6969- else:
7070- ret += message + self.sep2
7171- else:
7272- ret = ret
7373- return ret
7474-7575- elif self.sep_style == SeparatorStyle.PLAIN:
7676- seps = [self.sep, self.sep2]
7777- ret = ""
7878- for i, (role, message) in enumerate(self.messages):
7979- if message:
8080- if type(message) is tuple:
8181- message, _, _ = message
8282- if i % 2 == 0:
8383- ret += message + seps[i % 2]
8484- else:
8585- ret += message + seps[i % 2]
8686- else:
8787- ret += ""
8888- return ret
8989- elif self.sep_style == SeparatorStyle.ALIGNMENT:
9090- seps = [self.sep, self.sep2]
9191- ret = ""
9292- for i, (role, message) in enumerate(self.messages):
9393- if message:
9494- if type(message) is tuple:
9595- message, _, _ = message
9696- if i % 2 == 0:
9797- ret += '<image>\n' + seps[i % 2]
9898- else:
9999- ret += message + seps[i % 2]
100100- else:
101101- ret += ""
102102- return ret
103103- else:
104104- raise ValueError(f"Invalid style: {self.sep_style}")
105105-106106- def set_system_message(self, system_message: str):
107107- """Set the system message."""
108108- self.system_message = system_message
109109-110110- def append_message(self, role: str, message: str):
111111- """Append a new message."""
112112- self.messages.append([role, message])
113113-114114- def update_last_message(self, message: str):
115115- """Update the last output.
116116-117117- The last message is typically set to be None when constructing the prompt,
118118- so we need to update it in-place after getting the response from a model.
119119- """
120120- self.messages[-1][1] = message
121121-122122- def reset_message(self):
123123- """Reset a new message."""
124124- self.messages = []
125125-126126- def to_gradio_chatbot(self):
127127- """Convert the conversation to gradio chatbot format."""
128128- ret = []
129129- for i, (role, msg) in enumerate(self.messages[self.offset :]):
130130- if i % 2 == 0:
131131- ret.append([msg, None])
132132- else:
133133- ret[-1][-1] = msg
134134- return ret
135135-136136- def to_openai_api_messages(self):
137137- """Convert the conversation to OpenAI chat completion format."""
138138- system_prompt = self.system_template.format(system_message=self.system_message)
139139- ret = [{"role": "system", "content": system_prompt}]
140140-141141- for i, (_, msg) in enumerate(self.messages[self.offset :]):
142142- if i % 2 == 0:
143143- ret.append({"role": "user", "content": msg})
144144- else:
145145- if msg is not None:
146146- ret.append({"role": "assistant", "content": msg})
147147- return ret
148148-149149- def copy(self):
150150- return Conversation(
151151- name=self.name,
152152- system_template=self.system_template,
153153- system_message=self.system_message,
154154- roles=self.roles,
155155- messages=[[x, y] for x, y in self.messages],
156156- offset=self.offset,
157157- sep_style=self.sep_style,
158158- sep=self.sep,
159159- sep2=self.sep2,
160160- stop_str=self.stop_str,
161161- stop_token_ids=self.stop_token_ids,
162162- )
163163-164164- def dict(self):
165165- return {
166166- "template_name": self.name,
167167- "system_message": self.system_message,
168168- "roles": self.roles,
169169- "messages": self.messages,
170170- "offset": self.offset,
171171- }
172172-173173-174174-# A global registry for all conversation templates
175175-conv_templates: Dict[str, Conversation] = {}
176176-177177-178178-def register_conv_template(template: Conversation, override: bool = False):
179179- """Register a new conversation template."""
180180- if not override:
181181- assert template.name not in conv_templates, f"{template.name} has been registered."
182182-183183- conv_templates[template.name] = template
184184-185185-186186-def get_conv_template(name: str) -> Conversation:
187187- """Get a conversation template."""
188188- return conv_templates[name].copy()
189189-190190-191191-register_conv_template(
192192- Conversation(
193193- name="deepseek",
194194- system_template="{system_message}",
195195- # system_message="You are a helpful assistant. Please answer truthfully and write out your "
196196- # "thinking step by step to be sure you get the right answer.",
197197- system_message="",
198198- roles=("<|User|>", "<|Assistant|>"),
199199- messages=(),
200200- offset=0,
201201- sep_style=SeparatorStyle.DeepSeek,
202202- sep="\n\n",
203203- sep2="<|end▁of▁sentence|>",
204204- stop_token_ids=[100001],
205205- stop_str=["User:", "<|end▁of▁sentence|>"]
206206- )
207207-)
208208-register_conv_template(
209209- Conversation(
210210- name="deepseekv2",
211211- system_template="{system_message}",
212212- # system_message="You are a helpful assistant. Please answer truthfully and write out your "
213213- # "thinking step by step to be sure you get the right answer.",
214214- system_message="",
215215- roles=("<|User|>", "<|Assistant|>"),
216216- messages=(),
217217- offset=0,
218218- sep_style=SeparatorStyle.DeepSeek,
219219- sep="",
220220- sep2="<|end▁of▁sentence|>",
221221- stop_token_ids=[100001],
222222- stop_str=["User:", "<|end▁of▁sentence|>"]
223223- )
224224-)
225225-226226-227227-register_conv_template(
228228- Conversation(
229229- name="plain",
230230- system_template="",
231231- system_message="",
232232- roles=("", ""),
233233- messages=(),
234234- offset=0,
235235- sep_style=SeparatorStyle.PLAIN,
236236- sep="",
237237- sep2="",
238238- stop_token_ids=[100001],
239239- stop_str=['</s>'],
240240- )
241241-)
242242-243243-244244-register_conv_template(
245245- Conversation(
246246- name="alignment",
247247- system_template="",
248248- system_message="",
249249- roles=("", ""),
250250- messages=(),
251251- offset=0,
252252- sep_style=SeparatorStyle.ALIGNMENT,
253253- sep="",
254254- sep2="",
255255- stop_token_ids=[100001],
256256- stop_str=['</s>'],
257257- )
258258-)
259259-260260-261261-if __name__ == "__main__":
262262- print("deepseek template:")
263263- conv = get_conv_template("deepseek")
264264- conv.append_message(conv.roles[0], "Hello!")
265265- conv.append_message(conv.roles[1], "Hi! This is Tony.")
266266- conv.append_message(conv.roles[0], "Who are you?")
267267- conv.append_message(conv.roles[1], "I am a helpful assistant.")
268268- conv.append_message(conv.roles[0], "How are you?")
269269- conv.append_message(conv.roles[1], None)
270270- print(conv.get_prompt())
271271-272272- print("deepseekv2 template:")
273273- conv = get_conv_template("deepseekv2")
274274- conv.append_message(conv.roles[0], "Hello!")
275275- conv.append_message(conv.roles[1], "Hi! This is Tony.")
276276- conv.append_message(conv.roles[0], "Who are you?")
277277- conv.append_message(conv.roles[1], "I am a helpful assistant.")
278278- conv.append_message(conv.roles[0], "How are you?")
279279- conv.append_message(conv.roles[1], None)
280280- print(conv.get_prompt())
-1015
src/deepseek_ocr2/deepencoderv2.py
···11-import torch.nn as nn
22-import torch
33-import torch.nn.functional as F
44-import copy
55-66-77-from typing import Optional, Tuple
88-99-# from megatron.model import LayerNorm
1010-1111-import transformers
1212-1313-1414-from typing import Optional, Tuple, Type
1515-from functools import partial
1616-1717-1818-1919-class MlpProjector(nn.Module):
2020-2121- def __init__(self, cfg):
2222-2323- super().__init__()
2424-2525- self.cfg = cfg
2626-2727- if cfg.projector_type == "identity":
2828- modules = nn.Identity()
2929-3030- elif cfg.projector_type == "linear":
3131- modules = nn.Linear(cfg.input_dim, cfg.n_embed)
3232-3333- elif cfg.projector_type == "mlp_gelu":
3434- mlp_depth = cfg.get("depth", 1)
3535- modules = [nn.Linear(cfg.input_dim, cfg.n_embed)]
3636- for _ in range(1, mlp_depth):
3737- modules.append(nn.GELU())
3838- modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
3939- modules = nn.Sequential(*modules)
4040-4141- elif cfg.projector_type == "normlayer_downsample_mlp_gelu":
4242- mlp_depth = cfg.get("depth", 1)
4343- mlp_ratio = cfg.get("mlp_ratio", 1)
4444- modules = [
4545- nn.LayerNorm(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio),
4646- nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)
4747- ]
4848- for _ in range(1, mlp_depth - 1):
4949- modules.append(nn.GELU())
5050- modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio))
5151- modules.append(nn.GELU())
5252- modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
5353- modules = nn.Sequential(*modules)
5454-5555- elif cfg.projector_type == "downsample_mlp_gelu":
5656- mlp_depth = cfg.get("depth", 1)
5757- mlp_ratio = cfg.get("mlp_ratio", 1)
5858- modules = [nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)]
5959- for _ in range(1, mlp_depth - 1):
6060- modules.append(nn.GELU())
6161- modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio))
6262- modules.append(nn.GELU())
6363- modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed))
6464- modules = nn.Sequential(*modules)
6565-6666- elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
6767- mlp_depth = cfg.get("depth", 1)
6868- self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
6969- self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
7070-7171- modules = []
7272- for _ in range(1, mlp_depth):
7373- modules.append(nn.GELU())
7474- modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
7575- modules = nn.Sequential(*modules)
7676-7777- elif cfg.projector_type == "hybrid_split_feature_mlp_gelu":
7878- mlp_depth = cfg.get("depth", 1)
7979- channel_div = cfg.get("channel_div", 0.5)
8080- self.high_up_proj = nn.Linear(cfg.input_dim[0], int(cfg.n_embed * channel_div))
8181- self.low_up_proj = nn.Linear(cfg.input_dim[1], cfg.n_embed - int(cfg.n_embed * channel_div))
8282-8383- modules = []
8484- for _ in range(1, mlp_depth):
8585- modules.append(nn.GELU())
8686- modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
8787- modules = nn.Sequential(*modules)
8888-8989- elif cfg.projector_type == "low_high_split_mlp_gelu":
9090- mlp_depth = cfg.get("depth", 1)
9191- modules = []
9292- for _ in range(1, mlp_depth):
9393- modules.append(nn.GELU())
9494- modules.append(nn.Linear(cfg.n_embed // 2, cfg.n_embed // 2))
9595- modules = nn.Sequential(*modules)
9696- self.high_layers = nn.Sequential(*modules)
9797- self.low_layers = copy.deepcopy(modules)
9898-9999- else:
100100- raise ValueError(f"Unknown projector type: {cfg.projector_type}")
101101-102102- if cfg.get("token_pooling", False):
103103- self.token_pooling_layer = nn.Linear(cfg.input_dim * 4, cfg.input_dim)
104104-105105- if cfg.get("conv_fusion_high_low_features", False):
106106- self.fusion_layer = nn.Linear(cfg.input_dim, cfg.input_dim)
107107- self.layers = modules
108108-109109- def forward(self, x):
110110- if self.cfg.get("token_pooling", False):
111111- batch_size, wxh, channels = x.shape
112112- w = h = int(wxh**0.5)
113113- x = x.view(batch_size, w, h, channels)
114114- x = x.permute(0, 3, 1, 2)
115115- # import ipdb; ipdb.set_trace()
116116- patches = x.unfold(2, 2, 2).unfold(3, 2, 2)
117117- batch_size, channels, h_patches, w_patches, _, _ = patches.size()
118118- # 在通道维度上拼接
119119- patches = patches.contiguous().view(batch_size, channels, h_patches * w_patches, -1)
120120-121121- # 通过线性层
122122- patches = patches.permute(0, 2, 1, 3).contiguous()
123123- patches = patches.view(batch_size, h_patches * w_patches, channels * 4)
124124-125125- x = self.token_pooling_layer(patches)
126126-127127- if self.cfg.get("conv_fusion_high_low_features", False):
128128- x = self.fusion_layer(x[:, 0]) + x[:, 1]
129129-130130- if self.cfg.projector_type == 'low_high_hybrid_split_mlp_gelu':
131131- high_x, low_x = x[0], x[1]
132132- high_x = self.high_up_proj(high_x)
133133- low_x = self.low_up_proj(low_x)
134134- x = torch.concat([high_x, low_x], dim=-1)
135135-136136- if self.cfg.projector_type == 'hybrid_split_feature_mlp_gelu':
137137- high_x = x[...,:self.cfg.input_dim[0]]
138138- low_x = x[...,self.cfg.input_dim[0]:]
139139- high_x = self.high_up_proj(high_x)
140140- low_x = self.low_up_proj(low_x)
141141- x = torch.concat([high_x, low_x], dim=-1)
142142-143143- if self.cfg.projector_type == 'low_high_split_mlp_gelu':
144144- high_x, low_x = x[0], x[1]
145145- high_x = self.high_layers(high_x)
146146- low_x = self.low_layers(low_x)
147147- x = torch.concat([high_x, low_x], dim=-1)
148148- return x
149149-150150- if self.cfg.projector_type == 'downsample_mlp_gelu' or self.cfg.projector_type == 'normlayer_downsample_mlp_gelu':
151151- bs, hw, input_dim = x.shape
152152- h = w = int((hw) ** 0.5)
153153-154154- """compute padding"""
155155- if h % self.cfg.downsample_ratio:
156156- pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio
157157- else:
158158- pad = 0
159159- x = x.reshape(bs, h, w, input_dim)
160160- if pad > 0:
161161- x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
162162-163163- """4 to 1 concat"""
164164- x = x.permute(0, 3, 1, 2) # B, C, H, W
165165- x = F.unfold(x, kernel_size=self.cfg.downsample_ratio, stride=self.cfg.downsample_ratio, padding=0) # B, C*4, HW // 4
166166- x = x.permute(0, 2, 1)
167167-168168- return self.layers(x)
169169-170170- @staticmethod
171171- def get_flops_per_sample(cfg):
172172- if cfg.projector_type == "linear":
173173- fwd = 2 * cfg.input_dim * cfg.n_embed
174174-175175- elif "mlp_gelu" in cfg.projector_type :
176176- mlp_depth = cfg.get("depth", 1)
177177- downsample_ratio = cfg.get("downsample_ratio", 1)
178178- input_dim = sum(cfg.input_dim) if isinstance(cfg.input_dim, list) else cfg.input_dim
179179- input_dim = input_dim * downsample_ratio * downsample_ratio
180180- fwd = 2 * input_dim * cfg.n_embed + (mlp_depth - 1) * 2 * cfg.n_embed * cfg.n_embed
181181- else:
182182- fwd = 0
183183-184184- return fwd * 3
185185-186186-187187-#===================qwen2================================
188188-189189-class CustomQwen2Decoder(nn.Module):
190190- """
191191- Qwen2 visual encoder
192192- non-causal attention + causal attention
193193- token_type_ids :0=non-causal, 1=causal
194194- """
195195-196196- def __init__(
197197- self,
198198- decoder_layer: int = 24,
199199- max_position_embeddings: int = 131072,
200200- hidden_dimension: int = 896,
201201- num_attention_heads: int = 14,
202202- num_key_value_heads: int = 2,
203203- intermediate_size: int = 4864,
204204- vocab_size: int = 151936,
205205- attn_implementation: str = "sdpa", # ⭐
206206- rms_norm_eps: float = 1e-06,
207207- rope_theta: float = 1000000.0,
208208- attention_dropout: float = 0.0,
209209- hidden_act: str = "silu",
210210- initializer_range: float = 0.02,
211211- ):
212212- super().__init__()
213213-214214- # attn_implementation check
215215- if attn_implementation == "flash_attention_2":
216216- raise ValueError(
217217- "CustomQwen2Decoder do not support flash_attention_2,"
218218- "new attention mask needs 'sdpa' or 'eager'"
219219- )
220220-221221- # load
222222- Qwen2Model = getattr(transformers.models.qwen2.modeling_qwen2, 'Qwen2Model')
223223- Qwen2Config = getattr(transformers, 'Qwen2Config')
224224-225225- # config
226226- config = Qwen2Config(
227227- hidden_size=hidden_dimension,
228228- num_hidden_layers=decoder_layer,
229229- num_attention_heads=num_attention_heads,
230230- num_key_value_heads=num_key_value_heads,
231231- intermediate_size=intermediate_size,
232232- max_position_embeddings=max_position_embeddings,
233233- vocab_size=vocab_size,
234234- rms_norm_eps=rms_norm_eps,
235235- rope_theta=rope_theta,
236236- attention_dropout=attention_dropout,
237237- hidden_act=hidden_act,
238238- initializer_range=initializer_range,
239239- _attn_implementation=attn_implementation, # ⭐
240240- )
241241-242242- #
243243- self.model = self._create_custom_model(Qwen2Model, config)
244244-245245- del self.model.embed_tokens
246246-247247- def _create_custom_model(self, Qwen2Model, config):
248248- """ Qwen2Model """
249249-250250- class CustomQwen2ModelInner(Qwen2Model):
251251-252252-253253- def forward(
254254- self,
255255- input_ids=None,
256256- attention_mask=None,
257257- position_ids=None,
258258- past_key_values=None,
259259- inputs_embeds=None,
260260- token_type_ids=None, # ⭐
261261- use_cache=None,
262262- output_attentions=None,
263263- output_hidden_states=None,
264264- return_dict=None,
265265- cache_position=None,
266266- ):
267267- # token_type_ids
268268- self._current_token_type_ids = token_type_ids
269269-270270- outputs = super().forward(
271271- input_ids=input_ids,
272272- attention_mask=attention_mask,
273273- position_ids=position_ids,
274274- past_key_values=past_key_values,
275275- inputs_embeds=inputs_embeds,
276276- use_cache=use_cache,
277277- output_attentions=output_attentions,
278278- output_hidden_states=output_hidden_states,
279279- return_dict=return_dict,
280280- cache_position=cache_position,
281281- )
282282-283283- return outputs
284284-285285- def _update_causal_mask(
286286- self,
287287- attention_mask,
288288- input_tensor,
289289- cache_position,
290290- past_key_values,
291291- output_attentions,
292292- ):
293293- dtype, device = input_tensor.dtype, input_tensor.device
294294- min_dtype = torch.finfo(dtype).min
295295- batch_size, sequence_length = input_tensor.shape[0], input_tensor.shape[1]
296296-297297- token_type_ids = self._current_token_type_ids
298298-299299- # attention mask
300300- causal_mask = self._create_custom_4d_mask(
301301- sequence_length=sequence_length,
302302- dtype=dtype,
303303- device=device,
304304- batch_size=batch_size,
305305- token_type_ids=token_type_ids,
306306- )
307307-308308- # padding mask
309309- if attention_mask is not None and attention_mask.dim() == 2:
310310- padding_mask = attention_mask[:, None, None, :].to(dtype=dtype)
311311- padding_mask = (1.0 - padding_mask) * min_dtype
312312- causal_mask = causal_mask + padding_mask
313313-314314- return causal_mask
315315-316316- def _create_custom_4d_mask(
317317- self,
318318- sequence_length,
319319- dtype,
320320- device,
321321- batch_size,
322322- token_type_ids,
323323- ):
324324- min_dtype = torch.finfo(dtype).min
325325-326326- masks = []
327327- for b in range(batch_size):
328328- mask = torch.full(
329329- (sequence_length, sequence_length),
330330- fill_value=min_dtype,
331331- dtype=dtype,
332332- device=device
333333- )
334334-335335- type_ids = token_type_ids[b]
336336-337337- image_positions = (type_ids == 0).nonzero(as_tuple=True)[0]
338338- text_positions = (type_ids == 1).nonzero(as_tuple=True)[0]
339339-340340- # non-casual
341341- if len(image_positions) > 0:
342342- mask[image_positions[:, None], image_positions] = 0.0
343343-344344- # causal
345345- for i, text_pos in enumerate(text_positions):
346346- if len(image_positions) > 0:
347347- mask[text_pos, image_positions] = 0.0
348348- mask[text_pos, text_positions[:i+1]] = 0.0
349349-350350- masks.append(mask)
351351-352352- mask = torch.stack(masks, dim=0).unsqueeze(1)
353353- return mask
354354-355355- return CustomQwen2ModelInner(config)
356356-357357- def forward(
358358- self,
359359- inputs_embeds,
360360- token_type_ids,
361361- attention_mask=None,
362362- **kwargs
363363- ):
364364- """
365365- Args:
366366- inputs_embeds: [batch_size, seq_len, hidden_dim]
367367- token_type_ids: [batch_size, seq_len], 0=non-causal, 1=causal
368368- attention_mask: [batch_size, seq_len], optional
369369- """
370370- return self.model(
371371- inputs_embeds=inputs_embeds,
372372- token_type_ids=token_type_ids,
373373- attention_mask=attention_mask,
374374- **kwargs
375375- )
376376-377377-378378-379379-380380-381381-# batch_size = 2
382382-# inputs_embeds = torch.randn(batch_size, 512, 896).cuda()
383383-384384-# inputs_embeds = torch.randn(batch_size, 512, 896).cuda()
385385-# token_type_ids = torch.cat([
386386-# torch.zeros(batch_size, 256, dtype=torch.long),
387387-# torch.ones(batch_size, 256, dtype=torch.long),
388388-# ], dim=1).cuda()
389389-390390-# # start = time.time()
391391-# with torch.no_grad():
392392-# outputs_sdpa = decoder_sdpa(inputs_embeds, token_type_ids)
393393-# print(outputs_sdpa[0].shape)
394394-# print(f"SDPA time: {time.time() - start:.4f}s")
395395-396396-397397-398398-class Qwen2Decoder2Encoder(nn.Module):
399399- """
400400- Decoder based on Multilingual BART
401401- Set the initial weights and configuration with a pretrained multilingual BART model,
402402- and modify the detailed configurations as a Nougat decoder
403403- """
404404-405405- def __init__(
406406- self,
407407- decoder_layer: int,
408408- hidden_dimension: int,
409409- num_attention_heads: int,
410410- num_key_value_heads: int,
411411- intermediate_size: int,
412412- max_query: int,
413413- ):
414414- super().__init__()
415415-416416- self.model = CustomQwen2Decoder(
417417- decoder_layer=decoder_layer,
418418- hidden_dimension=hidden_dimension,
419419- num_attention_heads=num_attention_heads,
420420- num_key_value_heads=num_key_value_heads,
421421- intermediate_size=intermediate_size,
422422- attn_implementation="sdpa",
423423- )
424424-425425-426426-427427-428428- self.query_768 = nn.Embedding(144, hidden_dimension)
429429- self.query_1024 = nn.Embedding(256, hidden_dimension)
430430-431431-432432- # self.query_refixation = nn.Embedding(int(math.sqrt(max_query)), hidden_dimension)
433433-434434-435435- def forward(self, x: torch.Tensor) -> torch.Tensor:
436436- x = x.flatten(2).transpose(1, 2)
437437-438438- bs, n_query, _ = x.shape
439439-440440- if n_query == 144:
441441- param_img = self.query_768.weight
442442- elif n_query == 256:
443443- param_img = self.query_1024.weight
444444-445445- batch_query_imgs = param_img.unsqueeze(0).expand(
446446- bs, -1, -1
447447- ) # (batch_size, num_queries, hidden_size)
448448-449449-450450-451451- x_combined = torch.cat([x, batch_query_imgs], dim=1)
452452-453453- token_type_ids = torch.cat([
454454- torch.zeros(bs, n_query, dtype=torch.long),
455455- torch.ones(bs, n_query, dtype=torch.long),
456456- ], dim=1)
457457-458458-459459- y = self.model(x_combined, token_type_ids)[0]
460460-461461-462462- y = y[:, n_query:, :] # causal flow query
463463-464464-465465- return y
466466-467467-468468-def build_qwen2_decoder_as_encoder(
469469- decoder_layer=24,
470470- hidden_dimension=896,
471471- num_attention_heads=14,
472472- num_key_value_heads=2,
473473- intermediate_size=4864,
474474- max_query = 400,
475475- checkpoint=None,
476476-):
477477-478478- decoder_as_encoder = Qwen2Decoder2Encoder(
479479- decoder_layer=decoder_layer,
480480- hidden_dimension = hidden_dimension,
481481- num_attention_heads = num_attention_heads,
482482- num_key_value_heads = num_key_value_heads,
483483- intermediate_size = intermediate_size,
484484- max_query = max_query
485485- )
486486-487487-488488-489489-490490- if checkpoint is not None:
491491- # with open(checkpoint, "rb") as f:
492492- state_dict = torch.load(checkpoint)
493493-494494- decoder_as_encoder.load_state_dict(state_dict, strict=True)
495495- # tob
496496- print(checkpoint)
497497- return decoder_as_encoder
498498-499499-500500-501501-502502-#=========================Sam-Vary=================================
503503-504504-505505-def get_abs_pos_sam(abs_pos, tgt_size):
506506-507507- dtype = abs_pos.dtype
508508-509509- src_size = abs_pos.size(1)
510510-511511- if src_size != tgt_size:
512512- old_pos_embed = abs_pos.permute(0, 3, 1, 2)
513513- old_pos_embed = old_pos_embed.to(torch.float32)
514514- new_pos_embed = F.interpolate(
515515- old_pos_embed,
516516- size=(tgt_size, tgt_size),
517517- mode='bicubic',
518518- antialias=True,
519519- align_corners=False,
520520- ).to(dtype)
521521- new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
522522- return new_pos_embed
523523- else:
524524- return abs_pos
525525-526526-527527-528528-529529-class MLPBlock(nn.Module):
530530- def __init__(
531531- self,
532532- embedding_dim: int,
533533- mlp_dim: int,
534534- act: Type[nn.Module] = nn.GELU,
535535- ) -> None:
536536- super().__init__()
537537- self.lin1 = nn.Linear(embedding_dim, mlp_dim)
538538- self.lin2 = nn.Linear(mlp_dim, embedding_dim)
539539- self.act = act()
540540-541541- def forward(self, x: torch.Tensor) -> torch.Tensor:
542542- return self.lin2(self.act(self.lin1(x)))
543543-544544-545545-# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
546546-# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
547547-class LayerNorm2d(nn.Module):
548548- def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
549549- super().__init__()
550550- self.weight = nn.Parameter(torch.ones(num_channels))
551551- self.bias = nn.Parameter(torch.zeros(num_channels))
552552- self.eps = eps
553553-554554- def forward(self, x: torch.Tensor) -> torch.Tensor:
555555- u = x.mean(1, keepdim=True)
556556- s = (x - u).pow(2).mean(1, keepdim=True)
557557- x = (x - u) / torch.sqrt(s + self.eps)
558558- x = self.weight[:, None, None] * x + self.bias[:, None, None]
559559- return x
560560-561561-562562-# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
563563-class ImageEncoderViT(nn.Module):
564564- def __init__(
565565- self,
566566- img_size: int = 1024,
567567- patch_size: int = 16,
568568- in_chans: int = 3,
569569- embed_dim: int = 768,
570570- depth: int = 12,
571571- num_heads: int = 12,
572572- mlp_ratio: float = 4.0,
573573- out_chans: int = 256,
574574- qkv_bias: bool = True,
575575- norm_layer: Type[nn.Module] = nn.LayerNorm,
576576- act_layer: Type[nn.Module] = nn.GELU,
577577- use_abs_pos: bool = True,
578578- use_rel_pos: bool = False,
579579- rel_pos_zero_init: bool = True,
580580- window_size: int = 0,
581581- global_attn_indexes: Tuple[int, ...] = (),
582582- ) -> None:
583583- """
584584- Args:
585585- img_size (int): Input image size.
586586- patch_size (int): Patch size.
587587- in_chans (int): Number of input image channels.
588588- embed_dim (int): Patch embedding dimension.
589589- depth (int): Depth of ViT.
590590- num_heads (int): Number of attention heads in each ViT block.
591591- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
592592- qkv_bias (bool): If True, add a learnable bias to query, key, value.
593593- norm_layer (nn.Module): Normalization layer.
594594- act_layer (nn.Module): Activation layer.
595595- use_abs_pos (bool): If True, use absolute positional embeddings.
596596- use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
597597- rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
598598- window_size (int): Window size for window attention blocks.
599599- global_attn_indexes (list): Indexes for blocks using global attention.
600600- """
601601- super().__init__()
602602- self.img_size = img_size
603603-604604- self.patch_embed = PatchEmbed(
605605- kernel_size=(patch_size, patch_size),
606606- stride=(patch_size, patch_size),
607607- in_chans=in_chans,
608608- embed_dim=embed_dim,
609609- )
610610-611611- self.pos_embed: Optional[nn.Parameter] = None
612612- if use_abs_pos:
613613- # Initialize absolute positional embedding with pretrain image size.
614614- self.pos_embed = nn.Parameter(
615615- torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
616616- )
617617-618618- self.blocks = nn.ModuleList()
619619- for i in range(depth):
620620- block = Block(
621621- dim=embed_dim,
622622- num_heads=num_heads,
623623- mlp_ratio=mlp_ratio,
624624- qkv_bias=qkv_bias,
625625- norm_layer=norm_layer,
626626- act_layer=act_layer,
627627- use_rel_pos=use_rel_pos,
628628- rel_pos_zero_init=rel_pos_zero_init,
629629- window_size=window_size if i not in global_attn_indexes else 0,
630630- input_size=(img_size // patch_size, img_size // patch_size),
631631- )
632632- self.blocks.append(block)
633633-634634- self.neck = nn.Sequential(
635635- nn.Conv2d(
636636- embed_dim,
637637- out_chans,
638638- kernel_size=1,
639639- bias=False,
640640- ),
641641- LayerNorm2d(out_chans),
642642- nn.Conv2d(
643643- out_chans,
644644- out_chans,
645645- kernel_size=3,
646646- padding=1,
647647- bias=False,
648648- ),
649649- LayerNorm2d(out_chans),
650650- )
651651-652652- self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
653653- self.net_3 = nn.Conv2d(512, 896, kernel_size=3, stride=2, padding=1, bias=False)
654654-655655- def forward(self, x: torch.Tensor) -> torch.Tensor:
656656- x = self.patch_embed(x)
657657- if self.pos_embed is not None:
658658- # x = x + self.pos_embed
659659- x = x + get_abs_pos_sam(self.pos_embed, x.size(1))
660660-661661- for blk in self.blocks:
662662- x = blk(x)
663663-664664- x = self.neck(x.permute(0, 3, 1, 2))
665665- x2 = self.net_2(x)
666666- x3 = self.net_3(x2.clone())
667667-668668- return x3
669669-670670-671671-class Block(nn.Module):
672672- """Transformer blocks with support of window attention and residual propagation blocks"""
673673-674674- def __init__(
675675- self,
676676- dim: int,
677677- num_heads: int,
678678- mlp_ratio: float = 4.0,
679679- qkv_bias: bool = True,
680680- norm_layer: Type[nn.Module] = nn.LayerNorm,
681681- act_layer: Type[nn.Module] = nn.GELU,
682682- use_rel_pos: bool = False,
683683- rel_pos_zero_init: bool = True,
684684- window_size: int = 0,
685685- input_size: Optional[Tuple[int, int]] = None,
686686- ) -> None:
687687- """
688688- Args:
689689- dim (int): Number of input channels.
690690- num_heads (int): Number of attention heads in each ViT block.
691691- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
692692- qkv_bias (bool): If True, add a learnable bias to query, key, value.
693693- norm_layer (nn.Module): Normalization layer.
694694- act_layer (nn.Module): Activation layer.
695695- use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
696696- rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
697697- window_size (int): Window size for window attention blocks. If it equals 0, then
698698- use global attention.
699699- input_size (tuple(int, int) or None): Input resolution for calculating the relative
700700- positional parameter size.
701701- """
702702- super().__init__()
703703- self.norm1 = norm_layer(dim)
704704- self.attn = Attention(
705705- dim,
706706- num_heads=num_heads,
707707- qkv_bias=qkv_bias,
708708- use_rel_pos=use_rel_pos,
709709- rel_pos_zero_init=rel_pos_zero_init,
710710- input_size=input_size if window_size == 0 else (window_size, window_size),
711711- )
712712-713713- self.norm2 = norm_layer(dim)
714714- self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
715715-716716- self.window_size = window_size
717717-718718- def forward(self, x: torch.Tensor) -> torch.Tensor:
719719- shortcut = x
720720- x = self.norm1(x)
721721- # Window partition
722722- if self.window_size > 0:
723723- H, W = x.shape[1], x.shape[2]
724724- x, pad_hw = window_partition(x, self.window_size)
725725-726726- x = self.attn(x)
727727- # Reverse window partition
728728- if self.window_size > 0:
729729- x = window_unpartition(x, self.window_size, pad_hw, (H, W))
730730-731731- x = shortcut + x
732732- x = x + self.mlp(self.norm2(x))
733733-734734- return x
735735-736736-737737-class Attention(nn.Module):
738738- """Multi-head Attention block with relative position embeddings."""
739739-740740- def __init__(
741741- self,
742742- dim: int,
743743- num_heads: int = 8,
744744- qkv_bias: bool = True,
745745- use_rel_pos: bool = False,
746746- rel_pos_zero_init: bool = True,
747747- input_size: Optional[Tuple[int, int]] = None,
748748- ) -> None:
749749- """
750750- Args:
751751- dim (int): Number of input channels.
752752- num_heads (int): Number of attention heads.
753753- qkv_bias (bool): If True, add a learnable bias to query, key, value.
754754- rel_pos (bool): If True, add relative positional embeddings to the attention map.
755755- rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
756756- input_size (tuple(int, int) or None): Input resolution for calculating the relative
757757- positional parameter size.
758758- """
759759- super().__init__()
760760- self.num_heads = num_heads
761761- head_dim = dim // num_heads
762762- self.scale = head_dim**-0.5
763763-764764- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
765765- self.proj = nn.Linear(dim, dim)
766766-767767- self.use_rel_pos = use_rel_pos
768768- if self.use_rel_pos:
769769- assert (
770770- input_size is not None
771771- ), "Input size must be provided if using relative positional encoding."
772772- # initialize relative positional embeddings
773773- self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
774774- self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
775775-776776- def forward(self, x: torch.Tensor) -> torch.Tensor:
777777- B, H, W, _ = x.shape
778778- # qkv with shape (3, B, nHead, H * W, C)
779779- qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
780780- # q, k, v with shape (B * nHead, H * W, C)
781781- q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
782782-783783- rel_h, rel_w = None, None
784784- if self.use_rel_pos:
785785- rel_h, rel_w = add_decomposed_rel_pos(q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
786786-787787- q = q.view(B, self.num_heads, H * W, -1)
788788- k = k.view(B, self.num_heads, H * W, -1)
789789- v = v.view(B, self.num_heads, H * W, -1)
790790-791791- if self.use_rel_pos:
792792- rel_h = rel_h.view(B, self.num_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3))
793793- rel_w = rel_w.view(B, self.num_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3))
794794- attn_bias = (rel_h + rel_w).view(B, self.num_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4))
795795- x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
796796- # x = _attention_rel_h_rel_w(q, k, v, rel_h, rel_w)
797797- else:
798798- x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
799799-800800- x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
801801-802802- x = self.proj(x)
803803-804804- return x
805805-806806-807807-def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
808808- """
809809- Partition into non-overlapping windows with padding if needed.
810810- Args:
811811- x (tensor): input tokens with [B, H, W, C].
812812- window_size (int): window size.
813813-814814- Returns:
815815- windows: windows after partition with [B * num_windows, window_size, window_size, C].
816816- (Hp, Wp): padded height and width before partition
817817- """
818818- B, H, W, C = x.shape
819819-820820- pad_h = (window_size - H % window_size) % window_size
821821- pad_w = (window_size - W % window_size) % window_size
822822- if pad_h > 0 or pad_w > 0:
823823- x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
824824- Hp, Wp = H + pad_h, W + pad_w
825825-826826- x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
827827- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
828828- return windows, (Hp, Wp)
829829-830830-831831-def window_unpartition(
832832- windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
833833-) -> torch.Tensor:
834834- """
835835- Window unpartition into original sequences and removing padding.
836836- Args:
837837- windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
838838- window_size (int): window size.
839839- pad_hw (Tuple): padded height and width (Hp, Wp).
840840- hw (Tuple): original height and width (H, W) before padding.
841841-842842- Returns:
843843- x: unpartitioned sequences with [B, H, W, C].
844844- """
845845- Hp, Wp = pad_hw
846846- H, W = hw
847847- B = windows.shape[0] // (Hp * Wp // window_size // window_size)
848848- x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
849849- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
850850-851851- if Hp > H or Wp > W:
852852- x = x[:, :H, :W, :].contiguous()
853853- return x
854854-855855-856856-def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
857857- """
858858- Get relative positional embeddings according to the relative positions of
859859- query and key sizes.
860860- Args:
861861- q_size (int): size of query q.
862862- k_size (int): size of key k.
863863- rel_pos (Tensor): relative position embeddings (L, C).
864864-865865- Returns:
866866- Extracted positional embeddings according to relative positions.
867867- """
868868- max_rel_dist = int(2 * max(q_size, k_size) - 1)
869869- # Interpolate rel pos if needed.
870870- if rel_pos.shape[0] != max_rel_dist:
871871- # Interpolate rel pos.
872872- dtype = rel_pos.dtype
873873- rel_pos = rel_pos.to(torch.float32)
874874- rel_pos_resized = F.interpolate(
875875- rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
876876- size=max_rel_dist,
877877- mode="linear",
878878- ).to(dtype)
879879- rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
880880- else:
881881- rel_pos_resized = rel_pos
882882-883883- # Scale the coords with short length if shapes for q and k are different.
884884- q_coords = torch.arange(q_size, device=rel_pos.device)[:, None] * max(k_size / q_size, 1.0)
885885- k_coords = torch.arange(k_size, device=rel_pos.device)[None, :] * max(q_size / k_size, 1.0)
886886- relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
887887-888888- return rel_pos_resized[relative_coords.long()]
889889-890890-891891-def add_decomposed_rel_pos(
892892- q: torch.Tensor,
893893- rel_pos_h: torch.Tensor,
894894- rel_pos_w: torch.Tensor,
895895- q_size: Tuple[int, int],
896896- k_size: Tuple[int, int],
897897-) -> torch.Tensor:
898898- """
899899- Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
900900- https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
901901- Args:
902902- q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
903903- rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
904904- rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
905905- q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
906906- k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
907907-908908- Returns:
909909- attn (Tensor): attention map with added relative positional embeddings.
910910- """
911911- q_h, q_w = q_size
912912- k_h, k_w = k_size
913913- Rh = get_rel_pos(q_h, k_h, rel_pos_h)
914914- Rw = get_rel_pos(q_w, k_w, rel_pos_w)
915915-916916- B, _, dim = q.shape
917917- r_q = q.reshape(B, q_h, q_w, dim)
918918- rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
919919- rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
920920- rel_h = rel_h.unsqueeze(-1)
921921- rel_w = rel_w.unsqueeze(-2)
922922- rel_h = rel_h.reshape(B, q_h * q_w, k_h, 1)
923923- rel_w = rel_w.reshape(B, q_h * q_w, 1, k_w)
924924-925925- return rel_h, rel_w
926926-927927-928928-class PatchEmbed(nn.Module):
929929- """
930930- Image to Patch Embedding.
931931- """
932932-933933- def __init__(
934934- self,
935935- kernel_size: Tuple[int, int] = (16, 16),
936936- stride: Tuple[int, int] = (16, 16),
937937- padding: Tuple[int, int] = (0, 0),
938938- in_chans: int = 3,
939939- embed_dim: int = 768,
940940- ) -> None:
941941- """
942942- Args:
943943- kernel_size (Tuple): kernel size of the projection layer.
944944- stride (Tuple): stride of the projection layer.
945945- padding (Tuple): padding size of the projection layer.
946946- in_chans (int): Number of input image channels.
947947- embed_dim (int): Patch embedding dimension.
948948- """
949949- super().__init__()
950950-951951- self.proj = nn.Conv2d(
952952- in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
953953- )
954954-955955- def forward(self, x: torch.Tensor) -> torch.Tensor:
956956- x = self.proj(x)
957957- # B C H W -> B H W C
958958- x = x.permute(0, 2, 3, 1)
959959- return x
960960-961961-962962-def build_sam_vit_b(checkpoint=None):
963963- return _build_sam(
964964- encoder_embed_dim=768,
965965- encoder_depth=12,
966966- encoder_num_heads=12,
967967- encoder_global_attn_indexes=[2, 5, 8, 11],
968968- checkpoint=checkpoint,
969969- )
970970-971971-def build_sam_fast_vit_b(checkpoint=None, compile_mode='max-autotune', dtype=torch.bfloat16):
972972- image_encoder = build_sam_vit_b(checkpoint).eval().to(dtype)
973973- # sam = _apply_eval_dtype_sam(sam, dtype)
974974- image_encoder = torch.compile(image_encoder, mode=compile_mode)
975975- return image_encoder
976976-977977-978978-def _build_sam(
979979- encoder_embed_dim,
980980- encoder_depth,
981981- encoder_num_heads,
982982- encoder_global_attn_indexes,
983983- checkpoint=None,
984984-):
985985- prompt_embed_dim = 256
986986- image_size = 1024
987987- vit_patch_size = 16
988988- image_embedding_size = image_size // vit_patch_size
989989- image_encoder=ImageEncoderViT(
990990- depth=encoder_depth,
991991- embed_dim=encoder_embed_dim,
992992- img_size=image_size,
993993- mlp_ratio=4,
994994- norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
995995- num_heads=encoder_num_heads,
996996- patch_size=vit_patch_size,
997997- qkv_bias=True,
998998- use_rel_pos=True,
999999- global_attn_indexes=encoder_global_attn_indexes,
10001000- window_size=14,
10011001- out_chans=prompt_embed_dim,
10021002- )
10031003- image_encoder.eval()
10041004- if checkpoint is not None:
10051005- # with open(checkpoint, "rb") as f:
10061006- state_dict = torch.load(checkpoint)
10071007- # print(state_dict.keys())
10081008- # for key in state_dict:
10091009- # image_encoder.load_state_dict({k[14:]: v for k, v in state_dict.items() if 'image_encoder' in k}, strict=False)
10101010- # ocr-anyting
10111011- # image_encoder.load_state_dict(state_dict, strict=True)
10121012- # tob
10131013- image_encoder.load_state_dict({k[30:]: v for k, v in state_dict.items() if 'vision_tower_high' in k}, strict=True)
10141014- print(checkpoint)
10151015- return image_encoder
-1029
src/deepseek_ocr2/modeling_deepseekocr2.py
···11-from .modeling_deepseekv2 import DeepseekV2Model, DeepseekV2ForCausalLM
22-from .configuration_deepseek_v2 import DeepseekV2Config
33-from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
44-from typing import List, Optional, Tuple, Union
55-from transformers.cache_utils import Cache
66-import requests
77-from PIL import Image, ImageOps, ImageDraw, ImageFont
88-from io import BytesIO
99-import torch
1010-import torch.nn as nn
1111-from torch.nn import CrossEntropyLoss
1212-from torchvision import transforms
1313-# from torchvision.transforms.functional import InterpolationMode
1414-import os
1515-from .deepencoderv2 import build_sam_vit_b, build_qwen2_decoder_as_encoder, MlpProjector
1616-from addict import Dict
1717-from transformers import TextStreamer
1818-from .conversation import get_conv_template
1919-from abc import ABC
2020-import math
2121-import re
2222-from tqdm import tqdm
2323-import numpy as np
2424-# import time
2525-2626-2727-2828-def load_image(image_path):
2929-3030- try:
3131- image = Image.open(image_path)
3232-3333- corrected_image = ImageOps.exif_transpose(image)
3434-3535- return corrected_image
3636-3737- except Exception as e:
3838- print(f"error: {e}")
3939- try:
4040- return Image.open(image_path)
4141- except:
4242- return None
4343-4444-4545-def re_match(text):
4646- pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)'
4747- matches = re.findall(pattern, text, re.DOTALL)
4848-4949- # pattern1 = r'<\|ref\|>.*?<\|/ref\|>\n'
5050- # new_text1 = re.sub(pattern1, '', text, flags=re.DOTALL)
5151-5252- mathes_image = []
5353- mathes_other = []
5454- for a_match in matches:
5555- if '<|ref|>image<|/ref|>' in a_match[0]:
5656- mathes_image.append(a_match[0])
5757- else:
5858- mathes_other.append(a_match[0])
5959- return matches, mathes_image, mathes_other
6060-6161-6262-def extract_coordinates_and_label(ref_text, image_width, image_height):
6363-6464- try:
6565- label_type = ref_text[1]
6666- cor_list = eval(ref_text[2])
6767- except Exception as e:
6868- print(e)
6969- return None
7070-7171- return (label_type, cor_list)
7272-7373-7474-def draw_bounding_boxes(image, refs, ouput_path):
7575-7676- image_width, image_height = image.size
7777-7878- img_draw = image.copy()
7979- draw = ImageDraw.Draw(img_draw)
8080-8181- overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0))
8282- draw2 = ImageDraw.Draw(overlay)
8383-8484- # try:
8585- # except IOError:
8686- # try:
8787- # font = ImageFont.truetype("DejaVuSans.ttf", 20)
8888- # except IOError:
8989- font = ImageFont.load_default()
9090-9191- img_idx = 0
9292-9393- for i, ref in enumerate(refs):
9494- try:
9595- result = extract_coordinates_and_label(ref, image_width, image_height)
9696- if result:
9797- label_type, points_list = result
9898-9999- color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255))
100100-101101- color_a = color + (20, )
102102- for points in points_list:
103103- x1, y1, x2, y2 = points
104104-105105- x1 = int(x1 / 999 * image_width)
106106- y1 = int(y1 / 999 * image_height)
107107-108108- x2 = int(x2 / 999 * image_width)
109109- y2 = int(y2 / 999 * image_height)
110110-111111- if label_type == 'image':
112112- try:
113113- cropped = image.crop((x1, y1, x2, y2))
114114- cropped.save(f"{ouput_path}/images/{img_idx}.jpg")
115115- except Exception as e:
116116- print(e)
117117- pass
118118- img_idx += 1
119119-120120- try:
121121- if label_type == 'title':
122122- draw.rectangle([x1, y1, x2, y2], outline=color, width=4)
123123- draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
124124- else:
125125- draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
126126- draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1)
127127- text_x = x1
128128- text_y = max(0, y1 - 15)
129129-130130-131131- text_bbox = draw.textbbox((0, 0), label_type, font=font)
132132- text_width = text_bbox[2] - text_bbox[0]
133133- text_height = text_bbox[3] - text_bbox[1]
134134- draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height],
135135- fill=(255, 255, 255, 30))
136136-137137- draw.text((text_x, text_y), label_type, font=font, fill=color)
138138- except:
139139- pass
140140- except:
141141- continue
142142- img_draw.paste(overlay, (0, 0), overlay)
143143- return img_draw
144144-145145-146146-def process_image_with_refs(image, ref_texts, output_path):
147147-148148- result_image = draw_bounding_boxes(image, ref_texts, output_path)
149149-150150- return result_image
151151-152152-153153-154154-155155-156156-def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
157157- best_ratio_diff = float('inf')
158158- best_ratio = (1, 1)
159159- area = width * height
160160- for ratio in target_ratios:
161161- target_aspect_ratio = ratio[0] / ratio[1]
162162- ratio_diff = abs(aspect_ratio - target_aspect_ratio)
163163- if ratio_diff < best_ratio_diff:
164164- best_ratio_diff = ratio_diff
165165- best_ratio = ratio
166166- elif ratio_diff == best_ratio_diff:
167167- if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
168168- best_ratio = ratio
169169- # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
170170- return best_ratio
171171-172172-173173-def dynamic_preprocess(image, min_num=2, max_num=6, image_size=768, use_thumbnail=False):
174174- orig_width, orig_height = image.size
175175- aspect_ratio = orig_width / orig_height
176176-177177- # calculate the existing image aspect ratio
178178- target_ratios = set(
179179- (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
180180- i * j <= max_num and i * j >= min_num)
181181- # print(target_ratios)
182182- target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
183183-184184- # find the closest aspect ratio to the target
185185- target_aspect_ratio = find_closest_aspect_ratio(
186186- aspect_ratio, target_ratios, orig_width, orig_height, image_size)
187187-188188- # print(target_aspect_ratio)
189189- # calculate the target width and height
190190- target_width = image_size * target_aspect_ratio[0]
191191- target_height = image_size * target_aspect_ratio[1]
192192- blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
193193-194194- # resize the image
195195- resized_img = image.resize((target_width, target_height))
196196- processed_images = []
197197- for i in range(blocks):
198198- box = (
199199- (i % (target_width // image_size)) * image_size,
200200- (i // (target_width // image_size)) * image_size,
201201- ((i % (target_width // image_size)) + 1) * image_size,
202202- ((i // (target_width // image_size)) + 1) * image_size
203203- )
204204- # split the image
205205- split_img = resized_img.crop(box)
206206- processed_images.append(split_img)
207207- assert len(processed_images) == blocks
208208- if use_thumbnail and len(processed_images) != 1:
209209- thumbnail_img = image.resize((image_size, image_size))
210210- processed_images.append(thumbnail_img)
211211- return processed_images, target_aspect_ratio
212212-213213-214214-215215-def normalize_transform(mean, std):
216216- if mean is None and std is None:
217217- transform = None
218218- elif mean is None and std is not None:
219219- mean = [0.] * len(std)
220220- transform = transforms.Normalize(mean=mean, std=std)
221221- elif mean is not None and std is None:
222222- std = [1.] * len(mean)
223223- transform = transforms.Normalize(mean=mean, std=std)
224224- else:
225225- transform = transforms.Normalize(mean=mean, std=std)
226226-227227- return transform
228228-229229-230230-231231-def format_messages(
232232- conversations: List[Dict[str, str]],
233233- sft_format: str = "deepseek",
234234- system_prompt: str = "",
235235-):
236236- """
237237- Applies the SFT template to conversation.
238238-239239- Args:
240240- conversations (List[Dict]): A List of messages.
241241- sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek".
242242- system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "".
243243-244244- Returns:
245245- sft_prompt (str): The formatted text.
246246- """
247247-248248- conv = get_conv_template(sft_format)
249249- conv.set_system_message(system_prompt)
250250- for message in conversations:
251251- conv.append_message(message["role"], message["content"].strip())
252252- sft_prompt = conv.get_prompt().strip()
253253-254254- return sft_prompt
255255-256256-257257-def text_encode(tokenizer, text: str, bos: bool = True, eos: bool = False):
258258- t = tokenizer.encode(text, add_special_tokens=False)
259259- bos_id = 0
260260- eos_id = 1
261261- if bos:
262262- t = [bos_id] + t
263263- if eos:
264264- t = t + [eos_id]
265265-266266- return t
267267-268268-def load_pil_images(conversations: List[Dict[str, str]]) -> List[Image.Image]:
269269- """
270270-271271- Args:
272272- conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is :
273273- [
274274- {
275275- "role": "User",
276276- "content": "<image_placeholder>\nExtract all information from this image and convert them into markdown format.",
277277- "images": ["./examples/table_datasets.png"]
278278- },
279279- {"role": "Assistant", "content": ""},
280280- ]
281281-282282- Returns:
283283- pil_images (List[PIL.Image.Image]): the list of PIL images.
284284-285285- """
286286-287287- pil_images = []
288288-289289- for message in conversations:
290290- if "images" not in message:
291291- continue
292292-293293- for image_path in message["images"]:
294294- # print('----------------')
295295- # print(image_path)
296296- # print('----------------')
297297- # exit()
298298-299299- # pil_img = Image.open(image_path)
300300- pil_img = load_image(image_path)
301301- pil_img = pil_img.convert("RGB")
302302- pil_images.append(pil_img)
303303-304304- return pil_images
305305-306306-307307-class BaseTransform(ABC):
308308-309309- def set_rng(self, *args, **kwargs):
310310- pass
311311-312312- def __call__(self, *args, **kwargs) -> torch.Tensor:
313313- pass
314314-315315- @property
316316- def default_shape(self):
317317- raise NotImplementedError
318318-319319-320320-class BasicImageTransform(BaseTransform):
321321- def __init__(
322322- self,
323323- mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
324324- std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
325325- normalize: bool = True
326326- ):
327327- self.mean = mean
328328- self.std = std
329329-330330- transform_pipelines = [
331331- transforms.ToTensor()
332332- ]
333333-334334- normalize = normalize_transform(mean, std) if normalize else nn.Identity()
335335- if normalize is not None:
336336- transform_pipelines.append(normalize)
337337-338338- self.transform = transforms.Compose(transform_pipelines)
339339-340340- def __call__(self, x):
341341- x = self.transform(x)
342342- return x
343343-344344-class NoEOSTextStreamer(TextStreamer):
345345- def on_finalized_text(self, text: str, stream_end: bool = False):
346346-347347- eos_text = self.tokenizer.decode([self.tokenizer.eos_token_id], skip_special_tokens=False)
348348- text = text.replace(eos_text, "\n")
349349- print(text, flush=True, end="")
350350-351351-352352-class DeepseekOCR2Config(DeepseekV2Config):
353353- model_type = "DeepseekOCR2"
354354-355355-class DeepseekOCR2Model(DeepseekV2Model):
356356- config_class = DeepseekOCR2Config
357357-358358- def __init__(self, config: DeepseekV2Config):
359359- super(DeepseekOCR2Model, self).__init__(config)
360360-361361- self.sam_model = build_sam_vit_b()
362362- self.qwen2_model = build_qwen2_decoder_as_encoder()
363363- # self.conv_2 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=2, stride=2)
364364- n_embed = 1280
365365- self.projector = MlpProjector(Dict(projector_type="linear", input_dim=896, n_embed=n_embed))
366366- embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32))
367367- # self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std)
368368- self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std)
369369-370370-371371-372372-373373- def forward(
374374- self,
375375- input_ids: torch.LongTensor = None,
376376- attention_mask: Optional[torch.Tensor] = None,
377377- position_ids: Optional[torch.LongTensor] = None,
378378- past_key_values: Optional[List[torch.FloatTensor]] = None,
379379- inputs_embeds: Optional[torch.FloatTensor] = None,
380380- use_cache: Optional[bool] = None,
381381- output_attentions: Optional[bool] = None,
382382- output_hidden_states: Optional[bool] = None,
383383- images: Optional[torch.FloatTensor] = None,
384384- images_seq_mask: Optional[torch.FloatTensor] = None,
385385- images_spatial_crop: Optional[torch.FloatTensor] = None,
386386- return_dict: Optional[bool] = None,
387387- ) -> Union[Tuple, BaseModelOutputWithPast]:
388388-389389-390390-391391-392392- if inputs_embeds is None:
393393- # .clone().to(bfloat16): two training-only fixes --
394394- # 1. clone() breaks the autograd leaf-variable link so the
395395- # masked_scatter_ below does not raise "in-place on leaf".
396396- # 2. to(bfloat16) matches the dtype of image features produced
397397- # by the vision encoder (prepare_model_for_kbit_training
398398- # upcasts the embedding table to float32; the vision encoder
399399- # stays bfloat16 after our explicit cast in the training script).
400400- inputs_embeds = self.get_input_embeddings()(input_ids).clone().to(torch.bfloat16)
401401-402402-403403-404404- sam_model = getattr(self, 'sam_model', None)
405405- # sam_model = self.sam_model
406406- qwen2_model = getattr(self, 'qwen2_model', None)
407407-408408-409409-410410- if sam_model is not None and (input_ids.shape[1] != 1 or self.training) and torch.sum(images[0][1]).item() != 0:
411411-412412- idx = 0
413413-414414- # sam_model = torch.jit.script(sam_model)
415415-416416- # start_time = time.time()
417417- for image, crop_shape in zip(images, images_spatial_crop):
418418- images_in_this_batch = []
419419-420420- patches = image[0]
421421- image_ori = image[1]
422422-423423- with torch.no_grad():
424424- # with torch.inference_mode():
425425-426426- if torch.sum(patches).item() != 0:
427427- # P, C, H, W = patches.shape
428428- crop_flag = 1
429429- local_features_1 = sam_model(patches)
430430-431431- local_features_2 = qwen2_model(local_features_1)
432432- # vit_time = time.time()
433433- local_features = local_features_2
434434- local_features = self.projector(local_features)
435435-436436-437437- global_features_1 = sam_model(image_ori)
438438- global_features_2 = qwen2_model(global_features_1)
439439- global_features = global_features_2
440440- global_features = self.projector(global_features)
441441-442442- pass
443443-444444- _, hw, n_dim = global_features.shape
445445- # h = w = int(hw ** 0.5)
446446-447447- _2, hw2, n_dim2 = local_features.shape
448448- # h2 = w2 = int(hw2 ** 0.5)
449449-450450-451451- global_features = global_features.view(-1, n_dim)
452452-453453-454454- local_features = local_features.view(-1, n_dim2)
455455-456456- global_local_features = torch.cat([local_features, global_features, self.view_seperator[None, :]], dim=0)
457457-458458- # end_time = time.time()
459459-460460- # print('sam: ', sam_time - start_time)
461461- # print('vit: ', vit_time - sam_time)
462462- # print('all: ', end_time - start_time)
463463-464464- # exit()
465465-466466- else:
467467- global_features_1 = sam_model(image_ori)
468468- global_features_2 = qwen2_model(global_features_1)
469469- global_features = global_features_2
470470- global_features = self.projector(global_features)
471471- pass
472472- _, hw, n_dim = global_features.shape
473473- # h = w = int(hw ** 0.5)
474474-475475-476476- # global_features = global_features.view(h, w, n_dim)
477477-478478- # global_features = torch.cat(
479479- # [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1
480480- # )
481481-482482- global_features = global_features.view(-1, n_dim)
483483-484484- global_local_features = torch.cat([global_features, self.view_seperator[None, :]], dim=0)
485485-486486- images_in_this_batch.append(global_local_features)
487487-488488-489489- # print(inputs_embeds.shape)
490490-491491- if images_in_this_batch:
492492- images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
493493- # exit()
494494-495495- inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).cuda(), images_in_this_batch)
496496-497497- idx += 1
498498-499499-500500- return super(DeepseekOCR2Model, self).forward(
501501- input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
502502- inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids,
503503- output_attentions=output_attentions, output_hidden_states=output_hidden_states,
504504- return_dict=return_dict
505505- )
506506-507507-508508-class DeepseekOCR2ForCausalLM(DeepseekV2ForCausalLM):
509509-510510- config_class = DeepseekOCR2Config
511511- # supports_gradient_checkpointing = True
512512-513513- def __init__(self, config):
514514- super(DeepseekV2ForCausalLM, self).__init__(config)
515515- self.model = DeepseekOCR2Model(config)
516516-517517- self.vocab_size = config.vocab_size
518518-519519- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
520520-521521- # self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
522522-523523- # Initialize weights and apply final processing
524524- self.post_init()
525525-526526- def get_model(self):
527527- return self.model
528528-529529-530530- def forward(
531531- self,
532532- input_ids: torch.LongTensor = None,
533533- attention_mask: Optional[torch.Tensor] = None,
534534- position_ids: Optional[torch.LongTensor] = None,
535535- past_key_values: Optional[List[torch.FloatTensor]] = None,
536536- inputs_embeds: Optional[torch.FloatTensor] = None,
537537- labels: Optional[torch.LongTensor] = None,
538538- use_cache: Optional[bool] = None,
539539- output_attentions: Optional[bool] = None,
540540- output_hidden_states: Optional[bool] = None,
541541- images: Optional[torch.FloatTensor] = None,
542542- images_seq_mask: Optional[torch.FloatTensor] = None,
543543- images_spatial_crop: Optional[torch.FloatTensor] = None,
544544- return_dict: Optional[bool] = None,
545545-546546- ) -> Union[Tuple, CausalLMOutputWithPast]:
547547- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
548548- output_hidden_states = (
549549- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
550550- )
551551- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
552552-553553-554554-555555- outputs = self.model(
556556- input_ids=input_ids,
557557- past_key_values=past_key_values,
558558- attention_mask=attention_mask,
559559- position_ids=position_ids,
560560- inputs_embeds=inputs_embeds,
561561- use_cache=use_cache,
562562- output_attentions=output_attentions,
563563- output_hidden_states=output_hidden_states,
564564- images=images,
565565- images_seq_mask = images_seq_mask,
566566- images_spatial_crop = images_spatial_crop,
567567- return_dict=return_dict
568568-569569- )
570570-571571-572572-573573- # print(transformer_outputs)
574574-575575- hidden_states = outputs[0]
576576- logits = self.lm_head(hidden_states)
577577- logits = logits.float()
578578-579579- # logits
580580-581581- loss = None
582582- if labels is not None:
583583- # Shift so that tokens < n predict n
584584- shift_logits = logits[..., :-1, :].contiguous()
585585- shift_labels = labels[..., 1:].contiguous()
586586- # Flatten the tokens
587587- loss_fct = CrossEntropyLoss()
588588- shift_logits = shift_logits.view(-1, self.config.vocab_size)
589589- shift_labels = shift_labels.view(-1)
590590- # Enable model parallelism
591591- shift_labels = shift_labels.to(shift_logits.device)
592592- loss = loss_fct(shift_logits, shift_labels)
593593-594594- if not return_dict:
595595- output = (logits,) + outputs[1:]
596596- return (loss,) + output if loss is not None else output
597597-598598- return CausalLMOutputWithPast(
599599- loss=loss,
600600- logits=logits,
601601- past_key_values=outputs.past_key_values,
602602- hidden_states=outputs.hidden_states,
603603- attentions=outputs.attentions,
604604- )
605605-606606-607607- def prepare_inputs_for_generation(
608608- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
609609- ):
610610- # Omit tokens covered by past_key_values
611611- past_length = 0
612612- if past_key_values is not None:
613613- if isinstance(past_key_values, Cache):
614614- cache_length = past_key_values.get_seq_length()
615615- past_length = past_key_values.seen_tokens
616616- max_cache_length = past_key_values.get_max_length()
617617- else:
618618- cache_length = past_length = past_key_values[0][0].shape[2]
619619- max_cache_length = None
620620-621621- # Keep only the unprocessed tokens:
622622- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
623623- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
624624- # input)
625625- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
626626- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
627627- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
628628- # input_ids based on the past_length.
629629- elif past_length < input_ids.shape[1]:
630630- input_ids = input_ids[:, past_length:]
631631- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
632632-633633- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
634634- if (
635635- max_cache_length is not None
636636- and attention_mask is not None
637637- and cache_length + input_ids.shape[1] > max_cache_length
638638- ):
639639- attention_mask = attention_mask[:, -max_cache_length:]
640640-641641- position_ids = kwargs.get("position_ids", None)
642642- if attention_mask is not None and position_ids is None:
643643- # create position_ids on the fly for batch generation
644644- position_ids = attention_mask.long().cumsum(-1) - 1
645645- position_ids.masked_fill_(attention_mask == 0, 1)
646646- if past_key_values:
647647- position_ids = position_ids[:, -input_ids.shape[1] :]
648648-649649- # if self.generation_config.cache_implementation == "static":
650650- # # generation with static cache
651651- # cache_position = kwargs.get("cache_position", None)
652652- # if cache_position is None:
653653- # past_length = 0
654654- # else:
655655- # past_length = cache_position[-1] + 1
656656- # input_ids = input_ids[:, past_length:]
657657- # position_ids = position_ids[:, past_length:]
658658-659659- # TODO @gante we should only keep a `cache_position` in generate, and do +=1.
660660- # same goes for position ids. Could also help with continued generation.
661661- cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
662662-663663- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
664664- if inputs_embeds is not None and past_key_values is None:
665665- model_inputs = {"inputs_embeds": inputs_embeds}
666666- else:
667667- model_inputs = {"input_ids": input_ids}
668668-669669- model_inputs.update(
670670- {
671671- "position_ids": position_ids,
672672- "past_key_values": past_key_values,
673673- "use_cache": kwargs.get("use_cache"),
674674- "attention_mask": attention_mask,
675675- "images": kwargs.get("images", None),
676676- "images_seq_mask": kwargs.get("images_seq_mask", None),
677677- "images_spatial_crop": kwargs.get("images_spatial_crop", None),
678678- }
679679- )
680680- return model_inputs
681681-682682-683683- def disable_torch_init(self):
684684- """
685685- Disable the redundant torch default initialization to accelerate model creation.
686686- """
687687- import torch
688688- setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
689689- setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
690690-691691-692692-693693- def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False):
694694- self.disable_torch_init()
695695-696696- os.makedirs(output_path, exist_ok=True)
697697- os.makedirs(f'{output_path}/images', exist_ok=True)
698698-699699- if prompt and image_file:
700700- conversation = [
701701- {
702702- "role": "<|User|>",
703703- # "content": "<image>\n<|grounding|>Given the layout of the image. ",
704704- "content": f'{prompt}',
705705- # "content": "君不见黄河之水天上来的下一句是什么?",
706706- # "content": "<image>\nFree OCR. ",
707707- # "content": "<image>\nParse the figure. ",
708708- # "content": "<image>\nExtract the text in the image. ",
709709- "images": [f'{image_file}'],
710710- },
711711- {"role": "<|Assistant|>", "content": ""},
712712- ]
713713-714714- elif prompt:
715715- conversation = [
716716- {
717717- "role": "<|User|>",
718718- # "content": "<image>\n<|grounding|>Given the layout of the image. ",
719719- "content": f'{prompt}',
720720- # "content": "君不见黄河之水天上来的下一句是什么?",
721721- # "content": "<image>\nFree OCR. ",
722722- # "content": "<image>\nParse the figure. ",
723723- # "content": "<image>\nExtract the text in the image. ",
724724- # "images": [f'{image_file}'],
725725- },
726726- {"role": "<|Assistant|>", "content": ""},
727727- ]
728728- else:
729729- assert False, f'prompt is none!'
730730-731731- prompt = format_messages(conversations=conversation, sft_format='plain', system_prompt='')
732732-733733- patch_size = 16
734734- downsample_ratio = 4
735735- images = load_pil_images(conversation)
736736-737737- valid_img_tokens = 0
738738- ratio = 1
739739-740740- image_draw = images[0].copy()
741741-742742- w,h = image_draw.size
743743- # print(w, h)
744744- ratio = 1 - ((max(w, h) - min(w, h)) / (max(w, h)))
745745-746746-747747- image_transform=BasicImageTransform(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), normalize=True)
748748- images_seq_mask = []
749749-750750- image_token = '<image>'
751751- image_token_id = 128815
752752- text_splits = prompt.split(image_token)
753753-754754- images_list, images_crop_list, images_seq_mask = [], [], []
755755- tokenized_str = []
756756- images_spatial_crop = []
757757- for text_sep, image in zip(text_splits, images):
758758-759759- tokenized_sep = text_encode(tokenizer, text_sep, bos=False, eos=False)
760760- tokenized_str += tokenized_sep
761761- images_seq_mask += [False] * len(tokenized_sep)
762762-763763- if crop_mode:
764764-765765- if image.size[0] <= 768 and image.size[1] <= 768:
766766- crop_ratio = [1, 1]
767767-768768- else:
769769- if crop_mode:
770770- # best_width, best_height = select_best_resolution(image.size, self.candidate_resolutions)
771771- images_crop_raw, crop_ratio = dynamic_preprocess(image)
772772- else:
773773- # best_width, best_height = self.image_size, self.image_size
774774- crop_ratio = [1, 1]
775775-776776- """process the global view"""
777777- # image = image.resize((base_size, base_size))
778778- global_view = ImageOps.pad(image, (base_size, base_size),
779779- color=tuple(int(x * 255) for x in image_transform.mean))
780780-781781- if base_size == 1024:
782782- valid_img_tokens += int(256 * ratio)
783783- elif base_size == 1280:
784784- valid_img_tokens += int(400 * ratio)
785785- # elif base_size == 640:
786786- # valid_img_tokens += int(100 * ratio)
787787-788788-789789-790790-791791-792792- images_list.append(image_transform(global_view).to(torch.bfloat16))
793793-794794- # global_view_tensor = image_transform(global_view).to(torch.bfloat16)
795795-796796- width_crop_num, height_crop_num = crop_ratio
797797-798798- images_spatial_crop.append([width_crop_num, height_crop_num])
799799-800800-801801- if width_crop_num > 1 or height_crop_num > 1:
802802- """process the local views"""
803803-804804- for i in range(len(images_crop_raw)):
805805- images_crop_list.append(image_transform(images_crop_raw[i]).to(torch.bfloat16))
806806-807807- if image_size == 768:
808808- valid_img_tokens += len(images_crop_list) * 144
809809-810810- num_queries = math.ceil((image_size // patch_size) / downsample_ratio)
811811- num_queries_base = math.ceil((base_size // patch_size) / downsample_ratio)
812812-813813-814814-815815- """add image tokens"""
816816-817817-818818-819819- tokenized_image = ([image_token_id] * num_queries_base) * num_queries_base
820820- tokenized_image += [image_token_id]
821821- if width_crop_num > 1 or height_crop_num > 1:
822822- tokenized_image += ([image_token_id] * (num_queries * width_crop_num)) * (
823823- num_queries * height_crop_num)
824824- tokenized_str += tokenized_image
825825- images_seq_mask += [True] * len(tokenized_image)
826826- # num_image_tokens.append(len(tokenized_image))
827827-828828- else:
829829- # best_width, best_height = self.image_size, self.image_size
830830- # print(image.size, (best_width, best_height)) # check the select_best_resolutions func
831831-832832- """process the global view"""
833833- if image_size <= 768:
834834- print('directly resize')
835835- image = image.resize((image_size, image_size))
836836- # else:
837837- global_view = ImageOps.pad(image, (image_size, image_size),
838838- color=tuple(int(x * 255) for x in image_transform.mean))
839839- images_list.append(image_transform(global_view).to(torch.bfloat16))
840840-841841- if base_size == 1024:
842842- valid_img_tokens += int(256 * ratio)
843843- elif base_size == 1280:
844844- valid_img_tokens += int(400 * ratio)
845845- elif base_size == 640:
846846- valid_img_tokens += int(100 * 1)
847847- elif base_size == 512:
848848- valid_img_tokens += int(64 * 1)
849849- elif base_size == 768:
850850- valid_img_tokens += int(144 * 1)
851851-852852- width_crop_num, height_crop_num = 1, 1
853853-854854- images_spatial_crop.append([width_crop_num, height_crop_num])
855855-856856-857857- """add image tokens"""
858858- num_queries = math.ceil((image_size // patch_size) / downsample_ratio)
859859-860860- tokenized_image = ([image_token_id] * num_queries) * num_queries
861861- tokenized_image += [image_token_id]
862862- # tokenized_image += ([self.image_token_id] * (num_queries * width_crop_num) + [self.image_token_id]) * (
863863- # num_queries * height_crop_num)
864864- tokenized_str += tokenized_image
865865- images_seq_mask += [True] * len(tokenized_image)
866866- # num_image_tokens.append(len(tokenized_image))
867867-868868-869869- """process the last text split"""
870870- tokenized_sep = text_encode(tokenizer, text_splits[-1], bos=False, eos=False)
871871- tokenized_str += tokenized_sep
872872- images_seq_mask += [False] * len(tokenized_sep)
873873-874874- """add the bos tokens"""
875875- bos_id = 0
876876- tokenized_str = [bos_id] + tokenized_str
877877- images_seq_mask = [False] + images_seq_mask
878878-879879-880880-881881- input_ids = torch.LongTensor(tokenized_str)
882882-883883-884884-885885-886886- images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
887887-888888-889889- if len(images_list) == 0:
890890- images_ori = torch.zeros((1, 3, image_size, image_size))
891891- images_spatial_crop = torch.zeros((1, 2), dtype=torch.long)
892892- images_crop = torch.zeros((1, 3, base_size, base_size))
893893-894894- else:
895895- images_ori = torch.stack(images_list, dim=0)
896896- images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
897897- if images_crop_list:
898898- images_crop = torch.stack(images_crop_list, dim=0)
899899- else:
900900- images_crop = torch.zeros((1, 3, base_size, base_size))
901901-902902-903903-904904- if not eval_mode:
905905- streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
906906- with torch.autocast("cuda", dtype=torch.bfloat16):
907907- with torch.no_grad():
908908- output_ids = self.generate(
909909- input_ids.unsqueeze(0).cuda(),
910910- images=[(images_crop.cuda(), images_ori.cuda())],
911911- images_seq_mask = images_seq_mask.unsqueeze(0).cuda(),
912912- images_spatial_crop = images_spatial_crop,
913913- # do_sample=False,
914914- # num_beams = 1,
915915- temperature=0.0,
916916- eos_token_id=tokenizer.eos_token_id,
917917- streamer=streamer,
918918- max_new_tokens=8192,
919919- no_repeat_ngram_size = 20,
920920- use_cache = True
921921- )
922922-923923- else:
924924- with torch.autocast("cuda", dtype=torch.bfloat16):
925925- with torch.no_grad():
926926- output_ids = self.generate(
927927- input_ids.unsqueeze(0).cuda(),
928928- images=[(images_crop.cuda(), images_ori.cuda())],
929929- images_seq_mask = images_seq_mask.unsqueeze(0).cuda(),
930930- images_spatial_crop = images_spatial_crop,
931931- # do_sample=False,
932932- # num_beams = 1,
933933- temperature=0.0,
934934- eos_token_id=tokenizer.eos_token_id,
935935- max_new_tokens=8192,
936936- no_repeat_ngram_size = 35,
937937- use_cache = True
938938- )
939939-940940-941941- if '<image>' in conversation[0]['content'] and eval_mode:
942942- outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
943943- stop_str = '<|end▁of▁sentence|>'
944944- if outputs.endswith(stop_str):
945945- outputs = outputs[:-len(stop_str)]
946946- # re_match
947947- outputs = outputs.strip()
948948-949949- return outputs
950950-951951- if '<image>' in conversation[0]['content'] and test_compress:
952952- outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
953953- pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False))
954954- print('='*50)
955955- print('image size: ', (w, h))
956956- print('valid image tokens: ', int(valid_img_tokens))
957957- print('output texts tokens (valid): ', pure_texts_outputs_token_length)
958958- print('compression ratio: ', round(pure_texts_outputs_token_length/valid_img_tokens, 2))
959959- print('='*50)
960960-961961-962962- if '<image>' in conversation[0]['content'] and save_results:
963963- outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
964964- stop_str = '<|end▁of▁sentence|>'
965965-966966- print('='*15 + 'save results:' + '='*15)
967967-968968- # # # # conv.messages[-1][-1] = outputs
969969- if outputs.endswith(stop_str):
970970- outputs = outputs[:-len(stop_str)]
971971- outputs = outputs.strip()
972972-973973- matches_ref, matches_images, mathes_other = re_match(outputs)
974974- # print(matches_ref)
975975- result = process_image_with_refs(image_draw, matches_ref, output_path)
976976-977977-978978- for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")):
979979- outputs = outputs.replace(a_match_image, ' + '.jpg)\n')
980980-981981- for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")):
982982- outputs = outputs.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:')
983983-984984-985985- # if 'structural formula' in conversation[0]['content']:
986986- # outputs = '<smiles>' + outputs + '</smiles>'
987987- with open(f'{output_path}/result.mmd', 'w', encoding = 'utf-8') as afile:
988988- afile.write(outputs)
989989-990990- if 'line_type' in outputs:
991991- import matplotlib.pyplot as plt
992992- lines = eval(outputs)['Line']['line']
993993-994994- line_type = eval(outputs)['Line']['line_type']
995995- # print(lines)
996996-997997- endpoints = eval(outputs)['Line']['line_endpoint']
998998-999999- fig, ax = plt.subplots(figsize=(3,3), dpi=200)
10001000- ax.set_xlim(-15, 15)
10011001- ax.set_ylim(-15, 15)
10021002-10031003- for idx, line in enumerate(lines):
10041004- try:
10051005- p0 = eval(line.split(' -- ')[0])
10061006- p1 = eval(line.split(' -- ')[-1])
10071007-10081008- if line_type[idx] == '--':
10091009- ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color='k')
10101010- else:
10111011- ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth = 0.8, color = 'k')
10121012-10131013- ax.scatter(p0[0], p0[1], s=5, color = 'k')
10141014- ax.scatter(p1[0], p1[1], s=5, color = 'k')
10151015- except:
10161016- pass
10171017-10181018- for endpoint in endpoints:
10191019-10201020- label = endpoint.split(': ')[0]
10211021- (x, y) = eval(endpoint.split(': ')[1])
10221022- ax.annotate(label, (x, y), xytext=(1, 1), textcoords='offset points',
10231023- fontsize=5, fontweight='light')
10241024-10251025-10261026- plt.savefig(f'{output_path}/geo.jpg')
10271027- plt.close()
10281028-10291029- result.save(f"{output_path}/result_with_boxes.jpg")
-1992
src/deepseek_ocr2/modeling_deepseekv2.py
···11-# coding=utf-8
22-# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
33-#
44-# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
55-# and OPT implementations in this library. It has been modified from its
66-# original forms to accommodate minor architectural differences compared
77-# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
88-#
99-# Licensed under the Apache License, Version 2.0 (the "License");
1010-# you may not use this file except in compliance with the License.
1111-# You may obtain a copy of the License at
1212-#
1313-# http://www.apache.org/licenses/LICENSE-2.0
1414-#
1515-# Unless required by applicable law or agreed to in writing, software
1616-# distributed under the License is distributed on an "AS IS" BASIS,
1717-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1818-# See the License for the specific language governing permissions and
1919-# limitations under the License.
2020-""" PyTorch DeepSeek model and compatible with both DeepSeekV2 and DeepSeekV3"""
2121-import math
2222-import warnings
2323-from typing import List, Optional, Tuple, Union
2424-import numpy as np
2525-2626-import torch
2727-import torch.nn.functional as F
2828-import torch.utils.checkpoint
2929-import torch.distributed as dist
3030-from einops import repeat
3131-from torch import nn
3232-from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
3333-3434-from transformers.activations import ACT2FN
3535-from transformers.cache_utils import Cache, DynamicCache
3636-from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
3737-from transformers.models.llama.modeling_llama import (
3838- LlamaAttention,
3939- LlamaFlashAttention2
4040-)
4141-from transformers.modeling_outputs import (
4242- BaseModelOutputWithPast,
4343- CausalLMOutputWithPast,
4444- SequenceClassifierOutputWithPast,
4545-)
4646-from transformers.modeling_utils import PreTrainedModel
4747-from transformers.pytorch_utils import (
4848- ALL_LAYERNORM_LAYERS,
4949- is_torch_greater_or_equal_than_1_13,
5050-)
5151-from transformers.utils import (
5252- add_start_docstrings,
5353- add_start_docstrings_to_model_forward,
5454- is_flash_attn_2_available,
5555- is_flash_attn_greater_or_equal_2_10,
5656- logging,
5757- replace_return_docstrings,
5858-)
5959-from transformers.utils.import_utils import is_torch_fx_available
6060-6161-from .configuration_deepseek_v2 import DeepseekV2Config
6262-6363-if is_flash_attn_2_available():
6464- from flash_attn import flash_attn_func, flash_attn_varlen_func
6565- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
6666-6767-# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
6868-# It means that the function will not be traced through and simply appear as a node in the graph.
6969-if is_torch_fx_available():
7070- if not is_torch_greater_or_equal_than_1_13:
7171- import torch.fx
7272-7373- _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
7474-7575-logger = logging.get_logger(__name__)
7676-7777-_CONFIG_FOR_DOC = "DeepseekV2Config"
7878-7979-8080-def _get_unpad_data(attention_mask):
8181- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
8282- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
8383- max_seqlen_in_batch = seqlens_in_batch.max().item()
8484- cu_seqlens = F.pad(
8585- torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
8686- )
8787- return (
8888- indices,
8989- cu_seqlens,
9090- max_seqlen_in_batch,
9191- )
9292-9393-9494-class DeepseekV2RMSNorm(nn.Module):
9595- def __init__(self, hidden_size, eps=1e-6):
9696- """
9797- DeepseekV2RMSNorm is equivalent to T5LayerNorm
9898- """
9999- super().__init__()
100100- self.weight = nn.Parameter(torch.ones(hidden_size))
101101- self.variance_epsilon = eps
102102-103103- def forward(self, hidden_states):
104104- input_dtype = hidden_states.dtype
105105- hidden_states = hidden_states.to(torch.float32)
106106- variance = hidden_states.pow(2).mean(-1, keepdim=True)
107107- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
108108- return self.weight * hidden_states.to(input_dtype)
109109-110110-111111-ALL_LAYERNORM_LAYERS.append(DeepseekV2RMSNorm)
112112-113113-114114-115115-116116-class DeepseekV2RotaryEmbedding(nn.Module):
117117- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
118118- super().__init__()
119119-120120- self.dim = dim
121121- self.max_position_embeddings = max_position_embeddings
122122- self.base = base
123123- inv_freq = 1.0 / (
124124- self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
125125- )
126126- self.register_buffer("inv_freq", inv_freq, persistent=False)
127127-128128- # Build here to make `torch.jit.trace` work.
129129- self._set_cos_sin_cache(
130130- seq_len=max_position_embeddings,
131131- device=self.inv_freq.device,
132132- dtype=torch.get_default_dtype(),
133133- )
134134- self.max_seq_len_cached = None
135135-136136- def _set_cos_sin_cache(self, seq_len, device, dtype):
137137- self.max_seq_len_cached = seq_len
138138- t = torch.arange(
139139- self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
140140- )
141141-142142- freqs = torch.outer(t, self.inv_freq.to(t.device))
143143- # Different from paper, but it uses a different permutation in order to obtain the same calculation
144144- emb = torch.cat((freqs, freqs), dim=-1)
145145- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
146146- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
147147-148148- def forward(self, x, seq_len=None):
149149- # x: [bs, num_attention_heads, seq_len, head_size]
150150- if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
151151- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
152152-153153- return (
154154- self.cos_cached[:seq_len].to(dtype=x.dtype),
155155- self.sin_cached[:seq_len].to(dtype=x.dtype),
156156- )
157157-158158-159159-# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV2
160160-class DeepseekV2LinearScalingRotaryEmbedding(DeepseekV2RotaryEmbedding):
161161- """DeepseekV2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
162162-163163- def __init__(
164164- self,
165165- dim,
166166- max_position_embeddings=2048,
167167- base=10000,
168168- device=None,
169169- scaling_factor=1.0,
170170- ):
171171- self.scaling_factor = scaling_factor
172172- super().__init__(dim, max_position_embeddings, base, device)
173173-174174- def _set_cos_sin_cache(self, seq_len, device, dtype):
175175- self.max_seq_len_cached = seq_len
176176- t = torch.arange(
177177- self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
178178- )
179179- t = t / self.scaling_factor
180180-181181- freqs = torch.outer(t, self.inv_freq)
182182- # Different from paper, but it uses a different permutation in order to obtain the same calculation
183183- emb = torch.cat((freqs, freqs), dim=-1)
184184- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
185185- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
186186-187187-188188-# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV2
189189-class DeepseekV2DynamicNTKScalingRotaryEmbedding(DeepseekV2RotaryEmbedding):
190190- """DeepseekV2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
191191-192192- def __init__(
193193- self,
194194- dim,
195195- max_position_embeddings=2048,
196196- base=10000,
197197- device=None,
198198- scaling_factor=1.0,
199199- ):
200200- self.scaling_factor = scaling_factor
201201- super().__init__(dim, max_position_embeddings, base, device)
202202-203203- def _set_cos_sin_cache(self, seq_len, device, dtype):
204204- self.max_seq_len_cached = seq_len
205205-206206- if seq_len > self.max_position_embeddings:
207207- base = self.base * (
208208- (self.scaling_factor * seq_len / self.max_position_embeddings)
209209- - (self.scaling_factor - 1)
210210- ) ** (self.dim / (self.dim - 2))
211211- inv_freq = 1.0 / (
212212- base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
213213- )
214214- self.register_buffer("inv_freq", inv_freq, persistent=False)
215215-216216- t = torch.arange(
217217- self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
218218- )
219219-220220- freqs = torch.outer(t, self.inv_freq)
221221- # Different from paper, but it uses a different permutation in order to obtain the same calculation
222222- emb = torch.cat((freqs, freqs), dim=-1)
223223- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
224224- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
225225-226226-227227-# Inverse dim formula to find dim based on number of rotations
228228-def yarn_find_correction_dim(
229229- num_rotations, dim, base=10000, max_position_embeddings=2048
230230-):
231231- return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
232232- 2 * math.log(base)
233233- )
234234-235235-236236-# Find dim range bounds based on rotations
237237-def yarn_find_correction_range(
238238- low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
239239-):
240240- low = math.floor(
241241- yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
242242- )
243243- high = math.ceil(
244244- yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
245245- )
246246- return max(low, 0), min(high, dim - 1) # Clamp values just in case
247247-248248-249249-def yarn_get_mscale(scale=1, mscale=1):
250250- if scale <= 1:
251251- return 1.0
252252- return 0.1 * mscale * math.log(scale) + 1.0
253253-254254-255255-def yarn_linear_ramp_mask(min, max, dim):
256256- if min == max:
257257- max += 0.001 # Prevent singularity
258258-259259- linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
260260- ramp_func = torch.clamp(linear_func, 0, 1)
261261- return ramp_func
262262-263263-264264-class DeepseekV2YarnRotaryEmbedding(DeepseekV2RotaryEmbedding):
265265-266266- def __init__(
267267- self,
268268- dim,
269269- max_position_embeddings=2048,
270270- base=10000,
271271- device=None,
272272- scaling_factor=1.0,
273273- original_max_position_embeddings=4096,
274274- beta_fast=32,
275275- beta_slow=1,
276276- mscale=1,
277277- mscale_all_dim=0,
278278- ):
279279- self.scaling_factor = scaling_factor
280280- self.original_max_position_embeddings = original_max_position_embeddings
281281- self.beta_fast = beta_fast
282282- self.beta_slow = beta_slow
283283- self.mscale = mscale
284284- self.mscale_all_dim = mscale_all_dim
285285- super().__init__(dim, max_position_embeddings, base, device)
286286-287287- def _set_cos_sin_cache(self, seq_len, device, dtype):
288288- self.max_seq_len_cached = seq_len
289289- dim = self.dim
290290-291291- freq_extra = 1.0 / (
292292- self.base
293293- ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
294294- )
295295- freq_inter = 1.0 / (
296296- self.scaling_factor
297297- * self.base
298298- ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
299299- )
300300-301301- low, high = yarn_find_correction_range(
302302- self.beta_fast,
303303- self.beta_slow,
304304- dim,
305305- self.base,
306306- self.original_max_position_embeddings,
307307- )
308308- inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
309309- device=device, dtype=torch.float32
310310- )
311311- inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
312312- self.register_buffer("inv_freq", inv_freq, persistent=False)
313313-314314- t = torch.arange(seq_len, device=device, dtype=torch.float32)
315315-316316- freqs = torch.outer(t, inv_freq)
317317-318318- _mscale = float(
319319- yarn_get_mscale(self.scaling_factor, self.mscale)
320320- / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
321321- )
322322-323323- emb = torch.cat((freqs, freqs), dim=-1)
324324- self.register_buffer(
325325- "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False
326326- )
327327- self.register_buffer(
328328- "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False
329329- )
330330-331331-332332-# Copied from transformers.models.llama.modeling_llama.rotate_half
333333-def rotate_half(x):
334334- """Rotates half the hidden dims of the input."""
335335- x1 = x[..., : x.shape[-1] // 2]
336336- x2 = x[..., x.shape[-1] // 2 :]
337337- return torch.cat((-x2, x1), dim=-1)
338338-339339-340340-# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
341341-def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
342342- """Applies Rotary Position Embedding to the query and key tensors.
343343-344344- Args:
345345- q (`torch.Tensor`): The query tensor.
346346- k (`torch.Tensor`): The key tensor.
347347- cos (`torch.Tensor`): The cosine part of the rotary embedding.
348348- sin (`torch.Tensor`): The sine part of the rotary embedding.
349349- position_ids (`torch.Tensor`):
350350- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
351351- used to pass offsetted position ids when working with a KV-cache.
352352- unsqueeze_dim (`int`, *optional*, defaults to 1):
353353- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
354354- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
355355- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
356356- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
357357- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
358358- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
359359- Returns:
360360- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
361361- """
362362- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
363363- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
364364-365365-366366- # print()
367367-368368- b, h, s, d = q.shape
369369- q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
370370-371371- b, h, s, d = k.shape
372372- k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
373373-374374- q_embed = (q * cos) + (rotate_half(q) * sin)
375375- k_embed = (k * cos) + (rotate_half(k) * sin)
376376-377377-378378- return q_embed, k_embed
379379-380380-381381-class DeepseekV2MLP(nn.Module):
382382- def __init__(self, config, hidden_size=None, intermediate_size=None):
383383- super().__init__()
384384- self.config = config
385385- self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
386386- self.intermediate_size = (
387387- config.intermediate_size if intermediate_size is None else intermediate_size
388388- )
389389-390390- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
391391- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
392392- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
393393- self.act_fn = ACT2FN[config.hidden_act]
394394-395395- def forward(self, x):
396396- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
397397- return down_proj
398398-399399-400400-class MoEGate(nn.Module):
401401- def __init__(self, config):
402402- super().__init__()
403403- self.config = config
404404- self.top_k = config.num_experts_per_tok
405405- self.n_routed_experts = config.n_routed_experts
406406- self.routed_scaling_factor = config.routed_scaling_factor
407407- self.scoring_func = config.scoring_func
408408- self.alpha = config.aux_loss_alpha
409409- self.seq_aux = config.seq_aux
410410- self.topk_method = config.topk_method
411411- self.n_group = config.n_group
412412- self.topk_group = config.topk_group
413413-414414- # topk selection algorithm
415415- self.norm_topk_prob = config.norm_topk_prob
416416- self.gating_dim = config.hidden_size
417417- self.weight = nn.Parameter(
418418- torch.empty((self.n_routed_experts, self.gating_dim))
419419- )
420420- if self.topk_method == "noaux_tc":
421421- self.e_score_correction_bias = nn.Parameter(
422422- torch.empty((self.n_routed_experts))
423423- )
424424- self.reset_parameters()
425425-426426- def reset_parameters(self) -> None:
427427- import torch.nn.init as init
428428-429429- init.kaiming_uniform_(self.weight, a=math.sqrt(5))
430430-431431- def forward(self, hidden_states):
432432- bsz, seq_len, h = hidden_states.shape
433433- ### compute gating score
434434- hidden_states = hidden_states.view(-1, h)
435435- logits = F.linear(
436436- hidden_states.type(torch.float32), self.weight.type(torch.float32), None
437437- )
438438- if self.scoring_func == "softmax":
439439- scores = logits.softmax(dim=-1, dtype=torch.float32)
440440- elif self.scoring_func == "sigmoid":
441441- scores = logits.sigmoid()
442442- else:
443443- raise NotImplementedError(
444444- f"insupportable scoring function for MoE gating: {self.scoring_func}"
445445- )
446446-447447- ### select top-k experts
448448- if self.topk_method == "greedy":
449449- topk_weight, topk_idx = torch.topk(
450450- scores, k=self.top_k, dim=-1, sorted=False
451451- )
452452- elif self.topk_method == "group_limited_greedy":
453453- group_scores = (
454454- scores.view(bsz * seq_len, self.n_group, -1).max(dim=-1).values
455455- ) # [n, n_group]
456456- group_idx = torch.topk(
457457- group_scores, k=self.topk_group, dim=-1, sorted=False
458458- )[
459459- 1
460460- ] # [n, top_k_group]
461461- group_mask = torch.zeros_like(group_scores) # [n, n_group]
462462- group_mask.scatter_(1, group_idx, 1) # [n, n_group]
463463- score_mask = (
464464- group_mask.unsqueeze(-1)
465465- .expand(
466466- bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group
467467- )
468468- .reshape(bsz * seq_len, -1)
469469- ) # [n, e]
470470- tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
471471- topk_weight, topk_idx = torch.topk(
472472- tmp_scores, k=self.top_k, dim=-1, sorted=False
473473- )
474474- elif self.topk_method == "noaux_tc":
475475- assert not self.training
476476- scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0)
477477- group_scores = (
478478- scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim = -1)
479479- ) # [n, n_group]
480480- group_idx = torch.topk(
481481- group_scores, k=self.topk_group, dim=-1, sorted=False
482482- )[
483483- 1
484484- ] # [n, top_k_group]
485485- group_mask = torch.zeros_like(group_scores) # [n, n_group]
486486- group_mask.scatter_(1, group_idx, 1) # [n, n_group]
487487- score_mask = (
488488- group_mask.unsqueeze(-1)
489489- .expand(
490490- bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group
491491- )
492492- .reshape(bsz * seq_len, -1)
493493- ) # [n, e]
494494- tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e]
495495- _, topk_idx = torch.topk(
496496- tmp_scores, k=self.top_k, dim=-1, sorted=False
497497- )
498498- topk_weight = scores.gather(1, topk_idx)
499499-500500- ### norm gate to sum 1
501501- if self.top_k > 1 and self.norm_topk_prob:
502502- denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
503503- topk_weight = topk_weight / denominator * self.routed_scaling_factor
504504- else:
505505- topk_weight = topk_weight * self.routed_scaling_factor
506506- ### expert-level computation auxiliary loss
507507- if self.training and self.alpha > 0.0:
508508- scores_for_aux = scores
509509- aux_topk = self.top_k
510510- # always compute aux loss based on the naive greedy topk method
511511- topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
512512- if self.seq_aux:
513513- scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
514514- ce = torch.zeros(
515515- bsz, self.n_routed_experts, device=hidden_states.device
516516- )
517517- ce.scatter_add_(
518518- 1,
519519- topk_idx_for_aux_loss,
520520- torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device),
521521- ).div_(seq_len * aux_topk / self.n_routed_experts)
522522- aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(
523523- dim=1
524524- ).mean() * self.alpha
525525- else:
526526- mask_ce = F.one_hot(
527527- topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts
528528- )
529529- ce = mask_ce.float().mean(0)
530530- Pi = scores_for_aux.mean(0)
531531- fi = ce * self.n_routed_experts
532532- aux_loss = (Pi * fi).sum() * self.alpha
533533- else:
534534- aux_loss = None
535535- return topk_idx, topk_weight, aux_loss
536536-537537-538538-class AddAuxiliaryLoss(torch.autograd.Function):
539539- """
540540- The trick function of adding auxiliary (aux) loss,
541541- which includes the gradient of the aux loss during backpropagation.
542542- """
543543-544544- @staticmethod
545545- def forward(ctx, x, loss):
546546- assert loss.numel() == 1
547547- ctx.dtype = loss.dtype
548548- ctx.required_aux_loss = loss.requires_grad
549549- return x
550550-551551- @staticmethod
552552- def backward(ctx, grad_output):
553553- grad_loss = None
554554- if ctx.required_aux_loss:
555555- grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
556556- return grad_output, grad_loss
557557-558558-559559-class DeepseekV2MoE(nn.Module):
560560- """
561561- A mixed expert module containing shared experts.
562562- """
563563-564564- def __init__(self, config):
565565- super().__init__()
566566- self.config = config
567567- self.num_experts_per_tok = config.num_experts_per_tok
568568-569569- if hasattr(config, "ep_size") and config.ep_size > 1:
570570- assert config.ep_size == dist.get_world_size()
571571- self.ep_size = config.ep_size
572572- self.experts_per_rank = config.n_routed_experts // config.ep_size
573573- self.ep_rank = dist.get_rank()
574574- self.experts = nn.ModuleList(
575575- [
576576- (
577577- DeepseekV2MLP(
578578- config, intermediate_size=config.moe_intermediate_size
579579- )
580580- if i >= self.ep_rank * self.experts_per_rank
581581- and i < (self.ep_rank + 1) * self.experts_per_rank
582582- else None
583583- )
584584- for i in range(config.n_routed_experts)
585585- ]
586586- )
587587- else:
588588- self.ep_size = 1
589589- self.experts_per_rank = config.n_routed_experts
590590- self.ep_rank = 0
591591- self.experts = nn.ModuleList(
592592- [
593593- DeepseekV2MLP(
594594- config, intermediate_size=config.moe_intermediate_size
595595- )
596596- for i in range(config.n_routed_experts)
597597- ]
598598- )
599599- self.gate = MoEGate(config)
600600- if config.n_shared_experts is not None:
601601- intermediate_size = config.moe_intermediate_size * config.n_shared_experts
602602- self.shared_experts = DeepseekV2MLP(
603603- config=config, intermediate_size=intermediate_size
604604- )
605605-606606- def forward(self, hidden_states):
607607- identity = hidden_states
608608- orig_shape = hidden_states.shape
609609- topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
610610- hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
611611- flat_topk_idx = topk_idx.view(-1)
612612- if self.training:
613613- hidden_states = hidden_states.repeat_interleave(
614614- self.num_experts_per_tok, dim=0
615615- )
616616- y = torch.empty_like(hidden_states)
617617- for i, expert in enumerate(self.experts):
618618- y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
619619- y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
620620- y = y.to(hidden_states.dtype).view(*orig_shape)
621621- y = AddAuxiliaryLoss.apply(y, aux_loss)
622622- else:
623623- y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
624624- if self.config.n_shared_experts is not None:
625625- y = y + self.shared_experts(identity)
626626- return y
627627-628628- @torch.no_grad()
629629- def moe_infer(self, x, topk_ids, topk_weight):
630630- cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
631631- cnts.scatter_(1, topk_ids, 1)
632632- tokens_per_expert = cnts.sum(dim=0)
633633- idxs = topk_ids.view(-1).argsort()
634634- sorted_tokens = x[idxs // topk_ids.shape[1]]
635635- sorted_tokens_shape = sorted_tokens.shape
636636- if self.ep_size > 1:
637637- tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
638638- tokens_per_expert_group = tokens_per_expert.new_empty(
639639- tokens_per_expert.shape[0]
640640- )
641641- dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert)
642642- output_splits = (
643643- tokens_per_expert_group.view(self.ep_size, -1)
644644- .sum(1)
645645- .cpu()
646646- .numpy()
647647- .tolist()
648648- )
649649- gathered_tokens = sorted_tokens.new_empty(
650650- tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1]
651651- )
652652- input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()
653653- dist.all_to_all(
654654- list(gathered_tokens.split(output_splits)),
655655- list(sorted_tokens.split(input_split_sizes)),
656656- )
657657- tokens_per_expert_post_gather = tokens_per_expert_group.view(
658658- self.ep_size, self.experts_per_rank
659659- ).sum(dim=0)
660660- gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)
661661- s = 0
662662- for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
663663- gatherd_idxs[s : s + k] = i % self.experts_per_rank
664664- s += k
665665- gatherd_idxs = gatherd_idxs.argsort()
666666- sorted_tokens = gathered_tokens[gatherd_idxs]
667667- tokens_per_expert = tokens_per_expert_post_gather
668668- tokens_per_expert = tokens_per_expert.cpu().numpy()
669669-670670- outputs = []
671671- start_idx = 0
672672- for i, num_tokens in enumerate(tokens_per_expert):
673673- end_idx = start_idx + num_tokens
674674- if num_tokens == 0:
675675- continue
676676- expert = self.experts[i + self.ep_rank * self.experts_per_rank]
677677- tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
678678- expert_out = expert(tokens_for_this_expert)
679679- outputs.append(expert_out)
680680- start_idx = end_idx
681681-682682- outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
683683- if self.ep_size > 1:
684684- new_x = torch.empty_like(outs)
685685- new_x[gatherd_idxs] = outs
686686- gathered_tokens = new_x.new_empty(*sorted_tokens_shape)
687687- dist.all_to_all(
688688- list(gathered_tokens.split(input_split_sizes)),
689689- list(new_x.split(output_splits)),
690690- )
691691- outs = gathered_tokens
692692-693693- new_x = torch.empty_like(outs)
694694- new_x[idxs] = outs
695695- final_out = (
696696- new_x.view(*topk_ids.shape, -1)
697697- .type(topk_weight.dtype)
698698- .mul_(topk_weight.unsqueeze(dim=-1))
699699- .sum(dim=1)
700700- .type(new_x.dtype)
701701- )
702702- return final_out
703703-704704-705705-# Copied from transformers.models.llama.modeling_llama.repeat_kv
706706-def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
707707- """
708708- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
709709- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
710710- """
711711- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
712712- if n_rep == 1:
713713- return hidden_states
714714- hidden_states = hidden_states[:, :, None, :, :].expand(
715715- batch, num_key_value_heads, n_rep, slen, head_dim
716716- )
717717- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
718718-719719-720720-# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV2
721721-class DeepseekV2Attention(nn.Module):
722722- """Multi-headed attention from 'Attention Is All You Need' paper"""
723723-724724- def __init__(self, config: DeepseekV2Config, layer_idx: Optional[int] = None):
725725- super().__init__()
726726- self.config = config
727727- self.layer_idx = layer_idx
728728- if layer_idx is None:
729729- logger.warning_once(
730730- f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
731731- "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
732732- "when creating this class."
733733- )
734734-735735- self.attention_dropout = config.attention_dropout
736736- self.hidden_size = config.hidden_size
737737- self.num_heads = config.num_attention_heads
738738-739739- self.max_position_embeddings = config.max_position_embeddings
740740- self.rope_theta = config.rope_theta
741741- self.q_lora_rank = config.q_lora_rank
742742- self.qk_rope_head_dim = config.qk_rope_head_dim
743743- self.kv_lora_rank = config.kv_lora_rank
744744- self.v_head_dim = config.v_head_dim
745745- self.qk_nope_head_dim = config.qk_nope_head_dim
746746- self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
747747-748748- self.is_causal = True
749749-750750- if self.q_lora_rank is None:
751751- self.q_proj = nn.Linear(
752752- self.hidden_size, self.num_heads * self.q_head_dim, bias=False
753753- )
754754- else:
755755- self.q_a_proj = nn.Linear(
756756- self.hidden_size, config.q_lora_rank, bias=config.attention_bias
757757- )
758758- self.q_a_layernorm = DeepseekV2RMSNorm(config.q_lora_rank)
759759- self.q_b_proj = nn.Linear(
760760- config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
761761- )
762762- # config.kv_lora_rank + config.qk_rope_head_dim,
763763- self.kv_a_proj_with_mqa = nn.Linear(
764764- self.hidden_size,
765765- config.kv_lora_rank + config.qk_rope_head_dim,
766766- bias=config.attention_bias,
767767- )
768768- self.kv_a_layernorm = DeepseekV2RMSNorm(config.kv_lora_rank)
769769- self.kv_b_proj = nn.Linear(
770770- config.kv_lora_rank,
771771- self.num_heads
772772- * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
773773- bias=False,
774774- )
775775-776776- self.o_proj = nn.Linear(
777777- self.num_heads * self.v_head_dim,
778778- self.hidden_size,
779779- bias=config.attention_bias,
780780- )
781781- self._init_rope()
782782-783783- self.softmax_scale = self.q_head_dim ** (-0.5)
784784- if self.config.rope_scaling is not None:
785785- mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
786786- scaling_factor = self.config.rope_scaling["factor"]
787787- if mscale_all_dim:
788788- mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
789789- self.softmax_scale = self.softmax_scale * mscale * mscale
790790-791791- def _init_rope(self):
792792- if self.config.rope_scaling is None:
793793- self.rotary_emb = DeepseekV2RotaryEmbedding(
794794- self.qk_rope_head_dim,
795795- max_position_embeddings=self.max_position_embeddings,
796796- base=self.rope_theta,
797797- )
798798- # self.rotary_emb = DeepseekV2LinearScalingRotaryEmbedding(
799799- # self.qk_rope_head_dim,
800800- # max_position_embeddings=self.max_position_embeddings,
801801- # scaling_factor=scaling_factor,
802802- # base=self.rope_theta,
803803- # )
804804- else:
805805- scaling_type = self.config.rope_scaling["type"]
806806- scaling_factor = self.config.rope_scaling["factor"]
807807- if scaling_type == "linear":
808808- self.rotary_emb = DeepseekV2LinearScalingRotaryEmbedding(
809809- self.qk_rope_head_dim,
810810- max_position_embeddings=self.max_position_embeddings,
811811- scaling_factor=scaling_factor,
812812- base=self.rope_theta,
813813- )
814814- elif scaling_type == "dynamic":
815815- self.rotary_emb = DeepseekV2DynamicNTKScalingRotaryEmbedding(
816816- self.qk_rope_head_dim,
817817- max_position_embeddings=self.max_position_embeddings,
818818- scaling_factor=scaling_factor,
819819- base=self.rope_theta,
820820- )
821821- elif scaling_type == "yarn":
822822- kwargs = {
823823- key: self.config.rope_scaling[key]
824824- for key in [
825825- "original_max_position_embeddings",
826826- "beta_fast",
827827- "beta_slow",
828828- "mscale",
829829- "mscale_all_dim",
830830- ]
831831- if key in self.config.rope_scaling
832832- }
833833- self.rotary_emb = DeepseekV2YarnRotaryEmbedding(
834834- self.qk_rope_head_dim,
835835- max_position_embeddings=self.max_position_embeddings,
836836- scaling_factor=scaling_factor,
837837- base=self.rope_theta,
838838- **kwargs,
839839- )
840840- else:
841841- raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
842842-843843- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
844844- return (
845845- tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim)
846846- .transpose(1, 2)
847847- .contiguous()
848848- )
849849-850850- def forward(
851851- self,
852852- hidden_states: torch.Tensor,
853853- attention_mask: Optional[torch.Tensor] = None,
854854- position_ids: Optional[torch.LongTensor] = None,
855855- past_key_value: Optional[Cache] = None,
856856- output_attentions: bool = False,
857857- use_cache: bool = False,
858858- **kwargs,
859859- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
860860- if "padding_mask" in kwargs:
861861- warnings.warn(
862862- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
863863- )
864864- bsz, q_len, _ = hidden_states.size()
865865-866866- if self.q_lora_rank is None:
867867- q = self.q_proj(hidden_states)
868868- else:
869869- q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
870870- q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
871871-872872-873873- q_nope, q_pe = torch.split(
874874- q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
875875- )
876876-877877- compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
878878- compressed_kv, k_pe = torch.split(
879879- compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
880880- )
881881- compressed_kv = self.kv_a_layernorm(compressed_kv)
882882- k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
883883-884884- kv_seq_len = k_pe.shape[-2]
885885- if past_key_value is not None:
886886- if self.layer_idx is None:
887887- raise ValueError(
888888- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
889889- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
890890- "with a layer index."
891891- )
892892- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
893893-894894- cos, sin = self.rotary_emb(q_pe, seq_len=kv_seq_len)
895895- q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
896896-897897- if past_key_value is not None:
898898- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
899899- compressed_kv = compressed_kv.unsqueeze(1)
900900- k_pe, compressed_kv = past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs)
901901- compressed_kv = compressed_kv.squeeze(1)
902902-903903- kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
904904- q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :]
905905- out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :]
906906-907907- q_nope = torch.matmul(q_nope, q_absorb)
908908- attn_weights = (torch.matmul(q_pe, k_pe.mT) +
909909- torch.matmul(q_nope, compressed_kv.unsqueeze(-3).mT)) * self.softmax_scale
910910- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
911911- raise ValueError(
912912- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
913913- f" {attn_weights.size()}"
914914- )
915915- assert attention_mask is not None
916916- if attention_mask is not None:
917917- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
918918- raise ValueError(
919919- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
920920- )
921921- attn_weights = attn_weights + attention_mask
922922-923923- # upcast attention to fp32
924924- attn_weights = nn.functional.softmax(
925925- attn_weights, dim=-1, dtype=torch.float32
926926- ).to(q_pe.dtype)
927927- attn_weights = nn.functional.dropout(
928928- attn_weights, p=self.attention_dropout, training=self.training
929929- )
930930- attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv)
931931-932932- attn_output = torch.matmul(attn_output, out_absorb.mT)
933933-934934- if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
935935- raise ValueError(
936936- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is"
937937- f" {attn_output.size()}"
938938- )
939939-940940- attn_output = attn_output.transpose(1, 2).contiguous()
941941-942942- attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
943943-944944- attn_output = self.o_proj(attn_output)
945945-946946- if not output_attentions:
947947- attn_weights = None
948948-949949- return attn_output, attn_weights, past_key_value
950950-951951-952952-# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV2
953953-class DeepseekV2FlashAttention2(DeepseekV2Attention):
954954- """
955955- DeepseekV2 flash attention module. This module inherits from `DeepseekV2Attention` as the weights of the module stays
956956- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
957957- flash attention and deal with padding tokens in case the input contains any of them.
958958- """
959959-960960- def __init__(self, *args, **kwargs):
961961- super().__init__(*args, **kwargs)
962962-963963- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
964964- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
965965- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
966966- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
967967-968968- def forward(
969969- self,
970970- hidden_states: torch.Tensor,
971971- attention_mask: Optional[torch.LongTensor] = None,
972972- position_ids: Optional[torch.LongTensor] = None,
973973- past_key_value: Optional[Cache] = None,
974974- output_attentions: bool = False,
975975- use_cache: bool = False,
976976- **kwargs,
977977- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
978978- # DeepseekV2FlashAttention2 attention does not support output_attentions
979979- if "padding_mask" in kwargs:
980980- warnings.warn(
981981- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
982982- )
983983-984984- # overwrite attention_mask with padding_mask
985985- attention_mask = kwargs.pop("padding_mask")
986986-987987- output_attentions = False
988988-989989- bsz, q_len, _ = hidden_states.size()
990990-991991- if self.q_lora_rank is None:
992992- q = self.q_proj(hidden_states)
993993- else:
994994- q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
995995- q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
996996- q_nope, q_pe = torch.split(
997997- q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
998998- )
999999-10001000- # Flash attention requires the input to have the shape
10011001- # batch_size x seq_length x head_dim x hidden_dim
10021002- # therefore we just need to keep the original shape
10031003- compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
10041004- compressed_kv, k_pe = torch.split(
10051005- compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
10061006- )
10071007- k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
10081008- kv = (
10091009- self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
10101010- .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
10111011- .transpose(1, 2)
10121012- )
10131013-10141014- k_nope, value_states = torch.split(
10151015- kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
10161016- )
10171017- kv_seq_len = value_states.shape[-2]
10181018-10191019- kv_seq_len = value_states.shape[-2]
10201020- if past_key_value is not None:
10211021- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
10221022-10231023- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
10241024- q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
10251025-10261026- query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
10271027- query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
10281028- query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
10291029-10301030- key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
10311031- key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
10321032- key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
10331033-10341034- if self.q_head_dim != self.v_head_dim:
10351035- value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])
10361036-10371037- # TODO: support compressed_kv for kv_cache (instead of key_states, value_states) in flash_attention version
10381038- if past_key_value is not None:
10391039- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
10401040- key_states, value_states = past_key_value.update(
10411041- key_states, value_states, self.layer_idx, cache_kwargs
10421042- )
10431043-10441044- # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
10451045- # to be able to avoid many of these transpose/reshape/view.
10461046- query_states = query_states.transpose(1, 2)
10471047- key_states = key_states.transpose(1, 2)
10481048- value_states = value_states.transpose(1, 2)
10491049-10501050- dropout_rate = self.attention_dropout if self.training else 0.0
10511051-10521052- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
10531053- # therefore the input hidden states gets silently casted in float32. Hence, we need
10541054- # cast them back in the correct dtype just to be sure everything works as expected.
10551055- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
10561056- # in fp32. (DeepseekV2RMSNorm handles it correctly)
10571057-10581058- input_dtype = query_states.dtype
10591059- if input_dtype == torch.float32:
10601060- # Handle the case where the model is quantized
10611061- if hasattr(self.config, "_pre_quantization_dtype"):
10621062- target_dtype = self.config._pre_quantization_dtype
10631063- elif torch.is_autocast_enabled():
10641064- target_dtype = torch.get_autocast_gpu_dtype()
10651065- else:
10661066- target_dtype = (
10671067- self.q_proj.weight.dtype
10681068- if self.q_lora_rank is None
10691069- else self.q_a_proj.weight.dtype
10701070- )
10711071-10721072- logger.warning_once(
10731073- f"The input hidden states seems to be silently casted in float32, this might be related to"
10741074- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
10751075- f" {target_dtype}."
10761076- )
10771077-10781078- query_states = query_states.to(target_dtype)
10791079- key_states = key_states.to(target_dtype)
10801080- value_states = value_states.to(target_dtype)
10811081-10821082- attn_output = self._flash_attention_forward(
10831083- query_states,
10841084- key_states,
10851085- value_states,
10861086- attention_mask,
10871087- q_len,
10881088- dropout=dropout_rate,
10891089- softmax_scale=self.softmax_scale,
10901090- )
10911091- if self.q_head_dim != self.v_head_dim:
10921092- attn_output = attn_output[:, :, :, : self.v_head_dim]
10931093-10941094- attn_output = attn_output.reshape(
10951095- bsz, q_len, self.num_heads * self.v_head_dim
10961096- ).contiguous()
10971097- attn_output = self.o_proj(attn_output)
10981098-10991099- if not output_attentions:
11001100- attn_weights = None
11011101-11021102- return attn_output, attn_weights, past_key_value
11031103-11041104- def _flash_attention_forward(
11051105- self,
11061106- query_states,
11071107- key_states,
11081108- value_states,
11091109- attention_mask,
11101110- query_length,
11111111- dropout=0.0,
11121112- softmax_scale=None,
11131113- ):
11141114- """
11151115- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
11161116- first unpad the input, then computes the attention scores and pad the final attention scores.
11171117-11181118- Args:
11191119- query_states (`torch.Tensor`):
11201120- Input query states to be passed to Flash Attention API
11211121- key_states (`torch.Tensor`):
11221122- Input key states to be passed to Flash Attention API
11231123- value_states (`torch.Tensor`):
11241124- Input value states to be passed to Flash Attention API
11251125- attention_mask (`torch.Tensor`):
11261126- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
11271127- position of padding tokens and 1 for the position of non-padding tokens.
11281128- dropout (`int`, *optional*):
11291129- Attention dropout
11301130- softmax_scale (`float`, *optional*):
11311131- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
11321132- """
11331133- if not self._flash_attn_uses_top_left_mask:
11341134- causal = self.is_causal
11351135- else:
11361136- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV2FlashAttention2 __init__.
11371137- causal = self.is_causal and query_length != 1
11381138-11391139- # Contains at least one padding token in the sequence
11401140- if attention_mask is not None:
11411141- batch_size = query_states.shape[0]
11421142- (
11431143- query_states,
11441144- key_states,
11451145- value_states,
11461146- indices_q,
11471147- cu_seq_lens,
11481148- max_seq_lens,
11491149- ) = self._upad_input(
11501150- query_states, key_states, value_states, attention_mask, query_length
11511151- )
11521152-11531153- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
11541154- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
11551155-11561156- attn_output_unpad = flash_attn_varlen_func(
11571157- query_states,
11581158- key_states,
11591159- value_states,
11601160- cu_seqlens_q=cu_seqlens_q,
11611161- cu_seqlens_k=cu_seqlens_k,
11621162- max_seqlen_q=max_seqlen_in_batch_q,
11631163- max_seqlen_k=max_seqlen_in_batch_k,
11641164- dropout_p=dropout,
11651165- softmax_scale=softmax_scale,
11661166- causal=causal,
11671167- )
11681168-11691169- attn_output = pad_input(
11701170- attn_output_unpad, indices_q, batch_size, query_length
11711171- )
11721172- else:
11731173- attn_output = flash_attn_func(
11741174- query_states,
11751175- key_states,
11761176- value_states,
11771177- dropout,
11781178- softmax_scale=softmax_scale,
11791179- causal=causal,
11801180- )
11811181-11821182- return attn_output
11831183-11841184- def _upad_input(
11851185- self, query_layer, key_layer, value_layer, attention_mask, query_length
11861186- ):
11871187- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
11881188- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
11891189-11901190- key_layer = index_first_axis(
11911191- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
11921192- indices_k,
11931193- )
11941194- value_layer = index_first_axis(
11951195- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
11961196- indices_k,
11971197- )
11981198- if query_length == kv_seq_len:
11991199- query_layer = index_first_axis(
12001200- query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
12011201- indices_k,
12021202- )
12031203- cu_seqlens_q = cu_seqlens_k
12041204- max_seqlen_in_batch_q = max_seqlen_in_batch_k
12051205- indices_q = indices_k
12061206- elif query_length == 1:
12071207- max_seqlen_in_batch_q = 1
12081208- cu_seqlens_q = torch.arange(
12091209- batch_size + 1, dtype=torch.int32, device=query_layer.device
12101210- ) # There is a memcpy here, that is very bad.
12111211- indices_q = cu_seqlens_q[:-1]
12121212- query_layer = query_layer.squeeze(1)
12131213- else:
12141214- # The -q_len: slice assumes left padding.
12151215- attention_mask = attention_mask[:, -query_length:]
12161216- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
12171217- query_layer, attention_mask
12181218- )
12191219-12201220- return (
12211221- query_layer,
12221222- key_layer,
12231223- value_layer,
12241224- indices_q,
12251225- (cu_seqlens_q, cu_seqlens_k),
12261226- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
12271227- )
12281228-12291229-12301230-ATTENTION_CLASSES = {
12311231- "eager": DeepseekV2Attention,
12321232- "flash_attention_2": DeepseekV2FlashAttention2,
12331233-12341234- "mla_eager": DeepseekV2Attention,
12351235- "mla_flash_attention_2": DeepseekV2FlashAttention2,
12361236-12371237- "mha_eager": LlamaAttention,
12381238- "mha_flash_attention_2": LlamaFlashAttention2
12391239-}
12401240-12411241-12421242-class DeepseekV2DecoderLayer(nn.Module):
12431243- def __init__(self, config: DeepseekV2Config, layer_idx: int):
12441244- super().__init__()
12451245- self.hidden_size = config.hidden_size
12461246-12471247-12481248- if config.use_mla:
12491249- attn_implementation = "mla_" + config._attn_implementation
12501250- else:
12511251- attn_implementation = "mha_" + config._attn_implementation
12521252-12531253- self.self_attn = ATTENTION_CLASSES[attn_implementation](
12541254- config=config, layer_idx=layer_idx
12551255- )
12561256-12571257- self.mlp = (
12581258- DeepseekV2MoE(config)
12591259- if (
12601260- config.n_routed_experts is not None
12611261- and layer_idx >= config.first_k_dense_replace
12621262- and layer_idx % config.moe_layer_freq == 0
12631263- )
12641264- else DeepseekV2MLP(config)
12651265- )
12661266- self.input_layernorm = DeepseekV2RMSNorm(
12671267- config.hidden_size, eps=config.rms_norm_eps
12681268- )
12691269- self.post_attention_layernorm = DeepseekV2RMSNorm(
12701270- config.hidden_size, eps=config.rms_norm_eps
12711271- )
12721272-12731273- def forward(
12741274- self,
12751275- hidden_states: torch.Tensor,
12761276- attention_mask: Optional[torch.Tensor] = None,
12771277- position_ids: Optional[torch.LongTensor] = None,
12781278- past_key_value: Optional[Tuple[torch.Tensor]] = None,
12791279- output_attentions: Optional[bool] = False,
12801280- use_cache: Optional[bool] = False,
12811281- **kwargs,
12821282- ) -> Tuple[
12831283- torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
12841284- ]:
12851285- """
12861286- Args:
12871287- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
12881288- attention_mask (`torch.FloatTensor`, *optional*):
12891289- attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
12901290- query_sequence_length, key_sequence_length)` if default attention is used.
12911291- output_attentions (`bool`, *optional*):
12921292- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
12931293- returned tensors for more detail.
12941294- use_cache (`bool`, *optional*):
12951295- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
12961296- (see `past_key_values`).
12971297- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
12981298- """
12991299- if "padding_mask" in kwargs:
13001300- warnings.warn(
13011301- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
13021302- )
13031303- residual = hidden_states
13041304-13051305- hidden_states = self.input_layernorm(hidden_states)
13061306-13071307- # Self Attention
13081308- hidden_states, self_attn_weights, present_key_value = self.self_attn(
13091309- hidden_states=hidden_states,
13101310- attention_mask=attention_mask,
13111311- position_ids=position_ids,
13121312- past_key_value=past_key_value,
13131313- output_attentions=output_attentions,
13141314- use_cache=use_cache,
13151315- **kwargs,
13161316- )
13171317- hidden_states = residual + hidden_states
13181318-13191319- # Fully Connected
13201320- residual = hidden_states
13211321- hidden_states = self.post_attention_layernorm(hidden_states)
13221322- hidden_states = self.mlp(hidden_states)
13231323- hidden_states = residual + hidden_states
13241324-13251325- outputs = (hidden_states,)
13261326-13271327- if output_attentions:
13281328- outputs += (self_attn_weights,)
13291329-13301330- if use_cache:
13311331- outputs += (present_key_value,)
13321332-13331333- return outputs
13341334-13351335-13361336-DeepseekV2_START_DOCSTRING = r"""
13371337- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
13381338- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
13391339- etc.)
13401340-13411341- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
13421342- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
13431343- and behavior.
13441344-13451345- Parameters:
13461346- config ([`DeepseekV2Config`]):
13471347- Model configuration class with all the parameters of the model. Initializing with a config file does not
13481348- load the weights associated with the model, only the configuration. Check out the
13491349- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
13501350-"""
13511351-13521352-13531353-@add_start_docstrings(
13541354- "The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.",
13551355- DeepseekV2_START_DOCSTRING,
13561356-)
13571357-class DeepseekV2PreTrainedModel(PreTrainedModel):
13581358- config_class = DeepseekV2Config
13591359- base_model_prefix = "model"
13601360- supports_gradient_checkpointing = True
13611361- _no_split_modules = ["DeepseekV2DecoderLayer"]
13621362- _skip_keys_device_placement = "past_key_values"
13631363- _supports_flash_attn_2 = True
13641364- _supports_cache_class = True
13651365-13661366- def _init_weights(self, module):
13671367- std = self.config.initializer_range
13681368- if isinstance(module, nn.Linear):
13691369- module.weight.data.normal_(mean=0.0, std=std)
13701370- if module.bias is not None:
13711371- module.bias.data.zero_()
13721372- elif isinstance(module, nn.Embedding):
13731373- module.weight.data.normal_(mean=0.0, std=std)
13741374- if module.padding_idx is not None:
13751375- module.weight.data[module.padding_idx].zero_()
13761376-13771377-13781378-DeepseekV2_INPUTS_DOCSTRING = r"""
13791379- Args:
13801380- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
13811381- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
13821382- it.
13831383-13841384- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
13851385- [`PreTrainedTokenizer.__call__`] for details.
13861386-13871387- [What are input IDs?](../glossary#input-ids)
13881388- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
13891389- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
13901390-13911391- - 1 for tokens that are **not masked**,
13921392- - 0 for tokens that are **masked**.
13931393-13941394- [What are attention masks?](../glossary#attention-mask)
13951395-13961396- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
13971397- [`PreTrainedTokenizer.__call__`] for details.
13981398-13991399- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
14001400- `past_key_values`).
14011401-14021402- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
14031403- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
14041404- information on the default strategy.
14051405-14061406- - 1 indicates the head is **not masked**,
14071407- - 0 indicates the head is **masked**.
14081408- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
14091409- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
14101410- config.n_positions - 1]`.
14111411-14121412- [What are position IDs?](../glossary#position-ids)
14131413- past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
14141414- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
14151415- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
14161416- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
14171417-14181418- Two formats are allowed:
14191419- - a [`~cache_utils.Cache`] instance;
14201420- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
14211421- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
14221422- cache format.
14231423-14241424- The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
14251425- legacy cache format will be returned.
14261426-14271427- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
14281428- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
14291429- of shape `(batch_size, sequence_length)`.
14301430- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
14311431- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
14321432- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
14331433- model's internal embedding lookup matrix.
14341434- use_cache (`bool`, *optional*):
14351435- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
14361436- `past_key_values`).
14371437- output_attentions (`bool`, *optional*):
14381438- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
14391439- tensors for more detail.
14401440- output_hidden_states (`bool`, *optional*):
14411441- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
14421442- more detail.
14431443- return_dict (`bool`, *optional*):
14441444- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
14451445-"""
14461446-14471447-14481448-@add_start_docstrings(
14491449- "The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.",
14501450- DeepseekV2_START_DOCSTRING,
14511451-)
14521452-class DeepseekV2Model(DeepseekV2PreTrainedModel):
14531453- """
14541454- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`]
14551455-14561456- Args:
14571457- config: DeepseekV2Config
14581458- """
14591459-14601460- def __init__(self, config: DeepseekV2Config):
14611461- super().__init__(config)
14621462- self.padding_idx = config.pad_token_id
14631463- self.vocab_size = config.vocab_size
14641464-14651465- self.embed_tokens = nn.Embedding(
14661466- config.vocab_size, config.hidden_size, self.padding_idx
14671467- )
14681468- self.layers = nn.ModuleList(
14691469- [
14701470- DeepseekV2DecoderLayer(config, layer_idx)
14711471- for layer_idx in range(config.num_hidden_layers)
14721472- ]
14731473- )
14741474- # print(config._attn_implementation)
14751475- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
14761476- self.norm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
14771477-14781478- self.gradient_checkpointing = False
14791479- # Initialize weights and apply final processing
14801480- self.post_init()
14811481-14821482- def get_input_embeddings(self):
14831483- return self.embed_tokens
14841484-14851485- def set_input_embeddings(self, value):
14861486- self.embed_tokens = value
14871487-14881488- @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
14891489- def forward(
14901490- self,
14911491- input_ids: torch.LongTensor = None,
14921492- attention_mask: Optional[torch.Tensor] = None,
14931493- position_ids: Optional[torch.LongTensor] = None,
14941494- past_key_values: Optional[List[torch.FloatTensor]] = None,
14951495- inputs_embeds: Optional[torch.FloatTensor] = None,
14961496- use_cache: Optional[bool] = None,
14971497- output_attentions: Optional[bool] = None,
14981498- output_hidden_states: Optional[bool] = None,
14991499- return_dict: Optional[bool] = None,
15001500- cache_position: Optional[torch.LongTensor] = None
15011501- ) -> Union[Tuple, BaseModelOutputWithPast]:
15021502- output_attentions = (
15031503- output_attentions
15041504- if output_attentions is not None
15051505- else self.config.output_attentions
15061506- )
15071507- output_hidden_states = (
15081508- output_hidden_states
15091509- if output_hidden_states is not None
15101510- else self.config.output_hidden_states
15111511- )
15121512- use_cache = use_cache if use_cache is not None else self.config.use_cache
15131513-15141514- return_dict = (
15151515- return_dict if return_dict is not None else self.config.use_return_dict
15161516- )
15171517-15181518- # retrieve input_ids and inputs_embeds
15191519- if input_ids is not None and inputs_embeds is not None:
15201520- raise ValueError(
15211521- "You cannot specify both input_ids and inputs_embeds at the same time"
15221522- )
15231523- elif input_ids is not None:
15241524- batch_size, seq_length = input_ids.shape[:2]
15251525- elif inputs_embeds is not None:
15261526- batch_size, seq_length = inputs_embeds.shape[:2]
15271527- else:
15281528- raise ValueError("You have to specify either input_ids or inputs_embeds")
15291529-15301530- if self.gradient_checkpointing and self.training:
15311531- if use_cache:
15321532- logger.warning_once(
15331533- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers."
15341534- )
15351535- use_cache = False
15361536-15371537- past_key_values_length = 0
15381538- if use_cache:
15391539- use_legacy_cache = not isinstance(past_key_values, Cache)
15401540- if use_legacy_cache:
15411541- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
15421542- past_key_values_length = past_key_values.get_usable_length(seq_length)
15431543-15441544- if position_ids is None:
15451545- device = input_ids.device if input_ids is not None else inputs_embeds.device
15461546- position_ids = torch.arange(
15471547- past_key_values_length,
15481548- seq_length + past_key_values_length,
15491549- dtype=torch.long,
15501550- device=device,
15511551- )
15521552- position_ids = position_ids.unsqueeze(0)
15531553-15541554- if inputs_embeds is None:
15551555- inputs_embeds = self.embed_tokens(input_ids)
15561556-15571557- if self._use_flash_attention_2:
15581558- # 2d mask is passed through the layers
15591559- attention_mask = (
15601560- attention_mask
15611561- if (attention_mask is not None and 0 in attention_mask)
15621562- else None
15631563- )
15641564- else:
15651565- # 4d mask is passed through the layers
15661566- attention_mask = _prepare_4d_causal_attention_mask(
15671567- attention_mask,
15681568- (batch_size, seq_length),
15691569- inputs_embeds,
15701570- past_key_values_length,
15711571- )
15721572-15731573- # embed positions
15741574- hidden_states = inputs_embeds
15751575-15761576- # decoder layers
15771577- all_hidden_states = () if output_hidden_states else None
15781578- all_self_attns = () if output_attentions else None
15791579- next_decoder_cache = None
15801580-15811581- for decoder_layer in self.layers:
15821582- if output_hidden_states:
15831583- all_hidden_states += (hidden_states,)
15841584-15851585- if self.gradient_checkpointing and self.training:
15861586- layer_outputs = self._gradient_checkpointing_func(
15871587- decoder_layer.__call__,
15881588- hidden_states,
15891589- attention_mask,
15901590- position_ids,
15911591- past_key_values,
15921592- output_attentions,
15931593- use_cache,
15941594- )
15951595- else:
15961596- layer_outputs = decoder_layer(
15971597- hidden_states,
15981598- attention_mask=attention_mask,
15991599- position_ids=position_ids,
16001600- past_key_value=past_key_values,
16011601- output_attentions=output_attentions,
16021602- use_cache=use_cache,
16031603- )
16041604-16051605- hidden_states = layer_outputs[0]
16061606-16071607- if use_cache:
16081608- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
16091609-16101610- if output_attentions:
16111611- all_self_attns += (layer_outputs[1],)
16121612-16131613- hidden_states = self.norm(hidden_states)
16141614-16151615- # add hidden states from the last decoder layer
16161616- if output_hidden_states:
16171617- all_hidden_states += (hidden_states,)
16181618-16191619- next_cache = None
16201620- if use_cache:
16211621- next_cache = (
16221622- next_decoder_cache.to_legacy_cache()
16231623- if use_legacy_cache
16241624- else next_decoder_cache
16251625- )
16261626- if not return_dict:
16271627- return tuple(
16281628- v
16291629- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
16301630- if v is not None
16311631- )
16321632- return BaseModelOutputWithPast(
16331633- last_hidden_state=hidden_states,
16341634- past_key_values=next_cache,
16351635- hidden_states=all_hidden_states,
16361636- attentions=all_self_attns,
16371637- )
16381638-16391639-16401640-class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
16411641- _tied_weights_keys = ["lm_head.weight"]
16421642-16431643- def __init__(self, config):
16441644- super().__init__(config)
16451645- self.model = DeepseekV2Model(config)
16461646- self.vocab_size = config.vocab_size
16471647- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
16481648-16491649- # Initialize weights and apply final processing
16501650- self.post_init()
16511651-16521652- def get_input_embeddings(self):
16531653- return self.model.embed_tokens
16541654-16551655- def set_input_embeddings(self, value):
16561656- self.model.embed_tokens = value
16571657-16581658- def get_output_embeddings(self):
16591659- return self.lm_head
16601660-16611661- def set_output_embeddings(self, new_embeddings):
16621662- self.lm_head = new_embeddings
16631663-16641664- def set_decoder(self, decoder):
16651665- self.model = decoder
16661666-16671667- def get_decoder(self):
16681668- return self.model
16691669-16701670- @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
16711671- @replace_return_docstrings(
16721672- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
16731673- )
16741674- def forward(
16751675- self,
16761676- input_ids: torch.LongTensor = None,
16771677- attention_mask: Optional[torch.Tensor] = None,
16781678- position_ids: Optional[torch.LongTensor] = None,
16791679- past_key_values: Optional[List[torch.FloatTensor]] = None,
16801680- inputs_embeds: Optional[torch.FloatTensor] = None,
16811681- labels: Optional[torch.LongTensor] = None,
16821682- use_cache: Optional[bool] = None,
16831683- output_attentions: Optional[bool] = None,
16841684- output_hidden_states: Optional[bool] = None,
16851685- return_dict: Optional[bool] = None,
16861686- cache_position: Optional[torch.LongTensor] = None
16871687- ) -> Union[Tuple, CausalLMOutputWithPast]:
16881688- r"""
16891689- Args:
16901690- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
16911691- Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers.,
16921692- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
16931693- (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`.
16941694-16951695- Returns:
16961696-16971697- Example:
16981698-16991699- ```python
17001700- >>> from transformers import AutoTokenizer, DeepseekV2ForCausalLM
17011701-17021702- >>> model = DeepseekV2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
17031703- >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
17041704-17051705- >>> prompt = "Hey, are you conscious? Can you talk to me?"
17061706- >>> inputs = tokenizer(prompt, return_tensors="pt")
17071707-17081708- >>> # Generate
17091709- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
17101710- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
17111711- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
17121712- ```"""
17131713- output_attentions = (
17141714- output_attentions
17151715- if output_attentions is not None
17161716- else self.config.output_attentions
17171717- )
17181718- output_hidden_states = (
17191719- output_hidden_states
17201720- if output_hidden_states is not None
17211721- else self.config.output_hidden_states
17221722- )
17231723- return_dict = (
17241724- return_dict if return_dict is not None else self.config.use_return_dict
17251725- )
17261726-17271727- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
17281728- outputs = self.model(
17291729- input_ids=input_ids,
17301730- attention_mask=attention_mask,
17311731- position_ids=position_ids,
17321732- past_key_values=past_key_values,
17331733- inputs_embeds=inputs_embeds,
17341734- use_cache=use_cache,
17351735- output_attentions=output_attentions,
17361736- output_hidden_states=output_hidden_states,
17371737- return_dict=return_dict,
17381738- cache_position=cache_position
17391739- )
17401740-17411741- hidden_states = outputs[0]
17421742- logits = self.lm_head(hidden_states)
17431743- logits = logits.float()
17441744-17451745- loss = None
17461746- if labels is not None:
17471747- # Shift so that tokens < n predict n
17481748- shift_logits = logits[..., :-1, :].contiguous()
17491749- shift_labels = labels[..., 1:].contiguous()
17501750- # Flatten the tokens
17511751- loss_fct = CrossEntropyLoss()
17521752- shift_logits = shift_logits.view(-1, self.config.vocab_size)
17531753- shift_labels = shift_labels.view(-1)
17541754- # Enable model parallelism
17551755- shift_labels = shift_labels.to(shift_logits.device)
17561756- loss = loss_fct(shift_logits, shift_labels)
17571757-17581758- if not return_dict:
17591759- output = (logits,) + outputs[1:]
17601760- return (loss,) + output if loss is not None else output
17611761-17621762- return CausalLMOutputWithPast(
17631763- loss=loss,
17641764- logits=logits,
17651765- past_key_values=outputs.past_key_values,
17661766- hidden_states=outputs.hidden_states,
17671767- attentions=outputs.attentions,
17681768- )
17691769-17701770- def prepare_inputs_for_generation(
17711771- self,
17721772- input_ids,
17731773- past_key_values=None,
17741774- attention_mask=None,
17751775- inputs_embeds=None,
17761776- **kwargs,
17771777- ):
17781778- past_length = 0
17791779- if past_key_values is not None:
17801780- if isinstance(past_key_values, Cache):
17811781- cache_length = past_key_values.get_seq_length()
17821782- past_length = past_key_values.seen_tokens
17831783- max_cache_length = past_key_values.get_max_length()
17841784- else:
17851785- cache_length = past_length = past_key_values[0][0].shape[2]
17861786- max_cache_length = None
17871787-17881788- # Keep only the unprocessed tokens:
17891789- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
17901790- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
17911791- # input)
17921792- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
17931793- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):]
17941794- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
17951795- # input_ids based on the past_length.
17961796- elif past_length < input_ids.shape[1]:
17971797- input_ids = input_ids[:, past_length:]
17981798- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
17991799-18001800- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
18011801- if (
18021802- max_cache_length is not None
18031803- and attention_mask is not None
18041804- and cache_length + input_ids.shape[1] > max_cache_length
18051805- ):
18061806- attention_mask = attention_mask[:, -max_cache_length:]
18071807-18081808- position_ids = kwargs.get("position_ids", None)
18091809- if attention_mask is not None and position_ids is None:
18101810- # create position_ids on the fly for batch generation
18111811- position_ids = attention_mask.long().cumsum(-1) - 1
18121812- position_ids.masked_fill_(attention_mask == 0, 1)
18131813- if past_key_values:
18141814- position_ids = position_ids[:, -input_ids.shape[1]:]
18151815-18161816- if self.generation_config.cache_implementation == "static":
18171817- # generation with static cache
18181818- cache_position = kwargs.get("cache_position", None)
18191819- if cache_position is None:
18201820- past_length = 0
18211821- else:
18221822- past_length = cache_position[-1] + 1
18231823- input_ids = input_ids[:, past_length:]
18241824- position_ids = position_ids[:, past_length:]
18251825-18261826- # TODO @gante we should only keep a `cache_position` in generate, and do +=1.
18271827- # same goes for position ids. Could also help with continued generation.
18281828- cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
18291829-18301830- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
18311831- if inputs_embeds is not None and past_key_values is None:
18321832- model_inputs = {"inputs_embeds": inputs_embeds}
18331833- else:
18341834- # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
18351835- # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
18361836- # TODO: use `next_tokens` directly instead.
18371837- model_inputs = {"input_ids": input_ids.contiguous()}
18381838-18391839- model_inputs.update(
18401840- {
18411841- "position_ids": position_ids.contiguous(),
18421842- "cache_position": cache_position,
18431843- "past_key_values": past_key_values,
18441844- "use_cache": kwargs.get("use_cache"),
18451845- "attention_mask": attention_mask,
18461846- }
18471847- )
18481848- return model_inputs
18491849-18501850- @staticmethod
18511851- def _reorder_cache(past_key_values, beam_idx):
18521852- reordered_past = ()
18531853- for layer_past in past_key_values:
18541854- reordered_past += (
18551855- tuple(
18561856- past_state.index_select(0, beam_idx.to(past_state.device))
18571857- for past_state in layer_past
18581858- ),
18591859- )
18601860- return reordered_past
18611861-18621862-18631863-@add_start_docstrings(
18641864- """
18651865- The DeepseekV2 Model transformer with a sequence classification head on top (linear layer).
18661866-18671867- [`DeepseekV2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
18681868- (e.g. GPT-2) do.
18691869-18701870- Since it does classification on the last token, it requires to know the position of the last token. If a
18711871- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
18721872- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
18731873- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
18741874- each row of the batch).
18751875- """,
18761876- DeepseekV2_START_DOCSTRING,
18771877-)
18781878-class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel):
18791879- def __init__(self, config):
18801880- super().__init__(config)
18811881- self.num_labels = config.num_labels
18821882- self.model = DeepseekV2Model(config)
18831883- self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
18841884-18851885- # Initialize weights and apply final processing
18861886- self.post_init()
18871887-18881888- def get_input_embeddings(self):
18891889- return self.model.embed_tokens
18901890-18911891- def set_input_embeddings(self, value):
18921892- self.model.embed_tokens = value
18931893-18941894- @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
18951895- def forward(
18961896- self,
18971897- input_ids: torch.LongTensor = None,
18981898- attention_mask: Optional[torch.Tensor] = None,
18991899- position_ids: Optional[torch.LongTensor] = None,
19001900- past_key_values: Optional[List[torch.FloatTensor]] = None,
19011901- inputs_embeds: Optional[torch.FloatTensor] = None,
19021902- labels: Optional[torch.LongTensor] = None,
19031903- use_cache: Optional[bool] = None,
19041904- output_attentions: Optional[bool] = None,
19051905- output_hidden_states: Optional[bool] = None,
19061906- return_dict: Optional[bool] = None,
19071907- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
19081908- r"""
19091909- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
19101910- Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers.,
19111911- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
19121912- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
19131913- """
19141914- return_dict = (
19151915- return_dict if return_dict is not None else self.config.use_return_dict
19161916- )
19171917-19181918- transformer_outputs = self.model(
19191919- input_ids,
19201920- attention_mask=attention_mask,
19211921- position_ids=position_ids,
19221922- past_key_values=past_key_values,
19231923- inputs_embeds=inputs_embeds,
19241924- use_cache=use_cache,
19251925- output_attentions=output_attentions,
19261926- output_hidden_states=output_hidden_states,
19271927- return_dict=return_dict,
19281928- )
19291929- hidden_states = transformer_outputs[0]
19301930- logits = self.score(hidden_states)
19311931-19321932- if input_ids is not None:
19331933- batch_size = input_ids.shape[0]
19341934- else:
19351935- batch_size = inputs_embeds.shape[0]
19361936-19371937- if self.config.pad_token_id is None and batch_size != 1:
19381938- raise ValueError(
19391939- "Cannot handle batch sizes > 1 if no padding token is defined."
19401940- )
19411941- if self.config.pad_token_id is None:
19421942- sequence_lengths = -1
19431943- else:
19441944- if input_ids is not None:
19451945- sequence_lengths = (
19461946- torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
19471947- ).to(logits.device)
19481948- else:
19491949- sequence_lengths = -1
19501950-19511951- pooled_logits = logits[
19521952- torch.arange(batch_size, device=logits.device), sequence_lengths
19531953- ]
19541954-19551955- loss = None
19561956- if labels is not None:
19571957- labels = labels.to(logits.device)
19581958- if self.config.problem_type is None:
19591959- if self.num_labels == 1:
19601960- self.config.problem_type = "regression"
19611961- elif self.num_labels > 1 and (
19621962- labels.dtype == torch.long or labels.dtype == torch.int
19631963- ):
19641964- self.config.problem_type = "single_label_classification"
19651965- else:
19661966- self.config.problem_type = "multi_label_classification"
19671967-19681968- if self.config.problem_type == "regression":
19691969- loss_fct = MSELoss()
19701970- if self.num_labels == 1:
19711971- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
19721972- else:
19731973- loss = loss_fct(pooled_logits, labels)
19741974- elif self.config.problem_type == "single_label_classification":
19751975- loss_fct = CrossEntropyLoss()
19761976- loss = loss_fct(
19771977- pooled_logits.view(-1, self.num_labels), labels.view(-1)
19781978- )
19791979- elif self.config.problem_type == "multi_label_classification":
19801980- loss_fct = BCEWithLogitsLoss()
19811981- loss = loss_fct(pooled_logits, labels)
19821982- if not return_dict:
19831983- output = (pooled_logits,) + transformer_outputs[1:]
19841984- return ((loss,) + output) if loss is not None else output
19851985-19861986- return SequenceClassifierOutputWithPast(
19871987- loss=loss,
19881988- logits=pooled_logits,
19891989- past_key_values=transformer_outputs.past_key_values,
19901990- hidden_states=transformer_outputs.hidden_states,
19911991- attentions=transformer_outputs.attentions,
19921992- )
-82
src/eval_deepseek.py
···11-"""
22-ExpRate evaluation of a fine-tuned DeepSeek-OCR-2 checkpoint on the test splits.
33-44-ExpRate = fraction of exact string matches after whitespace normalisation.
55-66-Usage:
77- uv run eval-deepseek [--checkpoint checkpoints/deepseek/final]
88- [--splits mathwriting_test typeset_test typeset_mixed_test]
99- [--n 500]
1010-"""
1111-1212-import argparse
1313-1414-import torch
1515-from PIL import Image
1616-from tqdm import tqdm
1717-1818-import re
1919-2020-import random
2121-2222-from .data import TEST_SPLITS, load_records
2323-from .mine_failures import _infer, _load
2424-2525-2626-def normalize(s: str) -> str:
2727- return re.sub(r"\s+", " ", s).strip()
2828-2929-3030-def evaluate(
3131- checkpoint: str,
3232- splits: list[str] | None = None,
3333- n: int | None = None,
3434-) -> float:
3535- model, tokenizer = _load(checkpoint)
3636- prompt_ids = tokenizer.encode(
3737- "\nTranscribe this image to Typst notation.\n", add_special_tokens=False
3838- )
3939-4040- split_names = splits or TEST_SPLITS
4141- rng = random.Random(0)
4242- records = []
4343- for name in split_names:
4444- recs = load_records([name], dedupe=False)
4545- if n is not None and len(recs) > n:
4646- recs = rng.sample(recs, n)
4747- records.extend(recs)
4848-4949- correct = 0
5050- per_split: dict[str, list[bool]] = {}
5151-5252- for r in tqdm(records, desc="Evaluating"):
5353- img = Image.open(r["image_path"]).convert("RGB")
5454- pred = normalize(_infer(model, tokenizer, prompt_ids, img))
5555- gt = normalize(r["typst"])
5656- hit = pred == gt
5757- correct += hit
5858- per_split.setdefault(r.get("split", "unknown"), []).append(hit)
5959- print(f" GT : {repr(gt)}")
6060- print(f" PRED: {repr(pred)}")
6161- print(f" HIT : {hit}\n")
6262-6363- total = len(records)
6464- print(f"\nExpRate: {correct/total:.4f} ({correct}/{total})")
6565- for split, hits in sorted(per_split.items()):
6666- print(f" {split}: {sum(hits)/len(hits):.4f} ({sum(hits)}/{len(hits)})")
6767- return correct / total
6868-6969-7070-def main() -> None:
7171- parser = argparse.ArgumentParser()
7272- parser.add_argument("--checkpoint", default="checkpoints/deepseek/final")
7373- parser.add_argument("--splits", nargs="+", default=None,
7474- metavar="SPLIT", help="Override test splits")
7575- parser.add_argument("--n", type=int, default=None,
7676- help="Max records per split (stratified); omit for full test set")
7777- args = parser.parse_args()
7878- evaluate(args.checkpoint, args.splits, args.n)
7979-8080-8181-if __name__ == "__main__":
8282- main()