this repo has no description
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

at main 795 lines 26 kB view raw
1use crate::{Graph, NodeId, RelationshipId, Result}; 2use crate::core::relationship::Direction; 3use std::collections::{HashMap, HashSet, VecDeque, BinaryHeap}; 4use std::cmp::Ordering; 5 6pub mod pathfinding; 7pub mod centrality; 8pub mod community; 9pub mod traversal; 10 11pub use pathfinding::*; 12pub use centrality::*; 13pub use community::*; 14pub use traversal::*; 15 16/// Weight function for graph algorithms 17pub type WeightFn = dyn Fn(RelationshipId) -> f64 + Send + Sync; 18 19/// Graph statistics structure 20#[derive(Debug, Clone)] 21pub struct GraphStats { 22 pub node_count: u64, 23 pub relationship_count: u64, 24 pub label_count: u64, 25 pub property_key_count: u64, 26 pub relationship_type_count: u64, 27} 28 29/// Result type for pathfinding algorithms 30#[derive(Debug, Clone)] 31pub struct Path { 32 pub nodes: Vec<NodeId>, 33 pub relationships: Vec<RelationshipId>, 34 pub total_weight: f64, 35} 36 37impl Path { 38 pub fn new() -> Self { 39 Self { 40 nodes: Vec::new(), 41 relationships: Vec::new(), 42 total_weight: 0.0, 43 } 44 } 45 46 pub fn length(&self) -> usize { 47 self.nodes.len().saturating_sub(1) 48 } 49 50 pub fn is_empty(&self) -> bool { 51 self.nodes.is_empty() 52 } 53 54 pub fn add_step(&mut self, node: NodeId, relationship: Option<RelationshipId>, weight: f64) { 55 self.nodes.push(node); 56 if let Some(rel) = relationship { 57 self.relationships.push(rel); 58 } 59 self.total_weight += weight; 60 } 61} 62 63/// Priority queue entry for Dijkstra's algorithm 64#[derive(Debug, Clone)] 65struct DijkstraEntry { 66 node: NodeId, 67 distance: f64, 68 previous: Option<(NodeId, RelationshipId)>, 69} 70 71impl PartialEq for DijkstraEntry { 72 fn eq(&self, other: &Self) -> bool { 73 self.distance == other.distance 74 } 75} 76 77impl Eq for DijkstraEntry {} 78 79impl PartialOrd for DijkstraEntry { 80 fn partial_cmp(&self, other: &Self) -> Option<Ordering> { 81 Some(self.cmp(other)) 82 } 83} 84 85impl Ord for DijkstraEntry { 86 fn cmp(&self, other: &Self) -> Ordering { 87 // Reverse ordering for min-heap 88 other.distance.partial_cmp(&self.distance).unwrap_or(Ordering::Equal) 89 } 90} 91 92/// Graph algorithms implementation 93pub struct GraphAlgorithms<'a> { 94 graph: &'a Graph, 95} 96 97impl<'a> GraphAlgorithms<'a> { 98 pub fn new(graph: &'a Graph) -> Self { 99 Self { graph } 100 } 101 102 /// Find shortest path between two nodes using Dijkstra's algorithm 103 pub fn shortest_path( 104 &self, 105 start: NodeId, 106 end: NodeId, 107 weight_fn: Option<&WeightFn>, 108 ) -> Result<Option<Path>> { 109 let mut distances: HashMap<NodeId, f64> = HashMap::new(); 110 let mut previous: HashMap<NodeId, (NodeId, RelationshipId)> = HashMap::new(); 111 let mut heap: BinaryHeap<DijkstraEntry> = BinaryHeap::new(); 112 let mut visited: HashSet<NodeId> = HashSet::new(); 113 114 // Initialize start node 115 distances.insert(start, 0.0); 116 heap.push(DijkstraEntry { 117 node: start, 118 distance: 0.0, 119 previous: None, 120 }); 121 122 while let Some(current) = heap.pop() { 123 if visited.contains(&current.node) { 124 continue; 125 } 126 127 visited.insert(current.node); 128 129 // Found target 130 if current.node == end { 131 return Ok(Some(self.reconstruct_path(start, end, &previous)?)); 132 } 133 134 // Explore neighbors 135 let relationships = self.graph.get_node_relationships( 136 current.node, 137 Direction::Both, 138 None, 139 ); 140 141 for relationship in relationships { 142 let neighbor = if relationship.start_node == current.node { 143 relationship.end_node 144 } else { 145 relationship.start_node 146 }; 147 148 if visited.contains(&neighbor) { 149 continue; 150 } 151 152 let weight = weight_fn 153 .map(|f| f(relationship.id)) 154 .unwrap_or(1.0); 155 156 let new_distance = current.distance + weight; 157 let current_distance = distances.get(&neighbor).copied().unwrap_or(f64::INFINITY); 158 159 if new_distance < current_distance { 160 distances.insert(neighbor, new_distance); 161 previous.insert(neighbor, (current.node, relationship.id)); 162 163 heap.push(DijkstraEntry { 164 node: neighbor, 165 distance: new_distance, 166 previous: Some((current.node, relationship.id)), 167 }); 168 } 169 } 170 } 171 172 Ok(None) // No path found 173 } 174 175 /// Find all shortest paths from a source node (single-source shortest path) 176 pub fn shortest_paths_from( 177 &self, 178 start: NodeId, 179 weight_fn: Option<&WeightFn>, 180 ) -> Result<HashMap<NodeId, (f64, Option<Path>)>> { 181 let mut distances: HashMap<NodeId, f64> = HashMap::new(); 182 let mut previous: HashMap<NodeId, (NodeId, RelationshipId)> = HashMap::new(); 183 let mut heap: BinaryHeap<DijkstraEntry> = BinaryHeap::new(); 184 let mut visited: HashSet<NodeId> = HashSet::new(); 185 let mut results: HashMap<NodeId, (f64, Option<Path>)> = HashMap::new(); 186 187 // Initialize start node 188 distances.insert(start, 0.0); 189 heap.push(DijkstraEntry { 190 node: start, 191 distance: 0.0, 192 previous: None, 193 }); 194 195 while let Some(current) = heap.pop() { 196 if visited.contains(&current.node) { 197 continue; 198 } 199 200 visited.insert(current.node); 201 202 // Record result for this node 203 let path = if current.node == start { 204 Some(Path::new()) 205 } else { 206 Some(self.reconstruct_path(start, current.node, &previous)?) 207 }; 208 results.insert(current.node, (current.distance, path)); 209 210 // Explore neighbors 211 let relationships = self.graph.get_node_relationships( 212 current.node, 213 Direction::Both, 214 None, 215 ); 216 217 for relationship in relationships { 218 let neighbor = if relationship.start_node == current.node { 219 relationship.end_node 220 } else { 221 relationship.start_node 222 }; 223 224 if visited.contains(&neighbor) { 225 continue; 226 } 227 228 let weight = weight_fn 229 .map(|f| f(relationship.id)) 230 .unwrap_or(1.0); 231 232 let new_distance = current.distance + weight; 233 let current_distance = distances.get(&neighbor).copied().unwrap_or(f64::INFINITY); 234 235 if new_distance < current_distance { 236 distances.insert(neighbor, new_distance); 237 previous.insert(neighbor, (current.node, relationship.id)); 238 239 heap.push(DijkstraEntry { 240 node: neighbor, 241 distance: new_distance, 242 previous: Some((current.node, relationship.id)), 243 }); 244 } 245 } 246 } 247 248 Ok(results) 249 } 250 251 /// Find k shortest paths between two nodes 252 pub fn k_shortest_paths( 253 &self, 254 start: NodeId, 255 end: NodeId, 256 k: usize, 257 weight_fn: Option<&WeightFn>, 258 ) -> Result<Vec<Path>> { 259 // Yen's algorithm for k-shortest paths 260 let mut paths = Vec::new(); 261 262 // Find first shortest path 263 if let Some(first_path) = self.shortest_path(start, end, weight_fn)? { 264 paths.push(first_path); 265 } else { 266 return Ok(paths); // No path exists 267 } 268 269 let mut candidates: BinaryHeap<PathCandidate> = BinaryHeap::new(); 270 271 for i in 1..k { 272 if paths.is_empty() { 273 break; 274 } 275 276 let previous_path = &paths[i - 1]; 277 278 // Generate candidate paths by deviating from each node in the previous path 279 for j in 0..previous_path.nodes.len() - 1 { 280 let spur_node = previous_path.nodes[j]; 281 let root_path = &previous_path.nodes[0..=j]; 282 283 // Remove edges that would lead to already found paths 284 let mut removed_edges = HashSet::new(); 285 for path in &paths { 286 if path.nodes.len() > j && path.nodes[0..=j] == root_path[..] { 287 if j + 1 < path.relationships.len() { 288 removed_edges.insert(path.relationships[j]); 289 } 290 } 291 } 292 293 // Find shortest path from spur node to end (excluding removed edges) 294 if let Some(spur_path) = self.shortest_path_excluding( 295 spur_node, 296 end, 297 &removed_edges, 298 weight_fn, 299 )? { 300 // Combine root path with spur path 301 let mut full_path = Path::new(); 302 303 // Add root path 304 for &node in root_path { 305 full_path.add_step(node, None, 0.0); 306 } 307 308 // Add spur path (skip first node as it's already included) 309 for (idx, &node) in spur_path.nodes.iter().skip(1).enumerate() { 310 let rel = if idx < spur_path.relationships.len() { 311 Some(spur_path.relationships[idx]) 312 } else { 313 None 314 }; 315 full_path.add_step(node, rel, 0.0); 316 } 317 318 // Recalculate total weight 319 let total_weight = self.calculate_path_weight(&full_path, weight_fn)?; 320 full_path.total_weight = total_weight; 321 322 candidates.push(PathCandidate { 323 path: full_path, 324 weight: total_weight, 325 }); 326 } 327 } 328 329 if let Some(best_candidate) = candidates.pop() { 330 paths.push(best_candidate.path); 331 } else { 332 break; // No more candidates 333 } 334 } 335 336 Ok(paths) 337 } 338 339 /// Breadth-first search traversal 340 pub fn bfs(&self, start: NodeId, max_depth: Option<usize>) -> Result<Vec<NodeId>> { 341 let mut visited = HashSet::new(); 342 let mut queue = VecDeque::new(); 343 let mut result = Vec::new(); 344 345 queue.push_back((start, 0)); 346 visited.insert(start); 347 348 while let Some((node, depth)) = queue.pop_front() { 349 result.push(node); 350 351 if let Some(max_d) = max_depth { 352 if depth >= max_d { 353 continue; 354 } 355 } 356 357 let relationships = self.graph.get_node_relationships(node, Direction::Both, None); 358 for relationship in relationships { 359 let neighbor = if relationship.start_node == node { 360 relationship.end_node 361 } else { 362 relationship.start_node 363 }; 364 365 if !visited.contains(&neighbor) { 366 visited.insert(neighbor); 367 queue.push_back((neighbor, depth + 1)); 368 } 369 } 370 } 371 372 Ok(result) 373 } 374 375 /// Depth-first search traversal 376 pub fn dfs(&self, start: NodeId, max_depth: Option<usize>) -> Result<Vec<NodeId>> { 377 let mut visited = HashSet::new(); 378 let mut result = Vec::new(); 379 380 self.dfs_recursive(start, &mut visited, &mut result, 0, max_depth)?; 381 382 Ok(result) 383 } 384 385 /// Find connected components using Union-Find 386 pub fn connected_components(&self) -> Result<Vec<Vec<NodeId>>> { 387 let all_nodes = self.get_all_nodes()?; 388 let mut parent: HashMap<NodeId, NodeId> = HashMap::new(); 389 let mut rank: HashMap<NodeId, usize> = HashMap::new(); 390 391 // Initialize Union-Find 392 for &node in &all_nodes { 393 parent.insert(node, node); 394 rank.insert(node, 0); 395 } 396 397 // Process all relationships 398 for &node in &all_nodes { 399 let relationships = self.graph.get_node_relationships(node, Direction::Both, None); 400 for relationship in relationships { 401 let other = if relationship.start_node == node { 402 relationship.end_node 403 } else { 404 relationship.start_node 405 }; 406 407 self.union(&mut parent, &mut rank, node, other); 408 } 409 } 410 411 // Group nodes by their root parent 412 let mut components: HashMap<NodeId, Vec<NodeId>> = HashMap::new(); 413 for &node in &all_nodes { 414 let root = self.find(&mut parent, node); 415 components.entry(root).or_insert_with(Vec::new).push(node); 416 } 417 418 Ok(components.into_values().collect()) 419 } 420 421 /// Check if the graph has cycles (for directed graphs) 422 pub fn has_cycle(&self) -> Result<bool> { 423 let all_nodes = self.get_all_nodes()?; 424 let mut visited = HashSet::new(); 425 let mut rec_stack = HashSet::new(); 426 427 for &node in &all_nodes { 428 if !visited.contains(&node) { 429 if self.has_cycle_dfs(node, &mut visited, &mut rec_stack)? { 430 return Ok(true); 431 } 432 } 433 } 434 435 Ok(false) 436 } 437 438 // Helper methods 439 440 fn reconstruct_path( 441 &self, 442 start: NodeId, 443 end: NodeId, 444 previous: &HashMap<NodeId, (NodeId, RelationshipId)>, 445 ) -> Result<Path> { 446 let mut path = Path::new(); 447 let mut current = end; 448 let mut nodes = Vec::new(); 449 let mut relationships = Vec::new(); 450 451 while current != start { 452 nodes.push(current); 453 if let Some(&(prev_node, rel_id)) = previous.get(&current) { 454 relationships.push(rel_id); 455 current = prev_node; 456 } else { 457 return Err(crate::error::GigabrainError::Algorithm( 458 "Invalid path reconstruction".to_string(), 459 )); 460 } 461 } 462 463 nodes.push(start); 464 nodes.reverse(); 465 relationships.reverse(); 466 467 path.nodes = nodes; 468 path.relationships = relationships; 469 470 Ok(path) 471 } 472 473 fn shortest_path_excluding( 474 &self, 475 start: NodeId, 476 end: NodeId, 477 excluded_edges: &HashSet<RelationshipId>, 478 weight_fn: Option<&WeightFn>, 479 ) -> Result<Option<Path>> { 480 // Similar to shortest_path but excludes certain edges 481 let mut distances: HashMap<NodeId, f64> = HashMap::new(); 482 let mut previous: HashMap<NodeId, (NodeId, RelationshipId)> = HashMap::new(); 483 let mut heap: BinaryHeap<DijkstraEntry> = BinaryHeap::new(); 484 let mut visited: HashSet<NodeId> = HashSet::new(); 485 486 distances.insert(start, 0.0); 487 heap.push(DijkstraEntry { 488 node: start, 489 distance: 0.0, 490 previous: None, 491 }); 492 493 while let Some(current) = heap.pop() { 494 if visited.contains(&current.node) { 495 continue; 496 } 497 498 visited.insert(current.node); 499 500 if current.node == end { 501 return Ok(Some(self.reconstruct_path(start, end, &previous)?)); 502 } 503 504 let relationships = self.graph.get_node_relationships( 505 current.node, 506 Direction::Both, 507 None, 508 ); 509 510 for relationship in relationships { 511 if excluded_edges.contains(&relationship.id) { 512 continue; // Skip excluded edges 513 } 514 515 let neighbor = if relationship.start_node == current.node { 516 relationship.end_node 517 } else { 518 relationship.start_node 519 }; 520 521 if visited.contains(&neighbor) { 522 continue; 523 } 524 525 let weight = weight_fn 526 .map(|f| f(relationship.id)) 527 .unwrap_or(1.0); 528 529 let new_distance = current.distance + weight; 530 let current_distance = distances.get(&neighbor).copied().unwrap_or(f64::INFINITY); 531 532 if new_distance < current_distance { 533 distances.insert(neighbor, new_distance); 534 previous.insert(neighbor, (current.node, relationship.id)); 535 536 heap.push(DijkstraEntry { 537 node: neighbor, 538 distance: new_distance, 539 previous: Some((current.node, relationship.id)), 540 }); 541 } 542 } 543 } 544 545 Ok(None) 546 } 547 548 fn calculate_path_weight(&self, path: &Path, weight_fn: Option<&WeightFn>) -> Result<f64> { 549 let mut total_weight = 0.0; 550 551 for &rel_id in &path.relationships { 552 let weight = weight_fn 553 .map(|f| f(rel_id)) 554 .unwrap_or(1.0); 555 total_weight += weight; 556 } 557 558 Ok(total_weight) 559 } 560 561 fn dfs_recursive( 562 &self, 563 node: NodeId, 564 visited: &mut HashSet<NodeId>, 565 result: &mut Vec<NodeId>, 566 depth: usize, 567 max_depth: Option<usize>, 568 ) -> Result<()> { 569 visited.insert(node); 570 result.push(node); 571 572 if let Some(max_d) = max_depth { 573 if depth >= max_d { 574 return Ok(()); 575 } 576 } 577 578 let relationships = self.graph.get_node_relationships(node, Direction::Both, None); 579 for relationship in relationships { 580 let neighbor = if relationship.start_node == node { 581 relationship.end_node 582 } else { 583 relationship.start_node 584 }; 585 586 if !visited.contains(&neighbor) { 587 self.dfs_recursive(neighbor, visited, result, depth + 1, max_depth)?; 588 } 589 } 590 591 Ok(()) 592 } 593 594 fn get_all_nodes(&self) -> Result<Vec<NodeId>> { 595 Ok(self.graph.get_all_nodes()) 596 } 597 598 /// Get comprehensive graph statistics 599 pub fn get_stats(graph: &Graph) -> GraphStats { 600 let nodes = graph.get_all_nodes(); 601 let node_count = nodes.len() as u64; 602 603 let mut relationship_count = 0u64; 604 for &node in &nodes { 605 let rels = graph.get_node_relationships(node, Direction::Outgoing, None); 606 relationship_count += rels.len() as u64; 607 } 608 609 let schema = graph.schema().read(); 610 let label_count = schema.labels.len() as u64; 611 let property_key_count = schema.property_keys.len() as u64; 612 let relationship_type_count = schema.relationship_types.len() as u64; 613 614 GraphStats { 615 node_count, 616 relationship_count, 617 label_count, 618 property_key_count, 619 relationship_type_count, 620 } 621 } 622 623 fn find(&self, parent: &mut HashMap<NodeId, NodeId>, node: NodeId) -> NodeId { 624 let parent_node = parent[&node]; 625 if parent_node != node { 626 let root = self.find(parent, parent_node); 627 parent.insert(node, root); 628 } 629 parent[&node] 630 } 631 632 fn union( 633 &self, 634 parent: &mut HashMap<NodeId, NodeId>, 635 rank: &mut HashMap<NodeId, usize>, 636 x: NodeId, 637 y: NodeId, 638 ) { 639 let root_x = self.find(parent, x); 640 let root_y = self.find(parent, y); 641 642 if root_x != root_y { 643 match rank[&root_x].cmp(&rank[&root_y]) { 644 Ordering::Less => { 645 parent.insert(root_x, root_y); 646 } 647 Ordering::Greater => { 648 parent.insert(root_y, root_x); 649 } 650 Ordering::Equal => { 651 parent.insert(root_y, root_x); 652 rank.insert(root_x, rank[&root_x] + 1); 653 } 654 } 655 } 656 } 657 658 fn has_cycle_dfs( 659 &self, 660 node: NodeId, 661 visited: &mut HashSet<NodeId>, 662 rec_stack: &mut HashSet<NodeId>, 663 ) -> Result<bool> { 664 visited.insert(node); 665 rec_stack.insert(node); 666 667 let relationships = self.graph.get_node_relationships(node, Direction::Outgoing, None); 668 for relationship in relationships { 669 let neighbor = relationship.end_node; 670 671 if !visited.contains(&neighbor) { 672 if self.has_cycle_dfs(neighbor, visited, rec_stack)? { 673 return Ok(true); 674 } 675 } else if rec_stack.contains(&neighbor) { 676 return Ok(true); 677 } 678 } 679 680 rec_stack.remove(&node); 681 Ok(false) 682 } 683} 684 685#[derive(Debug, Clone)] 686struct PathCandidate { 687 path: Path, 688 weight: f64, 689} 690 691impl PartialEq for PathCandidate { 692 fn eq(&self, other: &Self) -> bool { 693 self.weight == other.weight 694 } 695} 696 697impl Eq for PathCandidate {} 698 699impl PartialOrd for PathCandidate { 700 fn partial_cmp(&self, other: &Self) -> Option<Ordering> { 701 Some(self.cmp(other)) 702 } 703} 704 705impl Ord for PathCandidate { 706 fn cmp(&self, other: &Self) -> Ordering { 707 // Reverse ordering for min-heap 708 other.weight.partial_cmp(&self.weight).unwrap_or(Ordering::Equal) 709 } 710} 711 712#[cfg(test)] 713mod tests { 714 use super::*; 715 use crate::Graph; 716 use std::sync::Arc; 717 718 fn create_test_graph() -> Graph { 719 let graph = Graph::new(); 720 721 // Create a simple graph: A -> B -> C -> D 722 // | | 723 // v v 724 // E ------> F 725 726 let node_a = graph.create_node(); 727 let node_b = graph.create_node(); 728 let node_c = graph.create_node(); 729 let node_d = graph.create_node(); 730 let node_e = graph.create_node(); 731 let node_f = graph.create_node(); 732 733 let schema = graph.schema(); 734 let mut schema = schema.write(); 735 let rel_type = schema.get_or_create_relationship_type("CONNECTS"); 736 drop(schema); 737 738 // Create relationships 739 graph.create_relationship(node_a, node_b, rel_type).unwrap(); 740 graph.create_relationship(node_b, node_c, rel_type).unwrap(); 741 graph.create_relationship(node_c, node_d, rel_type).unwrap(); 742 graph.create_relationship(node_a, node_e, rel_type).unwrap(); 743 graph.create_relationship(node_c, node_f, rel_type).unwrap(); 744 graph.create_relationship(node_e, node_f, rel_type).unwrap(); 745 746 graph 747 } 748 749 #[test] 750 fn test_bfs_traversal() { 751 let graph = create_test_graph(); 752 let algorithms = GraphAlgorithms::new(&graph); 753 754 // Test BFS from first node 755 let nodes: Vec<_> = graph.get_all_nodes(); 756 if let Some(&start_node) = nodes.first() { 757 let result = algorithms.bfs(start_node, Some(2)).unwrap(); 758 assert!(!result.is_empty()); 759 assert_eq!(result[0], start_node); 760 } 761 } 762 763 #[test] 764 fn test_dfs_traversal() { 765 let graph = create_test_graph(); 766 let algorithms = GraphAlgorithms::new(&graph); 767 768 // Test DFS from first node 769 let nodes: Vec<_> = graph.get_all_nodes(); 770 if let Some(&start_node) = nodes.first() { 771 let result = algorithms.dfs(start_node, Some(3)).unwrap(); 772 assert!(!result.is_empty()); 773 assert_eq!(result[0], start_node); 774 } 775 } 776 777 #[test] 778 fn test_shortest_path() { 779 let graph = create_test_graph(); 780 let algorithms = GraphAlgorithms::new(&graph); 781 782 let nodes: Vec<_> = graph.get_all_nodes(); 783 if nodes.len() >= 2 { 784 let start = nodes[0]; 785 let end = nodes[1]; 786 787 let path = algorithms.shortest_path(start, end, None).unwrap(); 788 if let Some(path) = path { 789 assert!(!path.is_empty()); 790 assert_eq!(path.nodes[0], start); 791 assert_eq!(*path.nodes.last().unwrap(), end); 792 } 793 } 794 } 795}