this repo has no description
1use crate::{Graph, NodeId, RelationshipId, Result};
2use crate::core::relationship::Direction;
3use super::WeightFn;
4use std::collections::{HashMap, HashSet, VecDeque};
5
6/// Centrality measures for graph analysis
7pub struct CentralityMeasures<'a> {
8 graph: &'a Graph,
9}
10
11impl<'a> CentralityMeasures<'a> {
12 pub fn new(graph: &'a Graph) -> Self {
13 Self { graph }
14 }
15
16 /// Calculate betweenness centrality for all nodes
17 pub fn betweenness_centrality(
18 &self,
19 nodes: &[NodeId],
20 weight_fn: Option<&WeightFn>,
21 ) -> Result<HashMap<NodeId, f64>> {
22 let mut centrality: HashMap<NodeId, f64> = HashMap::new();
23
24 // Initialize all centralities to 0
25 for &node in nodes {
26 centrality.insert(node, 0.0);
27 }
28
29 // For each node as source
30 for &source in nodes {
31 let mut stack = Vec::new();
32 let mut paths: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
33 let mut sigma: HashMap<NodeId, f64> = HashMap::new();
34 let mut dist: HashMap<NodeId, f64> = HashMap::new();
35 let mut delta: HashMap<NodeId, f64> = HashMap::new();
36
37 // Initialize
38 for &node in nodes {
39 paths.insert(node, Vec::new());
40 sigma.insert(node, 0.0);
41 dist.insert(node, f64::INFINITY);
42 delta.insert(node, 0.0);
43 }
44
45 sigma.insert(source, 1.0);
46 dist.insert(source, 0.0);
47
48 let mut queue = VecDeque::new();
49 queue.push_back(source);
50
51 // BFS to find shortest paths
52 while let Some(current) = queue.pop_front() {
53 stack.push(current);
54
55 let relationships = self.graph.get_node_relationships(current, Direction::Both, None);
56 for relationship in relationships {
57 let neighbor = if relationship.start_node == current {
58 relationship.end_node
59 } else {
60 relationship.start_node
61 };
62
63 if !nodes.contains(&neighbor) {
64 continue;
65 }
66
67 let weight = weight_fn
68 .map(|f| f(relationship.id))
69 .unwrap_or(1.0);
70
71 let new_dist = dist[¤t] + weight;
72
73 // First time we reach this neighbor
74 if dist[&neighbor] == f64::INFINITY {
75 queue.push_back(neighbor);
76 dist.insert(neighbor, new_dist);
77 }
78
79 // Shortest path to neighbor via current
80 if (dist[&neighbor] - new_dist).abs() < f64::EPSILON {
81 sigma.insert(neighbor, sigma[&neighbor] + sigma[¤t]);
82 paths.get_mut(&neighbor).unwrap().push(current);
83 }
84 }
85 }
86
87 // Accumulate dependencies
88 while let Some(current) = stack.pop() {
89 for &predecessor in &paths[¤t] {
90 let contribution = (sigma[&predecessor] / sigma[¤t]) * (1.0 + delta[¤t]);
91 delta.insert(predecessor, delta[&predecessor] + contribution);
92 }
93
94 if current != source {
95 centrality.insert(current, centrality[¤t] + delta[¤t]);
96 }
97 }
98 }
99
100 // Normalize (divide by 2 for undirected graphs)
101 let normalization_factor = if nodes.len() > 2 {
102 2.0 * ((nodes.len() - 1) * (nodes.len() - 2)) as f64
103 } else {
104 1.0
105 };
106
107 for value in centrality.values_mut() {
108 *value /= normalization_factor;
109 }
110
111 Ok(centrality)
112 }
113
114 /// Calculate closeness centrality for all nodes
115 pub fn closeness_centrality(
116 &self,
117 nodes: &[NodeId],
118 weight_fn: Option<&WeightFn>,
119 ) -> Result<HashMap<NodeId, f64>> {
120 let mut centrality: HashMap<NodeId, f64> = HashMap::new();
121
122 for &source in nodes {
123 let distances = self.single_source_shortest_paths(source, nodes, weight_fn)?;
124
125 let mut total_distance = 0.0;
126 let mut reachable_count = 0;
127
128 for &target in nodes {
129 if target != source {
130 if let Some(distance) = distances.get(&target) {
131 if *distance < f64::INFINITY {
132 total_distance += distance;
133 reachable_count += 1;
134 }
135 }
136 }
137 }
138
139 let closeness = if total_distance > 0.0 && reachable_count > 0 {
140 (reachable_count as f64) / total_distance
141 } else {
142 0.0
143 };
144
145 centrality.insert(source, closeness);
146 }
147
148 Ok(centrality)
149 }
150
151 /// Calculate degree centrality for all nodes
152 pub fn degree_centrality(&self, nodes: &[NodeId]) -> Result<HashMap<NodeId, f64>> {
153 let mut centrality: HashMap<NodeId, f64> = HashMap::new();
154 let node_count = nodes.len();
155
156 for &node in nodes {
157 let relationships = self.graph.get_node_relationships(node, Direction::Both, None);
158 let degree = relationships.len() as f64;
159
160 // Normalize by the maximum possible degree
161 let normalized_degree = if node_count > 1 {
162 degree / (node_count - 1) as f64
163 } else {
164 0.0
165 };
166
167 centrality.insert(node, normalized_degree);
168 }
169
170 Ok(centrality)
171 }
172
173 /// Calculate eigenvector centrality using power iteration
174 pub fn eigenvector_centrality(
175 &self,
176 nodes: &[NodeId],
177 max_iterations: usize,
178 tolerance: f64,
179 ) -> Result<HashMap<NodeId, f64>> {
180 let n = nodes.len();
181 if n == 0 {
182 return Ok(HashMap::new());
183 }
184
185 // Create adjacency matrix representation
186 let mut adj: HashMap<(NodeId, NodeId), f64> = HashMap::new();
187
188 for &node in nodes {
189 let relationships = self.graph.get_node_relationships(node, Direction::Both, None);
190 for relationship in relationships {
191 let neighbor = if relationship.start_node == node {
192 relationship.end_node
193 } else {
194 relationship.start_node
195 };
196
197 if nodes.contains(&neighbor) {
198 adj.insert((node, neighbor), 1.0);
199 }
200 }
201 }
202
203 // Initialize eigenvector
204 let mut centrality: HashMap<NodeId, f64> = HashMap::new();
205 for &node in nodes {
206 centrality.insert(node, 1.0 / (n as f64).sqrt());
207 }
208
209 // Power iteration
210 for _ in 0..max_iterations {
211 let mut new_centrality: HashMap<NodeId, f64> = HashMap::new();
212
213 // Matrix-vector multiplication
214 for &node in nodes {
215 let mut sum = 0.0;
216 for &neighbor in nodes {
217 if let Some(weight) = adj.get(&(neighbor, node)) {
218 sum += weight * centrality[&neighbor];
219 }
220 }
221 new_centrality.insert(node, sum);
222 }
223
224 // Normalize
225 let norm: f64 = new_centrality.values().map(|x| x * x).sum::<f64>().sqrt();
226 if norm > 0.0 {
227 for value in new_centrality.values_mut() {
228 *value /= norm;
229 }
230 }
231
232 // Check convergence
233 let mut converged = true;
234 for &node in nodes {
235 if (new_centrality[&node] - centrality[&node]).abs() > tolerance {
236 converged = false;
237 break;
238 }
239 }
240
241 centrality = new_centrality;
242
243 if converged {
244 break;
245 }
246 }
247
248 Ok(centrality)
249 }
250
251 /// Calculate PageRank centrality
252 pub fn pagerank(
253 &self,
254 nodes: &[NodeId],
255 damping_factor: f64,
256 max_iterations: usize,
257 tolerance: f64,
258 ) -> Result<HashMap<NodeId, f64>> {
259 let n = nodes.len();
260 if n == 0 {
261 return Ok(HashMap::new());
262 }
263
264 // Initialize PageRank values
265 let mut pagerank: HashMap<NodeId, f64> = HashMap::new();
266 let initial_value = 1.0 / n as f64;
267
268 for &node in nodes {
269 pagerank.insert(node, initial_value);
270 }
271
272 // Calculate out-degrees
273 let mut out_degree: HashMap<NodeId, usize> = HashMap::new();
274 for &node in nodes {
275 let relationships = self.graph.get_node_relationships(node, Direction::Outgoing, None);
276 out_degree.insert(node, relationships.len());
277 }
278
279 // Power iteration
280 for _ in 0..max_iterations {
281 let mut new_pagerank: HashMap<NodeId, f64> = HashMap::new();
282
283 // Calculate dangling node contribution (nodes with no outgoing links)
284 let mut dangling_sum = 0.0;
285 for &node in nodes {
286 if out_degree[&node] == 0 {
287 dangling_sum += pagerank[&node];
288 }
289 }
290
291 for &node in nodes {
292 let mut sum = 0.0;
293
294 // Sum contributions from incoming links
295 let incoming_relationships = self.graph.get_node_relationships(node, Direction::Incoming, None);
296 for relationship in incoming_relationships {
297 let source = relationship.start_node;
298 if nodes.contains(&source) && out_degree[&source] > 0 {
299 sum += pagerank[&source] / out_degree[&source] as f64;
300 }
301 }
302
303 // Add dangling node contribution (distributed equally)
304 let dangling_contribution = dangling_sum / n as f64;
305
306 let new_value = (1.0 - damping_factor) / n as f64 + damping_factor * (sum + dangling_contribution);
307 new_pagerank.insert(node, new_value);
308 }
309
310 // Check convergence
311 let mut converged = true;
312 for &node in nodes {
313 if (new_pagerank[&node] - pagerank[&node]).abs() > tolerance {
314 converged = false;
315 break;
316 }
317 }
318
319 pagerank = new_pagerank;
320
321 if converged {
322 break;
323 }
324 }
325
326 Ok(pagerank)
327 }
328
329 /// Calculate clustering coefficient for all nodes
330 pub fn clustering_coefficient(&self, nodes: &[NodeId]) -> Result<HashMap<NodeId, f64>> {
331 let mut clustering: HashMap<NodeId, f64> = HashMap::new();
332
333 for &node in nodes {
334 let neighbors = self.get_neighbors(node, nodes)?;
335 let degree = neighbors.len();
336
337 if degree < 2 {
338 clustering.insert(node, 0.0);
339 continue;
340 }
341
342 // Count triangles
343 let mut triangle_count = 0;
344 for i in 0..neighbors.len() {
345 for j in (i + 1)..neighbors.len() {
346 if self.are_connected(neighbors[i], neighbors[j])? {
347 triangle_count += 1;
348 }
349 }
350 }
351
352 let max_triangles = degree * (degree - 1) / 2;
353 let coefficient = if max_triangles > 0 {
354 triangle_count as f64 / max_triangles as f64
355 } else {
356 0.0
357 };
358
359 clustering.insert(node, coefficient);
360 }
361
362 Ok(clustering)
363 }
364
365 // Helper methods
366
367 fn single_source_shortest_paths(
368 &self,
369 source: NodeId,
370 nodes: &[NodeId],
371 weight_fn: Option<&WeightFn>,
372 ) -> Result<HashMap<NodeId, f64>> {
373 let mut distances: HashMap<NodeId, f64> = HashMap::new();
374 let mut queue = VecDeque::new();
375
376 for &node in nodes {
377 distances.insert(node, f64::INFINITY);
378 }
379
380 distances.insert(source, 0.0);
381 queue.push_back(source);
382
383 while let Some(current) = queue.pop_front() {
384 let current_dist = distances[¤t];
385
386 let relationships = self.graph.get_node_relationships(current, Direction::Both, None);
387 for relationship in relationships {
388 let neighbor = if relationship.start_node == current {
389 relationship.end_node
390 } else {
391 relationship.start_node
392 };
393
394 if !nodes.contains(&neighbor) {
395 continue;
396 }
397
398 let weight = weight_fn
399 .map(|f| f(relationship.id))
400 .unwrap_or(1.0);
401
402 let new_dist = current_dist + weight;
403
404 if new_dist < distances[&neighbor] {
405 distances.insert(neighbor, new_dist);
406 queue.push_back(neighbor);
407 }
408 }
409 }
410
411 Ok(distances)
412 }
413
414 fn get_neighbors(&self, node: NodeId, nodes: &[NodeId]) -> Result<Vec<NodeId>> {
415 let mut neighbors = Vec::new();
416 let relationships = self.graph.get_node_relationships(node, Direction::Both, None);
417
418 for relationship in relationships {
419 let neighbor = if relationship.start_node == node {
420 relationship.end_node
421 } else {
422 relationship.start_node
423 };
424
425 if nodes.contains(&neighbor) {
426 neighbors.push(neighbor);
427 }
428 }
429
430 Ok(neighbors)
431 }
432
433 fn are_connected(&self, node1: NodeId, node2: NodeId) -> Result<bool> {
434 let relationships = self.graph.get_node_relationships(node1, Direction::Both, None);
435
436 for relationship in relationships {
437 let other = if relationship.start_node == node1 {
438 relationship.end_node
439 } else {
440 relationship.start_node
441 };
442
443 if other == node2 {
444 return Ok(true);
445 }
446 }
447
448 Ok(false)
449 }
450}
451
452#[cfg(test)]
453mod tests {
454 use super::*;
455 use crate::Graph;
456
457 fn create_test_graph() -> Graph {
458 let graph = Graph::new();
459
460 // Create a simple graph for testing
461 let node_a = graph.create_node();
462 let node_b = graph.create_node();
463 let node_c = graph.create_node();
464 let node_d = graph.create_node();
465
466 let schema = graph.schema();
467 let mut schema = schema.write();
468 let rel_type = schema.get_or_create_relationship_type("CONNECTS");
469 drop(schema);
470
471 // Create relationships: A-B-C-D and A-C (creating a triangle and path)
472 graph.create_relationship(node_a, node_b, rel_type).unwrap();
473 graph.create_relationship(node_b, node_c, rel_type).unwrap();
474 graph.create_relationship(node_c, node_d, rel_type).unwrap();
475 graph.create_relationship(node_a, node_c, rel_type).unwrap();
476
477 graph
478 }
479
480 #[test]
481 fn test_degree_centrality() {
482 let graph = create_test_graph();
483 let centrality = CentralityMeasures::new(&graph);
484
485 let nodes: Vec<_> = graph.get_all_nodes();
486 let degree_centrality = centrality.degree_centrality(&nodes).unwrap();
487
488 assert_eq!(degree_centrality.len(), nodes.len());
489
490 // All centrality values should be between 0 and 1
491 for value in degree_centrality.values() {
492 assert!(*value >= 0.0 && *value <= 1.0);
493 }
494 }
495
496 #[test]
497 fn test_closeness_centrality() {
498 let graph = create_test_graph();
499 let centrality = CentralityMeasures::new(&graph);
500
501 let nodes: Vec<_> = graph.get_all_nodes();
502 let closeness_centrality = centrality.closeness_centrality(&nodes, None).unwrap();
503
504 assert_eq!(closeness_centrality.len(), nodes.len());
505
506 // All centrality values should be non-negative
507 for value in closeness_centrality.values() {
508 assert!(*value >= 0.0);
509 }
510 }
511
512 #[test]
513 fn test_clustering_coefficient() {
514 let graph = create_test_graph();
515 let centrality = CentralityMeasures::new(&graph);
516
517 let nodes: Vec<_> = graph.get_all_nodes();
518 let clustering = centrality.clustering_coefficient(&nodes).unwrap();
519
520 assert_eq!(clustering.len(), nodes.len());
521
522 // All clustering coefficients should be between 0 and 1
523 for value in clustering.values() {
524 assert!(*value >= 0.0 && *value <= 1.0);
525 }
526 }
527
528 #[test]
529 fn test_pagerank() {
530 let graph = create_test_graph();
531 let centrality = CentralityMeasures::new(&graph);
532
533 let nodes: Vec<_> = graph.get_all_nodes();
534 let pagerank = centrality.pagerank(&nodes, 0.85, 100, 1e-6).unwrap();
535
536 assert_eq!(pagerank.len(), nodes.len());
537
538 // PageRank values should sum to approximately 1.0
539 let total: f64 = pagerank.values().sum();
540 assert!((total - 1.0).abs() < 1e-3);
541
542 // All PageRank values should be positive
543 for value in pagerank.values() {
544 assert!(*value > 0.0);
545 }
546 }
547}