Mirror of https://github.com/roostorg/osprey
github.com/roostorg/osprey
1use crate::consumer::message_consumer::{ConsumerConfig, ConsumerMessage, MessageConsumer};
2use crate::consumer::message_decoder;
3use crate::coordinator_metrics::OspreyCoordinatorMetrics;
4use crate::metrics::counters::StaticCounter;
5use crate::metrics::histograms::StaticHistogram;
6use crate::priority_queue::{AckOrNack, AckableAction, PriorityQueueSender};
7use crate::signals::exit_signal;
8use crate::snowflake_client::SnowflakeClient;
9use anyhow::Result;
10use async_trait::async_trait;
11use prost_types::Timestamp;
12use rdkafka::config::ClientConfig;
13use rdkafka::consumer::{Consumer, StreamConsumer};
14use rdkafka::error::KafkaError;
15use rdkafka::message::{Headers, Message as KafkaRawMessage};
16use std::collections::HashMap;
17use std::sync::Arc;
18use std::time::{SystemTime, UNIX_EPOCH};
19use tokio::time::{timeout, Instant};
20
21pub struct KafkaConsumer {
22 consumer: StreamConsumer,
23}
24
25pub struct KafkaMessage {
26 data: Vec<u8>,
27 attributes: HashMap<String, String>,
28 timestamp: Timestamp,
29 id: String,
30}
31
32impl KafkaMessage {
33 pub fn new(
34 data: Vec<u8>,
35 attributes: HashMap<String, String>,
36 timestamp: Timestamp,
37 id: String,
38 ) -> Self {
39 Self {
40 data,
41 attributes,
42 timestamp,
43 id,
44 }
45 }
46}
47
48impl ConsumerMessage for KafkaMessage {
49 fn data(&self) -> &[u8] {
50 &self.data
51 }
52
53 fn attributes(&self) -> &HashMap<String, String> {
54 &self.attributes
55 }
56
57 fn timestamp(&self) -> Timestamp {
58 self.timestamp.clone()
59 }
60
61 fn id(&self) -> String {
62 self.id.clone()
63 }
64}
65
66#[async_trait]
67impl MessageConsumer for KafkaConsumer {
68 type Message = KafkaMessage;
69 type Error = KafkaError;
70
71 async fn receive(&mut self) -> Result<Self::Message, Self::Error> {
72 let msg = self.consumer.recv().await?;
73
74 let data = msg.payload().unwrap_or(&[]).to_vec();
75
76 let attributes: HashMap<String, String> = msg
77 .headers()
78 .map(|headers| {
79 headers
80 .iter()
81 .filter_map(|header| {
82 let key = header.key.to_string();
83 let value = header
84 .value
85 .and_then(|v| String::from_utf8(v.to_vec()).ok())
86 .unwrap_or_default();
87 Some((key, value))
88 })
89 .collect()
90 })
91 .unwrap_or_default();
92
93 let timestamp_millis = msg.timestamp().to_millis().unwrap_or_else(|| {
94 SystemTime::now()
95 .duration_since(UNIX_EPOCH)
96 .unwrap()
97 .as_millis() as i64
98 });
99
100 let timestamp = Timestamp {
101 seconds: timestamp_millis / 1000,
102 nanos: ((timestamp_millis % 1000) * 1_000_000) as i32,
103 };
104
105 let partition = msg.partition();
106 let offset = msg.offset();
107 let id = format!("kafka-{}-{}", partition, offset);
108
109 Ok(KafkaMessage::new(data, attributes, timestamp, id))
110 }
111
112 async fn ack(&self, _message: &Self::Message) -> Result<(), Self::Error> {
113 self.consumer
114 .commit_consumer_state(rdkafka::consumer::CommitMode::Async)?;
115 Ok(())
116 }
117
118 async fn nack(&self, _message: &Self::Message) -> Result<(), Self::Error> {
119 Ok(())
120 }
121}
122
123impl KafkaConsumer {
124 pub async fn new() -> Result<Self> {
125 let input_topic = std::env::var("OSPREY_KAFKA_INPUT_STREAM_TOPIC")
126 .unwrap_or("osprey.actions_input".to_string());
127 let input_bootstrap_servers =
128 std::env::var("OSPREY_KAFKA_BOOTSTRAP_SERVERS").unwrap_or("localhost:9092".to_string());
129 let group_id = std::env::var("OSPREY_KAFKA_GROUP_ID")
130 .unwrap_or("osprey_coordinator_group".to_string());
131
132 tracing::info!(
133 "Creating Kafka consumer for topic: {} with bootstrap servers: {}",
134 input_topic,
135 input_bootstrap_servers
136 );
137
138 let consumer: StreamConsumer = ClientConfig::new()
139 .set("group.id", &group_id)
140 .set("bootstrap.servers", &input_bootstrap_servers)
141 .set("enable.auto.commit", "false")
142 .set("auto.offset.reset", "earliest")
143 .create::<StreamConsumer>()?;
144
145 consumer.subscribe(&[&input_topic])?;
146
147 Ok(Self { consumer })
148 }
149}
150
151pub async fn start_kafka_consumer(
152 snowflake_client: Arc<SnowflakeClient>,
153 priority_queue_sender: PriorityQueueSender,
154 metrics: Arc<OspreyCoordinatorMetrics>,
155) -> Result<()> {
156 tracing::info!("Kafka consumer starting...");
157
158 let mut consumer = KafkaConsumer::new().await?;
159 let config = ConsumerConfig::default();
160
161 loop {
162 tokio::select! {
163 _ = exit_signal() => {
164 tracing::info!("Received exit signal, shutting down Kafka consumer");
165 return Ok(());
166 }
167 message_result = consumer.receive() => {
168 let message = match message_result {
169 Ok(msg) => msg,
170 Err(e) => {
171 tracing::error!({error = %e}, "[kafka] error receiving message");
172 continue;
173 }
174 };
175
176 let ack_id: u64 = rand::Rng::gen(&mut rand::thread_rng());
177 let message_id = message.id();
178
179 let action = match message.attributes().get("encoding").map(|s| s.as_str()) {
180 Some("proto") => {
181 message_decoder::decode_proto_message(
182 message.data(),
183 ack_id,
184 message.timestamp(),
185 &snowflake_client,
186 &metrics,
187 )
188 .await
189 }
190 _ => {
191 message_decoder::decode_msgpack_json_message(
192 message.data(),
193 ack_id,
194 message.timestamp(),
195 &snowflake_client,
196 &metrics,
197 )
198 .await
199 }
200 };
201
202 let action = match action {
203 Ok(action) => action,
204 Err(e) => {
205 tracing::error!(
206 {error = %e, ack_id = %ack_id, message_id = %message_id},
207 "[kafka] failed to decode message"
208 );
209 if let Err(nack_err) = consumer.nack(&message).await {
210 tracing::error!(
211 {error = %nack_err, message_id = %message_id},
212 "[kafka] failed to nack message"
213 );
214 }
215 continue;
216 }
217 };
218
219 let (ackable_action, acking_receiver) = AckableAction::new(action);
220
221 tracing::debug!(
222 {ack_id = %ack_id, message_id = %message_id},
223 "[kafka] received message"
224 );
225
226 let send_start_time = Instant::now();
227 match timeout(
228 config.max_time_to_send_to_async_queue,
229 priority_queue_sender.send_async(ackable_action),
230 )
231 .await
232 {
233 Ok(Ok(())) => {
234 tracing::debug!(
235 {message_id = %message_id, ack_id = %ack_id},
236 "[kafka] sent message to priority queue"
237 );
238 metrics.async_classification_added_to_queue.incr();
239 }
240 Ok(Err(e)) => {
241 tracing::error!(
242 {error = %e, message_id = %message_id},
243 "[kafka] priority queue send error"
244 );
245 if let Err(nack_err) = consumer.nack(&message).await {
246 tracing::error!(
247 {error = %nack_err, message_id = %message_id},
248 "[kafka] failed to nack message"
249 );
250 }
251 continue;
252 }
253 Err(_) => {
254 tracing::error!(
255 {message_id = %message_id},
256 "[kafka] sending to priority queue timed out"
257 );
258 if let Err(nack_err) = consumer.nack(&message).await {
259 tracing::error!(
260 {error = %nack_err, message_id = %message_id},
261 "[kafka] failed to nack message"
262 );
263 }
264 continue;
265 }
266 }
267 metrics
268 .priority_queue_send_time_async
269 .record(send_start_time.elapsed());
270
271 tracing::debug!(
272 {message_id = %message_id, ack_id = %ack_id},
273 "[kafka] waiting on ack or nack"
274 );
275
276 let receive_start_time = Instant::now();
277 match timeout(config.max_acking_receiver_wait_time, acking_receiver).await {
278 Ok(Ok(ack_or_nack)) => match ack_or_nack {
279 AckOrNack::Ack(_) => {
280 tracing::debug!(
281 {message_id = %message_id, ack_id = %ack_id},
282 "[kafka] acking message"
283 );
284 metrics.async_classification_result_ack.incr();
285 metrics
286 .receiver_ack_time_async
287 .record(receive_start_time.elapsed());
288
289 if let Err(e) = consumer.ack(&message).await {
290 tracing::error!(
291 {error = %e, message_id = %message_id},
292 "[kafka] failed to ack message"
293 );
294 }
295 }
296 AckOrNack::Nack => {
297 tracing::debug!(
298 {message_id = %message_id, ack_id = %ack_id},
299 "[kafka] nacking message"
300 );
301 metrics.async_classification_result_nack.incr();
302 metrics
303 .receiver_ack_time_async
304 .record(receive_start_time.elapsed());
305
306 if let Err(e) = consumer.nack(&message).await {
307 tracing::error!(
308 {error = %e, message_id = %message_id},
309 "[kafka] failed to nack message"
310 );
311 }
312 }
313 },
314 Ok(Err(recv_error)) => {
315 tracing::error!(
316 {message_id = %message_id, recv_error = %recv_error, ack_id = %ack_id},
317 "[kafka] acking sender dropped"
318 );
319 metrics
320 .receiver_ack_time_async
321 .record(receive_start_time.elapsed());
322
323 if let Err(e) = consumer.nack(&message).await {
324 tracing::error!(
325 {error = %e, message_id = %message_id},
326 "[kafka] failed to nack message"
327 );
328 }
329 }
330 Err(_) => {
331 tracing::error!(
332 {message_id = %message_id, ack_id = %ack_id},
333 "[kafka] waiting for ack/nack timed out"
334 );
335 metrics
336 .receiver_ack_time_async
337 .record(receive_start_time.elapsed());
338
339 if let Err(e) = consumer.nack(&message).await {
340 tracing::error!(
341 {error = %e, message_id = %message_id},
342 "[kafka] failed to nack message"
343 );
344 }
345 }
346 }
347 }
348 }
349 }
350}