this repo has no description
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

Add DeepSeek-OCR-2 data collator

src/collate_deepseek.py:
- Letterbox-pads images to 768×768 (gray fill, mean=0.5), normalizes
- Inserts 145 image tokens (12²+1) at sequence start
- Masks image+prompt prefix with -100; trains on response+EOS only
- Builds images_seq_mask and images_spatial_crop for forward()
- crop_mode=False: single global view, same tensor for both tuple slots
(TODO: validate images tuple format against model source)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

+165
+165
src/collate_deepseek.py
··· 1 + """ 2 + Data collator for DeepSeek-OCR-2 fine-tuning. 3 + 4 + Preprocesses image+typst pairs into the tensor format expected by 5 + DeepseekOCR2ForCausalLM.forward(): 6 + 7 + - PIL image letterbox-padded to 768×768 (neutral gray fill, mean=0.5) 8 + - 145 image tokens at the start of the sequence (12² spatial + 1 separator) 9 + - prompt tokens follow, then response tokens + EOS 10 + - labels mask the image+prompt prefix with -100 (train on response only) 11 + 12 + Design decisions: 13 + - crop_mode=False: single global view, no dynamic tiling. Our math 14 + expression images are small fragments -- tiling adds no information and 15 + wastes tokens. 16 + - image_size=768: fixed encoder input (SAM ViT-B, patch=16, downsample=4 17 + → 12×12 grid = 144 + 1 separator = 145 tokens). 18 + - The `images` argument to forward() is a list of (local_crops, global_view) 19 + tuples. With crop_mode=False there are no distinct local crops; we pass 20 + the same 768×768 tensor for both slots. This matches the single-tile 21 + case in infer() but should be validated against the model source once 22 + available locally. 23 + """ 24 + 25 + import math 26 + 27 + import torch 28 + from PIL import Image, ImageOps 29 + from torch.nn.utils.rnn import pad_sequence 30 + from torchvision import transforms 31 + 32 + # ── Constants ────────────────────────────────────────────────────────────────── 33 + 34 + IMAGE_TOKEN_ID = 128815 # <image> placeholder token in the DeepSeek-OCR-2 vocab 35 + IMAGE_SIZE = 768 # fixed encoder input resolution (px) 36 + PATCH_SIZE = 16 # SAM ViT-B patch size (px) 37 + DOWNSAMPLE = 4 # spatial downsample ratio inside the encoder 38 + 39 + # Number of image tokens inserted per image: 40 + # ceil(768 / 16 / 4) = 12 → 12×12 + 1 separator = 145 41 + _N_GRID = math.ceil(IMAGE_SIZE / PATCH_SIZE / DOWNSAMPLE) # 12 42 + N_IMAGE_TOKENS = _N_GRID * _N_GRID + 1 # 145 43 + 44 + _MEAN = (0.5, 0.5, 0.5) 45 + _STD = (0.5, 0.5, 0.5) 46 + _PAD_COLOR = tuple(int(x * 255) for x in _MEAN) # (127, 127, 127) gray 47 + 48 + _transform = transforms.Compose([ 49 + transforms.ToTensor(), 50 + transforms.Normalize(mean=_MEAN, std=_STD), 51 + ]) 52 + 53 + # Prompt appended after the image tokens. Leading \n because <image> is the 54 + # very first token and the model expects a newline before the text instruction. 55 + _PROMPT = "\nTranscribe this image to Typst notation.\n" 56 + 57 + 58 + # ── Image preprocessing ──────────────────────────────────────────────────────── 59 + 60 + def preprocess_image(img: Image.Image) -> torch.Tensor: 61 + """ 62 + Letterbox-pad a PIL image to IMAGE_SIZE×IMAGE_SIZE, normalize. 63 + 64 + ImageOps.pad preserves aspect ratio by centering and filling borders with 65 + _PAD_COLOR, so non-square images are not distorted. 66 + 67 + Returns float16 tensor of shape [3, IMAGE_SIZE, IMAGE_SIZE]. 68 + """ 69 + img = img.convert("RGB") 70 + img = ImageOps.pad(img, (IMAGE_SIZE, IMAGE_SIZE), color=_PAD_COLOR) 71 + return _transform(img).to(torch.bfloat16) 72 + 73 + 74 + # ── Collator ─────────────────────────────────────────────────────────────────── 75 + 76 + class DeepSeekOCRCollator: 77 + """ 78 + Collates MathOCRDataset items into inputs for DeepseekOCR2ForCausalLM. 79 + 80 + Each dataset item has the structure produced by MathOCRDataset.__getitem__: 81 + { 82 + "messages": [ 83 + {"role": "user", "content": [{"type": "image", "image": <PIL>}, 84 + {"type": "text", "text": PROMPT}]}, 85 + {"role": "assistant", "content": [{"type": "text", "text": <typst>}]}, 86 + ] 87 + } 88 + 89 + Output dict keys match DeepseekOCR2ForCausalLM.forward() signature: 90 + input_ids, attention_mask, labels, 91 + images, images_seq_mask, images_spatial_crop 92 + """ 93 + 94 + def __init__(self, tokenizer) -> None: 95 + self.tokenizer = tokenizer 96 + self.prompt_ids = tokenizer.encode(_PROMPT, add_special_tokens=False) 97 + self.eos_id = tokenizer.eos_token_id 98 + self.pad_id = (tokenizer.pad_token_id 99 + if tokenizer.pad_token_id is not None 100 + else tokenizer.eos_token_id) 101 + 102 + def __call__(self, batch: list[dict]) -> dict: 103 + input_ids_list = [] 104 + labels_list = [] 105 + seq_mask_list = [] 106 + img_tensors = [] 107 + spatial_crops = [] 108 + 109 + for item in batch: 110 + user_content = item["messages"][0]["content"] 111 + pil_img = user_content[0]["image"] # PIL Image 112 + typst = item["messages"][1]["content"][0]["text"] # target string 113 + 114 + # Image → tensor ─────────────────────────────────────────────────── 115 + img_t = preprocess_image(pil_img) # [3, 768, 768] bfloat16 116 + img_tensors.append(img_t) 117 + 118 + # Token sequence ─────────────────────────────────────────────────── 119 + # Layout: [img×145] [prompt] [response] [EOS] 120 + response_ids = self.tokenizer.encode(typst, add_special_tokens=False) 121 + img_ids = [IMAGE_TOKEN_ID] * N_IMAGE_TOKENS 122 + ids = img_ids + self.prompt_ids + response_ids + [self.eos_id] 123 + 124 + # Labels: -100 on image+prompt; train only on response+EOS 125 + n_prefix = N_IMAGE_TOKENS + len(self.prompt_ids) 126 + lbl = [-100] * n_prefix + response_ids + [self.eos_id] 127 + 128 + # images_seq_mask: True at the N_IMAGE_TOKENS image token positions 129 + mask = [False] * len(ids) 130 + for i in range(N_IMAGE_TOKENS): 131 + mask[i] = True 132 + 133 + input_ids_list.append(torch.tensor(ids, dtype=torch.long)) 134 + labels_list.append(torch.tensor(lbl, dtype=torch.long)) 135 + seq_mask_list.append(torch.tensor(mask, dtype=torch.bool)) 136 + # [width_crops, height_crops] = [1, 1] for single global view 137 + spatial_crops.append([1, 1]) 138 + 139 + # Pad sequences ──────────────────────────────────────────────────────── 140 + input_ids = pad_sequence(input_ids_list, batch_first=True, 141 + padding_value=self.pad_id) 142 + labels = pad_sequence(labels_list, batch_first=True, 143 + padding_value=-100) 144 + seq_mask = pad_sequence(seq_mask_list, batch_first=True, 145 + padding_value=False) 146 + attn_mask = (input_ids != self.pad_id) 147 + 148 + # images: list of (local_crops, global_view) tuples, one per sample. 149 + # With crop_mode=False there is one tile; the same tensor serves both 150 + # slots. Shape per slot: [1, 3, 768, 768]. 151 + # TODO: validate tuple format against model source once loaded locally. 152 + imgs_stacked = torch.stack(img_tensors) # [B, 3, 768, 768] 153 + images = [ 154 + (imgs_stacked[i].unsqueeze(0), imgs_stacked[i].unsqueeze(0)) 155 + for i in range(len(batch)) 156 + ] 157 + 158 + return { 159 + "input_ids": input_ids, # [B, T] 160 + "attention_mask": attn_mask, # [B, T] 161 + "labels": labels, # [B, T] 162 + "images": images, # list[B] of (tensor, tensor) 163 + "images_seq_mask": seq_mask, # [B, T] 164 + "images_spatial_crop": torch.tensor(spatial_crops, dtype=torch.long), # [B, 2] 165 + }