The world's most clever kitty cat
1#![allow(unused)]
2
3use std::collections::HashMap;
4
5use log::debug;
6use serde::{Deserialize, Serialize};
7use tokio::sync::oneshot;
8
9/// Some = Word, None = End Message
10pub type Token = Option<String>;
11pub type Weight = u16;
12
13#[derive(Default, Debug, Clone, Serialize, Deserialize)]
14pub struct Edges(HashMap<Token, Weight>, u64);
15
16#[derive(Default, Debug, Clone, Serialize, Deserialize)]
17pub struct Brain(HashMap<Token, Edges>);
18
19pub type TypingSender = oneshot::Sender<bool>;
20
21pub fn format_token(tok: &Token) -> String {
22 if let Some(w) = tok {
23 w.clone()
24 } else {
25 "~END".to_string()
26 }
27}
28
29impl Edges {
30 fn increment_token(&mut self, tok: &Token) {
31 if let Some(w) = self.0.get_mut(tok) {
32 *w = w.saturating_add(1);
33 } else {
34 self.0.insert(tok.clone(), 1);
35 }
36 self.1 = self.1.saturating_add(1);
37 }
38
39 fn merge_from(&mut self, other: Self) {
40 self.0.reserve(other.0.len());
41 for (k, v) in other.0.into_iter() {
42 if let Some(w) = self.0.get_mut(&k) {
43 *w = w.saturating_add(v);
44 } else {
45 self.0.insert(k, v);
46 }
47 self.1 = self.1.saturating_add(v as u64);
48 }
49 }
50
51 fn sample(&self, rand: &mut fastrand::Rng, allow_end: bool) -> Option<&Token> {
52 let total_dist = if !allow_end && let Some(weight) = self.0.get(&None) {
53 self.1 - *weight as u64
54 } else {
55 self.1
56 };
57 let mut dist_left = rand.f64() * total_dist as f64;
58
59 for (tok, weight) in self.0.iter().filter(|(tok, _)| allow_end || tok.is_some()) {
60 dist_left -= *weight as f64;
61 if dist_left < 0.0 {
62 return Some(tok);
63 }
64 }
65 None
66 }
67
68 pub fn forget(&mut self, token: &Token) {
69 if let Some(w) = self.0.remove(token) {
70 self.1 -= w as u64;
71 }
72 }
73
74 pub fn iter_weights(&self) -> impl Iterator<Item = (&Token, Weight, f64)> {
75 self.0
76 .iter()
77 .map(|(k, v)| (k, *v, (*v as f64) / (self.1 as f64)))
78 }
79}
80
81const FORCE_REPLIES: bool = cfg!(test) || (option_env!("BINGUS_FORCE_REPLY").is_some());
82
83impl Brain {
84 fn normalize_token(word: &str) -> Token {
85 let w = if word.starts_with("http://") || word.starts_with("https://") {
86 word.to_string()
87 } else {
88 word.to_ascii_lowercase()
89 };
90 Some(w)
91 }
92
93 fn parse(msg: &str) -> impl Iterator<Item = Token> {
94 msg.split_whitespace()
95 .filter_map(|w| {
96 // Filter out pings, they can get annoying
97 if w.starts_with("<@") && w.ends_with(">") {
98 None
99 } else {
100 Some(Self::normalize_token(w))
101 }
102 })
103 .chain(std::iter::once(None))
104 }
105
106 fn should_reply(rand: &mut fastrand::Rng, is_self: bool) -> bool {
107 let chance = if is_self { 45 } else { 80 };
108 let roll = rand.u8(0..=100);
109
110 (FORCE_REPLIES) || roll <= chance
111 }
112
113 fn extract_final_word(msg: &str) -> Option<String> {
114 msg.split_whitespace()
115 .last()
116 .and_then(Self::normalize_token)
117 }
118
119 fn random_token(&self, rand: &mut fastrand::Rng) -> Option<&Token> {
120 let len = self.0.len();
121 if len == 0 {
122 None
123 } else {
124 let i = rand.usize(..len);
125 self.0.keys().nth(i)
126 }
127 }
128
129 pub fn ingest(&mut self, msg: &str) -> bool {
130 // Using reduce instead of .any here to prevent short circuting
131 Self::parse(msg)
132 .map_windows(|[from, to]| {
133 if let Some(edge) = self.0.get_mut(from) {
134 edge.increment_token(to);
135 false
136 } else {
137 let new = Edges(HashMap::from_iter([(to.clone(), 1)]), 1);
138 self.0.insert(from.clone(), new);
139 true
140 }
141 })
142 .reduce(|acc, c| acc || c)
143 .unwrap_or_default()
144 }
145
146 pub fn forget(&mut self, word: &str) {
147 let tok = Self::normalize_token(word);
148
149 self.0.remove(&tok);
150
151 for edge in self.0.values_mut() {
152 edge.forget(&tok);
153 }
154 }
155
156 pub fn forget_edge(&mut self, from: &str, to: &str) -> bool {
157 if let Some(edges) = self.0.get_mut(&Self::normalize_token(from)) {
158 edges.forget(&Self::normalize_token(to));
159 true
160 } else {
161 false
162 }
163 }
164
165 pub fn merge_from(&mut self, other: Self) {
166 for (k, v) in other.0.into_iter() {
167 if let Some(edges) = self.0.get_mut(&k) {
168 edges.merge_from(v);
169 } else {
170 self.0.insert(k, v);
171 }
172 }
173 }
174
175 fn next_from(&self, tok: &Token, rand: &mut fastrand::Rng, allow_end: bool) -> Option<&Token> {
176 // Get the edges for the current token
177 // If we have that token, sample its edges
178 // Otherwise, if we don't know that token, and allow_end is false, try to pick a random token instead
179 self.0
180 .get(tok)
181 .and_then(|edges| edges.sample(rand, allow_end))
182 .or_else(|| {
183 if allow_end {
184 None
185 } else {
186 self.random_token(rand)
187 }
188 })
189 }
190
191 pub fn respond(
192 &self,
193 msg: &str,
194 is_self: bool,
195 force_reply: bool,
196 mut typing_oneshot: Option<TypingSender>,
197 ) -> Option<String> {
198 const MAX_TOKENS: usize = 20;
199
200 let mut rng = fastrand::Rng::new();
201
202 // Roll if we should reply
203 if !force_reply && !Self::should_reply(&mut rng, is_self) {
204 debug!("Failed roll");
205 return None;
206 }
207
208 // Get the final token
209 let last_token = Self::extract_final_word(msg);
210
211 let mut current_token = if let Some(t) = last_token {
212 // We found a word at the end of the previous message
213 &Some(t)
214 } else {
215 // We couldn't find a word at the end of the last message, pick a random one
216 // If we *still* don't have a token, return early
217 self.random_token(&mut rng)?
218 };
219
220 let mut chain = Vec::with_capacity(MAX_TOKENS);
221 let sep = String::from(" ");
222
223 while let Some(next @ Some(s)) = self.next_from(current_token, &mut rng, !chain.is_empty())
224 && chain.len() <= MAX_TOKENS
225 {
226 chain.push(s);
227 if let Some(typ) = typing_oneshot.take() {
228 typ.send(true).ok();
229 }
230 current_token = next;
231 }
232
233 if let Some(typ) = typing_oneshot.take() {
234 typ.send(false).ok();
235 }
236
237 if chain.is_empty() {
238 None
239 } else {
240 let s = chain
241 .into_iter()
242 .intersperse(&sep)
243 .cloned()
244 .collect::<String>();
245 Some(s)
246 .filter(|s| !s.trim().is_empty())
247 .filter(|s| s.encode_utf16().count() < 2000)
248 }
249 }
250
251 pub fn word_count(&self) -> usize {
252 self.0.len()
253 }
254
255 pub fn get_weights(&self, tok: &str) -> Option<&Edges> {
256 self.0
257 .get(&Self::normalize_token(tok))
258 .filter(|e| !e.0.is_empty())
259 }
260
261 fn legacy_token_format(tok: &Token) -> String {
262 tok.as_ref()
263 .map(|s| format!("W-{s}"))
264 .unwrap_or_else(|| String::from("E--"))
265 }
266
267 pub fn as_legacy_hashmap(&self) -> HashMap<String, HashMap<String, Weight>> {
268 self.0
269 .iter()
270 .map(|(k, v)| {
271 let map =
272 v.0.iter()
273 .map(|(t, w)| (Self::legacy_token_format(t), *w))
274 .collect();
275 (Self::legacy_token_format(k), map)
276 })
277 .collect()
278 }
279
280 fn read_legacy_token(s: String) -> Token {
281 match s.as_str() {
282 "E--" => None,
283 word => Some(word.strip_prefix("W-").unwrap_or(word).to_string()),
284 }
285 }
286
287 pub fn from_legacy_hashmap(map: HashMap<String, HashMap<String, Weight>>) -> Self {
288 Self(
289 map.into_iter()
290 .map(|(k, v)| {
291 let sum = v.values().map(|w| *w as u64).sum::<u64>();
292 let edges = Edges(
293 v.into_iter()
294 .map(|(t, w)| (Self::read_legacy_token(t), w))
295 .collect(),
296 sum,
297 );
298 (Self::read_legacy_token(k), edges)
299 })
300 .collect(),
301 )
302 }
303}
304
305#[cfg(test)]
306mod tests {
307
308 use super::*;
309 use std::default::Default;
310
311 extern crate test;
312
313 use test::Bencher;
314
315 #[test]
316 fn ingest_parse() {
317 let tokens = Brain::parse("Hello world").collect::<Vec<_>>();
318 assert_eq!(
319 tokens,
320 vec![Some("hello".to_string()), Some("world".to_string()), None]
321 );
322 }
323
324 #[test]
325 fn ingest_url() {
326 let tokens = Brain::parse("https://example.com/CAPS-PATH").collect::<Vec<_>>();
327 assert_eq!(
328 tokens,
329 vec![Some("https://example.com/CAPS-PATH".to_string()), None]
330 );
331 }
332
333 #[test]
334 fn ingest_ping() {
335 let tokens = Brain::parse("hi <@1234567>").collect::<Vec<_>>();
336 assert_eq!(tokens, vec![Some("hi".to_string()), None]);
337 }
338
339 #[test]
340 fn basic_chain() {
341 let mut brain = Brain::default();
342 brain.ingest("hello world");
343 let hello_edges = brain
344 .0
345 .get(&Some("hello".to_string()))
346 .expect("Hello edges not created");
347 assert_eq!(
348 hello_edges.0,
349 HashMap::from_iter([(Some("world".to_string()), 1)])
350 );
351 let reply = brain.respond("hello", false, false, None);
352 assert_eq!(reply, Some("world".to_string()));
353 }
354
355 #[test]
356 fn at_least_1_token() {
357 let mut brain = Brain::default();
358 brain.ingest("hello world");
359 for _ in 0..100 {
360 brain.ingest("hello");
361 }
362
363 for _ in 0..100 {
364 // I'm too lazy to mock lazyrand LOL!!
365 let reply = brain.respond("hello", false, false, None);
366 assert_eq!(reply, Some("world".to_string()));
367 }
368 }
369
370 #[test]
371 fn forget_word() {
372 let mut brain = Brain::default();
373
374 brain.ingest("hello world");
375 brain.ingest("hello evil world");
376
377 brain.forget("evil");
378
379 assert!(
380 !brain.0.contains_key(&Some(String::from("evil"))),
381 "Edges still exist for evil"
382 );
383 let edges = brain
384 .0
385 .get(&Some(String::from("hello")))
386 .expect("No weights for hello");
387 assert!(
388 !edges.0.contains_key(&Some(String::from("evil"))),
389 "Edges for hello still has evil"
390 );
391 assert_eq!(edges.1, 1);
392 }
393
394 #[test]
395 fn forget_edge() {
396 let mut brain = Brain::default();
397
398 brain.ingest("hello world");
399 brain.ingest("hello evil");
400 brain.ingest("evil bad");
401
402 let exists = brain.forget_edge("hello", "evil");
403
404 assert!(exists, "hello -> evil did not exist");
405
406 assert!(
407 brain.0.contains_key(&Some(String::from("evil"))),
408 "Edges don't exist for evil"
409 );
410 let edges = brain
411 .0
412 .get(&Some(String::from("hello")))
413 .expect("No weights for hello");
414 assert!(
415 !edges.0.contains_key(&Some(String::from("evil"))),
416 "Edges for hello still has evil"
417 );
418 assert!(
419 edges.0.contains_key(&Some(String::from("world"))),
420 "Edges for hello does not have world"
421 );
422 assert_eq!(edges.1, 1);
423 }
424
425 #[test]
426 fn none_on_empty() {
427 let mut brain = Brain::default();
428
429 let reply = brain.respond("hello", false, false, None);
430 assert_eq!(reply, None);
431 }
432
433 #[test]
434 fn none_on_long() {
435 let mut brain = Brain::default();
436
437 let msg = vec!["a"; 2500].into_iter().collect::<String>();
438 let msg = format!("hello {msg}");
439
440 brain.ingest(&msg);
441
442 assert!(brain.respond("hello", false, false, None).is_none())
443 }
444
445 #[test]
446 fn random_on_end() {
447 let mut brain = Brain::default();
448 brain.ingest("world hello");
449
450 let reply = brain.respond("hello", false, false, None);
451 assert!(reply.is_some());
452 }
453
454 #[test]
455 fn long_chain() {
456 const LETTERS: &str = "abcdefghijklmnopqrstuvwxyz";
457 let msg = LETTERS
458 .chars()
459 .map(|c| c.to_string())
460 .intersperse(" ".to_string())
461 .collect::<String>();
462 let mut brain = Brain::default();
463 brain.ingest(&msg);
464 let reply = brain.respond("a", false, false, None);
465 let expected = LETTERS
466 .chars()
467 .skip(1)
468 .take(21)
469 .map(|c| c.to_string())
470 .intersperse(" ".to_string())
471 .collect::<String>();
472 assert_eq!(reply, Some(expected));
473 }
474
475 #[test]
476 fn merge_brain() {
477 let mut brain1 = Brain::default();
478 let mut brain2 = Brain::default();
479
480 brain1.ingest("hello world");
481 brain2.ingest("hello world");
482 brain2.ingest("hello world");
483 brain2.ingest("other word");
484
485 brain1.merge_from(brain2);
486
487 let hello_edges = brain1
488 .0
489 .get(&Some("hello".to_string()))
490 .expect("Hello edges not created");
491 assert_eq!(
492 hello_edges.0,
493 HashMap::from_iter([(Some("world".to_string()), 3)])
494 );
495
496 let new_edges = brain1
497 .0
498 .get(&Some("other".to_string()))
499 .expect("New edges not created");
500 assert_eq!(
501 new_edges.0,
502 HashMap::from_iter([(Some("word".to_string()), 1)])
503 );
504 }
505
506 #[bench]
507 fn bench_learn(b: &mut Bencher) {
508 b.iter(|| {
509 let mut brain = Brain::default();
510 brain.ingest(
511 "your name is bingus the discord bot and this message is a test for benchmarking",
512 );
513 });
514 }
515
516 #[bench]
517 fn bench_respond(b: &mut Bencher) {
518 let mut brain = Brain::default();
519 brain.ingest(
520 "your name is bingus the discord bot and this message is a test for benchmarking",
521 );
522 b.iter(|| {
523 brain.respond("your", false, true, None);
524 });
525 }
526
527 include!("lorem.rs");
528
529 #[bench]
530 fn bench_learn_large(b: &mut Bencher) {
531 b.iter(|| {
532 let mut brain = Brain::default();
533 brain.ingest(LOREM);
534 });
535 }
536}