use crate::{Graph, NodeId, RelationshipId, Result}; use crate::core::relationship::Direction; use super::{Path, WeightFn}; use std::collections::{HashMap, HashSet, VecDeque, BinaryHeap}; use std::cmp::Ordering; /// A* pathfinding algorithm implementation pub struct AStar<'a> { graph: &'a Graph, } impl<'a> AStar<'a> { pub fn new(graph: &'a Graph) -> Self { Self { graph } } /// Find shortest path using A* algorithm with heuristic function pub fn find_path( &self, start: NodeId, goal: NodeId, weight_fn: Option<&WeightFn>, heuristic: H, ) -> Result> where H: Fn(NodeId, NodeId) -> f64, { let mut open_set = BinaryHeap::new(); let mut came_from: HashMap = HashMap::new(); let mut g_score: HashMap = HashMap::new(); let mut f_score: HashMap = HashMap::new(); g_score.insert(start, 0.0); f_score.insert(start, heuristic(start, goal)); open_set.push(AStarEntry { node: start, f_score: heuristic(start, goal), }); while let Some(current_entry) = open_set.pop() { let current = current_entry.node; if current == goal { return Ok(Some(self.reconstruct_path(start, goal, &came_from)?)); } let current_g_score = g_score.get(¤t).copied().unwrap_or(f64::INFINITY); let relationships = self.graph.get_node_relationships(current, Direction::Both, None); for relationship in relationships { let neighbor = if relationship.start_node == current { relationship.end_node } else { relationship.start_node }; let edge_weight = weight_fn .map(|f| f(relationship.id)) .unwrap_or(1.0); let tentative_g_score = current_g_score + edge_weight; let neighbor_g_score = g_score.get(&neighbor).copied().unwrap_or(f64::INFINITY); if tentative_g_score < neighbor_g_score { came_from.insert(neighbor, (current, relationship.id)); g_score.insert(neighbor, tentative_g_score); let new_f_score = tentative_g_score + heuristic(neighbor, goal); f_score.insert(neighbor, new_f_score); open_set.push(AStarEntry { node: neighbor, f_score: new_f_score, }); } } } Ok(None) // No path found } fn reconstruct_path( &self, start: NodeId, goal: NodeId, came_from: &HashMap, ) -> Result { let mut path = Path::new(); let mut current = goal; let mut nodes = Vec::new(); let mut relationships = Vec::new(); while current != start { nodes.push(current); if let Some(&(prev_node, rel_id)) = came_from.get(¤t) { relationships.push(rel_id); current = prev_node; } else { return Err(crate::error::GigabrainError::Algorithm( "Invalid path reconstruction in A*".to_string(), )); } } nodes.push(start); nodes.reverse(); relationships.reverse(); path.nodes = nodes; path.relationships = relationships; Ok(path) } } #[derive(Debug, Clone)] struct AStarEntry { node: NodeId, f_score: f64, } impl PartialEq for AStarEntry { fn eq(&self, other: &Self) -> bool { self.f_score == other.f_score } } impl Eq for AStarEntry {} impl PartialOrd for AStarEntry { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } impl Ord for AStarEntry { fn cmp(&self, other: &Self) -> Ordering { // Reverse ordering for min-heap other.f_score.partial_cmp(&self.f_score).unwrap_or(Ordering::Equal) } } /// Bidirectional search implementation pub struct BidirectionalSearch<'a> { graph: &'a Graph, } impl<'a> BidirectionalSearch<'a> { pub fn new(graph: &'a Graph) -> Self { Self { graph } } /// Find shortest path using bidirectional search pub fn find_path( &self, start: NodeId, goal: NodeId, weight_fn: Option<&WeightFn>, ) -> Result> { let mut forward_visited: HashMap)> = HashMap::new(); let mut backward_visited: HashMap)> = HashMap::new(); let mut forward_queue = VecDeque::new(); let mut backward_queue = VecDeque::new(); forward_visited.insert(start, (0.0, None)); backward_visited.insert(goal, (0.0, None)); forward_queue.push_back(start); backward_queue.push_back(goal); let mut meeting_point = None; let mut min_distance = f64::INFINITY; while !forward_queue.is_empty() || !backward_queue.is_empty() { // Expand forward search if !forward_queue.is_empty() { let current = forward_queue.pop_front().unwrap(); let (current_dist, _) = forward_visited[¤t]; let relationships = self.graph.get_node_relationships(current, Direction::Both, None); for relationship in relationships { let neighbor = if relationship.start_node == current { relationship.end_node } else { relationship.start_node }; let edge_weight = weight_fn .map(|f| f(relationship.id)) .unwrap_or(1.0); let new_dist = current_dist + edge_weight; let should_update = forward_visited .get(&neighbor) .map_or(true, |(dist, _)| new_dist < *dist); if should_update { forward_visited.insert(neighbor, (new_dist, Some((current, relationship.id)))); forward_queue.push_back(neighbor); // Check if we've met the backward search if let Some((backward_dist, _)) = backward_visited.get(&neighbor) { let total_dist = new_dist + backward_dist; if total_dist < min_distance { min_distance = total_dist; meeting_point = Some(neighbor); } } } } } // Expand backward search if !backward_queue.is_empty() { let current = backward_queue.pop_front().unwrap(); let (current_dist, _) = backward_visited[¤t]; let relationships = self.graph.get_node_relationships(current, Direction::Both, None); for relationship in relationships { let neighbor = if relationship.start_node == current { relationship.end_node } else { relationship.start_node }; let edge_weight = weight_fn .map(|f| f(relationship.id)) .unwrap_or(1.0); let new_dist = current_dist + edge_weight; let should_update = backward_visited .get(&neighbor) .map_or(true, |(dist, _)| new_dist < *dist); if should_update { backward_visited.insert(neighbor, (new_dist, Some((current, relationship.id)))); backward_queue.push_back(neighbor); // Check if we've met the forward search if let Some((forward_dist, _)) = forward_visited.get(&neighbor) { let total_dist = forward_dist + new_dist; if total_dist < min_distance { min_distance = total_dist; meeting_point = Some(neighbor); } } } } } } if let Some(meeting) = meeting_point { Ok(Some(self.reconstruct_bidirectional_path( start, goal, meeting, &forward_visited, &backward_visited, )?)) } else { Ok(None) } } fn reconstruct_bidirectional_path( &self, start: NodeId, goal: NodeId, meeting: NodeId, forward_visited: &HashMap)>, backward_visited: &HashMap)>, ) -> Result { let mut path = Path::new(); // Build forward path from start to meeting point let mut forward_nodes = Vec::new(); let mut forward_rels = Vec::new(); let mut current = meeting; while current != start { forward_nodes.push(current); if let Some((_, Some((prev, rel)))) = forward_visited.get(¤t) { forward_rels.push(*rel); current = *prev; } else { return Err(crate::error::GigabrainError::Algorithm( "Invalid forward path reconstruction".to_string(), )); } } forward_nodes.push(start); forward_nodes.reverse(); forward_rels.reverse(); // Build backward path from meeting point to goal let mut backward_nodes = Vec::new(); let mut backward_rels = Vec::new(); current = meeting; while current != goal { if let Some((_, Some((next, rel)))) = backward_visited.get(¤t) { backward_nodes.push(*next); backward_rels.push(*rel); current = *next; } else { return Err(crate::error::GigabrainError::Algorithm( "Invalid backward path reconstruction".to_string(), )); } } // Combine paths path.nodes = forward_nodes; path.nodes.extend(backward_nodes); path.relationships = forward_rels; path.relationships.extend(backward_rels); Ok(path) } } /// All-pairs shortest paths using Floyd-Warshall algorithm pub struct FloydWarshall<'a> { graph: &'a Graph, } impl<'a> FloydWarshall<'a> { pub fn new(graph: &'a Graph) -> Self { Self { graph } } /// Compute all-pairs shortest paths pub fn compute_all_pairs( &self, nodes: &[NodeId], weight_fn: Option<&WeightFn>, ) -> Result>> { let n = nodes.len(); let mut dist: HashMap<(NodeId, NodeId), f64> = HashMap::new(); let mut next: HashMap<(NodeId, NodeId), Option> = HashMap::new(); // Initialize distances for &i in nodes { for &j in nodes { if i == j { dist.insert((i, j), 0.0); } else { dist.insert((i, j), f64::INFINITY); } next.insert((i, j), None); } } // Set distances for direct edges for &node in nodes { let relationships = self.graph.get_node_relationships(node, Direction::Both, None); for relationship in relationships { let neighbor = if relationship.start_node == node { relationship.end_node } else { relationship.start_node }; if nodes.contains(&neighbor) { let weight = weight_fn .map(|f| f(relationship.id)) .unwrap_or(1.0); dist.insert((node, neighbor), weight); next.insert((node, neighbor), Some(neighbor)); } } } // Floyd-Warshall algorithm for &k in nodes { for &i in nodes { for &j in nodes { let dist_ik = dist[&(i, k)]; let dist_kj = dist[&(k, j)]; let dist_ij = dist[&(i, j)]; if dist_ik + dist_kj < dist_ij { dist.insert((i, j), dist_ik + dist_kj); next.insert((i, j), next[&(i, k)]); } } } } // Reconstruct paths let mut result = HashMap::new(); for &i in nodes { for &j in nodes { if i != j && dist[&(i, j)] != f64::INFINITY { result.insert((i, j), Some(self.reconstruct_floyd_warshall_path(i, j, &next)?)); } else { result.insert((i, j), None); } } } Ok(result) } fn reconstruct_floyd_warshall_path( &self, start: NodeId, end: NodeId, next: &HashMap<(NodeId, NodeId), Option>, ) -> Result { let mut path = Path::new(); let mut current = start; path.nodes.push(current); while current != end { if let Some(Some(next_node)) = next.get(&(current, end)) { path.nodes.push(*next_node); current = *next_node; } else { return Err(crate::error::GigabrainError::Algorithm( "Invalid Floyd-Warshall path reconstruction".to_string(), )); } } Ok(path) } } #[cfg(test)] mod tests { use super::*; use crate::Graph; fn create_test_graph() -> Graph { let graph = Graph::new(); // Create nodes let node_a = graph.create_node(); let node_b = graph.create_node(); let node_c = graph.create_node(); let schema = graph.schema(); let mut schema = schema.write(); let rel_type = schema.get_or_create_relationship_type("CONNECTS"); drop(schema); // Create relationships: A -> B -> C graph.create_relationship(node_a, node_b, rel_type).unwrap(); graph.create_relationship(node_b, node_c, rel_type).unwrap(); graph } #[test] fn test_a_star_pathfinding() { let graph = create_test_graph(); let astar = AStar::new(&graph); let nodes: Vec<_> = graph.get_all_nodes(); if nodes.len() >= 2 { let start = nodes[0]; let goal = nodes[nodes.len() - 1]; // Simple heuristic (always returns 0, making it equivalent to Dijkstra) let heuristic = |_: NodeId, _: NodeId| 0.0; let path = astar.find_path(start, goal, None, heuristic).unwrap(); if let Some(path) = path { assert!(!path.is_empty()); assert_eq!(path.nodes[0], start); assert_eq!(*path.nodes.last().unwrap(), goal); } } } #[test] fn test_bidirectional_search() { let graph = create_test_graph(); let bidirectional = BidirectionalSearch::new(&graph); let nodes: Vec<_> = graph.get_all_nodes(); if nodes.len() >= 2 { let start = nodes[0]; let goal = nodes[nodes.len() - 1]; let path = bidirectional.find_path(start, goal, None).unwrap(); if let Some(path) = path { assert!(!path.is_empty()); assert_eq!(path.nodes[0], start); assert_eq!(*path.nodes.last().unwrap(), goal); } } } #[test] fn test_floyd_warshall() { let graph = create_test_graph(); let floyd_warshall = FloydWarshall::new(&graph); let nodes: Vec<_> = graph.get_all_nodes(); if nodes.len() >= 2 { let all_pairs = floyd_warshall.compute_all_pairs(&nodes, None).unwrap(); // Check that we have results for all pairs for &i in &nodes { for &j in &nodes { assert!(all_pairs.contains_key(&(i, j))); } } } } }