this repo has no description
1use crate::{Graph, NodeId, RelationshipId, Result};
2use crate::core::relationship::Direction;
3use std::collections::{HashMap, HashSet, VecDeque, BinaryHeap};
4use std::cmp::Ordering;
5
6pub mod pathfinding;
7pub mod centrality;
8pub mod community;
9pub mod traversal;
10
11pub use pathfinding::*;
12pub use centrality::*;
13pub use community::*;
14pub use traversal::*;
15
16/// Weight function for graph algorithms
17pub type WeightFn = dyn Fn(RelationshipId) -> f64 + Send + Sync;
18
19/// Graph statistics structure
20#[derive(Debug, Clone)]
21pub struct GraphStats {
22 pub node_count: u64,
23 pub relationship_count: u64,
24 pub label_count: u64,
25 pub property_key_count: u64,
26 pub relationship_type_count: u64,
27}
28
29/// Result type for pathfinding algorithms
30#[derive(Debug, Clone)]
31pub struct Path {
32 pub nodes: Vec<NodeId>,
33 pub relationships: Vec<RelationshipId>,
34 pub total_weight: f64,
35}
36
37impl Path {
38 pub fn new() -> Self {
39 Self {
40 nodes: Vec::new(),
41 relationships: Vec::new(),
42 total_weight: 0.0,
43 }
44 }
45
46 pub fn length(&self) -> usize {
47 self.nodes.len().saturating_sub(1)
48 }
49
50 pub fn is_empty(&self) -> bool {
51 self.nodes.is_empty()
52 }
53
54 pub fn add_step(&mut self, node: NodeId, relationship: Option<RelationshipId>, weight: f64) {
55 self.nodes.push(node);
56 if let Some(rel) = relationship {
57 self.relationships.push(rel);
58 }
59 self.total_weight += weight;
60 }
61}
62
63/// Priority queue entry for Dijkstra's algorithm
64#[derive(Debug, Clone)]
65struct DijkstraEntry {
66 node: NodeId,
67 distance: f64,
68 previous: Option<(NodeId, RelationshipId)>,
69}
70
71impl PartialEq for DijkstraEntry {
72 fn eq(&self, other: &Self) -> bool {
73 self.distance == other.distance
74 }
75}
76
77impl Eq for DijkstraEntry {}
78
79impl PartialOrd for DijkstraEntry {
80 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
81 Some(self.cmp(other))
82 }
83}
84
85impl Ord for DijkstraEntry {
86 fn cmp(&self, other: &Self) -> Ordering {
87 // Reverse ordering for min-heap
88 other.distance.partial_cmp(&self.distance).unwrap_or(Ordering::Equal)
89 }
90}
91
92/// Graph algorithms implementation
93pub struct GraphAlgorithms<'a> {
94 graph: &'a Graph,
95}
96
97impl<'a> GraphAlgorithms<'a> {
98 pub fn new(graph: &'a Graph) -> Self {
99 Self { graph }
100 }
101
102 /// Find shortest path between two nodes using Dijkstra's algorithm
103 pub fn shortest_path(
104 &self,
105 start: NodeId,
106 end: NodeId,
107 weight_fn: Option<&WeightFn>,
108 ) -> Result<Option<Path>> {
109 let mut distances: HashMap<NodeId, f64> = HashMap::new();
110 let mut previous: HashMap<NodeId, (NodeId, RelationshipId)> = HashMap::new();
111 let mut heap: BinaryHeap<DijkstraEntry> = BinaryHeap::new();
112 let mut visited: HashSet<NodeId> = HashSet::new();
113
114 // Initialize start node
115 distances.insert(start, 0.0);
116 heap.push(DijkstraEntry {
117 node: start,
118 distance: 0.0,
119 previous: None,
120 });
121
122 while let Some(current) = heap.pop() {
123 if visited.contains(¤t.node) {
124 continue;
125 }
126
127 visited.insert(current.node);
128
129 // Found target
130 if current.node == end {
131 return Ok(Some(self.reconstruct_path(start, end, &previous)?));
132 }
133
134 // Explore neighbors
135 let relationships = self.graph.get_node_relationships(
136 current.node,
137 Direction::Both,
138 None,
139 );
140
141 for relationship in relationships {
142 let neighbor = if relationship.start_node == current.node {
143 relationship.end_node
144 } else {
145 relationship.start_node
146 };
147
148 if visited.contains(&neighbor) {
149 continue;
150 }
151
152 let weight = weight_fn
153 .map(|f| f(relationship.id))
154 .unwrap_or(1.0);
155
156 let new_distance = current.distance + weight;
157 let current_distance = distances.get(&neighbor).copied().unwrap_or(f64::INFINITY);
158
159 if new_distance < current_distance {
160 distances.insert(neighbor, new_distance);
161 previous.insert(neighbor, (current.node, relationship.id));
162
163 heap.push(DijkstraEntry {
164 node: neighbor,
165 distance: new_distance,
166 previous: Some((current.node, relationship.id)),
167 });
168 }
169 }
170 }
171
172 Ok(None) // No path found
173 }
174
175 /// Find all shortest paths from a source node (single-source shortest path)
176 pub fn shortest_paths_from(
177 &self,
178 start: NodeId,
179 weight_fn: Option<&WeightFn>,
180 ) -> Result<HashMap<NodeId, (f64, Option<Path>)>> {
181 let mut distances: HashMap<NodeId, f64> = HashMap::new();
182 let mut previous: HashMap<NodeId, (NodeId, RelationshipId)> = HashMap::new();
183 let mut heap: BinaryHeap<DijkstraEntry> = BinaryHeap::new();
184 let mut visited: HashSet<NodeId> = HashSet::new();
185 let mut results: HashMap<NodeId, (f64, Option<Path>)> = HashMap::new();
186
187 // Initialize start node
188 distances.insert(start, 0.0);
189 heap.push(DijkstraEntry {
190 node: start,
191 distance: 0.0,
192 previous: None,
193 });
194
195 while let Some(current) = heap.pop() {
196 if visited.contains(¤t.node) {
197 continue;
198 }
199
200 visited.insert(current.node);
201
202 // Record result for this node
203 let path = if current.node == start {
204 Some(Path::new())
205 } else {
206 Some(self.reconstruct_path(start, current.node, &previous)?)
207 };
208 results.insert(current.node, (current.distance, path));
209
210 // Explore neighbors
211 let relationships = self.graph.get_node_relationships(
212 current.node,
213 Direction::Both,
214 None,
215 );
216
217 for relationship in relationships {
218 let neighbor = if relationship.start_node == current.node {
219 relationship.end_node
220 } else {
221 relationship.start_node
222 };
223
224 if visited.contains(&neighbor) {
225 continue;
226 }
227
228 let weight = weight_fn
229 .map(|f| f(relationship.id))
230 .unwrap_or(1.0);
231
232 let new_distance = current.distance + weight;
233 let current_distance = distances.get(&neighbor).copied().unwrap_or(f64::INFINITY);
234
235 if new_distance < current_distance {
236 distances.insert(neighbor, new_distance);
237 previous.insert(neighbor, (current.node, relationship.id));
238
239 heap.push(DijkstraEntry {
240 node: neighbor,
241 distance: new_distance,
242 previous: Some((current.node, relationship.id)),
243 });
244 }
245 }
246 }
247
248 Ok(results)
249 }
250
251 /// Find k shortest paths between two nodes
252 pub fn k_shortest_paths(
253 &self,
254 start: NodeId,
255 end: NodeId,
256 k: usize,
257 weight_fn: Option<&WeightFn>,
258 ) -> Result<Vec<Path>> {
259 // Yen's algorithm for k-shortest paths
260 let mut paths = Vec::new();
261
262 // Find first shortest path
263 if let Some(first_path) = self.shortest_path(start, end, weight_fn)? {
264 paths.push(first_path);
265 } else {
266 return Ok(paths); // No path exists
267 }
268
269 let mut candidates: BinaryHeap<PathCandidate> = BinaryHeap::new();
270
271 for i in 1..k {
272 if paths.is_empty() {
273 break;
274 }
275
276 let previous_path = &paths[i - 1];
277
278 // Generate candidate paths by deviating from each node in the previous path
279 for j in 0..previous_path.nodes.len() - 1 {
280 let spur_node = previous_path.nodes[j];
281 let root_path = &previous_path.nodes[0..=j];
282
283 // Remove edges that would lead to already found paths
284 let mut removed_edges = HashSet::new();
285 for path in &paths {
286 if path.nodes.len() > j && path.nodes[0..=j] == root_path[..] {
287 if j + 1 < path.relationships.len() {
288 removed_edges.insert(path.relationships[j]);
289 }
290 }
291 }
292
293 // Find shortest path from spur node to end (excluding removed edges)
294 if let Some(spur_path) = self.shortest_path_excluding(
295 spur_node,
296 end,
297 &removed_edges,
298 weight_fn,
299 )? {
300 // Combine root path with spur path
301 let mut full_path = Path::new();
302
303 // Add root path
304 for &node in root_path {
305 full_path.add_step(node, None, 0.0);
306 }
307
308 // Add spur path (skip first node as it's already included)
309 for (idx, &node) in spur_path.nodes.iter().skip(1).enumerate() {
310 let rel = if idx < spur_path.relationships.len() {
311 Some(spur_path.relationships[idx])
312 } else {
313 None
314 };
315 full_path.add_step(node, rel, 0.0);
316 }
317
318 // Recalculate total weight
319 let total_weight = self.calculate_path_weight(&full_path, weight_fn)?;
320 full_path.total_weight = total_weight;
321
322 candidates.push(PathCandidate {
323 path: full_path,
324 weight: total_weight,
325 });
326 }
327 }
328
329 if let Some(best_candidate) = candidates.pop() {
330 paths.push(best_candidate.path);
331 } else {
332 break; // No more candidates
333 }
334 }
335
336 Ok(paths)
337 }
338
339 /// Breadth-first search traversal
340 pub fn bfs(&self, start: NodeId, max_depth: Option<usize>) -> Result<Vec<NodeId>> {
341 let mut visited = HashSet::new();
342 let mut queue = VecDeque::new();
343 let mut result = Vec::new();
344
345 queue.push_back((start, 0));
346 visited.insert(start);
347
348 while let Some((node, depth)) = queue.pop_front() {
349 result.push(node);
350
351 if let Some(max_d) = max_depth {
352 if depth >= max_d {
353 continue;
354 }
355 }
356
357 let relationships = self.graph.get_node_relationships(node, Direction::Both, None);
358 for relationship in relationships {
359 let neighbor = if relationship.start_node == node {
360 relationship.end_node
361 } else {
362 relationship.start_node
363 };
364
365 if !visited.contains(&neighbor) {
366 visited.insert(neighbor);
367 queue.push_back((neighbor, depth + 1));
368 }
369 }
370 }
371
372 Ok(result)
373 }
374
375 /// Depth-first search traversal
376 pub fn dfs(&self, start: NodeId, max_depth: Option<usize>) -> Result<Vec<NodeId>> {
377 let mut visited = HashSet::new();
378 let mut result = Vec::new();
379
380 self.dfs_recursive(start, &mut visited, &mut result, 0, max_depth)?;
381
382 Ok(result)
383 }
384
385 /// Find connected components using Union-Find
386 pub fn connected_components(&self) -> Result<Vec<Vec<NodeId>>> {
387 let all_nodes = self.get_all_nodes()?;
388 let mut parent: HashMap<NodeId, NodeId> = HashMap::new();
389 let mut rank: HashMap<NodeId, usize> = HashMap::new();
390
391 // Initialize Union-Find
392 for &node in &all_nodes {
393 parent.insert(node, node);
394 rank.insert(node, 0);
395 }
396
397 // Process all relationships
398 for &node in &all_nodes {
399 let relationships = self.graph.get_node_relationships(node, Direction::Both, None);
400 for relationship in relationships {
401 let other = if relationship.start_node == node {
402 relationship.end_node
403 } else {
404 relationship.start_node
405 };
406
407 self.union(&mut parent, &mut rank, node, other);
408 }
409 }
410
411 // Group nodes by their root parent
412 let mut components: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
413 for &node in &all_nodes {
414 let root = self.find(&mut parent, node);
415 components.entry(root).or_insert_with(Vec::new).push(node);
416 }
417
418 Ok(components.into_values().collect())
419 }
420
421 /// Check if the graph has cycles (for directed graphs)
422 pub fn has_cycle(&self) -> Result<bool> {
423 let all_nodes = self.get_all_nodes()?;
424 let mut visited = HashSet::new();
425 let mut rec_stack = HashSet::new();
426
427 for &node in &all_nodes {
428 if !visited.contains(&node) {
429 if self.has_cycle_dfs(node, &mut visited, &mut rec_stack)? {
430 return Ok(true);
431 }
432 }
433 }
434
435 Ok(false)
436 }
437
438 // Helper methods
439
440 fn reconstruct_path(
441 &self,
442 start: NodeId,
443 end: NodeId,
444 previous: &HashMap<NodeId, (NodeId, RelationshipId)>,
445 ) -> Result<Path> {
446 let mut path = Path::new();
447 let mut current = end;
448 let mut nodes = Vec::new();
449 let mut relationships = Vec::new();
450
451 while current != start {
452 nodes.push(current);
453 if let Some(&(prev_node, rel_id)) = previous.get(¤t) {
454 relationships.push(rel_id);
455 current = prev_node;
456 } else {
457 return Err(crate::error::GigabrainError::Algorithm(
458 "Invalid path reconstruction".to_string(),
459 ));
460 }
461 }
462
463 nodes.push(start);
464 nodes.reverse();
465 relationships.reverse();
466
467 path.nodes = nodes;
468 path.relationships = relationships;
469
470 Ok(path)
471 }
472
473 fn shortest_path_excluding(
474 &self,
475 start: NodeId,
476 end: NodeId,
477 excluded_edges: &HashSet<RelationshipId>,
478 weight_fn: Option<&WeightFn>,
479 ) -> Result<Option<Path>> {
480 // Similar to shortest_path but excludes certain edges
481 let mut distances: HashMap<NodeId, f64> = HashMap::new();
482 let mut previous: HashMap<NodeId, (NodeId, RelationshipId)> = HashMap::new();
483 let mut heap: BinaryHeap<DijkstraEntry> = BinaryHeap::new();
484 let mut visited: HashSet<NodeId> = HashSet::new();
485
486 distances.insert(start, 0.0);
487 heap.push(DijkstraEntry {
488 node: start,
489 distance: 0.0,
490 previous: None,
491 });
492
493 while let Some(current) = heap.pop() {
494 if visited.contains(¤t.node) {
495 continue;
496 }
497
498 visited.insert(current.node);
499
500 if current.node == end {
501 return Ok(Some(self.reconstruct_path(start, end, &previous)?));
502 }
503
504 let relationships = self.graph.get_node_relationships(
505 current.node,
506 Direction::Both,
507 None,
508 );
509
510 for relationship in relationships {
511 if excluded_edges.contains(&relationship.id) {
512 continue; // Skip excluded edges
513 }
514
515 let neighbor = if relationship.start_node == current.node {
516 relationship.end_node
517 } else {
518 relationship.start_node
519 };
520
521 if visited.contains(&neighbor) {
522 continue;
523 }
524
525 let weight = weight_fn
526 .map(|f| f(relationship.id))
527 .unwrap_or(1.0);
528
529 let new_distance = current.distance + weight;
530 let current_distance = distances.get(&neighbor).copied().unwrap_or(f64::INFINITY);
531
532 if new_distance < current_distance {
533 distances.insert(neighbor, new_distance);
534 previous.insert(neighbor, (current.node, relationship.id));
535
536 heap.push(DijkstraEntry {
537 node: neighbor,
538 distance: new_distance,
539 previous: Some((current.node, relationship.id)),
540 });
541 }
542 }
543 }
544
545 Ok(None)
546 }
547
548 fn calculate_path_weight(&self, path: &Path, weight_fn: Option<&WeightFn>) -> Result<f64> {
549 let mut total_weight = 0.0;
550
551 for &rel_id in &path.relationships {
552 let weight = weight_fn
553 .map(|f| f(rel_id))
554 .unwrap_or(1.0);
555 total_weight += weight;
556 }
557
558 Ok(total_weight)
559 }
560
561 fn dfs_recursive(
562 &self,
563 node: NodeId,
564 visited: &mut HashSet<NodeId>,
565 result: &mut Vec<NodeId>,
566 depth: usize,
567 max_depth: Option<usize>,
568 ) -> Result<()> {
569 visited.insert(node);
570 result.push(node);
571
572 if let Some(max_d) = max_depth {
573 if depth >= max_d {
574 return Ok(());
575 }
576 }
577
578 let relationships = self.graph.get_node_relationships(node, Direction::Both, None);
579 for relationship in relationships {
580 let neighbor = if relationship.start_node == node {
581 relationship.end_node
582 } else {
583 relationship.start_node
584 };
585
586 if !visited.contains(&neighbor) {
587 self.dfs_recursive(neighbor, visited, result, depth + 1, max_depth)?;
588 }
589 }
590
591 Ok(())
592 }
593
594 fn get_all_nodes(&self) -> Result<Vec<NodeId>> {
595 Ok(self.graph.get_all_nodes())
596 }
597
598 /// Get comprehensive graph statistics
599 pub fn get_stats(graph: &Graph) -> GraphStats {
600 let nodes = graph.get_all_nodes();
601 let node_count = nodes.len() as u64;
602
603 let mut relationship_count = 0u64;
604 for &node in &nodes {
605 let rels = graph.get_node_relationships(node, Direction::Outgoing, None);
606 relationship_count += rels.len() as u64;
607 }
608
609 let schema = graph.schema().read();
610 let label_count = schema.labels.len() as u64;
611 let property_key_count = schema.property_keys.len() as u64;
612 let relationship_type_count = schema.relationship_types.len() as u64;
613
614 GraphStats {
615 node_count,
616 relationship_count,
617 label_count,
618 property_key_count,
619 relationship_type_count,
620 }
621 }
622
623 fn find(&self, parent: &mut HashMap<NodeId, NodeId>, node: NodeId) -> NodeId {
624 let parent_node = parent[&node];
625 if parent_node != node {
626 let root = self.find(parent, parent_node);
627 parent.insert(node, root);
628 }
629 parent[&node]
630 }
631
632 fn union(
633 &self,
634 parent: &mut HashMap<NodeId, NodeId>,
635 rank: &mut HashMap<NodeId, usize>,
636 x: NodeId,
637 y: NodeId,
638 ) {
639 let root_x = self.find(parent, x);
640 let root_y = self.find(parent, y);
641
642 if root_x != root_y {
643 match rank[&root_x].cmp(&rank[&root_y]) {
644 Ordering::Less => {
645 parent.insert(root_x, root_y);
646 }
647 Ordering::Greater => {
648 parent.insert(root_y, root_x);
649 }
650 Ordering::Equal => {
651 parent.insert(root_y, root_x);
652 rank.insert(root_x, rank[&root_x] + 1);
653 }
654 }
655 }
656 }
657
658 fn has_cycle_dfs(
659 &self,
660 node: NodeId,
661 visited: &mut HashSet<NodeId>,
662 rec_stack: &mut HashSet<NodeId>,
663 ) -> Result<bool> {
664 visited.insert(node);
665 rec_stack.insert(node);
666
667 let relationships = self.graph.get_node_relationships(node, Direction::Outgoing, None);
668 for relationship in relationships {
669 let neighbor = relationship.end_node;
670
671 if !visited.contains(&neighbor) {
672 if self.has_cycle_dfs(neighbor, visited, rec_stack)? {
673 return Ok(true);
674 }
675 } else if rec_stack.contains(&neighbor) {
676 return Ok(true);
677 }
678 }
679
680 rec_stack.remove(&node);
681 Ok(false)
682 }
683}
684
685#[derive(Debug, Clone)]
686struct PathCandidate {
687 path: Path,
688 weight: f64,
689}
690
691impl PartialEq for PathCandidate {
692 fn eq(&self, other: &Self) -> bool {
693 self.weight == other.weight
694 }
695}
696
697impl Eq for PathCandidate {}
698
699impl PartialOrd for PathCandidate {
700 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
701 Some(self.cmp(other))
702 }
703}
704
705impl Ord for PathCandidate {
706 fn cmp(&self, other: &Self) -> Ordering {
707 // Reverse ordering for min-heap
708 other.weight.partial_cmp(&self.weight).unwrap_or(Ordering::Equal)
709 }
710}
711
712#[cfg(test)]
713mod tests {
714 use super::*;
715 use crate::Graph;
716 use std::sync::Arc;
717
718 fn create_test_graph() -> Graph {
719 let graph = Graph::new();
720
721 // Create a simple graph: A -> B -> C -> D
722 // | |
723 // v v
724 // E ------> F
725
726 let node_a = graph.create_node();
727 let node_b = graph.create_node();
728 let node_c = graph.create_node();
729 let node_d = graph.create_node();
730 let node_e = graph.create_node();
731 let node_f = graph.create_node();
732
733 let schema = graph.schema();
734 let mut schema = schema.write();
735 let rel_type = schema.get_or_create_relationship_type("CONNECTS");
736 drop(schema);
737
738 // Create relationships
739 graph.create_relationship(node_a, node_b, rel_type).unwrap();
740 graph.create_relationship(node_b, node_c, rel_type).unwrap();
741 graph.create_relationship(node_c, node_d, rel_type).unwrap();
742 graph.create_relationship(node_a, node_e, rel_type).unwrap();
743 graph.create_relationship(node_c, node_f, rel_type).unwrap();
744 graph.create_relationship(node_e, node_f, rel_type).unwrap();
745
746 graph
747 }
748
749 #[test]
750 fn test_bfs_traversal() {
751 let graph = create_test_graph();
752 let algorithms = GraphAlgorithms::new(&graph);
753
754 // Test BFS from first node
755 let nodes: Vec<_> = graph.get_all_nodes();
756 if let Some(&start_node) = nodes.first() {
757 let result = algorithms.bfs(start_node, Some(2)).unwrap();
758 assert!(!result.is_empty());
759 assert_eq!(result[0], start_node);
760 }
761 }
762
763 #[test]
764 fn test_dfs_traversal() {
765 let graph = create_test_graph();
766 let algorithms = GraphAlgorithms::new(&graph);
767
768 // Test DFS from first node
769 let nodes: Vec<_> = graph.get_all_nodes();
770 if let Some(&start_node) = nodes.first() {
771 let result = algorithms.dfs(start_node, Some(3)).unwrap();
772 assert!(!result.is_empty());
773 assert_eq!(result[0], start_node);
774 }
775 }
776
777 #[test]
778 fn test_shortest_path() {
779 let graph = create_test_graph();
780 let algorithms = GraphAlgorithms::new(&graph);
781
782 let nodes: Vec<_> = graph.get_all_nodes();
783 if nodes.len() >= 2 {
784 let start = nodes[0];
785 let end = nodes[1];
786
787 let path = algorithms.shortest_path(start, end, None).unwrap();
788 if let Some(path) = path {
789 assert!(!path.is_empty());
790 assert_eq!(path.nodes[0], start);
791 assert_eq!(*path.nodes.last().unwrap(), end);
792 }
793 }
794 }
795}