this repo has no description
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

at main 518 lines 18 kB view raw
1use crate::{Graph, NodeId, RelationshipId, Result}; 2use crate::core::relationship::Direction; 3use super::{Path, WeightFn}; 4use std::collections::{HashMap, HashSet, VecDeque, BinaryHeap}; 5use std::cmp::Ordering; 6 7/// A* pathfinding algorithm implementation 8pub struct AStar<'a> { 9 graph: &'a Graph, 10} 11 12impl<'a> AStar<'a> { 13 pub fn new(graph: &'a Graph) -> Self { 14 Self { graph } 15 } 16 17 /// Find shortest path using A* algorithm with heuristic function 18 pub fn find_path<H>( 19 &self, 20 start: NodeId, 21 goal: NodeId, 22 weight_fn: Option<&WeightFn>, 23 heuristic: H, 24 ) -> Result<Option<Path>> 25 where 26 H: Fn(NodeId, NodeId) -> f64, 27 { 28 let mut open_set = BinaryHeap::new(); 29 let mut came_from: HashMap<NodeId, (NodeId, RelationshipId)> = HashMap::new(); 30 let mut g_score: HashMap<NodeId, f64> = HashMap::new(); 31 let mut f_score: HashMap<NodeId, f64> = HashMap::new(); 32 33 g_score.insert(start, 0.0); 34 f_score.insert(start, heuristic(start, goal)); 35 36 open_set.push(AStarEntry { 37 node: start, 38 f_score: heuristic(start, goal), 39 }); 40 41 while let Some(current_entry) = open_set.pop() { 42 let current = current_entry.node; 43 44 if current == goal { 45 return Ok(Some(self.reconstruct_path(start, goal, &came_from)?)); 46 } 47 48 let current_g_score = g_score.get(&current).copied().unwrap_or(f64::INFINITY); 49 50 let relationships = self.graph.get_node_relationships(current, Direction::Both, None); 51 for relationship in relationships { 52 let neighbor = if relationship.start_node == current { 53 relationship.end_node 54 } else { 55 relationship.start_node 56 }; 57 58 let edge_weight = weight_fn 59 .map(|f| f(relationship.id)) 60 .unwrap_or(1.0); 61 62 let tentative_g_score = current_g_score + edge_weight; 63 let neighbor_g_score = g_score.get(&neighbor).copied().unwrap_or(f64::INFINITY); 64 65 if tentative_g_score < neighbor_g_score { 66 came_from.insert(neighbor, (current, relationship.id)); 67 g_score.insert(neighbor, tentative_g_score); 68 let new_f_score = tentative_g_score + heuristic(neighbor, goal); 69 f_score.insert(neighbor, new_f_score); 70 71 open_set.push(AStarEntry { 72 node: neighbor, 73 f_score: new_f_score, 74 }); 75 } 76 } 77 } 78 79 Ok(None) // No path found 80 } 81 82 fn reconstruct_path( 83 &self, 84 start: NodeId, 85 goal: NodeId, 86 came_from: &HashMap<NodeId, (NodeId, RelationshipId)>, 87 ) -> Result<Path> { 88 let mut path = Path::new(); 89 let mut current = goal; 90 let mut nodes = Vec::new(); 91 let mut relationships = Vec::new(); 92 93 while current != start { 94 nodes.push(current); 95 if let Some(&(prev_node, rel_id)) = came_from.get(&current) { 96 relationships.push(rel_id); 97 current = prev_node; 98 } else { 99 return Err(crate::error::GigabrainError::Algorithm( 100 "Invalid path reconstruction in A*".to_string(), 101 )); 102 } 103 } 104 105 nodes.push(start); 106 nodes.reverse(); 107 relationships.reverse(); 108 109 path.nodes = nodes; 110 path.relationships = relationships; 111 112 Ok(path) 113 } 114} 115 116#[derive(Debug, Clone)] 117struct AStarEntry { 118 node: NodeId, 119 f_score: f64, 120} 121 122impl PartialEq for AStarEntry { 123 fn eq(&self, other: &Self) -> bool { 124 self.f_score == other.f_score 125 } 126} 127 128impl Eq for AStarEntry {} 129 130impl PartialOrd for AStarEntry { 131 fn partial_cmp(&self, other: &Self) -> Option<Ordering> { 132 Some(self.cmp(other)) 133 } 134} 135 136impl Ord for AStarEntry { 137 fn cmp(&self, other: &Self) -> Ordering { 138 // Reverse ordering for min-heap 139 other.f_score.partial_cmp(&self.f_score).unwrap_or(Ordering::Equal) 140 } 141} 142 143/// Bidirectional search implementation 144pub struct BidirectionalSearch<'a> { 145 graph: &'a Graph, 146} 147 148impl<'a> BidirectionalSearch<'a> { 149 pub fn new(graph: &'a Graph) -> Self { 150 Self { graph } 151 } 152 153 /// Find shortest path using bidirectional search 154 pub fn find_path( 155 &self, 156 start: NodeId, 157 goal: NodeId, 158 weight_fn: Option<&WeightFn>, 159 ) -> Result<Option<Path>> { 160 let mut forward_visited: HashMap<NodeId, (f64, Option<(NodeId, RelationshipId)>)> = HashMap::new(); 161 let mut backward_visited: HashMap<NodeId, (f64, Option<(NodeId, RelationshipId)>)> = HashMap::new(); 162 163 let mut forward_queue = VecDeque::new(); 164 let mut backward_queue = VecDeque::new(); 165 166 forward_visited.insert(start, (0.0, None)); 167 backward_visited.insert(goal, (0.0, None)); 168 169 forward_queue.push_back(start); 170 backward_queue.push_back(goal); 171 172 let mut meeting_point = None; 173 let mut min_distance = f64::INFINITY; 174 175 while !forward_queue.is_empty() || !backward_queue.is_empty() { 176 // Expand forward search 177 if !forward_queue.is_empty() { 178 let current = forward_queue.pop_front().unwrap(); 179 let (current_dist, _) = forward_visited[&current]; 180 181 let relationships = self.graph.get_node_relationships(current, Direction::Both, None); 182 for relationship in relationships { 183 let neighbor = if relationship.start_node == current { 184 relationship.end_node 185 } else { 186 relationship.start_node 187 }; 188 189 let edge_weight = weight_fn 190 .map(|f| f(relationship.id)) 191 .unwrap_or(1.0); 192 193 let new_dist = current_dist + edge_weight; 194 195 let should_update = forward_visited 196 .get(&neighbor) 197 .map_or(true, |(dist, _)| new_dist < *dist); 198 199 if should_update { 200 forward_visited.insert(neighbor, (new_dist, Some((current, relationship.id)))); 201 forward_queue.push_back(neighbor); 202 203 // Check if we've met the backward search 204 if let Some((backward_dist, _)) = backward_visited.get(&neighbor) { 205 let total_dist = new_dist + backward_dist; 206 if total_dist < min_distance { 207 min_distance = total_dist; 208 meeting_point = Some(neighbor); 209 } 210 } 211 } 212 } 213 } 214 215 // Expand backward search 216 if !backward_queue.is_empty() { 217 let current = backward_queue.pop_front().unwrap(); 218 let (current_dist, _) = backward_visited[&current]; 219 220 let relationships = self.graph.get_node_relationships(current, Direction::Both, None); 221 for relationship in relationships { 222 let neighbor = if relationship.start_node == current { 223 relationship.end_node 224 } else { 225 relationship.start_node 226 }; 227 228 let edge_weight = weight_fn 229 .map(|f| f(relationship.id)) 230 .unwrap_or(1.0); 231 232 let new_dist = current_dist + edge_weight; 233 234 let should_update = backward_visited 235 .get(&neighbor) 236 .map_or(true, |(dist, _)| new_dist < *dist); 237 238 if should_update { 239 backward_visited.insert(neighbor, (new_dist, Some((current, relationship.id)))); 240 backward_queue.push_back(neighbor); 241 242 // Check if we've met the forward search 243 if let Some((forward_dist, _)) = forward_visited.get(&neighbor) { 244 let total_dist = forward_dist + new_dist; 245 if total_dist < min_distance { 246 min_distance = total_dist; 247 meeting_point = Some(neighbor); 248 } 249 } 250 } 251 } 252 } 253 } 254 255 if let Some(meeting) = meeting_point { 256 Ok(Some(self.reconstruct_bidirectional_path( 257 start, 258 goal, 259 meeting, 260 &forward_visited, 261 &backward_visited, 262 )?)) 263 } else { 264 Ok(None) 265 } 266 } 267 268 fn reconstruct_bidirectional_path( 269 &self, 270 start: NodeId, 271 goal: NodeId, 272 meeting: NodeId, 273 forward_visited: &HashMap<NodeId, (f64, Option<(NodeId, RelationshipId)>)>, 274 backward_visited: &HashMap<NodeId, (f64, Option<(NodeId, RelationshipId)>)>, 275 ) -> Result<Path> { 276 let mut path = Path::new(); 277 278 // Build forward path from start to meeting point 279 let mut forward_nodes = Vec::new(); 280 let mut forward_rels = Vec::new(); 281 let mut current = meeting; 282 283 while current != start { 284 forward_nodes.push(current); 285 if let Some((_, Some((prev, rel)))) = forward_visited.get(&current) { 286 forward_rels.push(*rel); 287 current = *prev; 288 } else { 289 return Err(crate::error::GigabrainError::Algorithm( 290 "Invalid forward path reconstruction".to_string(), 291 )); 292 } 293 } 294 forward_nodes.push(start); 295 forward_nodes.reverse(); 296 forward_rels.reverse(); 297 298 // Build backward path from meeting point to goal 299 let mut backward_nodes = Vec::new(); 300 let mut backward_rels = Vec::new(); 301 current = meeting; 302 303 while current != goal { 304 if let Some((_, Some((next, rel)))) = backward_visited.get(&current) { 305 backward_nodes.push(*next); 306 backward_rels.push(*rel); 307 current = *next; 308 } else { 309 return Err(crate::error::GigabrainError::Algorithm( 310 "Invalid backward path reconstruction".to_string(), 311 )); 312 } 313 } 314 315 // Combine paths 316 path.nodes = forward_nodes; 317 path.nodes.extend(backward_nodes); 318 path.relationships = forward_rels; 319 path.relationships.extend(backward_rels); 320 321 Ok(path) 322 } 323} 324 325/// All-pairs shortest paths using Floyd-Warshall algorithm 326pub struct FloydWarshall<'a> { 327 graph: &'a Graph, 328} 329 330impl<'a> FloydWarshall<'a> { 331 pub fn new(graph: &'a Graph) -> Self { 332 Self { graph } 333 } 334 335 /// Compute all-pairs shortest paths 336 pub fn compute_all_pairs( 337 &self, 338 nodes: &[NodeId], 339 weight_fn: Option<&WeightFn>, 340 ) -> Result<HashMap<(NodeId, NodeId), Option<Path>>> { 341 let n = nodes.len(); 342 let mut dist: HashMap<(NodeId, NodeId), f64> = HashMap::new(); 343 let mut next: HashMap<(NodeId, NodeId), Option<NodeId>> = HashMap::new(); 344 345 // Initialize distances 346 for &i in nodes { 347 for &j in nodes { 348 if i == j { 349 dist.insert((i, j), 0.0); 350 } else { 351 dist.insert((i, j), f64::INFINITY); 352 } 353 next.insert((i, j), None); 354 } 355 } 356 357 // Set distances for direct edges 358 for &node in nodes { 359 let relationships = self.graph.get_node_relationships(node, Direction::Both, None); 360 for relationship in relationships { 361 let neighbor = if relationship.start_node == node { 362 relationship.end_node 363 } else { 364 relationship.start_node 365 }; 366 367 if nodes.contains(&neighbor) { 368 let weight = weight_fn 369 .map(|f| f(relationship.id)) 370 .unwrap_or(1.0); 371 372 dist.insert((node, neighbor), weight); 373 next.insert((node, neighbor), Some(neighbor)); 374 } 375 } 376 } 377 378 // Floyd-Warshall algorithm 379 for &k in nodes { 380 for &i in nodes { 381 for &j in nodes { 382 let dist_ik = dist[&(i, k)]; 383 let dist_kj = dist[&(k, j)]; 384 let dist_ij = dist[&(i, j)]; 385 386 if dist_ik + dist_kj < dist_ij { 387 dist.insert((i, j), dist_ik + dist_kj); 388 next.insert((i, j), next[&(i, k)]); 389 } 390 } 391 } 392 } 393 394 // Reconstruct paths 395 let mut result = HashMap::new(); 396 for &i in nodes { 397 for &j in nodes { 398 if i != j && dist[&(i, j)] != f64::INFINITY { 399 result.insert((i, j), Some(self.reconstruct_floyd_warshall_path(i, j, &next)?)); 400 } else { 401 result.insert((i, j), None); 402 } 403 } 404 } 405 406 Ok(result) 407 } 408 409 fn reconstruct_floyd_warshall_path( 410 &self, 411 start: NodeId, 412 end: NodeId, 413 next: &HashMap<(NodeId, NodeId), Option<NodeId>>, 414 ) -> Result<Path> { 415 let mut path = Path::new(); 416 let mut current = start; 417 418 path.nodes.push(current); 419 420 while current != end { 421 if let Some(Some(next_node)) = next.get(&(current, end)) { 422 path.nodes.push(*next_node); 423 current = *next_node; 424 } else { 425 return Err(crate::error::GigabrainError::Algorithm( 426 "Invalid Floyd-Warshall path reconstruction".to_string(), 427 )); 428 } 429 } 430 431 Ok(path) 432 } 433} 434 435#[cfg(test)] 436mod tests { 437 use super::*; 438 use crate::Graph; 439 440 fn create_test_graph() -> Graph { 441 let graph = Graph::new(); 442 443 // Create nodes 444 let node_a = graph.create_node(); 445 let node_b = graph.create_node(); 446 let node_c = graph.create_node(); 447 448 let schema = graph.schema(); 449 let mut schema = schema.write(); 450 let rel_type = schema.get_or_create_relationship_type("CONNECTS"); 451 drop(schema); 452 453 // Create relationships: A -> B -> C 454 graph.create_relationship(node_a, node_b, rel_type).unwrap(); 455 graph.create_relationship(node_b, node_c, rel_type).unwrap(); 456 457 graph 458 } 459 460 #[test] 461 fn test_a_star_pathfinding() { 462 let graph = create_test_graph(); 463 let astar = AStar::new(&graph); 464 465 let nodes: Vec<_> = graph.get_all_nodes(); 466 if nodes.len() >= 2 { 467 let start = nodes[0]; 468 let goal = nodes[nodes.len() - 1]; 469 470 // Simple heuristic (always returns 0, making it equivalent to Dijkstra) 471 let heuristic = |_: NodeId, _: NodeId| 0.0; 472 473 let path = astar.find_path(start, goal, None, heuristic).unwrap(); 474 if let Some(path) = path { 475 assert!(!path.is_empty()); 476 assert_eq!(path.nodes[0], start); 477 assert_eq!(*path.nodes.last().unwrap(), goal); 478 } 479 } 480 } 481 482 #[test] 483 fn test_bidirectional_search() { 484 let graph = create_test_graph(); 485 let bidirectional = BidirectionalSearch::new(&graph); 486 487 let nodes: Vec<_> = graph.get_all_nodes(); 488 if nodes.len() >= 2 { 489 let start = nodes[0]; 490 let goal = nodes[nodes.len() - 1]; 491 492 let path = bidirectional.find_path(start, goal, None).unwrap(); 493 if let Some(path) = path { 494 assert!(!path.is_empty()); 495 assert_eq!(path.nodes[0], start); 496 assert_eq!(*path.nodes.last().unwrap(), goal); 497 } 498 } 499 } 500 501 #[test] 502 fn test_floyd_warshall() { 503 let graph = create_test_graph(); 504 let floyd_warshall = FloydWarshall::new(&graph); 505 506 let nodes: Vec<_> = graph.get_all_nodes(); 507 if nodes.len() >= 2 { 508 let all_pairs = floyd_warshall.compute_all_pairs(&nodes, None).unwrap(); 509 510 // Check that we have results for all pairs 511 for &i in &nodes { 512 for &j in &nodes { 513 assert!(all_pairs.contains_key(&(i, j))); 514 } 515 } 516 } 517 } 518}