this repo has no description
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

Add DeepSeek-OCR-2 model code with eval_mode inference fix

modeling_deepseekocr2.py is the upstream model package with one local patch:
in the eval_mode branch of infer(), pass attention_mask=torch.ones_like(input_ids)
to generate() to suppress the spurious warning caused by pad_token_id == eos_token_id,
and reuse the already-computed _input_ids_cuda tensor in the decode step.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

+1055 -1
+1053
deepseek/deepseek_ocr2/modeling_deepseekocr2.py
··· 1 + import os 2 + import math 3 + import re 4 + from tqdm import tqdm 5 + from abc import ABC 6 + from typing import List, Optional, Tuple, Union 7 + 8 + from addict import Dict 9 + from PIL import Image, ImageOps, ImageDraw, ImageFont 10 + import numpy as np 11 + 12 + import torch 13 + import torch.nn as nn 14 + from torch.nn import CrossEntropyLoss 15 + from torchvision import transforms 16 + 17 + from transformers.cache_utils import Cache 18 + from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast 19 + from transformers import DeepseekV2Model, DeepseekV2ForCausalLM 20 + from transformers import DeepseekV2Config 21 + from transformers.models.deepseek_v2.modeling_deepseek_v2 import ( 22 + DeepseekV2Attention, 23 + DeepseekV2MLP, 24 + DeepseekV2MoE, 25 + DeepseekV2RMSNorm, 26 + DeepseekV2DecoderLayer, 27 + ) 28 + from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRotaryEmbedding 29 + from transformers import TextStreamer 30 + from .deepencoderv2 import build_sam_vit_b, build_qwen2_decoder_as_encoder, MlpProjector 31 + from .conversation import get_conv_template 32 + 33 + torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 34 + 35 + def load_image(image_path): 36 + 37 + try: 38 + image = Image.open(image_path) 39 + 40 + corrected_image = ImageOps.exif_transpose(image) 41 + 42 + return corrected_image 43 + 44 + except Exception as e: 45 + print(f"error: {e}") 46 + try: 47 + return Image.open(image_path) 48 + except: 49 + return None 50 + 51 + 52 + def re_match(text): 53 + pattern = r'(<\|ref\|>(.*?)<\|/ref\|><\|det\|>(.*?)<\|/det\|>)' 54 + matches = re.findall(pattern, text, re.DOTALL) 55 + 56 + # pattern1 = r'<\|ref\|>.*?<\|/ref\|>\n' 57 + # new_text1 = re.sub(pattern1, '', text, flags=re.DOTALL) 58 + 59 + mathes_image = [] 60 + mathes_other = [] 61 + for a_match in matches: 62 + if '<|ref|>image<|/ref|>' in a_match[0]: 63 + mathes_image.append(a_match[0]) 64 + else: 65 + mathes_other.append(a_match[0]) 66 + return matches, mathes_image, mathes_other 67 + 68 + 69 + def extract_coordinates_and_label(ref_text, image_width, image_height): 70 + 71 + try: 72 + label_type = ref_text[1] 73 + cor_list = eval(ref_text[2]) 74 + except Exception as e: 75 + print(e) 76 + return None 77 + 78 + return (label_type, cor_list) 79 + 80 + 81 + def draw_bounding_boxes(image, refs, ouput_path): 82 + 83 + image_width, image_height = image.size 84 + 85 + img_draw = image.copy() 86 + draw = ImageDraw.Draw(img_draw) 87 + 88 + overlay = Image.new('RGBA', img_draw.size, (0, 0, 0, 0)) 89 + draw2 = ImageDraw.Draw(overlay) 90 + 91 + # try: 92 + # except IOError: 93 + # try: 94 + # font = ImageFont.truetype("DejaVuSans.ttf", 20) 95 + # except IOError: 96 + font = ImageFont.load_default() 97 + 98 + img_idx = 0 99 + 100 + for i, ref in enumerate(refs): 101 + try: 102 + result = extract_coordinates_and_label(ref, image_width, image_height) 103 + if result: 104 + label_type, points_list = result 105 + 106 + color = (np.random.randint(0, 200), np.random.randint(0, 200), np.random.randint(0, 255)) 107 + 108 + color_a = color + (20, ) 109 + for points in points_list: 110 + x1, y1, x2, y2 = points 111 + 112 + x1 = int(x1 / 999 * image_width) 113 + y1 = int(y1 / 999 * image_height) 114 + 115 + x2 = int(x2 / 999 * image_width) 116 + y2 = int(y2 / 999 * image_height) 117 + 118 + if label_type == 'image': 119 + try: 120 + cropped = image.crop((x1, y1, x2, y2)) 121 + cropped.save(f"{ouput_path}/images/{img_idx}.jpg") 122 + except Exception as e: 123 + print(e) 124 + pass 125 + img_idx += 1 126 + 127 + try: 128 + if label_type == 'title': 129 + draw.rectangle([x1, y1, x2, y2], outline=color, width=4) 130 + draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1) 131 + else: 132 + draw.rectangle([x1, y1, x2, y2], outline=color, width=2) 133 + draw2.rectangle([x1, y1, x2, y2], fill=color_a, outline=(0, 0, 0, 0), width=1) 134 + text_x = x1 135 + text_y = max(0, y1 - 15) 136 + 137 + 138 + text_bbox = draw.textbbox((0, 0), label_type, font=font) 139 + text_width = text_bbox[2] - text_bbox[0] 140 + text_height = text_bbox[3] - text_bbox[1] 141 + draw.rectangle([text_x, text_y, text_x + text_width, text_y + text_height], 142 + fill=(255, 255, 255, 30)) 143 + 144 + draw.text((text_x, text_y), label_type, font=font, fill=color) 145 + except: 146 + pass 147 + except: 148 + continue 149 + img_draw.paste(overlay, (0, 0), overlay) 150 + return img_draw 151 + 152 + 153 + def process_image_with_refs(image, ref_texts, output_path): 154 + 155 + result_image = draw_bounding_boxes(image, ref_texts, output_path) 156 + 157 + return result_image 158 + 159 + 160 + 161 + 162 + 163 + def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): 164 + best_ratio_diff = float('inf') 165 + best_ratio = (1, 1) 166 + area = width * height 167 + for ratio in target_ratios: 168 + target_aspect_ratio = ratio[0] / ratio[1] 169 + ratio_diff = abs(aspect_ratio - target_aspect_ratio) 170 + if ratio_diff < best_ratio_diff: 171 + best_ratio_diff = ratio_diff 172 + best_ratio = ratio 173 + elif ratio_diff == best_ratio_diff: 174 + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: 175 + best_ratio = ratio 176 + # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}') 177 + return best_ratio 178 + 179 + 180 + def dynamic_preprocess(image, min_num=2, max_num=6, image_size=768, use_thumbnail=False): 181 + orig_width, orig_height = image.size 182 + aspect_ratio = orig_width / orig_height 183 + 184 + # calculate the existing image aspect ratio 185 + target_ratios = set( 186 + (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if 187 + i * j <= max_num and i * j >= min_num) 188 + # print(target_ratios) 189 + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) 190 + 191 + # find the closest aspect ratio to the target 192 + target_aspect_ratio = find_closest_aspect_ratio( 193 + aspect_ratio, target_ratios, orig_width, orig_height, image_size) 194 + 195 + # print(target_aspect_ratio) 196 + # calculate the target width and height 197 + target_width = image_size * target_aspect_ratio[0] 198 + target_height = image_size * target_aspect_ratio[1] 199 + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] 200 + 201 + # resize the image 202 + resized_img = image.resize((target_width, target_height)) 203 + processed_images = [] 204 + for i in range(blocks): 205 + box = ( 206 + (i % (target_width // image_size)) * image_size, 207 + (i // (target_width // image_size)) * image_size, 208 + ((i % (target_width // image_size)) + 1) * image_size, 209 + ((i // (target_width // image_size)) + 1) * image_size 210 + ) 211 + # split the image 212 + split_img = resized_img.crop(box) 213 + processed_images.append(split_img) 214 + assert len(processed_images) == blocks 215 + if use_thumbnail and len(processed_images) != 1: 216 + thumbnail_img = image.resize((image_size, image_size)) 217 + processed_images.append(thumbnail_img) 218 + return processed_images, target_aspect_ratio 219 + 220 + 221 + 222 + def normalize_transform(mean, std): 223 + if mean is None and std is None: 224 + transform = None 225 + elif mean is None and std is not None: 226 + mean = [0.] * len(std) 227 + transform = transforms.Normalize(mean=mean, std=std) 228 + elif mean is not None and std is None: 229 + std = [1.] * len(mean) 230 + transform = transforms.Normalize(mean=mean, std=std) 231 + else: 232 + transform = transforms.Normalize(mean=mean, std=std) 233 + 234 + return transform 235 + 236 + 237 + 238 + def format_messages( 239 + conversations: List[Dict[str, str]], 240 + sft_format: str = "deepseek", 241 + system_prompt: str = "", 242 + ): 243 + """ 244 + Applies the SFT template to conversation. 245 + 246 + Args: 247 + conversations (List[Dict]): A List of messages. 248 + sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek". 249 + system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "". 250 + 251 + Returns: 252 + sft_prompt (str): The formatted text. 253 + """ 254 + 255 + conv = get_conv_template(sft_format) 256 + conv.set_system_message(system_prompt) 257 + for message in conversations: 258 + conv.append_message(message["role"], message["content"].strip()) 259 + sft_prompt = conv.get_prompt().strip() 260 + 261 + return sft_prompt 262 + 263 + 264 + def text_encode(tokenizer, text: str, bos: bool = True, eos: bool = False): 265 + t = tokenizer.encode(text, add_special_tokens=False) 266 + bos_id = 0 267 + eos_id = 1 268 + if bos: 269 + t = [bos_id] + t 270 + if eos: 271 + t = t + [eos_id] 272 + 273 + return t 274 + 275 + def load_pil_images(conversations: List[Dict[str, str]]) -> List[Image.Image]: 276 + """ 277 + 278 + Args: 279 + conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is : 280 + [ 281 + { 282 + "role": "User", 283 + "content": "<image_placeholder>\nExtract all information from this image and convert them into markdown format.", 284 + "images": ["./examples/table_datasets.png"] 285 + }, 286 + {"role": "Assistant", "content": ""}, 287 + ] 288 + 289 + Returns: 290 + pil_images (List[PIL.Image.Image]): the list of PIL images. 291 + 292 + """ 293 + 294 + pil_images = [] 295 + 296 + for message in conversations: 297 + if "images" not in message: 298 + continue 299 + 300 + for image_path in message["images"]: 301 + # print('----------------') 302 + # print(image_path) 303 + # print('----------------') 304 + # exit() 305 + 306 + # pil_img = Image.open(image_path) 307 + pil_img = load_image(image_path) 308 + pil_img = pil_img.convert("RGB") 309 + pil_images.append(pil_img) 310 + 311 + return pil_images 312 + 313 + 314 + class BaseTransform(ABC): 315 + 316 + def set_rng(self, *args, **kwargs): 317 + pass 318 + 319 + def __call__(self, *args, **kwargs) -> torch.Tensor: 320 + pass 321 + 322 + @property 323 + def default_shape(self): 324 + raise NotImplementedError 325 + 326 + 327 + class BasicImageTransform(BaseTransform): 328 + def __init__( 329 + self, 330 + mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), 331 + std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5), 332 + normalize: bool = True 333 + ): 334 + self.mean = mean 335 + self.std = std 336 + 337 + transform_pipelines = [ 338 + transforms.ToTensor() 339 + ] 340 + 341 + normalize = normalize_transform(mean, std) if normalize else nn.Identity() 342 + if normalize is not None: 343 + transform_pipelines.append(normalize) 344 + 345 + self.transform = transforms.Compose(transform_pipelines) 346 + 347 + def __call__(self, x): 348 + x = self.transform(x) 349 + return x 350 + 351 + class NoEOSTextStreamer(TextStreamer): 352 + def on_finalized_text(self, text: str, stream_end: bool = False): 353 + 354 + eos_text = self.tokenizer.decode([self.tokenizer.eos_token_id], skip_special_tokens=False) 355 + text = text.replace(eos_text, "\n") 356 + print(text, flush=True, end="") 357 + 358 + def decoder_layer_init(self, config: DeepseekV2Config, layer_idx: int): 359 + nn.Module.__init__(self) 360 + self.hidden_size = config.hidden_size 361 + 362 + if config.use_mla: 363 + self.self_attn = DeepseekV2Attention(config=config, layer_idx=layer_idx) 364 + else: 365 + config.head_dim = config.hidden_size // config.num_attention_heads 366 + self.self_attn = LlamaAttention(config, layer_idx) 367 + self.mlp = DeepseekV2MoE(config) if layer_idx >= config.first_k_dense_replace else DeepseekV2MLP(config) 368 + 369 + self.input_layernorm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 370 + self.post_attention_layernorm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 371 + 372 + 373 + DeepseekV2DecoderLayer.__init__ = decoder_layer_init 374 + 375 + class DeepseekOCR2Config(DeepseekV2Config): 376 + model_type = "DeepseekOCR2" 377 + 378 + class DeepseekOCR2Model(DeepseekV2Model): 379 + config_class = DeepseekOCR2Config 380 + 381 + def __init__(self, config: DeepseekV2Config): 382 + super(DeepseekOCR2Model, self).__init__(config) 383 + 384 + self.sam_model = build_sam_vit_b() 385 + self.qwen2_model = build_qwen2_decoder_as_encoder() 386 + # self.conv_2 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=2, stride=2) 387 + n_embed = 1280 388 + self.projector = MlpProjector(Dict(projector_type="linear", input_dim=896, n_embed=n_embed)) 389 + embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32)) 390 + # self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std) 391 + self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std) 392 + self.rotary_emb = LlamaRotaryEmbedding(config=config) 393 + 394 + 395 + def forward( 396 + self, 397 + input_ids: torch.LongTensor = None, 398 + attention_mask: Optional[torch.Tensor] = None, 399 + position_ids: Optional[torch.LongTensor] = None, 400 + past_key_values: Optional[List[torch.FloatTensor]] = None, 401 + inputs_embeds: Optional[torch.FloatTensor] = None, 402 + use_cache: Optional[bool] = None, 403 + output_attentions: Optional[bool] = None, 404 + output_hidden_states: Optional[bool] = None, 405 + images: Optional[torch.FloatTensor] = None, 406 + images_seq_mask: Optional[torch.FloatTensor] = None, 407 + images_spatial_crop: Optional[torch.FloatTensor] = None, 408 + return_dict: Optional[bool] = None, 409 + ) -> Union[Tuple, BaseModelOutputWithPast]: 410 + 411 + if inputs_embeds is None: 412 + # inputs_embeds = self.embed_tokens(input_ids) 413 + inputs_embeds = self.get_input_embeddings()(input_ids) 414 + inputs_embeds = inputs_embeds.clone() 415 + 416 + sam_model = getattr(self, 'sam_model', None) 417 + # sam_model = self.sam_model 418 + qwen2_model = getattr(self, 'qwen2_model', None) 419 + 420 + if sam_model is not None and (input_ids.shape[1] != 1 or self.training) and torch.sum(images[0][1]).item() != 0: 421 + 422 + idx = 0 423 + 424 + # sam_model = torch.jit.script(sam_model) 425 + 426 + # start_time = time.time() 427 + for image, crop_shape in zip(images, images_spatial_crop): 428 + images_in_this_batch = [] 429 + 430 + patches = image[0] 431 + image_ori = image[1] 432 + 433 + with torch.no_grad(): 434 + # with torch.inference_mode(): 435 + 436 + if torch.sum(patches).item() != 0: 437 + # P, C, H, W = patches.shape 438 + crop_flag = 1 439 + local_features_1 = sam_model(patches) 440 + 441 + local_features_2 = qwen2_model(local_features_1) 442 + # vit_time = time.time() 443 + local_features = local_features_2 444 + local_features = self.projector(local_features) 445 + 446 + 447 + global_features_1 = sam_model(image_ori) 448 + global_features_2 = qwen2_model(global_features_1) 449 + global_features = global_features_2 450 + global_features = self.projector(global_features) 451 + 452 + # print('=====================') 453 + # print('BASE: ', global_features.shape) 454 + # print('PATCHES: ', local_features.shape) 455 + # print('=====================') 456 + 457 + _, hw, n_dim = global_features.shape 458 + # h = w = int(hw ** 0.5) 459 + 460 + _2, hw2, n_dim2 = local_features.shape 461 + # h2 = w2 = int(hw2 ** 0.5) 462 + 463 + 464 + global_features = global_features.view(-1, n_dim) 465 + 466 + 467 + local_features = local_features.view(-1, n_dim2) 468 + 469 + global_local_features = torch.cat([local_features, global_features, self.view_seperator[None, :]], dim=0) 470 + 471 + # end_time = time.time() 472 + 473 + # print('sam: ', sam_time - start_time) 474 + # print('vit: ', vit_time - sam_time) 475 + # print('all: ', end_time - start_time) 476 + 477 + # exit() 478 + 479 + else: 480 + global_features_1 = sam_model(image_ori) 481 + global_features_2 = qwen2_model(global_features_1) 482 + global_features = global_features_2 483 + global_features = self.projector(global_features) 484 + # print('=====================') 485 + # print('BASE: ', global_features.shape) 486 + # print('NO PATCHES') 487 + # print('=====================') 488 + _, hw, n_dim = global_features.shape 489 + # h = w = int(hw ** 0.5) 490 + 491 + 492 + # global_features = global_features.view(h, w, n_dim) 493 + 494 + # global_features = torch.cat( 495 + # [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1 496 + # ) 497 + 498 + global_features = global_features.view(-1, n_dim) 499 + 500 + global_local_features = torch.cat([global_features, self.view_seperator[None, :]], dim=0) 501 + 502 + images_in_this_batch.append(global_local_features) 503 + 504 + 505 + # print(inputs_embeds.shape) 506 + 507 + if images_in_this_batch: 508 + images_in_this_batch = torch.cat(images_in_this_batch, dim=0) 509 + # exit() 510 + 511 + # inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).cuda(), images_in_this_batch) 512 + images_in_this_batch = images_in_this_batch.to( 513 + device=inputs_embeds.device, dtype=inputs_embeds.dtype 514 + ) 515 + mask = images_seq_mask[idx].unsqueeze(-1).to(inputs_embeds.device) # bool [T, 1] 516 + updated_row = inputs_embeds[idx].masked_scatter(mask, images_in_this_batch) 517 + inputs_embeds[idx] = updated_row 518 + 519 + idx += 1 520 + 521 + 522 + return super(DeepseekOCR2Model, self).forward( 523 + input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, 524 + inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids = position_ids, 525 + output_attentions=output_attentions, output_hidden_states=output_hidden_states, 526 + return_dict=return_dict 527 + ) 528 + 529 + 530 + class DeepseekOCR2ForCausalLM(DeepseekV2ForCausalLM): 531 + 532 + config_class = DeepseekOCR2Config 533 + # supports_gradient_checkpointing = True 534 + 535 + def __init__(self, config): 536 + super(DeepseekV2ForCausalLM, self).__init__(config) 537 + self.model = DeepseekOCR2Model(config) 538 + 539 + self.vocab_size = config.vocab_size 540 + 541 + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 542 + 543 + # self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 544 + 545 + # Initialize weights and apply final processing 546 + self.post_init() 547 + 548 + def get_model(self): 549 + return self.model 550 + 551 + 552 + def forward( 553 + self, 554 + input_ids: torch.LongTensor = None, 555 + attention_mask: Optional[torch.Tensor] = None, 556 + position_ids: Optional[torch.LongTensor] = None, 557 + past_key_values: Optional[List[torch.FloatTensor]] = None, 558 + inputs_embeds: Optional[torch.FloatTensor] = None, 559 + labels: Optional[torch.LongTensor] = None, 560 + use_cache: Optional[bool] = None, 561 + output_attentions: Optional[bool] = None, 562 + output_hidden_states: Optional[bool] = None, 563 + images: Optional[torch.FloatTensor] = None, 564 + images_seq_mask: Optional[torch.FloatTensor] = None, 565 + images_spatial_crop: Optional[torch.FloatTensor] = None, 566 + return_dict: Optional[bool] = None, 567 + 568 + ) -> Union[Tuple, CausalLMOutputWithPast]: 569 + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 570 + output_hidden_states = ( 571 + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 572 + ) 573 + return_dict = return_dict if return_dict is not None else self.config.use_return_dict 574 + 575 + 576 + 577 + outputs = self.model( 578 + input_ids=input_ids, 579 + past_key_values=past_key_values, 580 + attention_mask=attention_mask, 581 + position_ids=position_ids, 582 + inputs_embeds=inputs_embeds, 583 + use_cache=use_cache, 584 + output_attentions=output_attentions, 585 + output_hidden_states=output_hidden_states, 586 + images=images, 587 + images_seq_mask = images_seq_mask, 588 + images_spatial_crop = images_spatial_crop, 589 + return_dict=return_dict 590 + 591 + ) 592 + 593 + 594 + 595 + # print(transformer_outputs) 596 + 597 + hidden_states = outputs[0] 598 + logits = self.lm_head(hidden_states) 599 + logits = logits.float() 600 + 601 + # logits 602 + 603 + loss = None 604 + if labels is not None: 605 + # Shift so that tokens < n predict n 606 + shift_logits = logits[..., :-1, :].contiguous() 607 + shift_labels = labels[..., 1:].contiguous() 608 + # Flatten the tokens 609 + loss_fct = CrossEntropyLoss() 610 + shift_logits = shift_logits.view(-1, self.config.vocab_size) 611 + shift_labels = shift_labels.view(-1) 612 + # Enable model parallelism 613 + shift_labels = shift_labels.to(shift_logits.device) 614 + loss = loss_fct(shift_logits, shift_labels) 615 + 616 + if not return_dict: 617 + output = (logits,) + outputs[1:] 618 + return (loss,) + output if loss is not None else output 619 + 620 + return CausalLMOutputWithPast( 621 + loss=loss, 622 + logits=logits, 623 + past_key_values=outputs.past_key_values, 624 + hidden_states=outputs.hidden_states, 625 + attentions=outputs.attentions, 626 + ) 627 + 628 + 629 + def prepare_inputs_for_generation( 630 + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 631 + ): 632 + # Omit tokens covered by past_key_values 633 + past_length = 0 634 + if past_key_values is not None: 635 + if isinstance(past_key_values, Cache): 636 + cache_length = past_key_values.get_seq_length() 637 + past_length = past_key_values.get_seq_length() 638 + max_cache_length = None 639 + else: 640 + cache_length = past_length = past_key_values[0][0].shape[2] 641 + max_cache_length = None 642 + 643 + # Keep only the unprocessed tokens: 644 + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where 645 + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as 646 + # input) 647 + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: 648 + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] 649 + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard 650 + # input_ids based on the past_length. 651 + elif past_length < input_ids.shape[1]: 652 + input_ids = input_ids[:, past_length:] 653 + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. 654 + 655 + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. 656 + if ( 657 + max_cache_length is not None 658 + and attention_mask is not None 659 + and cache_length + input_ids.shape[1] > max_cache_length 660 + ): 661 + attention_mask = attention_mask[:, -max_cache_length:] 662 + 663 + position_ids = kwargs.get("position_ids", None) 664 + if attention_mask is not None and position_ids is None: 665 + # create position_ids on the fly for batch generation 666 + position_ids = attention_mask.long().cumsum(-1) - 1 667 + position_ids.masked_fill_(attention_mask == 0, 1) 668 + if past_key_values: 669 + position_ids = position_ids[:, -input_ids.shape[1] :] 670 + 671 + # if self.generation_config.cache_implementation == "static": 672 + # # generation with static cache 673 + # cache_position = kwargs.get("cache_position", None) 674 + # if cache_position is None: 675 + # past_length = 0 676 + # else: 677 + # past_length = cache_position[-1] + 1 678 + # input_ids = input_ids[:, past_length:] 679 + # position_ids = position_ids[:, past_length:] 680 + 681 + # TODO @gante we should only keep a `cache_position` in generate, and do +=1. 682 + # same goes for position ids. Could also help with continued generation. 683 + cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) 684 + 685 + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 686 + if inputs_embeds is not None and past_key_values is None: 687 + model_inputs = {"inputs_embeds": inputs_embeds} 688 + else: 689 + model_inputs = {"input_ids": input_ids} 690 + 691 + model_inputs.update( 692 + { 693 + "position_ids": position_ids, 694 + "past_key_values": past_key_values, 695 + "use_cache": kwargs.get("use_cache"), 696 + "attention_mask": attention_mask, 697 + "images": kwargs.get("images", None), 698 + "images_seq_mask": kwargs.get("images_seq_mask", None), 699 + "images_spatial_crop": kwargs.get("images_spatial_crop", None), 700 + } 701 + ) 702 + return model_inputs 703 + 704 + 705 + def disable_torch_init(self): 706 + """ 707 + Disable the redundant torch default initialization to accelerate model creation. 708 + """ 709 + import torch 710 + setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 711 + setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 712 + 713 + 714 + 715 + def infer(self, tokenizer, prompt='', image_file='', output_path = '', base_size=1024, image_size=640, crop_mode=True, test_compress=False, save_results=False, eval_mode=False): 716 + self.disable_torch_init() 717 + 718 + os.makedirs(output_path, exist_ok=True) 719 + os.makedirs(f'{output_path}/images', exist_ok=True) 720 + 721 + if prompt and image_file: 722 + conversation = [ 723 + { 724 + "role": "<|User|>", 725 + # "content": "<image>\n<|grounding|>Given the layout of the image. ", 726 + "content": f'{prompt}', 727 + # "content": "君不见黄河之水天上来的下一句是什么?", 728 + # "content": "<image>\nFree OCR. ", 729 + # "content": "<image>\nParse the figure. ", 730 + # "content": "<image>\nExtract the text in the image. ", 731 + "images": [f'{image_file}'], 732 + }, 733 + {"role": "<|Assistant|>", "content": ""}, 734 + ] 735 + 736 + elif prompt: 737 + conversation = [ 738 + { 739 + "role": "<|User|>", 740 + # "content": "<image>\n<|grounding|>Given the layout of the image. ", 741 + "content": f'{prompt}', 742 + # "content": "君不见黄河之水天上来的下一句是什么?", 743 + # "content": "<image>\nFree OCR. ", 744 + # "content": "<image>\nParse the figure. ", 745 + # "content": "<image>\nExtract the text in the image. ", 746 + # "images": [f'{image_file}'], 747 + }, 748 + {"role": "<|Assistant|>", "content": ""}, 749 + ] 750 + else: 751 + assert False, f'prompt is none!' 752 + 753 + prompt = format_messages(conversations=conversation, sft_format='plain', system_prompt='') 754 + 755 + patch_size = 16 756 + downsample_ratio = 4 757 + images = load_pil_images(conversation) 758 + 759 + valid_img_tokens = 0 760 + ratio = 1 761 + 762 + image_draw = images[0].copy() 763 + 764 + w,h = image_draw.size 765 + # print(w, h) 766 + ratio = 1 - ((max(w, h) - min(w, h)) / (max(w, h))) 767 + 768 + 769 + image_transform=BasicImageTransform(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), normalize=True) 770 + images_seq_mask = [] 771 + 772 + image_token = '<image>' 773 + image_token_id = 128815 774 + text_splits = prompt.split(image_token) 775 + 776 + images_list, images_crop_list, images_seq_mask = [], [], [] 777 + tokenized_str = [] 778 + images_spatial_crop = [] 779 + for text_sep, image in zip(text_splits, images): 780 + 781 + tokenized_sep = text_encode(tokenizer, text_sep, bos=False, eos=False) 782 + tokenized_str += tokenized_sep 783 + images_seq_mask += [False] * len(tokenized_sep) 784 + 785 + if crop_mode: 786 + 787 + if image.size[0] <= 768 and image.size[1] <= 768: 788 + crop_ratio = [1, 1] 789 + 790 + else: 791 + if crop_mode: 792 + # best_width, best_height = select_best_resolution(image.size, self.candidate_resolutions) 793 + images_crop_raw, crop_ratio = dynamic_preprocess(image) 794 + else: 795 + # best_width, best_height = self.image_size, self.image_size 796 + crop_ratio = [1, 1] 797 + 798 + """process the global view""" 799 + # image = image.resize((base_size, base_size)) 800 + global_view = ImageOps.pad(image, (base_size, base_size), 801 + color=tuple(int(x * 255) for x in image_transform.mean)) 802 + 803 + if base_size == 1024: 804 + valid_img_tokens += int(256 * ratio) 805 + elif base_size == 1280: 806 + valid_img_tokens += int(400 * ratio) 807 + # elif base_size == 640: 808 + # valid_img_tokens += int(100 * ratio) 809 + 810 + 811 + 812 + 813 + 814 + images_list.append(image_transform(global_view).to(torch_dtype)) 815 + 816 + # global_view_tensor = image_transform(global_view).to(torch_dtype) 817 + 818 + width_crop_num, height_crop_num = crop_ratio 819 + 820 + images_spatial_crop.append([width_crop_num, height_crop_num]) 821 + 822 + 823 + if width_crop_num > 1 or height_crop_num > 1: 824 + """process the local views""" 825 + 826 + for i in range(len(images_crop_raw)): 827 + images_crop_list.append(image_transform(images_crop_raw[i]).to(torch_dtype)) 828 + 829 + if image_size == 768: 830 + valid_img_tokens += len(images_crop_list) * 144 831 + 832 + num_queries = math.ceil((image_size // patch_size) / downsample_ratio) 833 + num_queries_base = math.ceil((base_size // patch_size) / downsample_ratio) 834 + 835 + 836 + 837 + """add image tokens""" 838 + 839 + 840 + 841 + tokenized_image = ([image_token_id] * num_queries_base) * num_queries_base 842 + tokenized_image += [image_token_id] 843 + if width_crop_num > 1 or height_crop_num > 1: 844 + tokenized_image += ([image_token_id] * (num_queries * width_crop_num)) * ( 845 + num_queries * height_crop_num) 846 + tokenized_str += tokenized_image 847 + images_seq_mask += [True] * len(tokenized_image) 848 + # num_image_tokens.append(len(tokenized_image)) 849 + 850 + else: 851 + # best_width, best_height = self.image_size, self.image_size 852 + # print(image.size, (best_width, best_height)) # check the select_best_resolutions func 853 + 854 + """process the global view""" 855 + if image_size <= 768: 856 + print('directly resize') 857 + image = image.resize((image_size, image_size)) 858 + # else: 859 + global_view = ImageOps.pad(image, (image_size, image_size), 860 + color=tuple(int(x * 255) for x in image_transform.mean)) 861 + images_list.append(image_transform(global_view).to(torch_dtype)) 862 + 863 + if base_size == 1024: 864 + valid_img_tokens += int(256 * ratio) 865 + elif base_size == 1280: 866 + valid_img_tokens += int(400 * ratio) 867 + elif base_size == 640: 868 + valid_img_tokens += int(100 * 1) 869 + elif base_size == 512: 870 + valid_img_tokens += int(64 * 1) 871 + elif base_size == 768: 872 + valid_img_tokens += int(144 * 1) 873 + 874 + width_crop_num, height_crop_num = 1, 1 875 + 876 + images_spatial_crop.append([width_crop_num, height_crop_num]) 877 + 878 + 879 + """add image tokens""" 880 + num_queries = math.ceil((image_size // patch_size) / downsample_ratio) 881 + 882 + tokenized_image = ([image_token_id] * num_queries) * num_queries 883 + tokenized_image += [image_token_id] 884 + # tokenized_image += ([self.image_token_id] * (num_queries * width_crop_num) + [self.image_token_id]) * ( 885 + # num_queries * height_crop_num) 886 + tokenized_str += tokenized_image 887 + images_seq_mask += [True] * len(tokenized_image) 888 + # num_image_tokens.append(len(tokenized_image)) 889 + 890 + 891 + """process the last text split""" 892 + tokenized_sep = text_encode(tokenizer, text_splits[-1], bos=False, eos=False) 893 + tokenized_str += tokenized_sep 894 + images_seq_mask += [False] * len(tokenized_sep) 895 + 896 + """add the bos tokens""" 897 + bos_id = 0 898 + tokenized_str = [bos_id] + tokenized_str 899 + images_seq_mask = [False] + images_seq_mask 900 + 901 + 902 + 903 + input_ids = torch.LongTensor(tokenized_str) 904 + 905 + 906 + 907 + 908 + images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool) 909 + 910 + 911 + if len(images_list) == 0: 912 + images_ori = torch.zeros((1, 3, image_size, image_size)) 913 + images_spatial_crop = torch.zeros((1, 2), dtype=torch.long) 914 + images_crop = torch.zeros((1, 3, base_size, base_size)) 915 + 916 + else: 917 + images_ori = torch.stack(images_list, dim=0) 918 + images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) 919 + if images_crop_list: 920 + images_crop = torch.stack(images_crop_list, dim=0) 921 + else: 922 + images_crop = torch.zeros((1, 3, base_size, base_size)) 923 + 924 + 925 + 926 + if not eval_mode: 927 + streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False) 928 + with torch.autocast("cuda", dtype=torch_dtype): 929 + with torch.no_grad(): 930 + output_ids = self.generate( 931 + input_ids.unsqueeze(0).cuda(), 932 + images=[(images_crop.cuda(), images_ori.cuda())], 933 + images_seq_mask = images_seq_mask.unsqueeze(0).cuda(), 934 + images_spatial_crop = images_spatial_crop, 935 + # do_sample=False, 936 + # num_beams = 1, 937 + temperature=0.0, 938 + eos_token_id=tokenizer.eos_token_id, 939 + streamer=streamer, 940 + max_new_tokens=8192, 941 + no_repeat_ngram_size = 20, 942 + use_cache = True 943 + ) 944 + 945 + else: 946 + with torch.autocast("cuda", dtype=torch_dtype): 947 + with torch.no_grad(): 948 + _input_ids_cuda = input_ids.unsqueeze(0).cuda() 949 + output_ids = self.generate( 950 + _input_ids_cuda, 951 + attention_mask=torch.ones_like(_input_ids_cuda), 952 + images=[(images_crop.cuda(), images_ori.cuda())], 953 + images_seq_mask = images_seq_mask.unsqueeze(0).cuda(), 954 + images_spatial_crop = images_spatial_crop, 955 + # do_sample=False, 956 + # num_beams = 1, 957 + temperature=0.0, 958 + eos_token_id=tokenizer.eos_token_id, 959 + max_new_tokens=8192, 960 + no_repeat_ngram_size = 35, 961 + use_cache = True 962 + ) 963 + 964 + 965 + if '<image>' in conversation[0]['content'] and eval_mode: 966 + outputs = tokenizer.decode(output_ids[0, _input_ids_cuda.shape[1]:]) 967 + stop_str = '<|end▁of▁sentence|>' 968 + if outputs.endswith(stop_str): 969 + outputs = outputs[:-len(stop_str)] 970 + # re_match 971 + outputs = outputs.strip() 972 + 973 + return outputs 974 + 975 + if '<image>' in conversation[0]['content'] and test_compress: 976 + outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:]) 977 + pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False)) 978 + print('='*50) 979 + print('image size: ', (w, h)) 980 + print('valid image tokens: ', int(valid_img_tokens)) 981 + print('output texts tokens (valid): ', pure_texts_outputs_token_length) 982 + print('compression ratio: ', round(pure_texts_outputs_token_length/valid_img_tokens, 2)) 983 + print('='*50) 984 + 985 + 986 + if '<image>' in conversation[0]['content'] and save_results: 987 + outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:]) 988 + stop_str = '<|end▁of▁sentence|>' 989 + 990 + print('='*15 + 'save results:' + '='*15) 991 + 992 + # # # # conv.messages[-1][-1] = outputs 993 + if outputs.endswith(stop_str): 994 + outputs = outputs[:-len(stop_str)] 995 + outputs = outputs.strip() 996 + 997 + matches_ref, matches_images, mathes_other = re_match(outputs) 998 + # print(matches_ref) 999 + result = process_image_with_refs(image_draw, matches_ref, output_path) 1000 + 1001 + 1002 + for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")): 1003 + outputs = outputs.replace(a_match_image, '![](images/' + str(idx) + '.jpg)\n') 1004 + 1005 + for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")): 1006 + outputs = outputs.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:') 1007 + 1008 + 1009 + # if 'structural formula' in conversation[0]['content']: 1010 + # outputs = '<smiles>' + outputs + '</smiles>' 1011 + with open(f'{output_path}/result.mmd', 'w', encoding = 'utf-8') as afile: 1012 + afile.write(outputs) 1013 + 1014 + if 'line_type' in outputs: 1015 + import matplotlib.pyplot as plt 1016 + lines = eval(outputs)['Line']['line'] 1017 + 1018 + line_type = eval(outputs)['Line']['line_type'] 1019 + # print(lines) 1020 + 1021 + endpoints = eval(outputs)['Line']['line_endpoint'] 1022 + 1023 + fig, ax = plt.subplots(figsize=(3,3), dpi=200) 1024 + ax.set_xlim(-15, 15) 1025 + ax.set_ylim(-15, 15) 1026 + 1027 + for idx, line in enumerate(lines): 1028 + try: 1029 + p0 = eval(line.split(' -- ')[0]) 1030 + p1 = eval(line.split(' -- ')[-1]) 1031 + 1032 + if line_type[idx] == '--': 1033 + ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color='k') 1034 + else: 1035 + ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth = 0.8, color = 'k') 1036 + 1037 + ax.scatter(p0[0], p0[1], s=5, color = 'k') 1038 + ax.scatter(p1[0], p1[1], s=5, color = 'k') 1039 + except: 1040 + pass 1041 + 1042 + for endpoint in endpoints: 1043 + 1044 + label = endpoint.split(': ')[0] 1045 + (x, y) = eval(endpoint.split(': ')[1]) 1046 + ax.annotate(label, (x, y), xytext=(1, 1), textcoords='offset points', 1047 + fontsize=5, fontweight='light') 1048 + 1049 + 1050 + plt.savefig(f'{output_path}/geo.jpg') 1051 + plt.close() 1052 + 1053 + result.save(f"{output_path}/result_with_boxes.jpg")
+2 -1
deepseek/src/eval.py
··· 77 77 crop_mode=crop_mode, 78 78 save_results=False, 79 79 test_compress=False, 80 + eval_mode=True, 80 81 ) 81 - pred = result if isinstance(result, str) else str(result) 82 + pred = result if isinstance(result, str) else "" 82 83 match = normalize(pred) == normalize(r["typst"]) 83 84 if match: 84 85 correct += 1