this repo has no description
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

Overhaul augmentation pipeline and add augment-vis tool

- Extract all magic numbers into named AUG_* constants
- Fix affine crop to account for rotation corner shift and scale expansion
- Fix perspective transform clipping (pad before, crop after)
- Add _find_blocks, _transform_patch, _region_jitter for per-block jitter
- Cap per-block dy to half the inter-block gap to prevent overlap
- Reduce patch elastic distortion (alpha 6→4, sigma 3→5)
- Content-type detection: lists get AUG_JITTER_LIST_MAX_DX=10 vs 40 for prose
- Increase ruled-line opacity (28-55→60-110) and probability (20%→30%)
- augment-vis: show orig / aug / aug+jitter columns; save NN_typst.txt;
print list detection tag and 80-char typst preview

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

+310 -21
+1
pyproject.toml
··· 29 29 train-hnm = "src.train_hnm:main" 30 30 export = "src.export:main" 31 31 generate-typeset = "src.generate_typeset:main" 32 + augment-vis = "src.augment_vis:main" 32 33 download-hw-fonts = "src.download_hw_fonts:main" 33 34 probe = "src.probe:main" 34 35 app = "src.app:main"
+85
src/augment_vis.py
··· 1 + """ 2 + Visualise augmentation on a sample of training images. 3 + 4 + Saves three images per example: 5 + NN_orig.png -- original 6 + NN_aug.png -- full page-level pipeline (no region jitter) 7 + NN_aug_jitter.png -- same page transforms + region jitter applied on top 8 + 9 + Usage: 10 + uv run augment-vis # 12 images from typeset_mixed_train 11 + uv run augment-vis --split typeset_prose_train --n 12 12 + uv run augment-vis --split mathwriting_train --out /tmp/aug_mw 13 + uv run augment-vis --images data/foo/images/abc.png ... 14 + """ 15 + 16 + import argparse 17 + import json 18 + import random 19 + from pathlib import Path 20 + 21 + from PIL import Image 22 + 23 + from .data import ( 24 + DATA_ROOT, _augment, _find_blocks, _region_jitter, 25 + _LIST_RE, AUG_JITTER_MAX_DX, AUG_JITTER_LIST_MAX_DX, 26 + ) 27 + 28 + 29 + def main() -> None: 30 + parser = argparse.ArgumentParser( 31 + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter 32 + ) 33 + parser.add_argument("--split", default="typeset_mixed_train") 34 + parser.add_argument("--n", type=int, default=12) 35 + parser.add_argument("--seed", type=int, default=42) 36 + parser.add_argument("--out", default="/tmp/augment_vis") 37 + parser.add_argument("--images", nargs="+", metavar="IMG", 38 + help="Use specific image files instead of sampling a split") 39 + args = parser.parse_args() 40 + 41 + rng = random.Random(args.seed) 42 + out = Path(args.out) 43 + out.mkdir(parents=True, exist_ok=True) 44 + 45 + if args.images: 46 + entries = [{"path": Path(p), "typst": ""} for p in args.images] 47 + else: 48 + manifest = DATA_ROOT / args.split / "manifest.jsonl" 49 + lines = manifest.read_text().splitlines() 50 + rng.shuffle(lines) 51 + base = DATA_ROOT / args.split 52 + records = [json.loads(l) for l in lines[: args.n]] 53 + entries = [{"path": base / r["image"], "typst": r.get("typst", "")} for r in records] 54 + 55 + for i, entry in enumerate(entries): 56 + path, typst = entry["path"], entry["typst"] 57 + img = Image.open(path).convert("RGB") 58 + blocks = _find_blocks(img) 59 + is_list = bool(_LIST_RE.search(typst)) 60 + 61 + aug_base = _augment(img, region_jitter=False, typst=typst) 62 + max_dx = AUG_JITTER_LIST_MAX_DX if is_list else AUG_JITTER_MAX_DX 63 + aug_jitter = _region_jitter(aug_base, max_dx=max_dx) 64 + 65 + img.save(out / f"{i:02d}_orig.png") 66 + aug_base.save(out / f"{i:02d}_aug.png") 67 + aug_jitter.save(out / f"{i:02d}_aug_jitter.png") 68 + (out / f"{i:02d}_typst.txt").write_text(typst) 69 + 70 + list_tag = " [list]" if is_list else "" 71 + ranges = " ".join(f"{t}-{b}({b-t}px)" for t, b in blocks) 72 + preview = (typst[:80] + "…") if len(typst) > 80 else typst 73 + print(f"{i}: {path.name} size={img.size} blocks={len(blocks)}{list_tag}") 74 + print(f" {ranges}") 75 + print(f" {preview}") 76 + 77 + print(f"\nSaved to {out}/") 78 + print(" NN_orig.png -- original") 79 + print(" NN_aug.png -- page-level augmentation only (no jitter)") 80 + print(" NN_aug_jitter.png -- same transforms + region jitter") 81 + print(" NN_typst.txt -- Typst source label") 82 + 83 + 84 + if __name__ == "__main__": 85 + main()
+224 -21
src/data.py
··· 9 9 """ 10 10 11 11 import json 12 + import math 12 13 import re 13 14 import random 14 15 from collections import defaultdict 15 16 from pathlib import Path 16 17 18 + import numpy as np 17 19 from torch.utils.data import Dataset 18 20 from PIL import Image, ImageDraw, ImageFilter 19 21 from torchvision.transforms import ElasticTransform, RandomPerspective ··· 64 66 _VAR_RE = re.compile(r"\b[a-zA-Z]\b") 65 67 _MAX_PER_KEY = 5 66 68 67 - _ELASTIC = ElasticTransform(alpha=30.0, sigma=6.0, fill=255) 68 - _PERSPECTIVE = RandomPerspective(distortion_scale=0.04, p=1.0, fill=255) 69 + # --------------------------------------------------------------------------- 70 + # Augmentation hyperparameters 71 + # --------------------------------------------------------------------------- 72 + 73 + # Page-level affine 74 + AUG_ANGLE_DEG = 5.0 # max rotation in either direction 75 + AUG_SCALE_RANGE = 0.1 # scale drawn from [1-r, 1+r] 76 + AUG_TRANSLATE_FRAC = 0.05 # max translation as fraction of image dimension 77 + 78 + # Page-level elastic (sigma kept large so the warp field is globally smooth) 79 + AUG_ELASTIC_ALPHA = 15.0 80 + AUG_ELASTIC_SIGMA = 5.0 81 + AUG_P_ELASTIC = 0.35 82 + 83 + # Page-level perspective 84 + AUG_PERSP_DISTORTION = 0.02 85 + AUG_P_PERSPECTIVE = 0.25 86 + 87 + # Per-block (patch) transforms inside _region_jitter 88 + AUG_PATCH_ANGLE_DEG = 3.0 89 + AUG_PATCH_SCALE_RANGE = 0.05 90 + AUG_PATCH_ELASTIC_ALPHA = 4.0 91 + AUG_PATCH_ELASTIC_SIGMA = 5.0 92 + AUG_P_PATCH_ELASTIC = 0.4 93 + AUG_PATCH_MIN_H = 30 # minimum patch height to apply elastic 94 + 95 + # Region jitter offsets 96 + AUG_JITTER_MAX_DX = 40 # prose / non-list content 97 + AUG_JITTER_LIST_MAX_DX = 10 # list content: keep indentation legible 98 + AUG_JITTER_MAX_DY = 30 99 + AUG_P_JITTER = 0.35 100 + 101 + _LIST_RE = re.compile(r'^\s*[-+] ', re.MULTILINE) 102 + 103 + # Photometric 104 + AUG_P_BLUR = 0.6 105 + AUG_BLUR_SIGMA_MAX = 0.8 106 + AUG_BRIGHTNESS_RANGE = 0.25 # factor drawn from [1-r, 1+r] 107 + AUG_CONTRAST_RANGE = 0.25 108 + AUG_P_RULED_LINES = 0.30 109 + 110 + # --------------------------------------------------------------------------- 111 + 112 + _ELASTIC = ElasticTransform(alpha=AUG_ELASTIC_ALPHA, sigma=AUG_ELASTIC_SIGMA, fill=[255, 255, 255]) 113 + _PATCH_ELASTIC = ElasticTransform(alpha=AUG_PATCH_ELASTIC_ALPHA, sigma=AUG_PATCH_ELASTIC_SIGMA, fill=[255, 255, 255]) 114 + _PERSPECTIVE = RandomPerspective(distortion_scale=AUG_PERSP_DISTORTION, p=1.0, fill=[255, 255, 255]) 69 115 70 116 71 117 def _structural_key(typst: str) -> str: ··· 126 172 overlay = Image.new("RGBA", work.size, (255, 255, 255, 0)) 127 173 draw = ImageDraw.Draw(overlay) 128 174 spacing = random.randint(18, 28) 129 - opacity = random.randint(28, 55) 175 + opacity = random.randint(60, 110) 130 176 for y in range(spacing, work.height, spacing): 131 177 draw.line([(0, y), (work.width, y)], fill=(160, 160, 210, opacity), width=1) 132 178 return Image.alpha_composite(work, overlay).convert(orig_mode) 133 179 134 180 135 - def _augment(img: Image.Image) -> Image.Image: 181 + def _find_blocks(img: Image.Image, threshold: float = 0.005, min_gap: int = 16) -> list[tuple[int, int]]: 182 + """Return (top, bottom) pixel ranges for content blocks via horizontal projection. 183 + 184 + threshold: fraction of dark pixels per row below which a row is "empty". 185 + 0.005 captures rows with as few as ~5 dark pixels per 1000px width, 186 + which preserves sparse tops of superscripts and integral limits. 187 + min_gap: consecutive empty rows required to count as a block separator. 188 + 16px sits above intra-character gaps (i-dot, diacritic) but below 189 + typical inter-paragraph whitespace at 250 PPI. 190 + """ 191 + gray = np.array(img.convert("L")) 192 + norm = (gray < 200).sum(axis=1) / max(gray.shape[1], 1) 193 + is_gap = norm < threshold 194 + 195 + H = len(is_gap) 196 + gap_intervals: list[tuple[int, int]] = [] 197 + i = 0 198 + while i < H: 199 + if is_gap[i]: 200 + j = i 201 + while j < H and is_gap[j]: 202 + j += 1 203 + if j - i >= min_gap: 204 + gap_intervals.append((i, j)) 205 + i = j 206 + else: 207 + i += 1 208 + 209 + boundaries = [0] + [x for g in gap_intervals for x in g] + [H] 210 + raw = [ 211 + (boundaries[k], boundaries[k + 1]) 212 + for k in range(0, len(boundaries) - 1, 2) 213 + if boundaries[k + 1] > boundaries[k] 214 + ] 215 + 216 + # Expand each block outward to capture sparse character edges 217 + # (tops of superscripts, bottoms of subscripts) that fell below threshold. 218 + # Merge any blocks whose expansions now overlap. 219 + EDGE = 8 220 + merged: list[list[int]] = [] 221 + for top, bot in raw: 222 + top, bot = max(0, top - EDGE), min(H, bot + EDGE) 223 + if merged and top < merged[-1][1]: 224 + merged[-1][1] = max(merged[-1][1], bot) 225 + else: 226 + merged.append([top, bot]) 227 + return [(t, b) for t, b in merged] 228 + 229 + 230 + def _transform_patch(patch: Image.Image) -> Image.Image: 231 + """Per-block affine + elastic, simulating independently written chunks.""" 232 + pw, ph = patch.width, patch.height 233 + 234 + # Rotation: pad canvas, rotate, then crop to the tight bounding box of the 235 + # rotated content so that corner content isn't clipped. 236 + angle = random.uniform(-AUG_PATCH_ANGLE_DEG, AUG_PATCH_ANGLE_DEG) 237 + a_rad = abs(math.radians(angle)) 238 + cos_a, sin_a = math.cos(a_rad), math.sin(a_rad) 239 + rot_pad = int(max(pw, ph) * sin_a) + 4 240 + canvas = Image.new("RGB", (pw + 2 * rot_pad, ph + 2 * rot_pad), (255, 255, 255)) 241 + canvas.paste(patch, (rot_pad, rot_pad)) 242 + canvas = canvas.rotate(angle, resample=Image.BICUBIC, fillcolor=(255, 255, 255)) 243 + bb_w = int(pw * cos_a + ph * sin_a) + 2 244 + bb_h = int(pw * sin_a + ph * cos_a) + 2 245 + cx, cy = rot_pad + pw // 2, rot_pad + ph // 2 246 + patch = canvas.crop((cx - bb_w // 2, cy - bb_h // 2, 247 + cx - bb_w // 2 + bb_w, cy - bb_h // 2 + bb_h)) 248 + 249 + scale = random.uniform(1.0 - AUG_PATCH_SCALE_RANGE, 1.0 + AUG_PATCH_SCALE_RANGE) 250 + new_w = max(1, int(pw * scale)) 251 + new_h = max(1, int(ph * scale)) 252 + patch = patch.resize((new_w, new_h), Image.BICUBIC) 253 + 254 + if ph >= AUG_PATCH_MIN_H and random.random() < AUG_P_PATCH_ELASTIC: 255 + patch = _PATCH_ELASTIC(patch) 256 + 257 + return patch 258 + 259 + 260 + def _region_jitter(img: Image.Image, max_dx: int = AUG_JITTER_MAX_DX) -> Image.Image: 261 + """Apply independent affine + elastic transforms per content block, then 262 + re-composite with random position offsets. Simulates 'written in chunks' 263 + layout geometry: misaligned baselines, inconsistent left margins, local 264 + stroke wobble -- while preserving content identity in each block.""" 265 + blocks = _find_blocks(img) 266 + if len(blocks) < 2: 267 + return img 268 + 269 + src = np.array(img) 270 + H, W = src.shape[:2] 271 + MY = AUG_JITTER_MAX_DY + 2 272 + MX = AUG_JITTER_MAX_DX + 2 273 + 274 + out = np.full((H + 2 * MY, W + 2 * MX, 3), 255, dtype=np.uint8) 275 + for i, (top, bot) in enumerate(blocks): 276 + patch = _transform_patch(Image.fromarray(src[top:bot])) 277 + strip = np.array(patch) 278 + sh, sw = strip.shape[:2] 279 + 280 + # Cap dy to half the gap toward each neighbor so blocks never overlap. 281 + gap_above = top - (blocks[i - 1][1] if i > 0 else 0) 282 + gap_below = (blocks[i + 1][0] if i < len(blocks) - 1 else H) - bot 283 + max_up = min(AUG_JITTER_MAX_DY, gap_above // 2) 284 + max_down = min(AUG_JITTER_MAX_DY, gap_below // 2) 285 + dy = random.randint(-max_up, max_down) 286 + dx = random.randint(-max_dx, max_dx) 287 + dst_top = MY + top + dy 288 + dst_left = MX + dx 289 + dst_bot = dst_top + sh 290 + dst_right = dst_left + sw 291 + 292 + # source crop bounds (adjusted for canvas clamp) 293 + s_top, s_bot, s_left, s_right = 0, sh, 0, sw 294 + if dst_top < 0: 295 + s_top = -dst_top; dst_top = 0 296 + if dst_bot > out.shape[0]: 297 + s_bot = sh - (dst_bot - out.shape[0]); dst_bot = out.shape[0] 298 + if dst_left < 0: 299 + s_left = -dst_left; dst_left = 0 300 + if dst_right > out.shape[1]: 301 + s_right = sw - (dst_right - out.shape[1]); dst_right = out.shape[1] 302 + 303 + out[dst_top:dst_bot, dst_left:dst_right] = strip[s_top:s_bot, s_left:s_right] 304 + 305 + return Image.fromarray(out[MY: MY + H, MX: MX + W]) 306 + 307 + 308 + def _augment(img: Image.Image, region_jitter: bool = True, typst: str = "") -> Image.Image: 136 309 """ 137 310 Augmentation for synthetic-to-real-notes robustness. 138 311 139 - Geometric: affine (existing) + elastic deformation + mild perspective. 140 - Photometric: blur + brightness/contrast (existing) + optional ruled lines. 312 + Geometric: affine + elastic + perspective + per-region block jitter. 313 + Photometric: blur + brightness/contrast + optional ruled lines. 314 + 315 + region_jitter: when True, applies _region_jitter with AUG_P_JITTER probability. 316 + Pass False to run the page-level pipeline without block jitter. 141 317 """ 142 - angle = random.uniform(-5, 5) 143 - scale = random.uniform(0.9, 1.1) 144 - tx = int(random.uniform(-0.05, 0.05) * img.width) 145 - ty = int(random.uniform(-0.05, 0.05) * img.height) 146 - img = TF.affine(img, angle=angle, translate=(tx, ty), scale=scale, shear=0, fill=255) 318 + angle = random.uniform(-AUG_ANGLE_DEG, AUG_ANGLE_DEG) 319 + scale = random.uniform(1.0 - AUG_SCALE_RANGE, 1.0 + AUG_SCALE_RANGE) 320 + tx = int(random.uniform(-AUG_TRANSLATE_FRAC, AUG_TRANSLATE_FRAC) * img.width) 321 + ty = int(random.uniform(-AUG_TRANSLATE_FRAC, AUG_TRANSLATE_FRAC) * img.height) 322 + orig_w, orig_h = img.width, img.height 323 + a_sin = abs(math.sin(math.radians(angle))) 324 + # Rotation pushes corners outside the original bounding box by ~(max_dim/2)*sin. 325 + # Scale>1 expands content outward from center by (scale-1)*max_dim/2 on each side. 326 + # Both require extra canvas space AND a wider crop to stay lossless. 327 + rot_extra = int(max(orig_h, orig_w) / 2 * a_sin) + 4 328 + scale_extra = max(0, int((scale - 1.0) * max(orig_w, orig_h) / 2)) + 2 329 + pad = max( 330 + int(scale * max(orig_w, orig_h) * a_sin), 331 + rot_extra + scale_extra, 332 + ) + abs(tx) + abs(ty) + 8 333 + padded = Image.new("RGB", (orig_w + 2 * pad, orig_h + 2 * pad), (255, 255, 255)) 334 + padded.paste(img, (pad, pad)) 335 + padded = TF.affine(padded, angle=angle, translate=(tx, ty), scale=scale, shear=0, fill=(255, 255, 255)) 336 + img = padded.crop(( 337 + pad - rot_extra - scale_extra - max(0, -tx), 338 + pad - rot_extra - scale_extra - max(0, -ty), 339 + pad + orig_w + rot_extra + scale_extra + max(0, tx), 340 + pad + orig_h + rot_extra + scale_extra + max(0, ty), 341 + )) 147 342 148 343 # Elastic deformation: smooth warp mimics baseline wobble without 149 344 # corrupting subscript/superscript spatial relationships (sigma=6 keeps 150 345 # the displacement field globally smooth). 151 - if random.random() < 0.35: 346 + if random.random() < AUG_P_ELASTIC: 152 347 img = _ELASTIC(img) 153 348 154 - # Mild perspective: simulates photographed note-paper rather than scanned. 155 - if random.random() < 0.25: 156 - img = _PERSPECTIVE(img) 349 + if random.random() < AUG_P_PERSPECTIVE: 350 + persp_pad = int(AUG_PERSP_DISTORTION * min(img.width, img.height)) + 4 351 + pw, ph = img.width, img.height 352 + persp_canvas = Image.new("RGB", (pw + 2 * persp_pad, ph + 2 * persp_pad), (255, 255, 255)) 353 + persp_canvas.paste(img, (persp_pad, persp_pad)) 354 + persp_canvas = _PERSPECTIVE(persp_canvas) 355 + img = persp_canvas.crop((persp_pad, persp_pad, persp_pad + pw, persp_pad + ph)) 356 + 357 + if region_jitter and random.random() < AUG_P_JITTER: 358 + max_dx = AUG_JITTER_LIST_MAX_DX if _LIST_RE.search(typst) else AUG_JITTER_MAX_DX 359 + img = _region_jitter(img, max_dx=max_dx) 157 360 158 - if random.random() < 0.6: 159 - img = img.filter(ImageFilter.GaussianBlur(random.uniform(0.0, 0.8))) 361 + if random.random() < AUG_P_BLUR: 362 + img = img.filter(ImageFilter.GaussianBlur(random.uniform(0.0, AUG_BLUR_SIGMA_MAX))) 160 363 161 - img = TF.adjust_brightness(img, random.uniform(0.75, 1.25)) 162 - img = TF.adjust_contrast(img, random.uniform(0.75, 1.25)) 364 + img = TF.adjust_brightness(img, random.uniform(1.0 - AUG_BRIGHTNESS_RANGE, 1.0 + AUG_BRIGHTNESS_RANGE)) 365 + img = TF.adjust_contrast(img, random.uniform(1.0 - AUG_CONTRAST_RANGE, 1.0 + AUG_CONTRAST_RANGE)) 163 366 164 - if random.random() < 0.20: 367 + if random.random() < AUG_P_RULED_LINES: 165 368 img = _add_ruled_lines(img) 166 369 167 370 return img ··· 227 430 r = self.records[idx] 228 431 img = Image.open(r["image_path"]).convert("RGB") 229 432 if self.do_augment: 230 - img = _augment(img) 433 + img = _augment(img, typst=r.get("typst", "")) 231 434 return self.format_fn(r, img) 232 435 233 436