this repo has no description
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

at main 671 lines 23 kB view raw
1use crate::{Graph, NodeId, RelationshipId, Result}; 2use crate::core::relationship::Direction; 3use std::collections::{HashMap, HashSet, VecDeque}; 4use std::cmp::Ordering; 5 6/// Community detection algorithms 7pub struct CommunityDetection<'a> { 8 graph: &'a Graph, 9} 10 11impl<'a> CommunityDetection<'a> { 12 pub fn new(graph: &'a Graph) -> Self { 13 Self { graph } 14 } 15 16 /// Louvain method for community detection 17 pub fn louvain(&self, nodes: &[NodeId]) -> Result<Vec<Vec<NodeId>>> { 18 let mut communities = self.initialize_communities(nodes)?; 19 let mut improved = true; 20 21 while improved { 22 improved = false; 23 24 for &node in nodes { 25 let current_community = self.find_node_community(&communities, node); 26 let best_community = self.find_best_community_for_node(node, &communities, nodes)?; 27 28 if best_community != current_community { 29 // Move node to best community 30 self.move_node_to_community(&mut communities, node, current_community, best_community); 31 improved = true; 32 } 33 } 34 } 35 36 Ok(communities) 37 } 38 39 /// Label propagation algorithm for community detection 40 pub fn label_propagation( 41 &self, 42 nodes: &[NodeId], 43 max_iterations: usize, 44 ) -> Result<Vec<Vec<NodeId>>> { 45 let mut labels: HashMap<NodeId, usize> = HashMap::new(); 46 47 // Initialize each node with its own label 48 for (i, &node) in nodes.iter().enumerate() { 49 labels.insert(node, i); 50 } 51 52 for _ in 0..max_iterations { 53 let mut changed = false; 54 let mut new_labels = labels.clone(); 55 56 for &node in nodes { 57 let neighbor_labels = self.get_neighbor_labels(node, &labels, nodes)?; 58 59 if let Some(most_frequent_label) = self.most_frequent_label(neighbor_labels) { 60 if labels[&node] != most_frequent_label { 61 new_labels.insert(node, most_frequent_label); 62 changed = true; 63 } 64 } 65 } 66 67 labels = new_labels; 68 69 if !changed { 70 break; 71 } 72 } 73 74 // Convert labels to communities 75 self.labels_to_communities(labels, nodes) 76 } 77 78 /// Girvan-Newman edge betweenness community detection 79 pub fn girvan_newman(&self, nodes: &[NodeId], num_communities: usize) -> Result<Vec<Vec<NodeId>>> { 80 let mut remaining_edges: HashSet<RelationshipId> = HashSet::new(); 81 82 // Collect all edges 83 for &node in nodes { 84 let relationships = self.graph.get_node_relationships(node, Direction::Both, None); 85 for relationship in relationships { 86 let other = if relationship.start_node == node { 87 relationship.end_node 88 } else { 89 relationship.start_node 90 }; 91 92 if nodes.contains(&other) { 93 remaining_edges.insert(relationship.id); 94 } 95 } 96 } 97 98 loop { 99 let components = self.find_components_excluding_edges(nodes, &HashSet::new())?; 100 101 if components.len() >= num_communities { 102 return Ok(components.into_iter().take(num_communities).collect()); 103 } 104 105 // Calculate edge betweenness for all remaining edges 106 let edge_betweenness = self.calculate_edge_betweenness(nodes, &remaining_edges)?; 107 108 // Find edge with highest betweenness 109 if let Some((&edge_to_remove, _)) = edge_betweenness.iter() 110 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal)) { 111 remaining_edges.remove(&edge_to_remove); 112 } else { 113 break; 114 } 115 } 116 117 self.find_components_excluding_edges(nodes, &HashSet::new()) 118 } 119 120 /// Modularity calculation for community quality assessment 121 pub fn modularity(&self, communities: &[Vec<NodeId>], nodes: &[NodeId]) -> Result<f64> { 122 let total_edges = self.count_total_edges(nodes)?; 123 if total_edges == 0 { 124 return Ok(0.0); 125 } 126 127 let mut modularity = 0.0; 128 129 for community in communities { 130 let community_set: HashSet<NodeId> = community.iter().copied().collect(); 131 132 for &node_i in community { 133 for &node_j in community { 134 if node_i <= node_j { 135 continue; // Avoid double counting 136 } 137 138 let a_ij = if self.are_connected(node_i, node_j)? { 1.0 } else { 0.0 }; 139 let k_i = self.get_degree(node_i, nodes)? as f64; 140 let k_j = self.get_degree(node_j, nodes)? as f64; 141 142 modularity += a_ij - (k_i * k_j) / (2.0 * total_edges as f64); 143 } 144 } 145 } 146 147 Ok(modularity / (2.0 * total_edges as f64)) 148 } 149 150 /// Fast greedy modularity optimization 151 pub fn fast_greedy_modularity(&self, nodes: &[NodeId]) -> Result<Vec<Vec<NodeId>>> { 152 let mut communities: Vec<HashSet<NodeId>> = nodes.iter().map(|&node| { 153 let mut set = HashSet::new(); 154 set.insert(node); 155 set 156 }).collect(); 157 158 let mut best_modularity = self.calculate_modularity_for_partition(&communities, nodes)?; 159 let mut best_communities = communities.clone(); 160 161 while communities.len() > 1 { 162 let mut best_merge: Option<(usize, usize)> = None; 163 let mut best_delta_q = f64::NEG_INFINITY; 164 165 // Try all possible merges 166 for i in 0..communities.len() { 167 for j in (i + 1)..communities.len() { 168 let delta_q = self.calculate_merge_delta_q(&communities, i, j, nodes)?; 169 170 if delta_q > best_delta_q { 171 best_delta_q = delta_q; 172 best_merge = Some((i, j)); 173 } 174 } 175 } 176 177 if let Some((i, j)) = best_merge { 178 // Merge communities i and j 179 let community_j = communities.remove(j); 180 communities[i].extend(community_j); 181 182 let new_modularity = self.calculate_modularity_for_partition(&communities, nodes)?; 183 if new_modularity > best_modularity { 184 best_modularity = new_modularity; 185 best_communities = communities.clone(); 186 } 187 } else { 188 break; 189 } 190 } 191 192 Ok(best_communities.into_iter().map(|set| set.into_iter().collect()).collect()) 193 } 194 195 /// Spectral clustering using graph Laplacian 196 pub fn spectral_clustering(&self, nodes: &[NodeId], k: usize) -> Result<Vec<Vec<NodeId>>> { 197 // This is a simplified version - in practice, you'd use proper eigenvalue decomposition 198 // For now, we'll use a heuristic approach based on node connectivity 199 200 let mut communities: Vec<Vec<NodeId>> = Vec::new(); 201 let mut remaining_nodes: HashSet<NodeId> = nodes.iter().copied().collect(); 202 203 for _ in 0..k { 204 if remaining_nodes.is_empty() { 205 break; 206 } 207 208 // Start with a random node 209 let start_node = *remaining_nodes.iter().next().unwrap(); 210 let mut community = vec![start_node]; 211 remaining_nodes.remove(&start_node); 212 213 // Add nodes that are well-connected to the current community 214 let mut added = true; 215 while added && !remaining_nodes.is_empty() { 216 added = false; 217 let mut best_node = None; 218 let mut best_score = 0.0; 219 220 for &candidate in &remaining_nodes { 221 let score = self.calculate_community_affinity(candidate, &community)?; 222 if score > best_score { 223 best_score = score; 224 best_node = Some(candidate); 225 } 226 } 227 228 if let Some(node) = best_node { 229 // Only add node if it has a strong connection to the community 230 // (at least 50% of community members) 231 let min_threshold = 0.5; 232 if best_score >= min_threshold { 233 community.push(node); 234 remaining_nodes.remove(&node); 235 added = true; 236 } 237 } 238 } 239 240 communities.push(community); 241 } 242 243 // Add any remaining nodes to the last community 244 if !remaining_nodes.is_empty() { 245 if let Some(last_community) = communities.last_mut() { 246 last_community.extend(remaining_nodes); 247 } else { 248 communities.push(remaining_nodes.into_iter().collect()); 249 } 250 } 251 252 Ok(communities) 253 } 254 255 // Helper methods 256 257 fn initialize_communities(&self, nodes: &[NodeId]) -> Result<Vec<Vec<NodeId>>> { 258 Ok(nodes.iter().map(|&node| vec![node]).collect()) 259 } 260 261 fn find_node_community(&self, communities: &[Vec<NodeId>], node: NodeId) -> usize { 262 for (i, community) in communities.iter().enumerate() { 263 if community.contains(&node) { 264 return i; 265 } 266 } 267 0 // Fallback 268 } 269 270 fn find_best_community_for_node( 271 &self, 272 node: NodeId, 273 communities: &[Vec<NodeId>], 274 nodes: &[NodeId], 275 ) -> Result<usize> { 276 let mut best_community = 0; 277 let mut best_gain = f64::NEG_INFINITY; 278 279 for (i, community) in communities.iter().enumerate() { 280 let gain = self.calculate_modularity_gain(node, community, nodes)?; 281 if gain > best_gain { 282 best_gain = gain; 283 best_community = i; 284 } 285 } 286 287 Ok(best_community) 288 } 289 290 fn move_node_to_community( 291 &self, 292 communities: &mut Vec<Vec<NodeId>>, 293 node: NodeId, 294 from: usize, 295 to: usize, 296 ) { 297 if from != to { 298 communities[from].retain(|&n| n != node); 299 communities[to].push(node); 300 } 301 } 302 303 fn calculate_modularity_gain( 304 &self, 305 node: NodeId, 306 community: &[NodeId], 307 nodes: &[NodeId], 308 ) -> Result<f64> { 309 let mut internal_connections = 0; 310 let total_edges = self.count_total_edges(nodes)?; 311 312 for &other in community { 313 if other != node && self.are_connected(node, other)? { 314 internal_connections += 1; 315 } 316 } 317 318 let node_degree = self.get_degree(node, nodes)?; 319 let community_degree: usize = community.iter() 320 .map(|&n| self.get_degree(n, nodes).unwrap_or(0)) 321 .sum(); 322 323 if total_edges == 0 { 324 return Ok(0.0); 325 } 326 327 let gain = (internal_connections as f64) - 328 (node_degree as f64 * community_degree as f64) / (2.0 * total_edges as f64); 329 330 Ok(gain) 331 } 332 333 fn get_neighbor_labels( 334 &self, 335 node: NodeId, 336 labels: &HashMap<NodeId, usize>, 337 nodes: &[NodeId], 338 ) -> Result<Vec<usize>> { 339 let mut neighbor_labels = Vec::new(); 340 let relationships = self.graph.get_node_relationships(node, Direction::Both, None); 341 342 for relationship in relationships { 343 let neighbor = if relationship.start_node == node { 344 relationship.end_node 345 } else { 346 relationship.start_node 347 }; 348 349 if nodes.contains(&neighbor) { 350 if let Some(&label) = labels.get(&neighbor) { 351 neighbor_labels.push(label); 352 } 353 } 354 } 355 356 Ok(neighbor_labels) 357 } 358 359 fn most_frequent_label(&self, labels: Vec<usize>) -> Option<usize> { 360 if labels.is_empty() { 361 return None; 362 } 363 364 let mut counts: HashMap<usize, usize> = HashMap::new(); 365 for label in labels { 366 *counts.entry(label).or_insert(0) += 1; 367 } 368 369 counts.into_iter() 370 .max_by_key(|(_, count)| *count) 371 .map(|(label, _)| label) 372 } 373 374 fn labels_to_communities( 375 &self, 376 labels: HashMap<NodeId, usize>, 377 nodes: &[NodeId], 378 ) -> Result<Vec<Vec<NodeId>>> { 379 let mut communities: HashMap<usize, Vec<NodeId>> = HashMap::new(); 380 381 for &node in nodes { 382 if let Some(&label) = labels.get(&node) { 383 communities.entry(label).or_insert_with(Vec::new).push(node); 384 } 385 } 386 387 Ok(communities.into_values().collect()) 388 } 389 390 fn find_components_excluding_edges( 391 &self, 392 nodes: &[NodeId], 393 excluded_edges: &HashSet<RelationshipId>, 394 ) -> Result<Vec<Vec<NodeId>>> { 395 let mut visited: HashSet<NodeId> = HashSet::new(); 396 let mut components = Vec::new(); 397 398 for &node in nodes { 399 if !visited.contains(&node) { 400 let component = self.bfs_component(node, nodes, excluded_edges, &mut visited)?; 401 components.push(component); 402 } 403 } 404 405 Ok(components) 406 } 407 408 fn bfs_component( 409 &self, 410 start: NodeId, 411 nodes: &[NodeId], 412 excluded_edges: &HashSet<RelationshipId>, 413 visited: &mut HashSet<NodeId>, 414 ) -> Result<Vec<NodeId>> { 415 let mut component = Vec::new(); 416 let mut queue = VecDeque::new(); 417 418 queue.push_back(start); 419 visited.insert(start); 420 421 while let Some(current) = queue.pop_front() { 422 component.push(current); 423 424 let relationships = self.graph.get_node_relationships(current, Direction::Both, None); 425 for relationship in relationships { 426 if excluded_edges.contains(&relationship.id) { 427 continue; 428 } 429 430 let neighbor = if relationship.start_node == current { 431 relationship.end_node 432 } else { 433 relationship.start_node 434 }; 435 436 if nodes.contains(&neighbor) && !visited.contains(&neighbor) { 437 visited.insert(neighbor); 438 queue.push_back(neighbor); 439 } 440 } 441 } 442 443 Ok(component) 444 } 445 446 fn calculate_edge_betweenness( 447 &self, 448 nodes: &[NodeId], 449 edges: &HashSet<RelationshipId>, 450 ) -> Result<HashMap<RelationshipId, f64>> { 451 let mut betweenness: HashMap<RelationshipId, f64> = HashMap::new(); 452 453 for &edge in edges { 454 betweenness.insert(edge, 0.0); 455 } 456 457 // This is a simplified calculation - proper edge betweenness requires 458 // counting shortest paths that pass through each edge 459 for &source in nodes { 460 for &target in nodes { 461 if source != target { 462 // Count paths from source to target that use each edge 463 // This is a placeholder implementation 464 for &edge in edges { 465 betweenness.insert(edge, betweenness[&edge] + 1.0); 466 } 467 } 468 } 469 } 470 471 Ok(betweenness) 472 } 473 474 fn calculate_modularity_for_partition( 475 &self, 476 communities: &[HashSet<NodeId>], 477 nodes: &[NodeId], 478 ) -> Result<f64> { 479 let communities_vec: Vec<Vec<NodeId>> = communities.iter() 480 .map(|set| set.iter().copied().collect()) 481 .collect(); 482 483 self.modularity(&communities_vec, nodes) 484 } 485 486 fn calculate_merge_delta_q( 487 &self, 488 communities: &[HashSet<NodeId>], 489 i: usize, 490 j: usize, 491 nodes: &[NodeId], 492 ) -> Result<f64> { 493 // Simplified delta Q calculation 494 let mut connections = 0; 495 496 for &node_i in &communities[i] { 497 for &node_j in &communities[j] { 498 if self.are_connected(node_i, node_j)? { 499 connections += 1; 500 } 501 } 502 } 503 504 Ok(connections as f64) 505 } 506 507 fn calculate_community_affinity(&self, node: NodeId, community: &[NodeId]) -> Result<f64> { 508 let mut connections = 0; 509 510 for &community_node in community { 511 if self.are_connected(node, community_node)? { 512 connections += 1; 513 } 514 } 515 516 Ok(connections as f64 / community.len() as f64) 517 } 518 519 fn count_total_edges(&self, nodes: &[NodeId]) -> Result<usize> { 520 let mut edge_count = 0; 521 let mut counted_edges: HashSet<RelationshipId> = HashSet::new(); 522 523 for &node in nodes { 524 let relationships = self.graph.get_node_relationships(node, Direction::Both, None); 525 for relationship in relationships { 526 if !counted_edges.contains(&relationship.id) { 527 let other = if relationship.start_node == node { 528 relationship.end_node 529 } else { 530 relationship.start_node 531 }; 532 533 if nodes.contains(&other) { 534 edge_count += 1; 535 counted_edges.insert(relationship.id); 536 } 537 } 538 } 539 } 540 541 Ok(edge_count) 542 } 543 544 fn get_degree(&self, node: NodeId, nodes: &[NodeId]) -> Result<usize> { 545 let relationships = self.graph.get_node_relationships(node, Direction::Both, None); 546 let mut degree = 0; 547 548 for relationship in relationships { 549 let other = if relationship.start_node == node { 550 relationship.end_node 551 } else { 552 relationship.start_node 553 }; 554 555 if nodes.contains(&other) { 556 degree += 1; 557 } 558 } 559 560 Ok(degree) 561 } 562 563 fn are_connected(&self, node1: NodeId, node2: NodeId) -> Result<bool> { 564 let relationships = self.graph.get_node_relationships(node1, Direction::Both, None); 565 566 for relationship in relationships { 567 let other = if relationship.start_node == node1 { 568 relationship.end_node 569 } else { 570 relationship.start_node 571 }; 572 573 if other == node2 { 574 return Ok(true); 575 } 576 } 577 578 Ok(false) 579 } 580} 581 582#[cfg(test)] 583mod tests { 584 use super::*; 585 use crate::Graph; 586 587 fn create_test_graph() -> Graph { 588 let graph = Graph::new(); 589 590 // Create a graph with two clear communities 591 // Community 1: A-B-C (triangle) 592 // Community 2: D-E-F (triangle) 593 // Bridge: C-D 594 595 let node_a = graph.create_node(); 596 let node_b = graph.create_node(); 597 let node_c = graph.create_node(); 598 let node_d = graph.create_node(); 599 let node_e = graph.create_node(); 600 let node_f = graph.create_node(); 601 602 let schema = graph.schema(); 603 let mut schema = schema.write(); 604 let rel_type = schema.get_or_create_relationship_type("CONNECTS"); 605 drop(schema); 606 607 // Community 1 608 graph.create_relationship(node_a, node_b, rel_type).unwrap(); 609 graph.create_relationship(node_b, node_c, rel_type).unwrap(); 610 graph.create_relationship(node_c, node_a, rel_type).unwrap(); 611 612 // Community 2 613 graph.create_relationship(node_d, node_e, rel_type).unwrap(); 614 graph.create_relationship(node_e, node_f, rel_type).unwrap(); 615 graph.create_relationship(node_f, node_d, rel_type).unwrap(); 616 617 // Bridge 618 graph.create_relationship(node_c, node_d, rel_type).unwrap(); 619 620 graph 621 } 622 623 #[test] 624 fn test_label_propagation() { 625 let graph = create_test_graph(); 626 let community_detection = CommunityDetection::new(&graph); 627 628 let nodes: Vec<_> = graph.get_all_nodes(); 629 let communities = community_detection.label_propagation(&nodes, 10).unwrap(); 630 631 assert!(!communities.is_empty()); 632 633 // Each node should be in exactly one community 634 let total_nodes: usize = communities.iter().map(|c| c.len()).sum(); 635 assert_eq!(total_nodes, nodes.len()); 636 } 637 638 #[test] 639 fn test_modularity_calculation() { 640 let graph = create_test_graph(); 641 let community_detection = CommunityDetection::new(&graph); 642 643 let nodes: Vec<_> = graph.get_all_nodes(); 644 645 // Create artificial communities for testing 646 let communities = vec![ 647 nodes[0..3].to_vec(), // First 3 nodes 648 nodes[3..].to_vec(), // Remaining nodes 649 ]; 650 651 let modularity = community_detection.modularity(&communities, &nodes).unwrap(); 652 653 // Modularity should be between -1 and 1 654 assert!(modularity >= -1.0 && modularity <= 1.0); 655 } 656 657 #[test] 658 fn test_spectral_clustering() { 659 let graph = create_test_graph(); 660 let community_detection = CommunityDetection::new(&graph); 661 662 let nodes: Vec<_> = graph.get_all_nodes(); 663 let communities = community_detection.spectral_clustering(&nodes, 2).unwrap(); 664 665 assert_eq!(communities.len(), 2); 666 667 // Each node should be in exactly one community 668 let total_nodes: usize = communities.iter().map(|c| c.len()).sum(); 669 assert_eq!(total_nodes, nodes.len()); 670 } 671}