this repo has no description
1use crate::{Graph, NodeId, RelationshipId, Result};
2use crate::core::relationship::Direction;
3use std::collections::{HashMap, HashSet, VecDeque};
4use std::cmp::Ordering;
5
6/// Community detection algorithms
7pub struct CommunityDetection<'a> {
8 graph: &'a Graph,
9}
10
11impl<'a> CommunityDetection<'a> {
12 pub fn new(graph: &'a Graph) -> Self {
13 Self { graph }
14 }
15
16 /// Louvain method for community detection
17 pub fn louvain(&self, nodes: &[NodeId]) -> Result<Vec<Vec<NodeId>>> {
18 let mut communities = self.initialize_communities(nodes)?;
19 let mut improved = true;
20
21 while improved {
22 improved = false;
23
24 for &node in nodes {
25 let current_community = self.find_node_community(&communities, node);
26 let best_community = self.find_best_community_for_node(node, &communities, nodes)?;
27
28 if best_community != current_community {
29 // Move node to best community
30 self.move_node_to_community(&mut communities, node, current_community, best_community);
31 improved = true;
32 }
33 }
34 }
35
36 Ok(communities)
37 }
38
39 /// Label propagation algorithm for community detection
40 pub fn label_propagation(
41 &self,
42 nodes: &[NodeId],
43 max_iterations: usize,
44 ) -> Result<Vec<Vec<NodeId>>> {
45 let mut labels: HashMap<NodeId, usize> = HashMap::new();
46
47 // Initialize each node with its own label
48 for (i, &node) in nodes.iter().enumerate() {
49 labels.insert(node, i);
50 }
51
52 for _ in 0..max_iterations {
53 let mut changed = false;
54 let mut new_labels = labels.clone();
55
56 for &node in nodes {
57 let neighbor_labels = self.get_neighbor_labels(node, &labels, nodes)?;
58
59 if let Some(most_frequent_label) = self.most_frequent_label(neighbor_labels) {
60 if labels[&node] != most_frequent_label {
61 new_labels.insert(node, most_frequent_label);
62 changed = true;
63 }
64 }
65 }
66
67 labels = new_labels;
68
69 if !changed {
70 break;
71 }
72 }
73
74 // Convert labels to communities
75 self.labels_to_communities(labels, nodes)
76 }
77
78 /// Girvan-Newman edge betweenness community detection
79 pub fn girvan_newman(&self, nodes: &[NodeId], num_communities: usize) -> Result<Vec<Vec<NodeId>>> {
80 let mut remaining_edges: HashSet<RelationshipId> = HashSet::new();
81
82 // Collect all edges
83 for &node in nodes {
84 let relationships = self.graph.get_node_relationships(node, Direction::Both, None);
85 for relationship in relationships {
86 let other = if relationship.start_node == node {
87 relationship.end_node
88 } else {
89 relationship.start_node
90 };
91
92 if nodes.contains(&other) {
93 remaining_edges.insert(relationship.id);
94 }
95 }
96 }
97
98 loop {
99 let components = self.find_components_excluding_edges(nodes, &HashSet::new())?;
100
101 if components.len() >= num_communities {
102 return Ok(components.into_iter().take(num_communities).collect());
103 }
104
105 // Calculate edge betweenness for all remaining edges
106 let edge_betweenness = self.calculate_edge_betweenness(nodes, &remaining_edges)?;
107
108 // Find edge with highest betweenness
109 if let Some((&edge_to_remove, _)) = edge_betweenness.iter()
110 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal)) {
111 remaining_edges.remove(&edge_to_remove);
112 } else {
113 break;
114 }
115 }
116
117 self.find_components_excluding_edges(nodes, &HashSet::new())
118 }
119
120 /// Modularity calculation for community quality assessment
121 pub fn modularity(&self, communities: &[Vec<NodeId>], nodes: &[NodeId]) -> Result<f64> {
122 let total_edges = self.count_total_edges(nodes)?;
123 if total_edges == 0 {
124 return Ok(0.0);
125 }
126
127 let mut modularity = 0.0;
128
129 for community in communities {
130 let community_set: HashSet<NodeId> = community.iter().copied().collect();
131
132 for &node_i in community {
133 for &node_j in community {
134 if node_i <= node_j {
135 continue; // Avoid double counting
136 }
137
138 let a_ij = if self.are_connected(node_i, node_j)? { 1.0 } else { 0.0 };
139 let k_i = self.get_degree(node_i, nodes)? as f64;
140 let k_j = self.get_degree(node_j, nodes)? as f64;
141
142 modularity += a_ij - (k_i * k_j) / (2.0 * total_edges as f64);
143 }
144 }
145 }
146
147 Ok(modularity / (2.0 * total_edges as f64))
148 }
149
150 /// Fast greedy modularity optimization
151 pub fn fast_greedy_modularity(&self, nodes: &[NodeId]) -> Result<Vec<Vec<NodeId>>> {
152 let mut communities: Vec<HashSet<NodeId>> = nodes.iter().map(|&node| {
153 let mut set = HashSet::new();
154 set.insert(node);
155 set
156 }).collect();
157
158 let mut best_modularity = self.calculate_modularity_for_partition(&communities, nodes)?;
159 let mut best_communities = communities.clone();
160
161 while communities.len() > 1 {
162 let mut best_merge: Option<(usize, usize)> = None;
163 let mut best_delta_q = f64::NEG_INFINITY;
164
165 // Try all possible merges
166 for i in 0..communities.len() {
167 for j in (i + 1)..communities.len() {
168 let delta_q = self.calculate_merge_delta_q(&communities, i, j, nodes)?;
169
170 if delta_q > best_delta_q {
171 best_delta_q = delta_q;
172 best_merge = Some((i, j));
173 }
174 }
175 }
176
177 if let Some((i, j)) = best_merge {
178 // Merge communities i and j
179 let community_j = communities.remove(j);
180 communities[i].extend(community_j);
181
182 let new_modularity = self.calculate_modularity_for_partition(&communities, nodes)?;
183 if new_modularity > best_modularity {
184 best_modularity = new_modularity;
185 best_communities = communities.clone();
186 }
187 } else {
188 break;
189 }
190 }
191
192 Ok(best_communities.into_iter().map(|set| set.into_iter().collect()).collect())
193 }
194
195 /// Spectral clustering using graph Laplacian
196 pub fn spectral_clustering(&self, nodes: &[NodeId], k: usize) -> Result<Vec<Vec<NodeId>>> {
197 // This is a simplified version - in practice, you'd use proper eigenvalue decomposition
198 // For now, we'll use a heuristic approach based on node connectivity
199
200 let mut communities: Vec<Vec<NodeId>> = Vec::new();
201 let mut remaining_nodes: HashSet<NodeId> = nodes.iter().copied().collect();
202
203 for _ in 0..k {
204 if remaining_nodes.is_empty() {
205 break;
206 }
207
208 // Start with a random node
209 let start_node = *remaining_nodes.iter().next().unwrap();
210 let mut community = vec![start_node];
211 remaining_nodes.remove(&start_node);
212
213 // Add nodes that are well-connected to the current community
214 let mut added = true;
215 while added && !remaining_nodes.is_empty() {
216 added = false;
217 let mut best_node = None;
218 let mut best_score = 0.0;
219
220 for &candidate in &remaining_nodes {
221 let score = self.calculate_community_affinity(candidate, &community)?;
222 if score > best_score {
223 best_score = score;
224 best_node = Some(candidate);
225 }
226 }
227
228 if let Some(node) = best_node {
229 // Only add node if it has a strong connection to the community
230 // (at least 50% of community members)
231 let min_threshold = 0.5;
232 if best_score >= min_threshold {
233 community.push(node);
234 remaining_nodes.remove(&node);
235 added = true;
236 }
237 }
238 }
239
240 communities.push(community);
241 }
242
243 // Add any remaining nodes to the last community
244 if !remaining_nodes.is_empty() {
245 if let Some(last_community) = communities.last_mut() {
246 last_community.extend(remaining_nodes);
247 } else {
248 communities.push(remaining_nodes.into_iter().collect());
249 }
250 }
251
252 Ok(communities)
253 }
254
255 // Helper methods
256
257 fn initialize_communities(&self, nodes: &[NodeId]) -> Result<Vec<Vec<NodeId>>> {
258 Ok(nodes.iter().map(|&node| vec![node]).collect())
259 }
260
261 fn find_node_community(&self, communities: &[Vec<NodeId>], node: NodeId) -> usize {
262 for (i, community) in communities.iter().enumerate() {
263 if community.contains(&node) {
264 return i;
265 }
266 }
267 0 // Fallback
268 }
269
270 fn find_best_community_for_node(
271 &self,
272 node: NodeId,
273 communities: &[Vec<NodeId>],
274 nodes: &[NodeId],
275 ) -> Result<usize> {
276 let mut best_community = 0;
277 let mut best_gain = f64::NEG_INFINITY;
278
279 for (i, community) in communities.iter().enumerate() {
280 let gain = self.calculate_modularity_gain(node, community, nodes)?;
281 if gain > best_gain {
282 best_gain = gain;
283 best_community = i;
284 }
285 }
286
287 Ok(best_community)
288 }
289
290 fn move_node_to_community(
291 &self,
292 communities: &mut Vec<Vec<NodeId>>,
293 node: NodeId,
294 from: usize,
295 to: usize,
296 ) {
297 if from != to {
298 communities[from].retain(|&n| n != node);
299 communities[to].push(node);
300 }
301 }
302
303 fn calculate_modularity_gain(
304 &self,
305 node: NodeId,
306 community: &[NodeId],
307 nodes: &[NodeId],
308 ) -> Result<f64> {
309 let mut internal_connections = 0;
310 let total_edges = self.count_total_edges(nodes)?;
311
312 for &other in community {
313 if other != node && self.are_connected(node, other)? {
314 internal_connections += 1;
315 }
316 }
317
318 let node_degree = self.get_degree(node, nodes)?;
319 let community_degree: usize = community.iter()
320 .map(|&n| self.get_degree(n, nodes).unwrap_or(0))
321 .sum();
322
323 if total_edges == 0 {
324 return Ok(0.0);
325 }
326
327 let gain = (internal_connections as f64) -
328 (node_degree as f64 * community_degree as f64) / (2.0 * total_edges as f64);
329
330 Ok(gain)
331 }
332
333 fn get_neighbor_labels(
334 &self,
335 node: NodeId,
336 labels: &HashMap<NodeId, usize>,
337 nodes: &[NodeId],
338 ) -> Result<Vec<usize>> {
339 let mut neighbor_labels = Vec::new();
340 let relationships = self.graph.get_node_relationships(node, Direction::Both, None);
341
342 for relationship in relationships {
343 let neighbor = if relationship.start_node == node {
344 relationship.end_node
345 } else {
346 relationship.start_node
347 };
348
349 if nodes.contains(&neighbor) {
350 if let Some(&label) = labels.get(&neighbor) {
351 neighbor_labels.push(label);
352 }
353 }
354 }
355
356 Ok(neighbor_labels)
357 }
358
359 fn most_frequent_label(&self, labels: Vec<usize>) -> Option<usize> {
360 if labels.is_empty() {
361 return None;
362 }
363
364 let mut counts: HashMap<usize, usize> = HashMap::new();
365 for label in labels {
366 *counts.entry(label).or_insert(0) += 1;
367 }
368
369 counts.into_iter()
370 .max_by_key(|(_, count)| *count)
371 .map(|(label, _)| label)
372 }
373
374 fn labels_to_communities(
375 &self,
376 labels: HashMap<NodeId, usize>,
377 nodes: &[NodeId],
378 ) -> Result<Vec<Vec<NodeId>>> {
379 let mut communities: HashMap<usize, Vec<NodeId>> = HashMap::new();
380
381 for &node in nodes {
382 if let Some(&label) = labels.get(&node) {
383 communities.entry(label).or_insert_with(Vec::new).push(node);
384 }
385 }
386
387 Ok(communities.into_values().collect())
388 }
389
390 fn find_components_excluding_edges(
391 &self,
392 nodes: &[NodeId],
393 excluded_edges: &HashSet<RelationshipId>,
394 ) -> Result<Vec<Vec<NodeId>>> {
395 let mut visited: HashSet<NodeId> = HashSet::new();
396 let mut components = Vec::new();
397
398 for &node in nodes {
399 if !visited.contains(&node) {
400 let component = self.bfs_component(node, nodes, excluded_edges, &mut visited)?;
401 components.push(component);
402 }
403 }
404
405 Ok(components)
406 }
407
408 fn bfs_component(
409 &self,
410 start: NodeId,
411 nodes: &[NodeId],
412 excluded_edges: &HashSet<RelationshipId>,
413 visited: &mut HashSet<NodeId>,
414 ) -> Result<Vec<NodeId>> {
415 let mut component = Vec::new();
416 let mut queue = VecDeque::new();
417
418 queue.push_back(start);
419 visited.insert(start);
420
421 while let Some(current) = queue.pop_front() {
422 component.push(current);
423
424 let relationships = self.graph.get_node_relationships(current, Direction::Both, None);
425 for relationship in relationships {
426 if excluded_edges.contains(&relationship.id) {
427 continue;
428 }
429
430 let neighbor = if relationship.start_node == current {
431 relationship.end_node
432 } else {
433 relationship.start_node
434 };
435
436 if nodes.contains(&neighbor) && !visited.contains(&neighbor) {
437 visited.insert(neighbor);
438 queue.push_back(neighbor);
439 }
440 }
441 }
442
443 Ok(component)
444 }
445
446 fn calculate_edge_betweenness(
447 &self,
448 nodes: &[NodeId],
449 edges: &HashSet<RelationshipId>,
450 ) -> Result<HashMap<RelationshipId, f64>> {
451 let mut betweenness: HashMap<RelationshipId, f64> = HashMap::new();
452
453 for &edge in edges {
454 betweenness.insert(edge, 0.0);
455 }
456
457 // This is a simplified calculation - proper edge betweenness requires
458 // counting shortest paths that pass through each edge
459 for &source in nodes {
460 for &target in nodes {
461 if source != target {
462 // Count paths from source to target that use each edge
463 // This is a placeholder implementation
464 for &edge in edges {
465 betweenness.insert(edge, betweenness[&edge] + 1.0);
466 }
467 }
468 }
469 }
470
471 Ok(betweenness)
472 }
473
474 fn calculate_modularity_for_partition(
475 &self,
476 communities: &[HashSet<NodeId>],
477 nodes: &[NodeId],
478 ) -> Result<f64> {
479 let communities_vec: Vec<Vec<NodeId>> = communities.iter()
480 .map(|set| set.iter().copied().collect())
481 .collect();
482
483 self.modularity(&communities_vec, nodes)
484 }
485
486 fn calculate_merge_delta_q(
487 &self,
488 communities: &[HashSet<NodeId>],
489 i: usize,
490 j: usize,
491 nodes: &[NodeId],
492 ) -> Result<f64> {
493 // Simplified delta Q calculation
494 let mut connections = 0;
495
496 for &node_i in &communities[i] {
497 for &node_j in &communities[j] {
498 if self.are_connected(node_i, node_j)? {
499 connections += 1;
500 }
501 }
502 }
503
504 Ok(connections as f64)
505 }
506
507 fn calculate_community_affinity(&self, node: NodeId, community: &[NodeId]) -> Result<f64> {
508 let mut connections = 0;
509
510 for &community_node in community {
511 if self.are_connected(node, community_node)? {
512 connections += 1;
513 }
514 }
515
516 Ok(connections as f64 / community.len() as f64)
517 }
518
519 fn count_total_edges(&self, nodes: &[NodeId]) -> Result<usize> {
520 let mut edge_count = 0;
521 let mut counted_edges: HashSet<RelationshipId> = HashSet::new();
522
523 for &node in nodes {
524 let relationships = self.graph.get_node_relationships(node, Direction::Both, None);
525 for relationship in relationships {
526 if !counted_edges.contains(&relationship.id) {
527 let other = if relationship.start_node == node {
528 relationship.end_node
529 } else {
530 relationship.start_node
531 };
532
533 if nodes.contains(&other) {
534 edge_count += 1;
535 counted_edges.insert(relationship.id);
536 }
537 }
538 }
539 }
540
541 Ok(edge_count)
542 }
543
544 fn get_degree(&self, node: NodeId, nodes: &[NodeId]) -> Result<usize> {
545 let relationships = self.graph.get_node_relationships(node, Direction::Both, None);
546 let mut degree = 0;
547
548 for relationship in relationships {
549 let other = if relationship.start_node == node {
550 relationship.end_node
551 } else {
552 relationship.start_node
553 };
554
555 if nodes.contains(&other) {
556 degree += 1;
557 }
558 }
559
560 Ok(degree)
561 }
562
563 fn are_connected(&self, node1: NodeId, node2: NodeId) -> Result<bool> {
564 let relationships = self.graph.get_node_relationships(node1, Direction::Both, None);
565
566 for relationship in relationships {
567 let other = if relationship.start_node == node1 {
568 relationship.end_node
569 } else {
570 relationship.start_node
571 };
572
573 if other == node2 {
574 return Ok(true);
575 }
576 }
577
578 Ok(false)
579 }
580}
581
582#[cfg(test)]
583mod tests {
584 use super::*;
585 use crate::Graph;
586
587 fn create_test_graph() -> Graph {
588 let graph = Graph::new();
589
590 // Create a graph with two clear communities
591 // Community 1: A-B-C (triangle)
592 // Community 2: D-E-F (triangle)
593 // Bridge: C-D
594
595 let node_a = graph.create_node();
596 let node_b = graph.create_node();
597 let node_c = graph.create_node();
598 let node_d = graph.create_node();
599 let node_e = graph.create_node();
600 let node_f = graph.create_node();
601
602 let schema = graph.schema();
603 let mut schema = schema.write();
604 let rel_type = schema.get_or_create_relationship_type("CONNECTS");
605 drop(schema);
606
607 // Community 1
608 graph.create_relationship(node_a, node_b, rel_type).unwrap();
609 graph.create_relationship(node_b, node_c, rel_type).unwrap();
610 graph.create_relationship(node_c, node_a, rel_type).unwrap();
611
612 // Community 2
613 graph.create_relationship(node_d, node_e, rel_type).unwrap();
614 graph.create_relationship(node_e, node_f, rel_type).unwrap();
615 graph.create_relationship(node_f, node_d, rel_type).unwrap();
616
617 // Bridge
618 graph.create_relationship(node_c, node_d, rel_type).unwrap();
619
620 graph
621 }
622
623 #[test]
624 fn test_label_propagation() {
625 let graph = create_test_graph();
626 let community_detection = CommunityDetection::new(&graph);
627
628 let nodes: Vec<_> = graph.get_all_nodes();
629 let communities = community_detection.label_propagation(&nodes, 10).unwrap();
630
631 assert!(!communities.is_empty());
632
633 // Each node should be in exactly one community
634 let total_nodes: usize = communities.iter().map(|c| c.len()).sum();
635 assert_eq!(total_nodes, nodes.len());
636 }
637
638 #[test]
639 fn test_modularity_calculation() {
640 let graph = create_test_graph();
641 let community_detection = CommunityDetection::new(&graph);
642
643 let nodes: Vec<_> = graph.get_all_nodes();
644
645 // Create artificial communities for testing
646 let communities = vec![
647 nodes[0..3].to_vec(), // First 3 nodes
648 nodes[3..].to_vec(), // Remaining nodes
649 ];
650
651 let modularity = community_detection.modularity(&communities, &nodes).unwrap();
652
653 // Modularity should be between -1 and 1
654 assert!(modularity >= -1.0 && modularity <= 1.0);
655 }
656
657 #[test]
658 fn test_spectral_clustering() {
659 let graph = create_test_graph();
660 let community_detection = CommunityDetection::new(&graph);
661
662 let nodes: Vec<_> = graph.get_all_nodes();
663 let communities = community_detection.spectral_clustering(&nodes, 2).unwrap();
664
665 assert_eq!(communities.len(), 2);
666
667 // Each node should be in exactly one community
668 let total_nodes: usize = communities.iter().map(|c| c.len()).sum();
669 assert_eq!(total_nodes, nodes.len());
670 }
671}