The world's most clever kitty cat
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

at main 536 lines 15 kB view raw
1#![allow(unused)] 2 3use std::collections::HashMap; 4 5use log::debug; 6use serde::{Deserialize, Serialize}; 7use tokio::sync::oneshot; 8 9/// Some = Word, None = End Message 10pub type Token = Option<String>; 11pub type Weight = u16; 12 13#[derive(Default, Debug, Clone, Serialize, Deserialize)] 14pub struct Edges(HashMap<Token, Weight>, u64); 15 16#[derive(Default, Debug, Clone, Serialize, Deserialize)] 17pub struct Brain(HashMap<Token, Edges>); 18 19pub type TypingSender = oneshot::Sender<bool>; 20 21pub fn format_token(tok: &Token) -> String { 22 if let Some(w) = tok { 23 w.clone() 24 } else { 25 "~END".to_string() 26 } 27} 28 29impl Edges { 30 fn increment_token(&mut self, tok: &Token) { 31 if let Some(w) = self.0.get_mut(tok) { 32 *w = w.saturating_add(1); 33 } else { 34 self.0.insert(tok.clone(), 1); 35 } 36 self.1 = self.1.saturating_add(1); 37 } 38 39 fn merge_from(&mut self, other: Self) { 40 self.0.reserve(other.0.len()); 41 for (k, v) in other.0.into_iter() { 42 if let Some(w) = self.0.get_mut(&k) { 43 *w = w.saturating_add(v); 44 } else { 45 self.0.insert(k, v); 46 } 47 self.1 = self.1.saturating_add(v as u64); 48 } 49 } 50 51 fn sample(&self, rand: &mut fastrand::Rng, allow_end: bool) -> Option<&Token> { 52 let total_dist = if !allow_end && let Some(weight) = self.0.get(&None) { 53 self.1 - *weight as u64 54 } else { 55 self.1 56 }; 57 let mut dist_left = rand.f64() * total_dist as f64; 58 59 for (tok, weight) in self.0.iter().filter(|(tok, _)| allow_end || tok.is_some()) { 60 dist_left -= *weight as f64; 61 if dist_left < 0.0 { 62 return Some(tok); 63 } 64 } 65 None 66 } 67 68 pub fn forget(&mut self, token: &Token) { 69 if let Some(w) = self.0.remove(token) { 70 self.1 -= w as u64; 71 } 72 } 73 74 pub fn iter_weights(&self) -> impl Iterator<Item = (&Token, Weight, f64)> { 75 self.0 76 .iter() 77 .map(|(k, v)| (k, *v, (*v as f64) / (self.1 as f64))) 78 } 79} 80 81const FORCE_REPLIES: bool = cfg!(test) || (option_env!("BINGUS_FORCE_REPLY").is_some()); 82 83impl Brain { 84 fn normalize_token(word: &str) -> Token { 85 let w = if word.starts_with("http://") || word.starts_with("https://") { 86 word.to_string() 87 } else { 88 word.to_ascii_lowercase() 89 }; 90 Some(w) 91 } 92 93 fn parse(msg: &str) -> impl Iterator<Item = Token> { 94 msg.split_whitespace() 95 .filter_map(|w| { 96 // Filter out pings, they can get annoying 97 if w.starts_with("<@") && w.ends_with(">") { 98 None 99 } else { 100 Some(Self::normalize_token(w)) 101 } 102 }) 103 .chain(std::iter::once(None)) 104 } 105 106 fn should_reply(rand: &mut fastrand::Rng, is_self: bool) -> bool { 107 let chance = if is_self { 45 } else { 80 }; 108 let roll = rand.u8(0..=100); 109 110 (FORCE_REPLIES) || roll <= chance 111 } 112 113 fn extract_final_word(msg: &str) -> Option<String> { 114 msg.split_whitespace() 115 .last() 116 .and_then(Self::normalize_token) 117 } 118 119 fn random_token(&self, rand: &mut fastrand::Rng) -> Option<&Token> { 120 let len = self.0.len(); 121 if len == 0 { 122 None 123 } else { 124 let i = rand.usize(..len); 125 self.0.keys().nth(i) 126 } 127 } 128 129 pub fn ingest(&mut self, msg: &str) -> bool { 130 // Using reduce instead of .any here to prevent short circuting 131 Self::parse(msg) 132 .map_windows(|[from, to]| { 133 if let Some(edge) = self.0.get_mut(from) { 134 edge.increment_token(to); 135 false 136 } else { 137 let new = Edges(HashMap::from_iter([(to.clone(), 1)]), 1); 138 self.0.insert(from.clone(), new); 139 true 140 } 141 }) 142 .reduce(|acc, c| acc || c) 143 .unwrap_or_default() 144 } 145 146 pub fn forget(&mut self, word: &str) { 147 let tok = Self::normalize_token(word); 148 149 self.0.remove(&tok); 150 151 for edge in self.0.values_mut() { 152 edge.forget(&tok); 153 } 154 } 155 156 pub fn forget_edge(&mut self, from: &str, to: &str) -> bool { 157 if let Some(edges) = self.0.get_mut(&Self::normalize_token(from)) { 158 edges.forget(&Self::normalize_token(to)); 159 true 160 } else { 161 false 162 } 163 } 164 165 pub fn merge_from(&mut self, other: Self) { 166 for (k, v) in other.0.into_iter() { 167 if let Some(edges) = self.0.get_mut(&k) { 168 edges.merge_from(v); 169 } else { 170 self.0.insert(k, v); 171 } 172 } 173 } 174 175 fn next_from(&self, tok: &Token, rand: &mut fastrand::Rng, allow_end: bool) -> Option<&Token> { 176 // Get the edges for the current token 177 // If we have that token, sample its edges 178 // Otherwise, if we don't know that token, and allow_end is false, try to pick a random token instead 179 self.0 180 .get(tok) 181 .and_then(|edges| edges.sample(rand, allow_end)) 182 .or_else(|| { 183 if allow_end { 184 None 185 } else { 186 self.random_token(rand) 187 } 188 }) 189 } 190 191 pub fn respond( 192 &self, 193 msg: &str, 194 is_self: bool, 195 force_reply: bool, 196 mut typing_oneshot: Option<TypingSender>, 197 ) -> Option<String> { 198 const MAX_TOKENS: usize = 20; 199 200 let mut rng = fastrand::Rng::new(); 201 202 // Roll if we should reply 203 if !force_reply && !Self::should_reply(&mut rng, is_self) { 204 debug!("Failed roll"); 205 return None; 206 } 207 208 // Get the final token 209 let last_token = Self::extract_final_word(msg); 210 211 let mut current_token = if let Some(t) = last_token { 212 // We found a word at the end of the previous message 213 &Some(t) 214 } else { 215 // We couldn't find a word at the end of the last message, pick a random one 216 // If we *still* don't have a token, return early 217 self.random_token(&mut rng)? 218 }; 219 220 let mut chain = Vec::with_capacity(MAX_TOKENS); 221 let sep = String::from(" "); 222 223 while let Some(next @ Some(s)) = self.next_from(current_token, &mut rng, !chain.is_empty()) 224 && chain.len() <= MAX_TOKENS 225 { 226 chain.push(s); 227 if let Some(typ) = typing_oneshot.take() { 228 typ.send(true).ok(); 229 } 230 current_token = next; 231 } 232 233 if let Some(typ) = typing_oneshot.take() { 234 typ.send(false).ok(); 235 } 236 237 if chain.is_empty() { 238 None 239 } else { 240 let s = chain 241 .into_iter() 242 .intersperse(&sep) 243 .cloned() 244 .collect::<String>(); 245 Some(s) 246 .filter(|s| !s.trim().is_empty()) 247 .filter(|s| s.encode_utf16().count() < 2000) 248 } 249 } 250 251 pub fn word_count(&self) -> usize { 252 self.0.len() 253 } 254 255 pub fn get_weights(&self, tok: &str) -> Option<&Edges> { 256 self.0 257 .get(&Self::normalize_token(tok)) 258 .filter(|e| !e.0.is_empty()) 259 } 260 261 fn legacy_token_format(tok: &Token) -> String { 262 tok.as_ref() 263 .map(|s| format!("W-{s}")) 264 .unwrap_or_else(|| String::from("E--")) 265 } 266 267 pub fn as_legacy_hashmap(&self) -> HashMap<String, HashMap<String, Weight>> { 268 self.0 269 .iter() 270 .map(|(k, v)| { 271 let map = 272 v.0.iter() 273 .map(|(t, w)| (Self::legacy_token_format(t), *w)) 274 .collect(); 275 (Self::legacy_token_format(k), map) 276 }) 277 .collect() 278 } 279 280 fn read_legacy_token(s: String) -> Token { 281 match s.as_str() { 282 "E--" => None, 283 word => Some(word.strip_prefix("W-").unwrap_or(word).to_string()), 284 } 285 } 286 287 pub fn from_legacy_hashmap(map: HashMap<String, HashMap<String, Weight>>) -> Self { 288 Self( 289 map.into_iter() 290 .map(|(k, v)| { 291 let sum = v.values().map(|w| *w as u64).sum::<u64>(); 292 let edges = Edges( 293 v.into_iter() 294 .map(|(t, w)| (Self::read_legacy_token(t), w)) 295 .collect(), 296 sum, 297 ); 298 (Self::read_legacy_token(k), edges) 299 }) 300 .collect(), 301 ) 302 } 303} 304 305#[cfg(test)] 306mod tests { 307 308 use super::*; 309 use std::default::Default; 310 311 extern crate test; 312 313 use test::Bencher; 314 315 #[test] 316 fn ingest_parse() { 317 let tokens = Brain::parse("Hello world").collect::<Vec<_>>(); 318 assert_eq!( 319 tokens, 320 vec![Some("hello".to_string()), Some("world".to_string()), None] 321 ); 322 } 323 324 #[test] 325 fn ingest_url() { 326 let tokens = Brain::parse("https://example.com/CAPS-PATH").collect::<Vec<_>>(); 327 assert_eq!( 328 tokens, 329 vec![Some("https://example.com/CAPS-PATH".to_string()), None] 330 ); 331 } 332 333 #[test] 334 fn ingest_ping() { 335 let tokens = Brain::parse("hi <@1234567>").collect::<Vec<_>>(); 336 assert_eq!(tokens, vec![Some("hi".to_string()), None]); 337 } 338 339 #[test] 340 fn basic_chain() { 341 let mut brain = Brain::default(); 342 brain.ingest("hello world"); 343 let hello_edges = brain 344 .0 345 .get(&Some("hello".to_string())) 346 .expect("Hello edges not created"); 347 assert_eq!( 348 hello_edges.0, 349 HashMap::from_iter([(Some("world".to_string()), 1)]) 350 ); 351 let reply = brain.respond("hello", false, false, None); 352 assert_eq!(reply, Some("world".to_string())); 353 } 354 355 #[test] 356 fn at_least_1_token() { 357 let mut brain = Brain::default(); 358 brain.ingest("hello world"); 359 for _ in 0..100 { 360 brain.ingest("hello"); 361 } 362 363 for _ in 0..100 { 364 // I'm too lazy to mock lazyrand LOL!! 365 let reply = brain.respond("hello", false, false, None); 366 assert_eq!(reply, Some("world".to_string())); 367 } 368 } 369 370 #[test] 371 fn forget_word() { 372 let mut brain = Brain::default(); 373 374 brain.ingest("hello world"); 375 brain.ingest("hello evil world"); 376 377 brain.forget("evil"); 378 379 assert!( 380 !brain.0.contains_key(&Some(String::from("evil"))), 381 "Edges still exist for evil" 382 ); 383 let edges = brain 384 .0 385 .get(&Some(String::from("hello"))) 386 .expect("No weights for hello"); 387 assert!( 388 !edges.0.contains_key(&Some(String::from("evil"))), 389 "Edges for hello still has evil" 390 ); 391 assert_eq!(edges.1, 1); 392 } 393 394 #[test] 395 fn forget_edge() { 396 let mut brain = Brain::default(); 397 398 brain.ingest("hello world"); 399 brain.ingest("hello evil"); 400 brain.ingest("evil bad"); 401 402 let exists = brain.forget_edge("hello", "evil"); 403 404 assert!(exists, "hello -> evil did not exist"); 405 406 assert!( 407 brain.0.contains_key(&Some(String::from("evil"))), 408 "Edges don't exist for evil" 409 ); 410 let edges = brain 411 .0 412 .get(&Some(String::from("hello"))) 413 .expect("No weights for hello"); 414 assert!( 415 !edges.0.contains_key(&Some(String::from("evil"))), 416 "Edges for hello still has evil" 417 ); 418 assert!( 419 edges.0.contains_key(&Some(String::from("world"))), 420 "Edges for hello does not have world" 421 ); 422 assert_eq!(edges.1, 1); 423 } 424 425 #[test] 426 fn none_on_empty() { 427 let mut brain = Brain::default(); 428 429 let reply = brain.respond("hello", false, false, None); 430 assert_eq!(reply, None); 431 } 432 433 #[test] 434 fn none_on_long() { 435 let mut brain = Brain::default(); 436 437 let msg = vec!["a"; 2500].into_iter().collect::<String>(); 438 let msg = format!("hello {msg}"); 439 440 brain.ingest(&msg); 441 442 assert!(brain.respond("hello", false, false, None).is_none()) 443 } 444 445 #[test] 446 fn random_on_end() { 447 let mut brain = Brain::default(); 448 brain.ingest("world hello"); 449 450 let reply = brain.respond("hello", false, false, None); 451 assert!(reply.is_some()); 452 } 453 454 #[test] 455 fn long_chain() { 456 const LETTERS: &str = "abcdefghijklmnopqrstuvwxyz"; 457 let msg = LETTERS 458 .chars() 459 .map(|c| c.to_string()) 460 .intersperse(" ".to_string()) 461 .collect::<String>(); 462 let mut brain = Brain::default(); 463 brain.ingest(&msg); 464 let reply = brain.respond("a", false, false, None); 465 let expected = LETTERS 466 .chars() 467 .skip(1) 468 .take(21) 469 .map(|c| c.to_string()) 470 .intersperse(" ".to_string()) 471 .collect::<String>(); 472 assert_eq!(reply, Some(expected)); 473 } 474 475 #[test] 476 fn merge_brain() { 477 let mut brain1 = Brain::default(); 478 let mut brain2 = Brain::default(); 479 480 brain1.ingest("hello world"); 481 brain2.ingest("hello world"); 482 brain2.ingest("hello world"); 483 brain2.ingest("other word"); 484 485 brain1.merge_from(brain2); 486 487 let hello_edges = brain1 488 .0 489 .get(&Some("hello".to_string())) 490 .expect("Hello edges not created"); 491 assert_eq!( 492 hello_edges.0, 493 HashMap::from_iter([(Some("world".to_string()), 3)]) 494 ); 495 496 let new_edges = brain1 497 .0 498 .get(&Some("other".to_string())) 499 .expect("New edges not created"); 500 assert_eq!( 501 new_edges.0, 502 HashMap::from_iter([(Some("word".to_string()), 1)]) 503 ); 504 } 505 506 #[bench] 507 fn bench_learn(b: &mut Bencher) { 508 b.iter(|| { 509 let mut brain = Brain::default(); 510 brain.ingest( 511 "your name is bingus the discord bot and this message is a test for benchmarking", 512 ); 513 }); 514 } 515 516 #[bench] 517 fn bench_respond(b: &mut Bencher) { 518 let mut brain = Brain::default(); 519 brain.ingest( 520 "your name is bingus the discord bot and this message is a test for benchmarking", 521 ); 522 b.iter(|| { 523 brain.respond("your", false, true, None); 524 }); 525 } 526 527 include!("lorem.rs"); 528 529 #[bench] 530 fn bench_learn_large(b: &mut Bencher) { 531 b.iter(|| { 532 let mut brain = Brain::default(); 533 brain.ingest(LOREM); 534 }); 535 } 536}