Mirror of https://github.com/roostorg/osprey
github.com/roostorg/osprey
1use crate::coordinator_metrics::OspreyCoordinatorMetrics;
2use crate::priority_queue::ActionAcker;
3use crate::priority_queue::{PriorityQueueReceiver, PriorityQueueSender};
4use crate::proto;
5use anyhow::{anyhow, Context, Result};
6use proto::action_request::ActionRequest;
7use std::sync::Arc;
8use std::{error::Error, io::ErrorKind};
9use tokio::sync::mpsc::{self, Sender};
10use tokio::time::{timeout, Duration, Instant};
11use tokio_stream::{wrappers::ReceiverStream, StreamExt};
12
13use crate::metrics::counters::StaticCounter;
14use crate::metrics::histograms::StaticHistogram;
15
16fn match_for_io_error(err_status: &tonic::Status) -> Option<&std::io::Error> {
17 let mut err: &(dyn Error + 'static) = err_status;
18
19 loop {
20 if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
21 return Some(io_err);
22 }
23
24 // h2::Error do not expose std::io::Error with `source()`
25 // https://github.com/hyperium/h2/pull/462
26 if let Some(h2_err) = err.downcast_ref::<h2::Error>() {
27 if let Some(io_err) = h2_err.get_io() {
28 return Some(io_err);
29 }
30 }
31
32 err = match err.source() {
33 Some(err) => err,
34 None => return None,
35 };
36 }
37}
38
39#[derive(Debug)]
40struct OutstandingActionState {
41 action_acker: ActionAcker,
42 send_time: Instant,
43 client_details: proto::ClientDetails,
44}
45
46#[derive(Debug)]
47enum ClientState {
48 NoOutstandingAction,
49 OutstandingAction(OutstandingActionState),
50}
51
52pub struct OspreyCoordinatorServer {
53 priority_queue_receiver: PriorityQueueReceiver,
54 #[allow(unused)]
55 priority_queue_sender: PriorityQueueSender, // TODO: use this for retrying sync actions
56 metrics: Arc<OspreyCoordinatorMetrics>,
57}
58
59impl OspreyCoordinatorServer {
60 pub fn new(
61 priority_queue_sender: PriorityQueueSender,
62 priority_queue_receiver: PriorityQueueReceiver,
63 metrics: Arc<OspreyCoordinatorMetrics>,
64 ) -> OspreyCoordinatorServer {
65 OspreyCoordinatorServer {
66 priority_queue_sender,
67 priority_queue_receiver,
68 metrics,
69 }
70 }
71}
72
73fn handle_action_request(
74 action_request: ActionRequest,
75 current_client_state: ClientState,
76 metrics: Arc<OspreyCoordinatorMetrics>,
77) -> Result<proto::ClientDetails> {
78 match (action_request, current_client_state) {
79 (ActionRequest::Initial(client_details), ClientState::NoOutstandingAction) => {
80 Ok(client_details)
81 }
82 (ActionRequest::Initial(_), ClientState::OutstandingAction(_)) => Err(anyhow!(
83 "got an initial action request while there was an outstanding action"
84 )),
85 (ActionRequest::AckOrNack(ack_or_nack), ClientState::NoOutstandingAction) => Err(anyhow!(
86 "got an {:?} with no outstanding action",
87 ack_or_nack
88 )),
89 (ActionRequest::AckOrNack(ack_or_nack), ClientState::OutstandingAction(state)) => {
90 let duration = Instant::now().duration_since(state.send_time);
91 metrics.action_outstanding_duration.record(duration);
92 state.action_acker.ack_or_nack(
93 ack_or_nack
94 .ack_or_nack
95 .context("no `ack_or_nack` in proto")?,
96 );
97 Ok(state.client_details)
98 }
99 }
100}
101
102enum UpdateClientStateOrDisconnect {
103 UpdateClientState(ClientState),
104 ClientRequestedDisconnect,
105 #[allow(dead_code)]
106 ActionReceiverClosedDisconnect,
107 ActionReceiverTimedOut,
108}
109
110async fn handle_request(
111 client_state: ClientState,
112 sender: &Sender<Result<proto::OspreyCoordinatorAction, tonic::Status>>,
113 request: proto::Request,
114 action_receiver: &PriorityQueueReceiver,
115 metrics: Arc<OspreyCoordinatorMetrics>,
116 receive_timeout: Duration,
117) -> Result<UpdateClientStateOrDisconnect> {
118 match request
119 .request
120 .context("request object missing from proto")?
121 {
122 proto::request::Request::ActionRequest(action_request) => {
123 let action_request = action_request
124 .action_request
125 .context("no `action_request.action_request` in `ActionRequest` proto")?;
126 let client_details =
127 handle_action_request(action_request, client_state, metrics.clone())?;
128 tracing::debug!("awaiting action from priority queue");
129 let priority_queue_receive_start_time = Instant::now();
130 let result = timeout(receive_timeout, action_receiver.recv(metrics.clone())).await;
131 let ackable_action = match result {
132 Ok(Ok(ackable_action)) => ackable_action,
133 Ok(Err(_)) | Err(_) => {
134 tracing::error!(
135 "Took too long to get action from priority queue, disconnecting"
136 );
137 metrics
138 .priority_queue_receive_time
139 .record(Instant::now().duration_since(priority_queue_receive_start_time));
140 return Ok(UpdateClientStateOrDisconnect::ActionReceiverTimedOut);
141 }
142 };
143 metrics
144 .priority_queue_receive_time
145 .record(Instant::now().duration_since(priority_queue_receive_start_time));
146 let (action, action_acker) = ackable_action.into_action();
147 sender.send(Ok(action)).await?;
148 Ok(UpdateClientStateOrDisconnect::UpdateClientState(
149 ClientState::OutstandingAction(OutstandingActionState {
150 action_acker,
151 send_time: Instant::now(),
152 client_details,
153 }),
154 ))
155 }
156 proto::request::Request::Disconnect(disconnect) => {
157 let ack_or_nack = disconnect
158 .ack_or_nack
159 .context("no `ack_or_nack` in `disconnect` proto")?
160 .ack_or_nack
161 .context("no `ack_or_nack.ack_or_nack` in `disconnect` proto")?;
162 if let ClientState::OutstandingAction(state) = client_state {
163 state.action_acker.ack_or_nack(ack_or_nack);
164 }
165 Ok(UpdateClientStateOrDisconnect::ClientRequestedDisconnect)
166 }
167 }
168}
169
170#[tonic::async_trait]
171impl proto::osprey_coordinator_service_server::OspreyCoordinatorService
172 for OspreyCoordinatorServer
173{
174 type OspreyBidirectionalStreamStream =
175 ReceiverStream<Result<proto::OspreyCoordinatorAction, tonic::Status>>;
176
177 async fn osprey_bidirectional_stream(
178 &self,
179 request: tonic::Request<tonic::Streaming<proto::Request>>,
180 ) -> Result<tonic::Response<Self::OspreyBidirectionalStreamStream>, tonic::Status> {
181 tracing::debug!(
182 { connection =? request.metadata() },
183 "New Connection Received"
184 );
185 let mut in_stream = request.into_inner();
186 self.metrics.new_connection_established.incr();
187 let (tx, rx) = mpsc::channel(128);
188 let action_receiver = self.priority_queue_receiver.clone();
189 let metrics = self.metrics.clone();
190 let max_pq_receive_await_time_ms = Duration::from_millis(
191 std::env::var("MAX_PQ_RECEIVE_AWAIT_TIME_MS")
192 .unwrap_or("5000".to_string())
193 .parse::<u64>()
194 .unwrap(),
195 );
196 tokio::spawn(async move {
197 let mut client_state = ClientState::NoOutstandingAction {};
198
199 // TODO: refactor the code to honor the invariants of: the first request should always be an
200 // InitialActionRequest and every request after that will either be an AckingRequest or AckingDisconnect
201 // let initial_request = in_stream.next().await.unwrap().unwrap();
202 // let client_details = match initial_request.request.expect("request must exist") {
203 // proto::request::Request::ActionRequest(action_request) => {
204 // match action_request.action_request.expect("must exist") {
205 // ActionRequest::Initial(client_details) => client_details,
206 // ActionRequest::AckOrNack(_) => unreachable!(),
207 // }
208 // }
209 // proto::request::Request::Disconnect(_) => unreachable!(),
210 // };
211
212 while let Some(result) = in_stream.next().await {
213 match result {
214 Ok(request) => {
215 tracing::debug!({request=?request},"got request");
216 client_state = match handle_request(
217 client_state,
218 &tx,
219 request,
220 &action_receiver,
221 metrics.clone(),
222 max_pq_receive_await_time_ms,
223 )
224 .await
225 {
226 Ok(directive) => match directive {
227 UpdateClientStateOrDisconnect::UpdateClientState(
228 new_client_state,
229 ) => new_client_state,
230 UpdateClientStateOrDisconnect::ClientRequestedDisconnect => {
231 tracing::debug!("client requested a disconnect");
232 metrics.client_disconnected_gracefully.incr();
233 break;
234 }
235 UpdateClientStateOrDisconnect::ActionReceiverClosedDisconnect => {
236 tracing::debug!("disconnecting client because receiver closed");
237 metrics.client_disconnected_receiver_closed.incr();
238 break;
239 }
240 UpdateClientStateOrDisconnect::ActionReceiverTimedOut => {
241 tracing::debug!(
242 "disconnecting client because receiver timed out"
243 );
244 metrics.client_disconnected_receiver_timeout.incr();
245 break;
246 }
247 },
248 Err(error) => {
249 tracing::error!({error=%error},"error in stream");
250 metrics.client_disconnected_stream_error.incr();
251 // commenting this out for now
252 // we might not have to send an aborted when we get an error
253 // tx.send(Err(tonic::Status::new(
254 // tonic::Code::Aborted,
255 // error.to_string(),
256 // )))
257 // .await
258 // .expect("output stream must be open");
259 break;
260 }
261 }
262 }
263 Err(err) => {
264 if let Some(io_err) = match_for_io_error(&err) {
265 if io_err.kind() == ErrorKind::BrokenPipe {
266 tracing::error!("client disconnected: broken pipe");
267 metrics.client_disconnected_broken_pipe.incr();
268 break;
269 }
270 }
271
272 match tx.send(Err(err)).await {
273 Ok(_) => (),
274 Err(_err) => break, // response was dropped
275 }
276 }
277 }
278 }
279
280 tracing::debug!("stream ended");
281 });
282
283 let out_stream = ReceiverStream::new(rx);
284 Ok(tonic::Response::new(out_stream))
285 }
286}
287
288#[cfg(test)]
289mod tests {
290
291 use crate::coordinator_metrics::OspreyCoordinatorMetrics;
292 use crate::metrics::emit_worker::SpawnEmitWorker;
293 use crate::metrics::new_client;
294 use crate::proto::osprey_coordinator_action::ActionData;
295 use crate::proto::osprey_coordinator_action::SecretData;
296 use proto::osprey_coordinator_service_server::OspreyCoordinatorService;
297
298 use crate::priority_queue::create_ackable_action_priority_queue;
299 use crate::priority_queue::AckableAction;
300
301 use super::*;
302
303 #[tokio::test]
304 async fn golden_path_bidirection_streaming_test() -> Result<()> {
305 // Simple golden path test that adds two actions to the queue and asserts that a properly
306 // formed bidirectional streaming request is returned the actions in that order
307
308 tracing_subscriber::fmt::init();
309 let (priority_queue_sender, priority_queue_receiver) =
310 create_ackable_action_priority_queue();
311 let metrics = OspreyCoordinatorMetrics::new();
312 let _worker_guard = metrics
313 .clone()
314 .spawn_emit_worker(new_client("osprey_coordinator").unwrap());
315
316 let ackable_action = proto::OspreyCoordinatorAction {
317 ack_id: 1,
318 action_id: 1,
319 action_name: "test_action".into(),
320 timestamp: None,
321 action_data: Some(ActionData::JsonActionData(
322 "{\"action\": \"test action data 1\"}".into(),
323 )),
324 secret_data: Some(SecretData::JsonSecretData(
325 "{\"secret\": \"test secret data 1\"}".into(),
326 )),
327 };
328 let (ackable_action, _receiver_drop_guard_1) = AckableAction::new(ackable_action);
329 priority_queue_sender
330 .send_sync(ackable_action)
331 .await
332 .unwrap();
333
334 let ackable_action_2 = proto::OspreyCoordinatorAction {
335 ack_id: 2,
336 action_id: 2,
337 action_name: "test_action".into(),
338 timestamp: None,
339 action_data: Some(ActionData::JsonActionData(
340 "{\"action\": \"test action data 2\"}".into(),
341 )),
342 secret_data: Some(SecretData::JsonSecretData(
343 "{\"secret\": \"test secret data 2\"}".into(),
344 )),
345 };
346 let (ackable_action, _receiver_drop_guard_2) = AckableAction::new(ackable_action_2);
347 priority_queue_sender
348 .send_sync(ackable_action)
349 .await
350 .unwrap();
351
352 let server = OspreyCoordinatorServer::new(
353 priority_queue_sender.clone(),
354 priority_queue_receiver,
355 metrics.clone(),
356 );
357
358 let initial_action_request = proto::Request {
359 request: Some(proto::request::Request::ActionRequest(
360 proto::ActionRequest {
361 action_request: Some(proto::action_request::ActionRequest::Initial(
362 proto::ClientDetails::default(),
363 )),
364 },
365 )),
366 };
367
368 let acking_action_request = proto::Request {
369 request: Some(proto::request::Request::ActionRequest(
370 proto::ActionRequest {
371 action_request: Some(proto::action_request::ActionRequest::AckOrNack(
372 proto::AckOrNack {
373 ack_id: 0,
374 ack_or_nack: Some(proto::ack_or_nack::AckOrNack::Ack(proto::Ack {
375 execution_result: None,
376 verdicts: None,
377 })),
378 },
379 )),
380 },
381 )),
382 };
383
384 let req = crate::tonic_mock::streaming_request(vec![
385 initial_action_request.clone(),
386 acking_action_request.clone(),
387 ]);
388
389 let res = server
390 .osprey_bidirectional_stream(req)
391 .await
392 .expect("error in stream");
393
394 println!("finish connection");
395
396 let mut result = Vec::new();
397 let mut messages = res.into_inner();
398 while let Some(v) = messages.next().await {
399 println!("got message: {:?}", &v);
400 result.push(v.expect("error from stream"))
401 }
402
403 print!("{:?}", result);
404
405 assert_eq!(result[0].action_id, 1);
406 assert_eq!(result[1].action_id, 2);
407 assert_eq!(
408 result[0].action_data,
409 Some(ActionData::JsonActionData(
410 "{\"action\": \"test action data 1\"}".into()
411 ))
412 );
413 assert_eq!(
414 result[1].action_data,
415 Some(ActionData::JsonActionData(
416 "{\"action\": \"test action data 2\"}".into()
417 ))
418 );
419 assert_eq!(
420 result[0].secret_data,
421 Some(SecretData::JsonSecretData(
422 "{\"secret\": \"test secret data 1\"}".into()
423 ))
424 );
425 assert_eq!(
426 result[1].secret_data,
427 Some(SecretData::JsonSecretData(
428 "{\"secret\": \"test secret data 2\"}".into()
429 ))
430 );
431
432 Ok(())
433 }
434}