this repo has no description
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

Add DeepSeek-OCR-2 modeling files from HuggingFace hub (verbatim)

Copied from deepseek-ai/DeepSeek-OCR-2 commit aaa02f38.
No modifications -- patches for training compatibility follow in the next commit.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

+4527
src/deepseek_ocr2/__init__.py

This is a binary file and will not be displayed.

+210
src/deepseek_ocr2/configuration_deepseek_v2.py
··· 1 + from transformers.configuration_utils import PretrainedConfig 2 + from transformers.utils import logging 3 + 4 + logger = logging.get_logger(__name__) 5 + 6 + DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {} 7 + class DeepseekV2Config(PretrainedConfig): 8 + r""" 9 + This is the configuration class to store the configuration of a [`DeepseekV2Model`]. It is used to instantiate an DeepSeek 10 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the 11 + defaults will yield a similar configuration to that of the DeepSeek-V2 with multi-latent attention. 12 + 13 + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 14 + documentation from [`PretrainedConfig`] for more information. 15 + 16 + 17 + Args: 18 + vocab_size (`int`, *optional*, defaults to 102400): 19 + Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the 20 + `inputs_ids` passed when calling [`DeepseekV2Model`] 21 + hidden_size (`int`, *optional*, defaults to 4096): 22 + Dimension of the hidden representations. 23 + intermediate_size (`int`, *optional*, defaults to 11008): 24 + Dimension of the MLP representations. 25 + moe_intermediate_size (`int`, *optional*, defaults to 1407): 26 + Dimension of the MoE representations. 27 + num_hidden_layers (`int`, *optional*, defaults to 32): 28 + Number of hidden layers in the Transformer decoder. 29 + num_attention_heads (`int`, *optional*, defaults to 32): 30 + Number of attention heads for each attention layer in the Transformer decoder. 31 + n_shared_experts (`int`, *optional*, defaults to None): 32 + Number of shared experts, None means dense model. 33 + n_routed_experts (`int`, *optional*, defaults to None): 34 + Number of routed experts, None means dense model. 35 + routed_scaling_factor (`float`, *optional*, defaults to 1.0): 36 + Scaling factor or routed experts. 37 + topk_method (`str`, *optional*, defaults to `gready`): 38 + Topk method used in routed gate. 39 + n_group (`int`, *optional*, defaults to None): 40 + Number of groups for routed experts. 41 + topk_group (`int`, *optional*, defaults to None): 42 + Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). 43 + num_experts_per_tok (`int`, *optional*, defaults to None): 44 + Number of selected experts, None means dense model. 45 + moe_layer_freq (`int`, *optional*, defaults to 1): 46 + The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. 47 + first_k_dense_replace (`int`, *optional*, defaults to 0): 48 + Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). 49 + \--k dense layers--/ 50 + norm_topk_prob (`bool`, *optional*, defaults to False): 51 + Whether to normalize the weights of the routed experts. 52 + scoring_func (`str`, *optional*, defaults to 'softmax'): 53 + Method of computing expert weights. 54 + aux_loss_alpha (`float`, *optional*, defaults to 0.001): 55 + Auxiliary loss weight coefficient. 56 + seq_aux = (`bool`, *optional*, defaults to True): 57 + Whether to compute the auxiliary loss for each individual sample. 58 + num_key_value_heads (`int`, *optional*): 59 + This is the number of key_value heads that should be used to implement Grouped Query Attention. If 60 + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 61 + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When 62 + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 63 + by meanpooling all the original heads within that group. For more details checkout [this 64 + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to 65 + `num_attention_heads`. 66 + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 67 + The non-linear activation function (function or string) in the decoder. 68 + max_position_embeddings (`int`, *optional*, defaults to 2048): 69 + The maximum sequence length that this model might ever be used with. 70 + initializer_range (`float`, *optional*, defaults to 0.02): 71 + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 72 + rms_norm_eps (`float`, *optional*, defaults to 1e-06): 73 + The epsilon used by the rms normalization layers. 74 + use_cache (`bool`, *optional*, defaults to `True`): 75 + Whether or not the model should return the last key/values attentions (not used by all models). Only 76 + relevant if `config.is_decoder=True`. 77 + pad_token_id (`int`, *optional*): 78 + Padding token id. 79 + bos_token_id (`int`, *optional*, defaults to 1): 80 + Beginning of stream token id. 81 + eos_token_id (`int`, *optional*, defaults to 2): 82 + End of stream token id. 83 + pretraining_tp (`int`, *optional*, defaults to 1): 84 + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this 85 + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is 86 + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this 87 + issue](https://github.com/pytorch/pytorch/issues/76232). 88 + tie_word_embeddings (`bool`, *optional*, defaults to `False`): 89 + Whether to tie weight embeddings 90 + rope_theta (`float`, *optional*, defaults to 10000.0): 91 + The base period of the RoPE embeddings. 92 + rope_scaling (`Dict`, *optional*): 93 + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling 94 + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is 95 + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update 96 + `max_position_embeddings` to the expected new maximum. 97 + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): 98 + Whether to use a bias in the query, key, value and output projection layers during self-attention. 99 + attention_dropout (`float`, *optional*, defaults to 0.0): 100 + The dropout ratio for the attention probabilities. 101 + use_mla (`bool`, *optional*, defaults to `True`): Use multi-latent attention or multi-head attention. If True, 102 + the model will use multi-latent attention, otherwise, it will use multi-head attention. 103 + 104 + ```python 105 + >>> from transformers import DeepseekV2Model, DeepseekV2Config 106 + 107 + >>> # Initializing a Deepseek-V2 style configuration 108 + >>> configuration = DeepseekV2Config() 109 + 110 + >>> # Accessing the model configuration 111 + >>> configuration = model.config 112 + ```""" 113 + 114 + model_type = "deepseek_v2" 115 + keys_to_ignore_at_inference = ["past_key_values"] 116 + 117 + def __init__( 118 + self, 119 + vocab_size=102400, 120 + hidden_size=4096, 121 + intermediate_size=11008, 122 + moe_intermediate_size = 1407, 123 + num_hidden_layers=30, 124 + num_attention_heads=32, 125 + num_key_value_heads=32, 126 + n_shared_experts = None, 127 + n_routed_experts = None, 128 + ep_size = 1, 129 + routed_scaling_factor = 1.0, 130 + kv_lora_rank = 512, 131 + q_lora_rank = 1536, 132 + qk_rope_head_dim = 64, 133 + v_head_dim = 128, 134 + qk_nope_head_dim = 128, 135 + topk_method = 'gready', 136 + n_group = None, 137 + topk_group = None, 138 + num_experts_per_tok = None, 139 + moe_layer_freq = 1, 140 + first_k_dense_replace = 0, 141 + norm_topk_prob = False, 142 + scoring_func = 'softmax', 143 + aux_loss_alpha = 0.001, 144 + seq_aux = True, 145 + hidden_act="silu", 146 + max_position_embeddings=2048, 147 + initializer_range=0.02, 148 + rms_norm_eps=1e-6, 149 + use_cache=True, 150 + pad_token_id=None, 151 + bos_token_id=100000, 152 + eos_token_id=100001, 153 + pretraining_tp=1, 154 + tie_word_embeddings=False, 155 + rope_theta=10000.0, 156 + rope_scaling=None, 157 + attention_bias=False, 158 + attention_dropout=0.0, 159 + use_mla=True, 160 + **kwargs, 161 + ): 162 + self.vocab_size = vocab_size 163 + self.max_position_embeddings = max_position_embeddings 164 + self.hidden_size = hidden_size 165 + self.intermediate_size = intermediate_size 166 + self.moe_intermediate_size = moe_intermediate_size 167 + self.num_hidden_layers = num_hidden_layers 168 + self.num_attention_heads = num_attention_heads 169 + self.n_shared_experts = n_shared_experts 170 + self.n_routed_experts = n_routed_experts 171 + self.ep_size = ep_size 172 + self.routed_scaling_factor = routed_scaling_factor 173 + self.kv_lora_rank = kv_lora_rank 174 + self.q_lora_rank = q_lora_rank 175 + self.qk_rope_head_dim = qk_rope_head_dim 176 + self.v_head_dim = v_head_dim 177 + self.qk_nope_head_dim = qk_nope_head_dim 178 + self.topk_method = topk_method 179 + self.n_group = n_group 180 + self.topk_group = topk_group 181 + self.num_experts_per_tok = num_experts_per_tok 182 + self.moe_layer_freq = moe_layer_freq 183 + self.first_k_dense_replace = first_k_dense_replace 184 + self.norm_topk_prob = norm_topk_prob 185 + self.scoring_func = scoring_func 186 + self.aux_loss_alpha = aux_loss_alpha 187 + self.seq_aux = seq_aux 188 + # for backward compatibility 189 + if num_key_value_heads is None: 190 + num_key_value_heads = num_attention_heads 191 + 192 + self.num_key_value_heads = num_key_value_heads 193 + self.hidden_act = hidden_act 194 + self.initializer_range = initializer_range 195 + self.rms_norm_eps = float(rms_norm_eps) 196 + self.pretraining_tp = pretraining_tp 197 + self.use_cache = use_cache 198 + self.rope_theta = rope_theta 199 + self.rope_scaling = rope_scaling 200 + self.attention_bias = attention_bias 201 + self.attention_dropout = attention_dropout 202 + self.use_mla = use_mla 203 + 204 + super().__init__( 205 + pad_token_id=pad_token_id, 206 + bos_token_id=bos_token_id, 207 + eos_token_id=eos_token_id, 208 + tie_word_embeddings=tie_word_embeddings, 209 + **kwargs, 210 + )
+280
src/deepseek_ocr2/conversation.py
··· 1 + """ 2 + From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py 3 + """ 4 + 5 + import dataclasses 6 + from enum import IntEnum, auto 7 + from typing import Any, Dict, List 8 + 9 + 10 + class SeparatorStyle(IntEnum): 11 + """Separator styles.""" 12 + 13 + DeepSeek = auto() 14 + DeepSeekV2 = auto() 15 + PLAIN = auto() 16 + ALIGNMENT = auto() 17 + 18 + 19 + @dataclasses.dataclass 20 + class Conversation: 21 + """A class that manages prompt templates and keeps all conversation history.""" 22 + 23 + # The name of this template 24 + name: str 25 + # The template of the system prompt 26 + system_template: str = "{system_message}" 27 + # The system message 28 + system_message: str = "" 29 + # The names of two roles 30 + roles: List[str] = (("USER", "ASSISTANT"),) 31 + # All messages. Each item is (role, message). 32 + messages: List[List[str]] = () 33 + # The number of few shot examples 34 + offset: int = 0 35 + # The separator style and configurations 36 + sep_style: SeparatorStyle = SeparatorStyle.DeepSeek 37 + sep: str = "\n" 38 + sep2: str = None 39 + # Stop criteria (the default one is EOS token) 40 + stop_str: str = None 41 + # Stops generation if meeting any token in this list 42 + stop_token_ids: List[int] = None 43 + 44 + def get_prompt(self) -> str: 45 + """Get the prompt for generation.""" 46 + system_prompt = self.system_template.format(system_message=self.system_message) 47 + if self.sep_style == SeparatorStyle.DeepSeek: 48 + seps = [self.sep, self.sep2] 49 + if system_prompt == "" or system_prompt is None: 50 + ret = "" 51 + else: 52 + ret = system_prompt + seps[0] 53 + for i, (role, message) in enumerate(self.messages): 54 + if message: 55 + ret += role + ": " + message + seps[i % 2] 56 + else: 57 + ret += role + ":" 58 + return ret 59 + elif self.sep_style == SeparatorStyle.DeepSeekV2: 60 + seps = [self.sep, self.sep2] 61 + if system_prompt == "" or system_prompt is None: 62 + ret = "" 63 + else: 64 + ret = system_prompt + seps[0] 65 + for i, (role, message) in enumerate(self.messages): 66 + if message: 67 + if role == "User": 68 + ret += "<|sft▁begin|>\n" + message + self.sep #<|sft▁begin|>User Input<|sft▁end|>\nResponse<|end▁of▁sentence|> 69 + else: 70 + ret += message + self.sep2 71 + else: 72 + ret = ret 73 + return ret 74 + 75 + elif self.sep_style == SeparatorStyle.PLAIN: 76 + seps = [self.sep, self.sep2] 77 + ret = "" 78 + for i, (role, message) in enumerate(self.messages): 79 + if message: 80 + if type(message) is tuple: 81 + message, _, _ = message 82 + if i % 2 == 0: 83 + ret += message + seps[i % 2] 84 + else: 85 + ret += message + seps[i % 2] 86 + else: 87 + ret += "" 88 + return ret 89 + elif self.sep_style == SeparatorStyle.ALIGNMENT: 90 + seps = [self.sep, self.sep2] 91 + ret = "" 92 + for i, (role, message) in enumerate(self.messages): 93 + if message: 94 + if type(message) is tuple: 95 + message, _, _ = message 96 + if i % 2 == 0: 97 + ret += '<image>\n' + seps[i % 2] 98 + else: 99 + ret += message + seps[i % 2] 100 + else: 101 + ret += "" 102 + return ret 103 + else: 104 + raise ValueError(f"Invalid style: {self.sep_style}") 105 + 106 + def set_system_message(self, system_message: str): 107 + """Set the system message.""" 108 + self.system_message = system_message 109 + 110 + def append_message(self, role: str, message: str): 111 + """Append a new message.""" 112 + self.messages.append([role, message]) 113 + 114 + def update_last_message(self, message: str): 115 + """Update the last output. 116 + 117 + The last message is typically set to be None when constructing the prompt, 118 + so we need to update it in-place after getting the response from a model. 119 + """ 120 + self.messages[-1][1] = message 121 + 122 + def reset_message(self): 123 + """Reset a new message.""" 124 + self.messages = [] 125 + 126 + def to_gradio_chatbot(self): 127 + """Convert the conversation to gradio chatbot format.""" 128 + ret = [] 129 + for i, (role, msg) in enumerate(self.messages[self.offset :]): 130 + if i % 2 == 0: 131 + ret.append([msg, None]) 132 + else: 133 + ret[-1][-1] = msg 134 + return ret 135 + 136 + def to_openai_api_messages(self): 137 + """Convert the conversation to OpenAI chat completion format.""" 138 + system_prompt = self.system_template.format(system_message=self.system_message) 139 + ret = [{"role": "system", "content": system_prompt}] 140 + 141 + for i, (_, msg) in enumerate(self.messages[self.offset :]): 142 + if i % 2 == 0: 143 + ret.append({"role": "user", "content": msg}) 144 + else: 145 + if msg is not None: 146 + ret.append({"role": "assistant", "content": msg}) 147 + return ret 148 + 149 + def copy(self): 150 + return Conversation( 151 + name=self.name, 152 + system_template=self.system_template, 153 + system_message=self.system_message, 154 + roles=self.roles, 155 + messages=[[x, y] for x, y in self.messages], 156 + offset=self.offset, 157 + sep_style=self.sep_style, 158 + sep=self.sep, 159 + sep2=self.sep2, 160 + stop_str=self.stop_str, 161 + stop_token_ids=self.stop_token_ids, 162 + ) 163 + 164 + def dict(self): 165 + return { 166 + "template_name": self.name, 167 + "system_message": self.system_message, 168 + "roles": self.roles, 169 + "messages": self.messages, 170 + "offset": self.offset, 171 + } 172 + 173 + 174 + # A global registry for all conversation templates 175 + conv_templates: Dict[str, Conversation] = {} 176 + 177 + 178 + def register_conv_template(template: Conversation, override: bool = False): 179 + """Register a new conversation template.""" 180 + if not override: 181 + assert template.name not in conv_templates, f"{template.name} has been registered." 182 + 183 + conv_templates[template.name] = template 184 + 185 + 186 + def get_conv_template(name: str) -> Conversation: 187 + """Get a conversation template.""" 188 + return conv_templates[name].copy() 189 + 190 + 191 + register_conv_template( 192 + Conversation( 193 + name="deepseek", 194 + system_template="{system_message}", 195 + # system_message="You are a helpful assistant. Please answer truthfully and write out your " 196 + # "thinking step by step to be sure you get the right answer.", 197 + system_message="", 198 + roles=("<|User|>", "<|Assistant|>"), 199 + messages=(), 200 + offset=0, 201 + sep_style=SeparatorStyle.DeepSeek, 202 + sep="\n\n", 203 + sep2="<|end▁of▁sentence|>", 204 + stop_token_ids=[100001], 205 + stop_str=["User:", "<|end▁of▁sentence|>"] 206 + ) 207 + ) 208 + register_conv_template( 209 + Conversation( 210 + name="deepseekv2", 211 + system_template="{system_message}", 212 + # system_message="You are a helpful assistant. Please answer truthfully and write out your " 213 + # "thinking step by step to be sure you get the right answer.", 214 + system_message="", 215 + roles=("<|User|>", "<|Assistant|>"), 216 + messages=(), 217 + offset=0, 218 + sep_style=SeparatorStyle.DeepSeek, 219 + sep="", 220 + sep2="<|end▁of▁sentence|>", 221 + stop_token_ids=[100001], 222 + stop_str=["User:", "<|end▁of▁sentence|>"] 223 + ) 224 + ) 225 + 226 + 227 + register_conv_template( 228 + Conversation( 229 + name="plain", 230 + system_template="", 231 + system_message="", 232 + roles=("", ""), 233 + messages=(), 234 + offset=0, 235 + sep_style=SeparatorStyle.PLAIN, 236 + sep="", 237 + sep2="", 238 + stop_token_ids=[100001], 239 + stop_str=['</s>'], 240 + ) 241 + ) 242 + 243 + 244 + register_conv_template( 245 + Conversation( 246 + name="alignment", 247 + system_template="", 248 + system_message="", 249 + roles=("", ""), 250 + messages=(), 251 + offset=0, 252 + sep_style=SeparatorStyle.ALIGNMENT, 253 + sep="", 254 + sep2="", 255 + stop_token_ids=[100001], 256 + stop_str=['</s>'], 257 + ) 258 + ) 259 + 260 + 261 + if __name__ == "__main__": 262 + print("deepseek template:") 263 + conv = get_conv_template("deepseek") 264 + conv.append_message(conv.roles[0], "Hello!") 265 + conv.append_message(conv.roles[1], "Hi! This is Tony.") 266 + conv.append_message(conv.roles[0], "Who are you?") 267 + conv.append_message(conv.roles[1], "I am a helpful assistant.") 268 + conv.append_message(conv.roles[0], "How are you?") 269 + conv.append_message(conv.roles[1], None) 270 + print(conv.get_prompt()) 271 + 272 + print("deepseekv2 template:") 273 + conv = get_conv_template("deepseekv2") 274 + conv.append_message(conv.roles[0], "Hello!") 275 + conv.append_message(conv.roles[1], "Hi! This is Tony.") 276 + conv.append_message(conv.roles[0], "Who are you?") 277 + conv.append_message(conv.roles[1], "I am a helpful assistant.") 278 + conv.append_message(conv.roles[0], "How are you?") 279 + conv.append_message(conv.roles[1], None) 280 + print(conv.get_prompt())
+1015
src/deepseek_ocr2/deepencoderv2.py
··· 1 + import torch.nn as nn 2 + import torch 3 + import torch.nn.functional as F 4 + import copy 5 + 6 + 7 + from typing import Optional, Tuple 8 + 9 + # from megatron.model import LayerNorm 10 + 11 + import transformers 12 + 13 + 14 + from typing import Optional, Tuple, Type 15 + from functools import partial 16 + 17 + 18 + 19 + class MlpProjector(nn.Module): 20 + 21 + def __init__(self, cfg): 22 + 23 + super().__init__() 24 + 25 + self.cfg = cfg 26 + 27 + if cfg.projector_type == "identity": 28 + modules = nn.Identity() 29 + 30 + elif cfg.projector_type == "linear": 31 + modules = nn.Linear(cfg.input_dim, cfg.n_embed) 32 + 33 + elif cfg.projector_type == "mlp_gelu": 34 + mlp_depth = cfg.get("depth", 1) 35 + modules = [nn.Linear(cfg.input_dim, cfg.n_embed)] 36 + for _ in range(1, mlp_depth): 37 + modules.append(nn.GELU()) 38 + modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) 39 + modules = nn.Sequential(*modules) 40 + 41 + elif cfg.projector_type == "normlayer_downsample_mlp_gelu": 42 + mlp_depth = cfg.get("depth", 1) 43 + mlp_ratio = cfg.get("mlp_ratio", 1) 44 + modules = [ 45 + nn.LayerNorm(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio), 46 + nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio) 47 + ] 48 + for _ in range(1, mlp_depth - 1): 49 + modules.append(nn.GELU()) 50 + modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio)) 51 + modules.append(nn.GELU()) 52 + modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed)) 53 + modules = nn.Sequential(*modules) 54 + 55 + elif cfg.projector_type == "downsample_mlp_gelu": 56 + mlp_depth = cfg.get("depth", 1) 57 + mlp_ratio = cfg.get("mlp_ratio", 1) 58 + modules = [nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)] 59 + for _ in range(1, mlp_depth - 1): 60 + modules.append(nn.GELU()) 61 + modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio)) 62 + modules.append(nn.GELU()) 63 + modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed)) 64 + modules = nn.Sequential(*modules) 65 + 66 + elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu": 67 + mlp_depth = cfg.get("depth", 1) 68 + self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) 69 + self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) 70 + 71 + modules = [] 72 + for _ in range(1, mlp_depth): 73 + modules.append(nn.GELU()) 74 + modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) 75 + modules = nn.Sequential(*modules) 76 + 77 + elif cfg.projector_type == "hybrid_split_feature_mlp_gelu": 78 + mlp_depth = cfg.get("depth", 1) 79 + channel_div = cfg.get("channel_div", 0.5) 80 + self.high_up_proj = nn.Linear(cfg.input_dim[0], int(cfg.n_embed * channel_div)) 81 + self.low_up_proj = nn.Linear(cfg.input_dim[1], cfg.n_embed - int(cfg.n_embed * channel_div)) 82 + 83 + modules = [] 84 + for _ in range(1, mlp_depth): 85 + modules.append(nn.GELU()) 86 + modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) 87 + modules = nn.Sequential(*modules) 88 + 89 + elif cfg.projector_type == "low_high_split_mlp_gelu": 90 + mlp_depth = cfg.get("depth", 1) 91 + modules = [] 92 + for _ in range(1, mlp_depth): 93 + modules.append(nn.GELU()) 94 + modules.append(nn.Linear(cfg.n_embed // 2, cfg.n_embed // 2)) 95 + modules = nn.Sequential(*modules) 96 + self.high_layers = nn.Sequential(*modules) 97 + self.low_layers = copy.deepcopy(modules) 98 + 99 + else: 100 + raise ValueError(f"Unknown projector type: {cfg.projector_type}") 101 + 102 + if cfg.get("token_pooling", False): 103 + self.token_pooling_layer = nn.Linear(cfg.input_dim * 4, cfg.input_dim) 104 + 105 + if cfg.get("conv_fusion_high_low_features", False): 106 + self.fusion_layer = nn.Linear(cfg.input_dim, cfg.input_dim) 107 + self.layers = modules 108 + 109 + def forward(self, x): 110 + if self.cfg.get("token_pooling", False): 111 + batch_size, wxh, channels = x.shape 112 + w = h = int(wxh**0.5) 113 + x = x.view(batch_size, w, h, channels) 114 + x = x.permute(0, 3, 1, 2) 115 + # import ipdb; ipdb.set_trace() 116 + patches = x.unfold(2, 2, 2).unfold(3, 2, 2) 117 + batch_size, channels, h_patches, w_patches, _, _ = patches.size() 118 + # 在通道维度上拼接 119 + patches = patches.contiguous().view(batch_size, channels, h_patches * w_patches, -1) 120 + 121 + # 通过线性层 122 + patches = patches.permute(0, 2, 1, 3).contiguous() 123 + patches = patches.view(batch_size, h_patches * w_patches, channels * 4) 124 + 125 + x = self.token_pooling_layer(patches) 126 + 127 + if self.cfg.get("conv_fusion_high_low_features", False): 128 + x = self.fusion_layer(x[:, 0]) + x[:, 1] 129 + 130 + if self.cfg.projector_type == 'low_high_hybrid_split_mlp_gelu': 131 + high_x, low_x = x[0], x[1] 132 + high_x = self.high_up_proj(high_x) 133 + low_x = self.low_up_proj(low_x) 134 + x = torch.concat([high_x, low_x], dim=-1) 135 + 136 + if self.cfg.projector_type == 'hybrid_split_feature_mlp_gelu': 137 + high_x = x[...,:self.cfg.input_dim[0]] 138 + low_x = x[...,self.cfg.input_dim[0]:] 139 + high_x = self.high_up_proj(high_x) 140 + low_x = self.low_up_proj(low_x) 141 + x = torch.concat([high_x, low_x], dim=-1) 142 + 143 + if self.cfg.projector_type == 'low_high_split_mlp_gelu': 144 + high_x, low_x = x[0], x[1] 145 + high_x = self.high_layers(high_x) 146 + low_x = self.low_layers(low_x) 147 + x = torch.concat([high_x, low_x], dim=-1) 148 + return x 149 + 150 + if self.cfg.projector_type == 'downsample_mlp_gelu' or self.cfg.projector_type == 'normlayer_downsample_mlp_gelu': 151 + bs, hw, input_dim = x.shape 152 + h = w = int((hw) ** 0.5) 153 + 154 + """compute padding""" 155 + if h % self.cfg.downsample_ratio: 156 + pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio 157 + else: 158 + pad = 0 159 + x = x.reshape(bs, h, w, input_dim) 160 + if pad > 0: 161 + x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0) 162 + 163 + """4 to 1 concat""" 164 + x = x.permute(0, 3, 1, 2) # B, C, H, W 165 + x = F.unfold(x, kernel_size=self.cfg.downsample_ratio, stride=self.cfg.downsample_ratio, padding=0) # B, C*4, HW // 4 166 + x = x.permute(0, 2, 1) 167 + 168 + return self.layers(x) 169 + 170 + @staticmethod 171 + def get_flops_per_sample(cfg): 172 + if cfg.projector_type == "linear": 173 + fwd = 2 * cfg.input_dim * cfg.n_embed 174 + 175 + elif "mlp_gelu" in cfg.projector_type : 176 + mlp_depth = cfg.get("depth", 1) 177 + downsample_ratio = cfg.get("downsample_ratio", 1) 178 + input_dim = sum(cfg.input_dim) if isinstance(cfg.input_dim, list) else cfg.input_dim 179 + input_dim = input_dim * downsample_ratio * downsample_ratio 180 + fwd = 2 * input_dim * cfg.n_embed + (mlp_depth - 1) * 2 * cfg.n_embed * cfg.n_embed 181 + else: 182 + fwd = 0 183 + 184 + return fwd * 3 185 + 186 + 187 + #===================qwen2================================ 188 + 189 + class CustomQwen2Decoder(nn.Module): 190 + """ 191 + Qwen2 visual encoder 192 + non-causal attention + causal attention 193 + token_type_ids :0=non-causal, 1=causal 194 + """ 195 + 196 + def __init__( 197 + self, 198 + decoder_layer: int = 24, 199 + max_position_embeddings: int = 131072, 200 + hidden_dimension: int = 896, 201 + num_attention_heads: int = 14, 202 + num_key_value_heads: int = 2, 203 + intermediate_size: int = 4864, 204 + vocab_size: int = 151936, 205 + attn_implementation: str = "sdpa", # ⭐ 206 + rms_norm_eps: float = 1e-06, 207 + rope_theta: float = 1000000.0, 208 + attention_dropout: float = 0.0, 209 + hidden_act: str = "silu", 210 + initializer_range: float = 0.02, 211 + ): 212 + super().__init__() 213 + 214 + # attn_implementation check 215 + if attn_implementation == "flash_attention_2": 216 + raise ValueError( 217 + "CustomQwen2Decoder do not support flash_attention_2," 218 + "new attention mask needs 'sdpa' or 'eager'" 219 + ) 220 + 221 + # load 222 + Qwen2Model = getattr(transformers.models.qwen2.modeling_qwen2, 'Qwen2Model') 223 + Qwen2Config = getattr(transformers, 'Qwen2Config') 224 + 225 + # config 226 + config = Qwen2Config( 227 + hidden_size=hidden_dimension, 228 + num_hidden_layers=decoder_layer, 229 + num_attention_heads=num_attention_heads, 230 + num_key_value_heads=num_key_value_heads, 231 + intermediate_size=intermediate_size, 232 + max_position_embeddings=max_position_embeddings, 233 + vocab_size=vocab_size, 234 + rms_norm_eps=rms_norm_eps, 235 + rope_theta=rope_theta, 236 + attention_dropout=attention_dropout, 237 + hidden_act=hidden_act, 238 + initializer_range=initializer_range, 239 + _attn_implementation=attn_implementation, # ⭐ 240 + ) 241 + 242 + # 243 + self.model = self._create_custom_model(Qwen2Model, config) 244 + 245 + del self.model.embed_tokens 246 + 247 + def _create_custom_model(self, Qwen2Model, config): 248 + """ Qwen2Model """ 249 + 250 + class CustomQwen2ModelInner(Qwen2Model): 251 + 252 + 253 + def forward( 254 + self, 255 + input_ids=None, 256 + attention_mask=None, 257 + position_ids=None, 258 + past_key_values=None, 259 + inputs_embeds=None, 260 + token_type_ids=None, # ⭐ 261 + use_cache=None, 262 + output_attentions=None, 263 + output_hidden_states=None, 264 + return_dict=None, 265 + cache_position=None, 266 + ): 267 + # token_type_ids 268 + self._current_token_type_ids = token_type_ids 269 + 270 + outputs = super().forward( 271 + input_ids=input_ids, 272 + attention_mask=attention_mask, 273 + position_ids=position_ids, 274 + past_key_values=past_key_values, 275 + inputs_embeds=inputs_embeds, 276 + use_cache=use_cache, 277 + output_attentions=output_attentions, 278 + output_hidden_states=output_hidden_states, 279 + return_dict=return_dict, 280 + cache_position=cache_position, 281 + ) 282 + 283 + return outputs 284 + 285 + def _update_causal_mask( 286 + self, 287 + attention_mask, 288 + input_tensor, 289 + cache_position, 290 + past_key_values, 291 + output_attentions, 292 + ): 293 + dtype, device = input_tensor.dtype, input_tensor.device 294 + min_dtype = torch.finfo(dtype).min 295 + batch_size, sequence_length = input_tensor.shape[0], input_tensor.shape[1] 296 + 297 + token_type_ids = self._current_token_type_ids 298 + 299 + # attention mask 300 + causal_mask = self._create_custom_4d_mask( 301 + sequence_length=sequence_length, 302 + dtype=dtype, 303 + device=device, 304 + batch_size=batch_size, 305 + token_type_ids=token_type_ids, 306 + ) 307 + 308 + # padding mask 309 + if attention_mask is not None and attention_mask.dim() == 2: 310 + padding_mask = attention_mask[:, None, None, :].to(dtype=dtype) 311 + padding_mask = (1.0 - padding_mask) * min_dtype 312 + causal_mask = causal_mask + padding_mask 313 + 314 + return causal_mask 315 + 316 + def _create_custom_4d_mask( 317 + self, 318 + sequence_length, 319 + dtype, 320 + device, 321 + batch_size, 322 + token_type_ids, 323 + ): 324 + min_dtype = torch.finfo(dtype).min 325 + 326 + masks = [] 327 + for b in range(batch_size): 328 + mask = torch.full( 329 + (sequence_length, sequence_length), 330 + fill_value=min_dtype, 331 + dtype=dtype, 332 + device=device 333 + ) 334 + 335 + type_ids = token_type_ids[b] 336 + 337 + image_positions = (type_ids == 0).nonzero(as_tuple=True)[0] 338 + text_positions = (type_ids == 1).nonzero(as_tuple=True)[0] 339 + 340 + # non-casual 341 + if len(image_positions) > 0: 342 + mask[image_positions[:, None], image_positions] = 0.0 343 + 344 + # causal 345 + for i, text_pos in enumerate(text_positions): 346 + if len(image_positions) > 0: 347 + mask[text_pos, image_positions] = 0.0 348 + mask[text_pos, text_positions[:i+1]] = 0.0 349 + 350 + masks.append(mask) 351 + 352 + mask = torch.stack(masks, dim=0).unsqueeze(1) 353 + return mask 354 + 355 + return CustomQwen2ModelInner(config) 356 + 357 + def forward( 358 + self, 359 + inputs_embeds, 360 + token_type_ids, 361 + attention_mask=None, 362 + **kwargs 363 + ): 364 + """ 365 + Args: 366 + inputs_embeds: [batch_size, seq_len, hidden_dim] 367 + token_type_ids: [batch_size, seq_len], 0=non-causal, 1=causal 368 + attention_mask: [batch_size, seq_len], optional 369 + """ 370 + return self.model( 371 + inputs_embeds=inputs_embeds, 372 + token_type_ids=token_type_ids, 373 + attention_mask=attention_mask, 374 + **kwargs 375 + ) 376 + 377 + 378 + 379 + 380 + 381 + # batch_size = 2 382 + # inputs_embeds = torch.randn(batch_size, 512, 896).cuda() 383 + 384 + # inputs_embeds = torch.randn(batch_size, 512, 896).cuda() 385 + # token_type_ids = torch.cat([ 386 + # torch.zeros(batch_size, 256, dtype=torch.long), 387 + # torch.ones(batch_size, 256, dtype=torch.long), 388 + # ], dim=1).cuda() 389 + 390 + # # start = time.time() 391 + # with torch.no_grad(): 392 + # outputs_sdpa = decoder_sdpa(inputs_embeds, token_type_ids) 393 + # print(outputs_sdpa[0].shape) 394 + # print(f"SDPA time: {time.time() - start:.4f}s") 395 + 396 + 397 + 398 + class Qwen2Decoder2Encoder(nn.Module): 399 + """ 400 + Decoder based on Multilingual BART 401 + Set the initial weights and configuration with a pretrained multilingual BART model, 402 + and modify the detailed configurations as a Nougat decoder 403 + """ 404 + 405 + def __init__( 406 + self, 407 + decoder_layer: int, 408 + hidden_dimension: int, 409 + num_attention_heads: int, 410 + num_key_value_heads: int, 411 + intermediate_size: int, 412 + max_query: int, 413 + ): 414 + super().__init__() 415 + 416 + self.model = CustomQwen2Decoder( 417 + decoder_layer=decoder_layer, 418 + hidden_dimension=hidden_dimension, 419 + num_attention_heads=num_attention_heads, 420 + num_key_value_heads=num_key_value_heads, 421 + intermediate_size=intermediate_size, 422 + attn_implementation="sdpa", 423 + ) 424 + 425 + 426 + 427 + 428 + self.query_768 = nn.Embedding(144, hidden_dimension) 429 + self.query_1024 = nn.Embedding(256, hidden_dimension) 430 + 431 + 432 + # self.query_refixation = nn.Embedding(int(math.sqrt(max_query)), hidden_dimension) 433 + 434 + 435 + def forward(self, x: torch.Tensor) -> torch.Tensor: 436 + x = x.flatten(2).transpose(1, 2) 437 + 438 + bs, n_query, _ = x.shape 439 + 440 + if n_query == 144: 441 + param_img = self.query_768.weight 442 + elif n_query == 256: 443 + param_img = self.query_1024.weight 444 + 445 + batch_query_imgs = param_img.unsqueeze(0).expand( 446 + bs, -1, -1 447 + ) # (batch_size, num_queries, hidden_size) 448 + 449 + 450 + 451 + x_combined = torch.cat([x, batch_query_imgs], dim=1) 452 + 453 + token_type_ids = torch.cat([ 454 + torch.zeros(bs, n_query, dtype=torch.long), 455 + torch.ones(bs, n_query, dtype=torch.long), 456 + ], dim=1) 457 + 458 + 459 + y = self.model(x_combined, token_type_ids)[0] 460 + 461 + 462 + y = y[:, n_query:, :] # causal flow query 463 + 464 + 465 + return y 466 + 467 + 468 + def build_qwen2_decoder_as_encoder( 469 + decoder_layer=24, 470 + hidden_dimension=896, 471 + num_attention_heads=14, 472 + num_key_value_heads=2, 473 + intermediate_size=4864, 474 + max_query = 400, 475 + checkpoint=None, 476 + ): 477 + 478 + decoder_as_encoder = Qwen2Decoder2Encoder( 479 + decoder_layer=decoder_layer, 480 + hidden_dimension = hidden_dimension, 481 + num_attention_heads = num_attention_heads, 482 + num_key_value_heads = num_key_value_heads, 483 + intermediate_size = intermediate_size, 484 + max_query = max_query 485 + ) 486 + 487 + 488 + 489 + 490 + if checkpoint is not None: 491 + # with open(checkpoint, "rb") as f: 492 + state_dict = torch.load(checkpoint) 493 + 494 + decoder_as_encoder.load_state_dict(state_dict, strict=True) 495 + # tob 496 + print(checkpoint) 497 + return decoder_as_encoder 498 + 499 + 500 + 501 + 502 + #=========================Sam-Vary================================= 503 + 504 + 505 + def get_abs_pos_sam(abs_pos, tgt_size): 506 + 507 + dtype = abs_pos.dtype 508 + 509 + src_size = abs_pos.size(1) 510 + 511 + if src_size != tgt_size: 512 + old_pos_embed = abs_pos.permute(0, 3, 1, 2) 513 + old_pos_embed = old_pos_embed.to(torch.float32) 514 + new_pos_embed = F.interpolate( 515 + old_pos_embed, 516 + size=(tgt_size, tgt_size), 517 + mode='bicubic', 518 + antialias=True, 519 + align_corners=False, 520 + ).to(dtype) 521 + new_pos_embed = new_pos_embed.permute(0, 2, 3, 1) 522 + return new_pos_embed 523 + else: 524 + return abs_pos 525 + 526 + 527 + 528 + 529 + class MLPBlock(nn.Module): 530 + def __init__( 531 + self, 532 + embedding_dim: int, 533 + mlp_dim: int, 534 + act: Type[nn.Module] = nn.GELU, 535 + ) -> None: 536 + super().__init__() 537 + self.lin1 = nn.Linear(embedding_dim, mlp_dim) 538 + self.lin2 = nn.Linear(mlp_dim, embedding_dim) 539 + self.act = act() 540 + 541 + def forward(self, x: torch.Tensor) -> torch.Tensor: 542 + return self.lin2(self.act(self.lin1(x))) 543 + 544 + 545 + # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 546 + # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 547 + class LayerNorm2d(nn.Module): 548 + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 549 + super().__init__() 550 + self.weight = nn.Parameter(torch.ones(num_channels)) 551 + self.bias = nn.Parameter(torch.zeros(num_channels)) 552 + self.eps = eps 553 + 554 + def forward(self, x: torch.Tensor) -> torch.Tensor: 555 + u = x.mean(1, keepdim=True) 556 + s = (x - u).pow(2).mean(1, keepdim=True) 557 + x = (x - u) / torch.sqrt(s + self.eps) 558 + x = self.weight[:, None, None] * x + self.bias[:, None, None] 559 + return x 560 + 561 + 562 + # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa 563 + class ImageEncoderViT(nn.Module): 564 + def __init__( 565 + self, 566 + img_size: int = 1024, 567 + patch_size: int = 16, 568 + in_chans: int = 3, 569 + embed_dim: int = 768, 570 + depth: int = 12, 571 + num_heads: int = 12, 572 + mlp_ratio: float = 4.0, 573 + out_chans: int = 256, 574 + qkv_bias: bool = True, 575 + norm_layer: Type[nn.Module] = nn.LayerNorm, 576 + act_layer: Type[nn.Module] = nn.GELU, 577 + use_abs_pos: bool = True, 578 + use_rel_pos: bool = False, 579 + rel_pos_zero_init: bool = True, 580 + window_size: int = 0, 581 + global_attn_indexes: Tuple[int, ...] = (), 582 + ) -> None: 583 + """ 584 + Args: 585 + img_size (int): Input image size. 586 + patch_size (int): Patch size. 587 + in_chans (int): Number of input image channels. 588 + embed_dim (int): Patch embedding dimension. 589 + depth (int): Depth of ViT. 590 + num_heads (int): Number of attention heads in each ViT block. 591 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 592 + qkv_bias (bool): If True, add a learnable bias to query, key, value. 593 + norm_layer (nn.Module): Normalization layer. 594 + act_layer (nn.Module): Activation layer. 595 + use_abs_pos (bool): If True, use absolute positional embeddings. 596 + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 597 + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 598 + window_size (int): Window size for window attention blocks. 599 + global_attn_indexes (list): Indexes for blocks using global attention. 600 + """ 601 + super().__init__() 602 + self.img_size = img_size 603 + 604 + self.patch_embed = PatchEmbed( 605 + kernel_size=(patch_size, patch_size), 606 + stride=(patch_size, patch_size), 607 + in_chans=in_chans, 608 + embed_dim=embed_dim, 609 + ) 610 + 611 + self.pos_embed: Optional[nn.Parameter] = None 612 + if use_abs_pos: 613 + # Initialize absolute positional embedding with pretrain image size. 614 + self.pos_embed = nn.Parameter( 615 + torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) 616 + ) 617 + 618 + self.blocks = nn.ModuleList() 619 + for i in range(depth): 620 + block = Block( 621 + dim=embed_dim, 622 + num_heads=num_heads, 623 + mlp_ratio=mlp_ratio, 624 + qkv_bias=qkv_bias, 625 + norm_layer=norm_layer, 626 + act_layer=act_layer, 627 + use_rel_pos=use_rel_pos, 628 + rel_pos_zero_init=rel_pos_zero_init, 629 + window_size=window_size if i not in global_attn_indexes else 0, 630 + input_size=(img_size // patch_size, img_size // patch_size), 631 + ) 632 + self.blocks.append(block) 633 + 634 + self.neck = nn.Sequential( 635 + nn.Conv2d( 636 + embed_dim, 637 + out_chans, 638 + kernel_size=1, 639 + bias=False, 640 + ), 641 + LayerNorm2d(out_chans), 642 + nn.Conv2d( 643 + out_chans, 644 + out_chans, 645 + kernel_size=3, 646 + padding=1, 647 + bias=False, 648 + ), 649 + LayerNorm2d(out_chans), 650 + ) 651 + 652 + self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False) 653 + self.net_3 = nn.Conv2d(512, 896, kernel_size=3, stride=2, padding=1, bias=False) 654 + 655 + def forward(self, x: torch.Tensor) -> torch.Tensor: 656 + x = self.patch_embed(x) 657 + if self.pos_embed is not None: 658 + # x = x + self.pos_embed 659 + x = x + get_abs_pos_sam(self.pos_embed, x.size(1)) 660 + 661 + for blk in self.blocks: 662 + x = blk(x) 663 + 664 + x = self.neck(x.permute(0, 3, 1, 2)) 665 + x2 = self.net_2(x) 666 + x3 = self.net_3(x2.clone()) 667 + 668 + return x3 669 + 670 + 671 + class Block(nn.Module): 672 + """Transformer blocks with support of window attention and residual propagation blocks""" 673 + 674 + def __init__( 675 + self, 676 + dim: int, 677 + num_heads: int, 678 + mlp_ratio: float = 4.0, 679 + qkv_bias: bool = True, 680 + norm_layer: Type[nn.Module] = nn.LayerNorm, 681 + act_layer: Type[nn.Module] = nn.GELU, 682 + use_rel_pos: bool = False, 683 + rel_pos_zero_init: bool = True, 684 + window_size: int = 0, 685 + input_size: Optional[Tuple[int, int]] = None, 686 + ) -> None: 687 + """ 688 + Args: 689 + dim (int): Number of input channels. 690 + num_heads (int): Number of attention heads in each ViT block. 691 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 692 + qkv_bias (bool): If True, add a learnable bias to query, key, value. 693 + norm_layer (nn.Module): Normalization layer. 694 + act_layer (nn.Module): Activation layer. 695 + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 696 + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 697 + window_size (int): Window size for window attention blocks. If it equals 0, then 698 + use global attention. 699 + input_size (tuple(int, int) or None): Input resolution for calculating the relative 700 + positional parameter size. 701 + """ 702 + super().__init__() 703 + self.norm1 = norm_layer(dim) 704 + self.attn = Attention( 705 + dim, 706 + num_heads=num_heads, 707 + qkv_bias=qkv_bias, 708 + use_rel_pos=use_rel_pos, 709 + rel_pos_zero_init=rel_pos_zero_init, 710 + input_size=input_size if window_size == 0 else (window_size, window_size), 711 + ) 712 + 713 + self.norm2 = norm_layer(dim) 714 + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) 715 + 716 + self.window_size = window_size 717 + 718 + def forward(self, x: torch.Tensor) -> torch.Tensor: 719 + shortcut = x 720 + x = self.norm1(x) 721 + # Window partition 722 + if self.window_size > 0: 723 + H, W = x.shape[1], x.shape[2] 724 + x, pad_hw = window_partition(x, self.window_size) 725 + 726 + x = self.attn(x) 727 + # Reverse window partition 728 + if self.window_size > 0: 729 + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) 730 + 731 + x = shortcut + x 732 + x = x + self.mlp(self.norm2(x)) 733 + 734 + return x 735 + 736 + 737 + class Attention(nn.Module): 738 + """Multi-head Attention block with relative position embeddings.""" 739 + 740 + def __init__( 741 + self, 742 + dim: int, 743 + num_heads: int = 8, 744 + qkv_bias: bool = True, 745 + use_rel_pos: bool = False, 746 + rel_pos_zero_init: bool = True, 747 + input_size: Optional[Tuple[int, int]] = None, 748 + ) -> None: 749 + """ 750 + Args: 751 + dim (int): Number of input channels. 752 + num_heads (int): Number of attention heads. 753 + qkv_bias (bool): If True, add a learnable bias to query, key, value. 754 + rel_pos (bool): If True, add relative positional embeddings to the attention map. 755 + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 756 + input_size (tuple(int, int) or None): Input resolution for calculating the relative 757 + positional parameter size. 758 + """ 759 + super().__init__() 760 + self.num_heads = num_heads 761 + head_dim = dim // num_heads 762 + self.scale = head_dim**-0.5 763 + 764 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 765 + self.proj = nn.Linear(dim, dim) 766 + 767 + self.use_rel_pos = use_rel_pos 768 + if self.use_rel_pos: 769 + assert ( 770 + input_size is not None 771 + ), "Input size must be provided if using relative positional encoding." 772 + # initialize relative positional embeddings 773 + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) 774 + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) 775 + 776 + def forward(self, x: torch.Tensor) -> torch.Tensor: 777 + B, H, W, _ = x.shape 778 + # qkv with shape (3, B, nHead, H * W, C) 779 + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 780 + # q, k, v with shape (B * nHead, H * W, C) 781 + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) 782 + 783 + rel_h, rel_w = None, None 784 + if self.use_rel_pos: 785 + rel_h, rel_w = add_decomposed_rel_pos(q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) 786 + 787 + q = q.view(B, self.num_heads, H * W, -1) 788 + k = k.view(B, self.num_heads, H * W, -1) 789 + v = v.view(B, self.num_heads, H * W, -1) 790 + 791 + if self.use_rel_pos: 792 + rel_h = rel_h.view(B, self.num_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3)) 793 + rel_w = rel_w.view(B, self.num_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3)) 794 + attn_bias = (rel_h + rel_w).view(B, self.num_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4)) 795 + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias) 796 + # x = _attention_rel_h_rel_w(q, k, v, rel_h, rel_w) 797 + else: 798 + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) 799 + 800 + x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) 801 + 802 + x = self.proj(x) 803 + 804 + return x 805 + 806 + 807 + def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: 808 + """ 809 + Partition into non-overlapping windows with padding if needed. 810 + Args: 811 + x (tensor): input tokens with [B, H, W, C]. 812 + window_size (int): window size. 813 + 814 + Returns: 815 + windows: windows after partition with [B * num_windows, window_size, window_size, C]. 816 + (Hp, Wp): padded height and width before partition 817 + """ 818 + B, H, W, C = x.shape 819 + 820 + pad_h = (window_size - H % window_size) % window_size 821 + pad_w = (window_size - W % window_size) % window_size 822 + if pad_h > 0 or pad_w > 0: 823 + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 824 + Hp, Wp = H + pad_h, W + pad_w 825 + 826 + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 827 + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 828 + return windows, (Hp, Wp) 829 + 830 + 831 + def window_unpartition( 832 + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] 833 + ) -> torch.Tensor: 834 + """ 835 + Window unpartition into original sequences and removing padding. 836 + Args: 837 + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 838 + window_size (int): window size. 839 + pad_hw (Tuple): padded height and width (Hp, Wp). 840 + hw (Tuple): original height and width (H, W) before padding. 841 + 842 + Returns: 843 + x: unpartitioned sequences with [B, H, W, C]. 844 + """ 845 + Hp, Wp = pad_hw 846 + H, W = hw 847 + B = windows.shape[0] // (Hp * Wp // window_size // window_size) 848 + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) 849 + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 850 + 851 + if Hp > H or Wp > W: 852 + x = x[:, :H, :W, :].contiguous() 853 + return x 854 + 855 + 856 + def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: 857 + """ 858 + Get relative positional embeddings according to the relative positions of 859 + query and key sizes. 860 + Args: 861 + q_size (int): size of query q. 862 + k_size (int): size of key k. 863 + rel_pos (Tensor): relative position embeddings (L, C). 864 + 865 + Returns: 866 + Extracted positional embeddings according to relative positions. 867 + """ 868 + max_rel_dist = int(2 * max(q_size, k_size) - 1) 869 + # Interpolate rel pos if needed. 870 + if rel_pos.shape[0] != max_rel_dist: 871 + # Interpolate rel pos. 872 + dtype = rel_pos.dtype 873 + rel_pos = rel_pos.to(torch.float32) 874 + rel_pos_resized = F.interpolate( 875 + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), 876 + size=max_rel_dist, 877 + mode="linear", 878 + ).to(dtype) 879 + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) 880 + else: 881 + rel_pos_resized = rel_pos 882 + 883 + # Scale the coords with short length if shapes for q and k are different. 884 + q_coords = torch.arange(q_size, device=rel_pos.device)[:, None] * max(k_size / q_size, 1.0) 885 + k_coords = torch.arange(k_size, device=rel_pos.device)[None, :] * max(q_size / k_size, 1.0) 886 + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) 887 + 888 + return rel_pos_resized[relative_coords.long()] 889 + 890 + 891 + def add_decomposed_rel_pos( 892 + q: torch.Tensor, 893 + rel_pos_h: torch.Tensor, 894 + rel_pos_w: torch.Tensor, 895 + q_size: Tuple[int, int], 896 + k_size: Tuple[int, int], 897 + ) -> torch.Tensor: 898 + """ 899 + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. 900 + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 901 + Args: 902 + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). 903 + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. 904 + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. 905 + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). 906 + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). 907 + 908 + Returns: 909 + attn (Tensor): attention map with added relative positional embeddings. 910 + """ 911 + q_h, q_w = q_size 912 + k_h, k_w = k_size 913 + Rh = get_rel_pos(q_h, k_h, rel_pos_h) 914 + Rw = get_rel_pos(q_w, k_w, rel_pos_w) 915 + 916 + B, _, dim = q.shape 917 + r_q = q.reshape(B, q_h, q_w, dim) 918 + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) 919 + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) 920 + rel_h = rel_h.unsqueeze(-1) 921 + rel_w = rel_w.unsqueeze(-2) 922 + rel_h = rel_h.reshape(B, q_h * q_w, k_h, 1) 923 + rel_w = rel_w.reshape(B, q_h * q_w, 1, k_w) 924 + 925 + return rel_h, rel_w 926 + 927 + 928 + class PatchEmbed(nn.Module): 929 + """ 930 + Image to Patch Embedding. 931 + """ 932 + 933 + def __init__( 934 + self, 935 + kernel_size: Tuple[int, int] = (16, 16), 936 + stride: Tuple[int, int] = (16, 16), 937 + padding: Tuple[int, int] = (0, 0), 938 + in_chans: int = 3, 939 + embed_dim: int = 768, 940 + ) -> None: 941 + """ 942 + Args: 943 + kernel_size (Tuple): kernel size of the projection layer. 944 + stride (Tuple): stride of the projection layer. 945 + padding (Tuple): padding size of the projection layer. 946 + in_chans (int): Number of input image channels. 947 + embed_dim (int): Patch embedding dimension. 948 + """ 949 + super().__init__() 950 + 951 + self.proj = nn.Conv2d( 952 + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 953 + ) 954 + 955 + def forward(self, x: torch.Tensor) -> torch.Tensor: 956 + x = self.proj(x) 957 + # B C H W -> B H W C 958 + x = x.permute(0, 2, 3, 1) 959 + return x 960 + 961 + 962 + def build_sam_vit_b(checkpoint=None): 963 + return _build_sam( 964 + encoder_embed_dim=768, 965 + encoder_depth=12, 966 + encoder_num_heads=12, 967 + encoder_global_attn_indexes=[2, 5, 8, 11], 968 + checkpoint=checkpoint, 969 + ) 970 + 971 + def build_sam_fast_vit_b(checkpoint=None, compile_mode='max-autotune', dtype=torch.bfloat16): 972 + image_encoder = build_sam_vit_b(checkpoint).eval().to(dtype) 973 + # sam = _apply_eval_dtype_sam(sam, dtype) 974 + image_encoder = torch.compile(image_encoder, mode=compile_mode) 975 + return image_encoder 976 + 977 + 978 + def _build_sam( 979 + encoder_embed_dim, 980 + encoder_depth, 981 + encoder_num_heads, 982 + encoder_global_attn_indexes, 983 + checkpoint=None, 984 + ): 985 + prompt_embed_dim = 256 986 + image_size = 1024 987 + vit_patch_size = 16 988 + image_embedding_size = image_size // vit_patch_size 989 + image_encoder=ImageEncoderViT( 990 + depth=encoder_depth, 991 + embed_dim=encoder_embed_dim, 992 + img_size=image_size, 993 + mlp_ratio=4, 994 + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 995 + num_heads=encoder_num_heads, 996 + patch_size=vit_patch_size, 997 + qkv_bias=True, 998 + use_rel_pos=True, 999 + global_attn_indexes=encoder_global_attn_indexes, 1000 + window_size=14, 1001 + out_chans=prompt_embed_dim, 1002 + ) 1003 + image_encoder.eval() 1004 + if checkpoint is not None: 1005 + # with open(checkpoint, "rb") as f: 1006 + state_dict = torch.load(checkpoint) 1007 + # print(state_dict.keys()) 1008 + # for key in state_dict: 1009 + # image_encoder.load_state_dict({k[14:]: v for k, v in state_dict.items() if 'image_encoder' in k}, strict=False) 1010 + # ocr-anyting 1011 + # image_encoder.load_state_dict(state_dict, strict=True) 1012 + # tob 1013 + image_encoder.load_state_dict({k[30:]: v for k, v in state_dict.items() if 'vision_tower_high' in k}, strict=True) 1014 + print(checkpoint) 1015 + return image_encoder
+1030
src/deepseek_ocr2/modeling_deepseekocr2.py
··· 1 + from .modeling_deepseekv2 import DeepseekV2Model, DeepseekV2ForCausalLM 2 + from .configuration_deepseek_v2 import DeepseekV2Config 3 + from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast 4 + from typing import List, Optional, Tuple, Union 5 + from transformers.cache_utils import Cache 6 + import requests 7 + from PIL import Image, ImageOps, ImageDraw, ImageFont 8 + from io import BytesIO 9 + import torch 10 + import torch.nn as nn 11 + from torch.nn import CrossEntropyLoss 12 + from torchvision import transforms 13 + # from torchvision.transforms.functional import InterpolationMode 14 + import os 15 + from .deepencoderv2 import build_sam_vit_b, build_qwen2_decoder_as_encoder, MlpProjector 16 + from addict import Dict 17 + from transformers import TextStreamer 18 + from .conversation import get_conv_template 19 + from abc import ABC 20 + import math 21 + import re 22 + from tqdm import tqdm 23 + import numpy as np 24 + # import time 25 + 26 + 27 + 28 + def load_image(image_path): 29 + 30 + try: 31 + image = Image.open(image_path) 32 + 33 + corrected_image = ImageOps.exif_transpose(image) 34 + 35 + return corrected_image 36 + 37 + except Exception as e: 38 + print(f"error: {e}") 39 + try: 40 + return Image.open(image_path) 41 + except: 42 + return None 43 + 44 + 45 + def re_match(text): 46 + pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)' 47 + matches = re.findall(pattern, text, re.DOTALL) 48 + 49 + # pattern1 = r'<\|ref\|>.*?<\|/ref\|>\n' 50 + # new_text1 = re.sub(pattern1, '', text, flags=re.DOTALL) 51 + 52 + mathes_image = [] 53 + mathes_other = [] 54 + for a_match in matches: 55 + if '<|ref|>image<|/ref|>' in a_match[0]: 56 + mathes_image.append(a_match[0]) 57 + else: 58 + mathes_other.append(a_match[0]) 59 + return matches, mathes_image, mathes_other 60 + 61 + 62 + def extract_coordinates_and_label(ref_text, image_width, image_height): 63 + 64 + try: 65 + label_type = ref_text[1] 66 + cor_list = eval(ref_text[2]) 67 + except Exception as e: 68 + print(e) 69 + return None 70 + 71 + return (label_type, cor_list) 72 + 73 + 74 + def draw_bounding_boxes(image, refs, ouput_path): 75 + 76 + image_width, image_height = image.size 77 + 78 + img_draw = image.copy() 79 + draw = ImageDraw.Draw(img_draw) 80 + 81 + overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0)) 82 + draw2 = ImageDraw.Draw(overlay) 83 + 84 + # try: 85 + # except IOError: 86 + # try: 87 + # font = ImageFont.truetype("DejaVuSans.ttf", 20) 88 + # except IOError: 89 + font = ImageFont.load_default() 90 + 91 + img_idx = 0 92 + 93 + for i, ref in enumerate(refs): 94 + try: 95 + result = extract_coordinates_and_label(ref, image_width, image_height) 96 + if result: 97 + label_type, points_list = result 98 + 99 + color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255)) 100 + 101 + color_a = color + (20, ) 102 + for points in points_list: 103 + x1, y1, x2, y2 = points 104 + 105 + x1 = int(x1 / 999 * image_width) 106 + y1 = int(y1 / 999 * image_height) 107 + 108 + x2 = int(x2 / 999 * image_width) 109 + y2 = int(y2 / 999 * image_height) 110 + 111 + if label_type == 'image': 112 + try: 113 + cropped = image.crop((x1, y1, x2, y2)) 114 + cropped.save(f"{ouput_path}/images/{img_idx}.jpg") 115 + except Exception as e: 116 + print(e) 117 + pass 118 + img_idx += 1 119 + 120 + try: 121 + if label_type == 'title': 122 + draw.rectangle([x1, y1, x2, y2], outline=color, width=4) 123 + draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1) 124 + else: 125 + draw.rectangle([x1, y1, x2, y2], outline=color, width=2) 126 + draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1) 127 + text_x = x1 128 + text_y = max(0, y1 - 15) 129 + 130 + 131 + text_bbox = draw.textbbox((0, 0), label_type, font=font) 132 + text_width = text_bbox[2] - text_bbox[0] 133 + text_height = text_bbox[3] - text_bbox[1] 134 + draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height], 135 + fill=(255, 255, 255, 30)) 136 + 137 + draw.text((text_x, text_y), label_type, font=font, fill=color) 138 + except: 139 + pass 140 + except: 141 + continue 142 + img_draw.paste(overlay, (0, 0), overlay) 143 + return img_draw 144 + 145 + 146 + def process_image_with_refs(image, ref_texts, output_path): 147 + 148 + result_image = draw_bounding_boxes(image, ref_texts, output_path) 149 + 150 + return result_image 151 + 152 + 153 + 154 + 155 + 156 + def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): 157 + best_ratio_diff = float('inf') 158 + best_ratio = (1, 1) 159 + area = width * height 160 + for ratio in target_ratios: 161 + target_aspect_ratio = ratio[0] / ratio[1] 162 + ratio_diff = abs(aspect_ratio - target_aspect_ratio) 163 + if ratio_diff < best_ratio_diff: 164 + best_ratio_diff = ratio_diff 165 + best_ratio = ratio 166 + elif ratio_diff == best_ratio_diff: 167 + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: 168 + best_ratio = ratio 169 + # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}') 170 + return best_ratio 171 + 172 + 173 + def dynamic_preprocess(image, min_num=2, max_num=6, image_size=768, use_thumbnail=False): 174 + orig_width, orig_height = image.size 175 + aspect_ratio = orig_width / orig_height 176 + 177 + # calculate the existing image aspect ratio 178 + target_ratios = set( 179 + (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if 180 + i * j <= max_num and i * j >= min_num) 181 + # print(target_ratios) 182 + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) 183 + 184 + # find the closest aspect ratio to the target 185 + target_aspect_ratio = find_closest_aspect_ratio( 186 + aspect_ratio, target_ratios, orig_width, orig_height, image_size) 187 + 188 + # print(target_aspect_ratio) 189 + # calculate the target width and height 190 + target_width = image_size * target_aspect_ratio[0] 191 + target_height = image_size * target_aspect_ratio[1] 192 + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] 193 + 194 + # resize the image 195 + resized_img = image.resize((target_width, target_height)) 196 + processed_images = [] 197 + for i in range(blocks): 198 + box = ( 199 + (i % (target_width // image_size)) * image_size, 200 + (i // (target_width // image_size)) * image_size, 201 + ((i % (target_width // image_size)) + 1) * image_size, 202 + ((i // (target_width // image_size)) + 1) * image_size 203 + ) 204 + # split the image 205 + split_img = resized_img.crop(box) 206 + processed_images.append(split_img) 207 + assert len(processed_images) == blocks 208 + if use_thumbnail and len(processed_images) != 1: 209 + thumbnail_img = image.resize((image_size, image_size)) 210 + processed_images.append(thumbnail_img) 211 + return processed_images, target_aspect_ratio 212 + 213 + 214 + 215 + def normalize_transform(mean, std): 216 + if mean is None and std is None: 217 + transform = None 218 + elif mean is None and std is not None: 219 + mean = [0.] * len(std) 220 + transform = transforms.Normalize(mean=mean, std=std) 221 + elif mean is not None and std is None: 222 + std = [1.] * len(mean) 223 + transform = transforms.Normalize(mean=mean, std=std) 224 + else: 225 + transform = transforms.Normalize(mean=mean, std=std) 226 + 227 + return transform 228 + 229 + 230 + 231 + def format_messages( 232 + conversations: List[Dict[str, str]], 233 + sft_format: str = "deepseek", 234 + system_prompt: str = "", 235 + ): 236 + """ 237 + Applies the SFT template to conversation. 238 + 239 + Args: 240 + conversations (List[Dict]): A List of messages. 241 + sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek". 242 + system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "". 243 + 244 + Returns: 245 + sft_prompt (str): The formatted text. 246 + """ 247 + 248 + conv = get_conv_template(sft_format) 249 + conv.set_system_message(system_prompt) 250 + for message in conversations: 251 + conv.append_message(message["role"], message["content"].strip()) 252 + sft_prompt = conv.get_prompt().strip() 253 + 254 + return sft_prompt 255 + 256 + 257 + def text_encode(tokenizer, text: str, bos: bool = True, eos: bool = False): 258 + t = tokenizer.encode(text, add_special_tokens=False) 259 + bos_id = 0 260 + eos_id = 1 261 + if bos: 262 + t = [bos_id] + t 263 + if eos: 264 + t = t + [eos_id] 265 + 266 + return t 267 + 268 + def load_pil_images(conversations: List[Dict[str, str]]) -> List[Image.Image]: 269 + """ 270 + 271 + Args: 272 + conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is : 273 + [ 274 + { 275 + "role": "User", 276 + "content": "<image_placeholder>\nExtract all information from this image and convert them into markdown format.", 277 + "images": ["./examples/table_datasets.png"] 278 + }, 279 + {"role": "Assistant", "content": ""}, 280 + ] 281 + 282 + Returns: 283 + pil_images (List[PIL.Image.Image]): the list of PIL images. 284 + 285 + """ 286 + 287 + pil_images = [] 288 + 289 + for message in conversations: 290 + if "images" not in message: 291 + continue 292 + 293 + for image_path in message["images"]: 294 + # print('----------------') 295 + # print(image_path) 296 + # print('----------------') 297 + # exit() 298 + 299 + # pil_img = Image.open(image_path) 300 + pil_img = load_image(image_path) 301 + pil_img = pil_img.convert("RGB") 302 + pil_images.append(pil_img) 303 + 304 + return pil_images 305 + 306 + 307 + class BaseTransform(ABC): 308 + 309 + def set_rng(self, *args, **kwargs): 310 + pass 311 + 312 + def __call__(self, *args, **kwargs) -> torch.Tensor: 313 + pass 314 + 315 + @property 316 + def default_shape(self): 317 + raise NotImplementedError 318 + 319 + 320 + class BasicImageTransform(BaseTransform): 321 + def __init__( 322 + self, 323 + mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), 324 + std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), 325 + normalize: bool = True 326 + ): 327 + self.mean = mean 328 + self.std = std 329 + 330 + transform_pipelines = [ 331 + transforms.ToTensor() 332 + ] 333 + 334 + normalize = normalize_transform(mean, std) if normalize else nn.Identity() 335 + if normalize is not None: 336 + transform_pipelines.append(normalize) 337 + 338 + self.transform = transforms.Compose(transform_pipelines) 339 + 340 + def __call__(self, x): 341 + x = self.transform(x) 342 + return x 343 + 344 + class NoEOSTextStreamer(TextStreamer): 345 + def on_finalized_text(self, text: str, stream_end: bool = False): 346 + 347 + eos_text = self.tokenizer.decode([self.tokenizer.eos_token_id], skip_special_tokens=False) 348 + text = text.replace(eos_text, "\n") 349 + print(text, flush=True, end="") 350 + 351 + 352 + class DeepseekOCR2Config(DeepseekV2Config): 353 + model_type = "DeepseekOCR2" 354 + 355 + class DeepseekOCR2Model(DeepseekV2Model): 356 + config_class = DeepseekOCR2Config 357 + 358 + def __init__(self, config: DeepseekV2Config): 359 + super(DeepseekOCR2Model, self).__init__(config) 360 + 361 + self.sam_model = build_sam_vit_b() 362 + self.qwen2_model = build_qwen2_decoder_as_encoder() 363 + # self.conv_2 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=2, stride=2) 364 + n_embed = 1280 365 + self.projector = MlpProjector(Dict(projector_type="linear", input_dim=896, n_embed=n_embed)) 366 + embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32)) 367 + # self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std) 368 + self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std) 369 + 370 + 371 + 372 + 373 + def forward( 374 + self, 375 + input_ids: torch.LongTensor = None, 376 + attention_mask: Optional[torch.Tensor] = None, 377 + position_ids: Optional[torch.LongTensor] = None, 378 + past_key_values: Optional[List[torch.FloatTensor]] = None, 379 + inputs_embeds: Optional[torch.FloatTensor] = None, 380 + use_cache: Optional[bool] = None, 381 + output_attentions: Optional[bool] = None, 382 + output_hidden_states: Optional[bool] = None, 383 + images: Optional[torch.FloatTensor] = None, 384 + images_seq_mask: Optional[torch.FloatTensor] = None, 385 + images_spatial_crop: Optional[torch.FloatTensor] = None, 386 + return_dict: Optional[bool] = None, 387 + ) -> Union[Tuple, BaseModelOutputWithPast]: 388 + 389 + 390 + 391 + 392 + if inputs_embeds is None: 393 + # inputs_embeds = self.embed_tokens(input_ids) 394 + # inputs_embeds = self.embed_tokens(input_ids) 395 + inputs_embeds = self.get_input_embeddings()(input_ids) 396 + 397 + 398 + 399 + sam_model = getattr(self, 'sam_model', None) 400 + # sam_model = self.sam_model 401 + qwen2_model = getattr(self, 'qwen2_model', None) 402 + 403 + 404 + 405 + if sam_model is not None and (input_ids.shape[1] != 1 or self.training) and torch.sum(images[0][1]).item() != 0: 406 + 407 + idx = 0 408 + 409 + # sam_model = torch.jit.script(sam_model) 410 + 411 + # start_time = time.time() 412 + for image, crop_shape in zip(images, images_spatial_crop): 413 + images_in_this_batch = [] 414 + 415 + patches = image[0] 416 + image_ori = image[1] 417 + 418 + with torch.no_grad(): 419 + # with torch.inference_mode(): 420 + 421 + if torch.sum(patches).item() != 0: 422 + # P, C, H, W = patches.shape 423 + crop_flag = 1 424 + local_features_1 = sam_model(patches) 425 + 426 + local_features_2 = qwen2_model(local_features_1) 427 + # vit_time = time.time() 428 + local_features = local_features_2 429 + local_features = self.projector(local_features) 430 + 431 + 432 + global_features_1 = sam_model(image_ori) 433 + global_features_2 = qwen2_model(global_features_1) 434 + global_features = global_features_2 435 + global_features = self.projector(global_features) 436 + 437 + print('=====================') 438 + print('BASE: ', global_features.shape) 439 + print('PATCHES: ', local_features.shape) 440 + print('=====================') 441 + 442 + _, hw, n_dim = global_features.shape 443 + # h = w = int(hw ** 0.5) 444 + 445 + _2, hw2, n_dim2 = local_features.shape 446 + # h2 = w2 = int(hw2 ** 0.5) 447 + 448 + 449 + global_features = global_features.view(-1, n_dim) 450 + 451 + 452 + local_features = local_features.view(-1, n_dim2) 453 + 454 + global_local_features = torch.cat([local_features, global_features, self.view_seperator[None, :]], dim=0) 455 + 456 + # end_time = time.time() 457 + 458 + # print('sam: ', sam_time - start_time) 459 + # print('vit: ', vit_time - sam_time) 460 + # print('all: ', end_time - start_time) 461 + 462 + # exit() 463 + 464 + else: 465 + global_features_1 = sam_model(image_ori) 466 + global_features_2 = qwen2_model(global_features_1) 467 + global_features = global_features_2 468 + global_features = self.projector(global_features) 469 + print('=====================') 470 + print('BASE: ', global_features.shape) 471 + print('NO PATCHES') 472 + print('=====================') 473 + _, hw, n_dim = global_features.shape 474 + # h = w = int(hw ** 0.5) 475 + 476 + 477 + # global_features = global_features.view(h, w, n_dim) 478 + 479 + # global_features = torch.cat( 480 + # [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1 481 + # ) 482 + 483 + global_features = global_features.view(-1, n_dim) 484 + 485 + global_local_features = torch.cat([global_features, self.view_seperator[None, :]], dim=0) 486 + 487 + images_in_this_batch.append(global_local_features) 488 + 489 + 490 + # print(inputs_embeds.shape) 491 + 492 + if images_in_this_batch: 493 + images_in_this_batch = torch.cat(images_in_this_batch, dim=0) 494 + # exit() 495 + 496 + inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).cuda(), images_in_this_batch) 497 + 498 + idx += 1 499 + 500 + 501 + return super(DeepseekOCR2Model, self).forward( 502 + input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, 503 + inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids, 504 + output_attentions=output_attentions, output_hidden_states=output_hidden_states, 505 + return_dict=return_dict 506 + ) 507 + 508 + 509 + class DeepseekOCR2ForCausalLM(DeepseekV2ForCausalLM): 510 + 511 + config_class = DeepseekOCR2Config 512 + # supports_gradient_checkpointing = True 513 + 514 + def __init__(self, config): 515 + super(DeepseekV2ForCausalLM, self).__init__(config) 516 + self.model = DeepseekOCR2Model(config) 517 + 518 + self.vocab_size = config.vocab_size 519 + 520 + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 521 + 522 + # self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 523 + 524 + # Initialize weights and apply final processing 525 + self.post_init() 526 + 527 + def get_model(self): 528 + return self.model 529 + 530 + 531 + def forward( 532 + self, 533 + input_ids: torch.LongTensor = None, 534 + attention_mask: Optional[torch.Tensor] = None, 535 + position_ids: Optional[torch.LongTensor] = None, 536 + past_key_values: Optional[List[torch.FloatTensor]] = None, 537 + inputs_embeds: Optional[torch.FloatTensor] = None, 538 + labels: Optional[torch.LongTensor] = None, 539 + use_cache: Optional[bool] = None, 540 + output_attentions: Optional[bool] = None, 541 + output_hidden_states: Optional[bool] = None, 542 + images: Optional[torch.FloatTensor] = None, 543 + images_seq_mask: Optional[torch.FloatTensor] = None, 544 + images_spatial_crop: Optional[torch.FloatTensor] = None, 545 + return_dict: Optional[bool] = None, 546 + 547 + ) -> Union[Tuple, CausalLMOutputWithPast]: 548 + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 549 + output_hidden_states = ( 550 + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 551 + ) 552 + return_dict = return_dict if return_dict is not None else self.config.use_return_dict 553 + 554 + 555 + 556 + outputs = self.model( 557 + input_ids=input_ids, 558 + past_key_values=past_key_values, 559 + attention_mask=attention_mask, 560 + position_ids=position_ids, 561 + inputs_embeds=inputs_embeds, 562 + use_cache=use_cache, 563 + output_attentions=output_attentions, 564 + output_hidden_states=output_hidden_states, 565 + images=images, 566 + images_seq_mask = images_seq_mask, 567 + images_spatial_crop = images_spatial_crop, 568 + return_dict=return_dict 569 + 570 + ) 571 + 572 + 573 + 574 + # print(transformer_outputs) 575 + 576 + hidden_states = outputs[0] 577 + logits = self.lm_head(hidden_states) 578 + logits = logits.float() 579 + 580 + # logits 581 + 582 + loss = None 583 + if labels is not None: 584 + # Shift so that tokens < n predict n 585 + shift_logits = logits[..., :-1, :].contiguous() 586 + shift_labels = labels[..., 1:].contiguous() 587 + # Flatten the tokens 588 + loss_fct = CrossEntropyLoss() 589 + shift_logits = shift_logits.view(-1, self.config.vocab_size) 590 + shift_labels = shift_labels.view(-1) 591 + # Enable model parallelism 592 + shift_labels = shift_labels.to(shift_logits.device) 593 + loss = loss_fct(shift_logits, shift_labels) 594 + 595 + if not return_dict: 596 + output = (logits,) + outputs[1:] 597 + return (loss,) + output if loss is not None else output 598 + 599 + return CausalLMOutputWithPast( 600 + loss=loss, 601 + logits=logits, 602 + past_key_values=outputs.past_key_values, 603 + hidden_states=outputs.hidden_states, 604 + attentions=outputs.attentions, 605 + ) 606 + 607 + 608 + def prepare_inputs_for_generation( 609 + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 610 + ): 611 + # Omit tokens covered by past_key_values 612 + past_length = 0 613 + if past_key_values is not None: 614 + if isinstance(past_key_values, Cache): 615 + cache_length = past_key_values.get_seq_length() 616 + past_length = past_key_values.seen_tokens 617 + max_cache_length = past_key_values.get_max_length() 618 + else: 619 + cache_length = past_length = past_key_values[0][0].shape[2] 620 + max_cache_length = None 621 + 622 + # Keep only the unprocessed tokens: 623 + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where 624 + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as 625 + # input) 626 + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: 627 + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] 628 + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard 629 + # input_ids based on the past_length. 630 + elif past_length < input_ids.shape[1]: 631 + input_ids = input_ids[:, past_length:] 632 + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. 633 + 634 + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. 635 + if ( 636 + max_cache_length is not None 637 + and attention_mask is not None 638 + and cache_length + input_ids.shape[1] > max_cache_length 639 + ): 640 + attention_mask = attention_mask[:, -max_cache_length:] 641 + 642 + position_ids = kwargs.get("position_ids", None) 643 + if attention_mask is not None and position_ids is None: 644 + # create position_ids on the fly for batch generation 645 + position_ids = attention_mask.long().cumsum(-1) - 1 646 + position_ids.masked_fill_(attention_mask == 0, 1) 647 + if past_key_values: 648 + position_ids = position_ids[:, -input_ids.shape[1] :] 649 + 650 + # if self.generation_config.cache_implementation == "static": 651 + # # generation with static cache 652 + # cache_position = kwargs.get("cache_position", None) 653 + # if cache_position is None: 654 + # past_length = 0 655 + # else: 656 + # past_length = cache_position[-1] + 1 657 + # input_ids = input_ids[:, past_length:] 658 + # position_ids = position_ids[:, past_length:] 659 + 660 + # TODO @gante we should only keep a `cache_position` in generate, and do +=1. 661 + # same goes for position ids. Could also help with continued generation. 662 + cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) 663 + 664 + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 665 + if inputs_embeds is not None and past_key_values is None: 666 + model_inputs = {"inputs_embeds": inputs_embeds} 667 + else: 668 + model_inputs = {"input_ids": input_ids} 669 + 670 + model_inputs.update( 671 + { 672 + "position_ids": position_ids, 673 + "past_key_values": past_key_values, 674 + "use_cache": kwargs.get("use_cache"), 675 + "attention_mask": attention_mask, 676 + "images": kwargs.get("images", None), 677 + "images_seq_mask": kwargs.get("images_seq_mask", None), 678 + "images_spatial_crop": kwargs.get("images_spatial_crop", None), 679 + } 680 + ) 681 + return model_inputs 682 + 683 + 684 + def disable_torch_init(self): 685 + """ 686 + Disable the redundant torch default initialization to accelerate model creation. 687 + """ 688 + import torch 689 + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 690 + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 691 + 692 + 693 + 694 + def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False): 695 + self.disable_torch_init() 696 + 697 + os.makedirs(output_path, exist_ok=True) 698 + os.makedirs(f'{output_path}/images', exist_ok=True) 699 + 700 + if prompt and image_file: 701 + conversation = [ 702 + { 703 + "role": "<|User|>", 704 + # "content": "<image>\n<|grounding|>Given the layout of the image. ", 705 + "content": f'{prompt}', 706 + # "content": "君不见黄河之水天上来的下一句是什么?", 707 + # "content": "<image>\nFree OCR. ", 708 + # "content": "<image>\nParse the figure. ", 709 + # "content": "<image>\nExtract the text in the image. ", 710 + "images": [f'{image_file}'], 711 + }, 712 + {"role": "<|Assistant|>", "content": ""}, 713 + ] 714 + 715 + elif prompt: 716 + conversation = [ 717 + { 718 + "role": "<|User|>", 719 + # "content": "<image>\n<|grounding|>Given the layout of the image. ", 720 + "content": f'{prompt}', 721 + # "content": "君不见黄河之水天上来的下一句是什么?", 722 + # "content": "<image>\nFree OCR. ", 723 + # "content": "<image>\nParse the figure. ", 724 + # "content": "<image>\nExtract the text in the image. ", 725 + # "images": [f'{image_file}'], 726 + }, 727 + {"role": "<|Assistant|>", "content": ""}, 728 + ] 729 + else: 730 + assert False, f'prompt is none!' 731 + 732 + prompt = format_messages(conversations=conversation, sft_format='plain', system_prompt='') 733 + 734 + patch_size = 16 735 + downsample_ratio = 4 736 + images = load_pil_images(conversation) 737 + 738 + valid_img_tokens = 0 739 + ratio = 1 740 + 741 + image_draw = images[0].copy() 742 + 743 + w,h = image_draw.size 744 + # print(w, h) 745 + ratio = 1 - ((max(w, h) - min(w, h)) / (max(w, h))) 746 + 747 + 748 + image_transform=BasicImageTransform(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), normalize=True) 749 + images_seq_mask = [] 750 + 751 + image_token = '<image>' 752 + image_token_id = 128815 753 + text_splits = prompt.split(image_token) 754 + 755 + images_list, images_crop_list, images_seq_mask = [], [], [] 756 + tokenized_str = [] 757 + images_spatial_crop = [] 758 + for text_sep, image in zip(text_splits, images): 759 + 760 + tokenized_sep = text_encode(tokenizer, text_sep, bos=False, eos=False) 761 + tokenized_str += tokenized_sep 762 + images_seq_mask += [False] * len(tokenized_sep) 763 + 764 + if crop_mode: 765 + 766 + if image.size[0] <= 768 and image.size[1] <= 768: 767 + crop_ratio = [1, 1] 768 + 769 + else: 770 + if crop_mode: 771 + # best_width, best_height = select_best_resolution(image.size, self.candidate_resolutions) 772 + images_crop_raw, crop_ratio = dynamic_preprocess(image) 773 + else: 774 + # best_width, best_height = self.image_size, self.image_size 775 + crop_ratio = [1, 1] 776 + 777 + """process the global view""" 778 + # image = image.resize((base_size, base_size)) 779 + global_view = ImageOps.pad(image, (base_size, base_size), 780 + color=tuple(int(x * 255) for x in image_transform.mean)) 781 + 782 + if base_size == 1024: 783 + valid_img_tokens += int(256 * ratio) 784 + elif base_size == 1280: 785 + valid_img_tokens += int(400 * ratio) 786 + # elif base_size == 640: 787 + # valid_img_tokens += int(100 * ratio) 788 + 789 + 790 + 791 + 792 + 793 + images_list.append(image_transform(global_view).to(torch.bfloat16)) 794 + 795 + # global_view_tensor = image_transform(global_view).to(torch.bfloat16) 796 + 797 + width_crop_num, height_crop_num = crop_ratio 798 + 799 + images_spatial_crop.append([width_crop_num, height_crop_num]) 800 + 801 + 802 + if width_crop_num > 1 or height_crop_num > 1: 803 + """process the local views""" 804 + 805 + for i in range(len(images_crop_raw)): 806 + images_crop_list.append(image_transform(images_crop_raw[i]).to(torch.bfloat16)) 807 + 808 + if image_size == 768: 809 + valid_img_tokens += len(images_crop_list) * 144 810 + 811 + num_queries = math.ceil((image_size // patch_size) / downsample_ratio) 812 + num_queries_base = math.ceil((base_size // patch_size) / downsample_ratio) 813 + 814 + 815 + 816 + """add image tokens""" 817 + 818 + 819 + 820 + tokenized_image = ([image_token_id] * num_queries_base) * num_queries_base 821 + tokenized_image += [image_token_id] 822 + if width_crop_num > 1 or height_crop_num > 1: 823 + tokenized_image += ([image_token_id] * (num_queries * width_crop_num)) * ( 824 + num_queries * height_crop_num) 825 + tokenized_str += tokenized_image 826 + images_seq_mask += [True] * len(tokenized_image) 827 + # num_image_tokens.append(len(tokenized_image)) 828 + 829 + else: 830 + # best_width, best_height = self.image_size, self.image_size 831 + # print(image.size, (best_width, best_height)) # check the select_best_resolutions func 832 + 833 + """process the global view""" 834 + if image_size <= 768: 835 + print('directly resize') 836 + image = image.resize((image_size, image_size)) 837 + # else: 838 + global_view = ImageOps.pad(image, (image_size, image_size), 839 + color=tuple(int(x * 255) for x in image_transform.mean)) 840 + images_list.append(image_transform(global_view).to(torch.bfloat16)) 841 + 842 + if base_size == 1024: 843 + valid_img_tokens += int(256 * ratio) 844 + elif base_size == 1280: 845 + valid_img_tokens += int(400 * ratio) 846 + elif base_size == 640: 847 + valid_img_tokens += int(100 * 1) 848 + elif base_size == 512: 849 + valid_img_tokens += int(64 * 1) 850 + elif base_size == 768: 851 + valid_img_tokens += int(144 * 1) 852 + 853 + width_crop_num, height_crop_num = 1, 1 854 + 855 + images_spatial_crop.append([width_crop_num, height_crop_num]) 856 + 857 + 858 + """add image tokens""" 859 + num_queries = math.ceil((image_size // patch_size) / downsample_ratio) 860 + 861 + tokenized_image = ([image_token_id] * num_queries) * num_queries 862 + tokenized_image += [image_token_id] 863 + # tokenized_image += ([self.image_token_id] * (num_queries * width_crop_num) + [self.image_token_id]) * ( 864 + # num_queries * height_crop_num) 865 + tokenized_str += tokenized_image 866 + images_seq_mask += [True] * len(tokenized_image) 867 + # num_image_tokens.append(len(tokenized_image)) 868 + 869 + 870 + """process the last text split""" 871 + tokenized_sep = text_encode(tokenizer, text_splits[-1], bos=False, eos=False) 872 + tokenized_str += tokenized_sep 873 + images_seq_mask += [False] * len(tokenized_sep) 874 + 875 + """add the bos tokens""" 876 + bos_id = 0 877 + tokenized_str = [bos_id] + tokenized_str 878 + images_seq_mask = [False] + images_seq_mask 879 + 880 + 881 + 882 + input_ids = torch.LongTensor(tokenized_str) 883 + 884 + 885 + 886 + 887 + images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool) 888 + 889 + 890 + if len(images_list) == 0: 891 + images_ori = torch.zeros((1, 3, image_size, image_size)) 892 + images_spatial_crop = torch.zeros((1, 2), dtype=torch.long) 893 + images_crop = torch.zeros((1, 3, base_size, base_size)) 894 + 895 + else: 896 + images_ori = torch.stack(images_list, dim=0) 897 + images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) 898 + if images_crop_list: 899 + images_crop = torch.stack(images_crop_list, dim=0) 900 + else: 901 + images_crop = torch.zeros((1, 3, base_size, base_size)) 902 + 903 + 904 + 905 + if not eval_mode: 906 + streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False) 907 + with torch.autocast("cuda", dtype=torch.bfloat16): 908 + with torch.no_grad(): 909 + output_ids = self.generate( 910 + input_ids.unsqueeze(0).cuda(), 911 + images=[(images_crop.cuda(), images_ori.cuda())], 912 + images_seq_mask = images_seq_mask.unsqueeze(0).cuda(), 913 + images_spatial_crop = images_spatial_crop, 914 + # do_sample=False, 915 + # num_beams = 1, 916 + temperature=0.0, 917 + eos_token_id=tokenizer.eos_token_id, 918 + streamer=streamer, 919 + max_new_tokens=8192, 920 + no_repeat_ngram_size = 20, 921 + use_cache = True 922 + ) 923 + 924 + else: 925 + with torch.autocast("cuda", dtype=torch.bfloat16): 926 + with torch.no_grad(): 927 + output_ids = self.generate( 928 + input_ids.unsqueeze(0).cuda(), 929 + images=[(images_crop.cuda(), images_ori.cuda())], 930 + images_seq_mask = images_seq_mask.unsqueeze(0).cuda(), 931 + images_spatial_crop = images_spatial_crop, 932 + # do_sample=False, 933 + # num_beams = 1, 934 + temperature=0.0, 935 + eos_token_id=tokenizer.eos_token_id, 936 + max_new_tokens=8192, 937 + no_repeat_ngram_size = 35, 938 + use_cache = True 939 + ) 940 + 941 + 942 + if '<image>' in conversation[0]['content'] and eval_mode: 943 + outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:]) 944 + stop_str = '<|end▁of▁sentence|>' 945 + if outputs.endswith(stop_str): 946 + outputs = outputs[:-len(stop_str)] 947 + # re_match 948 + outputs = outputs.strip() 949 + 950 + return outputs 951 + 952 + if '<image>' in conversation[0]['content'] and test_compress: 953 + outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:]) 954 + pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False)) 955 + print('='*50) 956 + print('image size: ', (w, h)) 957 + print('valid image tokens: ', int(valid_img_tokens)) 958 + print('output texts tokens (valid): ', pure_texts_outputs_token_length) 959 + print('compression ratio: ', round(pure_texts_outputs_token_length/valid_img_tokens, 2)) 960 + print('='*50) 961 + 962 + 963 + if '<image>' in conversation[0]['content'] and save_results: 964 + outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:]) 965 + stop_str = '<|end▁of▁sentence|>' 966 + 967 + print('='*15 + 'save results:' + '='*15) 968 + 969 + # # # # conv.messages[-1][-1] = outputs 970 + if outputs.endswith(stop_str): 971 + outputs = outputs[:-len(stop_str)] 972 + outputs = outputs.strip() 973 + 974 + matches_ref, matches_images, mathes_other = re_match(outputs) 975 + # print(matches_ref) 976 + result = process_image_with_refs(image_draw, matches_ref, output_path) 977 + 978 + 979 + for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")): 980 + outputs = outputs.replace(a_match_image, '![](images/' + str(idx) + '.jpg)\n') 981 + 982 + for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")): 983 + outputs = outputs.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:') 984 + 985 + 986 + # if 'structural formula' in conversation[0]['content']: 987 + # outputs = '<smiles>' + outputs + '</smiles>' 988 + with open(f'{output_path}/result.mmd', 'w', encoding = 'utf-8') as afile: 989 + afile.write(outputs) 990 + 991 + if 'line_type' in outputs: 992 + import matplotlib.pyplot as plt 993 + lines = eval(outputs)['Line']['line'] 994 + 995 + line_type = eval(outputs)['Line']['line_type'] 996 + # print(lines) 997 + 998 + endpoints = eval(outputs)['Line']['line_endpoint'] 999 + 1000 + fig, ax = plt.subplots(figsize=(3,3), dpi=200) 1001 + ax.set_xlim(-15, 15) 1002 + ax.set_ylim(-15, 15) 1003 + 1004 + for idx, line in enumerate(lines): 1005 + try: 1006 + p0 = eval(line.split(' -- ')[0]) 1007 + p1 = eval(line.split(' -- ')[-1]) 1008 + 1009 + if line_type[idx] == '--': 1010 + ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color='k') 1011 + else: 1012 + ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth = 0.8, color = 'k') 1013 + 1014 + ax.scatter(p0[0], p0[1], s=5, color = 'k') 1015 + ax.scatter(p1[0], p1[1], s=5, color = 'k') 1016 + except: 1017 + pass 1018 + 1019 + for endpoint in endpoints: 1020 + 1021 + label = endpoint.split(': ')[0] 1022 + (x, y) = eval(endpoint.split(': ')[1]) 1023 + ax.annotate(label, (x, y), xytext=(1, 1), textcoords='offset points', 1024 + fontsize=5, fontweight='light') 1025 + 1026 + 1027 + plt.savefig(f'{output_path}/geo.jpg') 1028 + plt.close() 1029 + 1030 + result.save(f"{output_path}/result_with_boxes.jpg")
+1992
src/deepseek_ocr2/modeling_deepseekv2.py
··· 1 + # coding=utf-8 2 + # Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. 3 + # 4 + # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 + # and OPT implementations in this library. It has been modified from its 6 + # original forms to accommodate minor architectural differences compared 7 + # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 + # 9 + # Licensed under the Apache License, Version 2.0 (the "License"); 10 + # you may not use this file except in compliance with the License. 11 + # You may obtain a copy of the License at 12 + # 13 + # http://www.apache.org/licenses/LICENSE-2.0 14 + # 15 + # Unless required by applicable law or agreed to in writing, software 16 + # distributed under the License is distributed on an "AS IS" BASIS, 17 + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 + # See the License for the specific language governing permissions and 19 + # limitations under the License. 20 + """ PyTorch DeepSeek model and compatible with both DeepSeekV2 and DeepSeekV3""" 21 + import math 22 + import warnings 23 + from typing import List, Optional, Tuple, Union 24 + import numpy as np 25 + 26 + import torch 27 + import torch.nn.functional as F 28 + import torch.utils.checkpoint 29 + import torch.distributed as dist 30 + from einops import repeat 31 + from torch import nn 32 + from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 33 + 34 + from transformers.activations import ACT2FN 35 + from transformers.cache_utils import Cache, DynamicCache 36 + from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask 37 + from transformers.models.llama.modeling_llama import ( 38 + LlamaAttention, 39 + LlamaFlashAttention2 40 + ) 41 + from transformers.modeling_outputs import ( 42 + BaseModelOutputWithPast, 43 + CausalLMOutputWithPast, 44 + SequenceClassifierOutputWithPast, 45 + ) 46 + from transformers.modeling_utils import PreTrainedModel 47 + from transformers.pytorch_utils import ( 48 + ALL_LAYERNORM_LAYERS, 49 + is_torch_greater_or_equal_than_1_13, 50 + ) 51 + from transformers.utils import ( 52 + add_start_docstrings, 53 + add_start_docstrings_to_model_forward, 54 + is_flash_attn_2_available, 55 + is_flash_attn_greater_or_equal_2_10, 56 + logging, 57 + replace_return_docstrings, 58 + ) 59 + from transformers.utils.import_utils import is_torch_fx_available 60 + 61 + from .configuration_deepseek_v2 import DeepseekV2Config 62 + 63 + if is_flash_attn_2_available(): 64 + from flash_attn import flash_attn_func, flash_attn_varlen_func 65 + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa 66 + 67 + # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. 68 + # It means that the function will not be traced through and simply appear as a node in the graph. 69 + if is_torch_fx_available(): 70 + if not is_torch_greater_or_equal_than_1_13: 71 + import torch.fx 72 + 73 + _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) 74 + 75 + logger = logging.get_logger(__name__) 76 + 77 + _CONFIG_FOR_DOC = "DeepseekV2Config" 78 + 79 + 80 + def _get_unpad_data(attention_mask): 81 + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) 82 + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() 83 + max_seqlen_in_batch = seqlens_in_batch.max().item() 84 + cu_seqlens = F.pad( 85 + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) 86 + ) 87 + return ( 88 + indices, 89 + cu_seqlens, 90 + max_seqlen_in_batch, 91 + ) 92 + 93 + 94 + class DeepseekV2RMSNorm(nn.Module): 95 + def __init__(self, hidden_size, eps=1e-6): 96 + """ 97 + DeepseekV2RMSNorm is equivalent to T5LayerNorm 98 + """ 99 + super().__init__() 100 + self.weight = nn.Parameter(torch.ones(hidden_size)) 101 + self.variance_epsilon = eps 102 + 103 + def forward(self, hidden_states): 104 + input_dtype = hidden_states.dtype 105 + hidden_states = hidden_states.to(torch.float32) 106 + variance = hidden_states.pow(2).mean(-1, keepdim=True) 107 + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 108 + return self.weight * hidden_states.to(input_dtype) 109 + 110 + 111 + ALL_LAYERNORM_LAYERS.append(DeepseekV2RMSNorm) 112 + 113 + 114 + 115 + 116 + class DeepseekV2RotaryEmbedding(nn.Module): 117 + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): 118 + super().__init__() 119 + 120 + self.dim = dim 121 + self.max_position_embeddings = max_position_embeddings 122 + self.base = base 123 + inv_freq = 1.0 / ( 124 + self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) 125 + ) 126 + self.register_buffer("inv_freq", inv_freq, persistent=False) 127 + 128 + # Build here to make `torch.jit.trace` work. 129 + self._set_cos_sin_cache( 130 + seq_len=max_position_embeddings, 131 + device=self.inv_freq.device, 132 + dtype=torch.get_default_dtype(), 133 + ) 134 + self.max_seq_len_cached = None 135 + 136 + def _set_cos_sin_cache(self, seq_len, device, dtype): 137 + self.max_seq_len_cached = seq_len 138 + t = torch.arange( 139 + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype 140 + ) 141 + 142 + freqs = torch.outer(t, self.inv_freq.to(t.device)) 143 + # Different from paper, but it uses a different permutation in order to obtain the same calculation 144 + emb = torch.cat((freqs, freqs), dim=-1) 145 + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) 146 + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) 147 + 148 + def forward(self, x, seq_len=None): 149 + # x: [bs, num_attention_heads, seq_len, head_size] 150 + if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: 151 + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) 152 + 153 + return ( 154 + self.cos_cached[:seq_len].to(dtype=x.dtype), 155 + self.sin_cached[:seq_len].to(dtype=x.dtype), 156 + ) 157 + 158 + 159 + # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV2 160 + class DeepseekV2LinearScalingRotaryEmbedding(DeepseekV2RotaryEmbedding): 161 + """DeepseekV2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" 162 + 163 + def __init__( 164 + self, 165 + dim, 166 + max_position_embeddings=2048, 167 + base=10000, 168 + device=None, 169 + scaling_factor=1.0, 170 + ): 171 + self.scaling_factor = scaling_factor 172 + super().__init__(dim, max_position_embeddings, base, device) 173 + 174 + def _set_cos_sin_cache(self, seq_len, device, dtype): 175 + self.max_seq_len_cached = seq_len 176 + t = torch.arange( 177 + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype 178 + ) 179 + t = t / self.scaling_factor 180 + 181 + freqs = torch.outer(t, self.inv_freq) 182 + # Different from paper, but it uses a different permutation in order to obtain the same calculation 183 + emb = torch.cat((freqs, freqs), dim=-1) 184 + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) 185 + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) 186 + 187 + 188 + # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV2 189 + class DeepseekV2DynamicNTKScalingRotaryEmbedding(DeepseekV2RotaryEmbedding): 190 + """DeepseekV2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" 191 + 192 + def __init__( 193 + self, 194 + dim, 195 + max_position_embeddings=2048, 196 + base=10000, 197 + device=None, 198 + scaling_factor=1.0, 199 + ): 200 + self.scaling_factor = scaling_factor 201 + super().__init__(dim, max_position_embeddings, base, device) 202 + 203 + def _set_cos_sin_cache(self, seq_len, device, dtype): 204 + self.max_seq_len_cached = seq_len 205 + 206 + if seq_len > self.max_position_embeddings: 207 + base = self.base * ( 208 + (self.scaling_factor * seq_len / self.max_position_embeddings) 209 + - (self.scaling_factor - 1) 210 + ) ** (self.dim / (self.dim - 2)) 211 + inv_freq = 1.0 / ( 212 + base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) 213 + ) 214 + self.register_buffer("inv_freq", inv_freq, persistent=False) 215 + 216 + t = torch.arange( 217 + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype 218 + ) 219 + 220 + freqs = torch.outer(t, self.inv_freq) 221 + # Different from paper, but it uses a different permutation in order to obtain the same calculation 222 + emb = torch.cat((freqs, freqs), dim=-1) 223 + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) 224 + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) 225 + 226 + 227 + # Inverse dim formula to find dim based on number of rotations 228 + def yarn_find_correction_dim( 229 + num_rotations, dim, base=10000, max_position_embeddings=2048 230 + ): 231 + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( 232 + 2 * math.log(base) 233 + ) 234 + 235 + 236 + # Find dim range bounds based on rotations 237 + def yarn_find_correction_range( 238 + low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 239 + ): 240 + low = math.floor( 241 + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) 242 + ) 243 + high = math.ceil( 244 + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) 245 + ) 246 + return max(low, 0), min(high, dim - 1) # Clamp values just in case 247 + 248 + 249 + def yarn_get_mscale(scale=1, mscale=1): 250 + if scale <= 1: 251 + return 1.0 252 + return 0.1 * mscale * math.log(scale) + 1.0 253 + 254 + 255 + def yarn_linear_ramp_mask(min, max, dim): 256 + if min == max: 257 + max += 0.001 # Prevent singularity 258 + 259 + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) 260 + ramp_func = torch.clamp(linear_func, 0, 1) 261 + return ramp_func 262 + 263 + 264 + class DeepseekV2YarnRotaryEmbedding(DeepseekV2RotaryEmbedding): 265 + 266 + def __init__( 267 + self, 268 + dim, 269 + max_position_embeddings=2048, 270 + base=10000, 271 + device=None, 272 + scaling_factor=1.0, 273 + original_max_position_embeddings=4096, 274 + beta_fast=32, 275 + beta_slow=1, 276 + mscale=1, 277 + mscale_all_dim=0, 278 + ): 279 + self.scaling_factor = scaling_factor 280 + self.original_max_position_embeddings = original_max_position_embeddings 281 + self.beta_fast = beta_fast 282 + self.beta_slow = beta_slow 283 + self.mscale = mscale 284 + self.mscale_all_dim = mscale_all_dim 285 + super().__init__(dim, max_position_embeddings, base, device) 286 + 287 + def _set_cos_sin_cache(self, seq_len, device, dtype): 288 + self.max_seq_len_cached = seq_len 289 + dim = self.dim 290 + 291 + freq_extra = 1.0 / ( 292 + self.base 293 + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) 294 + ) 295 + freq_inter = 1.0 / ( 296 + self.scaling_factor 297 + * self.base 298 + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) 299 + ) 300 + 301 + low, high = yarn_find_correction_range( 302 + self.beta_fast, 303 + self.beta_slow, 304 + dim, 305 + self.base, 306 + self.original_max_position_embeddings, 307 + ) 308 + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( 309 + device=device, dtype=torch.float32 310 + ) 311 + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask 312 + self.register_buffer("inv_freq", inv_freq, persistent=False) 313 + 314 + t = torch.arange(seq_len, device=device, dtype=torch.float32) 315 + 316 + freqs = torch.outer(t, inv_freq) 317 + 318 + _mscale = float( 319 + yarn_get_mscale(self.scaling_factor, self.mscale) 320 + / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) 321 + ) 322 + 323 + emb = torch.cat((freqs, freqs), dim=-1) 324 + self.register_buffer( 325 + "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False 326 + ) 327 + self.register_buffer( 328 + "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False 329 + ) 330 + 331 + 332 + # Copied from transformers.models.llama.modeling_llama.rotate_half 333 + def rotate_half(x): 334 + """Rotates half the hidden dims of the input.""" 335 + x1 = x[..., : x.shape[-1] // 2] 336 + x2 = x[..., x.shape[-1] // 2 :] 337 + return torch.cat((-x2, x1), dim=-1) 338 + 339 + 340 + # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb 341 + def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): 342 + """Applies Rotary Position Embedding to the query and key tensors. 343 + 344 + Args: 345 + q (`torch.Tensor`): The query tensor. 346 + k (`torch.Tensor`): The key tensor. 347 + cos (`torch.Tensor`): The cosine part of the rotary embedding. 348 + sin (`torch.Tensor`): The sine part of the rotary embedding. 349 + position_ids (`torch.Tensor`): 350 + The position indices of the tokens corresponding to the query and key tensors. For example, this can be 351 + used to pass offsetted position ids when working with a KV-cache. 352 + unsqueeze_dim (`int`, *optional*, defaults to 1): 353 + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and 354 + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note 355 + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and 356 + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes 357 + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have 358 + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. 359 + Returns: 360 + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. 361 + """ 362 + cos = cos[position_ids].unsqueeze(unsqueeze_dim) 363 + sin = sin[position_ids].unsqueeze(unsqueeze_dim) 364 + 365 + 366 + # print() 367 + 368 + b, h, s, d = q.shape 369 + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) 370 + 371 + b, h, s, d = k.shape 372 + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) 373 + 374 + q_embed = (q * cos) + (rotate_half(q) * sin) 375 + k_embed = (k * cos) + (rotate_half(k) * sin) 376 + 377 + 378 + return q_embed, k_embed 379 + 380 + 381 + class DeepseekV2MLP(nn.Module): 382 + def __init__(self, config, hidden_size=None, intermediate_size=None): 383 + super().__init__() 384 + self.config = config 385 + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size 386 + self.intermediate_size = ( 387 + config.intermediate_size if intermediate_size is None else intermediate_size 388 + ) 389 + 390 + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 391 + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 392 + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) 393 + self.act_fn = ACT2FN[config.hidden_act] 394 + 395 + def forward(self, x): 396 + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) 397 + return down_proj 398 + 399 + 400 + class MoEGate(nn.Module): 401 + def __init__(self, config): 402 + super().__init__() 403 + self.config = config 404 + self.top_k = config.num_experts_per_tok 405 + self.n_routed_experts = config.n_routed_experts 406 + self.routed_scaling_factor = config.routed_scaling_factor 407 + self.scoring_func = config.scoring_func 408 + self.alpha = config.aux_loss_alpha 409 + self.seq_aux = config.seq_aux 410 + self.topk_method = config.topk_method 411 + self.n_group = config.n_group 412 + self.topk_group = config.topk_group 413 + 414 + # topk selection algorithm 415 + self.norm_topk_prob = config.norm_topk_prob 416 + self.gating_dim = config.hidden_size 417 + self.weight = nn.Parameter( 418 + torch.empty((self.n_routed_experts, self.gating_dim)) 419 + ) 420 + if self.topk_method == "noaux_tc": 421 + self.e_score_correction_bias = nn.Parameter( 422 + torch.empty((self.n_routed_experts)) 423 + ) 424 + self.reset_parameters() 425 + 426 + def reset_parameters(self) -> None: 427 + import torch.nn.init as init 428 + 429 + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 430 + 431 + def forward(self, hidden_states): 432 + bsz, seq_len, h = hidden_states.shape 433 + ### compute gating score 434 + hidden_states = hidden_states.view(-1, h) 435 + logits = F.linear( 436 + hidden_states.type(torch.float32), self.weight.type(torch.float32), None 437 + ) 438 + if self.scoring_func == "softmax": 439 + scores = logits.softmax(dim=-1, dtype=torch.float32) 440 + elif self.scoring_func == "sigmoid": 441 + scores = logits.sigmoid() 442 + else: 443 + raise NotImplementedError( 444 + f"insupportable scoring function for MoE gating: {self.scoring_func}" 445 + ) 446 + 447 + ### select top-k experts 448 + if self.topk_method == "greedy": 449 + topk_weight, topk_idx = torch.topk( 450 + scores, k=self.top_k, dim=-1, sorted=False 451 + ) 452 + elif self.topk_method == "group_limited_greedy": 453 + group_scores = ( 454 + scores.view(bsz * seq_len, self.n_group, -1).max(dim=-1).values 455 + ) # [n, n_group] 456 + group_idx = torch.topk( 457 + group_scores, k=self.topk_group, dim=-1, sorted=False 458 + )[ 459 + 1 460 + ] # [n, top_k_group] 461 + group_mask = torch.zeros_like(group_scores) # [n, n_group] 462 + group_mask.scatter_(1, group_idx, 1) # [n, n_group] 463 + score_mask = ( 464 + group_mask.unsqueeze(-1) 465 + .expand( 466 + bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group 467 + ) 468 + .reshape(bsz * seq_len, -1) 469 + ) # [n, e] 470 + tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] 471 + topk_weight, topk_idx = torch.topk( 472 + tmp_scores, k=self.top_k, dim=-1, sorted=False 473 + ) 474 + elif self.topk_method == "noaux_tc": 475 + assert not self.training 476 + scores_for_choice = scores.view(bsz * seq_len, -1) + self.e_score_correction_bias.unsqueeze(0) 477 + group_scores = ( 478 + scores_for_choice.view(bsz * seq_len, self.n_group, -1).topk(2, dim=-1)[0].sum(dim = -1) 479 + ) # [n, n_group] 480 + group_idx = torch.topk( 481 + group_scores, k=self.topk_group, dim=-1, sorted=False 482 + )[ 483 + 1 484 + ] # [n, top_k_group] 485 + group_mask = torch.zeros_like(group_scores) # [n, n_group] 486 + group_mask.scatter_(1, group_idx, 1) # [n, n_group] 487 + score_mask = ( 488 + group_mask.unsqueeze(-1) 489 + .expand( 490 + bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group 491 + ) 492 + .reshape(bsz * seq_len, -1) 493 + ) # [n, e] 494 + tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e] 495 + _, topk_idx = torch.topk( 496 + tmp_scores, k=self.top_k, dim=-1, sorted=False 497 + ) 498 + topk_weight = scores.gather(1, topk_idx) 499 + 500 + ### norm gate to sum 1 501 + if self.top_k > 1 and self.norm_topk_prob: 502 + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 503 + topk_weight = topk_weight / denominator * self.routed_scaling_factor 504 + else: 505 + topk_weight = topk_weight * self.routed_scaling_factor 506 + ### expert-level computation auxiliary loss 507 + if self.training and self.alpha > 0.0: 508 + scores_for_aux = scores 509 + aux_topk = self.top_k 510 + # always compute aux loss based on the naive greedy topk method 511 + topk_idx_for_aux_loss = topk_idx.view(bsz, -1) 512 + if self.seq_aux: 513 + scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) 514 + ce = torch.zeros( 515 + bsz, self.n_routed_experts, device=hidden_states.device 516 + ) 517 + ce.scatter_add_( 518 + 1, 519 + topk_idx_for_aux_loss, 520 + torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device), 521 + ).div_(seq_len * aux_topk / self.n_routed_experts) 522 + aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum( 523 + dim=1 524 + ).mean() * self.alpha 525 + else: 526 + mask_ce = F.one_hot( 527 + topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts 528 + ) 529 + ce = mask_ce.float().mean(0) 530 + Pi = scores_for_aux.mean(0) 531 + fi = ce * self.n_routed_experts 532 + aux_loss = (Pi * fi).sum() * self.alpha 533 + else: 534 + aux_loss = None 535 + return topk_idx, topk_weight, aux_loss 536 + 537 + 538 + class AddAuxiliaryLoss(torch.autograd.Function): 539 + """ 540 + The trick function of adding auxiliary (aux) loss, 541 + which includes the gradient of the aux loss during backpropagation. 542 + """ 543 + 544 + @staticmethod 545 + def forward(ctx, x, loss): 546 + assert loss.numel() == 1 547 + ctx.dtype = loss.dtype 548 + ctx.required_aux_loss = loss.requires_grad 549 + return x 550 + 551 + @staticmethod 552 + def backward(ctx, grad_output): 553 + grad_loss = None 554 + if ctx.required_aux_loss: 555 + grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) 556 + return grad_output, grad_loss 557 + 558 + 559 + class DeepseekV2MoE(nn.Module): 560 + """ 561 + A mixed expert module containing shared experts. 562 + """ 563 + 564 + def __init__(self, config): 565 + super().__init__() 566 + self.config = config 567 + self.num_experts_per_tok = config.num_experts_per_tok 568 + 569 + if hasattr(config, "ep_size") and config.ep_size > 1: 570 + assert config.ep_size == dist.get_world_size() 571 + self.ep_size = config.ep_size 572 + self.experts_per_rank = config.n_routed_experts // config.ep_size 573 + self.ep_rank = dist.get_rank() 574 + self.experts = nn.ModuleList( 575 + [ 576 + ( 577 + DeepseekV2MLP( 578 + config, intermediate_size=config.moe_intermediate_size 579 + ) 580 + if i >= self.ep_rank * self.experts_per_rank 581 + and i < (self.ep_rank + 1) * self.experts_per_rank 582 + else None 583 + ) 584 + for i in range(config.n_routed_experts) 585 + ] 586 + ) 587 + else: 588 + self.ep_size = 1 589 + self.experts_per_rank = config.n_routed_experts 590 + self.ep_rank = 0 591 + self.experts = nn.ModuleList( 592 + [ 593 + DeepseekV2MLP( 594 + config, intermediate_size=config.moe_intermediate_size 595 + ) 596 + for i in range(config.n_routed_experts) 597 + ] 598 + ) 599 + self.gate = MoEGate(config) 600 + if config.n_shared_experts is not None: 601 + intermediate_size = config.moe_intermediate_size * config.n_shared_experts 602 + self.shared_experts = DeepseekV2MLP( 603 + config=config, intermediate_size=intermediate_size 604 + ) 605 + 606 + def forward(self, hidden_states): 607 + identity = hidden_states 608 + orig_shape = hidden_states.shape 609 + topk_idx, topk_weight, aux_loss = self.gate(hidden_states) 610 + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) 611 + flat_topk_idx = topk_idx.view(-1) 612 + if self.training: 613 + hidden_states = hidden_states.repeat_interleave( 614 + self.num_experts_per_tok, dim=0 615 + ) 616 + y = torch.empty_like(hidden_states) 617 + for i, expert in enumerate(self.experts): 618 + y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i]) 619 + y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) 620 + y = y.to(hidden_states.dtype).view(*orig_shape) 621 + y = AddAuxiliaryLoss.apply(y, aux_loss) 622 + else: 623 + y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape) 624 + if self.config.n_shared_experts is not None: 625 + y = y + self.shared_experts(identity) 626 + return y 627 + 628 + @torch.no_grad() 629 + def moe_infer(self, x, topk_ids, topk_weight): 630 + cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) 631 + cnts.scatter_(1, topk_ids, 1) 632 + tokens_per_expert = cnts.sum(dim=0) 633 + idxs = topk_ids.view(-1).argsort() 634 + sorted_tokens = x[idxs // topk_ids.shape[1]] 635 + sorted_tokens_shape = sorted_tokens.shape 636 + if self.ep_size > 1: 637 + tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) 638 + tokens_per_expert_group = tokens_per_expert.new_empty( 639 + tokens_per_expert.shape[0] 640 + ) 641 + dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert) 642 + output_splits = ( 643 + tokens_per_expert_group.view(self.ep_size, -1) 644 + .sum(1) 645 + .cpu() 646 + .numpy() 647 + .tolist() 648 + ) 649 + gathered_tokens = sorted_tokens.new_empty( 650 + tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1] 651 + ) 652 + input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist() 653 + dist.all_to_all( 654 + list(gathered_tokens.split(output_splits)), 655 + list(sorted_tokens.split(input_split_sizes)), 656 + ) 657 + tokens_per_expert_post_gather = tokens_per_expert_group.view( 658 + self.ep_size, self.experts_per_rank 659 + ).sum(dim=0) 660 + gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) 661 + s = 0 662 + for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): 663 + gatherd_idxs[s : s + k] = i % self.experts_per_rank 664 + s += k 665 + gatherd_idxs = gatherd_idxs.argsort() 666 + sorted_tokens = gathered_tokens[gatherd_idxs] 667 + tokens_per_expert = tokens_per_expert_post_gather 668 + tokens_per_expert = tokens_per_expert.cpu().numpy() 669 + 670 + outputs = [] 671 + start_idx = 0 672 + for i, num_tokens in enumerate(tokens_per_expert): 673 + end_idx = start_idx + num_tokens 674 + if num_tokens == 0: 675 + continue 676 + expert = self.experts[i + self.ep_rank * self.experts_per_rank] 677 + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] 678 + expert_out = expert(tokens_for_this_expert) 679 + outputs.append(expert_out) 680 + start_idx = end_idx 681 + 682 + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) 683 + if self.ep_size > 1: 684 + new_x = torch.empty_like(outs) 685 + new_x[gatherd_idxs] = outs 686 + gathered_tokens = new_x.new_empty(*sorted_tokens_shape) 687 + dist.all_to_all( 688 + list(gathered_tokens.split(input_split_sizes)), 689 + list(new_x.split(output_splits)), 690 + ) 691 + outs = gathered_tokens 692 + 693 + new_x = torch.empty_like(outs) 694 + new_x[idxs] = outs 695 + final_out = ( 696 + new_x.view(*topk_ids.shape, -1) 697 + .type(topk_weight.dtype) 698 + .mul_(topk_weight.unsqueeze(dim=-1)) 699 + .sum(dim=1) 700 + .type(new_x.dtype) 701 + ) 702 + return final_out 703 + 704 + 705 + # Copied from transformers.models.llama.modeling_llama.repeat_kv 706 + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 707 + """ 708 + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, 709 + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) 710 + """ 711 + batch, num_key_value_heads, slen, head_dim = hidden_states.shape 712 + if n_rep == 1: 713 + return hidden_states 714 + hidden_states = hidden_states[:, :, None, :, :].expand( 715 + batch, num_key_value_heads, n_rep, slen, head_dim 716 + ) 717 + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) 718 + 719 + 720 + # Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV2 721 + class DeepseekV2Attention(nn.Module): 722 + """Multi-headed attention from 'Attention Is All You Need' paper""" 723 + 724 + def __init__(self, config: DeepseekV2Config, layer_idx: Optional[int] = None): 725 + super().__init__() 726 + self.config = config 727 + self.layer_idx = layer_idx 728 + if layer_idx is None: 729 + logger.warning_once( 730 + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " 731 + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " 732 + "when creating this class." 733 + ) 734 + 735 + self.attention_dropout = config.attention_dropout 736 + self.hidden_size = config.hidden_size 737 + self.num_heads = config.num_attention_heads 738 + 739 + self.max_position_embeddings = config.max_position_embeddings 740 + self.rope_theta = config.rope_theta 741 + self.q_lora_rank = config.q_lora_rank 742 + self.qk_rope_head_dim = config.qk_rope_head_dim 743 + self.kv_lora_rank = config.kv_lora_rank 744 + self.v_head_dim = config.v_head_dim 745 + self.qk_nope_head_dim = config.qk_nope_head_dim 746 + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim 747 + 748 + self.is_causal = True 749 + 750 + if self.q_lora_rank is None: 751 + self.q_proj = nn.Linear( 752 + self.hidden_size, self.num_heads * self.q_head_dim, bias=False 753 + ) 754 + else: 755 + self.q_a_proj = nn.Linear( 756 + self.hidden_size, config.q_lora_rank, bias=config.attention_bias 757 + ) 758 + self.q_a_layernorm = DeepseekV2RMSNorm(config.q_lora_rank) 759 + self.q_b_proj = nn.Linear( 760 + config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False 761 + ) 762 + # config.kv_lora_rank + config.qk_rope_head_dim, 763 + self.kv_a_proj_with_mqa = nn.Linear( 764 + self.hidden_size, 765 + config.kv_lora_rank + config.qk_rope_head_dim, 766 + bias=config.attention_bias, 767 + ) 768 + self.kv_a_layernorm = DeepseekV2RMSNorm(config.kv_lora_rank) 769 + self.kv_b_proj = nn.Linear( 770 + config.kv_lora_rank, 771 + self.num_heads 772 + * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), 773 + bias=False, 774 + ) 775 + 776 + self.o_proj = nn.Linear( 777 + self.num_heads * self.v_head_dim, 778 + self.hidden_size, 779 + bias=config.attention_bias, 780 + ) 781 + self._init_rope() 782 + 783 + self.softmax_scale = self.q_head_dim ** (-0.5) 784 + if self.config.rope_scaling is not None: 785 + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) 786 + scaling_factor = self.config.rope_scaling["factor"] 787 + if mscale_all_dim: 788 + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) 789 + self.softmax_scale = self.softmax_scale * mscale * mscale 790 + 791 + def _init_rope(self): 792 + if self.config.rope_scaling is None: 793 + self.rotary_emb = DeepseekV2RotaryEmbedding( 794 + self.qk_rope_head_dim, 795 + max_position_embeddings=self.max_position_embeddings, 796 + base=self.rope_theta, 797 + ) 798 + # self.rotary_emb = DeepseekV2LinearScalingRotaryEmbedding( 799 + # self.qk_rope_head_dim, 800 + # max_position_embeddings=self.max_position_embeddings, 801 + # scaling_factor=scaling_factor, 802 + # base=self.rope_theta, 803 + # ) 804 + else: 805 + scaling_type = self.config.rope_scaling["type"] 806 + scaling_factor = self.config.rope_scaling["factor"] 807 + if scaling_type == "linear": 808 + self.rotary_emb = DeepseekV2LinearScalingRotaryEmbedding( 809 + self.qk_rope_head_dim, 810 + max_position_embeddings=self.max_position_embeddings, 811 + scaling_factor=scaling_factor, 812 + base=self.rope_theta, 813 + ) 814 + elif scaling_type == "dynamic": 815 + self.rotary_emb = DeepseekV2DynamicNTKScalingRotaryEmbedding( 816 + self.qk_rope_head_dim, 817 + max_position_embeddings=self.max_position_embeddings, 818 + scaling_factor=scaling_factor, 819 + base=self.rope_theta, 820 + ) 821 + elif scaling_type == "yarn": 822 + kwargs = { 823 + key: self.config.rope_scaling[key] 824 + for key in [ 825 + "original_max_position_embeddings", 826 + "beta_fast", 827 + "beta_slow", 828 + "mscale", 829 + "mscale_all_dim", 830 + ] 831 + if key in self.config.rope_scaling 832 + } 833 + self.rotary_emb = DeepseekV2YarnRotaryEmbedding( 834 + self.qk_rope_head_dim, 835 + max_position_embeddings=self.max_position_embeddings, 836 + scaling_factor=scaling_factor, 837 + base=self.rope_theta, 838 + **kwargs, 839 + ) 840 + else: 841 + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") 842 + 843 + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 844 + return ( 845 + tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim) 846 + .transpose(1, 2) 847 + .contiguous() 848 + ) 849 + 850 + def forward( 851 + self, 852 + hidden_states: torch.Tensor, 853 + attention_mask: Optional[torch.Tensor] = None, 854 + position_ids: Optional[torch.LongTensor] = None, 855 + past_key_value: Optional[Cache] = None, 856 + output_attentions: bool = False, 857 + use_cache: bool = False, 858 + **kwargs, 859 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 860 + if "padding_mask" in kwargs: 861 + warnings.warn( 862 + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" 863 + ) 864 + bsz, q_len, _ = hidden_states.size() 865 + 866 + if self.q_lora_rank is None: 867 + q = self.q_proj(hidden_states) 868 + else: 869 + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) 870 + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) 871 + 872 + 873 + q_nope, q_pe = torch.split( 874 + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 875 + ) 876 + 877 + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) 878 + compressed_kv, k_pe = torch.split( 879 + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 880 + ) 881 + compressed_kv = self.kv_a_layernorm(compressed_kv) 882 + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) 883 + 884 + kv_seq_len = k_pe.shape[-2] 885 + if past_key_value is not None: 886 + if self.layer_idx is None: 887 + raise ValueError( 888 + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " 889 + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " 890 + "with a layer index." 891 + ) 892 + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) 893 + 894 + cos, sin = self.rotary_emb(q_pe, seq_len=kv_seq_len) 895 + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) 896 + 897 + if past_key_value is not None: 898 + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models 899 + compressed_kv = compressed_kv.unsqueeze(1) 900 + k_pe, compressed_kv = past_key_value.update(k_pe, compressed_kv, self.layer_idx, cache_kwargs) 901 + compressed_kv = compressed_kv.squeeze(1) 902 + 903 + kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank) 904 + q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :] 905 + out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :] 906 + 907 + q_nope = torch.matmul(q_nope, q_absorb) 908 + attn_weights = (torch.matmul(q_pe, k_pe.mT) + 909 + torch.matmul(q_nope, compressed_kv.unsqueeze(-3).mT)) * self.softmax_scale 910 + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 911 + raise ValueError( 912 + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" 913 + f" {attn_weights.size()}" 914 + ) 915 + assert attention_mask is not None 916 + if attention_mask is not None: 917 + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 918 + raise ValueError( 919 + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 920 + ) 921 + attn_weights = attn_weights + attention_mask 922 + 923 + # upcast attention to fp32 924 + attn_weights = nn.functional.softmax( 925 + attn_weights, dim=-1, dtype=torch.float32 926 + ).to(q_pe.dtype) 927 + attn_weights = nn.functional.dropout( 928 + attn_weights, p=self.attention_dropout, training=self.training 929 + ) 930 + attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv) 931 + 932 + attn_output = torch.matmul(attn_output, out_absorb.mT) 933 + 934 + if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): 935 + raise ValueError( 936 + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is" 937 + f" {attn_output.size()}" 938 + ) 939 + 940 + attn_output = attn_output.transpose(1, 2).contiguous() 941 + 942 + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) 943 + 944 + attn_output = self.o_proj(attn_output) 945 + 946 + if not output_attentions: 947 + attn_weights = None 948 + 949 + return attn_output, attn_weights, past_key_value 950 + 951 + 952 + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV2 953 + class DeepseekV2FlashAttention2(DeepseekV2Attention): 954 + """ 955 + DeepseekV2 flash attention module. This module inherits from `DeepseekV2Attention` as the weights of the module stays 956 + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of 957 + flash attention and deal with padding tokens in case the input contains any of them. 958 + """ 959 + 960 + def __init__(self, *args, **kwargs): 961 + super().__init__(*args, **kwargs) 962 + 963 + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. 964 + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. 965 + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). 966 + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() 967 + 968 + def forward( 969 + self, 970 + hidden_states: torch.Tensor, 971 + attention_mask: Optional[torch.LongTensor] = None, 972 + position_ids: Optional[torch.LongTensor] = None, 973 + past_key_value: Optional[Cache] = None, 974 + output_attentions: bool = False, 975 + use_cache: bool = False, 976 + **kwargs, 977 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 978 + # DeepseekV2FlashAttention2 attention does not support output_attentions 979 + if "padding_mask" in kwargs: 980 + warnings.warn( 981 + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" 982 + ) 983 + 984 + # overwrite attention_mask with padding_mask 985 + attention_mask = kwargs.pop("padding_mask") 986 + 987 + output_attentions = False 988 + 989 + bsz, q_len, _ = hidden_states.size() 990 + 991 + if self.q_lora_rank is None: 992 + q = self.q_proj(hidden_states) 993 + else: 994 + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) 995 + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) 996 + q_nope, q_pe = torch.split( 997 + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 998 + ) 999 + 1000 + # Flash attention requires the input to have the shape 1001 + # batch_size x seq_length x head_dim x hidden_dim 1002 + # therefore we just need to keep the original shape 1003 + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) 1004 + compressed_kv, k_pe = torch.split( 1005 + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 1006 + ) 1007 + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) 1008 + kv = ( 1009 + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) 1010 + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) 1011 + .transpose(1, 2) 1012 + ) 1013 + 1014 + k_nope, value_states = torch.split( 1015 + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 1016 + ) 1017 + kv_seq_len = value_states.shape[-2] 1018 + 1019 + kv_seq_len = value_states.shape[-2] 1020 + if past_key_value is not None: 1021 + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) 1022 + 1023 + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 1024 + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) 1025 + 1026 + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) 1027 + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope 1028 + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe 1029 + 1030 + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) 1031 + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope 1032 + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe 1033 + 1034 + if self.q_head_dim != self.v_head_dim: 1035 + value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) 1036 + 1037 + # TODO: support compressed_kv for kv_cache (instead of key_states, value_states) in flash_attention version 1038 + if past_key_value is not None: 1039 + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models 1040 + key_states, value_states = past_key_value.update( 1041 + key_states, value_states, self.layer_idx, cache_kwargs 1042 + ) 1043 + 1044 + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache 1045 + # to be able to avoid many of these transpose/reshape/view. 1046 + query_states = query_states.transpose(1, 2) 1047 + key_states = key_states.transpose(1, 2) 1048 + value_states = value_states.transpose(1, 2) 1049 + 1050 + dropout_rate = self.attention_dropout if self.training else 0.0 1051 + 1052 + # In PEFT, usually we cast the layer norms in float32 for training stability reasons 1053 + # therefore the input hidden states gets silently casted in float32. Hence, we need 1054 + # cast them back in the correct dtype just to be sure everything works as expected. 1055 + # This might slowdown training & inference so it is recommended to not cast the LayerNorms 1056 + # in fp32. (DeepseekV2RMSNorm handles it correctly) 1057 + 1058 + input_dtype = query_states.dtype 1059 + if input_dtype == torch.float32: 1060 + # Handle the case where the model is quantized 1061 + if hasattr(self.config, "_pre_quantization_dtype"): 1062 + target_dtype = self.config._pre_quantization_dtype 1063 + elif torch.is_autocast_enabled(): 1064 + target_dtype = torch.get_autocast_gpu_dtype() 1065 + else: 1066 + target_dtype = ( 1067 + self.q_proj.weight.dtype 1068 + if self.q_lora_rank is None 1069 + else self.q_a_proj.weight.dtype 1070 + ) 1071 + 1072 + logger.warning_once( 1073 + f"The input hidden states seems to be silently casted in float32, this might be related to" 1074 + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" 1075 + f" {target_dtype}." 1076 + ) 1077 + 1078 + query_states = query_states.to(target_dtype) 1079 + key_states = key_states.to(target_dtype) 1080 + value_states = value_states.to(target_dtype) 1081 + 1082 + attn_output = self._flash_attention_forward( 1083 + query_states, 1084 + key_states, 1085 + value_states, 1086 + attention_mask, 1087 + q_len, 1088 + dropout=dropout_rate, 1089 + softmax_scale=self.softmax_scale, 1090 + ) 1091 + if self.q_head_dim != self.v_head_dim: 1092 + attn_output = attn_output[:, :, :, : self.v_head_dim] 1093 + 1094 + attn_output = attn_output.reshape( 1095 + bsz, q_len, self.num_heads * self.v_head_dim 1096 + ).contiguous() 1097 + attn_output = self.o_proj(attn_output) 1098 + 1099 + if not output_attentions: 1100 + attn_weights = None 1101 + 1102 + return attn_output, attn_weights, past_key_value 1103 + 1104 + def _flash_attention_forward( 1105 + self, 1106 + query_states, 1107 + key_states, 1108 + value_states, 1109 + attention_mask, 1110 + query_length, 1111 + dropout=0.0, 1112 + softmax_scale=None, 1113 + ): 1114 + """ 1115 + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token 1116 + first unpad the input, then computes the attention scores and pad the final attention scores. 1117 + 1118 + Args: 1119 + query_states (`torch.Tensor`): 1120 + Input query states to be passed to Flash Attention API 1121 + key_states (`torch.Tensor`): 1122 + Input key states to be passed to Flash Attention API 1123 + value_states (`torch.Tensor`): 1124 + Input value states to be passed to Flash Attention API 1125 + attention_mask (`torch.Tensor`): 1126 + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the 1127 + position of padding tokens and 1 for the position of non-padding tokens. 1128 + dropout (`int`, *optional*): 1129 + Attention dropout 1130 + softmax_scale (`float`, *optional*): 1131 + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) 1132 + """ 1133 + if not self._flash_attn_uses_top_left_mask: 1134 + causal = self.is_causal 1135 + else: 1136 + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV2FlashAttention2 __init__. 1137 + causal = self.is_causal and query_length != 1 1138 + 1139 + # Contains at least one padding token in the sequence 1140 + if attention_mask is not None: 1141 + batch_size = query_states.shape[0] 1142 + ( 1143 + query_states, 1144 + key_states, 1145 + value_states, 1146 + indices_q, 1147 + cu_seq_lens, 1148 + max_seq_lens, 1149 + ) = self._upad_input( 1150 + query_states, key_states, value_states, attention_mask, query_length 1151 + ) 1152 + 1153 + cu_seqlens_q, cu_seqlens_k = cu_seq_lens 1154 + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens 1155 + 1156 + attn_output_unpad = flash_attn_varlen_func( 1157 + query_states, 1158 + key_states, 1159 + value_states, 1160 + cu_seqlens_q=cu_seqlens_q, 1161 + cu_seqlens_k=cu_seqlens_k, 1162 + max_seqlen_q=max_seqlen_in_batch_q, 1163 + max_seqlen_k=max_seqlen_in_batch_k, 1164 + dropout_p=dropout, 1165 + softmax_scale=softmax_scale, 1166 + causal=causal, 1167 + ) 1168 + 1169 + attn_output = pad_input( 1170 + attn_output_unpad, indices_q, batch_size, query_length 1171 + ) 1172 + else: 1173 + attn_output = flash_attn_func( 1174 + query_states, 1175 + key_states, 1176 + value_states, 1177 + dropout, 1178 + softmax_scale=softmax_scale, 1179 + causal=causal, 1180 + ) 1181 + 1182 + return attn_output 1183 + 1184 + def _upad_input( 1185 + self, query_layer, key_layer, value_layer, attention_mask, query_length 1186 + ): 1187 + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) 1188 + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape 1189 + 1190 + key_layer = index_first_axis( 1191 + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), 1192 + indices_k, 1193 + ) 1194 + value_layer = index_first_axis( 1195 + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), 1196 + indices_k, 1197 + ) 1198 + if query_length == kv_seq_len: 1199 + query_layer = index_first_axis( 1200 + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), 1201 + indices_k, 1202 + ) 1203 + cu_seqlens_q = cu_seqlens_k 1204 + max_seqlen_in_batch_q = max_seqlen_in_batch_k 1205 + indices_q = indices_k 1206 + elif query_length == 1: 1207 + max_seqlen_in_batch_q = 1 1208 + cu_seqlens_q = torch.arange( 1209 + batch_size + 1, dtype=torch.int32, device=query_layer.device 1210 + ) # There is a memcpy here, that is very bad. 1211 + indices_q = cu_seqlens_q[:-1] 1212 + query_layer = query_layer.squeeze(1) 1213 + else: 1214 + # The -q_len: slice assumes left padding. 1215 + attention_mask = attention_mask[:, -query_length:] 1216 + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( 1217 + query_layer, attention_mask 1218 + ) 1219 + 1220 + return ( 1221 + query_layer, 1222 + key_layer, 1223 + value_layer, 1224 + indices_q, 1225 + (cu_seqlens_q, cu_seqlens_k), 1226 + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), 1227 + ) 1228 + 1229 + 1230 + ATTENTION_CLASSES = { 1231 + "eager": DeepseekV2Attention, 1232 + "flash_attention_2": DeepseekV2FlashAttention2, 1233 + 1234 + "mla_eager": DeepseekV2Attention, 1235 + "mla_flash_attention_2": DeepseekV2FlashAttention2, 1236 + 1237 + "mha_eager": LlamaAttention, 1238 + "mha_flash_attention_2": LlamaFlashAttention2 1239 + } 1240 + 1241 + 1242 + class DeepseekV2DecoderLayer(nn.Module): 1243 + def __init__(self, config: DeepseekV2Config, layer_idx: int): 1244 + super().__init__() 1245 + self.hidden_size = config.hidden_size 1246 + 1247 + 1248 + if config.use_mla: 1249 + attn_implementation = "mla_" + config._attn_implementation 1250 + else: 1251 + attn_implementation = "mha_" + config._attn_implementation 1252 + 1253 + self.self_attn = ATTENTION_CLASSES[attn_implementation]( 1254 + config=config, layer_idx=layer_idx 1255 + ) 1256 + 1257 + self.mlp = ( 1258 + DeepseekV2MoE(config) 1259 + if ( 1260 + config.n_routed_experts is not None 1261 + and layer_idx >= config.first_k_dense_replace 1262 + and layer_idx % config.moe_layer_freq == 0 1263 + ) 1264 + else DeepseekV2MLP(config) 1265 + ) 1266 + self.input_layernorm = DeepseekV2RMSNorm( 1267 + config.hidden_size, eps=config.rms_norm_eps 1268 + ) 1269 + self.post_attention_layernorm = DeepseekV2RMSNorm( 1270 + config.hidden_size, eps=config.rms_norm_eps 1271 + ) 1272 + 1273 + def forward( 1274 + self, 1275 + hidden_states: torch.Tensor, 1276 + attention_mask: Optional[torch.Tensor] = None, 1277 + position_ids: Optional[torch.LongTensor] = None, 1278 + past_key_value: Optional[Tuple[torch.Tensor]] = None, 1279 + output_attentions: Optional[bool] = False, 1280 + use_cache: Optional[bool] = False, 1281 + **kwargs, 1282 + ) -> Tuple[ 1283 + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] 1284 + ]: 1285 + """ 1286 + Args: 1287 + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` 1288 + attention_mask (`torch.FloatTensor`, *optional*): 1289 + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, 1290 + query_sequence_length, key_sequence_length)` if default attention is used. 1291 + output_attentions (`bool`, *optional*): 1292 + Whether or not to return the attentions tensors of all attention layers. See `attentions` under 1293 + returned tensors for more detail. 1294 + use_cache (`bool`, *optional*): 1295 + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding 1296 + (see `past_key_values`). 1297 + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states 1298 + """ 1299 + if "padding_mask" in kwargs: 1300 + warnings.warn( 1301 + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" 1302 + ) 1303 + residual = hidden_states 1304 + 1305 + hidden_states = self.input_layernorm(hidden_states) 1306 + 1307 + # Self Attention 1308 + hidden_states, self_attn_weights, present_key_value = self.self_attn( 1309 + hidden_states=hidden_states, 1310 + attention_mask=attention_mask, 1311 + position_ids=position_ids, 1312 + past_key_value=past_key_value, 1313 + output_attentions=output_attentions, 1314 + use_cache=use_cache, 1315 + **kwargs, 1316 + ) 1317 + hidden_states = residual + hidden_states 1318 + 1319 + # Fully Connected 1320 + residual = hidden_states 1321 + hidden_states = self.post_attention_layernorm(hidden_states) 1322 + hidden_states = self.mlp(hidden_states) 1323 + hidden_states = residual + hidden_states 1324 + 1325 + outputs = (hidden_states,) 1326 + 1327 + if output_attentions: 1328 + outputs += (self_attn_weights,) 1329 + 1330 + if use_cache: 1331 + outputs += (present_key_value,) 1332 + 1333 + return outputs 1334 + 1335 + 1336 + DeepseekV2_START_DOCSTRING = r""" 1337 + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 1338 + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 1339 + etc.) 1340 + 1341 + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 1342 + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 1343 + and behavior. 1344 + 1345 + Parameters: 1346 + config ([`DeepseekV2Config`]): 1347 + Model configuration class with all the parameters of the model. Initializing with a config file does not 1348 + load the weights associated with the model, only the configuration. Check out the 1349 + [`~PreTrainedModel.from_pretrained`] method to load the model weights. 1350 + """ 1351 + 1352 + 1353 + @add_start_docstrings( 1354 + "The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.", 1355 + DeepseekV2_START_DOCSTRING, 1356 + ) 1357 + class DeepseekV2PreTrainedModel(PreTrainedModel): 1358 + config_class = DeepseekV2Config 1359 + base_model_prefix = "model" 1360 + supports_gradient_checkpointing = True 1361 + _no_split_modules = ["DeepseekV2DecoderLayer"] 1362 + _skip_keys_device_placement = "past_key_values" 1363 + _supports_flash_attn_2 = True 1364 + _supports_cache_class = True 1365 + 1366 + def _init_weights(self, module): 1367 + std = self.config.initializer_range 1368 + if isinstance(module, nn.Linear): 1369 + module.weight.data.normal_(mean=0.0, std=std) 1370 + if module.bias is not None: 1371 + module.bias.data.zero_() 1372 + elif isinstance(module, nn.Embedding): 1373 + module.weight.data.normal_(mean=0.0, std=std) 1374 + if module.padding_idx is not None: 1375 + module.weight.data[module.padding_idx].zero_() 1376 + 1377 + 1378 + DeepseekV2_INPUTS_DOCSTRING = r""" 1379 + Args: 1380 + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 1381 + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide 1382 + it. 1383 + 1384 + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 1385 + [`PreTrainedTokenizer.__call__`] for details. 1386 + 1387 + [What are input IDs?](../glossary#input-ids) 1388 + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 1389 + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 1390 + 1391 + - 1 for tokens that are **not masked**, 1392 + - 0 for tokens that are **masked**. 1393 + 1394 + [What are attention masks?](../glossary#attention-mask) 1395 + 1396 + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 1397 + [`PreTrainedTokenizer.__call__`] for details. 1398 + 1399 + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see 1400 + `past_key_values`). 1401 + 1402 + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] 1403 + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more 1404 + information on the default strategy. 1405 + 1406 + - 1 indicates the head is **not masked**, 1407 + - 0 indicates the head is **masked**. 1408 + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1409 + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 1410 + config.n_positions - 1]`. 1411 + 1412 + [What are position IDs?](../glossary#position-ids) 1413 + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): 1414 + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention 1415 + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` 1416 + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. 1417 + 1418 + Two formats are allowed: 1419 + - a [`~cache_utils.Cache`] instance; 1420 + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of 1421 + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy 1422 + cache format. 1423 + 1424 + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the 1425 + legacy cache format will be returned. 1426 + 1427 + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't 1428 + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` 1429 + of shape `(batch_size, sequence_length)`. 1430 + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 1431 + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 1432 + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 1433 + model's internal embedding lookup matrix. 1434 + use_cache (`bool`, *optional*): 1435 + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 1436 + `past_key_values`). 1437 + output_attentions (`bool`, *optional*): 1438 + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 1439 + tensors for more detail. 1440 + output_hidden_states (`bool`, *optional*): 1441 + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 1442 + more detail. 1443 + return_dict (`bool`, *optional*): 1444 + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 1445 + """ 1446 + 1447 + 1448 + @add_start_docstrings( 1449 + "The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.", 1450 + DeepseekV2_START_DOCSTRING, 1451 + ) 1452 + class DeepseekV2Model(DeepseekV2PreTrainedModel): 1453 + """ 1454 + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`] 1455 + 1456 + Args: 1457 + config: DeepseekV2Config 1458 + """ 1459 + 1460 + def __init__(self, config: DeepseekV2Config): 1461 + super().__init__(config) 1462 + self.padding_idx = config.pad_token_id 1463 + self.vocab_size = config.vocab_size 1464 + 1465 + self.embed_tokens = nn.Embedding( 1466 + config.vocab_size, config.hidden_size, self.padding_idx 1467 + ) 1468 + self.layers = nn.ModuleList( 1469 + [ 1470 + DeepseekV2DecoderLayer(config, layer_idx) 1471 + for layer_idx in range(config.num_hidden_layers) 1472 + ] 1473 + ) 1474 + # print(config._attn_implementation) 1475 + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" 1476 + self.norm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 1477 + 1478 + self.gradient_checkpointing = False 1479 + # Initialize weights and apply final processing 1480 + self.post_init() 1481 + 1482 + def get_input_embeddings(self): 1483 + return self.embed_tokens 1484 + 1485 + def set_input_embeddings(self, value): 1486 + self.embed_tokens = value 1487 + 1488 + @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) 1489 + def forward( 1490 + self, 1491 + input_ids: torch.LongTensor = None, 1492 + attention_mask: Optional[torch.Tensor] = None, 1493 + position_ids: Optional[torch.LongTensor] = None, 1494 + past_key_values: Optional[List[torch.FloatTensor]] = None, 1495 + inputs_embeds: Optional[torch.FloatTensor] = None, 1496 + use_cache: Optional[bool] = None, 1497 + output_attentions: Optional[bool] = None, 1498 + output_hidden_states: Optional[bool] = None, 1499 + return_dict: Optional[bool] = None, 1500 + cache_position: Optional[torch.LongTensor] = None 1501 + ) -> Union[Tuple, BaseModelOutputWithPast]: 1502 + output_attentions = ( 1503 + output_attentions 1504 + if output_attentions is not None 1505 + else self.config.output_attentions 1506 + ) 1507 + output_hidden_states = ( 1508 + output_hidden_states 1509 + if output_hidden_states is not None 1510 + else self.config.output_hidden_states 1511 + ) 1512 + use_cache = use_cache if use_cache is not None else self.config.use_cache 1513 + 1514 + return_dict = ( 1515 + return_dict if return_dict is not None else self.config.use_return_dict 1516 + ) 1517 + 1518 + # retrieve input_ids and inputs_embeds 1519 + if input_ids is not None and inputs_embeds is not None: 1520 + raise ValueError( 1521 + "You cannot specify both input_ids and inputs_embeds at the same time" 1522 + ) 1523 + elif input_ids is not None: 1524 + batch_size, seq_length = input_ids.shape[:2] 1525 + elif inputs_embeds is not None: 1526 + batch_size, seq_length = inputs_embeds.shape[:2] 1527 + else: 1528 + raise ValueError("You have to specify either input_ids or inputs_embeds") 1529 + 1530 + if self.gradient_checkpointing and self.training: 1531 + if use_cache: 1532 + logger.warning_once( 1533 + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers." 1534 + ) 1535 + use_cache = False 1536 + 1537 + past_key_values_length = 0 1538 + if use_cache: 1539 + use_legacy_cache = not isinstance(past_key_values, Cache) 1540 + if use_legacy_cache: 1541 + past_key_values = DynamicCache.from_legacy_cache(past_key_values) 1542 + past_key_values_length = past_key_values.get_usable_length(seq_length) 1543 + 1544 + if position_ids is None: 1545 + device = input_ids.device if input_ids is not None else inputs_embeds.device 1546 + position_ids = torch.arange( 1547 + past_key_values_length, 1548 + seq_length + past_key_values_length, 1549 + dtype=torch.long, 1550 + device=device, 1551 + ) 1552 + position_ids = position_ids.unsqueeze(0) 1553 + 1554 + if inputs_embeds is None: 1555 + inputs_embeds = self.embed_tokens(input_ids) 1556 + 1557 + if self._use_flash_attention_2: 1558 + # 2d mask is passed through the layers 1559 + attention_mask = ( 1560 + attention_mask 1561 + if (attention_mask is not None and 0 in attention_mask) 1562 + else None 1563 + ) 1564 + else: 1565 + # 4d mask is passed through the layers 1566 + attention_mask = _prepare_4d_causal_attention_mask( 1567 + attention_mask, 1568 + (batch_size, seq_length), 1569 + inputs_embeds, 1570 + past_key_values_length, 1571 + ) 1572 + 1573 + # embed positions 1574 + hidden_states = inputs_embeds 1575 + 1576 + # decoder layers 1577 + all_hidden_states = () if output_hidden_states else None 1578 + all_self_attns = () if output_attentions else None 1579 + next_decoder_cache = None 1580 + 1581 + for decoder_layer in self.layers: 1582 + if output_hidden_states: 1583 + all_hidden_states += (hidden_states,) 1584 + 1585 + if self.gradient_checkpointing and self.training: 1586 + layer_outputs = self._gradient_checkpointing_func( 1587 + decoder_layer.__call__, 1588 + hidden_states, 1589 + attention_mask, 1590 + position_ids, 1591 + past_key_values, 1592 + output_attentions, 1593 + use_cache, 1594 + ) 1595 + else: 1596 + layer_outputs = decoder_layer( 1597 + hidden_states, 1598 + attention_mask=attention_mask, 1599 + position_ids=position_ids, 1600 + past_key_value=past_key_values, 1601 + output_attentions=output_attentions, 1602 + use_cache=use_cache, 1603 + ) 1604 + 1605 + hidden_states = layer_outputs[0] 1606 + 1607 + if use_cache: 1608 + next_decoder_cache = layer_outputs[2 if output_attentions else 1] 1609 + 1610 + if output_attentions: 1611 + all_self_attns += (layer_outputs[1],) 1612 + 1613 + hidden_states = self.norm(hidden_states) 1614 + 1615 + # add hidden states from the last decoder layer 1616 + if output_hidden_states: 1617 + all_hidden_states += (hidden_states,) 1618 + 1619 + next_cache = None 1620 + if use_cache: 1621 + next_cache = ( 1622 + next_decoder_cache.to_legacy_cache() 1623 + if use_legacy_cache 1624 + else next_decoder_cache 1625 + ) 1626 + if not return_dict: 1627 + return tuple( 1628 + v 1629 + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] 1630 + if v is not None 1631 + ) 1632 + return BaseModelOutputWithPast( 1633 + last_hidden_state=hidden_states, 1634 + past_key_values=next_cache, 1635 + hidden_states=all_hidden_states, 1636 + attentions=all_self_attns, 1637 + ) 1638 + 1639 + 1640 + class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel): 1641 + _tied_weights_keys = ["lm_head.weight"] 1642 + 1643 + def __init__(self, config): 1644 + super().__init__(config) 1645 + self.model = DeepseekV2Model(config) 1646 + self.vocab_size = config.vocab_size 1647 + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 1648 + 1649 + # Initialize weights and apply final processing 1650 + self.post_init() 1651 + 1652 + def get_input_embeddings(self): 1653 + return self.model.embed_tokens 1654 + 1655 + def set_input_embeddings(self, value): 1656 + self.model.embed_tokens = value 1657 + 1658 + def get_output_embeddings(self): 1659 + return self.lm_head 1660 + 1661 + def set_output_embeddings(self, new_embeddings): 1662 + self.lm_head = new_embeddings 1663 + 1664 + def set_decoder(self, decoder): 1665 + self.model = decoder 1666 + 1667 + def get_decoder(self): 1668 + return self.model 1669 + 1670 + @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) 1671 + @replace_return_docstrings( 1672 + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC 1673 + ) 1674 + def forward( 1675 + self, 1676 + input_ids: torch.LongTensor = None, 1677 + attention_mask: Optional[torch.Tensor] = None, 1678 + position_ids: Optional[torch.LongTensor] = None, 1679 + past_key_values: Optional[List[torch.FloatTensor]] = None, 1680 + inputs_embeds: Optional[torch.FloatTensor] = None, 1681 + labels: Optional[torch.LongTensor] = None, 1682 + use_cache: Optional[bool] = None, 1683 + output_attentions: Optional[bool] = None, 1684 + output_hidden_states: Optional[bool] = None, 1685 + return_dict: Optional[bool] = None, 1686 + cache_position: Optional[torch.LongTensor] = None 1687 + ) -> Union[Tuple, CausalLMOutputWithPast]: 1688 + r""" 1689 + Args: 1690 + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1691 + Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers., 1692 + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 1693 + (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`. 1694 + 1695 + Returns: 1696 + 1697 + Example: 1698 + 1699 + ```python 1700 + >>> from transformers import AutoTokenizer, DeepseekV2ForCausalLM 1701 + 1702 + >>> model = DeepseekV2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) 1703 + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) 1704 + 1705 + >>> prompt = "Hey, are you conscious? Can you talk to me?" 1706 + >>> inputs = tokenizer(prompt, return_tensors="pt") 1707 + 1708 + >>> # Generate 1709 + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 1710 + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 1711 + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." 1712 + ```""" 1713 + output_attentions = ( 1714 + output_attentions 1715 + if output_attentions is not None 1716 + else self.config.output_attentions 1717 + ) 1718 + output_hidden_states = ( 1719 + output_hidden_states 1720 + if output_hidden_states is not None 1721 + else self.config.output_hidden_states 1722 + ) 1723 + return_dict = ( 1724 + return_dict if return_dict is not None else self.config.use_return_dict 1725 + ) 1726 + 1727 + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 1728 + outputs = self.model( 1729 + input_ids=input_ids, 1730 + attention_mask=attention_mask, 1731 + position_ids=position_ids, 1732 + past_key_values=past_key_values, 1733 + inputs_embeds=inputs_embeds, 1734 + use_cache=use_cache, 1735 + output_attentions=output_attentions, 1736 + output_hidden_states=output_hidden_states, 1737 + return_dict=return_dict, 1738 + cache_position=cache_position 1739 + ) 1740 + 1741 + hidden_states = outputs[0] 1742 + logits = self.lm_head(hidden_states) 1743 + logits = logits.float() 1744 + 1745 + loss = None 1746 + if labels is not None: 1747 + # Shift so that tokens < n predict n 1748 + shift_logits = logits[..., :-1, :].contiguous() 1749 + shift_labels = labels[..., 1:].contiguous() 1750 + # Flatten the tokens 1751 + loss_fct = CrossEntropyLoss() 1752 + shift_logits = shift_logits.view(-1, self.config.vocab_size) 1753 + shift_labels = shift_labels.view(-1) 1754 + # Enable model parallelism 1755 + shift_labels = shift_labels.to(shift_logits.device) 1756 + loss = loss_fct(shift_logits, shift_labels) 1757 + 1758 + if not return_dict: 1759 + output = (logits,) + outputs[1:] 1760 + return (loss,) + output if loss is not None else output 1761 + 1762 + return CausalLMOutputWithPast( 1763 + loss=loss, 1764 + logits=logits, 1765 + past_key_values=outputs.past_key_values, 1766 + hidden_states=outputs.hidden_states, 1767 + attentions=outputs.attentions, 1768 + ) 1769 + 1770 + def prepare_inputs_for_generation( 1771 + self, 1772 + input_ids, 1773 + past_key_values=None, 1774 + attention_mask=None, 1775 + inputs_embeds=None, 1776 + **kwargs, 1777 + ): 1778 + past_length = 0 1779 + if past_key_values is not None: 1780 + if isinstance(past_key_values, Cache): 1781 + cache_length = past_key_values.get_seq_length() 1782 + past_length = past_key_values.seen_tokens 1783 + max_cache_length = past_key_values.get_max_length() 1784 + else: 1785 + cache_length = past_length = past_key_values[0][0].shape[2] 1786 + max_cache_length = None 1787 + 1788 + # Keep only the unprocessed tokens: 1789 + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where 1790 + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as 1791 + # input) 1792 + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: 1793 + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):] 1794 + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard 1795 + # input_ids based on the past_length. 1796 + elif past_length < input_ids.shape[1]: 1797 + input_ids = input_ids[:, past_length:] 1798 + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. 1799 + 1800 + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. 1801 + if ( 1802 + max_cache_length is not None 1803 + and attention_mask is not None 1804 + and cache_length + input_ids.shape[1] > max_cache_length 1805 + ): 1806 + attention_mask = attention_mask[:, -max_cache_length:] 1807 + 1808 + position_ids = kwargs.get("position_ids", None) 1809 + if attention_mask is not None and position_ids is None: 1810 + # create position_ids on the fly for batch generation 1811 + position_ids = attention_mask.long().cumsum(-1) - 1 1812 + position_ids.masked_fill_(attention_mask == 0, 1) 1813 + if past_key_values: 1814 + position_ids = position_ids[:, -input_ids.shape[1]:] 1815 + 1816 + if self.generation_config.cache_implementation == "static": 1817 + # generation with static cache 1818 + cache_position = kwargs.get("cache_position", None) 1819 + if cache_position is None: 1820 + past_length = 0 1821 + else: 1822 + past_length = cache_position[-1] + 1 1823 + input_ids = input_ids[:, past_length:] 1824 + position_ids = position_ids[:, past_length:] 1825 + 1826 + # TODO @gante we should only keep a `cache_position` in generate, and do +=1. 1827 + # same goes for position ids. Could also help with continued generation. 1828 + cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) 1829 + 1830 + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 1831 + if inputs_embeds is not None and past_key_values is None: 1832 + model_inputs = {"inputs_embeds": inputs_embeds} 1833 + else: 1834 + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise 1835 + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 1836 + # TODO: use `next_tokens` directly instead. 1837 + model_inputs = {"input_ids": input_ids.contiguous()} 1838 + 1839 + model_inputs.update( 1840 + { 1841 + "position_ids": position_ids.contiguous(), 1842 + "cache_position": cache_position, 1843 + "past_key_values": past_key_values, 1844 + "use_cache": kwargs.get("use_cache"), 1845 + "attention_mask": attention_mask, 1846 + } 1847 + ) 1848 + return model_inputs 1849 + 1850 + @staticmethod 1851 + def _reorder_cache(past_key_values, beam_idx): 1852 + reordered_past = () 1853 + for layer_past in past_key_values: 1854 + reordered_past += ( 1855 + tuple( 1856 + past_state.index_select(0, beam_idx.to(past_state.device)) 1857 + for past_state in layer_past 1858 + ), 1859 + ) 1860 + return reordered_past 1861 + 1862 + 1863 + @add_start_docstrings( 1864 + """ 1865 + The DeepseekV2 Model transformer with a sequence classification head on top (linear layer). 1866 + 1867 + [`DeepseekV2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models 1868 + (e.g. GPT-2) do. 1869 + 1870 + Since it does classification on the last token, it requires to know the position of the last token. If a 1871 + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If 1872 + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the 1873 + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in 1874 + each row of the batch). 1875 + """, 1876 + DeepseekV2_START_DOCSTRING, 1877 + ) 1878 + class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel): 1879 + def __init__(self, config): 1880 + super().__init__(config) 1881 + self.num_labels = config.num_labels 1882 + self.model = DeepseekV2Model(config) 1883 + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) 1884 + 1885 + # Initialize weights and apply final processing 1886 + self.post_init() 1887 + 1888 + def get_input_embeddings(self): 1889 + return self.model.embed_tokens 1890 + 1891 + def set_input_embeddings(self, value): 1892 + self.model.embed_tokens = value 1893 + 1894 + @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) 1895 + def forward( 1896 + self, 1897 + input_ids: torch.LongTensor = None, 1898 + attention_mask: Optional[torch.Tensor] = None, 1899 + position_ids: Optional[torch.LongTensor] = None, 1900 + past_key_values: Optional[List[torch.FloatTensor]] = None, 1901 + inputs_embeds: Optional[torch.FloatTensor] = None, 1902 + labels: Optional[torch.LongTensor] = None, 1903 + use_cache: Optional[bool] = None, 1904 + output_attentions: Optional[bool] = None, 1905 + output_hidden_states: Optional[bool] = None, 1906 + return_dict: Optional[bool] = None, 1907 + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: 1908 + r""" 1909 + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1910 + Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers., 1911 + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 1912 + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1913 + """ 1914 + return_dict = ( 1915 + return_dict if return_dict is not None else self.config.use_return_dict 1916 + ) 1917 + 1918 + transformer_outputs = self.model( 1919 + input_ids, 1920 + attention_mask=attention_mask, 1921 + position_ids=position_ids, 1922 + past_key_values=past_key_values, 1923 + inputs_embeds=inputs_embeds, 1924 + use_cache=use_cache, 1925 + output_attentions=output_attentions, 1926 + output_hidden_states=output_hidden_states, 1927 + return_dict=return_dict, 1928 + ) 1929 + hidden_states = transformer_outputs[0] 1930 + logits = self.score(hidden_states) 1931 + 1932 + if input_ids is not None: 1933 + batch_size = input_ids.shape[0] 1934 + else: 1935 + batch_size = inputs_embeds.shape[0] 1936 + 1937 + if self.config.pad_token_id is None and batch_size != 1: 1938 + raise ValueError( 1939 + "Cannot handle batch sizes > 1 if no padding token is defined." 1940 + ) 1941 + if self.config.pad_token_id is None: 1942 + sequence_lengths = -1 1943 + else: 1944 + if input_ids is not None: 1945 + sequence_lengths = ( 1946 + torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 1947 + ).to(logits.device) 1948 + else: 1949 + sequence_lengths = -1 1950 + 1951 + pooled_logits = logits[ 1952 + torch.arange(batch_size, device=logits.device), sequence_lengths 1953 + ] 1954 + 1955 + loss = None 1956 + if labels is not None: 1957 + labels = labels.to(logits.device) 1958 + if self.config.problem_type is None: 1959 + if self.num_labels == 1: 1960 + self.config.problem_type = "regression" 1961 + elif self.num_labels > 1 and ( 1962 + labels.dtype == torch.long or labels.dtype == torch.int 1963 + ): 1964 + self.config.problem_type = "single_label_classification" 1965 + else: 1966 + self.config.problem_type = "multi_label_classification" 1967 + 1968 + if self.config.problem_type == "regression": 1969 + loss_fct = MSELoss() 1970 + if self.num_labels == 1: 1971 + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) 1972 + else: 1973 + loss = loss_fct(pooled_logits, labels) 1974 + elif self.config.problem_type == "single_label_classification": 1975 + loss_fct = CrossEntropyLoss() 1976 + loss = loss_fct( 1977 + pooled_logits.view(-1, self.num_labels), labels.view(-1) 1978 + ) 1979 + elif self.config.problem_type == "multi_label_classification": 1980 + loss_fct = BCEWithLogitsLoss() 1981 + loss = loss_fct(pooled_logits, labels) 1982 + if not return_dict: 1983 + output = (pooled_logits,) + transformer_outputs[1:] 1984 + return ((loss,) + output) if loss is not None else output 1985 + 1986 + return SequenceClassifierOutputWithPast( 1987 + loss=loss, 1988 + logits=pooled_logits, 1989 + past_key_values=transformer_outputs.past_key_values, 1990 + hidden_states=transformer_outputs.hidden_states, 1991 + attentions=transformer_outputs.attentions, 1992 + )