The world's most clever kitty cat
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

Use RwLock for brain handle

Ben C 71d0a017 032de680

+73 -56
+2 -2
src/cmd/dump_chain.rs
··· 33 33 let mut brotli_writer = brotli::CompressorWriter::with_params(&mut buf, 4096, &params); 34 34 35 35 if compat.unwrap_or_default() { 36 - let brain = ctx.brain_handle.lock().await; 36 + let brain = ctx.brain_handle.read().await; 37 37 let map = brain.as_legacy_hashmap(); 38 38 drop(brain); 39 39 rmp_serde::encode::write(&mut brotli_writer, &map) 40 40 .context("Failed to legacy encode brain")?; 41 41 } else { 42 - let brain = ctx.brain_handle.lock().await; 42 + let brain = ctx.brain_handle.read().await; 43 43 rmp_serde::encode::write(&mut brotli_writer, &*brain) 44 44 .context("Failed to write serialized brain")?; 45 45 }
+1 -1
src/cmd/load_chain.rs
··· 50 50 }; 51 51 52 52 { 53 - let mut brain = ctx.brain_handle.lock().await; 53 + let mut brain = ctx.brain_handle.write().await; 54 54 brain.merge_from(new_brain); 55 55 ctx.pending_save.store(true, Ordering::Relaxed); 56 56 update_status(&*brain, &ctx.shard_sender).context("Failed to update status")?;
+1 -1
src/cmd/weights.rs
··· 16 16 } 17 17 18 18 async fn get_output(token: &str, brain: &BrainHandle) -> Option<String> { 19 - let brain = brain.lock().await; 19 + let brain = brain.read().await; 20 20 21 21 brain.get_weights(token).map(|edges| { 22 22 let sep = String::from("\n");
+14 -21
src/main.rs
··· 26 26 use log::{debug, error, info, warn}; 27 27 use prelude::*; 28 28 use tokio::{ 29 - sync::Mutex, 29 + sync::RwLock, 30 30 time::{self, Duration}, 31 31 }; 32 32 use twilight_gateway::{ ··· 48 48 status::update_status, 49 49 }; 50 50 51 - pub type BrainHandle = Mutex<Brain>; 51 + pub type BrainHandle = RwLock<Brain>; 52 52 53 53 #[derive(Debug)] 54 54 pub struct BotContext { ··· 64 64 65 65 async fn handle_discord_event(event: Event, ctx: Arc<BotContext>) -> Result { 66 66 match event { 67 - Event::MessageCreate(msg) => handle_discord_message(msg, ctx).await, 67 + Event::MessageCreate(msg) => handle_discord_message(msg, ctx).await.context("While handling a new message"), 68 68 Event::InteractionCreate(mut inter) => { 69 69 if let Some(InteractionData::ApplicationCommand(data)) = 70 70 std::mem::take(&mut inter.0.data) 71 71 { 72 - handle_app_command(*data, ctx, inter.0).await 72 + handle_app_command(*data, ctx, inter.0).await.context("While handling an app command") 73 73 } else { 74 74 Ok(()) 75 75 } 76 76 } 77 77 Event::Ready(ev) => { 78 78 info!("Connected to gateway as {}", ev.user.name); 79 - let brain = ctx.brain_handle.lock().await; 80 - update_status(&*brain, &ctx.shard_sender).context("Failed to update status") 79 + let brain = ctx.brain_handle.read().await; 80 + update_status(&*brain, &ctx.shard_sender).context("Failed to update status on ready") 81 81 } 82 82 _ => Ok(()), 83 83 } ··· 99 99 let mut file = File::create(&ctx.brain_file_path).context("Failed to open brain file")?; 100 100 let params = BrotliEncoderParams::default(); 101 101 let mut brotli_writer = brotli::CompressorWriter::with_params(&mut file, 4096, &params); 102 - let brain = ctx.brain_handle.lock().await; 102 + let brain = ctx.brain_handle.read().await; 103 103 rmp_serde::encode::write(&mut brotli_writer, &*brain) 104 104 .context("Failed to write serialized brain")?; 105 105 debug!("Saved brain file"); ··· 145 145 info!("Creating new brain file at {brain_file_path:?}"); 146 146 Brain::default() 147 147 }; 148 - let brain_handle = Mutex::new(brain); 148 + let brain_handle = RwLock::new(brain); 149 149 150 150 // Init 151 151 let mut shard = Shard::new(ShardId::ONE, token.to_string(), intents); ··· 174 174 pending_save: AtomicBool::new(false), 175 175 }); 176 176 177 - info!("Ensuring brain is writable..."); 178 - save_brain(context.clone()) 179 - .await 180 - .context("Brain file is not writable")?; 181 - info!("Brain file saved"); 182 - 183 177 info!("Registering Commands..."); 184 178 register_all_commands(context.clone()).await?; 185 179 ··· 197 191 Ok(()) = tokio::signal::ctrl_c() => { 198 192 info!("SIGINT: Closing connection and saving"); 199 193 shard.close(CloseFrame::NORMAL); 200 - break; 201 194 } 202 195 _ = interval.tick() => { 203 196 debug!("Save Interval"); ··· 214 207 opt = shard.next_event(EventTypeFlags::all()) => { 215 208 match opt { 216 209 Some(Ok(Event::GatewayClose(_))) | None => { 217 - info!("Disconnected from Discord: Saving brain and exiting"); 210 + info!("Disconnected from Discord"); 218 211 break; 219 212 } 220 213 Some(Ok(event)) => { ··· 233 226 } 234 227 } 235 228 236 - save_brain(context) 237 - .await 238 - .context("Failed to write brain file on exit")?; 239 - 240 - info!("Save Complete, Exiting"); 229 + if context.pending_save.load(Ordering::Relaxed) { 230 + save_brain(context) 231 + .await 232 + .context("Failed to write brain file on exit")?; 233 + } 241 234 242 235 Ok(()) 243 236 }
+55 -31
src/on_message.rs
··· 7 7 use twilight_model::{ 8 8 channel::message::{AllowedMentions, MessageFlags, MessageType}, 9 9 gateway::payload::incoming::MessageCreate, 10 + id::{ 11 + Id, 12 + marker::{ChannelMarker, MessageMarker}, 13 + }, 10 14 }; 11 15 12 16 use crate::{BotContext, prelude::*, status::update_status}; 13 17 14 - pub async fn handle_discord_message(msg: Box<MessageCreate>, ctx: Arc<BotContext>) -> Result { 15 - let channel_id = msg.channel_id.get(); 16 - let is_self = msg.author.id == ctx.self_id; 17 - let is_normal_message = matches!(msg.kind, MessageType::Regular | MessageType::Reply); 18 - let is_ephemeral = msg 19 - .flags 20 - .is_some_and(|flags| flags.contains(MessageFlags::EPHEMERAL)); 21 - let is_dm = msg.guild_id.is_none(); 18 + async fn learn_message(msg: &str, ctx: Arc<BotContext>) -> Result { 19 + let mut brain = ctx.brain_handle.write().await; 20 + let learned_new_word = brain.ingest(&msg); 21 + ctx.pending_save.store(true, Ordering::Relaxed); 22 22 23 - // Should we consider this message at all? 24 - if !is_normal_message || is_ephemeral || is_dm { 25 - return Ok(()); 23 + if learned_new_word { 24 + update_status(&*brain, &ctx.shard_sender).context("Failed to update status")?; 26 25 } 27 26 28 - // Should we learn from this message? (We don't want to learn from ourselves) 29 - if !is_self { 30 - let mut brain = ctx.brain_handle.lock().await; 31 - let learned_new_word = brain.ingest(&msg.content); 32 - ctx.pending_save.store(true, Ordering::Relaxed); 27 + Ok(()) 28 + } 33 29 34 - if learned_new_word { 35 - update_status(&*brain, &ctx.shard_sender).context("Failed to update status")?; 36 - } 37 - } 38 - 39 - // Should Reply to Message? 40 - if !ctx.reply_channels.contains(&channel_id) { 41 - return Ok(()); 42 - } 43 - 30 + async fn reply_message( 31 + msg: &str, 32 + msg_id: Id<MessageMarker>, 33 + channel_id: Id<ChannelMarker>, 34 + is_self: bool, 35 + ctx: &Arc<BotContext>, 36 + ) -> Result { 44 37 let (typ_tx, typ_rx) = tokio::sync::oneshot::channel(); 45 38 let (done_tx, done_rx) = tokio::sync::oneshot::channel(); 46 39 47 40 let ctx_typ = ctx.clone(); 48 - let typ_id = msg.channel_id; 41 + let typ_id = channel_id; 49 42 tokio::spawn(async move { 50 43 if typ_rx.await.ok().is_some_and(|start| start) { 51 44 if let Err(why) = ctx_typ.http.create_typing_trigger(typ_id).await { ··· 55 48 done_tx.send(()).ok(); 56 49 }); 57 50 58 - let brain = ctx.brain_handle.lock().await; 51 + let brain = ctx.brain_handle.read().await; 59 52 if let Some(reply_text) = brain 60 - .respond(&msg.content, is_self, Some(typ_tx)) 53 + .respond(&msg, is_self, Some(typ_tx)) 61 54 .filter(|s| !s.trim().is_empty()) 62 55 { 63 56 drop(brain); ··· 65 58 let allowed_mentions = AllowedMentions::default(); 66 59 let my_msg = ctx 67 60 .http 68 - .create_message(msg.channel_id) 61 + .create_message(channel_id) 69 62 .content(&reply_text) 70 63 .allowed_mentions(Some(&allowed_mentions)); 71 64 72 65 let my_msg = if !is_self { 73 - my_msg.reply(msg.id).fail_if_not_exists(false) 66 + my_msg.reply(msg_id).fail_if_not_exists(false) 74 67 } else { 75 68 my_msg 76 69 }; ··· 80 73 81 74 Ok(()) 82 75 } 76 + 77 + pub async fn handle_discord_message(msg: Box<MessageCreate>, ctx: Arc<BotContext>) -> Result { 78 + let channel_id = msg.channel_id.get(); 79 + let is_self = msg.author.id == ctx.self_id; 80 + let is_normal_message = matches!(msg.kind, MessageType::Regular | MessageType::Reply); 81 + let is_ephemeral = msg 82 + .flags 83 + .is_some_and(|flags| flags.contains(MessageFlags::EPHEMERAL)); 84 + let is_dm = msg.guild_id.is_none(); 85 + 86 + // Should we consider this message at all? 87 + if !is_normal_message || is_ephemeral || is_dm { 88 + return Ok(()); 89 + } 90 + 91 + // Should Reply to Message? 92 + if ctx.reply_channels.contains(&channel_id) { 93 + reply_message(&msg.content, msg.id, msg.channel_id, is_self, &ctx) 94 + .await 95 + .context("Bingus failed to reply to a message")?; 96 + } 97 + 98 + // Should we learn from this message? (We don't want to learn from ourselves) 99 + if !is_self { 100 + learn_message(&msg.content, ctx) 101 + .await 102 + .context("Bingus failed to learn from a message")?; 103 + } 104 + 105 + Ok(()) 106 + }