iterative image reconstruction using random cubic bézier strokes, accelerated on metal
metal graphics rust
7
fork

Configure Feed

Select the types of activity you want to include in your feed.

first commit

luthenwald 60ea9d7f

+2089
+1
.gitignore
··· 1 + target/
+294
Cargo.lock
··· 1 + # This file is automatically @generated by Cargo. 2 + # It is not intended for manual editing. 3 + version = 4 4 + 5 + [[package]] 6 + name = "adler2" 7 + version = "2.0.1" 8 + source = "registry+https://github.com/rust-lang/crates.io-index" 9 + checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" 10 + 11 + [[package]] 12 + name = "autocfg" 13 + version = "1.5.0" 14 + source = "registry+https://github.com/rust-lang/crates.io-index" 15 + checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" 16 + 17 + [[package]] 18 + name = "bitflags" 19 + version = "2.10.0" 20 + source = "registry+https://github.com/rust-lang/crates.io-index" 21 + checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" 22 + 23 + [[package]] 24 + name = "block2" 25 + version = "0.6.2" 26 + source = "registry+https://github.com/rust-lang/crates.io-index" 27 + checksum = "cdeb9d870516001442e364c5220d3574d2da8dc765554b4a617230d33fa58ef5" 28 + dependencies = [ 29 + "objc2", 30 + ] 31 + 32 + [[package]] 33 + name = "bytemuck" 34 + version = "1.24.0" 35 + source = "registry+https://github.com/rust-lang/crates.io-index" 36 + checksum = "1fbdf580320f38b612e485521afda1ee26d10cc9884efaaa750d383e13e3c5f4" 37 + 38 + [[package]] 39 + name = "byteorder-lite" 40 + version = "0.1.0" 41 + source = "registry+https://github.com/rust-lang/crates.io-index" 42 + checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" 43 + 44 + [[package]] 45 + name = "cfg-if" 46 + version = "1.0.4" 47 + source = "registry+https://github.com/rust-lang/crates.io-index" 48 + checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" 49 + 50 + [[package]] 51 + name = "crc32fast" 52 + version = "1.5.0" 53 + source = "registry+https://github.com/rust-lang/crates.io-index" 54 + checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" 55 + dependencies = [ 56 + "cfg-if", 57 + ] 58 + 59 + [[package]] 60 + name = "crossbeam-deque" 61 + version = "0.8.6" 62 + source = "registry+https://github.com/rust-lang/crates.io-index" 63 + checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" 64 + dependencies = [ 65 + "crossbeam-epoch", 66 + "crossbeam-utils", 67 + ] 68 + 69 + [[package]] 70 + name = "crossbeam-epoch" 71 + version = "0.9.18" 72 + source = "registry+https://github.com/rust-lang/crates.io-index" 73 + checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" 74 + dependencies = [ 75 + "crossbeam-utils", 76 + ] 77 + 78 + [[package]] 79 + name = "crossbeam-utils" 80 + version = "0.8.21" 81 + source = "registry+https://github.com/rust-lang/crates.io-index" 82 + checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" 83 + 84 + [[package]] 85 + name = "dispatch2" 86 + version = "0.3.0" 87 + source = "registry+https://github.com/rust-lang/crates.io-index" 88 + checksum = "89a09f22a6c6069a18470eb92d2298acf25463f14256d24778e1230d789a2aec" 89 + dependencies = [ 90 + "bitflags", 91 + "objc2", 92 + ] 93 + 94 + [[package]] 95 + name = "either" 96 + version = "1.15.0" 97 + source = "registry+https://github.com/rust-lang/crates.io-index" 98 + checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" 99 + 100 + [[package]] 101 + name = "fdeflate" 102 + version = "0.3.7" 103 + source = "registry+https://github.com/rust-lang/crates.io-index" 104 + checksum = "1e6853b52649d4ac5c0bd02320cddc5ba956bdb407c4b75a2c6b75bf51500f8c" 105 + dependencies = [ 106 + "simd-adler32", 107 + ] 108 + 109 + [[package]] 110 + name = "flate2" 111 + version = "1.1.5" 112 + source = "registry+https://github.com/rust-lang/crates.io-index" 113 + checksum = "bfe33edd8e85a12a67454e37f8c75e730830d83e313556ab9ebf9ee7fbeb3bfb" 114 + dependencies = [ 115 + "crc32fast", 116 + "miniz_oxide", 117 + ] 118 + 119 + [[package]] 120 + name = "image" 121 + version = "0.25.9" 122 + source = "registry+https://github.com/rust-lang/crates.io-index" 123 + checksum = "e6506c6c10786659413faa717ceebcb8f70731c0a60cbae39795fdf114519c1a" 124 + dependencies = [ 125 + "bytemuck", 126 + "byteorder-lite", 127 + "moxcms", 128 + "num-traits", 129 + "png", 130 + "zune-core", 131 + "zune-jpeg", 132 + ] 133 + 134 + [[package]] 135 + name = "libc" 136 + version = "0.2.178" 137 + source = "registry+https://github.com/rust-lang/crates.io-index" 138 + checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091" 139 + 140 + [[package]] 141 + name = "miniz_oxide" 142 + version = "0.8.9" 143 + source = "registry+https://github.com/rust-lang/crates.io-index" 144 + checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" 145 + dependencies = [ 146 + "adler2", 147 + "simd-adler32", 148 + ] 149 + 150 + [[package]] 151 + name = "moxcms" 152 + version = "0.7.11" 153 + source = "registry+https://github.com/rust-lang/crates.io-index" 154 + checksum = "ac9557c559cd6fc9867e122e20d2cbefc9ca29d80d027a8e39310920ed2f0a97" 155 + dependencies = [ 156 + "num-traits", 157 + "pxfm", 158 + ] 159 + 160 + [[package]] 161 + name = "num-traits" 162 + version = "0.2.19" 163 + source = "registry+https://github.com/rust-lang/crates.io-index" 164 + checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" 165 + dependencies = [ 166 + "autocfg", 167 + ] 168 + 169 + [[package]] 170 + name = "objc2" 171 + version = "0.6.3" 172 + source = "registry+https://github.com/rust-lang/crates.io-index" 173 + checksum = "b7c2599ce0ec54857b29ce62166b0ed9b4f6f1a70ccc9a71165b6154caca8c05" 174 + dependencies = [ 175 + "objc2-encode", 176 + ] 177 + 178 + [[package]] 179 + name = "objc2-core-foundation" 180 + version = "0.3.2" 181 + source = "registry+https://github.com/rust-lang/crates.io-index" 182 + checksum = "2a180dd8642fa45cdb7dd721cd4c11b1cadd4929ce112ebd8b9f5803cc79d536" 183 + dependencies = [ 184 + "bitflags", 185 + "dispatch2", 186 + "objc2", 187 + ] 188 + 189 + [[package]] 190 + name = "objc2-encode" 191 + version = "4.1.0" 192 + source = "registry+https://github.com/rust-lang/crates.io-index" 193 + checksum = "ef25abbcd74fb2609453eb695bd2f860d389e457f67dc17cafc8b8cbc89d0c33" 194 + 195 + [[package]] 196 + name = "objc2-foundation" 197 + version = "0.3.2" 198 + source = "registry+https://github.com/rust-lang/crates.io-index" 199 + checksum = "e3e0adef53c21f888deb4fa59fc59f7eb17404926ee8a6f59f5df0fd7f9f3272" 200 + dependencies = [ 201 + "bitflags", 202 + "block2", 203 + "libc", 204 + "objc2", 205 + "objc2-core-foundation", 206 + ] 207 + 208 + [[package]] 209 + name = "objc2-metal" 210 + version = "0.3.2" 211 + source = "registry+https://github.com/rust-lang/crates.io-index" 212 + checksum = "a0125f776a10d00af4152d74616409f0d4a2053a6f57fa5b7d6aa2854ac04794" 213 + dependencies = [ 214 + "bitflags", 215 + "block2", 216 + "dispatch2", 217 + "objc2", 218 + "objc2-core-foundation", 219 + "objc2-foundation", 220 + ] 221 + 222 + [[package]] 223 + name = "png" 224 + version = "0.18.0" 225 + source = "registry+https://github.com/rust-lang/crates.io-index" 226 + checksum = "97baced388464909d42d89643fe4361939af9b7ce7a31ee32a168f832a70f2a0" 227 + dependencies = [ 228 + "bitflags", 229 + "crc32fast", 230 + "fdeflate", 231 + "flate2", 232 + "miniz_oxide", 233 + ] 234 + 235 + [[package]] 236 + name = "pxfm" 237 + version = "0.1.27" 238 + source = "registry+https://github.com/rust-lang/crates.io-index" 239 + checksum = "7186d3822593aa4393561d186d1393b3923e9d6163d3fbfd6e825e3e6cf3e6a8" 240 + dependencies = [ 241 + "num-traits", 242 + ] 243 + 244 + [[package]] 245 + name = "rayon" 246 + version = "1.11.0" 247 + source = "registry+https://github.com/rust-lang/crates.io-index" 248 + checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" 249 + dependencies = [ 250 + "either", 251 + "rayon-core", 252 + ] 253 + 254 + [[package]] 255 + name = "rayon-core" 256 + version = "1.13.0" 257 + source = "registry+https://github.com/rust-lang/crates.io-index" 258 + checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" 259 + dependencies = [ 260 + "crossbeam-deque", 261 + "crossbeam-utils", 262 + ] 263 + 264 + [[package]] 265 + name = "simd-adler32" 266 + version = "0.3.8" 267 + source = "registry+https://github.com/rust-lang/crates.io-index" 268 + checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" 269 + 270 + [[package]] 271 + name = "splined" 272 + version = "0.1.0" 273 + dependencies = [ 274 + "image", 275 + "objc2", 276 + "objc2-foundation", 277 + "objc2-metal", 278 + "rayon", 279 + ] 280 + 281 + [[package]] 282 + name = "zune-core" 283 + version = "0.5.0" 284 + source = "registry+https://github.com/rust-lang/crates.io-index" 285 + checksum = "111f7d9820f05fd715df3144e254d6fc02ee4088b0644c0ffd0efc9e6d9d2773" 286 + 287 + [[package]] 288 + name = "zune-jpeg" 289 + version = "0.5.7" 290 + source = "registry+https://github.com/rust-lang/crates.io-index" 291 + checksum = "51d915729b0e7d5fe35c2f294c5dc10b30207cc637920e5b59077bfa3da63f28" 292 + dependencies = [ 293 + "zune-core", 294 + ]
+15
Cargo.toml
··· 1 + [package] 2 + name = "splined" 3 + version = "0.1.0" 4 + edition = "2024" 5 + 6 + [dependencies] 7 + image = { version = "0.25", default-features = false, features = ["jpeg", "png"] } 8 + objc2 = "0.6" 9 + objc2-metal = "0.3" 10 + objc2-foundation = "0.3" 11 + rayon = "1.11" 12 + 13 + [profile.release] 14 + lto = true 15 + codegen-units = 1
assets/01-inp.webp

This is a binary file and will not be displayed.

assets/01-out.webp

This is a binary file and will not be displayed.

assets/02-inp.webp

This is a binary file and will not be displayed.

assets/02-out.webp

This is a binary file and will not be displayed.

assets/02.gif

This is a binary file and will not be displayed.

assets/03-inp.webp

This is a binary file and will not be displayed.

assets/03-out.webp

This is a binary file and will not be displayed.

assets/04-inp.webp

This is a binary file and will not be displayed.

assets/04-out.webp

This is a binary file and will not be displayed.

+3
changelog
··· 1 + version 0.1.0 2 + 2026-01-15 3 + first release
+29
license
··· 1 + BSD 3-Clause License 2 + 3 + Copyright (c) 2026, lμthenwałd 4 + All rights reserved. 5 + 6 + Redistribution and use in source and binary forms, with or without 7 + modification, are permitted provided that the following conditions are met: 8 + 9 + * Redistributions of source code must retain the above copyright notice, this 10 + list of conditions and the following disclaimer. 11 + 12 + * Redistributions in binary form must reproduce the above copyright notice, 13 + this list of conditions and the following disclaimer in the documentation 14 + and/or other materials provided with the distribution. 15 + 16 + * Neither the name of the copyright holder nor the names of its 17 + contributors may be used to endorse or promote products derived from 18 + this software without specific prior written permission. 19 + 20 + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+98
readme.md
··· 1 + # splined 2 + 3 + iterative image reconstruction using random cubic bézier strokes, accelerated on [metal][metal] 4 + 5 + ## showcase 6 + 7 + > [!NOTE] 8 + > images used here are all under open access by [The Met][met] 9 + 10 + <table align="center"> 11 + <tr> 12 + <td><img src="assets/01-inp.webp" width="520"/></td> 13 + <td><img src="assets/01-out.webp" width="520"/></td> 14 + </tr> 15 + <tr> 16 + <td><img src="assets/02-inp.webp" width="520"/></td> 17 + <td><img src="assets/02-out.webp" width="520"/></td> 18 + </tr> 19 + <tr> 20 + <td><img src="assets/03-inp.webp" width="520"/></td> 21 + <td><img src="assets/03-out.webp" width="520"/></td> 22 + </tr> 23 + <tr> 24 + <td><img src="assets/04-inp.webp" width="520"/></td> 25 + <td><img src="assets/04-out.webp" width="520"/></td> 26 + </tr> 27 + </table> 28 + 29 + same input & different seeds → different reconstructions → simple animation: 30 + 31 + <p align="center"> 32 + <img src="assets/02.gif" width="720"/> 33 + </p> 34 + 35 + ## build 36 + 37 + ```sh 38 + cargo build -r 39 + ``` 40 + 41 + ## usage 42 + 43 + ```sh 44 + splined: iterative image reconstruction with random cubic bézier strokes (metal-accelerated) 45 + 46 + usage: splined <input> [args] 47 + 48 + args: 49 + -n, --number <u32> max splines to draw (default: (w*h)^0.7) 50 + -b, --batch <u32> batch size per gpu step (default: 32) 51 + -s, --seed <u64> rng seed (default: 0) 52 + --max-gpu <f32> max gpu usage in (0, 1] (default: 1.0) 53 + -l, --log <0..3> logging level (default: 1) 54 + -o, --output <path> output file or dir (default: output.png) 55 + -c, --current <path> current canvas image to resume from (single-file only) 56 + --nth <u32> save every nth accepted stroke (uses -o as dir) 57 + --bg <avg|r,g,b> initial canvas color (default: avg) 58 + -a, --alpha <f32> stroke alpha in [0, 1] (default: 1) 59 + --min-accept-ratio <f32> stagnant if accepted < batch*ratio (default: 0.02) 60 + --max-stagnant-batches <u32> stop after this many stagnant batches (default: 10) 61 + 62 + input: 63 + - file: writes one image to -o/--output (default: output.png) 64 + - dir: -o/--output must be a dir; mirrors input tree under it 65 + - --nth: saves frames to output dir every nth accepted stroke; also writes final.png 66 + 67 + examples: 68 + splined in.png -o out.png 69 + splined in.png -n 5000 -b 64 -s 42 -o out.png 70 + splined in.png --nth 50 -o frames/ 71 + splined images/ -o results/ -n 5000 -b 64 -s 42 --nth 50 72 + ``` 73 + 74 + ## algorithm 75 + 76 + - convert input to [oklab][oklab] color space 77 + - initialize canvas to image average (or `--bg`) 78 + - repeat until target stroke count or convergence: 79 + - sample batch of random cubic béziers (4 control points, uniform over image) 80 + - rasterize each curve to coverage mask 81 + - set stroke color to coverage-weighted mean of target pixels 82 + - accept curves that strictly reduce squared oklab error (Δε² < 0) 83 + - commit accepted strokes to canvas 84 + - export final canvas 85 + 86 + ## reference 87 + 88 + - [Geometrize][geometrize]: a desktop app that geometrizes images into geometric primitives 89 + 90 + ## todo 91 + 92 + - (better) antialiasing algorithm for drawing cubic bézier strokes 93 + - support other gpu backends (e.g., wasm) 94 + 95 + [metal]: https://en.wikipedia.org/wiki/Metal_(API) 96 + [oklab]: https://en.wikipedia.org/wiki/Oklab_color_space 97 + [met]: https://www.metmuseum.org/hubs/open-access 98 + [geometrize]: https://github.com/Tw1ddle/geometrize
+1
rustfmt.toml
··· 1 + tab_spaces = 3
+221
src/args.rs
··· 1 + fn usage() -> String { 2 + [ 3 + "splined: iterative image reconstruction with random cubic bézier strokes (metal-accelerated)", 4 + "", 5 + "usage: splined <input> [args]", 6 + "", 7 + "args:", 8 + " -n, --number <u32> max splines to draw (default: (w*h)^0.7)", 9 + " -b, --batch <u32> batch size per gpu step (default: 32)", 10 + " -s, --seed <u64> rng seed (default: 0)", 11 + " --max-gpu <f32> max gpu usage in (0, 1] (default: 1.0)", 12 + " -l, --log <0..3> logging level (default: 1)", 13 + " -o, --output <path> output file or dir (default: output.png)", 14 + " -c, --current <path> current canvas image to resume from (single-file only)", 15 + " --nth <u32> save every nth accepted stroke (uses -o as dir)", 16 + " --bg <avg|r,g,b> initial canvas color (default: avg)", 17 + " -a, --alpha <f32> stroke alpha in [0, 1] (default: 1)", 18 + " --min-accept-ratio <f32> stagnant if accepted < batch*ratio (default: 0.02)", 19 + " --max-stagnant-batches <u32> stop after this many stagnant batches (default: 10)", 20 + "", 21 + "input:", 22 + " - file: writes one image to -o/--output (default: output.png)", 23 + " - dir: -o/--output must be a dir; mirrors input tree under it", 24 + " - --nth: saves frames to output dir every nth accepted stroke; also writes final.png", 25 + "", 26 + "examples:", 27 + " splined in.png -o out.png", 28 + " splined in.png -n 5000 -b 64 -s 42 -o out.png", 29 + " splined in.png --nth 50 -o frames/", 30 + " splined images/ -o results/ -n 5000 -b 64 -s 42 --nth 50", 31 + "", 32 + ] 33 + .join("\n") 34 + } 35 + 36 + #[derive(Clone, Debug)] 37 + pub struct Config { 38 + pub input: String, 39 + pub number: Option<u32>, 40 + pub batch: u32, 41 + pub min_accept_ratio: f32, 42 + pub seed: u64, 43 + pub max_gpu: f32, 44 + pub log: u8, 45 + pub output: String, 46 + pub output_provided: bool, 47 + pub current: Option<String>, 48 + pub nth: Option<u32>, 49 + pub bg: Bg, 50 + pub alpha: f32, 51 + pub max_stagnant_batches: u32, 52 + } 53 + 54 + #[derive(Clone, Copy, Debug)] 55 + pub enum Bg { 56 + Avg, 57 + RgbU8 { r: u8, g: u8, b: u8 }, 58 + } 59 + 60 + impl Config { 61 + pub fn parse(argv: &[String]) -> Result<Self, String> { 62 + if argv.len() < 2 { 63 + return Err(usage()); 64 + } 65 + 66 + let mut input: Option<String> = None; 67 + let mut number: Option<u32> = None; 68 + let mut batch: u32 = 32; 69 + let mut min_accept_ratio: Option<f32> = None; 70 + let mut seed: u64 = 0; 71 + let mut max_gpu: f32 = 1.0; 72 + let mut log: u8 = 1; 73 + let mut output: String = "output.png".to_string(); 74 + let mut output_provided: bool = false; 75 + let mut current: Option<String> = None; 76 + let mut nth: Option<u32> = None; 77 + let mut bg: Bg = Bg::Avg; 78 + let mut alpha: f32 = 1.0; 79 + let mut max_stagnant_batches: u32 = 10; 80 + 81 + let mut i = 1usize; 82 + while i < argv.len() { 83 + let a = argv[i].as_str(); 84 + if !a.starts_with('-') && input.is_none() { 85 + input = Some(argv[i].clone()); 86 + i += 1; 87 + continue; 88 + } 89 + 90 + match a { 91 + "-n" | "--number" => { 92 + i += 1; 93 + number = Some(parse_val(argv, i, a)?); 94 + i += 1; 95 + } 96 + "-b" | "--batch" => { 97 + i += 1; 98 + batch = parse_val(argv, i, a)?; 99 + i += 1; 100 + } 101 + "--min-accept-ratio" => { 102 + i += 1; 103 + min_accept_ratio = Some(parse_val(argv, i, a)?); 104 + i += 1; 105 + } 106 + "-s" | "--seed" => { 107 + i += 1; 108 + seed = parse_val(argv, i, a)?; 109 + i += 1; 110 + } 111 + "--max-gpu" => { 112 + i += 1; 113 + max_gpu = parse_val(argv, i, a)?; 114 + i += 1; 115 + } 116 + "-l" | "--log" => { 117 + i += 1; 118 + log = parse_val(argv, i, a)?; 119 + i += 1; 120 + } 121 + "-o" | "--output" => { 122 + i += 1; 123 + output = parse_val(argv, i, a)?; 124 + output_provided = true; 125 + i += 1; 126 + } 127 + "-c" | "--current" => { 128 + i += 1; 129 + current = Some(parse_val(argv, i, a)?); 130 + i += 1; 131 + } 132 + "--nth" => { 133 + i += 1; 134 + nth = Some(parse_val(argv, i, a)?); 135 + i += 1; 136 + } 137 + "--bg" => { 138 + i += 1; 139 + bg = parse_bg(argv, i, a)?; 140 + i += 1; 141 + } 142 + "-a" | "--alpha" => { 143 + i += 1; 144 + alpha = parse_val(argv, i, a)?; 145 + i += 1; 146 + } 147 + "--max-stagnant-batches" => { 148 + i += 1; 149 + max_stagnant_batches = parse_val(argv, i, a)?; 150 + i += 1; 151 + } 152 + "-h" | "--help" => return Err(usage()), 153 + _ => return Err(format!("unknown arg: {a}\n\n{}", usage())), 154 + } 155 + } 156 + 157 + let Some(input) = input else { 158 + return Err(usage()); 159 + }; 160 + 161 + if batch == 0 { 162 + return Err("batch must be > 0".to_string()); 163 + } 164 + let min_accept_ratio = min_accept_ratio.unwrap_or(0.02); 165 + if !(0.0..=1.0).contains(&min_accept_ratio) { 166 + return Err("min-accept-ratio must be in [0, 1]".to_string()); 167 + } 168 + if log > 3 { 169 + return Err("log must be in [0, 3]".to_string()); 170 + } 171 + if !(0.0 < max_gpu && max_gpu <= 1.0) { 172 + return Err("max-gpu must be in (0, 1]".to_string()); 173 + } 174 + if let Some(nth) = nth { 175 + if nth == 0 { 176 + return Err("nth must be > 0".to_string()); 177 + } 178 + } 179 + if !(0.0..=1.0).contains(&alpha) { 180 + return Err("alpha must be in [0, 1]".to_string()); 181 + } 182 + if max_stagnant_batches == 0 { 183 + return Err("max-stagnant-batches must be > 0".to_string()); 184 + } 185 + 186 + Ok(Self { 187 + input, 188 + number, 189 + batch, 190 + min_accept_ratio, 191 + seed, 192 + max_gpu, 193 + log, 194 + output, 195 + output_provided, 196 + current, 197 + nth, 198 + bg, 199 + alpha, 200 + max_stagnant_batches, 201 + }) 202 + } 203 + } 204 + 205 + fn parse_val<T: std::str::FromStr>(argv: &[String], idx: usize, flag: &str) -> Result<T, String> { 206 + let s = argv.get(idx).ok_or_else(|| format!("missing value for {flag}"))?; 207 + s.parse::<T>().map_err(|_| format!("invalid value for {flag}: {s}")) 208 + } 209 + 210 + fn parse_bg(argv: &[String], idx: usize, flag: &str) -> Result<Bg, String> { 211 + let s = argv.get(idx).ok_or_else(|| format!("missing value for {flag}"))?; 212 + if s == "avg" { 213 + return Ok(Bg::Avg); 214 + } 215 + let parts: Vec<_> = s.split(',').collect(); 216 + if parts.len() != 3 { 217 + return Err(format!("invalid bg, expected avg or r,g,b: {s}")); 218 + } 219 + let parse = |i: usize| parts[i].parse::<u8>().map_err(|_| format!("invalid bg component: {s}")); 220 + Ok(Bg::RgbU8 { r: parse(0)?, g: parse(1)?, b: parse(2)? }) 221 + }
+39
src/fs.rs
··· 1 + use crate::args; 2 + use std::path::{Path, PathBuf}; 3 + 4 + pub fn validate_and_prepare_dir_output(cfg: &args::Config) -> Result<PathBuf, String> { 5 + if !cfg.output_provided { 6 + return Err("when input is a directory, -o/--output must be provided".into()); 7 + } 8 + 9 + let out_root = PathBuf::from(&cfg.output); 10 + if out_root.exists() && !out_root.is_dir() { 11 + return Err(format!("output must be a directory, got file: {}", out_root.display())); 12 + } 13 + std::fs::create_dir_all(&out_root) 14 + .map_err(|e| format!("failed to create output dir {}: {e}", out_root.display()))?; 15 + Ok(out_root) 16 + } 17 + 18 + pub fn walk_files_recursive(root: &Path) -> Result<Vec<PathBuf>, String> { 19 + let mut out = Vec::new(); 20 + let mut stack = vec![root.to_path_buf()]; 21 + 22 + while let Some(dir) = stack.pop() { 23 + let rd = std::fs::read_dir(&dir) 24 + .map_err(|e| format!("failed to read dir {}: {e}", dir.display()))?; 25 + for entry in rd { 26 + let entry = entry.map_err(|e| format!("failed to read dir entry: {e}"))?; 27 + let ty = entry 28 + .file_type() 29 + .map_err(|e| format!("failed to stat {}: {e}", entry.path().display()))?; 30 + if ty.is_dir() { 31 + stack.push(entry.path()); 32 + } else if ty.is_file() { 33 + out.push(entry.path()); 34 + } 35 + } 36 + } 37 + 38 + Ok(out) 39 + }
+115
src/log.rs
··· 1 + use std::fmt::Arguments; 2 + use std::sync::atomic::{AtomicU8, Ordering}; 3 + 4 + static LEVEL: AtomicU8 = AtomicU8::new(1); 5 + 6 + pub fn set_level(level: u8) { 7 + LEVEL.store(level.min(3), Ordering::Relaxed); 8 + } 9 + 10 + pub fn level() -> u8 { 11 + LEVEL.load(Ordering::Relaxed) 12 + } 13 + 14 + pub fn l1(args: Arguments) { 15 + emit(1, args); 16 + } 17 + 18 + pub fn l2(args: Arguments) { 19 + emit(2, args); 20 + } 21 + 22 + fn emit(min: u8, args: Arguments) { 23 + if level() >= min { 24 + eprintln!("{args}"); 25 + } 26 + } 27 + 28 + pub struct BatchInfo { 29 + pub batch_idx: u64, 30 + pub accepted_total: u32, 31 + pub max_splines: u32, 32 + pub accepted_now: u32, 33 + pub batch_size: u32, 34 + pub remaining: u32, 35 + pub consecutive_stagnant: u32, 36 + pub max_stagnant_batches: u32, 37 + pub elapsed_s: f64, 38 + } 39 + 40 + pub struct BatchLogger { 41 + rate_avg: Option<f64>, 42 + accept_ratio_avg: Option<f64>, 43 + alpha: f64, 44 + last_elapsed: f64, 45 + } 46 + 47 + impl BatchLogger { 48 + pub fn new() -> Self { 49 + Self { 50 + rate_avg: None, 51 + accept_ratio_avg: None, 52 + alpha: 0.25, 53 + last_elapsed: 0.0, 54 + } 55 + } 56 + 57 + pub fn log_batch(&mut self, info: BatchInfo) { 58 + let stop = info.consecutive_stagnant >= info.max_stagnant_batches; 59 + let delta = (info.elapsed_s - self.last_elapsed).max(0.0); 60 + self.last_elapsed = info.elapsed_s; 61 + 62 + let rate_sample = if delta > 0.0 { 63 + info.accepted_now as f64 / delta 64 + } else { 65 + 0.0 66 + }; 67 + let ratio_sample = if info.batch_size > 0 { 68 + info.accepted_now as f64 / info.batch_size as f64 69 + } else { 70 + 0.0 71 + }; 72 + 73 + self.rate_avg = Some(self.update_ema(self.rate_avg, rate_sample)); 74 + self.accept_ratio_avg = Some(self.update_ema(self.accept_ratio_avg, ratio_sample)); 75 + 76 + let rate_avg = self.rate_avg.unwrap_or(0.0); 77 + let accept_avg = self.accept_ratio_avg.unwrap_or(0.0); 78 + 79 + self.emit(&info, ratio_sample, rate_avg, accept_avg, stop); 80 + } 81 + 82 + fn update_ema(&self, current: Option<f64>, sample: f64) -> f64 { 83 + match current { 84 + Some(prev) => prev * (1.0 - self.alpha) + sample * self.alpha, 85 + None => sample, 86 + } 87 + } 88 + 89 + fn emit(&self, info: &BatchInfo, batch_ratio: f64, rate_avg: f64, accept_avg: f64, stop: bool) { 90 + let lv = level(); 91 + let stop_note = if stop { " stop:stagnant" } else { "" }; 92 + 93 + if lv >= 3 { 94 + eprintln!( 95 + "b:{} t:{:.1}s acc:{}/{} (+{}, {:.3}) stagnant:{}/{} rem:{} avg_rate:{:.2}/s avg_accept:{:.3}{}", 96 + info.batch_idx, info.elapsed_s, info.accepted_total, info.max_splines, 97 + info.accepted_now, batch_ratio, info.consecutive_stagnant, info.max_stagnant_batches, 98 + info.remaining, rate_avg, accept_avg, stop_note 99 + ); 100 + } else if lv >= 2 { 101 + eprintln!( 102 + "b:{} t:{:.1}s acc:{}/{} (+{}, {:.3}) stagnant:{}/{} avg_rate:{:.2}/s avg_accept:{:.3}{}", 103 + info.batch_idx, info.elapsed_s, info.accepted_total, info.max_splines, 104 + info.accepted_now, batch_ratio, info.consecutive_stagnant, info.max_stagnant_batches, 105 + rate_avg, accept_avg, stop_note 106 + ); 107 + } else if lv >= 1 && (info.accepted_now > 0 || stop) { 108 + eprintln!( 109 + "{}/{} (+{}) stagnant:{}/{}{}", 110 + info.accepted_total, info.max_splines, info.accepted_now, 111 + info.consecutive_stagnant, info.max_stagnant_batches, stop_note 112 + ); 113 + } 114 + } 115 + }
+41
src/main.rs
··· 1 + mod args; 2 + mod fs; 3 + mod log; 4 + mod metal; 5 + mod oklab; 6 + mod rng; 7 + mod sampling; 8 + mod pipeline; 9 + 10 + use crate::pipeline::{process_dir, process_one, ResolvedOutput}; 11 + use std::path::{Path, PathBuf}; 12 + use std::process::ExitCode; 13 + 14 + fn main() -> ExitCode { 15 + match real_main() { 16 + Ok(()) => ExitCode::SUCCESS, 17 + Err(msg) => { 18 + eprintln!("{msg}"); 19 + ExitCode::FAILURE 20 + } 21 + } 22 + } 23 + 24 + fn real_main() -> Result<(), String> { 25 + let args = std::env::args().collect::<Vec<_>>(); 26 + let cfg = args::Config::parse(&args)?; 27 + log::set_level(cfg.log); 28 + let input_path = Path::new(&cfg.input); 29 + if input_path.is_dir() { 30 + if cfg.current.is_some() { 31 + return Err("-c/--current is not supported when input is a directory".into()); 32 + } 33 + let out_root = fs::validate_and_prepare_dir_output(&cfg)?; 34 + return process_dir(&cfg, input_path, &out_root); 35 + } 36 + 37 + let output = ResolvedOutput::from_config(&cfg, PathBuf::from(&cfg.output)); 38 + let final_path = process_one(&cfg, input_path, output)?; 39 + println!("output: {}", final_path.display()); 40 + Ok(()) 41 + }
+230
src/metal/kernels.metal
··· 1 + #include <metal_stdlib> 2 + using namespace metal; 3 + 4 + struct Candidate { 5 + float2 p0; 6 + float2 p1; 7 + float2 p2; 8 + float2 p3; 9 + uint bx; 10 + uint by; 11 + uint bw; 12 + uint bh; 13 + }; 14 + 15 + struct MeanOut { 16 + float sum_w; 17 + float3 sum_lab; 18 + }; 19 + 20 + struct ScoreOut { 21 + float delta_e2; 22 + }; 23 + 24 + struct CommonParams { 25 + uint width; 26 + uint height; 27 + float alpha; 28 + uint _pad0; 29 + }; 30 + 31 + struct ApplyParams { 32 + uint cand_index; 33 + float alpha; 34 + uint _pad0; 35 + uint _pad1; 36 + }; 37 + 38 + static inline float2 cubic_eval(thread const Candidate& c, float t) { 39 + float u = 1.0f - t; 40 + float b0 = u * u * u; 41 + float b1 = 3.0f * u * u * t; 42 + float b2 = 3.0f * u * t * t; 43 + float b3 = t * t * t; 44 + return c.p0 * b0 + c.p1 * b1 + c.p2 * b2 + c.p3 * b3; 45 + } 46 + 47 + static inline float dist_to_segment(float2 p, float2 a, float2 b) { 48 + float2 ab = b - a; 49 + float denom = dot(ab, ab); 50 + if (denom <= 1e-12f) { 51 + return length(p - a); 52 + } 53 + float t = clamp(dot(p - a, ab) / denom, 0.0f, 1.0f); 54 + float2 q = a + t * ab; 55 + return length(p - q); 56 + } 57 + 58 + static inline float coverage_at(thread const Candidate& c, float2 p) { 59 + constexpr uint segments = 32; 60 + float min_d = 1e9f; 61 + float2 prev = cubic_eval(c, 0.0f); 62 + for (uint i = 1; i <= segments; i++) { 63 + float t = (float)i / (float)segments; 64 + float2 cur = cubic_eval(c, t); 65 + min_d = min(min_d, dist_to_segment(p, prev, cur)); 66 + prev = cur; 67 + } 68 + return clamp(1.0f - 2.0f * min_d, 0.0f, 1.0f); 69 + } 70 + 71 + kernel void mean_pass( 72 + const device Candidate* candidates [[buffer(0)]], 73 + texture2d<float, access::read> target [[texture(0)]], 74 + device MeanOut* out [[buffer(1)]], 75 + constant CommonParams& params [[buffer(2)]], 76 + uint tid [[thread_index_in_threadgroup]], 77 + uint3 tg [[threadgroup_position_in_grid]] 78 + ) { 79 + constexpr uint TG = 256; 80 + uint ci = tg.x; 81 + Candidate c = candidates[ci]; 82 + 83 + threadgroup float tg_sum_w[TG]; 84 + threadgroup float3 tg_sum_lab[TG]; 85 + 86 + float sum_w = 0.0f; 87 + float3 sum_lab = float3(0.0f); 88 + 89 + uint total = c.bw * c.bh; 90 + for (uint i = tid; i < total; i += TG) { 91 + uint ox = i % c.bw; 92 + uint oy = i / c.bw; 93 + uint x = c.bx + ox; 94 + uint y = c.by + oy; 95 + if (x >= params.width || y >= params.height) { 96 + continue; 97 + } 98 + float2 p = float2((float)x + 0.5f, (float)y + 0.5f); 99 + float cov = coverage_at(c, p); 100 + if (cov <= 0.0f) { 101 + continue; 102 + } 103 + float3 t_lab = target.read(uint2(x, y)).xyz; 104 + sum_w += cov; 105 + sum_lab += t_lab * cov; 106 + } 107 + 108 + tg_sum_w[tid] = sum_w; 109 + tg_sum_lab[tid] = sum_lab; 110 + threadgroup_barrier(mem_flags::mem_threadgroup); 111 + 112 + for (uint stride = TG / 2; stride > 0; stride >>= 1) { 113 + if (tid < stride) { 114 + tg_sum_w[tid] += tg_sum_w[tid + stride]; 115 + tg_sum_lab[tid] += tg_sum_lab[tid + stride]; 116 + } 117 + threadgroup_barrier(mem_flags::mem_threadgroup); 118 + } 119 + 120 + if (tid == 0) { 121 + MeanOut m; 122 + m.sum_w = tg_sum_w[0]; 123 + m.sum_lab = tg_sum_lab[0]; 124 + out[ci] = m; 125 + } 126 + } 127 + 128 + kernel void score_pass( 129 + const device Candidate* candidates [[buffer(0)]], 130 + const device MeanOut* means [[buffer(1)]], 131 + device ScoreOut* out [[buffer(2)]], 132 + texture2d<float, access::read> target [[texture(0)]], 133 + texture2d<float, access::read> canvas [[texture(1)]], 134 + constant CommonParams& params [[buffer(3)]], 135 + uint tid [[thread_index_in_threadgroup]], 136 + uint3 tg [[threadgroup_position_in_grid]] 137 + ) { 138 + constexpr uint TG = 256; 139 + uint ci = tg.x; 140 + Candidate c = candidates[ci]; 141 + MeanOut m = means[ci]; 142 + 143 + float3 stroke = float3(0.0f); 144 + if (m.sum_w > 0.0f) { 145 + stroke = m.sum_lab / m.sum_w; 146 + } 147 + 148 + threadgroup float tg_delta[TG]; 149 + 150 + float delta = 0.0f; 151 + uint total = c.bw * c.bh; 152 + for (uint i = tid; i < total; i += TG) { 153 + uint ox = i % c.bw; 154 + uint oy = i / c.bw; 155 + uint x = c.bx + ox; 156 + uint y = c.by + oy; 157 + if (x >= params.width || y >= params.height) { 158 + continue; 159 + } 160 + float2 p = float2((float)x + 0.5f, (float)y + 0.5f); 161 + float cov = coverage_at(c, p); 162 + if (cov <= 0.0f) { 163 + continue; 164 + } 165 + 166 + float eff_alpha = clamp(params.alpha * cov, 0.0f, 1.0f); 167 + 168 + float3 t_lab = target.read(uint2(x, y)).xyz; 169 + float3 c_lab = canvas.read(uint2(x, y)).xyz; 170 + float3 blended = mix(c_lab, stroke, eff_alpha); 171 + 172 + float3 d_after = blended - t_lab; 173 + float3 d_before = c_lab - t_lab; 174 + float e_after = dot(d_after, d_after); 175 + float e_before = dot(d_before, d_before); 176 + delta += (e_after - e_before); 177 + } 178 + 179 + tg_delta[tid] = delta; 180 + threadgroup_barrier(mem_flags::mem_threadgroup); 181 + 182 + for (uint stride = TG / 2; stride > 0; stride >>= 1) { 183 + if (tid < stride) { 184 + tg_delta[tid] += tg_delta[tid + stride]; 185 + } 186 + threadgroup_barrier(mem_flags::mem_threadgroup); 187 + } 188 + 189 + if (tid == 0) { 190 + ScoreOut s; 191 + s.delta_e2 = tg_delta[0]; 192 + out[ci] = s; 193 + } 194 + } 195 + 196 + kernel void apply_pass( 197 + texture2d<float, access::read_write> canvas [[texture(0)]], 198 + const device Candidate* candidates [[buffer(0)]], 199 + const device MeanOut* means [[buffer(1)]], 200 + constant ApplyParams& params [[buffer(2)]], 201 + uint2 tid [[thread_position_in_grid]] 202 + ) { 203 + uint ci = params.cand_index; 204 + Candidate c = candidates[ci]; 205 + MeanOut m = means[ci]; 206 + if (m.sum_w <= 0.0f) { 207 + return; 208 + } 209 + 210 + uint ox = tid.x; 211 + uint oy = tid.y; 212 + if (ox >= c.bw || oy >= c.bh) { 213 + return; 214 + } 215 + 216 + uint x = c.bx + ox; 217 + uint y = c.by + oy; 218 + 219 + float3 stroke = m.sum_lab / m.sum_w; 220 + float2 p = float2((float)x + 0.5f, (float)y + 0.5f); 221 + float cov = coverage_at(c, p); 222 + if (cov <= 0.0f) { 223 + return; 224 + } 225 + 226 + float eff_alpha = clamp(params.alpha * cov, 0.0f, 1.0f); 227 + float4 old = canvas.read(uint2(x, y)); 228 + float3 blended = mix(old.xyz, stroke, eff_alpha); 229 + canvas.write(float4(blended, 0.0f), uint2(x, y)); 230 + }
+567
src/metal/mod.rs
··· 1 + #[link(name = "CoreGraphics", kind = "framework")] 2 + unsafe extern "C" {} 3 + 4 + use crate::oklab::Oklab; 5 + use objc2::rc::Retained; 6 + use objc2::runtime::ProtocolObject; 7 + use objc2_foundation::{ns_string, NSString}; 8 + use objc2_metal::{ 9 + MTLBuffer, MTLCommandBuffer, MTLCommandEncoder, MTLCommandQueue, MTLComputeCommandEncoder, 10 + MTLComputePipelineState, MTLCreateSystemDefaultDevice, MTLDevice, MTLLibrary, MTLRegion, 11 + MTLResourceOptions, MTLSize, MTLStorageMode, MTLTexture, MTLTextureDescriptor, MTLTextureUsage, 12 + MTLPixelFormat, 13 + }; 14 + use std::ffi::c_void; 15 + use std::ptr::NonNull; 16 + use std::thread; 17 + use std::time::{Duration, Instant}; 18 + 19 + #[repr(C)] 20 + #[derive(Clone, Copy, Debug, Default)] 21 + /// a single bezier candidate and its conservative pixel bbox. 22 + /// 23 + /// `p0..p3` are control points in pixel space. 24 + /// `(bx, by, bw, bh)` bounds the region on the canvas potentially affected by the stroke. 25 + pub struct Candidate { 26 + pub p0: [f32; 2], 27 + pub p1: [f32; 2], 28 + pub p2: [f32; 2], 29 + pub p3: [f32; 2], 30 + pub bx: u32, 31 + pub by: u32, 32 + pub bw: u32, 33 + pub bh: u32, 34 + } 35 + 36 + #[repr(C)] 37 + #[derive(Clone, Copy, Debug, Default)] 38 + /// output payload for the mean kernel. 39 + struct MeanOut { 40 + sum_w: f32, 41 + _pad0: [f32; 3], 42 + sum_lab: [f32; 3], 43 + _pad1: f32, 44 + } 45 + 46 + #[repr(C)] 47 + #[derive(Clone, Copy, Debug, Default)] 48 + /// output payload for the scoring kernel (lower is better). 49 + struct ScoreOut { 50 + delta_e2: f32, 51 + } 52 + 53 + #[repr(C)] 54 + #[derive(Clone, Copy, Debug, Default)] 55 + /// common parameters shared by compute passes. 56 + struct CommonParams { 57 + width: u32, 58 + height: u32, 59 + alpha: f32, 60 + _pad0: u32, 61 + } 62 + 63 + #[repr(C)] 64 + #[derive(Clone, Copy, Debug, Default)] 65 + /// parameters for applying a selected candidate to the canvas. 66 + struct ApplyParams { 67 + cand_index: u32, 68 + alpha: f32, 69 + _pad0: u32, 70 + _pad1: u32, 71 + } 72 + 73 + /// metal compute context and persistent gpu resources for the current run. 74 + pub struct MetalContext { 75 + queue: Retained<ProtocolObject<dyn MTLCommandQueue>>, 76 + mean_pso: Retained<ProtocolObject<dyn MTLComputePipelineState>>, 77 + score_pso: Retained<ProtocolObject<dyn MTLComputePipelineState>>, 78 + apply_pso: Retained<ProtocolObject<dyn MTLComputePipelineState>>, 79 + target: Retained<ProtocolObject<dyn MTLTexture>>, 80 + canvas: Retained<ProtocolObject<dyn MTLTexture>>, 81 + candidates_buf: Retained<ProtocolObject<dyn objc2_metal::MTLBuffer>>, 82 + means_buf: Retained<ProtocolObject<dyn objc2_metal::MTLBuffer>>, 83 + scores_buf: Retained<ProtocolObject<dyn objc2_metal::MTLBuffer>>, 84 + width: u32, 85 + height: u32, 86 + batch_cap: u32, 87 + max_gpu: f32, 88 + } 89 + 90 + impl MetalContext { 91 + pub fn new(width: u32, height: u32, batch_cap: u32, max_gpu: f32) -> Result<Self, String> { 92 + let Some(device) = MTLCreateSystemDefaultDevice() else { 93 + return Err("failed to create metal device".to_string()); 94 + }; 95 + let Some(queue) = device.newCommandQueue() else { 96 + return Err("failed to create metal command queue".to_string()); 97 + }; 98 + 99 + let lib = compile_library(&device)?; 100 + let mean_pso = compile_pipeline(&device, &lib, ns_string!("mean_pass"))?; 101 + let score_pso = compile_pipeline(&device, &lib, ns_string!("score_pass"))?; 102 + let apply_pso = compile_pipeline(&device, &lib, ns_string!("apply_pass"))?; 103 + 104 + let target = make_lab_texture(&device, width, height, true)?; 105 + let canvas = make_lab_texture(&device, width, height, true)?; 106 + 107 + let candidates_bytes = (batch_cap as usize) * std::mem::size_of::<Candidate>(); 108 + let means_bytes = (batch_cap as usize) * std::mem::size_of::<MeanOut>(); 109 + let scores_bytes = (batch_cap as usize) * std::mem::size_of::<ScoreOut>(); 110 + 111 + let opts = MTLResourceOptions::StorageModeShared; 112 + let candidates_buf = device 113 + .newBufferWithLength_options(candidates_bytes as _, opts) 114 + .ok_or_else(|| "failed to allocate candidates buffer".to_string())?; 115 + let means_buf = device 116 + .newBufferWithLength_options(means_bytes as _, opts) 117 + .ok_or_else(|| "failed to allocate means buffer".to_string())?; 118 + let scores_buf = device 119 + .newBufferWithLength_options(scores_bytes as _, opts) 120 + .ok_or_else(|| "failed to allocate scores buffer".to_string())?; 121 + 122 + Ok(Self { 123 + queue, 124 + mean_pso, 125 + score_pso, 126 + apply_pso, 127 + target, 128 + canvas, 129 + candidates_buf, 130 + means_buf, 131 + scores_buf, 132 + width, 133 + height, 134 + batch_cap, 135 + max_gpu, 136 + }) 137 + } 138 + 139 + fn throttle_after(&self, work: Duration) { 140 + if !(0.0 < self.max_gpu && self.max_gpu < 1.0) { 141 + return; 142 + } 143 + let work_s = work.as_secs_f64(); 144 + if work_s <= 0.0 { 145 + return; 146 + } 147 + let idle_s = work_s * (1.0 / (self.max_gpu as f64) - 1.0); 148 + if idle_s <= 0.0 { 149 + return; 150 + } 151 + thread::sleep(Duration::from_secs_f64(idle_s)); 152 + } 153 + 154 + fn commit_wait_throttled(&self, cb: &ProtocolObject<dyn MTLCommandBuffer>) { 155 + let started = Instant::now(); 156 + cb.commit(); 157 + cb.waitUntilCompleted(); 158 + self.throttle_after(started.elapsed()); 159 + } 160 + 161 + pub fn upload_target_and_init_canvas( 162 + &self, 163 + target: &[Oklab], 164 + canvas_fill: Oklab, 165 + ) -> Result<(), String> { 166 + if target.len() != (self.width as usize) * (self.height as usize) { 167 + return Err("target size mismatch".to_string()); 168 + } 169 + 170 + let tgt = pack_oklab_rgba32f(target); 171 + 172 + let mut can = Vec::<f32>::with_capacity(target.len() * 4); 173 + for _ in 0..target.len() { 174 + can.extend_from_slice(&[canvas_fill.l, canvas_fill.a, canvas_fill.b, 0.0]); 175 + } 176 + 177 + unsafe { 178 + write_full_texture_rgba32f(&self.target, self.width, self.height, &tgt)?; 179 + write_full_texture_rgba32f(&self.canvas, self.width, self.height, &can)?; 180 + } 181 + 182 + Ok(()) 183 + } 184 + 185 + pub fn upload_target_and_set_canvas(&self, target: &[Oklab], canvas: &[Oklab]) -> Result<(), String> { 186 + let n = (self.width as usize) * (self.height as usize); 187 + if target.len() != n { 188 + return Err("target size mismatch".to_string()); 189 + } 190 + if canvas.len() != n { 191 + return Err("canvas size mismatch".to_string()); 192 + } 193 + 194 + let tgt = pack_oklab_rgba32f(target); 195 + let can = pack_oklab_rgba32f(canvas); 196 + unsafe { 197 + write_full_texture_rgba32f(&self.target, self.width, self.height, &tgt)?; 198 + write_full_texture_rgba32f(&self.canvas, self.width, self.height, &can)?; 199 + } 200 + Ok(()) 201 + } 202 + 203 + fn run_mean_and_score(&self, candidates: &[Candidate], alpha: f32) -> Result<Vec<bool>, String> { 204 + if candidates.is_empty() { 205 + return Ok(Vec::new()); 206 + } 207 + if candidates.len() > (self.batch_cap as usize) { 208 + return Err("batch exceeds configured capacity".into()); 209 + } 210 + 211 + unsafe { write_candidates(&self.candidates_buf, candidates) }; 212 + 213 + let params = CommonParams { width: self.width, height: self.height, alpha, _pad0: 0 }; 214 + let batch = candidates.len() as u32; 215 + 216 + let cb = new_command_buffer(&self.queue)?; 217 + let enc = new_compute_encoder(&cb)?; 218 + encode_mean(&enc, &self.mean_pso, &self.candidates_buf, &self.target, &self.means_buf, &params, batch); 219 + enc.endEncoding(); 220 + 221 + let enc2 = new_compute_encoder(&cb)?; 222 + encode_score(&enc2, &self.score_pso, &self.candidates_buf, &self.means_buf, &self.scores_buf, &self.target, &self.canvas, &params, batch); 223 + enc2.endEncoding(); 224 + 225 + self.commit_wait_throttled(&cb); 226 + 227 + let scores = unsafe { read_scores(&self.scores_buf, candidates.len()) }; 228 + Ok(scores.iter().map(|s| s.delta_e2 < 0.0).collect()) 229 + } 230 + 231 + fn new_apply_encoder(&self) -> Result<(Retained<ProtocolObject<dyn MTLCommandBuffer>>, Retained<ProtocolObject<dyn MTLComputeCommandEncoder>>), String> { 232 + let cb = new_command_buffer(&self.queue)?; 233 + let enc = new_compute_encoder(&cb)?; 234 + enc.setComputePipelineState(&self.apply_pso); 235 + unsafe { 236 + enc.setBuffer_offset_atIndex(Some(&self.candidates_buf), 0, 0); 237 + enc.setBuffer_offset_atIndex(Some(&self.means_buf), 0, 1); 238 + enc.setTexture_atIndex(Some(&self.canvas), 0); 239 + } 240 + Ok((cb, enc)) 241 + } 242 + 243 + fn dispatch_apply(&self, enc: &ProtocolObject<dyn MTLComputeCommandEncoder>, c: &Candidate, ci: usize, alpha: f32) { 244 + let ap = ApplyParams { cand_index: ci as u32, alpha, _pad0: 0, _pad1: 0 }; 245 + unsafe { 246 + enc.setBytes_length_atIndex( 247 + NonNull::new_unchecked((&ap as *const ApplyParams).cast::<c_void>() as *mut c_void), 248 + std::mem::size_of::<ApplyParams>() as _, 2, 249 + ); 250 + } 251 + let grid = MTLSize { width: c.bw as _, height: c.bh as _, depth: 1 }; 252 + let tg = MTLSize { width: 16, height: 16, depth: 1 }; 253 + enc.dispatchThreads_threadsPerThreadgroup(grid, tg); 254 + } 255 + 256 + pub fn process_batch(&self, candidates: &[Candidate], alpha: f32, apply_limit: u32) -> Result<Vec<bool>, String> { 257 + let raw_accepted = self.run_mean_and_score(candidates, alpha)?; 258 + if raw_accepted.is_empty() { 259 + return Ok(raw_accepted); 260 + } 261 + 262 + let (cb, enc) = self.new_apply_encoder()?; 263 + let mut accepted = vec![false; raw_accepted.len()]; 264 + let mut applied = 0u32; 265 + 266 + for (ci, ok) in raw_accepted.iter().enumerate() { 267 + if *ok && applied < apply_limit { 268 + accepted[ci] = true; 269 + applied += 1; 270 + self.dispatch_apply(&enc, &candidates[ci], ci, alpha); 271 + } 272 + } 273 + 274 + enc.endEncoding(); 275 + self.commit_wait_throttled(&cb); 276 + Ok(accepted) 277 + } 278 + 279 + pub fn process_batch_checkpointed<F>( 280 + &self, 281 + candidates: &[Candidate], 282 + alpha: f32, 283 + apply_limit: u32, 284 + accepted_total_before: u32, 285 + nth: u32, 286 + mut on_checkpoint: F, 287 + ) -> Result<(Vec<bool>, u32), String> 288 + where 289 + F: FnMut(&MetalContext, u32) -> Result<(), String>, 290 + { 291 + if nth == 0 { 292 + return Err("nth must be > 0".into()); 293 + } 294 + 295 + let raw_accepted = self.run_mean_and_score(candidates, alpha)?; 296 + if raw_accepted.is_empty() { 297 + return Ok((raw_accepted, accepted_total_before)); 298 + } 299 + 300 + let apply_indices: Vec<_> = raw_accepted.iter().enumerate() 301 + .filter_map(|(i, ok)| ok.then_some(i)) 302 + .take(apply_limit as usize) 303 + .collect(); 304 + 305 + let mut accepted = vec![false; raw_accepted.len()]; 306 + for &i in &apply_indices { 307 + accepted[i] = true; 308 + } 309 + 310 + if apply_indices.is_empty() { 311 + return Ok((accepted, accepted_total_before)); 312 + } 313 + 314 + let mut accepted_total = accepted_total_before; 315 + let (mut cb, mut enc) = self.new_apply_encoder()?; 316 + 317 + for (pos, &ci) in apply_indices.iter().enumerate() { 318 + accepted_total = accepted_total.saturating_add(1); 319 + self.dispatch_apply(&enc, &candidates[ci], ci, alpha); 320 + 321 + if accepted_total % nth == 0 { 322 + enc.endEncoding(); 323 + self.commit_wait_throttled(&cb); 324 + on_checkpoint(self, accepted_total)?; 325 + 326 + if pos + 1 < apply_indices.len() { 327 + (cb, enc) = self.new_apply_encoder()?; 328 + } 329 + } 330 + } 331 + 332 + if accepted_total % nth != 0 { 333 + enc.endEncoding(); 334 + self.commit_wait_throttled(&cb); 335 + } 336 + 337 + Ok((accepted, accepted_total)) 338 + } 339 + 340 + pub fn read_canvas(&self) -> Result<Vec<Oklab>, String> { 341 + let mut rgba = vec![0.0f32; (self.width as usize) * (self.height as usize) * 4]; 342 + unsafe { read_full_texture_rgba32f(&self.canvas, self.width, self.height, &mut rgba)? }; 343 + 344 + let mut out = Vec::with_capacity((self.width as usize) * (self.height as usize)); 345 + for i in 0..((self.width as usize) * (self.height as usize)) { 346 + out.push(Oklab { 347 + l: rgba[i * 4], 348 + a: rgba[i * 4 + 1], 349 + b: rgba[i * 4 + 2], 350 + }); 351 + } 352 + Ok(out) 353 + } 354 + } 355 + 356 + fn pack_oklab_rgba32f(pixels: &[Oklab]) -> Vec<f32> { 357 + let mut out = Vec::<f32>::with_capacity(pixels.len() * 4); 358 + for p in pixels { 359 + out.extend_from_slice(&[p.l, p.a, p.b, 0.0]); 360 + } 361 + out 362 + } 363 + 364 + fn compile_library(device: &ProtocolObject<dyn MTLDevice>) -> Result<Retained<ProtocolObject<dyn MTLLibrary>>, String> { 365 + let src = include_str!("kernels.metal"); 366 + let ns_src = NSString::from_str(src); 367 + device 368 + .newLibraryWithSource_options_error(&ns_src, None) 369 + .map_err(|e| format!("failed to compile metal library: {}", e)) 370 + } 371 + 372 + fn compile_pipeline( 373 + device: &ProtocolObject<dyn MTLDevice>, 374 + lib: &ProtocolObject<dyn MTLLibrary>, 375 + name: &NSString, 376 + ) -> Result<Retained<ProtocolObject<dyn MTLComputePipelineState>>, String> { 377 + let Some(f) = lib.newFunctionWithName(name) else { 378 + return Err(format!("missing metal function: {}", name)); 379 + }; 380 + device 381 + .newComputePipelineStateWithFunction_error(&f) 382 + .map_err(|e| format!("failed to build compute pipeline: {}", e)) 383 + } 384 + 385 + fn make_lab_texture( 386 + device: &ProtocolObject<dyn MTLDevice>, 387 + width: u32, 388 + height: u32, 389 + cpu_visible: bool, 390 + ) -> Result<Retained<ProtocolObject<dyn MTLTexture>>, String> { 391 + let desc = unsafe { 392 + MTLTextureDescriptor::texture2DDescriptorWithPixelFormat_width_height_mipmapped( 393 + MTLPixelFormat::RGBA32Float, 394 + width as _, 395 + height as _, 396 + false, 397 + ) 398 + }; 399 + desc.setUsage(MTLTextureUsage::ShaderRead | MTLTextureUsage::ShaderWrite); 400 + if cpu_visible { 401 + desc.setStorageMode(MTLStorageMode::Shared); 402 + } 403 + device 404 + .newTextureWithDescriptor(&desc) 405 + .ok_or_else(|| "failed to create texture".to_string()) 406 + } 407 + 408 + fn new_command_buffer( 409 + queue: &ProtocolObject<dyn MTLCommandQueue>, 410 + ) -> Result<Retained<ProtocolObject<dyn MTLCommandBuffer>>, String> { 411 + queue 412 + .commandBuffer() 413 + .ok_or_else(|| "failed to create command buffer".to_string()) 414 + } 415 + 416 + fn new_compute_encoder( 417 + cb: &ProtocolObject<dyn MTLCommandBuffer>, 418 + ) -> Result<Retained<ProtocolObject<dyn MTLComputeCommandEncoder>>, String> { 419 + cb.computeCommandEncoder() 420 + .ok_or_else(|| "failed to create compute encoder".to_string()) 421 + } 422 + 423 + fn encode_mean( 424 + enc: &ProtocolObject<dyn MTLComputeCommandEncoder>, 425 + pso: &ProtocolObject<dyn MTLComputePipelineState>, 426 + candidates: &ProtocolObject<dyn objc2_metal::MTLBuffer>, 427 + target: &ProtocolObject<dyn MTLTexture>, 428 + means: &ProtocolObject<dyn objc2_metal::MTLBuffer>, 429 + params: &CommonParams, 430 + batch: u32, 431 + ) { 432 + enc.setComputePipelineState(pso); 433 + unsafe { 434 + enc.setBuffer_offset_atIndex(Some(candidates), 0, 0); 435 + enc.setTexture_atIndex(Some(target), 0); 436 + enc.setBuffer_offset_atIndex(Some(means), 0, 1); 437 + enc.setBytes_length_atIndex( 438 + NonNull::new_unchecked((params as *const CommonParams).cast::<c_void>() as *mut c_void), 439 + std::mem::size_of::<CommonParams>() as _, 440 + 2, 441 + ); 442 + } 443 + 444 + let tg = MTLSize { 445 + width: 256, 446 + height: 1, 447 + depth: 1, 448 + }; 449 + let groups = MTLSize { 450 + width: batch as _, 451 + height: 1, 452 + depth: 1, 453 + }; 454 + enc.dispatchThreadgroups_threadsPerThreadgroup(groups, tg); 455 + } 456 + 457 + fn encode_score( 458 + enc: &ProtocolObject<dyn MTLComputeCommandEncoder>, 459 + pso: &ProtocolObject<dyn MTLComputePipelineState>, 460 + candidates: &ProtocolObject<dyn objc2_metal::MTLBuffer>, 461 + means: &ProtocolObject<dyn objc2_metal::MTLBuffer>, 462 + scores: &ProtocolObject<dyn objc2_metal::MTLBuffer>, 463 + target: &ProtocolObject<dyn MTLTexture>, 464 + canvas: &ProtocolObject<dyn MTLTexture>, 465 + params: &CommonParams, 466 + batch: u32, 467 + ) { 468 + enc.setComputePipelineState(pso); 469 + unsafe { 470 + enc.setBuffer_offset_atIndex(Some(candidates), 0, 0); 471 + enc.setBuffer_offset_atIndex(Some(means), 0, 1); 472 + enc.setBuffer_offset_atIndex(Some(scores), 0, 2); 473 + enc.setTexture_atIndex(Some(target), 0); 474 + enc.setTexture_atIndex(Some(canvas), 1); 475 + enc.setBytes_length_atIndex( 476 + NonNull::new_unchecked((params as *const CommonParams).cast::<c_void>() as *mut c_void), 477 + std::mem::size_of::<CommonParams>() as _, 478 + 3, 479 + ); 480 + } 481 + 482 + let tg = MTLSize { 483 + width: 256, 484 + height: 1, 485 + depth: 1, 486 + }; 487 + let groups = MTLSize { 488 + width: batch as _, 489 + height: 1, 490 + depth: 1, 491 + }; 492 + enc.dispatchThreadgroups_threadsPerThreadgroup(groups, tg); 493 + } 494 + 495 + unsafe fn write_candidates( 496 + buf: &ProtocolObject<dyn objc2_metal::MTLBuffer>, 497 + candidates: &[Candidate], 498 + ) { 499 + let ptr = buf.contents().as_ptr().cast::<Candidate>(); 500 + unsafe { std::ptr::copy_nonoverlapping(candidates.as_ptr(), ptr, candidates.len()) }; 501 + } 502 + 503 + unsafe fn read_scores( 504 + buf: &ProtocolObject<dyn objc2_metal::MTLBuffer>, 505 + n: usize, 506 + ) -> Vec<ScoreOut> { 507 + let ptr = buf.contents().as_ptr().cast::<ScoreOut>(); 508 + let mut out = vec![ScoreOut::default(); n]; 509 + unsafe { std::ptr::copy_nonoverlapping(ptr, out.as_mut_ptr(), n) }; 510 + out 511 + } 512 + 513 + unsafe fn write_full_texture_rgba32f( 514 + tex: &ProtocolObject<dyn MTLTexture>, 515 + width: u32, 516 + height: u32, 517 + rgba: &[f32], 518 + ) -> Result<(), String> { 519 + if rgba.len() != (width as usize) * (height as usize) * 4 { 520 + return Err("texture upload size mismatch".to_string()); 521 + } 522 + let region = MTLRegion { 523 + origin: objc2_metal::MTLOrigin { x: 0, y: 0, z: 0 }, 524 + size: MTLSize { 525 + width: width as _, 526 + height: height as _, 527 + depth: 1, 528 + }, 529 + }; 530 + unsafe { 531 + tex.replaceRegion_mipmapLevel_withBytes_bytesPerRow( 532 + region, 533 + 0, 534 + NonNull::new_unchecked(rgba.as_ptr().cast::<c_void>() as *mut c_void), 535 + (width as usize * 4 * std::mem::size_of::<f32>()) as _, 536 + ); 537 + } 538 + Ok(()) 539 + } 540 + 541 + unsafe fn read_full_texture_rgba32f( 542 + tex: &ProtocolObject<dyn MTLTexture>, 543 + width: u32, 544 + height: u32, 545 + out: &mut [f32], 546 + ) -> Result<(), String> { 547 + if out.len() != (width as usize) * (height as usize) * 4 { 548 + return Err("texture readback size mismatch".to_string()); 549 + } 550 + let region = MTLRegion { 551 + origin: objc2_metal::MTLOrigin { x: 0, y: 0, z: 0 }, 552 + size: MTLSize { 553 + width: width as _, 554 + height: height as _, 555 + depth: 1, 556 + }, 557 + }; 558 + unsafe { 559 + tex.getBytes_bytesPerRow_fromRegion_mipmapLevel( 560 + NonNull::new_unchecked(out.as_mut_ptr().cast::<c_void>()), 561 + (width as usize * 4 * std::mem::size_of::<f32>()) as _, 562 + region, 563 + 0, 564 + ); 565 + } 566 + Ok(()) 567 + }
+133
src/oklab.rs
··· 1 + use rayon::prelude::*; 2 + 3 + #[derive(Clone, Copy, Debug, Default)] 4 + pub struct Oklab { 5 + pub l: f32, 6 + pub a: f32, 7 + pub b: f32, 8 + } 9 + 10 + #[derive(Clone, Copy, Debug, Default)] 11 + pub struct Srgb { 12 + pub r: f32, 13 + pub g: f32, 14 + pub b: f32, 15 + } 16 + 17 + impl Srgb { 18 + pub fn from_u8(r: u8, g: u8, b: u8) -> Self { 19 + Self { 20 + r: r as f32 / 255.0, 21 + g: g as f32 / 255.0, 22 + b: b as f32 / 255.0, 23 + } 24 + } 25 + 26 + pub fn to_u8(self) -> (u8, u8, u8) { 27 + let clamp = |x: f32| (x.clamp(0.0, 1.0) * 255.0).round() as u8; 28 + (clamp(self.r), clamp(self.g), clamp(self.b)) 29 + } 30 + 31 + fn linearize(c: f32) -> f32 { 32 + if c <= 0.04045 { 33 + c / 12.92 34 + } else { 35 + ((c + 0.055) / 1.055).powf(2.4) 36 + } 37 + } 38 + 39 + fn delinearize(c: f32) -> f32 { 40 + if c <= 0.0031308 { 41 + c * 12.92 42 + } else { 43 + 1.055 * c.powf(1.0 / 2.4) - 0.055 44 + } 45 + } 46 + 47 + pub fn to_linear(self) -> (f32, f32, f32) { 48 + ( 49 + Self::linearize(self.r), 50 + Self::linearize(self.g), 51 + Self::linearize(self.b), 52 + ) 53 + } 54 + 55 + pub fn from_linear(r: f32, g: f32, b: f32) -> Self { 56 + Self { 57 + r: Self::delinearize(r), 58 + g: Self::delinearize(g), 59 + b: Self::delinearize(b), 60 + } 61 + } 62 + } 63 + 64 + impl Oklab { 65 + pub fn from_srgb(srgb: Srgb) -> Self { 66 + // srgb to linear srgb 67 + let (r, g, b) = srgb.to_linear(); 68 + 69 + // linear srgb to lms (approximate cone response) 70 + let l = 0.4122214708 * r + 0.5363325363 * g + 0.0514459929 * b; 71 + let m = 0.2119034982 * r + 0.6806995451 * g + 0.1073969566 * b; 72 + let s = 0.0883024619 * r + 0.2817188376 * g + 0.6299787005 * b; 73 + 74 + // apply non-linearity (cube root) 75 + let l_ = l.cbrt(); 76 + let m_ = m.cbrt(); 77 + let s_ = s.cbrt(); 78 + 79 + // transform to oklab coordinates (matrix m2) 80 + Self { 81 + l: 0.2104542553 * l_ + 0.7936177850 * m_ - 0.0040720468 * s_, 82 + a: 1.9779984951 * l_ - 2.4285922050 * m_ + 0.4505937099 * s_, 83 + b: 0.0259040371 * l_ + 0.7827717662 * m_ - 0.8086757660 * s_, 84 + } 85 + } 86 + 87 + pub fn to_srgb(self) -> Srgb { 88 + // inverse of matrix m2 (oklab -> non-linear lms) 89 + let l_ = self.l + 0.3963377774 * self.a + 0.2158037573 * self.b; 90 + let m_ = self.l - 0.1055613458 * self.a - 0.0638541728 * self.b; 91 + let s_ = self.l - 0.0894841775 * self.a - 1.2914855480 * self.b; 92 + 93 + // revert non-linearity (cube) 94 + let l = l_ * l_ * l_; 95 + let m = m_ * m_ * m_; 96 + let s = s_ * s_ * s_; 97 + 98 + // inverse of matrix m1 (lms -> linear srgb) 99 + let r = 4.0767416621 * l - 3.3077115913 * m + 0.2309699292 * s; 100 + let g = -1.2684380046 * l + 2.6097574011 * m - 0.3413193965 * s; 101 + let b = -0.0041960863 * l - 0.7034186147 * m + 1.7076147010 * s; 102 + 103 + // convert linear srgb to srgb 104 + Srgb::from_linear(r, g, b) 105 + } 106 + } 107 + 108 + pub fn srgb8_to_oklab(r: u8, g: u8, b: u8) -> Oklab { 109 + Oklab::from_srgb(Srgb::from_u8(r, g, b)) 110 + } 111 + 112 + pub fn oklab_to_srgb8(lab: Oklab) -> (u8, u8, u8) { 113 + lab.to_srgb().to_u8() 114 + } 115 + 116 + pub fn rgb8_to_oklab_parallel(rgb: &[u8]) -> Result<Vec<Oklab>, String> { 117 + if rgb.len() % 3 != 0 { 118 + return Err("rgb buffer must be 3 bytes per pixel".into()); 119 + } 120 + Ok(rgb.par_chunks(3).map(|c| srgb8_to_oklab(c[0], c[1], c[2])).collect()) 121 + } 122 + 123 + pub fn avg_oklab_parallel(pixels: &[Oklab]) -> Oklab { 124 + if pixels.is_empty() { 125 + return Oklab::default(); 126 + } 127 + let (l, a, b) = pixels 128 + .par_iter() 129 + .map(|p| (p.l, p.a, p.b)) 130 + .reduce(|| (0.0, 0.0, 0.0), |(l1, a1, b1), (l2, a2, b2)| (l1 + l2, a1 + a2, b1 + b2)); 131 + let n = pixels.len() as f32; 132 + Oklab { l: l / n, a: a / n, b: b / n } 133 + }
+235
src/pipeline.rs
··· 1 + use crate::args::{Bg, Config}; 2 + use crate::fs; 3 + use crate::log; 4 + use crate::metal::MetalContext; 5 + use crate::oklab::{avg_oklab_parallel, oklab_to_srgb8, rgb8_to_oklab_parallel, srgb8_to_oklab, Oklab}; 6 + use crate::rng::Rng; 7 + use crate::sampling; 8 + use image::{GenericImageView, ImageBuffer, Rgb}; 9 + use rayon::prelude::*; 10 + use std::path::{Path, PathBuf}; 11 + use std::time::Instant; 12 + 13 + pub enum ResolvedOutput { 14 + SingleFile(PathBuf), 15 + FrameSequence { dir: PathBuf, nth: u32 }, 16 + } 17 + 18 + impl ResolvedOutput { 19 + pub fn from_config(cfg: &Config, base_path: PathBuf) -> Self { 20 + match cfg.nth { 21 + Some(nth) => Self::FrameSequence { dir: base_path, nth }, 22 + None => Self::SingleFile(base_path), 23 + } 24 + } 25 + 26 + fn prepare_dirs(&self) -> Result<(), String> { 27 + match self { 28 + Self::FrameSequence { dir, .. } => std::fs::create_dir_all(dir) 29 + .map_err(|e| format!("failed to create dir {}: {e}", dir.display())), 30 + Self::SingleFile(file) => { 31 + if let Some(parent) = file.parent().filter(|p| !p.as_os_str().is_empty()) { 32 + std::fs::create_dir_all(parent) 33 + .map_err(|e| format!("failed to create dir {}: {e}", parent.display()))?; 34 + } 35 + Ok(()) 36 + } 37 + } 38 + } 39 + } 40 + 41 + pub fn process_dir(cfg: &Config, input_dir: &Path, out_root: &Path) -> Result<(), String> { 42 + let started = Instant::now(); 43 + let mut files = fs::walk_files_recursive(input_dir)?; 44 + files.sort(); 45 + 46 + let mut processed: u64 = 0; 47 + let mut skipped: u64 = 0; 48 + 49 + for input_file in files { 50 + let rel = match input_file.strip_prefix(input_dir) { 51 + Ok(v) => v, 52 + Err(_) => continue, 53 + }; 54 + 55 + let base_path = match cfg.nth { 56 + Some(_) => out_root.join(rel).with_extension(""), 57 + None => out_root.join(rel).with_extension("png"), 58 + }; 59 + let output = ResolvedOutput::from_config(cfg, base_path); 60 + 61 + let Ok(_) = image::open(&input_file) else { 62 + skipped += 1; 63 + if log::level() >= 2 { 64 + log::l2(format_args!("skip: {}", input_file.display())); 65 + } 66 + continue; 67 + }; 68 + 69 + match process_one(cfg, &input_file, output) { 70 + Ok(final_path) => { 71 + processed += 1; 72 + if log::level() >= 2 { 73 + log::l2(format_args!( 74 + "output: {} (input: {})", 75 + final_path.display(), 76 + input_file.display() 77 + )); 78 + } 79 + } 80 + Err(e) => { 81 + skipped += 1; 82 + log::l1(format_args!("failed: {} ({})", input_file.display(), e)); 83 + } 84 + } 85 + } 86 + 87 + let elapsed_s = started.elapsed().as_secs_f64(); 88 + println!( 89 + "done: processed {} images, skipped {}, elapsed {:.2}s", 90 + processed, skipped, elapsed_s 91 + ); 92 + Ok(()) 93 + } 94 + 95 + pub fn process_one(cfg: &Config, input_path: &Path, output: ResolvedOutput) -> Result<PathBuf, String> { 96 + let started = Instant::now(); 97 + 98 + let img = image::open(input_path) 99 + .map_err(|e| format!("failed to load image {}: {e}", input_path.display()))?; 100 + let (w, h) = img.dimensions(); 101 + let rgb = img.to_rgb8(); 102 + let target = rgb8_to_oklab_parallel(rgb.as_raw())?; 103 + 104 + let current_canvas = if let Some(current_path) = &cfg.current { 105 + let img = image::open(current_path) 106 + .map_err(|e| format!("failed to load current image {}: {e}", current_path))?; 107 + let (cw, ch) = img.dimensions(); 108 + if cw != w || ch != h { 109 + return Err(format!( 110 + "current image size mismatch: got {}x{}, expected {}x{}", 111 + cw, ch, w, h 112 + )); 113 + } 114 + let rgb = img.to_rgb8(); 115 + Some(rgb8_to_oklab_parallel(rgb.as_raw())?) 116 + } else { 117 + None 118 + }; 119 + 120 + log::l2(format_args!( 121 + "input: {} ({}x{}), batch: {}, seed: {}, alpha: {}, max_gpu: {}", 122 + input_path.display(), 123 + w, 124 + h, 125 + cfg.batch, 126 + cfg.seed, 127 + cfg.alpha, 128 + cfg.max_gpu, 129 + )); 130 + 131 + let avg = avg_oklab_parallel(&target); 132 + let bg = match cfg.bg { 133 + Bg::Avg => avg, 134 + Bg::RgbU8 { r, g, b } => srgb8_to_oklab(r, g, b), 135 + }; 136 + 137 + let max_splines = cfg.number.unwrap_or((w as f64 * h as f64).powf(0.7).round() as u32); 138 + let mut rng = Rng::new(cfg.seed); 139 + 140 + let metal = MetalContext::new(w, h, cfg.batch, cfg.max_gpu)?; 141 + match &current_canvas { 142 + Some(canvas) => metal.upload_target_and_set_canvas(&target, canvas)?, 143 + None => metal.upload_target_and_init_canvas(&target, bg)?, 144 + } 145 + output.prepare_dirs()?; 146 + 147 + if let ResolvedOutput::FrameSequence { dir, nth } = &output { 148 + log::l2(format_args!("frames: {} (every {} accepted)", dir.display(), nth)); 149 + } 150 + log::l1(format_args!("target splines: {}", max_splines)); 151 + let mut batch_logger = log::BatchLogger::new(); 152 + 153 + let mut accepted_total: u32 = 0; 154 + let mut consecutive_stagnant_batches: u32 = 0; 155 + let mut batch_idx: u64 = 0; 156 + 157 + while accepted_total < max_splines { 158 + batch_idx += 1; 159 + let remaining = max_splines - accepted_total; 160 + let candidates = sampling::sample_candidates(&mut rng, w, h, cfg.batch); 161 + let accepted_before = accepted_total; 162 + 163 + match &output { 164 + ResolvedOutput::FrameSequence { dir, nth } => { 165 + let (_, accepted_after) = metal.process_batch_checkpointed( 166 + &candidates, cfg.alpha, remaining, accepted_total, *nth, 167 + |metal, total| { 168 + let frame = metal.read_canvas()?; 169 + save_oklab_png(dir.join(format!("frame_{:06}.png", total)), &frame, w, h) 170 + }, 171 + )?; 172 + accepted_total = accepted_after; 173 + } 174 + ResolvedOutput::SingleFile(_) => { 175 + let accepted = metal.process_batch(&candidates, cfg.alpha, remaining)?; 176 + accepted_total += accepted.iter().filter(|v| **v).count() as u32; 177 + } 178 + }; 179 + 180 + let accepted_now = accepted_total - accepted_before; 181 + let accept_threshold = (cfg.batch as f64) * (cfg.min_accept_ratio as f64); 182 + let stagnant = (accepted_now as f64) + 1e-9 < accept_threshold; 183 + consecutive_stagnant_batches = if stagnant { consecutive_stagnant_batches + 1 } else { 0 }; 184 + 185 + batch_logger.log_batch(log::BatchInfo { 186 + batch_idx, 187 + accepted_total, 188 + max_splines, 189 + accepted_now, 190 + batch_size: cfg.batch, 191 + remaining: max_splines.saturating_sub(accepted_total), 192 + consecutive_stagnant: consecutive_stagnant_batches, 193 + max_stagnant_batches: cfg.max_stagnant_batches, 194 + elapsed_s: started.elapsed().as_secs_f64(), 195 + }); 196 + 197 + if consecutive_stagnant_batches >= cfg.max_stagnant_batches { 198 + break; 199 + } 200 + } 201 + 202 + let canvas = metal.read_canvas()?; 203 + let final_path = match &output { 204 + ResolvedOutput::FrameSequence { dir, .. } => { 205 + let path = dir.join("final.png"); 206 + save_oklab_png(&path, &canvas, w, h)?; 207 + path 208 + } 209 + ResolvedOutput::SingleFile(file) => { 210 + save_oklab_png(file, &canvas, w, h)?; 211 + file.clone() 212 + } 213 + }; 214 + 215 + log::l2(format_args!( 216 + "done: {} (accepted: {}, target: {}, elapsed: {:.2}s)", 217 + final_path.display(), accepted_total, max_splines, started.elapsed().as_secs_f64() 218 + )); 219 + Ok(final_path) 220 + } 221 + 222 + fn save_oklab_png(path: impl AsRef<Path>, data: &[Oklab], w: u32, h: u32) -> Result<(), String> { 223 + let path = path.as_ref(); 224 + if data.len() != (w as usize) * (h as usize) { 225 + return Err("output size mismatch".into()); 226 + } 227 + let rgb: Vec<u8> = data 228 + .par_iter() 229 + .flat_map(|lab| { let (r, g, b) = oklab_to_srgb8(*lab); [r, g, b] }) 230 + .collect(); 231 + ImageBuffer::<Rgb<u8>, _>::from_raw(w, h, rgb) 232 + .ok_or("buffer size mismatch")? 233 + .save(path) 234 + .map_err(|e| format!("failed to save: {e}")) 235 + }
+25
src/rng.rs
··· 1 + #[derive(Clone, Copy, Debug)] 2 + pub struct Rng { 3 + state: u64, 4 + } 5 + 6 + impl Rng { 7 + pub fn new(seed: u64) -> Self { 8 + Self { 9 + state: seed ^ 0x9e3779b97f4a7c15, 10 + } 11 + } 12 + 13 + pub fn next_u64(&mut self) -> u64 { 14 + self.state = self.state.wrapping_add(0x9e3779b97f4a7c15); 15 + let mut z = self.state; 16 + z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9); 17 + z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb); 18 + z ^ (z >> 31) 19 + } 20 + 21 + pub fn next_f32(&mut self) -> f32 { 22 + let bits = (self.next_u64() >> 40) as u32; 23 + (bits as f32) * (1.0 / ((1u32 << 24) as f32)) 24 + } 25 + }
+42
src/sampling.rs
··· 1 + use crate::metal::Candidate; 2 + use crate::rng::Rng; 3 + 4 + pub fn sample_candidates(rng: &mut Rng, w: u32, h: u32, batch: u32) -> Vec<Candidate> { 5 + let max_x = (w.saturating_sub(1)) as f32; 6 + let max_y = (h.saturating_sub(1)) as f32; 7 + let mut sample_pt = || [ 8 + (rng.next_f32() * (max_x + 1.0)).clamp(0.0, max_x), 9 + (rng.next_f32() * (max_y + 1.0)).clamp(0.0, max_y), 10 + ]; 11 + 12 + (0..batch) 13 + .map(|_| { 14 + let [p0, p1, p2, p3] = std::array::from_fn(|_| sample_pt()); 15 + let (bx, by, bw, bh) = bbox_from_points(p0, p1, p2, p3, w, h); 16 + Candidate { p0, p1, p2, p3, bx, by, bw, bh } 17 + }) 18 + .collect() 19 + } 20 + 21 + pub fn bbox_from_points( 22 + p0: [f32; 2], 23 + p1: [f32; 2], 24 + p2: [f32; 2], 25 + p3: [f32; 2], 26 + w: u32, 27 + h: u32, 28 + ) -> (u32, u32, u32, u32) { 29 + let min_x = p0[0].min(p1[0]).min(p2[0]).min(p3[0]); 30 + let max_x = p0[0].max(p1[0]).max(p2[0]).max(p3[0]); 31 + let min_y = p0[1].min(p1[1]).min(p2[1]).min(p3[1]); 32 + let max_y = p0[1].max(p1[1]).max(p2[1]).max(p3[1]); 33 + 34 + let bx = (min_x.floor() as i64 - 1).clamp(0, (w.saturating_sub(1)) as i64) as u32; 35 + let by = (min_y.floor() as i64 - 1).clamp(0, (h.saturating_sub(1)) as i64) as u32; 36 + let ex = (max_x.ceil() as i64 + 1).clamp(0, (w.saturating_sub(1)) as i64) as u32; 37 + let ey = (max_y.ceil() as i64 + 1).clamp(0, (h.saturating_sub(1)) as i64) as u32; 38 + 39 + let bw = (ex - bx + 1).max(1); 40 + let bh = (ey - by + 1).max(1); 41 + (bx, by, bw, bh) 42 + }