A better Rust ATProto crate
1use crate::error::{CodegenError, Result};
2use quote::quote;
3use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
4
5use super::CodeGenerator;
6use super::prettify::{FileOutput, ResolvedImports};
7use super::utils::{make_ident, sanitize_name};
8use crate::ref_utils::NsidPath;
9
10impl<'c> CodeGenerator<'c> {
11 /// Generate all code for the corpus, organized by file
12 /// Returns a map of file paths to FileOutput with reordered tokens
13 pub fn generate_all(&self) -> Result<BTreeMap<std::path::PathBuf, FileOutput>> {
14 let mut file_contents: BTreeMap<std::path::PathBuf, Vec<super::prettify::GeneratedCode>> =
15 BTreeMap::new();
16 let mut file_nsids: BTreeMap<std::path::PathBuf, String> = BTreeMap::new();
17
18 // Step 1: Enumerate local type names per file.
19 // Also collect all file paths so we can determine submodule names.
20 let mut file_local_names: BTreeMap<std::path::PathBuf, HashSet<String>> = BTreeMap::new();
21 let mut all_file_paths: BTreeSet<std::path::PathBuf> = BTreeSet::new();
22 for (nsid, doc) in self.corpus.iter() {
23 let file_path = self.nsid_to_file_path(nsid.as_ref());
24 all_file_paths.insert(file_path.clone());
25 let names = file_local_names.entry(file_path).or_default();
26 for def_name in doc.defs.keys() {
27 names.insert(self.def_to_type_name(nsid.as_ref(), def_name.as_ref()));
28 }
29 }
30
31 // Determine submodule names for each file. If `foo.rs` exists and `foo/bar.rs`
32 // also exists, then `bar` is a submodule of `foo`. These names must be treated
33 // as reserved — importing `use crate::something::bar;` would collide with
34 // `pub mod bar;`.
35 let mut file_submodule_names: BTreeMap<std::path::PathBuf, HashSet<String>> =
36 BTreeMap::new();
37 for file_path in &all_file_paths {
38 // For a file like `app_bsky/feed/post.rs`, the parent module file is
39 // `app_bsky/feed.rs`. Check if this file's parent has an entry.
40 if let Some(parent_dir) = file_path.parent() {
41 let parent_file = parent_dir.with_extension("rs");
42 if all_file_paths.contains(&parent_file)
43 || file_local_names.contains_key(&parent_file)
44 {
45 // This file is a submodule of parent_file.
46 if let Some(stem) = file_path.file_stem().and_then(|s| s.to_str()) {
47 file_submodule_names
48 .entry(parent_file)
49 .or_default()
50 .insert(stem.to_string());
51 }
52 }
53 }
54 }
55
56 // Step 2: Run collection pass and build ResolvedImports for each file.
57 // Multiple NSIDs can map to the same file (e.g. `app.rocksky.album` and
58 // `app.rocksky.album.defs` both output to `album.rs`), so we accumulate
59 // imports per file path before resolving.
60 let mut file_imports_map: BTreeMap<std::path::PathBuf, super::prettify::ImportSet> =
61 BTreeMap::new();
62 for (nsid, doc) in self.corpus.iter() {
63 let file_path = self.nsid_to_file_path(nsid.as_ref());
64 let file_imports = file_imports_map.entry(file_path).or_default();
65 for (def_name, def) in &doc.defs {
66 file_imports.merge(self.collect_def(nsid.as_ref(), def_name.as_ref(), def));
67 }
68 }
69
70 let mut file_resolved: BTreeMap<std::path::PathBuf, ResolvedImports> = BTreeMap::new();
71 for (file_path, file_imports) in &file_imports_map {
72 let local_names = file_local_names.get(file_path).cloned().unwrap_or_default();
73
74 let lexicon_paths: BTreeMap<String, String> = file_imports
75 .lexicon_refs
76 .iter()
77 .filter_map(|ref_str| self.ref_to_crate_path(ref_str))
78 .collect();
79
80 let submodule_names = file_submodule_names
81 .get(file_path)
82 .cloned()
83 .unwrap_or_default();
84 let resolved = ResolvedImports::resolve(
85 file_imports,
86 &local_names,
87 &submodule_names,
88 self.mode,
89 &lexicon_paths,
90 );
91 file_resolved.insert(file_path.clone(), resolved);
92 }
93
94 // Step 3: Generate code for all lexicons
95 for (nsid, doc) in self.corpus.iter() {
96 let file_path = self.nsid_to_file_path(nsid.as_ref());
97
98 // Track which NSID this file is for
99 file_nsids.insert(file_path.clone(), nsid.to_string());
100
101 // Get the per-file ResolvedImports (built in Step 2)
102 let resolved = file_resolved
103 .get(&file_path)
104 .expect("resolved imports built for every file");
105
106 for (_def_name, def) in &doc.defs {
107 let generated =
108 self.generate_def(nsid.as_ref(), _def_name.as_ref(), def, resolved)?;
109 file_contents
110 .entry(file_path.clone())
111 .or_default()
112 .push(generated);
113 }
114 }
115
116 // Combine all tokens for each file using FileOutput::combine for reordering
117 let mut result = BTreeMap::new();
118 for (path, generated_vec) in file_contents {
119 let nsid = file_nsids.get(&path).cloned();
120 let resolved = file_resolved
121 .get(&path)
122 .expect("resolved imports built for every file");
123 let file_output = FileOutput::combine(generated_vec, nsid, resolved);
124 result.insert(path, file_output);
125 }
126
127 Ok(result)
128 }
129
130 /// Generate parent module files with pub mod declarations
131 pub fn generate_module_tree(
132 &self,
133 file_map: &BTreeMap<std::path::PathBuf, FileOutput>,
134 defs_only: &BTreeMap<std::path::PathBuf, FileOutput>,
135 subscription_files: &HashSet<std::path::PathBuf>,
136 ) -> BTreeMap<std::path::PathBuf, FileOutput> {
137 // Track what modules each directory needs to declare
138 // Key: directory path, Value: set of module names (file stems)
139 let mut dir_modules: BTreeMap<std::path::PathBuf, BTreeSet<String>> = BTreeMap::new();
140
141 // Collect all parent directories that have files
142 let mut all_dirs: BTreeSet<std::path::PathBuf> = BTreeSet::new();
143 for path in file_map.keys() {
144 if let Some(parent_dir) = path.parent() {
145 all_dirs.insert(parent_dir.to_path_buf());
146 }
147 }
148
149 for path in file_map.keys() {
150 if let Some(parent_dir) = path.parent() {
151 if let Some(file_stem) = path.file_stem().and_then(|s| s.to_str()) {
152 // Skip mod.rs and lib.rs - they're module files, not modules to declare
153 if file_stem == "mod" || file_stem == "lib" {
154 continue;
155 }
156
157 // Always add the module declaration to parent
158 dir_modules
159 .entry(parent_dir.to_path_buf())
160 .or_default()
161 .insert(file_stem.to_string());
162 }
163 }
164 }
165
166 // Generate module files
167 let mut result = BTreeMap::new();
168
169 for (dir, module_names) in dir_modules {
170 let mod_file_path = if dir.components().count() == 0 {
171 // Root directory -> lib.rs for library crates
172 std::path::PathBuf::from("lib.rs")
173 } else {
174 // Subdirectory: app_bsky/feed -> app_bsky/feed.rs (Rust 2018 style)
175 let dir_name = dir.file_name().and_then(|s| s.to_str()).unwrap_or("mod");
176 let sanitized_dir_name = sanitize_name(dir_name);
177 let mut path = dir
178 .parent()
179 .unwrap_or_else(|| std::path::Path::new(""))
180 .to_path_buf();
181 path.push(format!("{}.rs", sanitized_dir_name));
182 path
183 };
184
185 let is_root = dir.components().count() == 0;
186 let mods: Vec<_> = module_names
187 .iter()
188 .map(|name| {
189 let ident = make_ident(name);
190
191 // Check if this module is a subscription endpoint
192 let mut module_path = dir.clone();
193 module_path.push(format!("{}.rs", name));
194 let is_subscription = subscription_files.contains(&module_path);
195
196 if is_root && name != "builder_types" {
197 // Top-level modules get feature gates (except builder_types which is always needed)
198 quote! {
199 #[cfg(feature = #name)]
200 pub mod #ident;
201 }
202 } else if is_subscription {
203 // Subscription modules get streaming feature gate
204 quote! {
205 #[cfg(feature = "streaming")]
206 pub mod #ident;
207 }
208 } else {
209 quote! { pub mod #ident; }
210 }
211 })
212 .collect();
213
214 // If this file already exists in defs_only (e.g., from defs), merge the content
215 let module_tokens = if is_root {
216 // lib.rs needs extern crate alloc for no_std compatibility
217 quote! {
218 extern crate alloc; #(#mods)*
219 }
220 } else {
221 quote! { #(#mods)* }
222 };
223 if let Some(existing_output) = defs_only.get(&mod_file_path) {
224 // Put module declarations FIRST, then existing defs content
225 let existing_tokens = &existing_output.tokens;
226 let merged_tokens = quote! {
227 #module_tokens
228 #existing_tokens
229 };
230 result.insert(
231 mod_file_path,
232 FileOutput {
233 tokens: merged_tokens,
234 imports: existing_output.imports.clone(),
235 nsid: existing_output.nsid.clone(),
236 },
237 );
238 } else {
239 result.insert(
240 mod_file_path,
241 FileOutput {
242 tokens: module_tokens,
243 imports: Default::default(),
244 nsid: None,
245 },
246 );
247 }
248 }
249
250 result
251 }
252
253 /// Write all generated code to disk
254 pub fn write_to_disk(&self, output_dir: &std::path::Path) -> Result<()> {
255 // Generate all code (defs only)
256 let defs_files = self.generate_all()?;
257 let mut all_files = defs_files.clone();
258
259 // Generate common builder types (Set, Unset, IsSet, IsUnset)
260 let common_types_path = std::path::PathBuf::from("builder_types.rs");
261 let common_types_tokens = super::builder_gen::common::generate_common_types();
262 all_files.insert(
263 common_types_path,
264 FileOutput {
265 tokens: common_types_tokens,
266 imports: Default::default(),
267 nsid: None,
268 },
269 );
270
271 // Get subscription files for feature gating
272 let subscription_files = self.subscription_files.borrow();
273
274 // Generate module tree iteratively until no new files appear
275 loop {
276 let module_map =
277 self.generate_module_tree(&all_files, &defs_files, &subscription_files);
278 let old_count = all_files.len();
279
280 // Merge new module files
281 for (path, file_output) in module_map {
282 all_files.insert(path, file_output);
283 }
284
285 if all_files.len() == old_count {
286 // No new files added
287 break;
288 }
289 }
290
291 // Write to disk
292 for (path, file_output) in all_files {
293 let full_path = output_dir.join(&path);
294
295 // Create parent directories
296 if let Some(parent) = full_path.parent() {
297 std::fs::create_dir_all(parent)?;
298 }
299
300 // Format code
301 let file: syn::File = syn::parse2(file_output.tokens.clone()).map_err(|e| {
302 let tokens = file_output.tokens.to_string();
303 eprintln!(
304 "Failed to parse generated tokens for {:?}:\n{}",
305 path, tokens
306 );
307 CodegenError::TokenParseError {
308 path: path.clone(),
309 source: e,
310 tokens,
311 }
312 })?;
313 let mut formatted = prettyplease::unparse(&file);
314
315 // Add blank lines between top-level items for better readability
316 let lines: Vec<&str> = formatted.lines().collect();
317 let mut result_lines = Vec::new();
318
319 for (i, line) in lines.iter().enumerate() {
320 result_lines.push(*line);
321
322 // Add blank line after closing braces that are at column 0 (top-level items)
323 if *line == "}" && i + 1 < lines.len() && !lines[i + 1].is_empty() {
324 result_lines.push("");
325 }
326
327 if !line.starts_with("#[") && i + 1 < lines.len() && !lines[i + 1].is_empty() {
328 let next_line = lines[i + 1];
329 if next_line.starts_with("#[") && !next_line.is_empty() {
330 result_lines.push("");
331 }
332 }
333
334 // Add blank line after last pub mod declaration before structs/enums
335 if line.starts_with("pub mod ") && i + 1 < lines.len() {
336 let next_line = lines[i + 1];
337 if !next_line.starts_with("pub mod ")
338 && !next_line.starts_with("pub use ")
339 && !next_line.is_empty()
340 {
341 result_lines.push("");
342 }
343 }
344 }
345
346 formatted = result_lines.join("\n");
347
348 // Add header comment
349 let header = if let Some(nsid) = &file_output.nsid {
350 format!(
351 "// @generated by jacquard-lexicon. DO NOT EDIT.\n//\n// Lexicon: {}\n//\n// This file was automatically generated from Lexicon schemas.\n// Any manual changes will be overwritten on the next regeneration.\n\n",
352 nsid
353 )
354 } else {
355 "// @generated by jacquard-lexicon. DO NOT EDIT.\n//\n// This file was automatically generated from Lexicon schemas.\n// Any manual changes will be overwritten on the next regeneration.\n\n".to_string()
356 };
357 formatted = format!("{}{}", header, formatted);
358
359 // Write file
360 std::fs::write(&full_path, formatted)?;
361 }
362
363 Ok(())
364 }
365
366 /// Get namespace dependencies collected during code generation
367 pub fn get_namespace_dependencies(&self) -> HashMap<String, HashSet<String>> {
368 self.namespace_deps.borrow().clone()
369 }
370
371 /// Generate Cargo.toml features section from namespace dependencies
372 pub fn generate_cargo_features(&self, lib_rs_path: Option<&std::path::Path>) -> String {
373 use std::fmt::Write;
374
375 let deps = self.namespace_deps.borrow();
376 let mut all_namespaces: HashSet<String> = HashSet::new();
377
378 // Collect all namespaces from the corpus (first two segments of each NSID)
379 for (nsid, _doc) in self.corpus.iter() {
380 let nsid_path = NsidPath::parse(nsid.as_str());
381 let namespace = nsid_path.namespace();
382 all_namespaces.insert(namespace);
383 }
384
385 // Also collect existing feature names from lib.rs
386 let mut existing_features = HashSet::new();
387 if let Some(lib_rs) = lib_rs_path {
388 if let Ok(content) = std::fs::read_to_string(lib_rs) {
389 for line in content.lines() {
390 if let Some(feature) = line
391 .trim()
392 .strip_prefix("#[cfg(feature = \"")
393 .and_then(|s| s.strip_suffix("\")]"))
394 {
395 existing_features.insert(feature.to_string());
396 }
397 }
398 }
399 }
400
401 let mut output = String::new();
402 writeln!(&mut output, "# Generated namespace features").unwrap();
403
404 // Convert namespace to feature name (matching module path sanitization)
405 let to_feature_name = |ns: &str| {
406 ns.split('.')
407 .map(|segment| {
408 // Apply same sanitization as module names
409 let mut result = segment.replace('-', "_");
410 // Prefix with underscore if starts with digit
411 if result.chars().next().map_or(false, |c| c.is_ascii_digit()) {
412 result.insert(0, '_');
413 }
414 result
415 })
416 .collect::<Vec<_>>()
417 .join("_")
418 };
419
420 // Collect all feature names (from corpus + existing lib.rs)
421 let mut all_feature_names = HashSet::new();
422 for ns in &all_namespaces {
423 all_feature_names.insert(to_feature_name(ns));
424 }
425 all_feature_names.extend(existing_features);
426
427 // Sort for consistent output
428 let mut feature_names: Vec<_> = all_feature_names.iter().collect();
429 feature_names.sort();
430
431 // Map namespace to feature name for dependency lookup
432 let mut ns_to_feature: HashMap<&str, String> = HashMap::new();
433 for ns in &all_namespaces {
434 ns_to_feature.insert(ns.as_str(), to_feature_name(ns));
435 }
436
437 for feature_name in feature_names {
438 // Find corresponding namespace for this feature (if any) to look up deps
439 let feature_deps: Vec<String> = all_namespaces
440 .iter()
441 .find(|ns| to_feature_name(ns) == *feature_name)
442 .and_then(|ns| deps.get(ns.as_str()))
443 .map(|ns_deps| {
444 let mut dep_features: Vec<_> = ns_deps
445 .iter()
446 .map(|d| format!("\"{}\"", to_feature_name(d)))
447 .collect();
448 dep_features.sort();
449 dep_features
450 })
451 .unwrap_or_default();
452
453 if !feature_deps.is_empty() {
454 writeln!(
455 &mut output,
456 "{} = [{}]",
457 feature_name,
458 feature_deps.join(", ")
459 )
460 .unwrap();
461 } else {
462 writeln!(&mut output, "{} = []", feature_name).unwrap();
463 }
464 }
465
466 output
467 }
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473 use crate::corpus::LexiconCorpus;
474
475 #[test]
476 fn test_enumerate_local_type_names() {
477 // Verifies AC3.4: Local type names are correctly enumerated from the corpus before generation
478 let corpus =
479 LexiconCorpus::load_from_dir("tests/fixtures/test_lexicons").expect("load corpus");
480 let codegen = CodeGenerator::new(&corpus, "jacquard_api");
481
482 // Generate all - this internally enumerates local type names
483 let result = codegen.generate_all().expect("generate_all");
484
485 // Verify that we got output for multiple files
486 assert!(!result.is_empty(), "Should have generated files");
487
488 // For pub.leaflet.poll.definition (multi-def), verify it's generated
489 let has_poll_defs = result
490 .keys()
491 .any(|path| path.to_string_lossy().contains("poll"));
492 assert!(has_poll_defs, "Should have poll defs in output");
493 }
494
495 #[test]
496 fn test_collection_produces_imports() {
497 // Verifies AC3.4: Collection produces non-empty ImportSet for a file with string types
498 let corpus =
499 LexiconCorpus::load_from_dir("tests/fixtures/test_lexicons").expect("load corpus");
500 let codegen = CodeGenerator::new(&corpus, "jacquard_api");
501
502 // Get a lexicon with known string types
503 let doc = corpus.get("app.bsky.feed.post").expect("get post");
504 let post_def = doc.defs.get("main").expect("get main def");
505
506 // Collect imports from the post definition
507 let imports = codegen.collect_def("app.bsky.feed.post", "main", post_def);
508
509 // Post should have collected imports for CowStr, Datetime, etc.
510 assert!(
511 imports.common.len() > 0,
512 "Post definition should have collected common types"
513 );
514 assert!(
515 imports.external.len() > 0,
516 "Post definition should have collected external imports (Serialize, Deserialize)"
517 );
518
519 // Verify specific types that we know post uses
520 assert!(
521 imports
522 .common
523 .contains(&crate::codegen::prettify::CommonType::CowStr),
524 "Post should collect CowStr"
525 );
526 assert!(
527 imports
528 .common
529 .contains(&crate::codegen::prettify::CommonType::Datetime),
530 "Post should collect Datetime"
531 );
532 }
533
534 #[test]
535 fn test_resolved_imports_for_collection_collision() {
536 // Verifies AC3.4: ResolvedImports correctly marks Collection as qualified for files defining it
537 let corpus =
538 LexiconCorpus::load_from_dir("tests/fixtures/test_lexicons").expect("load corpus");
539 let codegen = CodeGenerator::new(&corpus, "jacquard_api");
540
541 // Find a file that might define a Collection type
542 // (This is harder to verify without knowing the exact corpus, but we can verify
543 // that generate_all completes successfully with ResolvedImports built)
544 let result = codegen.generate_all().expect("generate_all");
545
546 // Verify we got output
547 assert!(!result.is_empty(), "Should have generated code");
548
549 // The fact that we generated code successfully means collection and
550 // ResolvedImports::resolve() were executed without errors
551 for (_path, file_output) in result {
552 // Each file output should have imports (from collection)
553 // and internally use ResolvedImports (built in generate_all)
554 // We can't directly inspect ResolvedImports since it's internal to Task 4,
555 // but we verify the output was generated
556 assert!(
557 !file_output.tokens.to_string().is_empty(),
558 "Generated code should not be empty"
559 );
560 }
561 }
562
563 #[test]
564 fn test_local_names_enumeration_accuracy() {
565 // Verifies that local type names are enumerated correctly per file
566 let corpus =
567 LexiconCorpus::load_from_dir("tests/fixtures/test_lexicons").expect("load corpus");
568 let codegen = CodeGenerator::new(&corpus, "jacquard_api");
569
570 // Generate all
571 let result = codegen.generate_all().expect("generate_all");
572
573 // For a known multi-def lexicon (app.bsky.feed.post), verify it generates
574 let post_file = result.keys().find(|p| p.to_string_lossy().contains("post"));
575 assert!(post_file.is_some(), "post file should exist");
576
577 // The post record has at least "Post" as a type name
578 let generated_code = post_file.and_then(|p| result.get(p));
579 assert!(
580 generated_code.is_some(),
581 "Should have generated code for post"
582 );
583 }
584
585 #[test]
586 fn test_generate_all_runs_collection_without_errors() {
587 // Verifies that generate_all successfully runs the collection pass
588 let corpus =
589 LexiconCorpus::load_from_dir("tests/fixtures/test_lexicons").expect("load corpus");
590 let codegen = CodeGenerator::new(&corpus, "jacquard_api");
591
592 // This should not panic or error - collection pass should run silently
593 let result = codegen.generate_all();
594 assert!(result.is_ok(), "generate_all should complete successfully");
595
596 let files = result.unwrap();
597 assert!(!files.is_empty(), "Should generate at least one file");
598 }
599}