this repo has no description
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

Import eff-mer CNN+Transformer model for comparison baseline

Copies eff_mer package (encoder, decoder, vocab, data, train, infer)
into src/eff_mer/ and adds eff-mer-evaluate entrypoint. Paths updated
to resolve eff-mer data relative to sibling repo location.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

+1326 -1
+1 -1
.gitignore
··· 1 - src/__pycache__/ 1 + __pycache__/ 2 2 data/ 3 3 unsloth_compiled_cache/
+3
pyproject.toml
··· 10 10 "torch>=2.3", 11 11 "editdistance", 12 12 "tqdm", 13 + "numpy", 14 + "torchvision>=0.18", 13 15 ] 14 16 15 17 [project.scripts] ··· 20 22 export = "src.export:main" 21 23 generate-typeset = "src.generate_typeset:main" 22 24 probe = "src.probe:main" 25 + eff-mer-evaluate = "src.eff_mer.infer:main" 23 26 24 27 [build-system] 25 28 requires = ["hatchling"]
src/eff_mer/__init__.py

This is a binary file and will not be displayed.

+160
src/eff_mer/data.py
··· 1 + """ 2 + Dataset for math expression recognition. 3 + 4 + Each sample: (image_tensor, token_ids) 5 + image_tensor: float32 (1, H, W), normalized to ImageNet grayscale stats 6 + token_ids: int64 [<sos>, tok..., <eos>], length <= max_len 7 + 8 + Images are scaled to a fixed height (IMG_H) preserving aspect ratio, then 9 + width is capped at MAX_W and padded to a multiple of STRIDE. 10 + 11 + collate_fn pads images to the batch-max width and returns a spatial memory 12 + mask (B, H_tok * W_tok) -- True at padding token positions -- alongside the 13 + usual token padding mask. 14 + """ 15 + 16 + import random 17 + from pathlib import Path 18 + 19 + import torch 20 + import torchvision.transforms.functional as TF 21 + from PIL import Image, ImageFilter, ImageOps 22 + from torch.utils.data import Dataset 23 + 24 + from .vocab import Vocab 25 + 26 + # ImageNet RGB mean/std averaged across channels, for grayscale input 27 + _MEAN = 0.449 28 + _STD = 0.226 29 + 30 + # Image geometry 31 + IMG_H = 256 # fixed height; must be a multiple of STRIDE 32 + MAX_W = 1024 # maximum width before scaling down 33 + STRIDE = 32 # EfficientNet-B0 spatial stride 34 + 35 + 36 + def scale_and_pad(img: Image.Image, h: int = IMG_H, max_w: int = MAX_W) -> Image.Image: 37 + """ 38 + Scale image to fixed height h preserving aspect ratio. 39 + If resulting width exceeds max_w, scale down further. 40 + Pad width to the nearest multiple of STRIDE with white. 41 + """ 42 + orig_w, orig_h = img.size 43 + scale = h / orig_h 44 + new_w = round(orig_w * scale) 45 + if new_w > max_w: 46 + scale = max_w / orig_w 47 + new_h = round(orig_h * scale) 48 + new_w = max_w 49 + else: 50 + new_h = h 51 + img = img.resize((new_w, new_h), Image.LANCZOS) 52 + # Pad height to exactly h (handles rounding), width to multiple of STRIDE 53 + pad_w = (-new_w % STRIDE) or 0 54 + if new_h != h or pad_w: 55 + img = ImageOps.expand(img, (0, 0, pad_w, h - new_h), fill=255) 56 + return img 57 + 58 + 59 + def _augment(img: Image.Image) -> Image.Image: 60 + """ 61 + Online augmentation for training images. 62 + Random affine (rotation, scale, translate) + blur + pixel noise. 63 + All parameters are small to preserve legibility. 64 + """ 65 + angle = random.uniform(-5, 5) 66 + scale = random.uniform(0.9, 1.1) 67 + tx = int(random.uniform(-0.05, 0.05) * img.width) 68 + ty = int(random.uniform(-0.05, 0.05) * img.height) 69 + img = TF.affine(img, angle=angle, translate=(tx, ty), scale=scale, shear=0, fill=255) 70 + 71 + radius = random.uniform(0.0, 0.8) 72 + if radius > 0.15: 73 + img = img.filter(ImageFilter.GaussianBlur(radius)) 74 + 75 + return img 76 + 77 + 78 + def _spatial_mask(w_toks: list[int], H_tok: int, W_tok_max: int) -> torch.Tensor: 79 + """ 80 + Build memory key padding mask (B, H_tok * W_tok_max). 81 + True at positions corresponding to width-padding columns. 82 + """ 83 + B = len(w_toks) 84 + mask = torch.zeros(B, H_tok, W_tok_max, dtype=torch.bool) 85 + for i, wt in enumerate(w_toks): 86 + if wt < W_tok_max: 87 + mask[i, :, wt:] = True 88 + return mask.reshape(B, H_tok * W_tok_max) 89 + 90 + 91 + class MathDataset(Dataset): 92 + def __init__( 93 + self, 94 + manifest_paths: list[Path], 95 + vocab: Vocab, 96 + max_len: int = 256, 97 + augment: bool = False, 98 + ): 99 + self.vocab = vocab 100 + self.max_len = max_len 101 + self.augment = augment 102 + self.records: list[tuple[Path, str]] = [] 103 + 104 + for mp in manifest_paths: 105 + base = mp.parent 106 + for line in mp.open(): 107 + r = __import__("json").loads(line) 108 + typst = r.get("typst", "") 109 + if not typst or typst.startswith("ERROR:"): 110 + continue 111 + self.records.append((base / r["image"], typst)) 112 + 113 + def __len__(self) -> int: 114 + return len(self.records) 115 + 116 + def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: 117 + img_path, typst = self.records[idx] 118 + 119 + img = Image.open(img_path).convert("L") 120 + if self.augment: 121 + img = _augment(img) 122 + img = scale_and_pad(img) 123 + img_t = TF.to_tensor(img) # (1, H, W), float32 in [0, 1] 124 + if self.augment: 125 + img_t = (img_t + torch.randn_like(img_t) * 0.02).clamp(0.0, 1.0) 126 + img_t = TF.normalize(img_t, [_MEAN], [_STD]) 127 + 128 + ids = [self.vocab.sos_id] + self.vocab.encode(typst) + [self.vocab.eos_id] 129 + ids = ids[: self.max_len] 130 + 131 + return img_t, torch.tensor(ids, dtype=torch.long) 132 + 133 + @staticmethod 134 + def collate(batch: list[tuple[torch.Tensor, torch.Tensor]]): 135 + imgs, seqs = zip(*batch) 136 + 137 + # Pad images to batch-max width 138 + H = imgs[0].shape[1] 139 + W_max = max(t.shape[2] for t in imgs) 140 + imgs_t = torch.zeros(len(imgs), 1, H, W_max) 141 + for i, t in enumerate(imgs): 142 + imgs_t[i, :, :, : t.shape[2]] = t 143 + 144 + # Spatial memory mask: True at width-padding token positions 145 + H_tok = H // STRIDE 146 + W_tok_max = W_max // STRIDE 147 + w_toks = [t.shape[2] // STRIDE for t in imgs] 148 + mem_mask = _spatial_mask(w_toks, H_tok, W_tok_max) # (B, H_tok*W_tok_max) 149 + 150 + # Token sequence padding 151 + max_len = max(s.shape[0] for s in seqs) 152 + padded = torch.zeros(len(seqs), max_len, dtype=torch.long) 153 + lengths = torch.zeros(len(seqs), dtype=torch.long) 154 + for i, s in enumerate(seqs): 155 + padded[i, : s.shape[0]] = s 156 + lengths[i] = s.shape[0] 157 + 158 + pad_mask = padded == 0 # (B, T) 159 + 160 + return imgs_t, padded, lengths, pad_mask, mem_mask
+221
src/eff_mer/infer.py
··· 1 + """ 2 + Inference and evaluation for MathOCR. 3 + 4 + Evaluate ExpRate on test splits: 5 + uv run python scripts/infer.py --checkpoint checkpoints/epoch_009.pt 6 + 7 + Run on arbitrary image files: 8 + uv run python scripts/infer.py --checkpoint checkpoints/epoch_009.pt --images a.png b.png 9 + 10 + Both modes can be combined. 11 + """ 12 + 13 + import argparse 14 + import json 15 + from pathlib import Path 16 + 17 + import numpy as np 18 + import torch 19 + import torchvision.transforms.functional as TF 20 + from PIL import Image, ImageOps 21 + from tqdm import tqdm 22 + 23 + # eff-mer data lives in the sibling eff-mer repo 24 + _EFF_MER_ROOT = Path(__file__).parent.parent.parent.parent / "eff-mer" 25 + 26 + from .data import _MEAN, _STD, scale_and_pad 27 + from .model import MathOCR 28 + from .vocab import Vocab 29 + 30 + TEST_SPLITS = {"crohme_test", "mathwriting_Test"} 31 + 32 + 33 + def load_model(checkpoint: Path, vocab: Vocab, device: torch.device) -> MathOCR: 34 + ckpt = torch.load(checkpoint, map_location=device, weights_only=True) 35 + model = MathOCR(vocab_size=len(vocab)).to(device) 36 + model.load_state_dict(ckpt["model"]) 37 + model.eval() 38 + return model 39 + 40 + 41 + def preprocess_photo(img: Image.Image, margin: float = 0.05) -> Image.Image: 42 + """ 43 + Crop a phone photo to the bounding box of dark (ink) pixels. 44 + 45 + Steps: grayscale -> Otsu binarize -> bbox of dark pixels -> crop + margin. 46 + Works for clean shots of a single expression on blank paper. 47 + Not robust to ruled lines -- use --no-preprocess to skip. 48 + """ 49 + gray = np.array(img.convert("L")) 50 + thresh = int(gray.mean()) # simple global threshold; Otsu would need scipy 51 + # Iterative Otsu approximation using numpy 52 + for _ in range(10): 53 + fg = gray[gray < thresh].mean() if (gray < thresh).any() else 0 54 + bg = gray[gray >= thresh].mean() if (gray >= thresh).any() else 255 55 + thresh = int((fg + bg) / 2) 56 + binary = gray < thresh # True = ink 57 + 58 + rows = np.where(binary.any(axis=1))[0] 59 + cols = np.where(binary.any(axis=0))[0] 60 + if rows.size == 0 or cols.size == 0: 61 + return img # no ink found, return as-is 62 + 63 + r0, r1 = int(rows[0]), int(rows[-1]) 64 + c0, c1 = int(cols[0]), int(cols[-1]) 65 + 66 + h, w = gray.shape 67 + pad_r = max(1, int((r1 - r0) * margin)) 68 + pad_c = max(1, int((c1 - c0) * margin)) 69 + r0 = max(0, r0 - pad_r) 70 + r1 = min(h, r1 + pad_r) 71 + c0 = max(0, c0 - pad_c) 72 + c1 = min(w, c1 + pad_c) 73 + 74 + cropped = img.crop((c0, r0, c1, r1)) 75 + # Pad to square to preserve aspect ratio before resize 76 + side = max(cropped.width, cropped.height) 77 + squared = ImageOps.pad(cropped, (side, side), color=255) 78 + return squared 79 + 80 + 81 + def binarize(img: Image.Image) -> Image.Image: 82 + """Otsu threshold -> pure black-on-white, color/lighting invariant.""" 83 + arr = np.array(img, dtype=np.float32) 84 + thresh = arr.mean() 85 + for _ in range(10): 86 + fg = arr[arr < thresh].mean() if (arr < thresh).any() else 0 87 + bg = arr[arr >= thresh].mean() if (arr >= thresh).any() else 255 88 + thresh = (fg + bg) / 2 89 + ink = arr < thresh 90 + # Auto-invert if majority of pixels are dark (light-on-dark input) 91 + if ink.mean() > 0.5: 92 + ink = ~ink 93 + return Image.fromarray(np.where(ink, np.uint8(0), np.uint8(255))) 94 + 95 + 96 + def load_image(path: Path, preprocess: bool = False) -> torch.Tensor: 97 + img = Image.open(path).convert("L") 98 + img = binarize(img) 99 + if preprocess: 100 + img = preprocess_photo(img) 101 + img = scale_and_pad(img) 102 + t = TF.to_tensor(img) 103 + return TF.normalize(t, [_MEAN], [_STD]) 104 + 105 + 106 + def pad_images(tensors: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor | None]: 107 + """ 108 + Pad a list of (1, H, W) tensors to the batch-max width. 109 + Returns (imgs, mem_mask) where mem_mask is None if all widths are equal. 110 + """ 111 + from eff_mer.data import STRIDE, _spatial_mask 112 + H = tensors[0].shape[1] 113 + W_max = max(t.shape[2] for t in tensors) 114 + out = torch.zeros(len(tensors), 1, H, W_max) 115 + for i, t in enumerate(tensors): 116 + out[i, :, :, : t.shape[2]] = t 117 + w_toks = [t.shape[2] // STRIDE for t in tensors] 118 + H_tok = H // STRIDE 119 + W_tok_max = W_max // STRIDE 120 + if all(w == W_tok_max for w in w_toks): 121 + return out, None 122 + return out, _spatial_mask(w_toks, H_tok, W_tok_max) 123 + 124 + 125 + def decode_batch(model: MathOCR, imgs: torch.Tensor, vocab: Vocab, 126 + mem_mask: torch.Tensor | None = None) -> list[str]: 127 + ids_list = model.generate(imgs, vocab.sos_id, vocab.eos_id, 128 + memory_key_padding_mask=mem_mask) 129 + return [vocab.decode(ids) for ids in ids_list] 130 + 131 + 132 + def evaluate(model: MathOCR, vocab: Vocab, raster: Path, device: torch.device, batch_size: int): 133 + manifests = [mp for mp in sorted(raster.glob("*/manifest.jsonl")) 134 + if mp.parent.name in TEST_SPLITS] 135 + if not manifests: 136 + print(f"No test manifests found under {raster}") 137 + return 138 + 139 + for mp in manifests: 140 + split = mp.parent.name 141 + records = [] 142 + for line in mp.open(): 143 + r = json.loads(line) 144 + typst = r.get("typst", "") 145 + if typst and not typst.startswith("ERROR:"): 146 + records.append((mp.parent / r["image"], typst)) 147 + 148 + exact = 0 149 + bar = tqdm(range(0, len(records), batch_size), desc=split, leave=True) 150 + examples = [] 151 + 152 + for start in bar: 153 + batch = records[start : start + batch_size] 154 + imgs, mem_mask = pad_images([load_image(p) for p, _ in batch]) 155 + imgs = imgs.to(device) 156 + mem_mask = mem_mask.to(device) if mem_mask is not None else None 157 + preds = decode_batch(model, imgs, vocab, mem_mask) 158 + 159 + for (_, gt), pred in zip(batch, preds): 160 + gt_norm = vocab.decode(vocab.encode(gt)) 161 + if pred == gt_norm: 162 + exact += 1 163 + if len(examples) < 5: 164 + examples.append((gt_norm, pred)) 165 + 166 + bar.set_postfix(ExpRate=f"{exact/(start+len(batch)):.3f}") 167 + 168 + exprate = exact / len(records) 169 + print(f"\n{split}: ExpRate={exprate:.4f} ({exact}/{len(records)})") 170 + print("Examples (gt | pred):") 171 + for gt, pred in examples: 172 + match = "OK" if gt == pred else "!!" 173 + print(f" [{match}] {gt}") 174 + if gt != pred: 175 + print(f" {pred}") 176 + 177 + 178 + def infer_files(model: MathOCR, vocab: Vocab, paths: list[Path], device: torch.device, preprocess: bool = False): 179 + imgs, mem_mask = pad_images([load_image(p, preprocess=preprocess) for p in paths]) 180 + imgs = imgs.to(device) 181 + mem_mask = mem_mask.to(device) if mem_mask is not None else None 182 + preds = decode_batch(model, imgs, vocab, mem_mask) 183 + for path, pred in zip(paths, preds): 184 + print(f"{path.name}: {pred}") 185 + 186 + 187 + def main(): 188 + parser = argparse.ArgumentParser( 189 + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter 190 + ) 191 + parser.add_argument("--checkpoint", type=Path, required=True) 192 + parser.add_argument("--vocab", type=Path, default=_EFF_MER_ROOT / "data" / "vocab.json") 193 + parser.add_argument("--raster", type=Path, default=_EFF_MER_ROOT / "data" / "raster") 194 + parser.add_argument("--batch", type=int, default=64) 195 + parser.add_argument("--images", type=Path, nargs="+", metavar="IMG", 196 + help="Image files to run inference on") 197 + parser.add_argument("--preprocess", action="store_true", 198 + help="Binarize + crop to ink bbox before inference (for phone photos)") 199 + parser.add_argument("--eval", action="store_true", default=True, 200 + help="Evaluate ExpRate on test splits (default: on unless --images only)") 201 + parser.add_argument("--no-eval", dest="eval", action="store_false") 202 + args = parser.parse_args() 203 + 204 + # If --images given without explicit --eval, skip test-set eval 205 + if args.images and "--eval" not in sys.argv: 206 + args.eval = False 207 + 208 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 209 + vocab = Vocab.load(args.vocab) 210 + model = load_model(args.checkpoint, vocab, device) 211 + print(f"Loaded {args.checkpoint.name} | vocab={len(vocab)} | device={device}") 212 + 213 + if args.images: 214 + infer_files(model, vocab, args.images, device, preprocess=args.preprocess) 215 + 216 + if args.eval: 217 + evaluate(model, vocab, args.raster, device, args.batch) 218 + 219 + 220 + if __name__ == "__main__": 221 + main()
+83
src/eff_mer/model/__init__.py
··· 1 + """ 2 + MathOCR: encoder-decoder model for math expression recognition. 3 + 4 + images -> Encoder -> memory -> Decoder -> logits 5 + -> greedy/beam generate 6 + """ 7 + 8 + import torch 9 + import torch.nn as nn 10 + 11 + from .decoder import Decoder 12 + from .encoder import Encoder 13 + 14 + 15 + class MathOCR(nn.Module): 16 + def __init__( 17 + self, 18 + vocab_size: int, 19 + d_model: int = 256, 20 + nhead: int = 8, 21 + num_decoder_layers: int = 3, 22 + dropout: float = 0.1, 23 + max_len: int = 256, 24 + ): 25 + super().__init__() 26 + self.encoder = Encoder(d_model) 27 + self.decoder = Decoder(vocab_size, d_model, nhead, num_decoder_layers, dropout, max_len) 28 + 29 + def forward( 30 + self, 31 + images: torch.Tensor, 32 + tgt: torch.Tensor, 33 + tgt_padding_mask: torch.Tensor | None = None, 34 + memory_key_padding_mask: torch.Tensor | None = None, 35 + ) -> torch.Tensor: 36 + """ 37 + images: (B, 1, H, W) 38 + tgt: (B, T) teacher-forced token ids 39 + tgt_padding_mask: (B, T) True at pad positions 40 + memory_key_padding_mask: (B, S) True at spatial padding token positions 41 + 42 + Returns logits (B, T, vocab_size). 43 + """ 44 + memory = self.encoder(images) 45 + return self.decoder( 46 + tgt, memory, 47 + tgt_padding_mask=tgt_padding_mask, 48 + memory_padding_mask=memory_key_padding_mask, 49 + ) 50 + 51 + @torch.no_grad() 52 + def generate( 53 + self, 54 + images: torch.Tensor, 55 + sos_id: int, 56 + eos_id: int, 57 + max_len: int = 256, 58 + memory_key_padding_mask: torch.Tensor | None = None, 59 + ) -> list[list[int]]: 60 + """Greedy decoding. Returns list of token id lists (eos stripped).""" 61 + memory = self.encoder(images) 62 + B = images.shape[0] 63 + 64 + seqs = torch.full((B, 1), sos_id, dtype=torch.long, device=images.device) 65 + done = torch.zeros(B, dtype=torch.bool, device=images.device) 66 + 67 + for _ in range(max_len): 68 + logits = self.decoder(seqs, memory, 69 + memory_padding_mask=memory_key_padding_mask) # (B, T, V) 70 + next_tok = logits[:, -1, :].argmax(-1) # (B,) 71 + next_tok = torch.where(done, torch.full_like(next_tok, eos_id), next_tok) 72 + seqs = torch.cat([seqs, next_tok.unsqueeze(1)], dim=1) 73 + done = done | (next_tok == eos_id) 74 + if done.all(): 75 + break 76 + 77 + results = [] 78 + for i in range(B): 79 + toks = seqs[i, 1:].tolist() # strip <sos> 80 + if eos_id in toks: 81 + toks = toks[: toks.index(eos_id)] 82 + results.append(toks) 83 + return results
+79
src/eff_mer/model/decoder.py
··· 1 + """ 2 + Autoregressive transformer decoder. 3 + 4 + Uses pre-norm (norm_first=True) for training stability, learned positional 5 + embeddings, and a causal mask over the target sequence. 6 + 7 + Coverage attention is a planned addition (see CoMER/WAP) -- the spatial 8 + coverage accumulator should be added to cross-attention logits to prevent 9 + repeated attendance to the same encoder positions. Deferred to post-baseline. 10 + """ 11 + 12 + import torch 13 + import torch.nn as nn 14 + 15 + 16 + class Decoder(nn.Module): 17 + def __init__( 18 + self, 19 + vocab_size: int, 20 + d_model: int = 256, 21 + nhead: int = 8, 22 + num_layers: int = 3, 23 + dropout: float = 0.1, 24 + max_len: int = 256, 25 + ): 26 + super().__init__() 27 + self.d_model = d_model 28 + 29 + self.embed = nn.Embedding(vocab_size, d_model, padding_idx=0) 30 + self.pos_embed = nn.Embedding(max_len, d_model) 31 + 32 + layer = nn.TransformerDecoderLayer( 33 + d_model=d_model, 34 + nhead=nhead, 35 + dim_feedforward=d_model * 4, 36 + dropout=dropout, 37 + batch_first=True, 38 + norm_first=True, # pre-norm: more stable training 39 + ) 40 + self.transformer = nn.TransformerDecoder(layer, num_layers=num_layers) 41 + self.out_proj = nn.Linear(d_model, vocab_size) 42 + 43 + self._init_weights() 44 + 45 + def _init_weights(self): 46 + nn.init.normal_(self.embed.weight, std=0.02) 47 + nn.init.normal_(self.pos_embed.weight, std=0.02) 48 + nn.init.zeros_(self.out_proj.bias) 49 + 50 + def forward( 51 + self, 52 + tgt: torch.Tensor, 53 + memory: torch.Tensor, 54 + tgt_padding_mask: torch.Tensor | None = None, 55 + memory_padding_mask: torch.Tensor | None = None, 56 + ) -> torch.Tensor: 57 + """ 58 + tgt: (B, T) token ids (teacher-forced) 59 + memory: (B, S, d) encoder output 60 + tgt_padding_mask: (B, T) True at pad positions 61 + memory_padding_mask: (B, S) True at pad positions (unused for fixed grids) 62 + 63 + Returns logits (B, T, vocab_size). 64 + """ 65 + B, T = tgt.shape 66 + causal_mask = nn.Transformer.generate_square_subsequent_mask(T, device=tgt.device, dtype=torch.bool) 67 + 68 + pos = torch.arange(T, device=tgt.device) 69 + x = self.embed(tgt) + self.pos_embed(pos) # (B, T, d_model) 70 + 71 + x = self.transformer( 72 + x, 73 + memory, 74 + tgt_mask=causal_mask, 75 + tgt_key_padding_mask=tgt_padding_mask, 76 + memory_key_padding_mask=memory_padding_mask, 77 + tgt_is_causal=True, 78 + ) 79 + return self.out_proj(x) # (B, T, vocab_size)
+59
src/eff_mer/model/encoder.py
··· 1 + """ 2 + Encoder: EfficientNet-B0 backbone + 1x1 projection + 2D sinusoidal positional encoding. 3 + 4 + Input: (B, 1, H, W) grayscale, normalized to ImageNet grayscale stats 5 + Output: (B, h*w, d_model) sequence of spatial tokens 6 + 7 + At 256x256 input, EfficientNet-B0 produces an 8x8 feature grid (stride 32), 8 + giving 64 spatial tokens per image. 9 + """ 10 + 11 + import math 12 + 13 + import torch 14 + import torch.nn as nn 15 + import torchvision.models as tvm 16 + 17 + 18 + class Encoder(nn.Module): 19 + def __init__(self, d_model: int = 256): 20 + super().__init__() 21 + assert d_model % 4 == 0, "d_model must be divisible by 4 for 2D sinusoidal PE" 22 + 23 + backbone = tvm.efficientnet_b0(weights=tvm.EfficientNet_B0_Weights.IMAGENET1K_V1) 24 + self.features = backbone.features # (B, 1280, h, w) 25 + 26 + self.proj = nn.Conv2d(1280, d_model, kernel_size=1, bias=False) 27 + self.d_model = d_model 28 + 29 + def _sinusoidal_pe_2d(self, H: int, W: int, device: torch.device) -> torch.Tensor: 30 + """Returns (1, d_model, H, W) positional encoding.""" 31 + half_d = self.d_model // 4 # quarter for each of sin/cos × row/col 32 + inv_freq = 1.0 / (10000 ** (torch.arange(half_d, device=device).float() / half_d)) 33 + 34 + row = torch.arange(H, device=device).float() 35 + col = torch.arange(W, device=device).float() 36 + 37 + row_enc = torch.cat([torch.sin(row[:, None] * inv_freq), 38 + torch.cos(row[:, None] * inv_freq)], dim=-1) # (H, d/2) 39 + col_enc = torch.cat([torch.sin(col[:, None] * inv_freq), 40 + torch.cos(col[:, None] * inv_freq)], dim=-1) # (W, d/2) 41 + 42 + pe = torch.cat([ 43 + row_enc.unsqueeze(1).expand(H, W, -1), # (H, W, d/2) 44 + col_enc.unsqueeze(0).expand(H, W, -1), # (H, W, d/2) 45 + ], dim=-1) # (H, W, d_model) 46 + 47 + return pe.permute(2, 0, 1).unsqueeze(0) # (1, d_model, H, W) 48 + 49 + def forward(self, x: torch.Tensor) -> torch.Tensor: 50 + # Grayscale -> 3-channel (EfficientNet expects RGB) 51 + x3 = x.expand(-1, 3, -1, -1) 52 + 53 + feat = self.features(x3) # (B, 1280, h, w) 54 + feat = self.proj(feat) # (B, d_model, h, w) 55 + 56 + _, _, H, W = feat.shape 57 + feat = feat + self._sinusoidal_pe_2d(H, W, feat.device) 58 + 59 + return feat.flatten(2).transpose(1, 2) # (B, h*w, d_model)
src/eff_mer/scripts/__init__.py

This is a binary file and will not be displayed.

+125
src/eff_mer/scripts/convert_labels.py
··· 1 + """ 2 + Convert LaTeX labels in raster manifests to Typst using latex2typst-batch. 3 + 4 + Reads each manifest.jsonl, pipes LaTeX strings through the binary, and writes 5 + a new manifest.jsonl with an added "typst" field. Records that fail conversion 6 + are dropped. Processes manifests in-place (originals backed up as .bak). 7 + 8 + Usage: 9 + uv run convert-labels # all splits under data/raster 10 + uv run convert-labels --raster data/raster # explicit raster root 11 + uv run convert-labels --binary ../latex2typst-batch 12 + """ 13 + 14 + import argparse 15 + import json 16 + import re 17 + import shlex 18 + import shutil 19 + import subprocess 20 + from pathlib import Path 21 + 22 + DATA_ROOT = Path(__file__).parent.parent.parent.parent / "data" 23 + _PROJECT_ROOT = Path(__file__).parent.parent.parent.parent 24 + DEFAULT_CMD = f"node {_PROJECT_ROOT / 'tex2typst-js-batch' / 'batch.mjs'}" 25 + 26 + _DELIM_RE = re.compile(r"^\$(.+)\$$|^\\\((.+)\\\)$|^\\\[(.+)\\\]$", re.DOTALL) 27 + 28 + 29 + def strip_delimiters(s: str) -> str: 30 + """Strip LaTeX/Typst math mode delimiters ($...$, \\(...\\), \\[...\\]).""" 31 + s = s.strip() 32 + m = _DELIM_RE.match(s) 33 + if m: 34 + return next(g for g in m.groups() if g is not None).strip() 35 + return s 36 + 37 + 38 + def convert_manifest(manifest_path: Path, binary: str) -> tuple[int, int]: 39 + """ 40 + Convert a single manifest.jsonl in-place. 41 + Returns (written, dropped) counts. 42 + """ 43 + records = [json.loads(line) for line in manifest_path.open()] 44 + if not records: 45 + return 0, 0 46 + 47 + # Skip if already converted 48 + if "typst" in records[0]: 49 + print(f" [skip] already converted: {manifest_path}") 50 + return 0, 0 51 + 52 + latexes = [strip_delimiters(r["latex"]) for r in records] 53 + batch_input = "\n".join(latexes) 54 + 55 + result = subprocess.run( 56 + shlex.split(binary), 57 + input=batch_input, 58 + capture_output=True, 59 + text=True, 60 + ) 61 + typst_lines = result.stdout.splitlines() 62 + 63 + if len(typst_lines) != len(records): 64 + raise RuntimeError( 65 + f"{manifest_path}: expected {len(records)} output lines, " 66 + f"got {len(typst_lines)}" 67 + ) 68 + 69 + # Backup original 70 + bak = manifest_path.with_suffix(".jsonl.bak") 71 + shutil.copy2(manifest_path, bak) 72 + 73 + written = dropped = 0 74 + with manifest_path.open("w") as f: 75 + for rec, typst in zip(records, typst_lines): 76 + if typst.startswith("ERROR:"): 77 + dropped += 1 78 + continue 79 + # Converter re-wraps some outputs in $...$; strip them 80 + typst = strip_delimiters(typst) 81 + rec["typst"] = typst 82 + f.write(json.dumps(rec) + "\n") 83 + written += 1 84 + 85 + return written, dropped 86 + 87 + 88 + def main(): 89 + parser = argparse.ArgumentParser( 90 + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter 91 + ) 92 + parser.add_argument( 93 + "--raster", 94 + type=Path, 95 + default=DATA_ROOT / "raster", 96 + help="Raster root containing split subdirectories (default: data/raster)", 97 + ) 98 + parser.add_argument( 99 + "--binary", 100 + type=str, 101 + default=DEFAULT_CMD, 102 + help="Converter command (shell-split), default: node tex2typst-js-batch/batch.mjs", 103 + ) 104 + args = parser.parse_args() 105 + 106 + manifests = sorted(args.raster.glob("*/manifest.jsonl")) 107 + if not manifests: 108 + parser.error(f"No manifest.jsonl files found under {args.raster}") 109 + 110 + total_written = total_dropped = 0 111 + for mp in manifests: 112 + tag = mp.parent.name 113 + print(f"[{tag}]") 114 + written, dropped = convert_manifest(mp, args.binary) 115 + print(f" {written:,} kept, {dropped:,} dropped") 116 + total_written += written 117 + total_dropped += dropped 118 + 119 + total = total_written + total_dropped 120 + drop_pct = 100 * total_dropped / total if total else 0 121 + print(f"\nTotal: {total_written:,} kept, {total_dropped:,} dropped ({drop_pct:.1f}%)") 122 + 123 + 124 + if __name__ == "__main__": 125 + main()
+306
src/eff_mer/scripts/rasterize.py
··· 1 + #!/usr/bin/env python3 2 + """ 3 + Rasterize InkML math expression datasets to PNG images + JSONL manifest. 4 + 5 + Supported datasets: 6 + crohme splits: real_train, val, test, gen_LaTeX_data_CROHME_2019, 7 + gen_LaTeX_data_CROHME_2023_corpus, gen_syntactic_data 8 + mathwriting splits: Train, Val, Test, Symbols, Synthetic 9 + typeset [stub] not yet implemented 10 + 11 + Output structure: 12 + <out>/<dataset>_<split>/ 13 + images/ *.png 14 + manifest.jsonl {"image": "images/<stem>.png", "latex": "...", "mathml": "..."} 15 + 16 + Usage: 17 + python scripts/rasterize.py --dataset crohme --split real_train 18 + python scripts/rasterize.py --dataset mathwriting --split Train --augment 19 + python scripts/rasterize.py --all 20 + python scripts/rasterize.py --all --size 128 --out data/raster_128 21 + """ 22 + 23 + import argparse 24 + import json 25 + import multiprocessing as mp 26 + import random 27 + from pathlib import Path 28 + from xml.etree import ElementTree as ET 29 + 30 + import numpy as np 31 + from PIL import Image, ImageDraw, ImageFilter 32 + 33 + # ── Config ──────────────────────────────────────────────────────────────────── 34 + 35 + DATA_ROOT = Path(__file__).parent.parent.parent.parent / "data" 36 + 37 + CROHME_SPLITS = [ 38 + "real_train", 39 + "val", 40 + "test", 41 + "gen_LaTeX_data_CROHME_2019", 42 + "gen_LaTeX_data_CROHME_2023_corpus", 43 + "gen_syntactic_data", 44 + ] 45 + 46 + MATHWRITING_SPLITS = ["Train", "Val", "Test", "Symbols", "Synthetic"] 47 + 48 + INK_NS = "http://www.w3.org/2003/InkML" 49 + NS = {"ink": INK_NS} 50 + 51 + # Fraction of canvas size used as padding on each side 52 + MARGIN_FRAC = 0.08 53 + 54 + # ── InkML parsing ───────────────────────────────────────────────────────────── 55 + 56 + 57 + def _find_annotation(root: ET.Element, *type_keys: str) -> str | None: 58 + for key in type_keys: 59 + el = root.find(f"{{%s}}annotation[@type='%s']" % (INK_NS, key)) 60 + if el is not None and el.text: 61 + return el.text.strip() 62 + return None 63 + 64 + 65 + def parse_inkml(path: Path) -> dict | None: 66 + """ 67 + Returns dict with: 68 + latex (str) ground-truth LaTeX string 69 + mathml (str) serialized MathML annotationXML block, or "" 70 + strokes (list of np.ndarray shape (N, 2), float32 xy) 71 + Returns None on parse error or missing label. 72 + """ 73 + try: 74 + tree = ET.parse(path) 75 + except ET.ParseError: 76 + return None 77 + 78 + root = tree.getroot() 79 + 80 + # CROHME+ uses "truth"; MathWriting+ uses "label" / "normalizedLabel" 81 + latex = _find_annotation(root, "truth", "label", "normalizedLabel") 82 + if latex is None: 83 + return None 84 + 85 + # MathML tree: serialize <annotationXML type="truth"> for later use by tree head 86 + mathml_el = root.find(f"{{{INK_NS}}}annotationXML[@type='truth']") 87 + mathml_str = ET.tostring(mathml_el, encoding="unicode") if mathml_el is not None else "" 88 + 89 + # Collect strokes: map trace id -> xy array (ignore timestamps / pressure) 90 + traces: dict[str, np.ndarray] = {} 91 + for trace in root.findall(f"{{{INK_NS}}}trace"): 92 + tid = trace.get("id", "") 93 + pts = [] 94 + if trace.text: 95 + for token in trace.text.strip().split(","): 96 + coords = token.strip().split() 97 + if len(coords) >= 2: 98 + try: 99 + pts.append((float(coords[0]), float(coords[1]))) 100 + except ValueError: 101 + continue 102 + if pts: 103 + traces[tid] = np.array(pts, dtype=np.float32) 104 + 105 + strokes = list(traces.values()) 106 + if not strokes: 107 + return None 108 + 109 + return {"latex": latex, "mathml": mathml_str, "strokes": strokes} 110 + 111 + 112 + # ── Rendering ───────────────────────────────────────────────────────────────── 113 + 114 + 115 + def render_strokes( 116 + strokes: list[np.ndarray], 117 + size: int = 256, 118 + pen_width: float = 2.0, 119 + augment: bool = False, 120 + ) -> Image.Image: 121 + """Render stroke list to a grayscale PIL image (white bg, black ink).""" 122 + all_pts = np.concatenate(strokes, axis=0) 123 + xmin, ymin = float(all_pts[:, 0].min()), float(all_pts[:, 1].min()) 124 + xmax, ymax = float(all_pts[:, 0].max()), float(all_pts[:, 1].max()) 125 + 126 + span = max(xmax - xmin, ymax - ymin, 1.0) 127 + margin_px = int(size * MARGIN_FRAC) 128 + canvas_px = size - 2 * margin_px 129 + 130 + if augment: 131 + pen_width = pen_width * random.uniform(0.7, 1.6) 132 + 133 + def to_px(pts: np.ndarray) -> list[tuple[int, int]]: 134 + x = ((pts[:, 0] - xmin) / span * canvas_px + margin_px).round().astype(int) 135 + y = ((pts[:, 1] - ymin) / span * canvas_px + margin_px).round().astype(int) 136 + return list(zip(x.tolist(), y.tolist())) 137 + 138 + img = Image.new("L", (size, size), color=255) 139 + draw = ImageDraw.Draw(img) 140 + pw = max(1, round(pen_width)) 141 + 142 + for stroke in strokes: 143 + px = to_px(stroke) 144 + if len(px) == 1: 145 + draw.ellipse([px[0][0] - pw, px[0][1] - pw, px[0][0] + pw, px[0][1] + pw], fill=0) 146 + else: 147 + draw.line(px, fill=0, width=pw) 148 + 149 + if augment: 150 + # Soften jagged strokes 151 + img = img.filter(ImageFilter.GaussianBlur(radius=random.uniform(0.0, 0.5))) 152 + # Light pixel noise to simulate paper texture 153 + arr = np.array(img, dtype=np.int16) 154 + arr = np.clip(arr + np.random.normal(0, 4, arr.shape).astype(np.int16), 0, 255).astype(np.uint8) 155 + img = Image.fromarray(arr) 156 + 157 + return img 158 + 159 + 160 + # ── Worker (top-level for multiprocessing pickling) ─────────────────────────── 161 + 162 + 163 + def _worker(args: tuple) -> dict | None: 164 + inkml_path, out_img_dir, size, augment = args 165 + parsed = parse_inkml(inkml_path) 166 + if parsed is None: 167 + return None 168 + img = render_strokes(parsed["strokes"], size=size, augment=augment) 169 + img_name = inkml_path.stem + ".png" 170 + img_path = Path(out_img_dir) / img_name 171 + img.save(img_path, optimize=False) 172 + return { 173 + "image": "images/" + img_name, 174 + "latex": parsed["latex"], 175 + "mathml": parsed["mathml"], 176 + } 177 + 178 + 179 + # ── Split processor ─────────────────────────────────────────────────────────── 180 + 181 + 182 + def process_split( 183 + inkml_dir: Path, 184 + out_dir: Path, 185 + size: int, 186 + augment: bool, 187 + workers: int, 188 + ) -> int: 189 + inkml_files = sorted(inkml_dir.glob("*.inkml")) 190 + if not inkml_files: 191 + print(f" [skip] no .inkml files in {inkml_dir}") 192 + return 0 193 + 194 + out_img_dir = out_dir / "images" 195 + out_img_dir.mkdir(parents=True, exist_ok=True) 196 + 197 + job_args = [(f, str(out_img_dir), size, augment) for f in inkml_files] 198 + 199 + count = 0 200 + skipped = 0 201 + with open(out_dir / "manifest.jsonl", "w") as mf: 202 + with mp.Pool(workers) as pool: 203 + for rec in pool.imap_unordered(_worker, job_args, chunksize=128): 204 + if rec is not None: 205 + mf.write(json.dumps(rec) + "\n") 206 + count += 1 207 + else: 208 + skipped += 1 209 + 210 + if skipped: 211 + print(f" {count:,} written, {skipped} skipped (parse errors / no label)") 212 + return count 213 + 214 + 215 + # ── Typeset stub ────────────────────────────────────────────────────────────── 216 + 217 + 218 + def process_typeset(out_dir: Path, size: int, augment: bool) -> int: 219 + """ 220 + [STUB] Render typeset math expressions to images. 221 + 222 + Planned sources: 223 + - Typst CLI: generate synthetic expressions from a formal grammar, 224 + render via `typst compile` to PNG. 225 + - im2latex-100k or similar LaTeX-rendered dataset. 226 + 227 + Expected manifest format is identical to handwriting splits: 228 + {"image": "images/<stem>.png", "latex": "...", "mathml": ""} 229 + """ 230 + raise NotImplementedError( 231 + "Typeset rasterization not yet implemented.\n" 232 + "Planned: Typst CLI rendering of grammar-generated expressions.\n" 233 + "See scripts/rasterize_typeset.py (to be written)." 234 + ) 235 + 236 + 237 + # ── CLI ─────────────────────────────────────────────────────────────────────── 238 + 239 + 240 + def build_todo(args) -> list[tuple[str, str, Path]]: 241 + """Returns list of (dataset_tag, split, inkml_dir) tuples.""" 242 + if args.all: 243 + todo = [] 244 + for split in CROHME_SPLITS: 245 + d = DATA_ROOT / "CROHME+" / split 246 + if d.exists(): 247 + todo.append(("crohme", split, d)) 248 + else: 249 + print(f"[warn] CROHME+ split not found: {d}") 250 + for split in MATHWRITING_SPLITS: 251 + d = DATA_ROOT / "MathWriting+" / split 252 + if d.exists(): 253 + todo.append(("mathwriting", split, d)) 254 + else: 255 + print(f"[warn] MathWriting+ split not found: {d}") 256 + return todo 257 + 258 + if args.dataset == "typeset": 259 + return [("typeset", "typeset", Path())] 260 + 261 + dataset_root = { 262 + "crohme": DATA_ROOT / "CROHME+", 263 + "mathwriting": DATA_ROOT / "MathWriting+", 264 + }[args.dataset] 265 + 266 + inkml_dir = dataset_root / args.split 267 + if not inkml_dir.exists(): 268 + raise FileNotFoundError(f"Split directory not found: {inkml_dir}") 269 + return [(args.dataset, args.split, inkml_dir)] 270 + 271 + 272 + def main(): 273 + parser = argparse.ArgumentParser( 274 + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter 275 + ) 276 + parser.add_argument("--dataset", choices=["crohme", "mathwriting", "typeset"]) 277 + parser.add_argument("--split", help="Split name (e.g. real_train, Train)") 278 + parser.add_argument("--all", action="store_true", help="Process all known handwriting splits") 279 + parser.add_argument("--out", type=Path, default=DATA_ROOT / "raster", help="Output root (default: data/raster)") 280 + parser.add_argument("--size", type=int, default=256, help="Square canvas size in pixels (default: 256)") 281 + parser.add_argument("--augment", action="store_true", help="Apply pen-width jitter + noise at render time") 282 + parser.add_argument("--workers", type=int, default=max(1, mp.cpu_count() - 1)) 283 + args = parser.parse_args() 284 + 285 + if not args.all and args.dataset is None: 286 + parser.error("Specify --dataset or --all") 287 + if not args.all and args.dataset != "typeset" and args.split is None: 288 + parser.error("--split required unless --all or --dataset typeset") 289 + 290 + todo = build_todo(args) 291 + 292 + for dataset, split, inkml_dir in todo: 293 + tag = f"{dataset}_{split}" 294 + out_dir = args.out / tag 295 + 296 + if dataset == "typeset": 297 + process_typeset(out_dir, args.size, args.augment) 298 + continue 299 + 300 + print(f"[{tag}] {inkml_dir} -> {out_dir} (workers={args.workers}, size={args.size})") 301 + n = process_split(inkml_dir, out_dir, args.size, args.augment, args.workers) 302 + print(f" done: {n:,} images") 303 + 304 + 305 + if __name__ == "__main__": 306 + main()
+222
src/eff_mer/train.py
··· 1 + """ 2 + Training loop for MathOCR. 3 + 4 + Usage: 5 + uv run train 6 + uv run train --epochs 30 --batch 32 --lr 3e-4 --log-every 100 7 + 8 + Checkpoints saved to checkpoints/epoch_{N}.pt every epoch. 9 + Resumes from latest checkpoint if present. 10 + """ 11 + 12 + import argparse 13 + import csv 14 + import time 15 + from pathlib import Path 16 + 17 + import torch 18 + import torch.nn as nn 19 + from torch.utils.data import DataLoader 20 + from tqdm import tqdm 21 + 22 + from .data import MathDataset 23 + from .model import MathOCR 24 + from .vocab import Vocab 25 + 26 + DATA_ROOT = Path(__file__).parent.parent.parent / "data" 27 + CKPT_DIR = Path(__file__).parent.parent.parent / "checkpoints" 28 + 29 + TRAIN_SPLITS = { 30 + "crohme_gen_LaTeX_data_CROHME_2019", 31 + "crohme_gen_LaTeX_data_CROHME_2023_corpus", 32 + "crohme_gen_syntactic_data", 33 + "crohme_real_train", 34 + "mathwriting_Train", 35 + "mathwriting_Synthetic", 36 + "mathwriting_Symbols", 37 + } 38 + VAL_SPLITS = {"crohme_val", "mathwriting_Val"} 39 + 40 + 41 + def get_manifests(raster: Path, splits: set[str]) -> list[Path]: 42 + return sorted(mp for mp in raster.glob("*/manifest.jsonl") if mp.parent.name in splits) 43 + 44 + 45 + def load_latest_checkpoint( 46 + ckpt_dir: Path, 47 + model: nn.Module, 48 + optimizer: torch.optim.Optimizer, 49 + scheduler: torch.optim.lr_scheduler.LRScheduler, 50 + ): 51 + ckpts = sorted(ckpt_dir.glob("epoch_*.pt")) 52 + if not ckpts: 53 + return 0 54 + ckpt = torch.load(ckpts[-1], weights_only=True) 55 + model.load_state_dict(ckpt["model"]) 56 + optimizer.load_state_dict(ckpt["optimizer"]) 57 + if "scheduler" in ckpt: 58 + scheduler.load_state_dict(ckpt["scheduler"]) 59 + epoch = ckpt["epoch"] 60 + print(f"Resumed from {ckpts[-1].name} (epoch {epoch})") 61 + return epoch 62 + 63 + 64 + def run_epoch(model, loader, optimizer, criterion, device, train: bool, log_every: int, epoch: int): 65 + model.train(train) 66 + total_loss = total_tokens = 0 67 + step = 0 68 + 69 + phase = "train" if train else "val" 70 + bar = tqdm(loader, desc=f"Epoch {epoch} {phase}", leave=False, dynamic_ncols=True) 71 + 72 + with torch.set_grad_enabled(train): 73 + for imgs, tgt, lengths, pad_mask, mem_mask in bar: 74 + imgs = imgs.to(device) 75 + tgt = tgt.to(device) 76 + pad_mask = pad_mask.to(device) 77 + mem_mask = mem_mask.to(device) 78 + 79 + tgt_in = tgt[:, :-1] 80 + tgt_out = tgt[:, 1:] 81 + pad_mask_in = pad_mask[:, :-1] 82 + 83 + logits = model(imgs, tgt_in, tgt_padding_mask=pad_mask_in, 84 + memory_key_padding_mask=mem_mask) 85 + 86 + loss = criterion( 87 + logits.reshape(-1, logits.shape[-1]), 88 + tgt_out.reshape(-1), 89 + ) 90 + 91 + if train: 92 + optimizer.zero_grad() 93 + loss.backward() 94 + nn.utils.clip_grad_norm_(model.parameters(), 1.0) 95 + optimizer.step() 96 + 97 + n_tokens = (tgt_out != 0).sum().item() 98 + total_loss += loss.item() * n_tokens 99 + total_tokens += n_tokens 100 + step += 1 101 + 102 + if train and step % log_every == 0: 103 + running = total_loss / total_tokens 104 + bar.set_postfix(loss=f"{running:.4f}") 105 + 106 + bar.close() 107 + return total_loss / total_tokens if total_tokens else float("inf") 108 + 109 + 110 + def main(): 111 + parser = argparse.ArgumentParser( 112 + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter 113 + ) 114 + parser.add_argument("--raster", type=Path, default=DATA_ROOT / "raster") 115 + parser.add_argument("--vocab", type=Path, default=DATA_ROOT / "vocab.json") 116 + parser.add_argument("--epochs", type=int, default=30) 117 + parser.add_argument("--batch", type=int, default=32) 118 + parser.add_argument("--lr", type=float, default=3e-4) 119 + parser.add_argument("--workers", type=int, default=4) 120 + parser.add_argument("--max-len", type=int, default=256) 121 + parser.add_argument("--d-model", type=int, default=256) 122 + parser.add_argument("--nhead", type=int, default=8) 123 + parser.add_argument("--dec-layers", type=int, default=3) 124 + parser.add_argument("--dropout", type=float, default=0.1) 125 + parser.add_argument("--log-every", type=int, default=100, help="Log loss every N steps") 126 + args = parser.parse_args() 127 + 128 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 129 + print(f"Device: {device}") 130 + 131 + vocab = Vocab.load(args.vocab) 132 + print(f"Vocab: {len(vocab)} tokens") 133 + 134 + train_manifests = get_manifests(args.raster, TRAIN_SPLITS) 135 + val_manifests = get_manifests(args.raster, VAL_SPLITS) 136 + print(f"Train manifests: {len(train_manifests)}, Val manifests: {len(val_manifests)}") 137 + 138 + train_ds = MathDataset(train_manifests, vocab, max_len=args.max_len, augment=True) 139 + val_ds = MathDataset(val_manifests, vocab, max_len=args.max_len, augment=False) 140 + print(f"Train samples: {len(train_ds):,}, Val samples: {len(val_ds):,}") 141 + 142 + train_loader = DataLoader( 143 + train_ds, 144 + batch_size=args.batch, 145 + shuffle=True, 146 + num_workers=args.workers, 147 + collate_fn=MathDataset.collate, 148 + pin_memory=device.type == "cuda", 149 + ) 150 + val_loader = DataLoader( 151 + val_ds, 152 + batch_size=args.batch, 153 + shuffle=False, 154 + num_workers=args.workers, 155 + collate_fn=MathDataset.collate, 156 + pin_memory=device.type == "cuda", 157 + ) 158 + 159 + model = MathOCR( 160 + vocab_size=len(vocab), 161 + d_model=args.d_model, 162 + nhead=args.nhead, 163 + num_decoder_layers=args.dec_layers, 164 + dropout=args.dropout, 165 + max_len=args.max_len, 166 + ).to(device) 167 + 168 + n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 169 + print(f"Trainable params: {n_params/1e6:.2f}M") 170 + 171 + optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4) 172 + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 173 + optimizer, T_max=args.epochs, eta_min=args.lr * 0.01 174 + ) 175 + 176 + criterion = nn.CrossEntropyLoss(ignore_index=vocab.pad_id, label_smoothing=0.1) 177 + 178 + CKPT_DIR.mkdir(exist_ok=True) 179 + start_epoch = load_latest_checkpoint(CKPT_DIR, model, optimizer, scheduler) 180 + 181 + log_path = CKPT_DIR / "metrics.csv" 182 + log_exists = log_path.exists() 183 + log_file = log_path.open("a", newline="") 184 + log_writer = csv.writer(log_file) 185 + if not log_exists: 186 + log_writer.writerow(["epoch", "train_loss", "val_loss", "lr", "elapsed_s"]) 187 + 188 + for epoch in range(start_epoch + 1, args.epochs + 1): 189 + t0 = time.time() 190 + train_loss = run_epoch(model, train_loader, optimizer, criterion, device, True, args.log_every, epoch) 191 + scheduler.step() 192 + val_loss = run_epoch(model, val_loader, optimizer, criterion, device, False, args.log_every, epoch) 193 + elapsed = time.time() - t0 194 + lr = optimizer.param_groups[0]["lr"] 195 + 196 + print( 197 + f"Epoch {epoch:3d}/{args.epochs} | " 198 + f"train={train_loss:.4f} | val={val_loss:.4f} | " 199 + f"lr={lr:.2e} | {elapsed:.0f}s" 200 + ) 201 + 202 + ckpt_path = CKPT_DIR / f"epoch_{epoch:03d}.pt" 203 + torch.save( 204 + { 205 + "epoch": epoch, 206 + "model": model.state_dict(), 207 + "optimizer": optimizer.state_dict(), 208 + "scheduler": scheduler.state_dict(), 209 + "val_loss": val_loss, 210 + }, 211 + ckpt_path, 212 + ) 213 + print(f" checkpoint -> {ckpt_path.name}") 214 + 215 + log_writer.writerow([epoch, f"{train_loss:.6f}", f"{val_loss:.6f}", f"{lr:.6e}", f"{elapsed:.1f}"]) 216 + log_file.flush() 217 + 218 + log_file.close() 219 + 220 + 221 + if __name__ == "__main__": 222 + main()
+67
src/eff_mer/vocab.py
··· 1 + """ 2 + Fixed vocabulary over Typst math tokens. 3 + 4 + Tokenization: regex split into word-character runs and individual symbols, 5 + matching how Typst math is written (multi-char identifiers like `frac`, `sqrt` 6 + stay atomic; operators like `+`, `(`, `^` are single tokens). 7 + """ 8 + 9 + import json 10 + import re 11 + from collections import Counter 12 + from pathlib import Path 13 + 14 + SPECIAL = ["<pad>", "<sos>", "<eos>", "<unk>"] 15 + 16 + _TOKEN_RE = re.compile(r"[a-zA-Z0-9]+|[^\w\s]|_") 17 + 18 + 19 + def tokenize(typst: str) -> list[str]: 20 + return _TOKEN_RE.findall(typst) 21 + 22 + 23 + class Vocab: 24 + def __init__(self, tokens: list[str]): 25 + self._id2tok = tokens 26 + self._tok2id = {t: i for i, t in enumerate(tokens)} 27 + 28 + def __len__(self) -> int: 29 + return len(self._id2tok) 30 + 31 + @property 32 + def pad_id(self) -> int: 33 + return 0 # always first by construction 34 + 35 + @property 36 + def sos_id(self) -> int: 37 + return 1 38 + 39 + @property 40 + def eos_id(self) -> int: 41 + return 2 42 + 43 + @property 44 + def unk_id(self) -> int: 45 + return 3 46 + 47 + def encode(self, typst: str) -> list[int]: 48 + return [self._tok2id.get(t, self.unk_id) for t in tokenize(typst)] 49 + 50 + def decode(self, ids: list[int], skip_special: bool = True) -> str: 51 + skip = {self.pad_id, self.sos_id, self.eos_id} if skip_special else set() 52 + return " ".join(self._id2tok[i] for i in ids if i not in skip) 53 + 54 + def save(self, path: Path) -> None: 55 + path.write_text(json.dumps(self._id2tok, ensure_ascii=False)) 56 + 57 + @classmethod 58 + def load(cls, path: Path) -> "Vocab": 59 + return cls(json.loads(path.read_text())) 60 + 61 + @classmethod 62 + def build(cls, typst_strings: list[str], min_freq: int = 2) -> "Vocab": 63 + counter: Counter = Counter() 64 + for s in typst_strings: 65 + counter.update(tokenize(s)) 66 + tokens = SPECIAL + [t for t, c in counter.most_common() if c >= min_freq] 67 + return cls(tokens)