this repo has no description
1use crate::{Graph, NodeId, RelationshipId, Result};
2use crate::core::relationship::Direction;
3use super::{Path, WeightFn};
4use std::collections::{HashMap, HashSet, VecDeque, BinaryHeap};
5use std::cmp::Ordering;
6
7/// A* pathfinding algorithm implementation
8pub struct AStar<'a> {
9 graph: &'a Graph,
10}
11
12impl<'a> AStar<'a> {
13 pub fn new(graph: &'a Graph) -> Self {
14 Self { graph }
15 }
16
17 /// Find shortest path using A* algorithm with heuristic function
18 pub fn find_path<H>(
19 &self,
20 start: NodeId,
21 goal: NodeId,
22 weight_fn: Option<&WeightFn>,
23 heuristic: H,
24 ) -> Result<Option<Path>>
25 where
26 H: Fn(NodeId, NodeId) -> f64,
27 {
28 let mut open_set = BinaryHeap::new();
29 let mut came_from: HashMap<NodeId, (NodeId, RelationshipId)> = HashMap::new();
30 let mut g_score: HashMap<NodeId, f64> = HashMap::new();
31 let mut f_score: HashMap<NodeId, f64> = HashMap::new();
32
33 g_score.insert(start, 0.0);
34 f_score.insert(start, heuristic(start, goal));
35
36 open_set.push(AStarEntry {
37 node: start,
38 f_score: heuristic(start, goal),
39 });
40
41 while let Some(current_entry) = open_set.pop() {
42 let current = current_entry.node;
43
44 if current == goal {
45 return Ok(Some(self.reconstruct_path(start, goal, &came_from)?));
46 }
47
48 let current_g_score = g_score.get(¤t).copied().unwrap_or(f64::INFINITY);
49
50 let relationships = self.graph.get_node_relationships(current, Direction::Both, None);
51 for relationship in relationships {
52 let neighbor = if relationship.start_node == current {
53 relationship.end_node
54 } else {
55 relationship.start_node
56 };
57
58 let edge_weight = weight_fn
59 .map(|f| f(relationship.id))
60 .unwrap_or(1.0);
61
62 let tentative_g_score = current_g_score + edge_weight;
63 let neighbor_g_score = g_score.get(&neighbor).copied().unwrap_or(f64::INFINITY);
64
65 if tentative_g_score < neighbor_g_score {
66 came_from.insert(neighbor, (current, relationship.id));
67 g_score.insert(neighbor, tentative_g_score);
68 let new_f_score = tentative_g_score + heuristic(neighbor, goal);
69 f_score.insert(neighbor, new_f_score);
70
71 open_set.push(AStarEntry {
72 node: neighbor,
73 f_score: new_f_score,
74 });
75 }
76 }
77 }
78
79 Ok(None) // No path found
80 }
81
82 fn reconstruct_path(
83 &self,
84 start: NodeId,
85 goal: NodeId,
86 came_from: &HashMap<NodeId, (NodeId, RelationshipId)>,
87 ) -> Result<Path> {
88 let mut path = Path::new();
89 let mut current = goal;
90 let mut nodes = Vec::new();
91 let mut relationships = Vec::new();
92
93 while current != start {
94 nodes.push(current);
95 if let Some(&(prev_node, rel_id)) = came_from.get(¤t) {
96 relationships.push(rel_id);
97 current = prev_node;
98 } else {
99 return Err(crate::error::GigabrainError::Algorithm(
100 "Invalid path reconstruction in A*".to_string(),
101 ));
102 }
103 }
104
105 nodes.push(start);
106 nodes.reverse();
107 relationships.reverse();
108
109 path.nodes = nodes;
110 path.relationships = relationships;
111
112 Ok(path)
113 }
114}
115
116#[derive(Debug, Clone)]
117struct AStarEntry {
118 node: NodeId,
119 f_score: f64,
120}
121
122impl PartialEq for AStarEntry {
123 fn eq(&self, other: &Self) -> bool {
124 self.f_score == other.f_score
125 }
126}
127
128impl Eq for AStarEntry {}
129
130impl PartialOrd for AStarEntry {
131 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
132 Some(self.cmp(other))
133 }
134}
135
136impl Ord for AStarEntry {
137 fn cmp(&self, other: &Self) -> Ordering {
138 // Reverse ordering for min-heap
139 other.f_score.partial_cmp(&self.f_score).unwrap_or(Ordering::Equal)
140 }
141}
142
143/// Bidirectional search implementation
144pub struct BidirectionalSearch<'a> {
145 graph: &'a Graph,
146}
147
148impl<'a> BidirectionalSearch<'a> {
149 pub fn new(graph: &'a Graph) -> Self {
150 Self { graph }
151 }
152
153 /// Find shortest path using bidirectional search
154 pub fn find_path(
155 &self,
156 start: NodeId,
157 goal: NodeId,
158 weight_fn: Option<&WeightFn>,
159 ) -> Result<Option<Path>> {
160 let mut forward_visited: HashMap<NodeId, (f64, Option<(NodeId, RelationshipId)>)> = HashMap::new();
161 let mut backward_visited: HashMap<NodeId, (f64, Option<(NodeId, RelationshipId)>)> = HashMap::new();
162
163 let mut forward_queue = VecDeque::new();
164 let mut backward_queue = VecDeque::new();
165
166 forward_visited.insert(start, (0.0, None));
167 backward_visited.insert(goal, (0.0, None));
168
169 forward_queue.push_back(start);
170 backward_queue.push_back(goal);
171
172 let mut meeting_point = None;
173 let mut min_distance = f64::INFINITY;
174
175 while !forward_queue.is_empty() || !backward_queue.is_empty() {
176 // Expand forward search
177 if !forward_queue.is_empty() {
178 let current = forward_queue.pop_front().unwrap();
179 let (current_dist, _) = forward_visited[¤t];
180
181 let relationships = self.graph.get_node_relationships(current, Direction::Both, None);
182 for relationship in relationships {
183 let neighbor = if relationship.start_node == current {
184 relationship.end_node
185 } else {
186 relationship.start_node
187 };
188
189 let edge_weight = weight_fn
190 .map(|f| f(relationship.id))
191 .unwrap_or(1.0);
192
193 let new_dist = current_dist + edge_weight;
194
195 let should_update = forward_visited
196 .get(&neighbor)
197 .map_or(true, |(dist, _)| new_dist < *dist);
198
199 if should_update {
200 forward_visited.insert(neighbor, (new_dist, Some((current, relationship.id))));
201 forward_queue.push_back(neighbor);
202
203 // Check if we've met the backward search
204 if let Some((backward_dist, _)) = backward_visited.get(&neighbor) {
205 let total_dist = new_dist + backward_dist;
206 if total_dist < min_distance {
207 min_distance = total_dist;
208 meeting_point = Some(neighbor);
209 }
210 }
211 }
212 }
213 }
214
215 // Expand backward search
216 if !backward_queue.is_empty() {
217 let current = backward_queue.pop_front().unwrap();
218 let (current_dist, _) = backward_visited[¤t];
219
220 let relationships = self.graph.get_node_relationships(current, Direction::Both, None);
221 for relationship in relationships {
222 let neighbor = if relationship.start_node == current {
223 relationship.end_node
224 } else {
225 relationship.start_node
226 };
227
228 let edge_weight = weight_fn
229 .map(|f| f(relationship.id))
230 .unwrap_or(1.0);
231
232 let new_dist = current_dist + edge_weight;
233
234 let should_update = backward_visited
235 .get(&neighbor)
236 .map_or(true, |(dist, _)| new_dist < *dist);
237
238 if should_update {
239 backward_visited.insert(neighbor, (new_dist, Some((current, relationship.id))));
240 backward_queue.push_back(neighbor);
241
242 // Check if we've met the forward search
243 if let Some((forward_dist, _)) = forward_visited.get(&neighbor) {
244 let total_dist = forward_dist + new_dist;
245 if total_dist < min_distance {
246 min_distance = total_dist;
247 meeting_point = Some(neighbor);
248 }
249 }
250 }
251 }
252 }
253 }
254
255 if let Some(meeting) = meeting_point {
256 Ok(Some(self.reconstruct_bidirectional_path(
257 start,
258 goal,
259 meeting,
260 &forward_visited,
261 &backward_visited,
262 )?))
263 } else {
264 Ok(None)
265 }
266 }
267
268 fn reconstruct_bidirectional_path(
269 &self,
270 start: NodeId,
271 goal: NodeId,
272 meeting: NodeId,
273 forward_visited: &HashMap<NodeId, (f64, Option<(NodeId, RelationshipId)>)>,
274 backward_visited: &HashMap<NodeId, (f64, Option<(NodeId, RelationshipId)>)>,
275 ) -> Result<Path> {
276 let mut path = Path::new();
277
278 // Build forward path from start to meeting point
279 let mut forward_nodes = Vec::new();
280 let mut forward_rels = Vec::new();
281 let mut current = meeting;
282
283 while current != start {
284 forward_nodes.push(current);
285 if let Some((_, Some((prev, rel)))) = forward_visited.get(¤t) {
286 forward_rels.push(*rel);
287 current = *prev;
288 } else {
289 return Err(crate::error::GigabrainError::Algorithm(
290 "Invalid forward path reconstruction".to_string(),
291 ));
292 }
293 }
294 forward_nodes.push(start);
295 forward_nodes.reverse();
296 forward_rels.reverse();
297
298 // Build backward path from meeting point to goal
299 let mut backward_nodes = Vec::new();
300 let mut backward_rels = Vec::new();
301 current = meeting;
302
303 while current != goal {
304 if let Some((_, Some((next, rel)))) = backward_visited.get(¤t) {
305 backward_nodes.push(*next);
306 backward_rels.push(*rel);
307 current = *next;
308 } else {
309 return Err(crate::error::GigabrainError::Algorithm(
310 "Invalid backward path reconstruction".to_string(),
311 ));
312 }
313 }
314
315 // Combine paths
316 path.nodes = forward_nodes;
317 path.nodes.extend(backward_nodes);
318 path.relationships = forward_rels;
319 path.relationships.extend(backward_rels);
320
321 Ok(path)
322 }
323}
324
325/// All-pairs shortest paths using Floyd-Warshall algorithm
326pub struct FloydWarshall<'a> {
327 graph: &'a Graph,
328}
329
330impl<'a> FloydWarshall<'a> {
331 pub fn new(graph: &'a Graph) -> Self {
332 Self { graph }
333 }
334
335 /// Compute all-pairs shortest paths
336 pub fn compute_all_pairs(
337 &self,
338 nodes: &[NodeId],
339 weight_fn: Option<&WeightFn>,
340 ) -> Result<HashMap<(NodeId, NodeId), Option<Path>>> {
341 let n = nodes.len();
342 let mut dist: HashMap<(NodeId, NodeId), f64> = HashMap::new();
343 let mut next: HashMap<(NodeId, NodeId), Option<NodeId>> = HashMap::new();
344
345 // Initialize distances
346 for &i in nodes {
347 for &j in nodes {
348 if i == j {
349 dist.insert((i, j), 0.0);
350 } else {
351 dist.insert((i, j), f64::INFINITY);
352 }
353 next.insert((i, j), None);
354 }
355 }
356
357 // Set distances for direct edges
358 for &node in nodes {
359 let relationships = self.graph.get_node_relationships(node, Direction::Both, None);
360 for relationship in relationships {
361 let neighbor = if relationship.start_node == node {
362 relationship.end_node
363 } else {
364 relationship.start_node
365 };
366
367 if nodes.contains(&neighbor) {
368 let weight = weight_fn
369 .map(|f| f(relationship.id))
370 .unwrap_or(1.0);
371
372 dist.insert((node, neighbor), weight);
373 next.insert((node, neighbor), Some(neighbor));
374 }
375 }
376 }
377
378 // Floyd-Warshall algorithm
379 for &k in nodes {
380 for &i in nodes {
381 for &j in nodes {
382 let dist_ik = dist[&(i, k)];
383 let dist_kj = dist[&(k, j)];
384 let dist_ij = dist[&(i, j)];
385
386 if dist_ik + dist_kj < dist_ij {
387 dist.insert((i, j), dist_ik + dist_kj);
388 next.insert((i, j), next[&(i, k)]);
389 }
390 }
391 }
392 }
393
394 // Reconstruct paths
395 let mut result = HashMap::new();
396 for &i in nodes {
397 for &j in nodes {
398 if i != j && dist[&(i, j)] != f64::INFINITY {
399 result.insert((i, j), Some(self.reconstruct_floyd_warshall_path(i, j, &next)?));
400 } else {
401 result.insert((i, j), None);
402 }
403 }
404 }
405
406 Ok(result)
407 }
408
409 fn reconstruct_floyd_warshall_path(
410 &self,
411 start: NodeId,
412 end: NodeId,
413 next: &HashMap<(NodeId, NodeId), Option<NodeId>>,
414 ) -> Result<Path> {
415 let mut path = Path::new();
416 let mut current = start;
417
418 path.nodes.push(current);
419
420 while current != end {
421 if let Some(Some(next_node)) = next.get(&(current, end)) {
422 path.nodes.push(*next_node);
423 current = *next_node;
424 } else {
425 return Err(crate::error::GigabrainError::Algorithm(
426 "Invalid Floyd-Warshall path reconstruction".to_string(),
427 ));
428 }
429 }
430
431 Ok(path)
432 }
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438 use crate::Graph;
439
440 fn create_test_graph() -> Graph {
441 let graph = Graph::new();
442
443 // Create nodes
444 let node_a = graph.create_node();
445 let node_b = graph.create_node();
446 let node_c = graph.create_node();
447
448 let schema = graph.schema();
449 let mut schema = schema.write();
450 let rel_type = schema.get_or_create_relationship_type("CONNECTS");
451 drop(schema);
452
453 // Create relationships: A -> B -> C
454 graph.create_relationship(node_a, node_b, rel_type).unwrap();
455 graph.create_relationship(node_b, node_c, rel_type).unwrap();
456
457 graph
458 }
459
460 #[test]
461 fn test_a_star_pathfinding() {
462 let graph = create_test_graph();
463 let astar = AStar::new(&graph);
464
465 let nodes: Vec<_> = graph.get_all_nodes();
466 if nodes.len() >= 2 {
467 let start = nodes[0];
468 let goal = nodes[nodes.len() - 1];
469
470 // Simple heuristic (always returns 0, making it equivalent to Dijkstra)
471 let heuristic = |_: NodeId, _: NodeId| 0.0;
472
473 let path = astar.find_path(start, goal, None, heuristic).unwrap();
474 if let Some(path) = path {
475 assert!(!path.is_empty());
476 assert_eq!(path.nodes[0], start);
477 assert_eq!(*path.nodes.last().unwrap(), goal);
478 }
479 }
480 }
481
482 #[test]
483 fn test_bidirectional_search() {
484 let graph = create_test_graph();
485 let bidirectional = BidirectionalSearch::new(&graph);
486
487 let nodes: Vec<_> = graph.get_all_nodes();
488 if nodes.len() >= 2 {
489 let start = nodes[0];
490 let goal = nodes[nodes.len() - 1];
491
492 let path = bidirectional.find_path(start, goal, None).unwrap();
493 if let Some(path) = path {
494 assert!(!path.is_empty());
495 assert_eq!(path.nodes[0], start);
496 assert_eq!(*path.nodes.last().unwrap(), goal);
497 }
498 }
499 }
500
501 #[test]
502 fn test_floyd_warshall() {
503 let graph = create_test_graph();
504 let floyd_warshall = FloydWarshall::new(&graph);
505
506 let nodes: Vec<_> = graph.get_all_nodes();
507 if nodes.len() >= 2 {
508 let all_pairs = floyd_warshall.compute_all_pairs(&nodes, None).unwrap();
509
510 // Check that we have results for all pairs
511 for &i in &nodes {
512 for &j in &nodes {
513 assert!(all_pairs.contains_key(&(i, j)));
514 }
515 }
516 }
517 }
518}