this repo has no description
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

at main 547 lines 18 kB view raw
1use crate::{Graph, NodeId, RelationshipId, Result}; 2use crate::core::relationship::Direction; 3use super::WeightFn; 4use std::collections::{HashMap, HashSet, VecDeque}; 5 6/// Centrality measures for graph analysis 7pub struct CentralityMeasures<'a> { 8 graph: &'a Graph, 9} 10 11impl<'a> CentralityMeasures<'a> { 12 pub fn new(graph: &'a Graph) -> Self { 13 Self { graph } 14 } 15 16 /// Calculate betweenness centrality for all nodes 17 pub fn betweenness_centrality( 18 &self, 19 nodes: &[NodeId], 20 weight_fn: Option<&WeightFn>, 21 ) -> Result<HashMap<NodeId, f64>> { 22 let mut centrality: HashMap<NodeId, f64> = HashMap::new(); 23 24 // Initialize all centralities to 0 25 for &node in nodes { 26 centrality.insert(node, 0.0); 27 } 28 29 // For each node as source 30 for &source in nodes { 31 let mut stack = Vec::new(); 32 let mut paths: HashMap<NodeId, Vec<NodeId>> = HashMap::new(); 33 let mut sigma: HashMap<NodeId, f64> = HashMap::new(); 34 let mut dist: HashMap<NodeId, f64> = HashMap::new(); 35 let mut delta: HashMap<NodeId, f64> = HashMap::new(); 36 37 // Initialize 38 for &node in nodes { 39 paths.insert(node, Vec::new()); 40 sigma.insert(node, 0.0); 41 dist.insert(node, f64::INFINITY); 42 delta.insert(node, 0.0); 43 } 44 45 sigma.insert(source, 1.0); 46 dist.insert(source, 0.0); 47 48 let mut queue = VecDeque::new(); 49 queue.push_back(source); 50 51 // BFS to find shortest paths 52 while let Some(current) = queue.pop_front() { 53 stack.push(current); 54 55 let relationships = self.graph.get_node_relationships(current, Direction::Both, None); 56 for relationship in relationships { 57 let neighbor = if relationship.start_node == current { 58 relationship.end_node 59 } else { 60 relationship.start_node 61 }; 62 63 if !nodes.contains(&neighbor) { 64 continue; 65 } 66 67 let weight = weight_fn 68 .map(|f| f(relationship.id)) 69 .unwrap_or(1.0); 70 71 let new_dist = dist[&current] + weight; 72 73 // First time we reach this neighbor 74 if dist[&neighbor] == f64::INFINITY { 75 queue.push_back(neighbor); 76 dist.insert(neighbor, new_dist); 77 } 78 79 // Shortest path to neighbor via current 80 if (dist[&neighbor] - new_dist).abs() < f64::EPSILON { 81 sigma.insert(neighbor, sigma[&neighbor] + sigma[&current]); 82 paths.get_mut(&neighbor).unwrap().push(current); 83 } 84 } 85 } 86 87 // Accumulate dependencies 88 while let Some(current) = stack.pop() { 89 for &predecessor in &paths[&current] { 90 let contribution = (sigma[&predecessor] / sigma[&current]) * (1.0 + delta[&current]); 91 delta.insert(predecessor, delta[&predecessor] + contribution); 92 } 93 94 if current != source { 95 centrality.insert(current, centrality[&current] + delta[&current]); 96 } 97 } 98 } 99 100 // Normalize (divide by 2 for undirected graphs) 101 let normalization_factor = if nodes.len() > 2 { 102 2.0 * ((nodes.len() - 1) * (nodes.len() - 2)) as f64 103 } else { 104 1.0 105 }; 106 107 for value in centrality.values_mut() { 108 *value /= normalization_factor; 109 } 110 111 Ok(centrality) 112 } 113 114 /// Calculate closeness centrality for all nodes 115 pub fn closeness_centrality( 116 &self, 117 nodes: &[NodeId], 118 weight_fn: Option<&WeightFn>, 119 ) -> Result<HashMap<NodeId, f64>> { 120 let mut centrality: HashMap<NodeId, f64> = HashMap::new(); 121 122 for &source in nodes { 123 let distances = self.single_source_shortest_paths(source, nodes, weight_fn)?; 124 125 let mut total_distance = 0.0; 126 let mut reachable_count = 0; 127 128 for &target in nodes { 129 if target != source { 130 if let Some(distance) = distances.get(&target) { 131 if *distance < f64::INFINITY { 132 total_distance += distance; 133 reachable_count += 1; 134 } 135 } 136 } 137 } 138 139 let closeness = if total_distance > 0.0 && reachable_count > 0 { 140 (reachable_count as f64) / total_distance 141 } else { 142 0.0 143 }; 144 145 centrality.insert(source, closeness); 146 } 147 148 Ok(centrality) 149 } 150 151 /// Calculate degree centrality for all nodes 152 pub fn degree_centrality(&self, nodes: &[NodeId]) -> Result<HashMap<NodeId, f64>> { 153 let mut centrality: HashMap<NodeId, f64> = HashMap::new(); 154 let node_count = nodes.len(); 155 156 for &node in nodes { 157 let relationships = self.graph.get_node_relationships(node, Direction::Both, None); 158 let degree = relationships.len() as f64; 159 160 // Normalize by the maximum possible degree 161 let normalized_degree = if node_count > 1 { 162 degree / (node_count - 1) as f64 163 } else { 164 0.0 165 }; 166 167 centrality.insert(node, normalized_degree); 168 } 169 170 Ok(centrality) 171 } 172 173 /// Calculate eigenvector centrality using power iteration 174 pub fn eigenvector_centrality( 175 &self, 176 nodes: &[NodeId], 177 max_iterations: usize, 178 tolerance: f64, 179 ) -> Result<HashMap<NodeId, f64>> { 180 let n = nodes.len(); 181 if n == 0 { 182 return Ok(HashMap::new()); 183 } 184 185 // Create adjacency matrix representation 186 let mut adj: HashMap<(NodeId, NodeId), f64> = HashMap::new(); 187 188 for &node in nodes { 189 let relationships = self.graph.get_node_relationships(node, Direction::Both, None); 190 for relationship in relationships { 191 let neighbor = if relationship.start_node == node { 192 relationship.end_node 193 } else { 194 relationship.start_node 195 }; 196 197 if nodes.contains(&neighbor) { 198 adj.insert((node, neighbor), 1.0); 199 } 200 } 201 } 202 203 // Initialize eigenvector 204 let mut centrality: HashMap<NodeId, f64> = HashMap::new(); 205 for &node in nodes { 206 centrality.insert(node, 1.0 / (n as f64).sqrt()); 207 } 208 209 // Power iteration 210 for _ in 0..max_iterations { 211 let mut new_centrality: HashMap<NodeId, f64> = HashMap::new(); 212 213 // Matrix-vector multiplication 214 for &node in nodes { 215 let mut sum = 0.0; 216 for &neighbor in nodes { 217 if let Some(weight) = adj.get(&(neighbor, node)) { 218 sum += weight * centrality[&neighbor]; 219 } 220 } 221 new_centrality.insert(node, sum); 222 } 223 224 // Normalize 225 let norm: f64 = new_centrality.values().map(|x| x * x).sum::<f64>().sqrt(); 226 if norm > 0.0 { 227 for value in new_centrality.values_mut() { 228 *value /= norm; 229 } 230 } 231 232 // Check convergence 233 let mut converged = true; 234 for &node in nodes { 235 if (new_centrality[&node] - centrality[&node]).abs() > tolerance { 236 converged = false; 237 break; 238 } 239 } 240 241 centrality = new_centrality; 242 243 if converged { 244 break; 245 } 246 } 247 248 Ok(centrality) 249 } 250 251 /// Calculate PageRank centrality 252 pub fn pagerank( 253 &self, 254 nodes: &[NodeId], 255 damping_factor: f64, 256 max_iterations: usize, 257 tolerance: f64, 258 ) -> Result<HashMap<NodeId, f64>> { 259 let n = nodes.len(); 260 if n == 0 { 261 return Ok(HashMap::new()); 262 } 263 264 // Initialize PageRank values 265 let mut pagerank: HashMap<NodeId, f64> = HashMap::new(); 266 let initial_value = 1.0 / n as f64; 267 268 for &node in nodes { 269 pagerank.insert(node, initial_value); 270 } 271 272 // Calculate out-degrees 273 let mut out_degree: HashMap<NodeId, usize> = HashMap::new(); 274 for &node in nodes { 275 let relationships = self.graph.get_node_relationships(node, Direction::Outgoing, None); 276 out_degree.insert(node, relationships.len()); 277 } 278 279 // Power iteration 280 for _ in 0..max_iterations { 281 let mut new_pagerank: HashMap<NodeId, f64> = HashMap::new(); 282 283 // Calculate dangling node contribution (nodes with no outgoing links) 284 let mut dangling_sum = 0.0; 285 for &node in nodes { 286 if out_degree[&node] == 0 { 287 dangling_sum += pagerank[&node]; 288 } 289 } 290 291 for &node in nodes { 292 let mut sum = 0.0; 293 294 // Sum contributions from incoming links 295 let incoming_relationships = self.graph.get_node_relationships(node, Direction::Incoming, None); 296 for relationship in incoming_relationships { 297 let source = relationship.start_node; 298 if nodes.contains(&source) && out_degree[&source] > 0 { 299 sum += pagerank[&source] / out_degree[&source] as f64; 300 } 301 } 302 303 // Add dangling node contribution (distributed equally) 304 let dangling_contribution = dangling_sum / n as f64; 305 306 let new_value = (1.0 - damping_factor) / n as f64 + damping_factor * (sum + dangling_contribution); 307 new_pagerank.insert(node, new_value); 308 } 309 310 // Check convergence 311 let mut converged = true; 312 for &node in nodes { 313 if (new_pagerank[&node] - pagerank[&node]).abs() > tolerance { 314 converged = false; 315 break; 316 } 317 } 318 319 pagerank = new_pagerank; 320 321 if converged { 322 break; 323 } 324 } 325 326 Ok(pagerank) 327 } 328 329 /// Calculate clustering coefficient for all nodes 330 pub fn clustering_coefficient(&self, nodes: &[NodeId]) -> Result<HashMap<NodeId, f64>> { 331 let mut clustering: HashMap<NodeId, f64> = HashMap::new(); 332 333 for &node in nodes { 334 let neighbors = self.get_neighbors(node, nodes)?; 335 let degree = neighbors.len(); 336 337 if degree < 2 { 338 clustering.insert(node, 0.0); 339 continue; 340 } 341 342 // Count triangles 343 let mut triangle_count = 0; 344 for i in 0..neighbors.len() { 345 for j in (i + 1)..neighbors.len() { 346 if self.are_connected(neighbors[i], neighbors[j])? { 347 triangle_count += 1; 348 } 349 } 350 } 351 352 let max_triangles = degree * (degree - 1) / 2; 353 let coefficient = if max_triangles > 0 { 354 triangle_count as f64 / max_triangles as f64 355 } else { 356 0.0 357 }; 358 359 clustering.insert(node, coefficient); 360 } 361 362 Ok(clustering) 363 } 364 365 // Helper methods 366 367 fn single_source_shortest_paths( 368 &self, 369 source: NodeId, 370 nodes: &[NodeId], 371 weight_fn: Option<&WeightFn>, 372 ) -> Result<HashMap<NodeId, f64>> { 373 let mut distances: HashMap<NodeId, f64> = HashMap::new(); 374 let mut queue = VecDeque::new(); 375 376 for &node in nodes { 377 distances.insert(node, f64::INFINITY); 378 } 379 380 distances.insert(source, 0.0); 381 queue.push_back(source); 382 383 while let Some(current) = queue.pop_front() { 384 let current_dist = distances[&current]; 385 386 let relationships = self.graph.get_node_relationships(current, Direction::Both, None); 387 for relationship in relationships { 388 let neighbor = if relationship.start_node == current { 389 relationship.end_node 390 } else { 391 relationship.start_node 392 }; 393 394 if !nodes.contains(&neighbor) { 395 continue; 396 } 397 398 let weight = weight_fn 399 .map(|f| f(relationship.id)) 400 .unwrap_or(1.0); 401 402 let new_dist = current_dist + weight; 403 404 if new_dist < distances[&neighbor] { 405 distances.insert(neighbor, new_dist); 406 queue.push_back(neighbor); 407 } 408 } 409 } 410 411 Ok(distances) 412 } 413 414 fn get_neighbors(&self, node: NodeId, nodes: &[NodeId]) -> Result<Vec<NodeId>> { 415 let mut neighbors = Vec::new(); 416 let relationships = self.graph.get_node_relationships(node, Direction::Both, None); 417 418 for relationship in relationships { 419 let neighbor = if relationship.start_node == node { 420 relationship.end_node 421 } else { 422 relationship.start_node 423 }; 424 425 if nodes.contains(&neighbor) { 426 neighbors.push(neighbor); 427 } 428 } 429 430 Ok(neighbors) 431 } 432 433 fn are_connected(&self, node1: NodeId, node2: NodeId) -> Result<bool> { 434 let relationships = self.graph.get_node_relationships(node1, Direction::Both, None); 435 436 for relationship in relationships { 437 let other = if relationship.start_node == node1 { 438 relationship.end_node 439 } else { 440 relationship.start_node 441 }; 442 443 if other == node2 { 444 return Ok(true); 445 } 446 } 447 448 Ok(false) 449 } 450} 451 452#[cfg(test)] 453mod tests { 454 use super::*; 455 use crate::Graph; 456 457 fn create_test_graph() -> Graph { 458 let graph = Graph::new(); 459 460 // Create a simple graph for testing 461 let node_a = graph.create_node(); 462 let node_b = graph.create_node(); 463 let node_c = graph.create_node(); 464 let node_d = graph.create_node(); 465 466 let schema = graph.schema(); 467 let mut schema = schema.write(); 468 let rel_type = schema.get_or_create_relationship_type("CONNECTS"); 469 drop(schema); 470 471 // Create relationships: A-B-C-D and A-C (creating a triangle and path) 472 graph.create_relationship(node_a, node_b, rel_type).unwrap(); 473 graph.create_relationship(node_b, node_c, rel_type).unwrap(); 474 graph.create_relationship(node_c, node_d, rel_type).unwrap(); 475 graph.create_relationship(node_a, node_c, rel_type).unwrap(); 476 477 graph 478 } 479 480 #[test] 481 fn test_degree_centrality() { 482 let graph = create_test_graph(); 483 let centrality = CentralityMeasures::new(&graph); 484 485 let nodes: Vec<_> = graph.get_all_nodes(); 486 let degree_centrality = centrality.degree_centrality(&nodes).unwrap(); 487 488 assert_eq!(degree_centrality.len(), nodes.len()); 489 490 // All centrality values should be between 0 and 1 491 for value in degree_centrality.values() { 492 assert!(*value >= 0.0 && *value <= 1.0); 493 } 494 } 495 496 #[test] 497 fn test_closeness_centrality() { 498 let graph = create_test_graph(); 499 let centrality = CentralityMeasures::new(&graph); 500 501 let nodes: Vec<_> = graph.get_all_nodes(); 502 let closeness_centrality = centrality.closeness_centrality(&nodes, None).unwrap(); 503 504 assert_eq!(closeness_centrality.len(), nodes.len()); 505 506 // All centrality values should be non-negative 507 for value in closeness_centrality.values() { 508 assert!(*value >= 0.0); 509 } 510 } 511 512 #[test] 513 fn test_clustering_coefficient() { 514 let graph = create_test_graph(); 515 let centrality = CentralityMeasures::new(&graph); 516 517 let nodes: Vec<_> = graph.get_all_nodes(); 518 let clustering = centrality.clustering_coefficient(&nodes).unwrap(); 519 520 assert_eq!(clustering.len(), nodes.len()); 521 522 // All clustering coefficients should be between 0 and 1 523 for value in clustering.values() { 524 assert!(*value >= 0.0 && *value <= 1.0); 525 } 526 } 527 528 #[test] 529 fn test_pagerank() { 530 let graph = create_test_graph(); 531 let centrality = CentralityMeasures::new(&graph); 532 533 let nodes: Vec<_> = graph.get_all_nodes(); 534 let pagerank = centrality.pagerank(&nodes, 0.85, 100, 1e-6).unwrap(); 535 536 assert_eq!(pagerank.len(), nodes.len()); 537 538 // PageRank values should sum to approximately 1.0 539 let total: f64 = pagerank.values().sum(); 540 assert!((total - 1.0).abs() < 1e-3); 541 542 // All PageRank values should be positive 543 for value in pagerank.values() { 544 assert!(*value > 0.0); 545 } 546 } 547}