Tap drinker
1use std::time::{Duration, SystemTime};
2
3use futures_util::{SinkExt as _, StreamExt};
4use serde::Serialize;
5use tokio::{
6 net::TcpStream,
7 sync::mpsc::{self, error::SendTimeoutError},
8};
9use tokio_tungstenite::{
10 MaybeTlsStream, WebSocketStream,
11 tungstenite::{Bytes, ClientRequestBuilder, Message},
12};
13use tokio_util::sync::{CancellationToken, DropGuard};
14use tracing::Span;
15
16use crate::tap::{TapClient, TapEvent};
17
18const TIMEOUT: Duration = Duration::from_secs(10);
19
20const DISPATCH_TIMEOUT: Duration = Duration::from_secs(2);
21
22#[derive(Debug, thiserror::Error)]
23#[error("Failed to enqueue acknowledgement for event #{0}")]
24pub struct AckError(u64);
25
26impl From<mpsc::error::SendError<u64>> for AckError {
27 fn from(error: mpsc::error::SendError<u64>) -> Self {
28 Self(error.0)
29 }
30}
31
32pub struct Ack {
33 id: u64,
34 tx: mpsc::Sender<u64>,
35}
36
37impl Ack {
38 /// Acknowledge receipt of the associated event.
39 ///
40 /// Success does *not* mean the Tap server has successfully received the
41 /// acknowledgement, only that the ack has be queued by the client.
42 pub async fn acknowledge(self) -> Result<(), AckError> {
43 self.tx.send(self.id).await?;
44 Ok(())
45 }
46}
47
48/// Messages that are serialized and sent to the Tap server.
49#[derive(Debug, Serialize)]
50#[serde(tag = "type", rename_all = "snake_case")]
51enum ClientMessage {
52 Ack { id: u64 },
53}
54
55impl From<u64> for ClientMessage {
56 fn from(id: u64) -> Self {
57 Self::Ack { id }
58 }
59}
60
61#[derive(Debug)]
62pub struct TapChannel {
63 rx: mpsc::Receiver<(Span, TapEvent, Ack)>,
64 #[allow(unused)]
65 shutdown: DropGuard,
66}
67
68impl TapChannel {
69 pub const DEFAULT_CAPACITY: usize = 128;
70
71 pub fn new(
72 tap: &TapClient,
73 capacity: usize,
74 ) -> (
75 Self,
76 impl Future<Output = Result<(), ChannelError>> + Send + 'static,
77 ) {
78 let mut url = tap.url().clone();
79 url.set_path("/channel");
80 url.set_scheme(match url.scheme() {
81 "https" => "wss",
82 "http" => "ws",
83 _ => unreachable!("Tap::new should reject unknown schemes"),
84 })
85 .expect("'http' or 'https' is a valid URL scheme");
86
87 let uri = url
88 .as_str()
89 .parse()
90 .expect("Url has already been validated");
91
92 let mut builder = ClientRequestBuilder::new(uri);
93 for (header_name, header_value) in &tap.headers {
94 builder = builder.with_header(
95 header_name.to_string(),
96 header_value
97 .to_str()
98 .expect("Header value has already been validated"),
99 );
100 }
101
102 let (tx, rx) = mpsc::channel(capacity);
103 let shutdown = CancellationToken::new();
104 let handle = channel_task(builder, tx, shutdown.child_token(), capacity);
105
106 (
107 Self {
108 rx,
109 shutdown: shutdown.drop_guard(),
110 },
111 handle,
112 )
113 }
114}
115
116impl TapChannel {
117 pub async fn recv(&mut self) -> Option<(Span, TapEvent, Ack)> {
118 self.rx.recv().await
119 }
120}
121
122#[derive(Debug, thiserror::Error)]
123pub enum ChannelError {
124 #[error("Client authorization failed")]
125 Authorization,
126 #[error("Failed to send pending Acks: {0:?}: {1}")]
127 FailedAck(u64, tokio_tungstenite::tungstenite::Error),
128}
129
130async fn channel_task(
131 request_builder: ClientRequestBuilder,
132 event_tx: mpsc::Sender<(Span, TapEvent, Ack)>,
133 shutdown: CancellationToken,
134 capacity: usize,
135) -> Result<(), ChannelError> {
136 #[derive(Debug)]
137 enum Action {
138 Message(Message),
139 Timeout,
140 Ack(u64),
141 ClearAcks,
142 }
143
144 let (ack_tx, mut ack_rx) = mpsc::channel(capacity);
145
146 'outer: while !shutdown.is_cancelled() {
147 let mut ping_inflight = false;
148 let mut recv_timeout = tokio::time::interval(TIMEOUT);
149 recv_timeout.tick().await;
150
151 let request = request_builder.clone();
152 let (mut socket, _) = match tokio_tungstenite::connect_async(request).await {
153 Ok(result) => result,
154 Err(tokio_tungstenite::tungstenite::Error::Http(error))
155 if error.status().is_client_error() =>
156 {
157 tracing::error!(?error, "failed to connect to Tap channel");
158 return Err(ChannelError::Authorization);
159 }
160 Err(error) => {
161 tracing::error!(?error);
162
163 // @TODO Reconnect delay
164
165 continue 'outer;
166 }
167 };
168
169 loop {
170 let action = tokio::select! {
171 Some(Ok(message)) = socket.next() => Action::Message(message),
172 Some(ack) = ack_rx.recv() => Action::Ack(ack),
173 _ = recv_timeout.tick() => Action::Timeout,
174 _ = shutdown.cancelled() => Action::ClearAcks,
175 else => Action::ClearAcks,
176 };
177
178 recv_timeout.reset();
179
180 match action {
181 Action::Message(message) => match message {
182 Message::Text(bytes) => {
183 let event = match serde_json::from_str::<TapEvent>(bytes.as_str()) {
184 Ok(event) => event,
185 Err(error) => {
186 tracing::error!(?error, bytes = %bytes.as_str(), "failed to deserialize event");
187 continue;
188 }
189 };
190
191 let span = tracing::info_span!("event", id = event.id());
192 span.in_scope(|| {
193 tracing::info!(?event, "received event");
194 });
195
196 let ack = Ack {
197 id: event.id(),
198 tx: ack_tx.clone(),
199 };
200
201 let message = (span, event, ack);
202 match event_tx.send_timeout(message, DISPATCH_TIMEOUT).await {
203 Err(SendTimeoutError::Timeout((span, event, _))) => {
204 span.in_scope(|| {
205 tracing::error!(?event, "channel consumer stalled");
206 });
207 break 'outer;
208 }
209 Err(SendTimeoutError::Closed((span, event, _))) => {
210 span.in_scope(|| {
211 tracing::error!(?event, "channel consumer closed");
212 });
213 break 'outer;
214 }
215 Ok(_) => {}
216 }
217 }
218 Message::Binary(_) | Message::Frame(_) => {
219 tracing::error!("unexpected Binary or Frame message from server");
220 break;
221 }
222 Message::Ping(bytes) => {
223 if let Err(error) = socket.send(Message::Pong(bytes)).await {
224 tracing::error!(?error, "failed to send Pong");
225 break;
226 }
227 }
228 Message::Pong(bytes) => {
229 tracing::trace!(?bytes, "received Pong from server");
230 ping_inflight = false;
231 }
232 Message::Close(close_frame) => {
233 tracing::debug!(?close_frame, "received close frame");
234 break;
235 }
236 },
237 Action::Ack(ack) => {
238 if let Err(error) = send_acknowledgement(&mut socket, ack).await {
239 tracing::error!(?error, "failed to send Ack");
240 break;
241 }
242 }
243 Action::Timeout => {
244 if ping_inflight {
245 tracing::error!("missed ping");
246 break;
247 }
248
249 let timestamp = SystemTime::now()
250 .duration_since(SystemTime::UNIX_EPOCH)
251 .expect("system time precedes UNIX epoch")
252 .as_micros();
253
254 let payload = format!("{timestamp}");
255 let payload: Bytes = payload.into();
256 if socket.send(Message::Ping(payload)).await.is_err() {
257 tracing::error!("failed to send Ping to server");
258 break;
259 }
260
261 ping_inflight = true;
262 }
263 Action::ClearAcks => {
264 drop(ack_tx);
265 while let Some(ack) = ack_rx.recv().await {
266 if let Err(error) = send_acknowledgement(&mut socket, ack).await {
267 tracing::error!(?error, "failed to send ack");
268 return Err(ChannelError::FailedAck(ack, error));
269 }
270 }
271
272 break 'outer;
273 }
274 }
275 }
276
277 tracing::warn!("disconnected");
278 }
279
280 tracing::info!("complete");
281 Ok(())
282}
283
284async fn send_acknowledgement(
285 socket: &mut WebSocketStream<MaybeTlsStream<TcpStream>>,
286 ack: u64,
287) -> Result<(), tokio_tungstenite::tungstenite::Error> {
288 tracing::info!(?ack, "sending ack");
289 let message = serde_json::to_string(&ClientMessage::from(ack))
290 .expect("ClientMessage should be serializable");
291
292 socket.send(Message::text(message)).await?;
293
294 Ok(())
295}