this repo has no description
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

WIP: new deepseek-ocr 2 impl

+530 -29
+11 -11
README.md
··· 12 12 ### Overview 13 13 14 14 The training set combines handwritten math datasets with synthetically rendered 15 - Typst documents. Handwriting-font splits (`hw_*`) cover both single-equation 15 + Typst documents. Handwriting-font splits (`typeset_*`) cover both single-equation 16 16 and structured document content rendered in 6 diverse fonts plus the Typst 17 17 default, sampled uniformly. 18 18 ··· 46 46 47 47 | Split | Samples | Notes | 48 48 |---|---|---| 49 - | `hw_structured_train` | 15,000 | Whole document uses one uniformly-sampled font. | 50 - | `hw_mixed_train` | 10,000 | Per-paragraph font mixing (~55% of blocks get hw font); requires multi-block bodies. | 49 + | `typeset_uniform_train` | 15,000 | Whole document uses one uniformly-sampled font. | 50 + | `typeset_mixed_train` | 10,000 | Per-paragraph font mixing (~55% of blocks get hw font); requires multi-block bodies. | 51 51 52 52 Body types and generation weights (see `src/generate_mixed.py`): 53 53 ··· 66 66 67 67 - **Math-only splits** (`mathwriting_*`, `crohme_*`): manifest stores bare math 68 68 expressions. `data.load_records()` wraps these as `$ ... $` at load time. 69 - - **hw_* splits**: manifest stores complete body content with inline `$...$` 69 + - **typeset_* splits**: manifest stores complete body content with inline `$...$` 70 70 delimiters already present. No wrapping applied. 71 71 72 72 ### Effective training mix (after caps) ··· 77 77 | `mathwriting_synthetic` | ~85,879 | **20,000** | 21% | 78 78 | `crohme_gen_2019` | ~51,855 | **15,000** | 16% | 79 79 | `crohme_gen_syntactic` | ~69,397 | **15,000** | 16% | 80 - | `hw_structured_train` | 15,000 | 15,000 | 16% | 81 - | `hw_mixed_train` | 10,000 | 10,000 | 11% | 80 + | `typeset_uniform_train` | 15,000 | 15,000 | 16% | 81 + | `typeset_mixed_train` | 10,000 | 10,000 | 11% | 82 82 | `crohme_gen_2023` | ~3,072 | 3,072 | 3% | 83 83 | `mathwriting_symbols` | ~6,091 | 6,091 | 6% | 84 84 | `crohme_real_train` | ~9 | 9 | <1% | ··· 91 91 | Split | Samples | Notes | 92 92 |---|---|---| 93 93 | `mathwriting_val` | 250 | Real handwritten single equations. | 94 - | `hw_structured_val` | 250 | Whole-doc font document fragments. | 95 - | `hw_mixed_val` | 250 | Per-block mixed-font document fragments. | 94 + | `typeset_uniform_val` | 250 | Whole-doc font document fragments. | 95 + | `typeset_mixed_val` | 250 | Per-block mixed-font document fragments. | 96 96 | **Total** | **750** | | 97 97 98 98 ### Test ··· 100 100 | Split | Notes | 101 101 |---|---| 102 102 | `mathwriting_test` | Held-out real handwritten equations. | 103 - | `hw_structured_test` | Held-out whole-doc font document fragments. | 104 - | `hw_mixed_test` | Held-out mixed-font document fragments. | 103 + | `typeset_uniform_test` | Held-out whole-doc font document fragments. | 104 + | `typeset_mixed_test` | Held-out mixed-font document fragments. | 105 105 106 106 ### Known gaps 107 107 108 - - `hw_*` splits are font-based renders, not real handwriting photos. The model 108 + - `typeset_*` splits are font-based renders, not real handwriting photos. The model 109 109 still lacks real handwritten document fragments at scale. 110 110 - `crohme_real_train` has only 9 samples. 111 111
+2
pyproject.toml
··· 25 25 26 26 [project.scripts] 27 27 train = "src.train:main" 28 + train-deepseek = "src.train_deepseek:main" 28 29 evaluate = "src.eval:main" 30 + evaluate-deepseek = "src.eval_deepseek:main" 29 31 train-hnm = "src.train_hnm:main" 30 32 export = "src.export:main" 31 33 generate-typeset = "src.generate_typeset:main"
+54 -18
src/data.py
··· 52 52 PROMPT = "Transcribe this image to Typst notation. Output only the raw Typst, without explanation. No LaTeX, only Typst." 53 53 BASE_MODEL = "unsloth/gemma-4-E2B-it" 54 54 55 + DEEPSEEK_PROMPT = "<image>\nFree OCR. " 56 + DEEPSEEK_MODEL_DIR = str(Path(__file__).parent.parent / "deepseek_ocr2") 57 + 55 58 56 59 _NUM_RE = re.compile(r"\d+(\.\d+)?") 57 60 _VAR_RE = re.compile(r"\b[a-zA-Z]\b") ··· 128 131 return img 129 132 130 133 134 + def gemma_format(record: dict, img: Image.Image) -> dict: 135 + return { 136 + "messages": [ 137 + { 138 + "role": "user", 139 + "content": [ 140 + {"type": "image", "image": img}, 141 + {"type": "text", "text": PROMPT}, 142 + ], 143 + }, 144 + { 145 + "role": "assistant", 146 + "content": [{"type": "text", "text": record["typst"]}], 147 + }, 148 + ] 149 + } 150 + 151 + 152 + def deepseek_format(record: dict, img: Image.Image) -> dict: 153 + return { 154 + "messages": [ 155 + { 156 + "role": "<|User|>", 157 + "content": DEEPSEEK_PROMPT, 158 + "images": [img], 159 + }, 160 + { 161 + "role": "<|Assistant|>", 162 + "content": record["typst"], 163 + }, 164 + ] 165 + } 166 + 167 + 131 168 class MathOCRDataset(Dataset): 132 169 """ 133 170 Torch Dataset with lazy image loading. Avoids HuggingFace datasets 134 171 fingerprinting, which uses dill and breaks on Python 3.14. 172 + 173 + format_fn(record, img) -> dict with a "messages" key in the model's 174 + expected conversation format. Defaults to Gemma format. 135 175 """ 136 176 137 - def __init__(self, records: list[dict], do_augment: bool = False) -> None: 177 + def __init__( 178 + self, 179 + records: list[dict], 180 + do_augment: bool = False, 181 + format_fn=None, 182 + ) -> None: 138 183 self.records = records 139 184 self.do_augment = do_augment 185 + self.format_fn = format_fn if format_fn is not None else gemma_format 140 186 141 187 def __len__(self) -> int: 142 188 return len(self.records) ··· 146 192 img = Image.open(r["image_path"]).convert("RGB") 147 193 if self.do_augment: 148 194 img = _augment(img) 149 - return { 150 - "messages": [ 151 - { 152 - "role": "user", 153 - "content": [ 154 - {"type": "image", "image": img}, 155 - {"type": "text", "text": PROMPT}, 156 - ], 157 - }, 158 - { 159 - "role": "assistant", 160 - "content": [{"type": "text", "text": r["typst"]}], 161 - }, 162 - ] 163 - } 195 + return self.format_fn(r, img) 164 196 165 197 166 - def make_dataset(records: list[dict], do_augment: bool = False) -> MathOCRDataset: 167 - return MathOCRDataset(records, do_augment) 198 + def make_dataset( 199 + records: list[dict], 200 + do_augment: bool = False, 201 + format_fn=None, 202 + ) -> MathOCRDataset: 203 + return MathOCRDataset(records, do_augment, format_fn)
+115
src/eval_deepseek.py
··· 1 + """ 2 + ExpRate evaluation of DeepSeek-OCR-2 on the test splits. 3 + 4 + Inference uses model.infer() (the model's native API), which takes a file path, 5 + so images are written to a temp file per sample. 6 + 7 + Usage: uv run evaluate-deepseek [--checkpoint checkpoints/deepseek-ocr2/final] 8 + """ 9 + 10 + import argparse 11 + import os 12 + import re 13 + import tempfile 14 + 15 + import torch 16 + from tqdm import tqdm 17 + from transformers import AutoModel 18 + from unsloth import FastVisionModel 19 + 20 + from .data import DEEPSEEK_MODEL_DIR, TEST_SPLITS, load_records 21 + 22 + os.environ["UNSLOTH_WARN_UNINITIALIZED"] = "0" 23 + 24 + 25 + def normalize(s: str) -> str: 26 + return re.sub(r"\s+", " ", s).strip() 27 + 28 + 29 + def evaluate( 30 + checkpoint: str, 31 + n: int | None = None, 32 + verbose: bool = False, 33 + image_size: int = 768, 34 + base_size: int = 1024, 35 + crop_mode: bool = True, 36 + ) -> float: 37 + model, tokenizer = FastVisionModel.from_pretrained( 38 + checkpoint, 39 + load_in_4bit=False, 40 + auto_model=AutoModel, 41 + trust_remote_code=True, 42 + unsloth_force_compile=True, 43 + use_gradient_checkpointing="unsloth", 44 + ) 45 + FastVisionModel.for_inference(model) 46 + model.eval() 47 + 48 + import random 49 + rng = random.Random(42) 50 + records = [] 51 + for split in TEST_SPLITS: 52 + split_recs = load_records([split], dedupe=False) 53 + sample = rng.sample(split_recs, min(n, len(split_recs))) if n is not None else split_recs 54 + records.extend(sample) 55 + 56 + correct = 0 57 + prompt = "<image>\nFree OCR. " 58 + 59 + with tempfile.TemporaryDirectory() as tmpdir: 60 + tmp_img = os.path.join(tmpdir, "img.png") 61 + for r in tqdm(records, desc="Evaluating"): 62 + # model.infer requires a file path 63 + import shutil 64 + shutil.copy(r["image_path"], tmp_img) 65 + 66 + result = model.infer( 67 + tokenizer, 68 + prompt=prompt, 69 + image_file=tmp_img, 70 + output_path=tmpdir, 71 + image_size=image_size, 72 + base_size=base_size, 73 + crop_mode=crop_mode, 74 + save_results=False, 75 + test_compress=False, 76 + ) 77 + # model.infer returns the transcription string directly 78 + pred = result if isinstance(result, str) else str(result) 79 + 80 + match = normalize(pred) == normalize(r["typst"]) 81 + if match: 82 + correct += 1 83 + if verbose: 84 + status = "OK" if match else "FAIL" 85 + print(f"[{status}] split={r['split']}") 86 + print(f" GT: {r['typst']}") 87 + print(f" PRED: {pred}") 88 + 89 + exprate = correct / len(records) 90 + split_label = f" ({n}/split)" if n is not None else "" 91 + print(f"ExpRate{split_label}: {exprate:.4f} ({correct}/{len(records)})") 92 + return exprate 93 + 94 + 95 + def main() -> None: 96 + parser = argparse.ArgumentParser() 97 + parser.add_argument("--checkpoint", default=f"checkpoints/deepseek-ocr2/final") 98 + parser.add_argument("--n", type=int, default=None, help="Evaluate N random examples per split") 99 + parser.add_argument("--verbose", action="store_true") 100 + parser.add_argument("--image-size", type=int, default=768) 101 + parser.add_argument("--base-size", type=int, default=1024) 102 + parser.add_argument("--no-crop", action="store_true", help="Disable dynamic cropping") 103 + args = parser.parse_args() 104 + evaluate( 105 + args.checkpoint, 106 + n=args.n, 107 + verbose=args.verbose, 108 + image_size=args.image_size, 109 + base_size=args.base_size, 110 + crop_mode=not args.no_crop, 111 + ) 112 + 113 + 114 + if __name__ == "__main__": 115 + main()
+348
src/train_deepseek.py
··· 1 + """ 2 + LoRA fine-tuning of DeepSeek-OCR-2 for Typst math OCR. 3 + 4 + The model uses a fixed "Free OCR" prompt with no way to specify output format, 5 + so finetuning replaces the LaTeX prior with Typst via supervised examples. 6 + 7 + Usage: uv run train-deepseek [--epochs 2] [--cap mathwriting_synthetic 20000] 8 + 9 + Requires the model weights to be downloaded first: 10 + python -c "from huggingface_hub import snapshot_download; snapshot_download('unsloth/DeepSeek-OCR-2', local_dir='deepseek_ocr2')" 11 + """ 12 + 13 + from unsloth import FastVisionModel, is_bf16_supported 14 + from PIL import Image, ImageOps 15 + import torch 16 + from torch.nn.utils.rnn import pad_sequence 17 + from transformers import AutoModel, Trainer, TrainingArguments 18 + 19 + import argparse 20 + import io 21 + import math 22 + import os 23 + import random 24 + from dataclasses import dataclass 25 + from pathlib import Path 26 + from typing import Any, Dict, List, Tuple 27 + 28 + 29 + from .data import ( 30 + DEEPSEEK_MODEL_DIR, 31 + TRAIN_SPLITS, 32 + VAL_SPLITS, 33 + deepseek_format, 34 + load_records, 35 + make_dataset, 36 + ) 37 + 38 + os.environ["UNSLOTH_WARN_UNINITIALIZED"] = "0" 39 + 40 + _DEFAULT_CAPS = { 41 + "mathwriting_synthetic": 20_000, 42 + "crohme_gen_2019": 15_000, 43 + "crohme_gen_syntactic": 15_000, 44 + "mathwriting_train": 10_000, 45 + } 46 + 47 + 48 + # --------------------------------------------------------------------------- 49 + # Data collator (DeepSeek-OCR-2 specific) 50 + # --------------------------------------------------------------------------- 51 + 52 + @dataclass 53 + class DeepSeekOCR2DataCollator: 54 + tokenizer: Any 55 + model: Any 56 + image_size: int = 768 57 + base_size: int = 1024 58 + crop_mode: bool = True 59 + train_on_responses_only: bool = True 60 + 61 + def __post_init__(self) -> None: 62 + from deepseek_ocr2.modeling_deepseekocr2 import BasicImageTransform, dynamic_preprocess # noqa: F401 63 + self._dynamic_preprocess = dynamic_preprocess 64 + self.image_token_id = 128815 65 + self.dtype = self.model.dtype 66 + self.patch_size = 16 67 + self.downsample_ratio = 4 68 + self.image_transform = BasicImageTransform( 69 + mean=(0.5, 0.5, 0.5), 70 + std=(0.5, 0.5, 0.5), 71 + normalize=True, 72 + ) 73 + if hasattr(self.tokenizer, "bos_token_id") and self.tokenizer.bos_token_id is not None: 74 + self.bos_id = self.tokenizer.bos_token_id 75 + else: 76 + self.bos_id = 0 77 + 78 + def _to_pil(self, image_data) -> Image.Image: 79 + if isinstance(image_data, Image.Image): 80 + return image_data.convert("RGB") 81 + if isinstance(image_data, dict) and "bytes" in image_data: 82 + return Image.open(io.BytesIO(image_data["bytes"])).convert("RGB") 83 + raise ValueError(f"Unsupported image format: {type(image_data)}") 84 + 85 + def _process_image(self, image: Image.Image) -> Tuple[List, List, List, List, Tuple]: 86 + from deepseek_ocr2.modeling_deepseekocr2 import dynamic_preprocess 87 + 88 + images_list, images_crop_list, images_spatial_crop = [], [], [] 89 + nq = math.ceil((self.image_size // self.patch_size) / self.downsample_ratio) 90 + nq_base = math.ceil((self.base_size // self.patch_size) / self.downsample_ratio) 91 + 92 + if self.crop_mode: 93 + if image.size[0] <= 768 and image.size[1] <= 768: 94 + crop_ratio = (1, 1) 95 + crops_raw = [] 96 + else: 97 + crops_raw, crop_ratio = dynamic_preprocess( 98 + image, min_num=2, max_num=6, 99 + image_size=self.image_size, use_thumbnail=False, 100 + ) 101 + pad_color = tuple(int(x * 255) for x in self.image_transform.mean) 102 + global_view = ImageOps.pad(image, (self.base_size, self.base_size), color=pad_color) 103 + images_list.append(self.image_transform(global_view).to(self.dtype)) 104 + w, h = crop_ratio 105 + images_spatial_crop.append([w, h]) 106 + for c in crops_raw: 107 + images_crop_list.append(self.image_transform(c).to(self.dtype)) 108 + tok = ([self.image_token_id] * nq_base) * nq_base + [self.image_token_id] 109 + if w > 1 or h > 1: 110 + tok += ([self.image_token_id] * (nq * w)) * (nq * h) 111 + else: 112 + crop_ratio = (1, 1) 113 + images_spatial_crop.append([1, 1]) 114 + if self.base_size <= 768: 115 + resized = image.resize((self.base_size, self.base_size), Image.LANCZOS) 116 + images_list.append(self.image_transform(resized).to(self.dtype)) 117 + else: 118 + pad_color = tuple(int(x * 255) for x in self.image_transform.mean) 119 + global_view = ImageOps.pad(image, (self.base_size, self.base_size), color=pad_color) 120 + images_list.append(self.image_transform(global_view).to(self.dtype)) 121 + tok = ([self.image_token_id] * nq) * nq + [self.image_token_id] 122 + 123 + return images_list, images_crop_list, images_spatial_crop, tok, crop_ratio 124 + 125 + def _process_sample(self, messages: List[Dict]) -> Dict[str, Any]: 126 + from deepseek_ocr2.modeling_deepseekocr2 import text_encode 127 + 128 + images = [ 129 + self._to_pil(img) 130 + for msg in messages 131 + for img in msg.get("images", []) 132 + if img is not None 133 + ] 134 + if not images: 135 + raise ValueError("No images in sample.") 136 + 137 + tokenized_str, images_seq_mask = [self.bos_id], [False] 138 + images_list, images_crop_list, images_spatial_crop = [], [], [] 139 + prompt_token_count = -1 140 + assistant_started = False 141 + image_idx = 0 142 + 143 + for msg in messages: 144 + role = msg["role"] 145 + content = msg["content"] 146 + if role == "<|Assistant|>": 147 + if not assistant_started: 148 + prompt_token_count = len(tokenized_str) 149 + assistant_started = True 150 + content = f"{content.strip()} {self.tokenizer.eos_token}" 151 + 152 + for i, part in enumerate(content.split("<image>")): 153 + toks = text_encode(self.tokenizer, part, bos=False, eos=False) 154 + tokenized_str.extend(toks) 155 + images_seq_mask.extend([False] * len(toks)) 156 + if i < len(content.split("<image>")) - 1: 157 + img_list, crop_list, spatial, tok_img, _ = self._process_image(images[image_idx]) 158 + images_list.extend(img_list) 159 + images_crop_list.extend(crop_list) 160 + images_spatial_crop.extend(spatial) 161 + tokenized_str.extend(tok_img) 162 + images_seq_mask.extend([True] * len(tok_img)) 163 + image_idx += 1 164 + 165 + if not assistant_started: 166 + prompt_token_count = len(tokenized_str) 167 + 168 + images_ori = torch.stack(images_list, dim=0) 169 + images_spatial_crop_t = torch.tensor(images_spatial_crop, dtype=torch.long) 170 + if images_crop_list: 171 + images_crop = torch.stack(images_crop_list, dim=0) 172 + else: 173 + images_crop = torch.zeros((1, 3, self.base_size, self.base_size), dtype=self.dtype) 174 + 175 + return { 176 + "input_ids": torch.tensor(tokenized_str, dtype=torch.long), 177 + "images_seq_mask": torch.tensor(images_seq_mask, dtype=torch.bool), 178 + "images_ori": images_ori, 179 + "images_crop": images_crop, 180 + "images_spatial_crop": images_spatial_crop_t, 181 + "prompt_token_count": prompt_token_count, 182 + } 183 + 184 + def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: 185 + batch_data = [] 186 + for feat in features: 187 + try: 188 + batch_data.append(self._process_sample(feat["messages"])) 189 + except Exception as e: 190 + print(f"Skipping sample: {e}") 191 + if not batch_data: 192 + raise ValueError("Empty batch after collation.") 193 + 194 + input_ids = pad_sequence( 195 + [d["input_ids"] for d in batch_data], 196 + batch_first=True, 197 + padding_value=self.tokenizer.pad_token_id, 198 + ) 199 + images_seq_mask = pad_sequence( 200 + [d["images_seq_mask"] for d in batch_data], 201 + batch_first=True, 202 + padding_value=False, 203 + ) 204 + labels = input_ids.clone() 205 + labels[labels == self.tokenizer.pad_token_id] = -100 206 + labels[images_seq_mask] = -100 207 + if self.train_on_responses_only: 208 + for i, d in enumerate(batch_data): 209 + pc = d["prompt_token_count"] 210 + if pc > 0: 211 + labels[i, :pc] = -100 212 + 213 + attention_mask = (input_ids != self.tokenizer.pad_token_id).long() 214 + images_batch = [(d["images_crop"], d["images_ori"]) for d in batch_data] 215 + images_spatial_crop = torch.cat([d["images_spatial_crop"] for d in batch_data], dim=0) 216 + 217 + return { 218 + "input_ids": input_ids, 219 + "attention_mask": attention_mask, 220 + "labels": labels, 221 + "images": images_batch, 222 + "images_seq_mask": images_seq_mask, 223 + "images_spatial_crop": images_spatial_crop, 224 + } 225 + 226 + 227 + # --------------------------------------------------------------------------- 228 + # Main 229 + # --------------------------------------------------------------------------- 230 + 231 + def main() -> None: 232 + parser = argparse.ArgumentParser() 233 + parser.add_argument("--epochs", type=int, default=2) 234 + parser.add_argument("--lr", type=float, default=2e-4) 235 + parser.add_argument("--output-dir", default="checkpoints/deepseek-ocr2") 236 + parser.add_argument("--model-dir", default=DEEPSEEK_MODEL_DIR) 237 + parser.add_argument("--cap", nargs=2, action="append", metavar=("SPLIT", "N"), 238 + default=None) 239 + args = parser.parse_args() 240 + 241 + caps = dict(_DEFAULT_CAPS) 242 + if args.cap: 243 + for split, n in args.cap: 244 + caps[split] = int(n) 245 + 246 + if not Path(args.model_dir).resolve().exists(): 247 + print(f"{args.model_dir} not found -- downloading from HuggingFace...") 248 + from huggingface_hub import snapshot_download 249 + snapshot_download("unsloth/DeepSeek-OCR-2", local_dir=args.model_dir) 250 + 251 + model, tokenizer = FastVisionModel.from_pretrained( 252 + args.model_dir, 253 + load_in_4bit=False, 254 + auto_model=AutoModel, 255 + trust_remote_code=True, 256 + unsloth_force_compile=True, 257 + use_gradient_checkpointing="unsloth", 258 + ) 259 + 260 + model = FastVisionModel.get_peft_model( 261 + model, 262 + target_modules=["q_proj", "k_proj", "v_proj", "o_proj", 263 + "gate_proj", "up_proj", "down_proj"], 264 + r=16, 265 + lora_alpha=16, 266 + lora_dropout=0, 267 + bias="none", 268 + random_state=3407, 269 + use_rslora=False, 270 + loftq_config=None, 271 + ) 272 + FastVisionModel.for_training(model) 273 + 274 + rng = random.Random(42) 275 + train_records: list[dict] = [] 276 + for split in TRAIN_SPLITS: 277 + recs = load_records([split], dedupe=False) 278 + cap = caps.get(split) 279 + if cap and len(recs) > cap: 280 + recs = rng.sample(recs, cap) 281 + train_records.extend(recs) 282 + rng.shuffle(train_records) 283 + 284 + val_rng = random.Random(42) 285 + val_records: list[dict] = [] 286 + for split in VAL_SPLITS: 287 + recs = load_records([split], dedupe=False) 288 + val_records += val_rng.sample(recs, min(250, len(recs))) 289 + val_rng.shuffle(val_records) 290 + 291 + print(f"Train: {len(train_records):,} Val: {len(val_records):,}") 292 + for split in TRAIN_SPLITS: 293 + n = sum(1 for r in train_records if r["split"] == split) 294 + cap = caps.get(split) 295 + print(f" {split}: {n:,}" + (f" (cap {cap:,})" if cap else "")) 296 + 297 + train_ds = make_dataset(train_records, do_augment=True, format_fn=deepseek_format) 298 + val_ds = make_dataset(val_records, do_augment=False, format_fn=deepseek_format) 299 + 300 + data_collator = DeepSeekOCR2DataCollator( 301 + tokenizer=tokenizer, 302 + model=model, 303 + image_size=768, 304 + base_size=1024, 305 + crop_mode=True, 306 + train_on_responses_only=True, 307 + ) 308 + 309 + out_dir = args.output_dir 310 + trainer = Trainer( 311 + model=model, 312 + tokenizer=tokenizer, 313 + data_collator=data_collator, 314 + train_dataset=train_ds, 315 + eval_dataset=val_ds, 316 + args=TrainingArguments( 317 + per_device_train_batch_size=2, 318 + per_device_eval_batch_size=2, 319 + gradient_accumulation_steps=8, 320 + num_train_epochs=args.epochs, 321 + learning_rate=args.lr, 322 + warmup_steps=500, 323 + lr_scheduler_type="cosine", 324 + fp16=not is_bf16_supported(), 325 + bf16=is_bf16_supported(), 326 + dataloader_num_workers=2, 327 + logging_steps=50, 328 + eval_strategy="steps", 329 + eval_steps=500, 330 + save_steps=500, 331 + save_total_limit=7, 332 + load_best_model_at_end=False, 333 + output_dir=out_dir, 334 + run_name="deepseek-ocr2", 335 + report_to=["tensorboard"], 336 + remove_unused_columns=False, 337 + ), 338 + ) 339 + 340 + has_checkpoint = any(Path(out_dir).glob("checkpoint-*")) if Path(out_dir).exists() else False 341 + trainer.train(resume_from_checkpoint=has_checkpoint) 342 + model.save_pretrained(f"{out_dir}/final") 343 + tokenizer.save_pretrained(f"{out_dir}/final") 344 + print(f"Saved to {out_dir}/final") 345 + 346 + 347 + if __name__ == "__main__": 348 + main()