iterative image reconstruction using random cubic bézier strokes, accelerated on metal
metal graphics rust
7
fork

Configure Feed

Select the types of activity you want to include in your feed.

perform opinionated formatting

+700 -1327
+1 -1
Cargo.lock
··· 269 269 270 270 [[package]] 271 271 name = "splined" 272 - version = "0.1.0" 272 + version = "0.1.1" 273 273 dependencies = [ 274 274 "image", 275 275 "objc2",
+2 -2
Cargo.toml
··· 1 1 [package] 2 2 name = "splined" 3 - version = "0.1.0" 3 + version = "0.1.1" 4 4 edition = "2024" 5 5 6 6 [dependencies] ··· 12 12 13 13 [profile.release] 14 14 lto = true 15 - codegen-units = 1 15 + codegen-units = 1
+4
changelog
··· 1 + version 0.1.1 2 + 2026-01-23 3 + opinionated formatting 4 + 1 5 version 0.1.0 2 6 2026-01-15 3 7 first release
-1
rustfmt.toml
··· 1 - tab_spaces = 3
+67 -169
src/args.rs
··· 29 29 " splined in.png --nth 50 -o frames/", 30 30 " splined images/ -o results/ -n 5000 -b 64 -s 42 --nth 50", 31 31 "", 32 - ] 33 - .join("\n") 34 - } 32 + ].join("\n") } 35 33 36 34 #[derive(Clone, Debug)] 37 - pub struct Config { 38 - pub input: String, 39 - pub number: Option<u32>, 40 - pub batch: u32, 41 - pub min_accept_ratio: f32, 42 - pub seed: u64, 43 - pub max_gpu: f32, 44 - pub log: u8, 45 - pub output: String, 46 - pub output_provided: bool, 47 - pub current: Option<String>, 48 - pub nth: Option<u32>, 49 - pub bg: Bg, 50 - pub alpha: f32, 51 - pub max_stagnant_batches: u32, 52 - } 35 + pub struct Config { pub input : String 36 + , pub number : Option<u32> 37 + , pub batch : u32 38 + , pub min_accept_ratio : f32 39 + , pub seed : u64 40 + , pub max_gpu : f32 41 + , pub log : u8 42 + , pub output : String 43 + , pub output_provided : bool 44 + , pub current : Option<String> 45 + , pub nth : Option<u32> 46 + , pub bg : Bg 47 + , pub alpha : f32 48 + , pub max_stagnant_batches : u32 } 53 49 54 50 #[derive(Clone, Copy, Debug)] 55 - pub enum Bg { 56 - Avg, 57 - RgbU8 { r: u8, g: u8, b: u8 }, 58 - } 51 + pub enum Bg { Avg 52 + , RgbU8 { r: u8, g: u8, b: u8 } } 59 53 60 54 impl Config { 61 - pub fn parse(argv: &[String]) -> Result<Self, String> { 62 - if argv.len() < 2 { 63 - return Err(usage()); 64 - } 55 + pub fn parse(argv: &[String]) -> Result<Self, String> { 56 + if argv.len() < 2 { return Err(usage()); } 65 57 66 - let mut input: Option<String> = None; 67 - let mut number: Option<u32> = None; 68 - let mut batch: u32 = 32; 69 - let mut min_accept_ratio: Option<f32> = None; 70 - let mut seed: u64 = 0; 71 - let mut max_gpu: f32 = 1.0; 72 - let mut log: u8 = 1; 73 - let mut output: String = "output.png".to_string(); 74 - let mut output_provided: bool = false; 75 - let mut current: Option<String> = None; 76 - let mut nth: Option<u32> = None; 77 - let mut bg: Bg = Bg::Avg; 78 - let mut alpha: f32 = 1.0; 79 - let mut max_stagnant_batches: u32 = 10; 58 + let mut input : Option<String> = None; 59 + let mut number : Option<u32> = None; 60 + let mut batch : u32 = 32; 61 + let mut min_accept_ratio : Option<f32> = None; 62 + let mut seed : u64 = 0; 63 + let mut max_gpu : f32 = 1.0; 64 + let mut log : u8 = 1; 65 + let mut output : String = "output.png".to_string(); 66 + let mut output_provided : bool = false; 67 + let mut current : Option<String> = None; 68 + let mut nth : Option<u32> = None; 69 + let mut bg : Bg = Bg::Avg; 70 + let mut alpha : f32 = 1.0; 71 + let mut max_stagnant_batches : u32 = 10; 80 72 81 - let mut i = 1usize; 82 - while i < argv.len() { 83 - let a = argv[i].as_str(); 84 - if !a.starts_with('-') && input.is_none() { 85 - input = Some(argv[i].clone()); 86 - i += 1; 87 - continue; 88 - } 73 + let mut i = 1usize; 74 + while i < argv.len() { 75 + let a = argv[i].as_str(); 76 + if !a.starts_with('-') && input.is_none() { input = Some(argv[i].clone()); i += 1; continue; } 89 77 90 78 match a { 91 - "-n" | "--number" => { 92 - i += 1; 93 - number = Some(parse_val(argv, i, a)?); 94 - i += 1; 95 - } 96 - "-b" | "--batch" => { 97 - i += 1; 98 - batch = parse_val(argv, i, a)?; 99 - i += 1; 100 - } 101 - "--min-accept-ratio" => { 102 - i += 1; 103 - min_accept_ratio = Some(parse_val(argv, i, a)?); 104 - i += 1; 105 - } 106 - "-s" | "--seed" => { 107 - i += 1; 108 - seed = parse_val(argv, i, a)?; 109 - i += 1; 110 - } 111 - "--max-gpu" => { 112 - i += 1; 113 - max_gpu = parse_val(argv, i, a)?; 114 - i += 1; 115 - } 116 - "-l" | "--log" => { 117 - i += 1; 118 - log = parse_val(argv, i, a)?; 119 - i += 1; 120 - } 121 - "-o" | "--output" => { 122 - i += 1; 123 - output = parse_val(argv, i, a)?; 124 - output_provided = true; 125 - i += 1; 126 - } 127 - "-c" | "--current" => { 128 - i += 1; 129 - current = Some(parse_val(argv, i, a)?); 130 - i += 1; 131 - } 132 - "--nth" => { 133 - i += 1; 134 - nth = Some(parse_val(argv, i, a)?); 135 - i += 1; 136 - } 137 - "--bg" => { 138 - i += 1; 139 - bg = parse_bg(argv, i, a)?; 140 - i += 1; 141 - } 142 - "-a" | "--alpha" => { 143 - i += 1; 144 - alpha = parse_val(argv, i, a)?; 145 - i += 1; 146 - } 147 - "--max-stagnant-batches" => { 148 - i += 1; 149 - max_stagnant_batches = parse_val(argv, i, a)?; 150 - i += 1; 151 - } 152 - "-h" | "--help" => return Err(usage()), 153 - _ => return Err(format!("unknown arg: {a}\n\n{}", usage())), 154 - } 155 - } 79 + "-n" | "--number" => { i += 1; number = Some(parse_val(argv, i, a)?); i += 1; } 80 + "-b" | "--batch" => { i += 1; batch = parse_val(argv, i, a)?; i += 1; } 81 + "--min-accept-ratio" => { i += 1; min_accept_ratio = Some(parse_val(argv, i, a)?); i += 1; } 82 + "-s" | "--seed" => { i += 1; seed = parse_val(argv, i, a)?; i += 1; } 83 + "--max-gpu" => { i += 1; max_gpu = parse_val(argv, i, a)?; i += 1; } 84 + "-l" | "--log" => { i += 1; log = parse_val(argv, i, a)?; i += 1; } 85 + "-o" | "--output" => { i += 1; output = parse_val(argv, i, a)?; output_provided = true; i += 1; } 86 + "-c" | "--current" => { i += 1; current = Some(parse_val(argv, i, a)?); i += 1; } 87 + "--nth" => { i += 1; nth = Some(parse_val(argv, i, a)?); i += 1; } 88 + "--bg" => { i += 1; bg = parse_bg(argv, i, a)?; i += 1; } 89 + "-a" | "--alpha" => { i += 1; alpha = parse_val(argv, i, a)?; i += 1; } 90 + "--max-stagnant-batches" => { i += 1; max_stagnant_batches = parse_val(argv, i, a)?; i += 1; } 91 + "-h" | "--help" => return Err(usage()), 92 + _ => return Err(format!("unknown arg: {a}\n\n{}", usage())), } } 156 93 157 - let Some(input) = input else { 158 - return Err(usage()); 159 - }; 94 + let Some(input) = input else { return Err(usage()); }; 160 95 161 - if batch == 0 { 162 - return Err("batch must be > 0".to_string()); 163 - } 164 - let min_accept_ratio = min_accept_ratio.unwrap_or(0.02); 165 - if !(0.0..=1.0).contains(&min_accept_ratio) { 166 - return Err("min-accept-ratio must be in [0, 1]".to_string()); 167 - } 168 - if log > 3 { 169 - return Err("log must be in [0, 3]".to_string()); 170 - } 171 - if !(0.0 < max_gpu && max_gpu <= 1.0) { 172 - return Err("max-gpu must be in (0, 1]".to_string()); 173 - } 174 - if let Some(nth) = nth { 175 - if nth == 0 { 176 - return Err("nth must be > 0".to_string()); 177 - } 178 - } 179 - if !(0.0..=1.0).contains(&alpha) { 180 - return Err("alpha must be in [0, 1]".to_string()); 181 - } 182 - if max_stagnant_batches == 0 { 183 - return Err("max-stagnant-batches must be > 0".to_string()); 184 - } 96 + if batch == 0 { return Err("batch must be > 0".to_string()); } 97 + 98 + let min_accept_ratio = min_accept_ratio.unwrap_or(0.02); 99 + 100 + if !(0.0..=1.0).contains(&min_accept_ratio) { return Err("min-accept-ratio must be in [0, 1]".to_string()); } 101 + if log > 3 { return Err("log must be in [0, 3]".to_string()); } 102 + if !(0.0 < max_gpu && max_gpu <= 1.0) { return Err("max-gpu must be in (0, 1]".to_string()); } 103 + if let Some(nth) = nth { if nth == 0 { return Err("nth must be > 0".to_string()); } } 104 + if !(0.0..=1.0).contains(&alpha) { return Err("alpha must be in [0, 1]".to_string()); } 105 + if max_stagnant_batches == 0 { return Err("max-stagnant-batches must be > 0".to_string()); } 185 106 186 - Ok(Self { 187 - input, 188 - number, 189 - batch, 190 - min_accept_ratio, 191 - seed, 192 - max_gpu, 193 - log, 194 - output, 195 - output_provided, 196 - current, 197 - nth, 198 - bg, 199 - alpha, 200 - max_stagnant_batches, 201 - }) 202 - } 203 - } 107 + Ok(Self { input, number, batch, min_accept_ratio, seed, max_gpu, log, output, output_provided, current, nth, bg, alpha, max_stagnant_batches, }) } } 204 108 205 109 fn parse_val<T: std::str::FromStr>(argv: &[String], idx: usize, flag: &str) -> Result<T, String> { 206 110 let s = argv.get(idx).ok_or_else(|| format!("missing value for {flag}"))?; 207 - s.parse::<T>().map_err(|_| format!("invalid value for {flag}: {s}")) 208 - } 111 + s.parse::<T>().map_err(|_| format!("invalid value for {flag}: {s}")) } 209 112 210 113 fn parse_bg(argv: &[String], idx: usize, flag: &str) -> Result<Bg, String> { 211 114 let s = argv.get(idx).ok_or_else(|| format!("missing value for {flag}"))?; 212 - if s == "avg" { 213 - return Ok(Bg::Avg); 214 - } 115 + if s == "avg" { return Ok(Bg::Avg); } 215 116 let parts: Vec<_> = s.split(',').collect(); 216 - if parts.len() != 3 { 217 - return Err(format!("invalid bg, expected avg or r,g,b: {s}")); 218 - } 117 + if parts.len() != 3 { return Err(format!("invalid bg, expected avg or r,g,b: {s}")); } 219 118 let parse = |i: usize| parts[i].parse::<u8>().map_err(|_| format!("invalid bg component: {s}")); 220 - Ok(Bg::RgbU8 { r: parse(0)?, g: parse(1)?, b: parse(2)? }) 221 - } 119 + Ok(Bg::RgbU8 { r: parse(0)?, g: parse(1)?, b: parse(2)? }) }
+15 -30
src/fs.rs
··· 2 2 use std::path::{Path, PathBuf}; 3 3 4 4 pub fn validate_and_prepare_dir_output(cfg: &args::Config) -> Result<PathBuf, String> { 5 - if !cfg.output_provided { 6 - return Err("when input is a directory, -o/--output must be provided".into()); 7 - } 5 + if !cfg.output_provided { return Err("when input is a directory, -o/--output must be provided".into()); } 8 6 9 - let out_root = PathBuf::from(&cfg.output); 10 - if out_root.exists() && !out_root.is_dir() { 11 - return Err(format!("output must be a directory, got file: {}", out_root.display())); 12 - } 13 - std::fs::create_dir_all(&out_root) 14 - .map_err(|e| format!("failed to create output dir {}: {e}", out_root.display()))?; 15 - Ok(out_root) 16 - } 7 + let out_root = PathBuf::from(&cfg.output); 8 + if out_root.exists() && !out_root.is_dir() { return Err(format!("output must be a directory, got file: {}", out_root.display())); } 9 + std::fs::create_dir_all(&out_root).map_err(|e| format!("failed to create output dir {}: {e}", out_root.display()))?; 10 + Ok(out_root) } 17 11 18 12 pub fn walk_files_recursive(root: &Path) -> Result<Vec<PathBuf>, String> { 19 - let mut out = Vec::new(); 20 - let mut stack = vec![root.to_path_buf()]; 13 + let mut out = Vec::new(); 14 + let mut stack = vec![root.to_path_buf()]; 15 + 16 + while let Some(dir) = stack.pop() { 17 + let rd = std::fs::read_dir(&dir).map_err(|e| format!("failed to read dir {}: {e}", dir.display()))?; 21 18 22 - while let Some(dir) = stack.pop() { 23 - let rd = std::fs::read_dir(&dir) 24 - .map_err(|e| format!("failed to read dir {}: {e}", dir.display()))?; 25 - for entry in rd { 26 - let entry = entry.map_err(|e| format!("failed to read dir entry: {e}"))?; 27 - let ty = entry 28 - .file_type() 29 - .map_err(|e| format!("failed to stat {}: {e}", entry.path().display()))?; 30 - if ty.is_dir() { 31 - stack.push(entry.path()); 32 - } else if ty.is_file() { 33 - out.push(entry.path()); 34 - } 35 - } 36 - } 19 + for entry in rd { 20 + let entry = entry.map_err(|e| format!("failed to read dir entry: {e}"))?; 21 + let ty = entry.file_type().map_err(|e| format!("failed to stat {}: {e}", entry.path().display()))?; 22 + if ty.is_dir() { stack.push(entry.path()); } else if ty.is_file() { out.push(entry.path()); } } } 37 23 38 - Ok(out) 39 - } 24 + Ok(out) }
+55 -95
src/log.rs
··· 3 3 4 4 static LEVEL: AtomicU8 = AtomicU8::new(1); 5 5 6 - pub fn set_level(level: u8) { 7 - LEVEL.store(level.min(3), Ordering::Relaxed); 8 - } 6 + pub fn set_level(level: u8) { LEVEL.store(level.min(3), Ordering::Relaxed); } 9 7 10 - pub fn level() -> u8 { 11 - LEVEL.load(Ordering::Relaxed) 12 - } 8 + pub fn level() -> u8 { LEVEL.load(Ordering::Relaxed) } 13 9 14 - pub fn l1(args: Arguments) { 15 - emit(1, args); 16 - } 10 + pub fn l1(args: Arguments) { emit(1, args); } 11 + pub fn l2(args: Arguments) { emit(2, args); } 17 12 18 - pub fn l2(args: Arguments) { 19 - emit(2, args); 20 - } 13 + fn emit(min: u8, args: Arguments) { if level() >= min { eprintln!("{args}"); } } 21 14 22 - fn emit(min: u8, args: Arguments) { 23 - if level() >= min { 24 - eprintln!("{args}"); 25 - } 26 - } 15 + pub struct BatchInfo { pub batch_idx : u64 16 + , pub accepted_total : u32 17 + , pub max_splines : u32 18 + , pub accepted_now : u32 19 + , pub batch_size : u32 20 + , pub remaining : u32 21 + , pub consecutive_stagnant : u32 22 + , pub max_stagnant_batches : u32 23 + , pub elapsed_s : f64 } 27 24 28 - pub struct BatchInfo { 29 - pub batch_idx: u64, 30 - pub accepted_total: u32, 31 - pub max_splines: u32, 32 - pub accepted_now: u32, 33 - pub batch_size: u32, 34 - pub remaining: u32, 35 - pub consecutive_stagnant: u32, 36 - pub max_stagnant_batches: u32, 37 - pub elapsed_s: f64, 38 - } 39 - 40 - pub struct BatchLogger { 41 - rate_avg: Option<f64>, 42 - accept_ratio_avg: Option<f64>, 43 - alpha: f64, 44 - last_elapsed: f64, 45 - } 25 + pub struct BatchLogger { rate_avg : Option<f64> 26 + , accept_ratio_avg : Option<f64> 27 + , alpha : f64 28 + , last_elapsed : f64 } 46 29 47 30 impl BatchLogger { 48 - pub fn new() -> Self { 49 - Self { 50 - rate_avg: None, 51 - accept_ratio_avg: None, 52 - alpha: 0.25, 53 - last_elapsed: 0.0, 54 - } 55 - } 31 + pub fn new() -> Self { Self { rate_avg: None, accept_ratio_avg: None, alpha: 0.25, last_elapsed: 0.0, } } 56 32 57 - pub fn log_batch(&mut self, info: BatchInfo) { 58 - let stop = info.consecutive_stagnant >= info.max_stagnant_batches; 59 - let delta = (info.elapsed_s - self.last_elapsed).max(0.0); 60 - self.last_elapsed = info.elapsed_s; 33 + pub fn log_batch(&mut self, info: BatchInfo) { 34 + let stop = info.consecutive_stagnant >= info.max_stagnant_batches; 35 + let delta = (info.elapsed_s - self.last_elapsed).max(0.0); 36 + self.last_elapsed = info.elapsed_s; 61 37 62 - let rate_sample = if delta > 0.0 { 63 - info.accepted_now as f64 / delta 64 - } else { 65 - 0.0 66 - }; 67 - let ratio_sample = if info.batch_size > 0 { 68 - info.accepted_now as f64 / info.batch_size as f64 69 - } else { 70 - 0.0 71 - }; 38 + let rate_sample = if delta > 0.0 { info.accepted_now as f64 / delta } else { 0.0 }; 39 + let ratio_sample = if info.batch_size > 0 { info.accepted_now as f64 / info.batch_size as f64 } else { 0.0 }; 72 40 73 - self.rate_avg = Some(self.update_ema(self.rate_avg, rate_sample)); 74 - self.accept_ratio_avg = Some(self.update_ema(self.accept_ratio_avg, ratio_sample)); 41 + self.rate_avg = Some(self.update_ema(self.rate_avg, rate_sample)); 42 + self.accept_ratio_avg = Some(self.update_ema(self.accept_ratio_avg, ratio_sample)); 75 43 76 - let rate_avg = self.rate_avg.unwrap_or(0.0); 77 - let accept_avg = self.accept_ratio_avg.unwrap_or(0.0); 44 + let rate_avg = self.rate_avg.unwrap_or(0.0); 45 + let accept_avg = self.accept_ratio_avg.unwrap_or(0.0); 78 46 79 - self.emit(&info, ratio_sample, rate_avg, accept_avg, stop); 80 - } 47 + self.emit(&info, ratio_sample, rate_avg, accept_avg, stop); } 81 48 82 - fn update_ema(&self, current: Option<f64>, sample: f64) -> f64 { 83 - match current { 84 - Some(prev) => prev * (1.0 - self.alpha) + sample * self.alpha, 85 - None => sample, 86 - } 87 - } 49 + fn update_ema(&self, current: Option<f64>, sample: f64) -> f64 { 50 + match current { 51 + Some(prev) => prev * (1.0 - self.alpha) + sample * self.alpha, 52 + None => sample, } } 53 + 54 + fn emit(&self, info: &BatchInfo, batch_ratio: f64, rate_avg: f64, accept_avg: f64, stop: bool) { 55 + let lv = level(); let stop_note = if stop { " stop:stagnant" } else { "" }; 56 + 57 + if lv >= 3 { 58 + eprintln!( 59 + "b:{} t:{:.1}s acc:{}/{} (+{}, {:.3}) stagnant:{}/{} rem:{} avg_rate:{:.2}/s avg_accept:{:.3}{}", 60 + info.batch_idx, info.elapsed_s, info.accepted_total, info.max_splines, 61 + info.accepted_now, batch_ratio, info.consecutive_stagnant, info.max_stagnant_batches, 62 + info.remaining, rate_avg, accept_avg, stop_note ); } 88 63 89 - fn emit(&self, info: &BatchInfo, batch_ratio: f64, rate_avg: f64, accept_avg: f64, stop: bool) { 90 - let lv = level(); 91 - let stop_note = if stop { " stop:stagnant" } else { "" }; 64 + else if lv >= 2 { 65 + eprintln!( 66 + "b:{} t:{:.1}s acc:{}/{} (+{}, {:.3}) stagnant:{}/{} avg_rate:{:.2}/s avg_accept:{:.3}{}", 67 + info.batch_idx, info.elapsed_s, info.accepted_total, info.max_splines, 68 + info.accepted_now, batch_ratio, info.consecutive_stagnant, info.max_stagnant_batches, 69 + rate_avg, accept_avg, stop_note ); } 92 70 93 - if lv >= 3 { 94 - eprintln!( 95 - "b:{} t:{:.1}s acc:{}/{} (+{}, {:.3}) stagnant:{}/{} rem:{} avg_rate:{:.2}/s avg_accept:{:.3}{}", 96 - info.batch_idx, info.elapsed_s, info.accepted_total, info.max_splines, 97 - info.accepted_now, batch_ratio, info.consecutive_stagnant, info.max_stagnant_batches, 98 - info.remaining, rate_avg, accept_avg, stop_note 99 - ); 100 - } else if lv >= 2 { 101 - eprintln!( 102 - "b:{} t:{:.1}s acc:{}/{} (+{}, {:.3}) stagnant:{}/{} avg_rate:{:.2}/s avg_accept:{:.3}{}", 103 - info.batch_idx, info.elapsed_s, info.accepted_total, info.max_splines, 104 - info.accepted_now, batch_ratio, info.consecutive_stagnant, info.max_stagnant_batches, 105 - rate_avg, accept_avg, stop_note 106 - ); 107 - } else if lv >= 1 && (info.accepted_now > 0 || stop) { 108 - eprintln!( 109 - "{}/{} (+{}) stagnant:{}/{}{}", 110 - info.accepted_total, info.max_splines, info.accepted_now, 111 - info.consecutive_stagnant, info.max_stagnant_batches, stop_note 112 - ); 113 - } 114 - } 115 - } 71 + else if lv >= 1 && (info.accepted_now > 0 || stop) { 72 + eprintln!( 73 + "{}/{} (+{}) stagnant:{}/{}{}", 74 + info.accepted_total, info.max_splines, info.accepted_now, 75 + info.consecutive_stagnant, info.max_stagnant_batches, stop_note ); } } }
+9 -14
src/main.rs
··· 13 13 14 14 fn main() -> ExitCode { 15 15 match real_main() { 16 - Ok(()) => ExitCode::SUCCESS, 17 - Err(msg) => { 18 - eprintln!("{msg}"); 19 - ExitCode::FAILURE 20 - } 21 - } 22 - } 16 + Ok(()) => ExitCode::SUCCESS, 17 + Err(msg) => { eprintln!("{msg}"); ExitCode::FAILURE } } } 23 18 24 19 fn real_main() -> Result<(), String> { 25 20 let args = std::env::args().collect::<Vec<_>>(); 26 21 let cfg = args::Config::parse(&args)?; 22 + 27 23 log::set_level(cfg.log); 24 + 28 25 let input_path = Path::new(&cfg.input); 26 + 29 27 if input_path.is_dir() { 30 - if cfg.current.is_some() { 31 - return Err("-c/--current is not supported when input is a directory".into()); 32 - } 28 + if cfg.current.is_some() { return Err("-c/--current is not supported when input is a directory".into()); } 33 29 let out_root = fs::validate_and_prepare_dir_output(&cfg)?; 34 - return process_dir(&cfg, input_path, &out_root); 35 - } 30 + return process_dir(&cfg, input_path, &out_root); } 36 31 37 32 let output = ResolvedOutput::from_config(&cfg, PathBuf::from(&cfg.output)); 38 33 let final_path = process_one(&cfg, input_path, output)?; 39 34 println!("output: {}", final_path.display()); 40 - Ok(()) 41 - } 35 + 36 + Ok(()) }
+139 -192
src/metal/kernels.metal
··· 1 1 #include <metal_stdlib> 2 2 using namespace metal; 3 3 4 - struct Candidate { 5 - float2 p0; 6 - float2 p1; 7 - float2 p2; 8 - float2 p3; 9 - uint bx; 10 - uint by; 11 - uint bw; 12 - uint bh; 13 - }; 4 + struct Candidate { float2 p0 5 + ; float2 p1 6 + ; float2 p2 7 + ; float2 p3 8 + ; uint bx 9 + ; uint by 10 + ; uint bw 11 + ; uint bh; }; 14 12 15 - struct MeanOut { 16 - float sum_w; 17 - float3 sum_lab; 18 - }; 13 + struct MeanOut { float sum_w 14 + ; float3 sum_lab; }; 19 15 20 - struct ScoreOut { 21 - float delta_e2; 22 - }; 16 + struct ScoreOut { float delta_e2; }; 23 17 24 - struct CommonParams { 25 - uint width; 26 - uint height; 27 - float alpha; 28 - uint _pad0; 29 - }; 18 + struct CommonParams { uint width 19 + ; uint height 20 + ; float alpha 21 + ; uint _pad0; }; 30 22 31 - struct ApplyParams { 32 - uint cand_index; 33 - float alpha; 34 - uint _pad0; 35 - uint _pad1; 36 - }; 23 + struct ApplyParams { uint cand_index 24 + ; float alpha 25 + ; uint _pad0 26 + ; uint _pad1; }; 37 27 38 28 static inline float2 cubic_eval(thread const Candidate& c, float t) { 39 - float u = 1.0f - t; 40 - float b0 = u * u * u; 41 - float b1 = 3.0f * u * u * t; 42 - float b2 = 3.0f * u * t * t; 43 - float b3 = t * t * t; 44 - return c.p0 * b0 + c.p1 * b1 + c.p2 * b2 + c.p3 * b3; 45 - } 29 + float u = 1.0f - t; 30 + float b0 = u * u * u; 31 + float b1 = 3.0f * u * u * t; 32 + float b2 = 3.0f * u * t * t; 33 + float b3 = t * t * t; 34 + return c.p0 * b0 + c.p1 * b1 + c.p2 * b2 + c.p3 * b3; } 46 35 47 36 static inline float dist_to_segment(float2 p, float2 a, float2 b) { 48 - float2 ab = b - a; 49 - float denom = dot(ab, ab); 50 - if (denom <= 1e-12f) { 51 - return length(p - a); 52 - } 53 - float t = clamp(dot(p - a, ab) / denom, 0.0f, 1.0f); 54 - float2 q = a + t * ab; 55 - return length(p - q); 56 - } 37 + float2 ab = b - a; 38 + float denom = dot(ab, ab); 39 + if (denom <= 1e-12f) { return length(p - a); } 40 + float t = clamp(dot(p - a, ab) / denom, 0.0f, 1.0f); 41 + float2 q = a + t * ab; 42 + return length(p - q); } 57 43 58 44 static inline float coverage_at(thread const Candidate& c, float2 p) { 59 - constexpr uint segments = 32; 60 - float min_d = 1e9f; 61 - float2 prev = cubic_eval(c, 0.0f); 62 - for (uint i = 1; i <= segments; i++) { 63 - float t = (float)i / (float)segments; 64 - float2 cur = cubic_eval(c, t); 65 - min_d = min(min_d, dist_to_segment(p, prev, cur)); 66 - prev = cur; 67 - } 68 - return clamp(1.0f - 2.0f * min_d, 0.0f, 1.0f); 69 - } 45 + constexpr uint segments = 32; 46 + float min_d = 1e9f; 47 + float2 prev = cubic_eval(c, 0.0f); 48 + for (uint i = 1; i <= segments; i++) { 49 + float t = (float)i / (float)segments; 50 + float2 cur = cubic_eval(c, t); 51 + min_d = min(min_d, dist_to_segment(p, prev, cur)); 52 + prev = cur; } 53 + return clamp(1.0f - 2.0f * min_d, 0.0f, 1.0f); } 70 54 71 55 kernel void mean_pass( 72 - const device Candidate* candidates [[buffer(0)]], 73 - texture2d<float, access::read> target [[texture(0)]], 74 - device MeanOut* out [[buffer(1)]], 75 - constant CommonParams& params [[buffer(2)]], 76 - uint tid [[thread_index_in_threadgroup]], 77 - uint3 tg [[threadgroup_position_in_grid]] 78 - ) { 79 - constexpr uint TG = 256; 80 - uint ci = tg.x; 81 - Candidate c = candidates[ci]; 56 + const device Candidate* candidates [[buffer(0)]], 57 + texture2d<float, access::read> target [[texture(0)]], 58 + device MeanOut* out [[buffer(1)]], 59 + constant CommonParams& params [[buffer(2)]], 60 + uint tid [[thread_index_in_threadgroup]], 61 + uint3 tg [[threadgroup_position_in_grid]] ) { 62 + 63 + constexpr uint TG = 256; 64 + uint ci = tg.x; 65 + Candidate c = candidates[ci]; 82 66 83 - threadgroup float tg_sum_w[TG]; 84 - threadgroup float3 tg_sum_lab[TG]; 67 + threadgroup float tg_sum_w[TG]; 68 + threadgroup float3 tg_sum_lab[TG]; 85 69 86 - float sum_w = 0.0f; 87 - float3 sum_lab = float3(0.0f); 70 + float sum_w = 0.0f; 71 + float3 sum_lab = float3(0.0f); 88 72 89 - uint total = c.bw * c.bh; 90 - for (uint i = tid; i < total; i += TG) { 91 - uint ox = i % c.bw; 92 - uint oy = i / c.bw; 93 - uint x = c.bx + ox; 94 - uint y = c.by + oy; 95 - if (x >= params.width || y >= params.height) { 96 - continue; 97 - } 98 - float2 p = float2((float)x + 0.5f, (float)y + 0.5f); 99 - float cov = coverage_at(c, p); 100 - if (cov <= 0.0f) { 101 - continue; 102 - } 103 - float3 t_lab = target.read(uint2(x, y)).xyz; 104 - sum_w += cov; 105 - sum_lab += t_lab * cov; 106 - } 73 + uint total = c.bw * c.bh; 74 + for (uint i = tid; i < total; i += TG) { 75 + uint ox = i % c.bw; 76 + uint oy = i / c.bw; 77 + uint x = c.bx + ox; 78 + uint y = c.by + oy; 79 + if (x >= params.width || y >= params.height) { continue; } 80 + float2 p = float2((float)x + 0.5f, (float)y + 0.5f); 81 + float cov = coverage_at(c, p); 82 + if (cov <= 0.0f) { continue; } 83 + float3 t_lab = target.read(uint2(x, y)).xyz; 84 + sum_w += cov; 85 + sum_lab += t_lab * cov; } 107 86 108 - tg_sum_w[tid] = sum_w; 109 - tg_sum_lab[tid] = sum_lab; 110 - threadgroup_barrier(mem_flags::mem_threadgroup); 87 + tg_sum_w[tid] = sum_w; 88 + tg_sum_lab[tid] = sum_lab; 89 + threadgroup_barrier(mem_flags::mem_threadgroup); 111 90 112 - for (uint stride = TG / 2; stride > 0; stride >>= 1) { 113 - if (tid < stride) { 114 - tg_sum_w[tid] += tg_sum_w[tid + stride]; 115 - tg_sum_lab[tid] += tg_sum_lab[tid + stride]; 116 - } 117 - threadgroup_barrier(mem_flags::mem_threadgroup); 118 - } 91 + for (uint stride = TG / 2; stride > 0; stride >>= 1) { 92 + if (tid < stride) { tg_sum_w[tid] += tg_sum_w[tid + stride]; tg_sum_lab[tid] += tg_sum_lab[tid + stride]; } 93 + threadgroup_barrier(mem_flags::mem_threadgroup); } 119 94 120 - if (tid == 0) { 121 - MeanOut m; 122 - m.sum_w = tg_sum_w[0]; 123 - m.sum_lab = tg_sum_lab[0]; 124 - out[ci] = m; 125 - } 126 - } 95 + if (tid == 0) { MeanOut m; m.sum_w = tg_sum_w[0]; m.sum_lab = tg_sum_lab[0]; out[ci] = m; } } 127 96 128 97 kernel void score_pass( 129 - const device Candidate* candidates [[buffer(0)]], 130 - const device MeanOut* means [[buffer(1)]], 131 - device ScoreOut* out [[buffer(2)]], 132 - texture2d<float, access::read> target [[texture(0)]], 133 - texture2d<float, access::read> canvas [[texture(1)]], 134 - constant CommonParams& params [[buffer(3)]], 135 - uint tid [[thread_index_in_threadgroup]], 136 - uint3 tg [[threadgroup_position_in_grid]] 137 - ) { 138 - constexpr uint TG = 256; 139 - uint ci = tg.x; 140 - Candidate c = candidates[ci]; 141 - MeanOut m = means[ci]; 98 + const device Candidate* candidates [[buffer(0)]], 99 + const device MeanOut* means [[buffer(1)]], 100 + device ScoreOut* out [[buffer(2)]], 101 + texture2d<float, access::read> target [[texture(0)]], 102 + texture2d<float, access::read> canvas [[texture(1)]], 103 + constant CommonParams& params [[buffer(3)]], 104 + uint tid [[thread_index_in_threadgroup]], 105 + uint3 tg [[threadgroup_position_in_grid]] ) { 106 + 107 + constexpr uint TG = 256; 108 + uint ci = tg.x; 109 + Candidate c = candidates[ci]; 110 + MeanOut m = means[ci]; 142 111 143 - float3 stroke = float3(0.0f); 144 - if (m.sum_w > 0.0f) { 145 - stroke = m.sum_lab / m.sum_w; 146 - } 112 + float3 stroke = float3(0.0f); 113 + if (m.sum_w > 0.0f) { stroke = m.sum_lab / m.sum_w; } 147 114 148 - threadgroup float tg_delta[TG]; 115 + threadgroup float tg_delta[TG]; 149 116 150 - float delta = 0.0f; 151 - uint total = c.bw * c.bh; 152 - for (uint i = tid; i < total; i += TG) { 153 - uint ox = i % c.bw; 154 - uint oy = i / c.bw; 155 - uint x = c.bx + ox; 156 - uint y = c.by + oy; 157 - if (x >= params.width || y >= params.height) { 158 - continue; 159 - } 160 - float2 p = float2((float)x + 0.5f, (float)y + 0.5f); 161 - float cov = coverage_at(c, p); 162 - if (cov <= 0.0f) { 163 - continue; 164 - } 117 + float delta = 0.0f; 118 + uint total = c.bw * c.bh; 119 + for (uint i = tid; i < total; i += TG) { 120 + uint ox = i % c.bw; 121 + uint oy = i / c.bw; 122 + uint x = c.bx + ox; 123 + uint y = c.by + oy; 124 + if (x >= params.width || y >= params.height) { continue; } 125 + float2 p = float2((float)x + 0.5f, (float)y + 0.5f); 126 + float cov = coverage_at(c, p); 127 + if (cov <= 0.0f) { continue; } 165 128 166 - float eff_alpha = clamp(params.alpha * cov, 0.0f, 1.0f); 129 + float eff_alpha = clamp(params.alpha * cov, 0.0f, 1.0f); 167 130 168 - float3 t_lab = target.read(uint2(x, y)).xyz; 169 - float3 c_lab = canvas.read(uint2(x, y)).xyz; 170 - float3 blended = mix(c_lab, stroke, eff_alpha); 131 + float3 t_lab = target.read(uint2(x, y)).xyz; 132 + float3 c_lab = canvas.read(uint2(x, y)).xyz; 133 + float3 blended = mix(c_lab, stroke, eff_alpha); 171 134 172 - float3 d_after = blended - t_lab; 173 - float3 d_before = c_lab - t_lab; 174 - float e_after = dot(d_after, d_after); 175 - float e_before = dot(d_before, d_before); 176 - delta += (e_after - e_before); 177 - } 135 + float3 d_after = blended - t_lab; 136 + float3 d_before = c_lab - t_lab; 137 + float e_after = dot(d_after, d_after); 138 + float e_before = dot(d_before, d_before); 139 + delta += (e_after - e_before); } 178 140 179 - tg_delta[tid] = delta; 180 - threadgroup_barrier(mem_flags::mem_threadgroup); 141 + tg_delta[tid] = delta; 142 + threadgroup_barrier(mem_flags::mem_threadgroup); 181 143 182 - for (uint stride = TG / 2; stride > 0; stride >>= 1) { 183 - if (tid < stride) { 184 - tg_delta[tid] += tg_delta[tid + stride]; 185 - } 186 - threadgroup_barrier(mem_flags::mem_threadgroup); 187 - } 144 + for (uint stride = TG / 2; stride > 0; stride >>= 1) { 145 + if (tid < stride) { tg_delta[tid] += tg_delta[tid + stride]; } 146 + threadgroup_barrier(mem_flags::mem_threadgroup); } 188 147 189 - if (tid == 0) { 190 - ScoreOut s; 191 - s.delta_e2 = tg_delta[0]; 192 - out[ci] = s; 193 - } 194 - } 148 + if (tid == 0) { ScoreOut s; s.delta_e2 = tg_delta[0]; out[ci] = s; } } 195 149 196 150 kernel void apply_pass( 197 - texture2d<float, access::read_write> canvas [[texture(0)]], 198 - const device Candidate* candidates [[buffer(0)]], 199 - const device MeanOut* means [[buffer(1)]], 200 - constant ApplyParams& params [[buffer(2)]], 201 - uint2 tid [[thread_position_in_grid]] 202 - ) { 203 - uint ci = params.cand_index; 204 - Candidate c = candidates[ci]; 205 - MeanOut m = means[ci]; 206 - if (m.sum_w <= 0.0f) { 207 - return; 208 - } 151 + texture2d<float, access::read_write> canvas [[texture(0)]], 152 + const device Candidate* candidates [[buffer(0)]], 153 + const device MeanOut* means [[buffer(1)]], 154 + constant ApplyParams& params [[buffer(2)]], 155 + uint2 tid [[thread_position_in_grid]] ) { 209 156 210 - uint ox = tid.x; 211 - uint oy = tid.y; 212 - if (ox >= c.bw || oy >= c.bh) { 213 - return; 214 - } 157 + uint ci = params.cand_index; 158 + Candidate c = candidates[ci]; 159 + MeanOut m = means[ci]; 160 + if (m.sum_w <= 0.0f) { return; } 161 + 162 + uint ox = tid.x; 163 + uint oy = tid.y; 164 + if (ox >= c.bw || oy >= c.bh) { return; } 215 165 216 - uint x = c.bx + ox; 217 - uint y = c.by + oy; 166 + uint x = c.bx + ox; 167 + uint y = c.by + oy; 218 168 219 - float3 stroke = m.sum_lab / m.sum_w; 220 - float2 p = float2((float)x + 0.5f, (float)y + 0.5f); 221 - float cov = coverage_at(c, p); 222 - if (cov <= 0.0f) { 223 - return; 224 - } 169 + float3 stroke = m.sum_lab / m.sum_w; 170 + float2 p = float2((float)x + 0.5f, (float)y + 0.5f); 171 + float cov = coverage_at(c, p); 172 + if (cov <= 0.0f) { return; } 225 173 226 - float eff_alpha = clamp(params.alpha * cov, 0.0f, 1.0f); 227 - float4 old = canvas.read(uint2(x, y)); 228 - float3 blended = mix(old.xyz, stroke, eff_alpha); 229 - canvas.write(float4(blended, 0.0f), uint2(x, y)); 230 - } 174 + float eff_alpha = clamp(params.alpha * cov, 0.0f, 1.0f); 175 + float4 old = canvas.read(uint2(x, y)); 176 + float3 blended = mix(old.xyz, stroke, eff_alpha); 177 + canvas.write(float4(blended, 0.0f), uint2(x, y)); }
+232 -468
src/metal/mod.rs
··· 5 5 use objc2::rc::Retained; 6 6 use objc2::runtime::ProtocolObject; 7 7 use objc2_foundation::{ns_string, NSString}; 8 - use objc2_metal::{ 9 - MTLBuffer, MTLCommandBuffer, MTLCommandEncoder, MTLCommandQueue, MTLComputeCommandEncoder, 10 - MTLComputePipelineState, MTLCreateSystemDefaultDevice, MTLDevice, MTLLibrary, MTLRegion, 11 - MTLResourceOptions, MTLSize, MTLStorageMode, MTLTexture, MTLTextureDescriptor, MTLTextureUsage, 12 - MTLPixelFormat, 13 - }; 8 + use objc2_metal::{ MTLBuffer, MTLCommandBuffer, MTLCommandEncoder, MTLCommandQueue, MTLComputeCommandEncoder 9 + , MTLComputePipelineState, MTLCreateSystemDefaultDevice, MTLDevice, MTLLibrary, MTLRegion 10 + , MTLResourceOptions, MTLSize, MTLStorageMode, MTLTexture, MTLTextureDescriptor, MTLTextureUsage 11 + , MTLPixelFormat }; 14 12 use std::ffi::c_void; 15 13 use std::ptr::NonNull; 16 14 use std::thread; ··· 18 16 19 17 #[repr(C)] 20 18 #[derive(Clone, Copy, Debug, Default)] 21 - /// a single bezier candidate and its conservative pixel bbox. 22 - /// 23 - /// `p0..p3` are control points in pixel space. 24 - /// `(bx, by, bw, bh)` bounds the region on the canvas potentially affected by the stroke. 25 - pub struct Candidate { 26 - pub p0: [f32; 2], 27 - pub p1: [f32; 2], 28 - pub p2: [f32; 2], 29 - pub p3: [f32; 2], 30 - pub bx: u32, 31 - pub by: u32, 32 - pub bw: u32, 33 - pub bh: u32, 34 - } 19 + pub struct Candidate { pub p0: [f32; 2], pub p1: [f32; 2], pub p2: [f32; 2], pub p3: [f32; 2] 20 + , pub bx: u32, pub by: u32, pub bw: u32, pub bh: u32 } 35 21 36 22 #[repr(C)] 37 23 #[derive(Clone, Copy, Debug, Default)] 38 24 /// output payload for the mean kernel. 39 - struct MeanOut { 40 - sum_w: f32, 41 - _pad0: [f32; 3], 42 - sum_lab: [f32; 3], 43 - _pad1: f32, 44 - } 25 + struct MeanOut { sum_w: f32, _pad0: [f32; 3], sum_lab: [f32; 3], _pad1: f32, } 45 26 46 27 #[repr(C)] 47 28 #[derive(Clone, Copy, Debug, Default)] 48 29 /// output payload for the scoring kernel (lower is better). 49 - struct ScoreOut { 50 - delta_e2: f32, 51 - } 30 + struct ScoreOut { delta_e2: f32, } 52 31 53 32 #[repr(C)] 54 33 #[derive(Clone, Copy, Debug, Default)] 55 34 /// common parameters shared by compute passes. 56 - struct CommonParams { 57 - width: u32, 58 - height: u32, 59 - alpha: f32, 60 - _pad0: u32, 61 - } 35 + struct CommonParams { width: u32, height: u32, alpha: f32, _pad0: u32, } 62 36 63 37 #[repr(C)] 64 38 #[derive(Clone, Copy, Debug, Default)] 65 39 /// parameters for applying a selected candidate to the canvas. 66 - struct ApplyParams { 67 - cand_index: u32, 68 - alpha: f32, 69 - _pad0: u32, 70 - _pad1: u32, 71 - } 40 + struct ApplyParams { cand_index: u32, alpha: f32, _pad0: u32, _pad1: u32, } 72 41 73 42 /// metal compute context and persistent gpu resources for the current run. 74 - pub struct MetalContext { 75 - queue: Retained<ProtocolObject<dyn MTLCommandQueue>>, 76 - mean_pso: Retained<ProtocolObject<dyn MTLComputePipelineState>>, 77 - score_pso: Retained<ProtocolObject<dyn MTLComputePipelineState>>, 78 - apply_pso: Retained<ProtocolObject<dyn MTLComputePipelineState>>, 79 - target: Retained<ProtocolObject<dyn MTLTexture>>, 80 - canvas: Retained<ProtocolObject<dyn MTLTexture>>, 81 - candidates_buf: Retained<ProtocolObject<dyn objc2_metal::MTLBuffer>>, 82 - means_buf: Retained<ProtocolObject<dyn objc2_metal::MTLBuffer>>, 83 - scores_buf: Retained<ProtocolObject<dyn objc2_metal::MTLBuffer>>, 84 - width: u32, 85 - height: u32, 86 - batch_cap: u32, 87 - max_gpu: f32, 88 - } 89 - 43 + pub struct MetalContext { queue : Retained<ProtocolObject<dyn MTLCommandQueue>> 44 + , mean_pso : Retained<ProtocolObject<dyn MTLComputePipelineState>> 45 + , score_pso : Retained<ProtocolObject<dyn MTLComputePipelineState>> 46 + , apply_pso : Retained<ProtocolObject<dyn MTLComputePipelineState>> 47 + , target : Retained<ProtocolObject<dyn MTLTexture>> 48 + , canvas : Retained<ProtocolObject<dyn MTLTexture>> 49 + , candidates_buf : Retained<ProtocolObject<dyn objc2_metal::MTLBuffer>> 50 + , means_buf : Retained<ProtocolObject<dyn objc2_metal::MTLBuffer>> 51 + , scores_buf : Retained<ProtocolObject<dyn objc2_metal::MTLBuffer>> 52 + , width : u32 53 + , height : u32 54 + , batch_cap : u32 55 + , max_gpu : f32 } 90 56 impl MetalContext { 91 - pub fn new(width: u32, height: u32, batch_cap: u32, max_gpu: f32) -> Result<Self, String> { 92 - let Some(device) = MTLCreateSystemDefaultDevice() else { 93 - return Err("failed to create metal device".to_string()); 94 - }; 95 - let Some(queue) = device.newCommandQueue() else { 96 - return Err("failed to create metal command queue".to_string()); 97 - }; 98 57 99 - let lib = compile_library(&device)?; 100 - let mean_pso = compile_pipeline(&device, &lib, ns_string!("mean_pass"))?; 101 - let score_pso = compile_pipeline(&device, &lib, ns_string!("score_pass"))?; 102 - let apply_pso = compile_pipeline(&device, &lib, ns_string!("apply_pass"))?; 58 + pub fn new(width: u32, height: u32, batch_cap: u32, max_gpu: f32) -> Result<Self, String> { 59 + let Some(device) = MTLCreateSystemDefaultDevice() else { return Err("failed to create metal device".to_string()); }; 60 + let Some(queue) = device.newCommandQueue() else { return Err("failed to create metal command queue".to_string()); }; 103 61 104 - let target = make_lab_texture(&device, width, height, true)?; 105 - let canvas = make_lab_texture(&device, width, height, true)?; 62 + let lib = compile_library(&device)?; 63 + let mean_pso = compile_pipeline(&device, &lib, ns_string!("mean_pass"))?; 64 + let score_pso = compile_pipeline(&device, &lib, ns_string!("score_pass"))?; 65 + let apply_pso = compile_pipeline(&device, &lib, ns_string!("apply_pass"))?; 106 66 107 - let candidates_bytes = (batch_cap as usize) * std::mem::size_of::<Candidate>(); 108 - let means_bytes = (batch_cap as usize) * std::mem::size_of::<MeanOut>(); 109 - let scores_bytes = (batch_cap as usize) * std::mem::size_of::<ScoreOut>(); 67 + let target = make_lab_texture(&device, width, height, true)?; 68 + let canvas = make_lab_texture(&device, width, height, true)?; 110 69 111 - let opts = MTLResourceOptions::StorageModeShared; 112 - let candidates_buf = device 113 - .newBufferWithLength_options(candidates_bytes as _, opts) 114 - .ok_or_else(|| "failed to allocate candidates buffer".to_string())?; 115 - let means_buf = device 116 - .newBufferWithLength_options(means_bytes as _, opts) 117 - .ok_or_else(|| "failed to allocate means buffer".to_string())?; 118 - let scores_buf = device 119 - .newBufferWithLength_options(scores_bytes as _, opts) 120 - .ok_or_else(|| "failed to allocate scores buffer".to_string())?; 70 + let candidates_bytes = (batch_cap as usize) * std::mem::size_of::<Candidate>(); 71 + let means_bytes = (batch_cap as usize) * std::mem::size_of::<MeanOut>(); 72 + let scores_bytes = (batch_cap as usize) * std::mem::size_of::<ScoreOut>(); 121 73 122 - Ok(Self { 123 - queue, 124 - mean_pso, 125 - score_pso, 126 - apply_pso, 127 - target, 128 - canvas, 129 - candidates_buf, 130 - means_buf, 131 - scores_buf, 132 - width, 133 - height, 134 - batch_cap, 135 - max_gpu, 136 - }) 137 - } 74 + let opts = MTLResourceOptions::StorageModeShared; 75 + let candidates_buf = device.newBufferWithLength_options(candidates_bytes as _, opts).ok_or_else(|| "failed to allocate candidates buffer".to_string())?; 76 + let means_buf = device.newBufferWithLength_options(means_bytes as _, opts).ok_or_else(|| "failed to allocate means buffer".to_string())?; 77 + let scores_buf = device.newBufferWithLength_options(scores_bytes as _, opts).ok_or_else(|| "failed to allocate scores buffer".to_string())?; 138 78 139 - fn throttle_after(&self, work: Duration) { 140 - if !(0.0 < self.max_gpu && self.max_gpu < 1.0) { 141 - return; 142 - } 143 - let work_s = work.as_secs_f64(); 144 - if work_s <= 0.0 { 145 - return; 146 - } 147 - let idle_s = work_s * (1.0 / (self.max_gpu as f64) - 1.0); 148 - if idle_s <= 0.0 { 149 - return; 150 - } 151 - thread::sleep(Duration::from_secs_f64(idle_s)); 152 - } 79 + Ok(Self { queue, mean_pso, score_pso, apply_pso, target, canvas, candidates_buf, means_buf, scores_buf, width, height, batch_cap, max_gpu, }) } 153 80 154 - fn commit_wait_throttled(&self, cb: &ProtocolObject<dyn MTLCommandBuffer>) { 155 - let started = Instant::now(); 156 - cb.commit(); 157 - cb.waitUntilCompleted(); 158 - self.throttle_after(started.elapsed()); 159 - } 81 + fn throttle_after(&self, work: Duration) { 82 + if !(0.0 < self.max_gpu && self.max_gpu < 1.0) { return; } 160 83 161 - pub fn upload_target_and_init_canvas( 162 - &self, 163 - target: &[Oklab], 164 - canvas_fill: Oklab, 165 - ) -> Result<(), String> { 166 - if target.len() != (self.width as usize) * (self.height as usize) { 167 - return Err("target size mismatch".to_string()); 168 - } 84 + let work_s = work.as_secs_f64(); 85 + if work_s <= 0.0 { return; } 169 86 170 - let tgt = pack_oklab_rgba32f(target); 87 + let idle_s = work_s * (1.0 / (self.max_gpu as f64) - 1.0); 88 + if idle_s <= 0.0 { return; } 171 89 172 - let mut can = Vec::<f32>::with_capacity(target.len() * 4); 173 - for _ in 0..target.len() { 174 - can.extend_from_slice(&[canvas_fill.l, canvas_fill.a, canvas_fill.b, 0.0]); 175 - } 90 + thread::sleep(Duration::from_secs_f64(idle_s)); } 176 91 177 - unsafe { 178 - write_full_texture_rgba32f(&self.target, self.width, self.height, &tgt)?; 179 - write_full_texture_rgba32f(&self.canvas, self.width, self.height, &can)?; 180 - } 92 + fn commit_wait_throttled(&self, cb: &ProtocolObject<dyn MTLCommandBuffer>) { 93 + let started = Instant::now(); 94 + cb.commit(); 95 + cb.waitUntilCompleted(); 96 + self.throttle_after(started.elapsed()); } 181 97 182 - Ok(()) 183 - } 98 + pub fn upload_target_and_init_canvas(&self, target: &[Oklab], canvas_fill: Oklab) -> Result<(), String> { 99 + if target.len() != (self.width as usize) * (self.height as usize) { return Err("target size mismatch".to_string()); } 184 100 185 - pub fn upload_target_and_set_canvas(&self, target: &[Oklab], canvas: &[Oklab]) -> Result<(), String> { 186 - let n = (self.width as usize) * (self.height as usize); 187 - if target.len() != n { 188 - return Err("target size mismatch".to_string()); 189 - } 190 - if canvas.len() != n { 191 - return Err("canvas size mismatch".to_string()); 192 - } 101 + let tgt = pack_oklab_rgba32f(target); 193 102 194 - let tgt = pack_oklab_rgba32f(target); 195 - let can = pack_oklab_rgba32f(canvas); 196 - unsafe { 197 - write_full_texture_rgba32f(&self.target, self.width, self.height, &tgt)?; 198 - write_full_texture_rgba32f(&self.canvas, self.width, self.height, &can)?; 199 - } 200 - Ok(()) 201 - } 103 + let mut can = Vec::<f32>::with_capacity(target.len() * 4); 104 + for _ in 0..target.len() { can.extend_from_slice(&[canvas_fill.l, canvas_fill.a, canvas_fill.b, 0.0]); } 202 105 203 - fn run_mean_and_score(&self, candidates: &[Candidate], alpha: f32) -> Result<Vec<bool>, String> { 204 - if candidates.is_empty() { 205 - return Ok(Vec::new()); 206 - } 207 - if candidates.len() > (self.batch_cap as usize) { 208 - return Err("batch exceeds configured capacity".into()); 209 - } 106 + unsafe { write_full_texture_rgba32f(&self.target, self.width, self.height, &tgt)?; 107 + write_full_texture_rgba32f(&self.canvas, self.width, self.height, &can)?; } 210 108 211 - unsafe { write_candidates(&self.candidates_buf, candidates) }; 109 + Ok(()) } 212 110 213 - let params = CommonParams { width: self.width, height: self.height, alpha, _pad0: 0 }; 214 - let batch = candidates.len() as u32; 111 + pub fn upload_target_and_set_canvas(&self, target: &[Oklab], canvas: &[Oklab]) -> Result<(), String> { 112 + let n = (self.width as usize) * (self.height as usize); 113 + if target.len() != n { return Err("target size mismatch".to_string()); } 114 + if canvas.len() != n { return Err("canvas size mismatch".to_string()); } 215 115 216 - let cb = new_command_buffer(&self.queue)?; 217 - let enc = new_compute_encoder(&cb)?; 218 - encode_mean(&enc, &self.mean_pso, &self.candidates_buf, &self.target, &self.means_buf, &params, batch); 219 - enc.endEncoding(); 116 + let tgt = pack_oklab_rgba32f(target); 117 + let can = pack_oklab_rgba32f(canvas); 118 + unsafe { write_full_texture_rgba32f(&self.target, self.width, self.height, &tgt)?; 119 + write_full_texture_rgba32f(&self.canvas, self.width, self.height, &can)?; } 120 + Ok(()) } 220 121 221 - let enc2 = new_compute_encoder(&cb)?; 222 - encode_score(&enc2, &self.score_pso, &self.candidates_buf, &self.means_buf, &self.scores_buf, &self.target, &self.canvas, &params, batch); 223 - enc2.endEncoding(); 122 + fn run_mean_and_score(&self, candidates: &[Candidate], alpha: f32) -> Result<Vec<bool>, String> { 123 + if candidates.is_empty() { return Ok(Vec::new()); } 124 + if candidates.len() > (self.batch_cap as usize) { return Err("batch exceeds configured capacity".into()); } 224 125 225 - self.commit_wait_throttled(&cb); 126 + unsafe { write_candidates(&self.candidates_buf, candidates) }; 226 127 227 - let scores = unsafe { read_scores(&self.scores_buf, candidates.len()) }; 228 - Ok(scores.iter().map(|s| s.delta_e2 < 0.0).collect()) 229 - } 128 + let params = CommonParams { width: self.width, height: self.height, alpha, _pad0: 0 }; 129 + let batch = candidates.len() as u32; 230 130 231 - fn new_apply_encoder(&self) -> Result<(Retained<ProtocolObject<dyn MTLCommandBuffer>>, Retained<ProtocolObject<dyn MTLComputeCommandEncoder>>), String> { 232 - let cb = new_command_buffer(&self.queue)?; 233 - let enc = new_compute_encoder(&cb)?; 234 - enc.setComputePipelineState(&self.apply_pso); 235 - unsafe { 236 - enc.setBuffer_offset_atIndex(Some(&self.candidates_buf), 0, 0); 237 - enc.setBuffer_offset_atIndex(Some(&self.means_buf), 0, 1); 238 - enc.setTexture_atIndex(Some(&self.canvas), 0); 239 - } 240 - Ok((cb, enc)) 241 - } 131 + let cb = new_command_buffer(&self.queue)?; 132 + let enc = new_compute_encoder(&cb)?; 133 + encode_mean(&enc, &self.mean_pso, &self.candidates_buf, &self.target, &self.means_buf, &params, batch); 134 + enc.endEncoding(); 242 135 243 - fn dispatch_apply(&self, enc: &ProtocolObject<dyn MTLComputeCommandEncoder>, c: &Candidate, ci: usize, alpha: f32) { 244 - let ap = ApplyParams { cand_index: ci as u32, alpha, _pad0: 0, _pad1: 0 }; 245 - unsafe { 246 - enc.setBytes_length_atIndex( 247 - NonNull::new_unchecked((&ap as *const ApplyParams).cast::<c_void>() as *mut c_void), 248 - std::mem::size_of::<ApplyParams>() as _, 2, 249 - ); 250 - } 251 - let grid = MTLSize { width: c.bw as _, height: c.bh as _, depth: 1 }; 252 - let tg = MTLSize { width: 16, height: 16, depth: 1 }; 253 - enc.dispatchThreads_threadsPerThreadgroup(grid, tg); 254 - } 136 + let enc2 = new_compute_encoder(&cb)?; 137 + encode_score(&enc2, &self.score_pso, &self.candidates_buf, &self.means_buf, &self.scores_buf, &self.target, &self.canvas, &params, batch); 138 + enc2.endEncoding(); 255 139 256 - pub fn process_batch(&self, candidates: &[Candidate], alpha: f32, apply_limit: u32) -> Result<Vec<bool>, String> { 257 - let raw_accepted = self.run_mean_and_score(candidates, alpha)?; 258 - if raw_accepted.is_empty() { 259 - return Ok(raw_accepted); 260 - } 140 + self.commit_wait_throttled(&cb); 261 141 262 - let (cb, enc) = self.new_apply_encoder()?; 263 - let mut accepted = vec![false; raw_accepted.len()]; 264 - let mut applied = 0u32; 142 + let scores = unsafe { read_scores(&self.scores_buf, candidates.len()) }; 143 + Ok(scores.iter().map(|s| s.delta_e2 < 0.0).collect()) } 265 144 266 - for (ci, ok) in raw_accepted.iter().enumerate() { 267 - if *ok && applied < apply_limit { 268 - accepted[ci] = true; 269 - applied += 1; 270 - self.dispatch_apply(&enc, &candidates[ci], ci, alpha); 271 - } 272 - } 145 + fn new_apply_encoder(&self) -> Result<(Retained<ProtocolObject<dyn MTLCommandBuffer>>, Retained<ProtocolObject<dyn MTLComputeCommandEncoder>>), String> { 146 + let cb = new_command_buffer(&self.queue)?; 147 + let enc = new_compute_encoder(&cb)?; 148 + enc.setComputePipelineState(&self.apply_pso); 149 + unsafe { enc.setBuffer_offset_atIndex(Some(&self.candidates_buf), 0, 0); 150 + enc.setBuffer_offset_atIndex(Some(&self.means_buf), 0, 1); 151 + enc.setTexture_atIndex(Some(&self.canvas), 0); } 152 + Ok((cb, enc)) } 273 153 274 - enc.endEncoding(); 275 - self.commit_wait_throttled(&cb); 276 - Ok(accepted) 277 - } 154 + fn dispatch_apply(&self, enc: &ProtocolObject<dyn MTLComputeCommandEncoder>, c: &Candidate, ci: usize, alpha: f32) { 155 + let ap = ApplyParams { cand_index: ci as u32, alpha, _pad0: 0, _pad1: 0 }; 156 + unsafe { enc.setBytes_length_atIndex( 157 + NonNull::new_unchecked((&ap as *const ApplyParams).cast::<c_void>() as *mut c_void), 158 + std::mem::size_of::<ApplyParams>() as _, 2, ); } 159 + let grid = MTLSize { width: c.bw as _, height: c.bh as _, depth: 1 }; 160 + let tg = MTLSize { width: 16, height: 16, depth: 1 }; 161 + enc.dispatchThreads_threadsPerThreadgroup(grid, tg); } 278 162 279 - pub fn process_batch_checkpointed<F>( 280 - &self, 281 - candidates: &[Candidate], 282 - alpha: f32, 283 - apply_limit: u32, 284 - accepted_total_before: u32, 285 - nth: u32, 286 - mut on_checkpoint: F, 287 - ) -> Result<(Vec<bool>, u32), String> 288 - where 289 - F: FnMut(&MetalContext, u32) -> Result<(), String>, 290 - { 291 - if nth == 0 { 292 - return Err("nth must be > 0".into()); 293 - } 163 + pub fn process_batch(&self, candidates: &[Candidate], alpha: f32, apply_limit: u32) -> Result<Vec<bool>, String> { 164 + let raw_accepted = self.run_mean_and_score(candidates, alpha)?; 165 + if raw_accepted.is_empty() { return Ok(raw_accepted); } 294 166 295 - let raw_accepted = self.run_mean_and_score(candidates, alpha)?; 296 - if raw_accepted.is_empty() { 297 - return Ok((raw_accepted, accepted_total_before)); 298 - } 167 + let (cb, enc) = self.new_apply_encoder()?; 168 + let mut accepted = vec![false; raw_accepted.len()]; 169 + let mut applied = 0u32; 299 170 300 - let apply_indices: Vec<_> = raw_accepted.iter().enumerate() 301 - .filter_map(|(i, ok)| ok.then_some(i)) 302 - .take(apply_limit as usize) 303 - .collect(); 171 + for (ci, ok) in raw_accepted.iter().enumerate() { 172 + if *ok && applied < apply_limit { accepted[ci] = true; applied += 1; self.dispatch_apply(&enc, &candidates[ci], ci, alpha); } } 304 173 305 - let mut accepted = vec![false; raw_accepted.len()]; 306 - for &i in &apply_indices { 307 - accepted[i] = true; 308 - } 174 + enc.endEncoding(); 175 + self.commit_wait_throttled(&cb); 176 + Ok(accepted) } 309 177 310 - if apply_indices.is_empty() { 311 - return Ok((accepted, accepted_total_before)); 312 - } 178 + pub fn process_batch_checkpointed<F>( 179 + &self, 180 + candidates: &[Candidate], 181 + alpha: f32, 182 + apply_limit: u32, 183 + accepted_total_before: u32, 184 + nth: u32, 185 + mut on_checkpoint: F ) -> Result<(Vec<bool>, u32), String> where F: FnMut(&MetalContext, u32) -> Result<(), String>, { 313 186 314 - let mut accepted_total = accepted_total_before; 315 - let (mut cb, mut enc) = self.new_apply_encoder()?; 187 + if nth == 0 { return Err("nth must be > 0".into()); } 316 188 317 - for (pos, &ci) in apply_indices.iter().enumerate() { 318 - accepted_total = accepted_total.saturating_add(1); 319 - self.dispatch_apply(&enc, &candidates[ci], ci, alpha); 189 + let raw_accepted = self.run_mean_and_score(candidates, alpha)?; 190 + if raw_accepted.is_empty() { return Ok((raw_accepted, accepted_total_before)); } 320 191 321 - if accepted_total % nth == 0 { 322 - enc.endEncoding(); 323 - self.commit_wait_throttled(&cb); 324 - on_checkpoint(self, accepted_total)?; 192 + let apply_indices: Vec<_> = raw_accepted.iter().enumerate().filter_map(|(i, ok)| ok.then_some(i)).take(apply_limit as usize).collect(); 325 193 326 - if pos + 1 < apply_indices.len() { 327 - (cb, enc) = self.new_apply_encoder()?; 328 - } 329 - } 330 - } 194 + let mut accepted = vec![false; raw_accepted.len()]; 195 + for &i in &apply_indices { accepted[i] = true; } 331 196 332 - if accepted_total % nth != 0 { 333 - enc.endEncoding(); 334 - self.commit_wait_throttled(&cb); 335 - } 197 + if apply_indices.is_empty() { return Ok((accepted, accepted_total_before)); } 336 198 337 - Ok((accepted, accepted_total)) 338 - } 199 + let mut accepted_total = accepted_total_before; 200 + let (mut cb, mut enc) = self.new_apply_encoder()?; 339 201 340 - pub fn read_canvas(&self) -> Result<Vec<Oklab>, String> { 341 - let mut rgba = vec![0.0f32; (self.width as usize) * (self.height as usize) * 4]; 342 - unsafe { read_full_texture_rgba32f(&self.canvas, self.width, self.height, &mut rgba)? }; 202 + for (pos, &ci) in apply_indices.iter().enumerate() { 203 + accepted_total = accepted_total.saturating_add(1); 204 + self.dispatch_apply(&enc, &candidates[ci], ci, alpha); 343 205 344 - let mut out = Vec::with_capacity((self.width as usize) * (self.height as usize)); 345 - for i in 0..((self.width as usize) * (self.height as usize)) { 346 - out.push(Oklab { 347 - l: rgba[i * 4], 348 - a: rgba[i * 4 + 1], 349 - b: rgba[i * 4 + 2], 350 - }); 351 - } 352 - Ok(out) 353 - } 354 - } 206 + if accepted_total % nth == 0 { 207 + enc.endEncoding(); 208 + self.commit_wait_throttled(&cb); 209 + on_checkpoint(self, accepted_total)?; 210 + 211 + if pos + 1 < apply_indices.len() { (cb, enc) = self.new_apply_encoder()?; } } } 212 + 213 + if accepted_total % nth != 0 { enc.endEncoding(); self.commit_wait_throttled(&cb); } 214 + 215 + Ok((accepted, accepted_total)) } 216 + 217 + pub fn read_canvas(&self) -> Result<Vec<Oklab>, String> { 218 + let mut rgba = vec![0.0f32; (self.width as usize) * (self.height as usize) * 4]; 219 + unsafe { read_full_texture_rgba32f(&self.canvas, self.width, self.height, &mut rgba)? }; 220 + 221 + let mut out = Vec::with_capacity((self.width as usize) * (self.height as usize)); 222 + for i in 0..((self.width as usize) * (self.height as usize)) { out.push(Oklab { l: rgba[i * 4], a: rgba[i * 4 + 1], b: rgba[i * 4 + 2], }); } 223 + Ok(out) } } 355 224 356 225 fn pack_oklab_rgba32f(pixels: &[Oklab]) -> Vec<f32> { 357 226 let mut out = Vec::<f32>::with_capacity(pixels.len() * 4); 358 - for p in pixels { 359 - out.extend_from_slice(&[p.l, p.a, p.b, 0.0]); 360 - } 361 - out 362 - } 227 + for p in pixels { out.extend_from_slice(&[p.l, p.a, p.b, 0.0]); } 228 + out } 363 229 364 230 fn compile_library(device: &ProtocolObject<dyn MTLDevice>) -> Result<Retained<ProtocolObject<dyn MTLLibrary>>, String> { 365 231 let src = include_str!("kernels.metal"); 366 232 let ns_src = NSString::from_str(src); 367 - device 368 - .newLibraryWithSource_options_error(&ns_src, None) 369 - .map_err(|e| format!("failed to compile metal library: {}", e)) 370 - } 233 + device.newLibraryWithSource_options_error(&ns_src, None).map_err(|e| format!("failed to compile metal library: {}", e)) } 371 234 372 - fn compile_pipeline( 373 - device: &ProtocolObject<dyn MTLDevice>, 374 - lib: &ProtocolObject<dyn MTLLibrary>, 375 - name: &NSString, 376 - ) -> Result<Retained<ProtocolObject<dyn MTLComputePipelineState>>, String> { 377 - let Some(f) = lib.newFunctionWithName(name) else { 378 - return Err(format!("missing metal function: {}", name)); 379 - }; 380 - device 381 - .newComputePipelineStateWithFunction_error(&f) 382 - .map_err(|e| format!("failed to build compute pipeline: {}", e)) 383 - } 235 + fn compile_pipeline(device: &ProtocolObject<dyn MTLDevice>, lib: &ProtocolObject<dyn MTLLibrary>, name: &NSString) -> Result<Retained<ProtocolObject<dyn MTLComputePipelineState>>, String> { 236 + let Some(f) = lib.newFunctionWithName(name) else { return Err(format!("missing metal function: {}", name)); }; 237 + device.newComputePipelineStateWithFunction_error(&f).map_err(|e| format!("failed to build compute pipeline: {}", e)) } 384 238 385 - fn make_lab_texture( 386 - device: &ProtocolObject<dyn MTLDevice>, 387 - width: u32, 388 - height: u32, 389 - cpu_visible: bool, 390 - ) -> Result<Retained<ProtocolObject<dyn MTLTexture>>, String> { 391 - let desc = unsafe { 392 - MTLTextureDescriptor::texture2DDescriptorWithPixelFormat_width_height_mipmapped( 393 - MTLPixelFormat::RGBA32Float, 394 - width as _, 395 - height as _, 396 - false, 397 - ) 398 - }; 239 + fn make_lab_texture(device: &ProtocolObject<dyn MTLDevice>, width: u32, height: u32, cpu_visible: bool) -> Result<Retained<ProtocolObject<dyn MTLTexture>>, String> { 240 + let desc = unsafe { MTLTextureDescriptor::texture2DDescriptorWithPixelFormat_width_height_mipmapped( 241 + MTLPixelFormat::RGBA32Float, 242 + width as _, 243 + height as _, 244 + false, ) }; 399 245 desc.setUsage(MTLTextureUsage::ShaderRead | MTLTextureUsage::ShaderWrite); 400 - if cpu_visible { 401 - desc.setStorageMode(MTLStorageMode::Shared); 402 - } 403 - device 404 - .newTextureWithDescriptor(&desc) 405 - .ok_or_else(|| "failed to create texture".to_string()) 406 - } 246 + if cpu_visible { desc.setStorageMode(MTLStorageMode::Shared); } 247 + device.newTextureWithDescriptor(&desc).ok_or_else(|| "failed to create texture".to_string()) } 407 248 408 - fn new_command_buffer( 409 - queue: &ProtocolObject<dyn MTLCommandQueue>, 410 - ) -> Result<Retained<ProtocolObject<dyn MTLCommandBuffer>>, String> { 411 - queue 412 - .commandBuffer() 413 - .ok_or_else(|| "failed to create command buffer".to_string()) 414 - } 249 + fn new_command_buffer(queue: &ProtocolObject<dyn MTLCommandQueue>) -> Result<Retained<ProtocolObject<dyn MTLCommandBuffer>>, String> { 250 + queue.commandBuffer().ok_or_else(|| "failed to create command buffer".to_string()) } 415 251 416 - fn new_compute_encoder( 417 - cb: &ProtocolObject<dyn MTLCommandBuffer>, 418 - ) -> Result<Retained<ProtocolObject<dyn MTLComputeCommandEncoder>>, String> { 419 - cb.computeCommandEncoder() 420 - .ok_or_else(|| "failed to create compute encoder".to_string()) 421 - } 252 + fn new_compute_encoder(cb: &ProtocolObject<dyn MTLCommandBuffer>) -> Result<Retained<ProtocolObject<dyn MTLComputeCommandEncoder>>, String> { 253 + cb.computeCommandEncoder() .ok_or_else(|| "failed to create compute encoder".to_string()) } 422 254 423 255 fn encode_mean( 424 - enc: &ProtocolObject<dyn MTLComputeCommandEncoder>, 425 - pso: &ProtocolObject<dyn MTLComputePipelineState>, 256 + enc: &ProtocolObject<dyn MTLComputeCommandEncoder>, 257 + pso: &ProtocolObject<dyn MTLComputePipelineState>, 426 258 candidates: &ProtocolObject<dyn objc2_metal::MTLBuffer>, 427 - target: &ProtocolObject<dyn MTLTexture>, 428 - means: &ProtocolObject<dyn objc2_metal::MTLBuffer>, 429 - params: &CommonParams, 430 - batch: u32, 431 - ) { 259 + target: &ProtocolObject<dyn MTLTexture>, 260 + means: &ProtocolObject<dyn objc2_metal::MTLBuffer>, 261 + params: &CommonParams, 262 + batch: u32 ) { 263 + 432 264 enc.setComputePipelineState(pso); 433 - unsafe { 434 - enc.setBuffer_offset_atIndex(Some(candidates), 0, 0); 435 - enc.setTexture_atIndex(Some(target), 0); 436 - enc.setBuffer_offset_atIndex(Some(means), 0, 1); 437 - enc.setBytes_length_atIndex( 438 - NonNull::new_unchecked((params as *const CommonParams).cast::<c_void>() as *mut c_void), 439 - std::mem::size_of::<CommonParams>() as _, 440 - 2, 441 - ); 442 - } 265 + unsafe { enc.setBuffer_offset_atIndex(Some(candidates), 0, 0); 266 + enc.setTexture_atIndex(Some(target), 0); 267 + enc.setBuffer_offset_atIndex(Some(means), 0, 1); 268 + enc.setBytes_length_atIndex( 269 + NonNull::new_unchecked((params as *const CommonParams).cast::<c_void>() as *mut c_void), 270 + std::mem::size_of::<CommonParams>() as _, 271 + 2); } 443 272 444 - let tg = MTLSize { 445 - width: 256, 446 - height: 1, 447 - depth: 1, 448 - }; 449 - let groups = MTLSize { 450 - width: batch as _, 451 - height: 1, 452 - depth: 1, 453 - }; 454 - enc.dispatchThreadgroups_threadsPerThreadgroup(groups, tg); 455 - } 273 + let tg = MTLSize { width: 256, height: 1, depth: 1 }; 274 + let groups = MTLSize { width: batch as _, height: 1, depth: 1 }; 275 + enc.dispatchThreadgroups_threadsPerThreadgroup(groups, tg); } 456 276 457 277 fn encode_score( 458 - enc: &ProtocolObject<dyn MTLComputeCommandEncoder>, 459 - pso: &ProtocolObject<dyn MTLComputePipelineState>, 278 + enc: &ProtocolObject<dyn MTLComputeCommandEncoder>, 279 + pso: &ProtocolObject<dyn MTLComputePipelineState>, 460 280 candidates: &ProtocolObject<dyn objc2_metal::MTLBuffer>, 461 - means: &ProtocolObject<dyn objc2_metal::MTLBuffer>, 462 - scores: &ProtocolObject<dyn objc2_metal::MTLBuffer>, 463 - target: &ProtocolObject<dyn MTLTexture>, 464 - canvas: &ProtocolObject<dyn MTLTexture>, 465 - params: &CommonParams, 466 - batch: u32, 467 - ) { 281 + means: &ProtocolObject<dyn objc2_metal::MTLBuffer>, 282 + scores: &ProtocolObject<dyn objc2_metal::MTLBuffer>, 283 + target: &ProtocolObject<dyn MTLTexture>, 284 + canvas: &ProtocolObject<dyn MTLTexture>, 285 + params: &CommonParams, 286 + batch: u32, ) { 287 + 468 288 enc.setComputePipelineState(pso); 469 - unsafe { 470 - enc.setBuffer_offset_atIndex(Some(candidates), 0, 0); 471 - enc.setBuffer_offset_atIndex(Some(means), 0, 1); 472 - enc.setBuffer_offset_atIndex(Some(scores), 0, 2); 473 - enc.setTexture_atIndex(Some(target), 0); 474 - enc.setTexture_atIndex(Some(canvas), 1); 475 - enc.setBytes_length_atIndex( 476 - NonNull::new_unchecked((params as *const CommonParams).cast::<c_void>() as *mut c_void), 477 - std::mem::size_of::<CommonParams>() as _, 478 - 3, 479 - ); 480 - } 289 + unsafe { enc.setBuffer_offset_atIndex(Some(candidates), 0, 0); 290 + enc.setBuffer_offset_atIndex(Some(means), 0, 1); 291 + enc.setBuffer_offset_atIndex(Some(scores), 0, 2); 292 + enc.setTexture_atIndex(Some(target), 0); 293 + enc.setTexture_atIndex(Some(canvas), 1); 294 + enc.setBytes_length_atIndex( 295 + NonNull::new_unchecked((params as *const CommonParams).cast::<c_void>() as *mut c_void), 296 + std::mem::size_of::<CommonParams>() as _, 297 + 3, ); } 481 298 482 - let tg = MTLSize { 483 - width: 256, 484 - height: 1, 485 - depth: 1, 486 - }; 487 - let groups = MTLSize { 488 - width: batch as _, 489 - height: 1, 490 - depth: 1, 491 - }; 492 - enc.dispatchThreadgroups_threadsPerThreadgroup(groups, tg); 493 - } 299 + let tg = MTLSize { width: 256, height: 1, depth: 1 }; 300 + let groups = MTLSize { width: batch as _, height: 1, depth: 1 }; 301 + enc.dispatchThreadgroups_threadsPerThreadgroup(groups, tg); } 494 302 495 - unsafe fn write_candidates( 496 - buf: &ProtocolObject<dyn objc2_metal::MTLBuffer>, 497 - candidates: &[Candidate], 498 - ) { 499 - let ptr = buf.contents().as_ptr().cast::<Candidate>(); 500 - unsafe { std::ptr::copy_nonoverlapping(candidates.as_ptr(), ptr, candidates.len()) }; 501 - } 303 + unsafe fn write_candidates(buf: &ProtocolObject<dyn objc2_metal::MTLBuffer>, candidates: &[Candidate]) { 304 + let ptr = buf.contents().as_ptr().cast::<Candidate>(); 305 + unsafe { std::ptr::copy_nonoverlapping(candidates.as_ptr(), ptr, candidates.len()) }; } 502 306 503 - unsafe fn read_scores( 504 - buf: &ProtocolObject<dyn objc2_metal::MTLBuffer>, 505 - n: usize, 506 - ) -> Vec<ScoreOut> { 507 - let ptr = buf.contents().as_ptr().cast::<ScoreOut>(); 508 - let mut out = vec![ScoreOut::default(); n]; 509 - unsafe { std::ptr::copy_nonoverlapping(ptr, out.as_mut_ptr(), n) }; 510 - out 511 - } 307 + unsafe fn read_scores(buf: &ProtocolObject<dyn objc2_metal::MTLBuffer>, n: usize) -> Vec<ScoreOut> { 308 + let ptr = buf.contents().as_ptr().cast::<ScoreOut>(); 309 + let mut out = vec![ScoreOut::default(); n]; 310 + unsafe { std::ptr::copy_nonoverlapping(ptr, out.as_mut_ptr(), n) }; 311 + out } 512 312 513 - unsafe fn write_full_texture_rgba32f( 514 - tex: &ProtocolObject<dyn MTLTexture>, 515 - width: u32, 516 - height: u32, 517 - rgba: &[f32], 518 - ) -> Result<(), String> { 519 - if rgba.len() != (width as usize) * (height as usize) * 4 { 520 - return Err("texture upload size mismatch".to_string()); 521 - } 522 - let region = MTLRegion { 523 - origin: objc2_metal::MTLOrigin { x: 0, y: 0, z: 0 }, 524 - size: MTLSize { 525 - width: width as _, 526 - height: height as _, 527 - depth: 1, 528 - }, 529 - }; 530 - unsafe { 531 - tex.replaceRegion_mipmapLevel_withBytes_bytesPerRow( 532 - region, 533 - 0, 534 - NonNull::new_unchecked(rgba.as_ptr().cast::<c_void>() as *mut c_void), 535 - (width as usize * 4 * std::mem::size_of::<f32>()) as _, 536 - ); 537 - } 538 - Ok(()) 539 - } 313 + unsafe fn write_full_texture_rgba32f(tex: &ProtocolObject<dyn MTLTexture>, width: u32, height: u32, rgba: &[f32]) -> Result<(), String> { 314 + if rgba.len() != (width as usize) * (height as usize) * 4 { return Err("texture upload size mismatch".to_string()); } 315 + let region = MTLRegion { origin: objc2_metal::MTLOrigin { x: 0, y: 0, z: 0 }, size: MTLSize { width: width as _, height: height as _, depth: 1, }, }; 316 + unsafe { tex.replaceRegion_mipmapLevel_withBytes_bytesPerRow( 317 + region, 318 + 0, 319 + NonNull::new_unchecked(rgba.as_ptr().cast::<c_void>() as *mut c_void), 320 + (width as usize * 4 * std::mem::size_of::<f32>()) as _, ); } 321 + Ok(()) } 540 322 541 - unsafe fn read_full_texture_rgba32f( 542 - tex: &ProtocolObject<dyn MTLTexture>, 543 - width: u32, 544 - height: u32, 545 - out: &mut [f32], 546 - ) -> Result<(), String> { 547 - if out.len() != (width as usize) * (height as usize) * 4 { 548 - return Err("texture readback size mismatch".to_string()); 549 - } 550 - let region = MTLRegion { 551 - origin: objc2_metal::MTLOrigin { x: 0, y: 0, z: 0 }, 552 - size: MTLSize { 553 - width: width as _, 554 - height: height as _, 555 - depth: 1, 556 - }, 557 - }; 558 - unsafe { 559 - tex.getBytes_bytesPerRow_fromRegion_mipmapLevel( 560 - NonNull::new_unchecked(out.as_mut_ptr().cast::<c_void>()), 561 - (width as usize * 4 * std::mem::size_of::<f32>()) as _, 562 - region, 563 - 0, 564 - ); 565 - } 566 - Ok(()) 567 - } 323 + unsafe fn read_full_texture_rgba32f(tex: &ProtocolObject<dyn MTLTexture>, width: u32, height: u32, out: &mut [f32]) -> Result<(), String> { 324 + if out.len() != (width as usize) * (height as usize) * 4 { return Err("texture readback size mismatch".to_string()); } 325 + let region = MTLRegion { origin: objc2_metal::MTLOrigin { x: 0, y: 0, z: 0 }, size: MTLSize { width: width as _, height: height as _, depth: 1, }, }; 326 + unsafe { tex.getBytes_bytesPerRow_fromRegion_mipmapLevel( 327 + NonNull::new_unchecked(out.as_mut_ptr().cast::<c_void>()), 328 + (width as usize * 4 * std::mem::size_of::<f32>()) as _, 329 + region, 330 + 0, ); } 331 + Ok(()) }
+46 -111
src/oklab.rs
··· 1 1 use rayon::prelude::*; 2 2 3 3 #[derive(Clone, Copy, Debug, Default)] 4 - pub struct Oklab { 5 - pub l: f32, 6 - pub a: f32, 7 - pub b: f32, 8 - } 9 - 4 + pub struct Oklab { pub l: f32, pub a: f32, pub b: f32, } 10 5 #[derive(Clone, Copy, Debug, Default)] 11 - pub struct Srgb { 12 - pub r: f32, 13 - pub g: f32, 14 - pub b: f32, 15 - } 6 + pub struct Srgb { pub r: f32, pub g: f32, pub b: f32, } 16 7 17 8 impl Srgb { 18 - pub fn from_u8(r: u8, g: u8, b: u8) -> Self { 19 - Self { 20 - r: r as f32 / 255.0, 21 - g: g as f32 / 255.0, 22 - b: b as f32 / 255.0, 23 - } 24 - } 9 + pub fn from_u8(r: u8, g: u8, b: u8) -> Self { Self { r: r as f32 / 255.0, g: g as f32 / 255.0, b: b as f32 / 255.0, } } 10 + pub fn to_u8(self) -> (u8, u8, u8) { let clamp = |x: f32| (x.clamp(0.0, 1.0) * 255.0).round() as u8; (clamp(self.r), clamp(self.g), clamp(self.b)) } 25 11 26 - pub fn to_u8(self) -> (u8, u8, u8) { 27 - let clamp = |x: f32| (x.clamp(0.0, 1.0) * 255.0).round() as u8; 28 - (clamp(self.r), clamp(self.g), clamp(self.b)) 29 - } 12 + fn linearize(c: f32) -> f32 { if c <= 0.04045 { c / 12.92 } else { ((c + 0.055) / 1.055).powf(2.4) } } 13 + fn delinearize(c: f32) -> f32 { if c <= 0.0031308 { c * 12.92 } else { 1.055 * c.powf(1.0 / 2.4) - 0.055 } } 30 14 31 - fn linearize(c: f32) -> f32 { 32 - if c <= 0.04045 { 33 - c / 12.92 34 - } else { 35 - ((c + 0.055) / 1.055).powf(2.4) 36 - } 37 - } 38 - 39 - fn delinearize(c: f32) -> f32 { 40 - if c <= 0.0031308 { 41 - c * 12.92 42 - } else { 43 - 1.055 * c.powf(1.0 / 2.4) - 0.055 44 - } 45 - } 46 - 47 - pub fn to_linear(self) -> (f32, f32, f32) { 48 - ( 49 - Self::linearize(self.r), 50 - Self::linearize(self.g), 51 - Self::linearize(self.b), 52 - ) 53 - } 54 - 55 - pub fn from_linear(r: f32, g: f32, b: f32) -> Self { 56 - Self { 57 - r: Self::delinearize(r), 58 - g: Self::delinearize(g), 59 - b: Self::delinearize(b), 60 - } 61 - } 62 - } 15 + pub fn to_linear(self) -> (f32, f32, f32) { ( Self::linearize(self.r), Self::linearize(self.g), Self::linearize(self.b), ) } 16 + pub fn from_linear(r: f32, g: f32, b: f32) -> Self { Self { r: Self::delinearize(r), g: Self::delinearize(g), b: Self::delinearize(b), } } } 63 17 64 18 impl Oklab { 65 - pub fn from_srgb(srgb: Srgb) -> Self { 66 - // srgb to linear srgb 67 - let (r, g, b) = srgb.to_linear(); 19 + pub fn from_srgb(srgb: Srgb) -> Self { 20 + // srgb to linear srgb 21 + let (r, g, b) = srgb.to_linear(); 68 22 69 - // linear srgb to lms (approximate cone response) 70 - let l = 0.4122214708 * r + 0.5363325363 * g + 0.0514459929 * b; 71 - let m = 0.2119034982 * r + 0.6806995451 * g + 0.1073969566 * b; 72 - let s = 0.0883024619 * r + 0.2817188376 * g + 0.6299787005 * b; 23 + // linear srgb to lms (approximate cone response) 24 + let l = 0.4122214708 * r + 0.5363325363 * g + 0.0514459929 * b; 25 + let m = 0.2119034982 * r + 0.6806995451 * g + 0.1073969566 * b; 26 + let s = 0.0883024619 * r + 0.2817188376 * g + 0.6299787005 * b; 73 27 74 - // apply non-linearity (cube root) 75 - let l_ = l.cbrt(); 76 - let m_ = m.cbrt(); 77 - let s_ = s.cbrt(); 28 + // apply non-linearity (cube root) 29 + let l_ = l.cbrt(); 30 + let m_ = m.cbrt(); 31 + let s_ = s.cbrt(); 78 32 79 - // transform to oklab coordinates (matrix m2) 80 - Self { 81 - l: 0.2104542553 * l_ + 0.7936177850 * m_ - 0.0040720468 * s_, 82 - a: 1.9779984951 * l_ - 2.4285922050 * m_ + 0.4505937099 * s_, 83 - b: 0.0259040371 * l_ + 0.7827717662 * m_ - 0.8086757660 * s_, 84 - } 85 - } 33 + // transform to oklab coordinates (matrix m2) 34 + Self { l: 0.2104542553 * l_ + 0.7936177850 * m_ - 0.0040720468 * s_ 35 + , a: 1.9779984951 * l_ - 2.4285922050 * m_ + 0.4505937099 * s_ 36 + , b: 0.0259040371 * l_ + 0.7827717662 * m_ - 0.8086757660 * s_ } } 86 37 87 - pub fn to_srgb(self) -> Srgb { 88 - // inverse of matrix m2 (oklab -> non-linear lms) 89 - let l_ = self.l + 0.3963377774 * self.a + 0.2158037573 * self.b; 90 - let m_ = self.l - 0.1055613458 * self.a - 0.0638541728 * self.b; 91 - let s_ = self.l - 0.0894841775 * self.a - 1.2914855480 * self.b; 38 + pub fn to_srgb(self) -> Srgb { 39 + // inverse of matrix m2 (oklab -> non-linear lms) 40 + let l_ = self.l + 0.3963377774 * self.a + 0.2158037573 * self.b; 41 + let m_ = self.l - 0.1055613458 * self.a - 0.0638541728 * self.b; 42 + let s_ = self.l - 0.0894841775 * self.a - 1.2914855480 * self.b; 92 43 93 - // revert non-linearity (cube) 94 - let l = l_ * l_ * l_; 95 - let m = m_ * m_ * m_; 96 - let s = s_ * s_ * s_; 97 - 98 - // inverse of matrix m1 (lms -> linear srgb) 99 - let r = 4.0767416621 * l - 3.3077115913 * m + 0.2309699292 * s; 100 - let g = -1.2684380046 * l + 2.6097574011 * m - 0.3413193965 * s; 101 - let b = -0.0041960863 * l - 0.7034186147 * m + 1.7076147010 * s; 44 + // revert non-linearity (cube) 45 + let l = l_ * l_ * l_; 46 + let m = m_ * m_ * m_; 47 + let s = s_ * s_ * s_; 102 48 103 - // convert linear srgb to srgb 104 - Srgb::from_linear(r, g, b) 105 - } 106 - } 49 + // inverse of matrix m1 (lms -> linear srgb) 50 + let r = 4.0767416621 * l - 3.3077115913 * m + 0.2309699292 * s; 51 + let g = -1.2684380046 * l + 2.6097574011 * m - 0.3413193965 * s; 52 + let b = -0.0041960863 * l - 0.7034186147 * m + 1.7076147010 * s; 107 53 108 - pub fn srgb8_to_oklab(r: u8, g: u8, b: u8) -> Oklab { 109 - Oklab::from_srgb(Srgb::from_u8(r, g, b)) 110 - } 54 + // convert linear srgb to srgb 55 + Srgb::from_linear(r, g, b) } } 111 56 112 - pub fn oklab_to_srgb8(lab: Oklab) -> (u8, u8, u8) { 113 - lab.to_srgb().to_u8() 114 - } 57 + pub fn srgb8_to_oklab(r: u8, g: u8, b: u8) -> Oklab { Oklab::from_srgb(Srgb::from_u8(r, g, b)) } 58 + pub fn oklab_to_srgb8(lab: Oklab) -> (u8, u8, u8) { lab.to_srgb().to_u8() } 115 59 116 60 pub fn rgb8_to_oklab_parallel(rgb: &[u8]) -> Result<Vec<Oklab>, String> { 117 - if rgb.len() % 3 != 0 { 118 - return Err("rgb buffer must be 3 bytes per pixel".into()); 119 - } 120 - Ok(rgb.par_chunks(3).map(|c| srgb8_to_oklab(c[0], c[1], c[2])).collect()) 121 - } 61 + if rgb.len() % 3 != 0 { return Err("rgb buffer must be 3 bytes per pixel".into()); } 62 + Ok(rgb.par_chunks(3).map(|c| srgb8_to_oklab(c[0], c[1], c[2])).collect()) } 122 63 123 64 pub fn avg_oklab_parallel(pixels: &[Oklab]) -> Oklab { 124 - if pixels.is_empty() { 125 - return Oklab::default(); 126 - } 127 - let (l, a, b) = pixels 128 - .par_iter() 129 - .map(|p| (p.l, p.a, p.b)) 130 - .reduce(|| (0.0, 0.0, 0.0), |(l1, a1, b1), (l2, a2, b2)| (l1 + l2, a1 + a2, b1 + b2)); 131 - let n = pixels.len() as f32; 132 - Oklab { l: l / n, a: a / n, b: b / n } 133 - } 65 + if pixels.is_empty() { return Oklab::default(); } 66 + let (l, a, b) = pixels.par_iter().map(|p| (p.l, p.a, p.b)).reduce(|| (0.0, 0.0, 0.0), |(l1, a1, b1), (l2, a2, b2)| (l1 + l2, a1 + a2, b1 + b2)); 67 + let n = pixels.len() as f32; 68 + Oklab { l: l / n, a: a / n, b: b / n } }
+102 -190
src/pipeline.rs
··· 11 11 use std::time::Instant; 12 12 13 13 pub enum ResolvedOutput { 14 - SingleFile(PathBuf), 15 - FrameSequence { dir: PathBuf, nth: u32 }, 16 - } 14 + SingleFile(PathBuf), 15 + FrameSequence { dir: PathBuf, nth: u32 }, } 17 16 18 17 impl ResolvedOutput { 19 - pub fn from_config(cfg: &Config, base_path: PathBuf) -> Self { 20 - match cfg.nth { 21 - Some(nth) => Self::FrameSequence { dir: base_path, nth }, 22 - None => Self::SingleFile(base_path), 23 - } 24 - } 18 + pub fn from_config(cfg: &Config, base_path: PathBuf) -> Self { 19 + match cfg.nth { 20 + Some(nth) => Self::FrameSequence { dir: base_path, nth }, 21 + None => Self::SingleFile(base_path), } } 25 22 26 - fn prepare_dirs(&self) -> Result<(), String> { 27 - match self { 28 - Self::FrameSequence { dir, .. } => std::fs::create_dir_all(dir) 29 - .map_err(|e| format!("failed to create dir {}: {e}", dir.display())), 30 - Self::SingleFile(file) => { 31 - if let Some(parent) = file.parent().filter(|p| !p.as_os_str().is_empty()) { 32 - std::fs::create_dir_all(parent) 33 - .map_err(|e| format!("failed to create dir {}: {e}", parent.display()))?; 34 - } 35 - Ok(()) 36 - } 37 - } 38 - } 39 - } 23 + fn prepare_dirs(&self) -> Result<(), String> { 24 + match self { 25 + Self::FrameSequence { dir, .. } => std::fs::create_dir_all(dir).map_err(|e| format!("failed to create dir {}: {e}", dir.display())), 26 + Self::SingleFile(file) => { 27 + if let Some(parent) = file.parent().filter(|p| !p.as_os_str().is_empty()) { std::fs::create_dir_all(parent).map_err(|e| format!("failed to create dir {}: {e}", parent.display()))?; } 28 + Ok(()) } } } } 40 29 41 30 pub fn process_dir(cfg: &Config, input_dir: &Path, out_root: &Path) -> Result<(), String> { 42 - let started = Instant::now(); 43 - let mut files = fs::walk_files_recursive(input_dir)?; 44 - files.sort(); 31 + let started = Instant::now(); 32 + let mut files = fs::walk_files_recursive(input_dir)?; 33 + files.sort(); 45 34 46 - let mut processed: u64 = 0; 47 - let mut skipped: u64 = 0; 35 + let mut processed: u64 = 0; 36 + let mut skipped: u64 = 0; 48 37 49 - for input_file in files { 50 - let rel = match input_file.strip_prefix(input_dir) { 51 - Ok(v) => v, 52 - Err(_) => continue, 53 - }; 38 + for input_file in files { 39 + let rel = match input_file.strip_prefix(input_dir) { Ok(v) => v 40 + , Err(_) => continue, }; 54 41 55 - let base_path = match cfg.nth { 56 - Some(_) => out_root.join(rel).with_extension(""), 57 - None => out_root.join(rel).with_extension("png"), 58 - }; 59 - let output = ResolvedOutput::from_config(cfg, base_path); 42 + let base_path = match cfg.nth { Some(_) => out_root.join(rel).with_extension("") 43 + , None => out_root.join(rel).with_extension("png"), }; 60 44 61 - let Ok(_) = image::open(&input_file) else { 62 - skipped += 1; 63 - if log::level() >= 2 { 64 - log::l2(format_args!("skip: {}", input_file.display())); 65 - } 66 - continue; 67 - }; 45 + let output = ResolvedOutput::from_config(cfg, base_path); 68 46 69 - match process_one(cfg, &input_file, output) { 70 - Ok(final_path) => { 71 - processed += 1; 72 - if log::level() >= 2 { 73 - log::l2(format_args!( 74 - "output: {} (input: {})", 75 - final_path.display(), 76 - input_file.display() 77 - )); 78 - } 79 - } 80 - Err(e) => { 47 + let Ok(_) = image::open(&input_file) else { 81 48 skipped += 1; 82 - log::l1(format_args!("failed: {} ({})", input_file.display(), e)); 83 - } 84 - } 85 - } 49 + if log::level() >= 2 { log::l2(format_args!("skip: {}", input_file.display())); } 50 + continue; }; 51 + 52 + match process_one(cfg, &input_file, output) { 53 + Ok(final_path) => { processed += 1; 54 + if log::level() >= 2 { log::l2(format_args!( "output: {} (input: {})", final_path.display(), input_file.display() )); } } 55 + Err(e) => { skipped += 1; log::l1(format_args!("failed: {} ({})", input_file.display(), e)); } } } 86 56 87 - let elapsed_s = started.elapsed().as_secs_f64(); 88 - println!( 89 - "done: processed {} images, skipped {}, elapsed {:.2}s", 90 - processed, skipped, elapsed_s 91 - ); 92 - Ok(()) 93 - } 57 + let elapsed_s = started.elapsed().as_secs_f64(); 58 + println!( "done: processed {} images, skipped {}, elapsed {:.2}s", processed, skipped, elapsed_s ); 59 + Ok(()) } 94 60 95 61 pub fn process_one(cfg: &Config, input_path: &Path, output: ResolvedOutput) -> Result<PathBuf, String> { 96 - let started = Instant::now(); 62 + let started = Instant::now(); 97 63 98 - let img = image::open(input_path) 99 - .map_err(|e| format!("failed to load image {}: {e}", input_path.display()))?; 100 - let (w, h) = img.dimensions(); 101 - let rgb = img.to_rgb8(); 102 - let target = rgb8_to_oklab_parallel(rgb.as_raw())?; 64 + let img = image::open(input_path).map_err(|e| format!("failed to load image {}: {e}", input_path.display()))?; 65 + let (w, h) = img.dimensions(); 66 + let rgb = img.to_rgb8(); 67 + let target = rgb8_to_oklab_parallel(rgb.as_raw())?; 103 68 104 - let current_canvas = if let Some(current_path) = &cfg.current { 105 - let img = image::open(current_path) 106 - .map_err(|e| format!("failed to load current image {}: {e}", current_path))?; 107 - let (cw, ch) = img.dimensions(); 108 - if cw != w || ch != h { 109 - return Err(format!( 110 - "current image size mismatch: got {}x{}, expected {}x{}", 111 - cw, ch, w, h 112 - )); 113 - } 114 - let rgb = img.to_rgb8(); 115 - Some(rgb8_to_oklab_parallel(rgb.as_raw())?) 116 - } else { 117 - None 118 - }; 69 + let current_canvas = if let Some(current_path) = &cfg.current { 70 + let img = image::open(current_path).map_err(|e| format!("failed to load current image {}: {e}", current_path))?; 71 + let (cw, ch) = img.dimensions(); 119 72 120 - log::l2(format_args!( 121 - "input: {} ({}x{}), batch: {}, seed: {}, alpha: {}, max_gpu: {}", 122 - input_path.display(), 123 - w, 124 - h, 125 - cfg.batch, 126 - cfg.seed, 127 - cfg.alpha, 128 - cfg.max_gpu, 129 - )); 73 + if cw != w || ch != h { return Err(format!("current image size mismatch: got {}x{}, expected {}x{}", cw, ch, w, h )); } 130 74 131 - let avg = avg_oklab_parallel(&target); 132 - let bg = match cfg.bg { 133 - Bg::Avg => avg, 134 - Bg::RgbU8 { r, g, b } => srgb8_to_oklab(r, g, b), 135 - }; 75 + let rgb = img.to_rgb8(); 76 + Some(rgb8_to_oklab_parallel(rgb.as_raw())?) } else { None }; 136 77 137 - let max_splines = cfg.number.unwrap_or((w as f64 * h as f64).powf(0.7).round() as u32); 138 - let mut rng = Rng::new(cfg.seed); 78 + log::l2(format_args!("input: {} ({}x{}), batch: {}, seed: {}, alpha: {}, max_gpu: {}", input_path.display(), w, h, cfg.batch, cfg.seed, cfg.alpha, cfg.max_gpu)); 139 79 140 - let metal = MetalContext::new(w, h, cfg.batch, cfg.max_gpu)?; 141 - match &current_canvas { 142 - Some(canvas) => metal.upload_target_and_set_canvas(&target, canvas)?, 143 - None => metal.upload_target_and_init_canvas(&target, bg)?, 144 - } 145 - output.prepare_dirs()?; 80 + let avg = avg_oklab_parallel(&target); 81 + let bg = match cfg.bg { Bg::Avg => avg 82 + , Bg::RgbU8 { r, g, b } => srgb8_to_oklab(r, g, b), }; 146 83 147 - if let ResolvedOutput::FrameSequence { dir, nth } = &output { 148 - log::l2(format_args!("frames: {} (every {} accepted)", dir.display(), nth)); 149 - } 150 - log::l1(format_args!("target splines: {}", max_splines)); 151 - let mut batch_logger = log::BatchLogger::new(); 84 + let max_splines = cfg.number.unwrap_or((w as f64 * h as f64).powf(0.7).round() as u32); 85 + let mut rng = Rng::new(cfg.seed); 152 86 153 - let mut accepted_total: u32 = 0; 154 - let mut consecutive_stagnant_batches: u32 = 0; 155 - let mut batch_idx: u64 = 0; 87 + let metal = MetalContext::new(w, h, cfg.batch, cfg.max_gpu)?; 88 + match &current_canvas { Some(canvas) => metal.upload_target_and_set_canvas(&target, canvas)? 89 + , None => metal.upload_target_and_init_canvas(&target, bg)?, } 90 + output.prepare_dirs()?; 156 91 157 - while accepted_total < max_splines { 158 - batch_idx += 1; 159 - let remaining = max_splines - accepted_total; 160 - let candidates = sampling::sample_candidates(&mut rng, w, h, cfg.batch); 161 - let accepted_before = accepted_total; 92 + if let ResolvedOutput::FrameSequence { dir, nth } = &output { log::l2(format_args!("frames: {} (every {} accepted)", dir.display(), nth)); } 93 + log::l1(format_args!("target splines: {}", max_splines)); 94 + let mut batch_logger = log::BatchLogger::new(); 162 95 163 - match &output { 164 - ResolvedOutput::FrameSequence { dir, nth } => { 165 - let (_, accepted_after) = metal.process_batch_checkpointed( 166 - &candidates, cfg.alpha, remaining, accepted_total, *nth, 167 - |metal, total| { 168 - let frame = metal.read_canvas()?; 169 - save_oklab_png(dir.join(format!("frame_{:06}.png", total)), &frame, w, h) 170 - }, 171 - )?; 172 - accepted_total = accepted_after; 173 - } 174 - ResolvedOutput::SingleFile(_) => { 175 - let accepted = metal.process_batch(&candidates, cfg.alpha, remaining)?; 176 - accepted_total += accepted.iter().filter(|v| **v).count() as u32; 177 - } 178 - }; 96 + let mut accepted_total: u32 = 0; 97 + let mut consecutive_stagnant_batches: u32 = 0; 98 + let mut batch_idx: u64 = 0; 179 99 180 - let accepted_now = accepted_total - accepted_before; 181 - let accept_threshold = (cfg.batch as f64) * (cfg.min_accept_ratio as f64); 182 - let stagnant = (accepted_now as f64) + 1e-9 < accept_threshold; 183 - consecutive_stagnant_batches = if stagnant { consecutive_stagnant_batches + 1 } else { 0 }; 100 + while accepted_total < max_splines { 101 + batch_idx += 1; 102 + let remaining = max_splines - accepted_total; 103 + let candidates = sampling::sample_candidates(&mut rng, w, h, cfg.batch); 104 + let accepted_before = accepted_total; 184 105 185 - batch_logger.log_batch(log::BatchInfo { 186 - batch_idx, 187 - accepted_total, 188 - max_splines, 189 - accepted_now, 190 - batch_size: cfg.batch, 191 - remaining: max_splines.saturating_sub(accepted_total), 192 - consecutive_stagnant: consecutive_stagnant_batches, 193 - max_stagnant_batches: cfg.max_stagnant_batches, 194 - elapsed_s: started.elapsed().as_secs_f64(), 195 - }); 106 + match &output { 107 + ResolvedOutput::FrameSequence { dir, nth } => { let (_, accepted_after) = metal.process_batch_checkpointed( 108 + &candidates, cfg.alpha, remaining, accepted_total, *nth, 109 + |metal, total| { let frame = metal.read_canvas()?; 110 + save_oklab_png(dir.join(format!("frame_{:06}.png", total)), &frame, w, h) })?; 111 + accepted_total = accepted_after; } 112 + ResolvedOutput::SingleFile(_) => { let accepted = metal.process_batch(&candidates, cfg.alpha, remaining)?; 113 + accepted_total += accepted.iter().filter(|v| **v).count() as u32; } }; 114 + 115 + let accepted_now = accepted_total - accepted_before; 116 + let accept_threshold = (cfg.batch as f64) * (cfg.min_accept_ratio as f64); 117 + let stagnant = (accepted_now as f64) + 1e-9 < accept_threshold; 118 + consecutive_stagnant_batches = if stagnant { consecutive_stagnant_batches + 1 } else { 0 }; 119 + 120 + batch_logger.log_batch(log::BatchInfo { batch_idx 121 + , accepted_total 122 + , max_splines 123 + , accepted_now 124 + , batch_size: cfg.batch 125 + , remaining: max_splines.saturating_sub(accepted_total) 126 + , consecutive_stagnant: consecutive_stagnant_batches 127 + , max_stagnant_batches: cfg.max_stagnant_batches 128 + , elapsed_s: started.elapsed().as_secs_f64() }); 196 129 197 - if consecutive_stagnant_batches >= cfg.max_stagnant_batches { 198 - break; 199 - } 200 - } 130 + if consecutive_stagnant_batches >= cfg.max_stagnant_batches { break; } } 201 131 202 - let canvas = metal.read_canvas()?; 203 - let final_path = match &output { 204 - ResolvedOutput::FrameSequence { dir, .. } => { 205 - let path = dir.join("final.png"); 206 - save_oklab_png(&path, &canvas, w, h)?; 207 - path 208 - } 209 - ResolvedOutput::SingleFile(file) => { 210 - save_oklab_png(file, &canvas, w, h)?; 211 - file.clone() 212 - } 213 - }; 132 + let canvas = metal.read_canvas()?; 133 + let final_path = match &output { 134 + ResolvedOutput::FrameSequence { dir, .. } => { let path = dir.join("final.png"); 135 + save_oklab_png(&path, &canvas, w, h)?; 136 + path } 137 + ResolvedOutput::SingleFile(file) => { save_oklab_png(file, &canvas, w, h)?; 138 + file.clone() } }; 214 139 215 - log::l2(format_args!( 216 - "done: {} (accepted: {}, target: {}, elapsed: {:.2}s)", 217 - final_path.display(), accepted_total, max_splines, started.elapsed().as_secs_f64() 218 - )); 219 - Ok(final_path) 220 - } 140 + log::l2(format_args!("done: {} (accepted: {}, target: {}, elapsed: {:.2}s)", final_path.display(), accepted_total, max_splines, started.elapsed().as_secs_f64())); 141 + Ok(final_path) } 221 142 222 143 fn save_oklab_png(path: impl AsRef<Path>, data: &[Oklab], w: u32, h: u32) -> Result<(), String> { 223 144 let path = path.as_ref(); 224 - if data.len() != (w as usize) * (h as usize) { 225 - return Err("output size mismatch".into()); 226 - } 227 - let rgb: Vec<u8> = data 228 - .par_iter() 229 - .flat_map(|lab| { let (r, g, b) = oklab_to_srgb8(*lab); [r, g, b] }) 230 - .collect(); 231 - ImageBuffer::<Rgb<u8>, _>::from_raw(w, h, rgb) 232 - .ok_or("buffer size mismatch")? 233 - .save(path) 234 - .map_err(|e| format!("failed to save: {e}")) 235 - } 145 + if data.len() != (w as usize) * (h as usize) { return Err("output size mismatch".into()); } 146 + let rgb: Vec<u8> = data.par_iter().flat_map(|lab| { let (r, g, b) = oklab_to_srgb8(*lab); [r, g, b] }).collect(); 147 + ImageBuffer::<Rgb<u8>, _>::from_raw(w, h, rgb).ok_or("buffer size mismatch")?.save(path).map_err(|e| format!("failed to save: {e}")) }
+9 -20
src/rng.rs
··· 1 1 #[derive(Clone, Copy, Debug)] 2 - pub struct Rng { 3 - state: u64, 4 - } 2 + pub struct Rng { state: u64, } 5 3 6 4 impl Rng { 7 - pub fn new(seed: u64) -> Self { 8 - Self { 9 - state: seed ^ 0x9e3779b97f4a7c15, 10 - } 11 - } 5 + pub fn new(seed: u64) -> Self { Self { state: seed ^ 0x9e3779b97f4a7c15, } } 12 6 13 - pub fn next_u64(&mut self) -> u64 { 14 - self.state = self.state.wrapping_add(0x9e3779b97f4a7c15); 15 - let mut z = self.state; 16 - z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9); 17 - z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb); 18 - z ^ (z >> 31) 19 - } 7 + pub fn next_u64(&mut self) -> u64 { 8 + self.state = self.state.wrapping_add(0x9e3779b97f4a7c15); 9 + let mut z = self.state; 10 + z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9); 11 + z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb); 12 + z ^ (z >> 31) } 20 13 21 - pub fn next_f32(&mut self) -> f32 { 22 - let bits = (self.next_u64() >> 40) as u32; 23 - (bits as f32) * (1.0 / ((1u32 << 24) as f32)) 24 - } 25 - } 14 + pub fn next_f32(&mut self) -> f32 { let bits = (self.next_u64() >> 40) as u32; (bits as f32) * (1.0 / ((1u32 << 24) as f32)) } }
+19 -34
src/sampling.rs
··· 2 2 use crate::rng::Rng; 3 3 4 4 pub fn sample_candidates(rng: &mut Rng, w: u32, h: u32, batch: u32) -> Vec<Candidate> { 5 - let max_x = (w.saturating_sub(1)) as f32; 6 - let max_y = (h.saturating_sub(1)) as f32; 7 - let mut sample_pt = || [ 8 - (rng.next_f32() * (max_x + 1.0)).clamp(0.0, max_x), 9 - (rng.next_f32() * (max_y + 1.0)).clamp(0.0, max_y), 10 - ]; 5 + let max_x = (w.saturating_sub(1)) as f32; 6 + let max_y = (h.saturating_sub(1)) as f32; 7 + let mut sample_pt = || [ (rng.next_f32() * (max_x + 1.0)).clamp(0.0, max_x), (rng.next_f32() * (max_y + 1.0)).clamp(0.0, max_y), ]; 11 8 12 - (0..batch) 13 - .map(|_| { 14 - let [p0, p1, p2, p3] = std::array::from_fn(|_| sample_pt()); 15 - let (bx, by, bw, bh) = bbox_from_points(p0, p1, p2, p3, w, h); 16 - Candidate { p0, p1, p2, p3, bx, by, bw, bh } 17 - }) 18 - .collect() 19 - } 9 + (0..batch).map(|_| { let [p0, p1, p2, p3] = std::array::from_fn(|_| sample_pt()); 10 + let (bx, by, bw, bh) = bbox_from_points(p0, p1, p2, p3, w, h); 11 + Candidate { p0, p1, p2, p3, bx, by, bw, bh } }).collect() } 20 12 21 - pub fn bbox_from_points( 22 - p0: [f32; 2], 23 - p1: [f32; 2], 24 - p2: [f32; 2], 25 - p3: [f32; 2], 26 - w: u32, 27 - h: u32, 28 - ) -> (u32, u32, u32, u32) { 29 - let min_x = p0[0].min(p1[0]).min(p2[0]).min(p3[0]); 30 - let max_x = p0[0].max(p1[0]).max(p2[0]).max(p3[0]); 31 - let min_y = p0[1].min(p1[1]).min(p2[1]).min(p3[1]); 32 - let max_y = p0[1].max(p1[1]).max(p2[1]).max(p3[1]); 13 + pub fn bbox_from_points( p0: [f32; 2], p1: [f32; 2], p2: [f32; 2], p3: [f32; 2], 14 + w: u32, h: u32 ) -> (u32, u32, u32, u32) { 15 + let min_x = p0[0].min(p1[0]).min(p2[0]).min(p3[0]); 16 + let max_x = p0[0].max(p1[0]).max(p2[0]).max(p3[0]); 17 + let min_y = p0[1].min(p1[1]).min(p2[1]).min(p3[1]); 18 + let max_y = p0[1].max(p1[1]).max(p2[1]).max(p3[1]); 33 19 34 - let bx = (min_x.floor() as i64 - 1).clamp(0, (w.saturating_sub(1)) as i64) as u32; 35 - let by = (min_y.floor() as i64 - 1).clamp(0, (h.saturating_sub(1)) as i64) as u32; 36 - let ex = (max_x.ceil() as i64 + 1).clamp(0, (w.saturating_sub(1)) as i64) as u32; 37 - let ey = (max_y.ceil() as i64 + 1).clamp(0, (h.saturating_sub(1)) as i64) as u32; 20 + let bx = (min_x.floor() as i64 - 1).clamp(0, (w.saturating_sub(1)) as i64) as u32; 21 + let by = (min_y.floor() as i64 - 1).clamp(0, (h.saturating_sub(1)) as i64) as u32; 22 + let ex = (max_x.ceil() as i64 + 1).clamp(0, (w.saturating_sub(1)) as i64) as u32; 23 + let ey = (max_y.ceil() as i64 + 1).clamp(0, (h.saturating_sub(1)) as i64) as u32; 38 24 39 - let bw = (ex - bx + 1).max(1); 40 - let bh = (ey - by + 1).max(1); 41 - (bx, by, bw, bh) 42 - } 25 + let bw = (ex - bx + 1).max(1); 26 + let bh = (ey - by + 1).max(1); 27 + (bx, by, bw, bh) }