A better Rust ATProto crate
103
fork

Configure Feed

Select the types of activity you want to include in your feed.

at pretty-codegen 709 lines 20 kB view raw
1//! WebSocket client abstraction 2 3use crate::CowStr; 4use crate::deps::fluent_uri::Uri; 5use crate::stream::StreamError; 6use alloc::boxed::Box; 7use alloc::string::String; 8use alloc::string::ToString; 9use alloc::vec::Vec; 10use bytes::Bytes; 11use core::borrow::Borrow; 12use core::fmt::{self, Display}; 13use core::future::Future; 14use core::ops::Deref; 15use core::pin::Pin; 16use n0_future::Stream; 17 18/// UTF-8 validated bytes for WebSocket text messages 19#[repr(transparent)] 20#[derive(Debug, Clone, Eq, PartialEq, Hash, PartialOrd, Ord)] 21pub struct WsText(Bytes); 22 23impl WsText { 24 /// Create from static string 25 pub const fn from_static(s: &'static str) -> Self { 26 Self(Bytes::from_static(s.as_bytes())) 27 } 28 29 /// Get as string slice 30 pub fn as_str(&self) -> &str { 31 unsafe { core::str::from_utf8_unchecked(&self.0) } 32 } 33 34 /// Create from bytes without validation (caller must ensure UTF-8) 35 /// 36 /// # Safety 37 /// Bytes must be valid UTF-8 38 pub unsafe fn from_bytes_unchecked(bytes: Bytes) -> Self { 39 Self(bytes) 40 } 41 42 /// Convert into underlying bytes 43 pub fn into_bytes(self) -> Bytes { 44 self.0 45 } 46} 47 48impl Deref for WsText { 49 type Target = str; 50 fn deref(&self) -> &str { 51 self.as_str() 52 } 53} 54 55impl AsRef<str> for WsText { 56 fn as_ref(&self) -> &str { 57 self.as_str() 58 } 59} 60 61impl AsRef<[u8]> for WsText { 62 fn as_ref(&self) -> &[u8] { 63 &self.0 64 } 65} 66 67impl AsRef<Bytes> for WsText { 68 fn as_ref(&self) -> &Bytes { 69 &self.0 70 } 71} 72 73impl Borrow<str> for WsText { 74 fn borrow(&self) -> &str { 75 self.as_str() 76 } 77} 78 79impl Display for WsText { 80 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 81 Display::fmt(self.as_str(), f) 82 } 83} 84 85impl From<String> for WsText { 86 fn from(s: String) -> Self { 87 Self(Bytes::from(s)) 88 } 89} 90 91impl From<&str> for WsText { 92 fn from(s: &str) -> Self { 93 Self(Bytes::copy_from_slice(s.as_bytes())) 94 } 95} 96 97impl From<&String> for WsText { 98 fn from(s: &String) -> Self { 99 Self::from(s.as_str()) 100 } 101} 102 103impl TryFrom<Bytes> for WsText { 104 type Error = core::str::Utf8Error; 105 fn try_from(bytes: Bytes) -> Result<Self, Self::Error> { 106 core::str::from_utf8(&bytes)?; 107 Ok(Self(bytes)) 108 } 109} 110 111impl TryFrom<Vec<u8>> for WsText { 112 type Error = core::str::Utf8Error; 113 fn try_from(vec: Vec<u8>) -> Result<Self, Self::Error> { 114 Self::try_from(Bytes::from(vec)) 115 } 116} 117 118impl From<WsText> for Bytes { 119 fn from(t: WsText) -> Bytes { 120 t.0 121 } 122} 123 124impl Default for WsText { 125 fn default() -> Self { 126 Self(Bytes::new()) 127 } 128} 129 130/// WebSocket close code 131#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 132#[repr(u16)] 133pub enum CloseCode { 134 /// Normal closure 135 Normal = 1000, 136 /// Endpoint going away 137 Away = 1001, 138 /// Protocol error 139 Protocol = 1002, 140 /// Unsupported data 141 Unsupported = 1003, 142 /// Invalid frame payload data 143 Invalid = 1007, 144 /// Policy violation 145 Policy = 1008, 146 /// Message too big 147 Size = 1009, 148 /// Extension negotiation failure 149 Extension = 1010, 150 /// Unexpected condition 151 Error = 1011, 152 /// TLS handshake failure 153 Tls = 1015, 154 /// Other code 155 Other(u16), 156} 157 158impl From<u16> for CloseCode { 159 fn from(code: u16) -> Self { 160 match code { 161 1000 => CloseCode::Normal, 162 1001 => CloseCode::Away, 163 1002 => CloseCode::Protocol, 164 1003 => CloseCode::Unsupported, 165 1007 => CloseCode::Invalid, 166 1008 => CloseCode::Policy, 167 1009 => CloseCode::Size, 168 1010 => CloseCode::Extension, 169 1011 => CloseCode::Error, 170 1015 => CloseCode::Tls, 171 other => CloseCode::Other(other), 172 } 173 } 174} 175 176impl From<CloseCode> for u16 { 177 fn from(code: CloseCode) -> u16 { 178 match code { 179 CloseCode::Normal => 1000, 180 CloseCode::Away => 1001, 181 CloseCode::Protocol => 1002, 182 CloseCode::Unsupported => 1003, 183 CloseCode::Invalid => 1007, 184 CloseCode::Policy => 1008, 185 CloseCode::Size => 1009, 186 CloseCode::Extension => 1010, 187 CloseCode::Error => 1011, 188 CloseCode::Tls => 1015, 189 CloseCode::Other(code) => code, 190 } 191 } 192} 193 194/// WebSocket close frame 195#[derive(Debug, Clone, PartialEq, Eq)] 196pub struct CloseFrame<'a> { 197 /// Close code 198 pub code: CloseCode, 199 /// Close reason text 200 pub reason: CowStr<'a>, 201} 202 203impl<'a> CloseFrame<'a> { 204 /// Create a new close frame 205 pub fn new(code: CloseCode, reason: impl Into<CowStr<'a>>) -> Self { 206 Self { 207 code, 208 reason: reason.into(), 209 } 210 } 211} 212 213/// WebSocket message 214#[derive(Debug, Clone, PartialEq, Eq)] 215pub enum WsMessage { 216 /// Text message (UTF-8) 217 Text(WsText), 218 /// Binary message 219 Binary(Bytes), 220 /// Close frame 221 Close(Option<CloseFrame<'static>>), 222} 223 224impl WsMessage { 225 /// Check if this is a text message 226 pub fn is_text(&self) -> bool { 227 matches!(self, WsMessage::Text(_)) 228 } 229 230 /// Check if this is a binary message 231 pub fn is_binary(&self) -> bool { 232 matches!(self, WsMessage::Binary(_)) 233 } 234 235 /// Check if this is a close message 236 pub fn is_close(&self) -> bool { 237 matches!(self, WsMessage::Close(_)) 238 } 239 240 /// Get as text, if this is a text message 241 pub fn as_text(&self) -> Option<&str> { 242 match self { 243 WsMessage::Text(t) => Some(t.as_str()), 244 _ => None, 245 } 246 } 247 248 /// Get as bytes 249 pub fn as_bytes(&self) -> Option<&[u8]> { 250 match self { 251 WsMessage::Text(t) => Some(t.as_ref()), 252 WsMessage::Binary(b) => Some(b), 253 WsMessage::Close(_) => None, 254 } 255 } 256} 257 258impl From<WsText> for WsMessage { 259 fn from(text: WsText) -> Self { 260 WsMessage::Text(text) 261 } 262} 263 264impl From<String> for WsMessage { 265 fn from(s: String) -> Self { 266 WsMessage::Text(WsText::from(s)) 267 } 268} 269 270impl From<&str> for WsMessage { 271 fn from(s: &str) -> Self { 272 WsMessage::Text(WsText::from(s)) 273 } 274} 275 276impl From<Bytes> for WsMessage { 277 fn from(bytes: Bytes) -> Self { 278 WsMessage::Binary(bytes) 279 } 280} 281 282impl From<Vec<u8>> for WsMessage { 283 fn from(vec: Vec<u8>) -> Self { 284 WsMessage::Binary(Bytes::from(vec)) 285 } 286} 287 288/// WebSocket message stream 289#[cfg(not(target_arch = "wasm32"))] 290pub struct WsStream(Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>> + Send>>); 291 292/// WebSocket message stream 293#[cfg(target_arch = "wasm32")] 294pub struct WsStream(Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>>>>); 295 296impl WsStream { 297 /// Create a new message stream 298 #[cfg(not(target_arch = "wasm32"))] 299 pub fn new<S>(stream: S) -> Self 300 where 301 S: Stream<Item = Result<WsMessage, StreamError>> + Send + 'static, 302 { 303 Self(Box::pin(stream)) 304 } 305 306 /// Create a new message stream 307 #[cfg(target_arch = "wasm32")] 308 pub fn new<S>(stream: S) -> Self 309 where 310 S: Stream<Item = Result<WsMessage, StreamError>> + 'static, 311 { 312 Self(Box::pin(stream)) 313 } 314 315 /// Convert into the inner pinned boxed stream 316 #[cfg(not(target_arch = "wasm32"))] 317 pub fn into_inner(self) -> Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>> + Send>> { 318 self.0 319 } 320 321 /// Convert into the inner pinned boxed stream 322 #[cfg(target_arch = "wasm32")] 323 pub fn into_inner(self) -> Pin<Box<dyn Stream<Item = Result<WsMessage, StreamError>>>> { 324 self.0 325 } 326 327 /// Split this stream into two streams that both receive all messages 328 /// 329 /// Messages are cloned (cheaply via Bytes rc). Spawns a forwarder task. 330 /// Both returned streams will receive all messages from the original stream. 331 /// The forwarder continues as long as at least one stream is alive. 332 /// If the underlying stream errors, both teed streams will end. 333 pub fn tee(self) -> (WsStream, WsStream) { 334 use futures::channel::mpsc; 335 use n0_future::StreamExt as _; 336 337 let (tx1, rx1) = mpsc::unbounded(); 338 let (tx2, rx2) = mpsc::unbounded(); 339 340 n0_future::task::spawn(async move { 341 let mut stream = self.0; 342 while let Some(result) = stream.next().await { 343 match result { 344 Ok(msg) => { 345 // Clone message (cheap - Bytes is rc'd) 346 let msg2 = msg.clone(); 347 348 // Send to both channels, continue if at least one succeeds 349 let send1 = tx1.unbounded_send(Ok(msg)); 350 let send2 = tx2.unbounded_send(Ok(msg2)); 351 352 // Only stop if both channels are closed 353 if send1.is_err() && send2.is_err() { 354 break; 355 } 356 } 357 Err(_e) => { 358 // Underlying stream errored, stop forwarding. 359 // Both channels will close, ending both streams. 360 break; 361 } 362 } 363 } 364 }); 365 366 (WsStream::new(rx1), WsStream::new(rx2)) 367 } 368} 369 370impl fmt::Debug for WsStream { 371 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 372 f.debug_struct("WsStream").finish_non_exhaustive() 373 } 374} 375 376/// WebSocket message sink 377#[cfg(not(target_arch = "wasm32"))] 378pub struct WsSink(Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + Send>>); 379 380/// WebSocket message sink 381#[cfg(target_arch = "wasm32")] 382pub struct WsSink(Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError>>>); 383 384impl WsSink { 385 /// Create a new message sink 386 #[cfg(not(target_arch = "wasm32"))] 387 pub fn new<S>(sink: S) -> Self 388 where 389 S: n0_future::Sink<WsMessage, Error = StreamError> + Send + 'static, 390 { 391 Self(Box::pin(sink)) 392 } 393 394 /// Create a new message sink 395 #[cfg(target_arch = "wasm32")] 396 pub fn new<S>(sink: S) -> Self 397 where 398 S: n0_future::Sink<WsMessage, Error = StreamError> + 'static, 399 { 400 Self(Box::pin(sink)) 401 } 402 403 /// Convert into the inner boxed sink 404 #[cfg(not(target_arch = "wasm32"))] 405 pub fn into_inner( 406 self, 407 ) -> Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + Send>> { 408 self.0 409 } 410 411 /// Convert into the inner boxed sink 412 #[cfg(target_arch = "wasm32")] 413 pub fn into_inner(self) -> Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError>>> { 414 self.0 415 } 416 417 /// get a mutable reference to the inner boxed sink 418 #[cfg(not(target_arch = "wasm32"))] 419 pub fn get_mut( 420 &mut self, 421 ) -> &mut Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + Send>> { 422 use core::borrow::BorrowMut; 423 424 self.0.borrow_mut() 425 } 426 427 /// get a mutable reference to the inner boxed sink 428 #[cfg(target_arch = "wasm32")] 429 pub fn get_mut( 430 &mut self, 431 ) -> &mut Pin<Box<dyn n0_future::Sink<WsMessage, Error = StreamError> + 'static>> { 432 use core::borrow::BorrowMut; 433 434 self.0.borrow_mut() 435 } 436} 437 438impl fmt::Debug for WsSink { 439 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 440 f.debug_struct("WsSink").finish_non_exhaustive() 441 } 442} 443 444/// WebSocket client trait 445#[cfg_attr(not(target_arch = "wasm32"), trait_variant::make(Send))] 446pub trait WebSocketClient: Sync { 447 /// Error type for WebSocket operations 448 type Error: core::error::Error + Send + Sync + 'static; 449 450 /// Connect to a WebSocket endpoint 451 fn connect( 452 &self, 453 uri: Uri<&str>, 454 ) -> impl Future<Output = Result<WebSocketConnection, Self::Error>>; 455 456 /// Connect to a WebSocket endpoint with custom headers 457 /// 458 /// Default implementation ignores headers and calls `connect()`. 459 /// Override this method to support authentication headers for subscriptions. 460 fn connect_with_headers( 461 &self, 462 uri: Uri<&str>, 463 _headers: Vec<(CowStr<'_>, CowStr<'_>)>, 464 ) -> impl Future<Output = Result<WebSocketConnection, Self::Error>> { 465 async move { self.connect(uri).await } 466 } 467} 468 469/// WebSocket connection with bidirectional streams 470pub struct WebSocketConnection { 471 tx: WsSink, 472 rx: WsStream, 473} 474 475impl WebSocketConnection { 476 /// Create a new WebSocket connection 477 pub fn new(tx: WsSink, rx: WsStream) -> Self { 478 Self { tx, rx } 479 } 480 481 /// Get mutable access to the sender 482 pub fn sender_mut(&mut self) -> &mut WsSink { 483 &mut self.tx 484 } 485 486 /// Get mutable access to the receiver 487 pub fn receiver_mut(&mut self) -> &mut WsStream { 488 &mut self.rx 489 } 490 491 /// Get a reference to the receiver 492 pub fn receiver(&self) -> &WsStream { 493 &self.rx 494 } 495 496 /// Get a reference to the sender 497 pub fn sender(&self) -> &WsSink { 498 &self.tx 499 } 500 501 /// Split into sender and receiver 502 pub fn split(self) -> (WsSink, WsStream) { 503 (self.tx, self.rx) 504 } 505 506 /// Check if connection is open (always true for this abstraction) 507 pub fn is_open(&self) -> bool { 508 true 509 } 510} 511 512impl fmt::Debug for WebSocketConnection { 513 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 514 f.debug_struct("WebSocketConnection") 515 .finish_non_exhaustive() 516 } 517} 518 519/// Concrete WebSocket client implementation using tokio-tungstenite-wasm 520pub mod tungstenite_client { 521 use super::*; 522 use crate::IntoStatic; 523 use futures::{SinkExt, StreamExt}; 524 525 /// WebSocket client backed by tokio-tungstenite-wasm 526 #[derive(Debug, Clone, Default)] 527 pub struct TungsteniteClient; 528 529 impl TungsteniteClient { 530 /// Create a new tungstenite WebSocket client 531 pub fn new() -> Self { 532 Self 533 } 534 } 535 536 impl WebSocketClient for TungsteniteClient { 537 type Error = tokio_tungstenite_wasm::Error; 538 539 async fn connect(&self, uri: Uri<&str>) -> Result<WebSocketConnection, Self::Error> { 540 let ws_stream = tokio_tungstenite_wasm::connect(uri.as_str()).await?; 541 542 let (sink, stream) = ws_stream.split(); 543 544 // Convert tungstenite messages to our WsMessage 545 let rx_stream = stream.filter_map(|result| async move { 546 match result { 547 Ok(msg) => match convert_message(msg) { 548 Some(ws_msg) => Some(Ok(ws_msg)), 549 None => None, // Skip ping/pong 550 }, 551 Err(e) => Some(Err(StreamError::transport(e))), 552 } 553 }); 554 555 let rx = WsStream::new(rx_stream); 556 557 // Convert our WsMessage to tungstenite messages 558 let tx_sink = sink.with(|msg: WsMessage| async move { 559 Ok::<_, tokio_tungstenite_wasm::Error>(msg.into()) 560 }); 561 562 let tx_sink_mapped = tx_sink.sink_map_err(|e| StreamError::transport(e)); 563 let tx = WsSink::new(tx_sink_mapped); 564 565 Ok(WebSocketConnection::new(tx, rx)) 566 } 567 } 568 569 /// Convert tokio-tungstenite-wasm Message to our WsMessage 570 /// Returns None for Ping/Pong which we auto-handle 571 fn convert_message(msg: tokio_tungstenite_wasm::Message) -> Option<WsMessage> { 572 use tokio_tungstenite_wasm::Message; 573 574 match msg { 575 Message::Text(vec) => { 576 // tokio-tungstenite-wasm Text contains Vec<u8> (UTF-8 validated) 577 let bytes = Bytes::from(vec); 578 Some(WsMessage::Text(unsafe { 579 WsText::from_bytes_unchecked(bytes) 580 })) 581 } 582 Message::Binary(vec) => Some(WsMessage::Binary(Bytes::from(vec))), 583 Message::Close(frame) => { 584 let close_frame = frame.map(|f| { 585 let code = convert_close_code(f.code); 586 CloseFrame::new(code, CowStr::from(f.reason.into_owned())) 587 }); 588 Some(WsMessage::Close(close_frame)) 589 } 590 } 591 } 592 593 /// Convert tokio-tungstenite-wasm CloseCode to our CloseCode 594 fn convert_close_code(code: tokio_tungstenite_wasm::CloseCode) -> CloseCode { 595 use tokio_tungstenite_wasm::CloseCode as TungsteniteCode; 596 597 match code { 598 TungsteniteCode::Normal => CloseCode::Normal, 599 TungsteniteCode::Away => CloseCode::Away, 600 TungsteniteCode::Protocol => CloseCode::Protocol, 601 TungsteniteCode::Unsupported => CloseCode::Unsupported, 602 TungsteniteCode::Invalid => CloseCode::Invalid, 603 TungsteniteCode::Policy => CloseCode::Policy, 604 TungsteniteCode::Size => CloseCode::Size, 605 TungsteniteCode::Extension => CloseCode::Extension, 606 TungsteniteCode::Error => CloseCode::Error, 607 TungsteniteCode::Tls => CloseCode::Tls, 608 // For other variants, extract raw code 609 other => { 610 let raw: u16 = other.into(); 611 CloseCode::from(raw) 612 } 613 } 614 } 615 616 impl From<WsMessage> for tokio_tungstenite_wasm::Message { 617 fn from(msg: WsMessage) -> Self { 618 use tokio_tungstenite_wasm::Message; 619 620 match msg { 621 WsMessage::Text(text) => { 622 // tokio-tungstenite-wasm Text expects String 623 let bytes = text.into_bytes(); 624 // Safe: WsText is already UTF-8 validated 625 let string = unsafe { String::from_utf8_unchecked(bytes.to_vec()) }; 626 Message::Text(string) 627 } 628 WsMessage::Binary(bytes) => Message::Binary(bytes.to_vec()), 629 WsMessage::Close(frame) => { 630 let close_frame = frame.map(|f| { 631 let code = u16::from(f.code).into(); 632 tokio_tungstenite_wasm::CloseFrame { 633 code, 634 reason: f.reason.into_static().to_string().into(), 635 } 636 }); 637 Message::Close(close_frame) 638 } 639 } 640 } 641 } 642} 643 644#[cfg(test)] 645mod tests { 646 use super::*; 647 648 #[test] 649 fn ws_text_from_string() { 650 let text = WsText::from("hello"); 651 assert_eq!(text.as_str(), "hello"); 652 } 653 654 #[test] 655 fn ws_text_deref() { 656 let text = WsText::from(String::from("world")); 657 assert_eq!(&*text, "world"); 658 } 659 660 #[test] 661 fn ws_text_try_from_bytes() { 662 let bytes = Bytes::from("test"); 663 let text = WsText::try_from(bytes).unwrap(); 664 assert_eq!(text.as_str(), "test"); 665 } 666 667 #[test] 668 fn ws_text_invalid_utf8() { 669 let bytes = Bytes::from(vec![0xFF, 0xFE]); 670 assert!(WsText::try_from(bytes).is_err()); 671 } 672 673 #[test] 674 fn ws_message_text() { 675 let msg = WsMessage::from("hello"); 676 assert!(msg.is_text()); 677 assert_eq!(msg.as_text(), Some("hello")); 678 } 679 680 #[test] 681 fn ws_message_binary() { 682 let msg = WsMessage::from(vec![1, 2, 3]); 683 assert!(msg.is_binary()); 684 assert_eq!(msg.as_bytes(), Some(&[1u8, 2, 3][..])); 685 } 686 687 #[test] 688 fn close_code_conversion() { 689 assert_eq!(u16::from(CloseCode::Normal), 1000); 690 assert_eq!(CloseCode::from(1000), CloseCode::Normal); 691 assert_eq!(CloseCode::from(9999), CloseCode::Other(9999)); 692 } 693 694 #[test] 695 fn websocket_connection_has_tx_and_rx() { 696 use futures::sink::SinkExt; 697 use futures::stream; 698 699 let rx_stream = stream::iter(vec![Ok(WsMessage::from("test"))]); 700 let rx = WsStream::new(rx_stream); 701 702 let drain_sink = futures::sink::drain() 703 .sink_map_err(|_: std::convert::Infallible| StreamError::closed()); 704 let tx = WsSink::new(drain_sink); 705 706 let conn = WebSocketConnection::new(tx, rx); 707 assert!(conn.is_open()); 708 } 709}