don't
5
fork

Configure Feed

Select the types of activity you want to include in your feed.

refactor(jetstream): clean-up client task

Signed-off-by: tjh <x@tjh.dev>

tjh 7cdc24a9 ac684096

+163 -123
+163 -123
crates/jetstream/src/task.rs
··· 4 4 subscriber_options::{SubscribeMethod, SubscriberOptions}, 5 5 }; 6 6 use bytes::Bytes; 7 - use futures_util::{SinkExt, StreamExt as _}; 7 + use futures_util::{SinkExt, StreamExt}; 8 8 use serde::Deserialize; 9 9 use std::{ 10 10 sync::{Arc, Mutex}, 11 11 time::Duration, 12 12 }; 13 - use tokio::time::Instant; 13 + use tokio::time::timeout; 14 14 use tokio_tungstenite::{ 15 15 connect_async, 16 16 tungstenite::{ClientRequestBuilder, Error as TungsteniteError, Message, http::Uri}, ··· 20 20 #[cfg(feature = "zstd")] 21 21 const ZSTD_DICTIONARY: &[u8] = include_bytes!("dictionary"); 22 22 23 - /// Duration since last receipt of a message to consider a subscription broken. 24 - const THRESHOLD: Duration = Duration::from_secs(35); 25 - 26 23 /// Duration to rewind the cursor by when reconnecting a borken subscription. 27 24 const REWIND: Duration = Duration::from_secs(1); 25 + 26 + const RECV_TIMEOUT: Duration = Duration::from_secs(45); 28 27 29 28 #[derive(Debug, thiserror::Error)] 30 29 pub enum JetstreamTaskError { ··· 31 32 OptionsUpdate(TungsteniteError), 32 33 #[error("Failed to send close message: {0}")] 33 34 Close(TungsteniteError), 35 + } 36 + 37 + #[derive(Default)] 38 + struct State { 39 + cursor: Option<u128>, 40 + metrics: Metrics, 34 41 } 35 42 36 43 pub async fn jetstream_subscriber( ··· 48 43 initial_cursor: Option<u128>, 49 44 shutdown: CancellationToken, 50 45 ) { 51 - #[cfg(feature = "zstd")] 52 - let dictionary = zstd::dict::DecoderDictionary::copy(ZSTD_DICTIONARY); 53 - 54 - let mut cursor = initial_cursor; 55 - 56 - // How often to check whether the time elapsed since the last message 57 - // received exceeds THRESHOLD. 58 - let mut timeout = tokio::time::interval(Duration::from_secs(5)); 46 + let mut state = State { 47 + cursor: initial_cursor, 48 + metrics, 49 + }; 59 50 60 51 'outer: loop { 61 - let (subscribe_url, require_hello) = 62 - match options.lock().unwrap().subscribe_url(&instance, &cursor) { 63 - SubscribeMethod::Query(url) => (url, false), 64 - SubscribeMethod::Hello(url) => (url, true), 65 - }; 52 + let (subscribe_url, require_hello) = match options 53 + .lock() 54 + .unwrap() 55 + .subscribe_url(&instance, &state.cursor) 56 + { 57 + SubscribeMethod::Query(url) => (url, false), 58 + SubscribeMethod::Hello(url) => (url, true), 59 + }; 66 60 67 61 tracing::debug!(%subscribe_url, "connecting to jetstream"); 68 62 let uri: Uri = subscribe_url ··· 92 88 None => break, 93 89 }; 94 90 95 - metrics.modify(|mut data| data.connects += 1); 91 + state.metrics.modify(|mut data| data.connects += 1); 96 92 let (mut write, mut read) = socket.split(); 97 93 98 94 if require_hello ··· 102 98 continue; 103 99 } 104 100 105 - let mut last = Instant::now(); 106 101 loop { 107 102 let message = tokio::select! { 108 - Some(Ok(message)) = read.next() => { 109 - last = Instant::now(); 110 - message 111 - }, 103 + Some(Ok(outcome)) = shutdown.run_until_cancelled(handle_read_socket(&mut read, &mut state)) => { 104 + match outcome { 105 + ReadOutcome::Event(message) => message, 106 + ReadOutcome::Ping(payload) => { 107 + if let Err(error) = write.send(Message::Pong(payload)).await { 108 + tracing::error!(?error, "failed to send pong"); 109 + break; 110 + } 111 + state.metrics.modify(|mut data| data.pongs_sent += 1); 112 + continue; 113 + } 114 + ReadOutcome::Timeout => { 115 + tracing::error!("time since last received message exceeds threshold"); 116 + state.metrics.modify(|mut data| data.timeouts += 1); 117 + state.cursor = rewind_cursor(state.cursor); 118 + break; 119 + 120 + }, 121 + ReadOutcome::Closed => { 122 + tracing::error!("socket closed"); 123 + break; 124 + } 125 + } 126 + } 112 127 Ok(command) = client_rx.recv_async() => { 113 128 match command { 114 129 ClientCommand::SubscriberOptionsUpdate(complete) => { ··· 136 113 .map_err(JetstreamTaskError::OptionsUpdate); 137 114 138 115 match (result.is_err(), complete.send(result)) { 139 - (_, Err(_)) => { 140 - // Client is broken. 141 - break 'outer; 142 - }, 143 - (true, _) => { 144 - // Reconnect websocket. 145 - break; 146 - } 116 + (_, Err(_)) => break 'outer, 117 + (true, _) => break, 147 118 (false, _) => continue, 148 119 } 149 120 } ··· 148 131 } 149 132 } 150 133 } 151 - now = timeout.tick() => { 152 - if now.duration_since(last) > THRESHOLD { 153 - metrics.modify(|mut data| data.timeouts += 1); 154 - tracing::error!("time since last received message exceeds threshold"); 155 - // Rewind the cursor a few seconds. 156 - if let Some(v) = cursor.as_mut() { 157 - *v = v.saturating_sub(REWIND.as_micros()); 158 - } 159 - break; 160 - } 161 - continue; 162 - } 163 134 else => break, 164 - }; 165 - 166 - let bytes: Bytes = match message { 167 - #[cfg(feature = "zstd")] 168 - Message::Text(payload) => { 169 - panic!( 170 - "received uncompressed message but zstd feature is enabled: {}", 171 - payload.as_str() 172 - ); 173 - } 174 - #[cfg(not(feature = "zstd"))] 175 - Message::Text(payload) => { 176 - metrics.modify(|mut data| { 177 - data.bytes_received_raw += payload.len(); 178 - data.bytes_received += payload.len(); 179 - }); 180 - payload.into() 181 - } 182 - #[cfg(feature = "zstd")] 183 - Message::Binary(compressed_payload) => { 184 - use bytes::Buf as _; 185 - use std::io::Read as _; 186 - 187 - let compressed_bytes = compressed_payload.len(); 188 - let mut payload = Vec::with_capacity(compressed_payload.len()); 189 - let mut decoder = zstd::Decoder::with_prepared_dictionary( 190 - compressed_payload.reader(), 191 - &dictionary, 192 - ) 193 - .expect("prepared zstd dictionary should be valid"); 194 - 195 - let Ok(_) = decoder.read_to_end(&mut payload) else { 196 - tracing::error!("failed to dezstd message with zstd"); 197 - continue; 198 - }; 199 - 200 - metrics.modify(|mut data| { 201 - data.bytes_received_raw += compressed_bytes; 202 - data.bytes_received += payload.len(); 203 - }); 204 - Bytes::from(payload) 205 - } 206 - #[cfg(not(feature = "zstd"))] 207 - Message::Binary(bytes) => { 208 - tracing::warn!(?bytes, "received unexpected binary message"); 209 - continue; 210 - } 211 - Message::Ping(payload) => { 212 - tracing::trace!(?payload, "received ping, sending pong"); 213 - metrics.modify(|mut data| data.pings_received += 1); 214 - if let Err(error) = write.send(Message::Pong(payload)).await { 215 - tracing::error!(?error, "failed to send pong"); 216 - break; 217 - } 218 - metrics.modify(|mut data| data.pongs_sent += 1); 219 - continue; 220 - } 221 - Message::Pong(payload) => { 222 - tracing::warn!(payload = ?std::str::from_utf8(&payload), "received unexpected pong"); 223 - continue; 224 - } 225 - Message::Frame(frame) => { 226 - tracing::warn!(frame = ?std::str::from_utf8(&frame.into_payload()), "received unexpected frame"); 227 - continue; 228 - } 229 - Message::Close(_) => { 230 - tracing::error!("websocket closed"); 231 - break; 232 - } 233 135 }; 234 136 235 137 #[derive(Deserialize)] ··· 158 222 kind: &'a str, 159 223 } 160 224 161 - let mut new_cursor = cursor; 225 + let mut new_cursor = state.cursor; 162 226 163 - // Deserialize just the event timestamp and kind. 164 - match serde_json::from_slice::<PartialEvent>(&bytes) { 227 + // Deserialize just the event timestamp and event kind. 228 + match serde_json::from_slice::<PartialEvent>(&message) { 165 229 Ok(event) => { 166 - metrics.increment_message_kind(event.kind); 230 + state.metrics.increment_message_kind(event.kind); 167 231 new_cursor.replace(event.time_us.into()); 168 232 } 169 233 Err(error) => { 170 - match std::str::from_utf8(&bytes) { 234 + match std::str::from_utf8(&message) { 171 235 Ok(payload) => { 172 236 tracing::error!(?error, ?payload, "failed to deserialize event") 173 237 } 174 - Err(_) => tracing::error!(?error, ?bytes, "failed to deserialize event"), 238 + Err(_) => tracing::error!(?error, ?message, "failed to deserialize event"), 175 239 } 176 240 break; 177 241 } 178 242 } 179 243 180 - metrics.modify(|mut data| data.messages_received += 1); 181 - if let Err(error) = event_tx.send_async(bytes).await { 244 + state.metrics.modify(|mut data| data.messages_received += 1); 245 + if let Err(error) = event_tx.send_async(message).await { 182 246 let payload = error.into_inner(); 183 247 match std::str::from_utf8(&payload) { 184 248 Ok(payload) => tracing::error!(%payload, "Failed to dispatch event to channel"), ··· 188 252 } 189 253 190 254 // Update the cursor since the message has been dispatched. 191 - cursor = new_cursor; 255 + state.cursor = new_cursor; 192 256 } 193 257 194 - metrics.modify(|mut data| data.disconnects += 1); 258 + state.metrics.modify(|mut data| data.disconnects += 1); 195 259 } 196 260 197 261 tracing::warn!("jetstream subscriber task ended"); 262 + } 263 + 264 + enum ReadOutcome { 265 + Event(Bytes), 266 + Ping(Bytes), 267 + Timeout, 268 + Closed, 269 + } 270 + 271 + async fn handle_read_socket<S>( 272 + stream: &mut S, 273 + state: &mut State, 274 + ) -> Result<ReadOutcome, TungsteniteError> 275 + where 276 + S: StreamExt<Item = Result<Message, TungsteniteError>> + Unpin, 277 + { 278 + #[cfg(feature = "zstd")] 279 + let dictionary = zstd::dict::DecoderDictionary::copy(ZSTD_DICTIONARY); 280 + 281 + loop { 282 + let message = match timeout(RECV_TIMEOUT, stream.next()).await { 283 + Ok(Some(Ok(message))) => message, 284 + Ok(Some(Err(error))) => return Err(error), 285 + Ok(None) => return Ok(ReadOutcome::Closed), 286 + Err(_) => return Ok(ReadOutcome::Timeout), 287 + }; 288 + 289 + let message: Bytes = match message { 290 + #[cfg(feature = "zstd")] 291 + Message::Text(payload) => { 292 + panic!( 293 + "received uncompressed message but zstd feature is enabled: {}", 294 + payload.as_str() 295 + ); 296 + } 297 + #[cfg(not(feature = "zstd"))] 298 + Message::Text(payload) => { 299 + state.metrics.modify(|mut data| { 300 + data.bytes_received_raw += payload.len(); 301 + data.bytes_received += payload.len(); 302 + }); 303 + 304 + payload.into() 305 + } 306 + #[cfg(feature = "zstd")] 307 + Message::Binary(compressed_payload) => { 308 + use bytes::Buf as _; 309 + use std::io::Read as _; 310 + 311 + let compressed_bytes = compressed_payload.len(); 312 + let mut payload = Vec::with_capacity(compressed_payload.len()); 313 + let mut decoder = zstd::Decoder::with_prepared_dictionary( 314 + compressed_payload.reader(), 315 + &dictionary, 316 + ) 317 + .expect("prepared zstd dictionary should be valid"); 318 + 319 + let Ok(_) = decoder.read_to_end(&mut payload) else { 320 + tracing::error!("failed to dezstd message with zstd"); 321 + continue; 322 + }; 323 + 324 + state.metrics.modify(|mut data| { 325 + data.bytes_received_raw += compressed_bytes; 326 + data.bytes_received += payload.len(); 327 + }); 328 + 329 + payload.into() 330 + } 331 + #[cfg(not(feature = "zstd"))] 332 + Message::Binary(bytes) => { 333 + tracing::warn!(?bytes, "received unexpected binary message"); 334 + continue; 335 + } 336 + Message::Ping(payload) => { 337 + tracing::trace!(?payload, "received ping, sending pong"); 338 + state.metrics.modify(|mut data| data.pings_received += 1); 339 + return Ok(ReadOutcome::Ping(payload)); 340 + } 341 + Message::Pong(payload) => { 342 + tracing::warn!(payload = ?std::str::from_utf8(&payload), "received unexpected pong"); 343 + continue; 344 + } 345 + Message::Frame(frame) => { 346 + tracing::warn!(frame = ?std::str::from_utf8(&frame.into_payload()), "received unexpected frame"); 347 + continue; 348 + } 349 + Message::Close(_) => { 350 + tracing::error!("websocket closed"); 351 + break; 352 + } 353 + }; 354 + 355 + return Ok(ReadOutcome::Event(message)); 356 + } 357 + 358 + Ok(ReadOutcome::Closed) 359 + } 360 + 361 + fn rewind_cursor(mut cursor: Option<u128>) -> Option<u128> { 362 + if let Some(value) = &mut cursor { 363 + *value = value.saturating_sub(REWIND.as_micros()) 364 + } 365 + cursor 198 366 } 199 367 200 368 async fn send_options_update<S, E>(