A better Rust ATProto crate
103
fork

Configure Feed

Select the types of activity you want to include in your feed.

at main 599 lines 25 kB view raw
1use crate::error::{CodegenError, Result}; 2use quote::quote; 3use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; 4 5use super::CodeGenerator; 6use super::prettify::{FileOutput, ResolvedImports}; 7use super::utils::{make_ident, sanitize_name}; 8use crate::ref_utils::NsidPath; 9 10impl<'c> CodeGenerator<'c> { 11 /// Generate all code for the corpus, organized by file 12 /// Returns a map of file paths to FileOutput with reordered tokens 13 pub fn generate_all(&self) -> Result<BTreeMap<std::path::PathBuf, FileOutput>> { 14 let mut file_contents: BTreeMap<std::path::PathBuf, Vec<super::prettify::GeneratedCode>> = 15 BTreeMap::new(); 16 let mut file_nsids: BTreeMap<std::path::PathBuf, String> = BTreeMap::new(); 17 18 // Step 1: Enumerate local type names per file. 19 // Also collect all file paths so we can determine submodule names. 20 let mut file_local_names: BTreeMap<std::path::PathBuf, HashSet<String>> = BTreeMap::new(); 21 let mut all_file_paths: BTreeSet<std::path::PathBuf> = BTreeSet::new(); 22 for (nsid, doc) in self.corpus.iter() { 23 let file_path = self.nsid_to_file_path(nsid.as_ref()); 24 all_file_paths.insert(file_path.clone()); 25 let names = file_local_names.entry(file_path).or_default(); 26 for def_name in doc.defs.keys() { 27 names.insert(self.def_to_type_name(nsid.as_ref(), def_name.as_ref())); 28 } 29 } 30 31 // Determine submodule names for each file. If `foo.rs` exists and `foo/bar.rs` 32 // also exists, then `bar` is a submodule of `foo`. These names must be treated 33 // as reserved — importing `use crate::something::bar;` would collide with 34 // `pub mod bar;`. 35 let mut file_submodule_names: BTreeMap<std::path::PathBuf, HashSet<String>> = 36 BTreeMap::new(); 37 for file_path in &all_file_paths { 38 // For a file like `app_bsky/feed/post.rs`, the parent module file is 39 // `app_bsky/feed.rs`. Check if this file's parent has an entry. 40 if let Some(parent_dir) = file_path.parent() { 41 let parent_file = parent_dir.with_extension("rs"); 42 if all_file_paths.contains(&parent_file) 43 || file_local_names.contains_key(&parent_file) 44 { 45 // This file is a submodule of parent_file. 46 if let Some(stem) = file_path.file_stem().and_then(|s| s.to_str()) { 47 file_submodule_names 48 .entry(parent_file) 49 .or_default() 50 .insert(stem.to_string()); 51 } 52 } 53 } 54 } 55 56 // Step 2: Run collection pass and build ResolvedImports for each file. 57 // Multiple NSIDs can map to the same file (e.g. `app.rocksky.album` and 58 // `app.rocksky.album.defs` both output to `album.rs`), so we accumulate 59 // imports per file path before resolving. 60 let mut file_imports_map: BTreeMap<std::path::PathBuf, super::prettify::ImportSet> = 61 BTreeMap::new(); 62 for (nsid, doc) in self.corpus.iter() { 63 let file_path = self.nsid_to_file_path(nsid.as_ref()); 64 let file_imports = file_imports_map.entry(file_path).or_default(); 65 for (def_name, def) in &doc.defs { 66 file_imports.merge(self.collect_def(nsid.as_ref(), def_name.as_ref(), def)); 67 } 68 } 69 70 let mut file_resolved: BTreeMap<std::path::PathBuf, ResolvedImports> = BTreeMap::new(); 71 for (file_path, file_imports) in &file_imports_map { 72 let local_names = file_local_names.get(file_path).cloned().unwrap_or_default(); 73 74 let lexicon_paths: BTreeMap<String, String> = file_imports 75 .lexicon_refs 76 .iter() 77 .filter_map(|ref_str| self.ref_to_crate_path(ref_str)) 78 .collect(); 79 80 let submodule_names = file_submodule_names 81 .get(file_path) 82 .cloned() 83 .unwrap_or_default(); 84 let resolved = ResolvedImports::resolve( 85 file_imports, 86 &local_names, 87 &submodule_names, 88 self.mode, 89 &lexicon_paths, 90 ); 91 file_resolved.insert(file_path.clone(), resolved); 92 } 93 94 // Step 3: Generate code for all lexicons 95 for (nsid, doc) in self.corpus.iter() { 96 let file_path = self.nsid_to_file_path(nsid.as_ref()); 97 98 // Track which NSID this file is for 99 file_nsids.insert(file_path.clone(), nsid.to_string()); 100 101 // Get the per-file ResolvedImports (built in Step 2) 102 let resolved = file_resolved 103 .get(&file_path) 104 .expect("resolved imports built for every file"); 105 106 for (_def_name, def) in &doc.defs { 107 let generated = 108 self.generate_def(nsid.as_ref(), _def_name.as_ref(), def, resolved)?; 109 file_contents 110 .entry(file_path.clone()) 111 .or_default() 112 .push(generated); 113 } 114 } 115 116 // Combine all tokens for each file using FileOutput::combine for reordering 117 let mut result = BTreeMap::new(); 118 for (path, generated_vec) in file_contents { 119 let nsid = file_nsids.get(&path).cloned(); 120 let resolved = file_resolved 121 .get(&path) 122 .expect("resolved imports built for every file"); 123 let file_output = FileOutput::combine(generated_vec, nsid, resolved); 124 result.insert(path, file_output); 125 } 126 127 Ok(result) 128 } 129 130 /// Generate parent module files with pub mod declarations 131 pub fn generate_module_tree( 132 &self, 133 file_map: &BTreeMap<std::path::PathBuf, FileOutput>, 134 defs_only: &BTreeMap<std::path::PathBuf, FileOutput>, 135 subscription_files: &HashSet<std::path::PathBuf>, 136 ) -> BTreeMap<std::path::PathBuf, FileOutput> { 137 // Track what modules each directory needs to declare 138 // Key: directory path, Value: set of module names (file stems) 139 let mut dir_modules: BTreeMap<std::path::PathBuf, BTreeSet<String>> = BTreeMap::new(); 140 141 // Collect all parent directories that have files 142 let mut all_dirs: BTreeSet<std::path::PathBuf> = BTreeSet::new(); 143 for path in file_map.keys() { 144 if let Some(parent_dir) = path.parent() { 145 all_dirs.insert(parent_dir.to_path_buf()); 146 } 147 } 148 149 for path in file_map.keys() { 150 if let Some(parent_dir) = path.parent() { 151 if let Some(file_stem) = path.file_stem().and_then(|s| s.to_str()) { 152 // Skip mod.rs and lib.rs - they're module files, not modules to declare 153 if file_stem == "mod" || file_stem == "lib" { 154 continue; 155 } 156 157 // Always add the module declaration to parent 158 dir_modules 159 .entry(parent_dir.to_path_buf()) 160 .or_default() 161 .insert(file_stem.to_string()); 162 } 163 } 164 } 165 166 // Generate module files 167 let mut result = BTreeMap::new(); 168 169 for (dir, module_names) in dir_modules { 170 let mod_file_path = if dir.components().count() == 0 { 171 // Root directory -> lib.rs for library crates 172 std::path::PathBuf::from("lib.rs") 173 } else { 174 // Subdirectory: app_bsky/feed -> app_bsky/feed.rs (Rust 2018 style) 175 let dir_name = dir.file_name().and_then(|s| s.to_str()).unwrap_or("mod"); 176 let sanitized_dir_name = sanitize_name(dir_name); 177 let mut path = dir 178 .parent() 179 .unwrap_or_else(|| std::path::Path::new("")) 180 .to_path_buf(); 181 path.push(format!("{}.rs", sanitized_dir_name)); 182 path 183 }; 184 185 let is_root = dir.components().count() == 0; 186 let mods: Vec<_> = module_names 187 .iter() 188 .map(|name| { 189 let ident = make_ident(name); 190 191 // Check if this module is a subscription endpoint 192 let mut module_path = dir.clone(); 193 module_path.push(format!("{}.rs", name)); 194 let is_subscription = subscription_files.contains(&module_path); 195 196 if is_root && name != "builder_types" { 197 // Top-level modules get feature gates (except builder_types which is always needed) 198 quote! { 199 #[cfg(feature = #name)] 200 pub mod #ident; 201 } 202 } else if is_subscription { 203 // Subscription modules get streaming feature gate 204 quote! { 205 #[cfg(feature = "streaming")] 206 pub mod #ident; 207 } 208 } else { 209 quote! { pub mod #ident; } 210 } 211 }) 212 .collect(); 213 214 // If this file already exists in defs_only (e.g., from defs), merge the content 215 let module_tokens = if is_root { 216 // lib.rs needs extern crate alloc for no_std compatibility 217 quote! { 218 extern crate alloc; #(#mods)* 219 } 220 } else { 221 quote! { #(#mods)* } 222 }; 223 if let Some(existing_output) = defs_only.get(&mod_file_path) { 224 // Put module declarations FIRST, then existing defs content 225 let existing_tokens = &existing_output.tokens; 226 let merged_tokens = quote! { 227 #module_tokens 228 #existing_tokens 229 }; 230 result.insert( 231 mod_file_path, 232 FileOutput { 233 tokens: merged_tokens, 234 imports: existing_output.imports.clone(), 235 nsid: existing_output.nsid.clone(), 236 }, 237 ); 238 } else { 239 result.insert( 240 mod_file_path, 241 FileOutput { 242 tokens: module_tokens, 243 imports: Default::default(), 244 nsid: None, 245 }, 246 ); 247 } 248 } 249 250 result 251 } 252 253 /// Write all generated code to disk 254 pub fn write_to_disk(&self, output_dir: &std::path::Path) -> Result<()> { 255 // Generate all code (defs only) 256 let defs_files = self.generate_all()?; 257 let mut all_files = defs_files.clone(); 258 259 // Generate common builder types (Set, Unset, IsSet, IsUnset) 260 let common_types_path = std::path::PathBuf::from("builder_types.rs"); 261 let common_types_tokens = super::builder_gen::common::generate_common_types(); 262 all_files.insert( 263 common_types_path, 264 FileOutput { 265 tokens: common_types_tokens, 266 imports: Default::default(), 267 nsid: None, 268 }, 269 ); 270 271 // Get subscription files for feature gating 272 let subscription_files = self.subscription_files.borrow(); 273 274 // Generate module tree iteratively until no new files appear 275 loop { 276 let module_map = 277 self.generate_module_tree(&all_files, &defs_files, &subscription_files); 278 let old_count = all_files.len(); 279 280 // Merge new module files 281 for (path, file_output) in module_map { 282 all_files.insert(path, file_output); 283 } 284 285 if all_files.len() == old_count { 286 // No new files added 287 break; 288 } 289 } 290 291 // Write to disk 292 for (path, file_output) in all_files { 293 let full_path = output_dir.join(&path); 294 295 // Create parent directories 296 if let Some(parent) = full_path.parent() { 297 std::fs::create_dir_all(parent)?; 298 } 299 300 // Format code 301 let file: syn::File = syn::parse2(file_output.tokens.clone()).map_err(|e| { 302 let tokens = file_output.tokens.to_string(); 303 eprintln!( 304 "Failed to parse generated tokens for {:?}:\n{}", 305 path, tokens 306 ); 307 CodegenError::TokenParseError { 308 path: path.clone(), 309 source: e, 310 tokens, 311 } 312 })?; 313 let mut formatted = prettyplease::unparse(&file); 314 315 // Add blank lines between top-level items for better readability 316 let lines: Vec<&str> = formatted.lines().collect(); 317 let mut result_lines = Vec::new(); 318 319 for (i, line) in lines.iter().enumerate() { 320 result_lines.push(*line); 321 322 // Add blank line after closing braces that are at column 0 (top-level items) 323 if *line == "}" && i + 1 < lines.len() && !lines[i + 1].is_empty() { 324 result_lines.push(""); 325 } 326 327 if !line.starts_with("#[") && i + 1 < lines.len() && !lines[i + 1].is_empty() { 328 let next_line = lines[i + 1]; 329 if next_line.starts_with("#[") && !next_line.is_empty() { 330 result_lines.push(""); 331 } 332 } 333 334 // Add blank line after last pub mod declaration before structs/enums 335 if line.starts_with("pub mod ") && i + 1 < lines.len() { 336 let next_line = lines[i + 1]; 337 if !next_line.starts_with("pub mod ") 338 && !next_line.starts_with("pub use ") 339 && !next_line.is_empty() 340 { 341 result_lines.push(""); 342 } 343 } 344 } 345 346 formatted = result_lines.join("\n"); 347 348 // Add header comment 349 let header = if let Some(nsid) = &file_output.nsid { 350 format!( 351 "// @generated by jacquard-lexicon. DO NOT EDIT.\n//\n// Lexicon: {}\n//\n// This file was automatically generated from Lexicon schemas.\n// Any manual changes will be overwritten on the next regeneration.\n\n", 352 nsid 353 ) 354 } else { 355 "// @generated by jacquard-lexicon. DO NOT EDIT.\n//\n// This file was automatically generated from Lexicon schemas.\n// Any manual changes will be overwritten on the next regeneration.\n\n".to_string() 356 }; 357 formatted = format!("{}{}", header, formatted); 358 359 // Write file 360 std::fs::write(&full_path, formatted)?; 361 } 362 363 Ok(()) 364 } 365 366 /// Get namespace dependencies collected during code generation 367 pub fn get_namespace_dependencies(&self) -> HashMap<String, HashSet<String>> { 368 self.namespace_deps.borrow().clone() 369 } 370 371 /// Generate Cargo.toml features section from namespace dependencies 372 pub fn generate_cargo_features(&self, lib_rs_path: Option<&std::path::Path>) -> String { 373 use std::fmt::Write; 374 375 let deps = self.namespace_deps.borrow(); 376 let mut all_namespaces: HashSet<String> = HashSet::new(); 377 378 // Collect all namespaces from the corpus (first two segments of each NSID) 379 for (nsid, _doc) in self.corpus.iter() { 380 let nsid_path = NsidPath::parse(nsid.as_str()); 381 let namespace = nsid_path.namespace(); 382 all_namespaces.insert(namespace); 383 } 384 385 // Also collect existing feature names from lib.rs 386 let mut existing_features = HashSet::new(); 387 if let Some(lib_rs) = lib_rs_path { 388 if let Ok(content) = std::fs::read_to_string(lib_rs) { 389 for line in content.lines() { 390 if let Some(feature) = line 391 .trim() 392 .strip_prefix("#[cfg(feature = \"") 393 .and_then(|s| s.strip_suffix("\")]")) 394 { 395 existing_features.insert(feature.to_string()); 396 } 397 } 398 } 399 } 400 401 let mut output = String::new(); 402 writeln!(&mut output, "# Generated namespace features").unwrap(); 403 404 // Convert namespace to feature name (matching module path sanitization) 405 let to_feature_name = |ns: &str| { 406 ns.split('.') 407 .map(|segment| { 408 // Apply same sanitization as module names 409 let mut result = segment.replace('-', "_"); 410 // Prefix with underscore if starts with digit 411 if result.chars().next().map_or(false, |c| c.is_ascii_digit()) { 412 result.insert(0, '_'); 413 } 414 result 415 }) 416 .collect::<Vec<_>>() 417 .join("_") 418 }; 419 420 // Collect all feature names (from corpus + existing lib.rs) 421 let mut all_feature_names = HashSet::new(); 422 for ns in &all_namespaces { 423 all_feature_names.insert(to_feature_name(ns)); 424 } 425 all_feature_names.extend(existing_features); 426 427 // Sort for consistent output 428 let mut feature_names: Vec<_> = all_feature_names.iter().collect(); 429 feature_names.sort(); 430 431 // Map namespace to feature name for dependency lookup 432 let mut ns_to_feature: HashMap<&str, String> = HashMap::new(); 433 for ns in &all_namespaces { 434 ns_to_feature.insert(ns.as_str(), to_feature_name(ns)); 435 } 436 437 for feature_name in feature_names { 438 // Find corresponding namespace for this feature (if any) to look up deps 439 let feature_deps: Vec<String> = all_namespaces 440 .iter() 441 .find(|ns| to_feature_name(ns) == *feature_name) 442 .and_then(|ns| deps.get(ns.as_str())) 443 .map(|ns_deps| { 444 let mut dep_features: Vec<_> = ns_deps 445 .iter() 446 .map(|d| format!("\"{}\"", to_feature_name(d))) 447 .collect(); 448 dep_features.sort(); 449 dep_features 450 }) 451 .unwrap_or_default(); 452 453 if !feature_deps.is_empty() { 454 writeln!( 455 &mut output, 456 "{} = [{}]", 457 feature_name, 458 feature_deps.join(", ") 459 ) 460 .unwrap(); 461 } else { 462 writeln!(&mut output, "{} = []", feature_name).unwrap(); 463 } 464 } 465 466 output 467 } 468} 469 470#[cfg(test)] 471mod tests { 472 use super::*; 473 use crate::corpus::LexiconCorpus; 474 475 #[test] 476 fn test_enumerate_local_type_names() { 477 // Verifies AC3.4: Local type names are correctly enumerated from the corpus before generation 478 let corpus = 479 LexiconCorpus::load_from_dir("tests/fixtures/test_lexicons").expect("load corpus"); 480 let codegen = CodeGenerator::new(&corpus, "jacquard_api"); 481 482 // Generate all - this internally enumerates local type names 483 let result = codegen.generate_all().expect("generate_all"); 484 485 // Verify that we got output for multiple files 486 assert!(!result.is_empty(), "Should have generated files"); 487 488 // For pub.leaflet.poll.definition (multi-def), verify it's generated 489 let has_poll_defs = result 490 .keys() 491 .any(|path| path.to_string_lossy().contains("poll")); 492 assert!(has_poll_defs, "Should have poll defs in output"); 493 } 494 495 #[test] 496 fn test_collection_produces_imports() { 497 // Verifies AC3.4: Collection produces non-empty ImportSet for a file with string types 498 let corpus = 499 LexiconCorpus::load_from_dir("tests/fixtures/test_lexicons").expect("load corpus"); 500 let codegen = CodeGenerator::new(&corpus, "jacquard_api"); 501 502 // Get a lexicon with known string types 503 let doc = corpus.get("app.bsky.feed.post").expect("get post"); 504 let post_def = doc.defs.get("main").expect("get main def"); 505 506 // Collect imports from the post definition 507 let imports = codegen.collect_def("app.bsky.feed.post", "main", post_def); 508 509 // Post should have collected imports for CowStr, Datetime, etc. 510 assert!( 511 imports.common.len() > 0, 512 "Post definition should have collected common types" 513 ); 514 assert!( 515 imports.external.len() > 0, 516 "Post definition should have collected external imports (Serialize, Deserialize)" 517 ); 518 519 // Verify specific types that we know post uses 520 assert!( 521 imports 522 .common 523 .contains(&crate::codegen::prettify::CommonType::CowStr), 524 "Post should collect CowStr" 525 ); 526 assert!( 527 imports 528 .common 529 .contains(&crate::codegen::prettify::CommonType::Datetime), 530 "Post should collect Datetime" 531 ); 532 } 533 534 #[test] 535 fn test_resolved_imports_for_collection_collision() { 536 // Verifies AC3.4: ResolvedImports correctly marks Collection as qualified for files defining it 537 let corpus = 538 LexiconCorpus::load_from_dir("tests/fixtures/test_lexicons").expect("load corpus"); 539 let codegen = CodeGenerator::new(&corpus, "jacquard_api"); 540 541 // Find a file that might define a Collection type 542 // (This is harder to verify without knowing the exact corpus, but we can verify 543 // that generate_all completes successfully with ResolvedImports built) 544 let result = codegen.generate_all().expect("generate_all"); 545 546 // Verify we got output 547 assert!(!result.is_empty(), "Should have generated code"); 548 549 // The fact that we generated code successfully means collection and 550 // ResolvedImports::resolve() were executed without errors 551 for (_path, file_output) in result { 552 // Each file output should have imports (from collection) 553 // and internally use ResolvedImports (built in generate_all) 554 // We can't directly inspect ResolvedImports since it's internal to Task 4, 555 // but we verify the output was generated 556 assert!( 557 !file_output.tokens.to_string().is_empty(), 558 "Generated code should not be empty" 559 ); 560 } 561 } 562 563 #[test] 564 fn test_local_names_enumeration_accuracy() { 565 // Verifies that local type names are enumerated correctly per file 566 let corpus = 567 LexiconCorpus::load_from_dir("tests/fixtures/test_lexicons").expect("load corpus"); 568 let codegen = CodeGenerator::new(&corpus, "jacquard_api"); 569 570 // Generate all 571 let result = codegen.generate_all().expect("generate_all"); 572 573 // For a known multi-def lexicon (app.bsky.feed.post), verify it generates 574 let post_file = result.keys().find(|p| p.to_string_lossy().contains("post")); 575 assert!(post_file.is_some(), "post file should exist"); 576 577 // The post record has at least "Post" as a type name 578 let generated_code = post_file.and_then(|p| result.get(p)); 579 assert!( 580 generated_code.is_some(), 581 "Should have generated code for post" 582 ); 583 } 584 585 #[test] 586 fn test_generate_all_runs_collection_without_errors() { 587 // Verifies that generate_all successfully runs the collection pass 588 let corpus = 589 LexiconCorpus::load_from_dir("tests/fixtures/test_lexicons").expect("load corpus"); 590 let codegen = CodeGenerator::new(&corpus, "jacquard_api"); 591 592 // This should not panic or error - collection pass should run silently 593 let result = codegen.generate_all(); 594 assert!(result.is_ok(), "generate_all should complete successfully"); 595 596 let files = result.unwrap(); 597 assert!(!files.is_empty(), "Should generate at least one file"); 598 } 599}