Mirror of https://github.com/roostorg/osprey
github.com/roostorg/osprey
1use std::sync::Arc;
2
3use crate::etcd::Client;
4use crate::etcd_watcherd::{RecursiveKeyWatchEvents, Watcher};
5use anyhow::Result;
6
7use base64::{engine::general_purpose::STANDARD as BASE64_ENGINE, Engine};
8
9use prost::Message;
10
11/// Knows how to handle updates from etcd config changes.
12pub trait KeyHandler {
13 /// Invoked anytime a configuration change to `key` has been detected.
14 /// `value` is the raw string contents stored at `key`.
15 fn handle_key_updated(&self, key: &str, value: Option<&str>);
16}
17
18/// Validate a disconfig after the value from etcd event update has been successfully
19/// decoded into a Protobuf message object. The default implementation returns the decoded protobuf object.
20/// The type placehoder `T` is the Protobuf message object that the disconfig will be decoded into.
21/// The `validate` method can be overridden to perform customized validation on the decoded Protobuf object.
22pub trait HandleDisconfigUpdated {
23 type Disconfig: Message + Default;
24 type Error;
25
26 /// If customizing the validation, return `None` if the decoded Protobuf object is invalid.
27 fn validate(proto: Self::Disconfig) -> Result<Self::Disconfig, Self::Error> {
28 Ok(proto)
29 }
30
31 /// Invoked after the in-memory ArcSwap wrapped Disconfig has been updated.
32 fn after_update(&self) {}
33}
34
35/// Creates a (non-secure) Etcd client to watch for any value changes recursively under the `config_key_root` path.
36/// This watcher is kept alive in the background.
37/// Updates to the values under the `config_key_root` are sent to the [KeyHandler] to handle.
38pub async fn run_etcd_watcher<T: KeyHandler + Send + Sync + 'static>(
39 config_key_root: impl Into<String>,
40 key_handler: Arc<T>,
41) -> Result<()> {
42 let etcd_client = Client::from_etcd_peers()?;
43 // let etcd_client = Arc::new(etcd_client);
44 let watcher = Watcher::new(etcd_client);
45 let response = watcher.watch_key_recursive(config_key_root).await?;
46
47 let mut events = response.events();
48 // Handle full sync event before resolving the future.
49 if let RecursiveKeyWatchEvents::FullSync { items } = events.initial_event() {
50 for (key, value) in items.iter() {
51 key_handler.handle_key_updated(key, Some(value));
52 }
53 }
54
55 // Create background task to monitor for changes
56 tokio::spawn(async move {
57 loop {
58 let event = events.next().await;
59 match event {
60 RecursiveKeyWatchEvents::FullSync { items } => {
61 for (key, value) in items.iter() {
62 key_handler.handle_key_updated(key, Some(value));
63 }
64 }
65 RecursiveKeyWatchEvents::SyncOne { key, value } => {
66 key_handler.handle_key_updated(&key, Some(&value));
67 }
68 RecursiveKeyWatchEvents::DeleteOne { key, prev_value: _ } => {
69 key_handler.handle_key_updated(&key, None);
70 }
71 }
72 }
73 });
74
75 Ok(())
76}
77
78const PATTERN: &[char] = &['"', ' '];
79
80pub fn convert_to_bool(value: Option<&str>) -> bool {
81 value.map_or(false, |v| {
82 v.trim_matches(PATTERN).parse::<f32>().unwrap_or_default() > 0.0
83 })
84}
85
86pub fn convert_to_float(value: Option<&str>) -> f32 {
87 value.map_or(0.0, |v| {
88 v.trim_matches(PATTERN).parse::<f32>().unwrap_or_default()
89 })
90}
91
92pub fn convert_to_usize(value: Option<&str>) -> usize {
93 value.map_or(0, |v| {
94 v.trim_matches(PATTERN).parse::<usize>().unwrap_or_default()
95 })
96}
97
98pub fn convert_to_u64(value: Option<&str>) -> u64 {
99 value.map_or(0, |v| {
100 v.trim_matches(PATTERN).parse::<u64>().unwrap_or_default()
101 })
102}
103
104pub fn convert_to_u32(value: Option<&str>) -> u32 {
105 value.map_or(0, |v| {
106 v.trim_matches(PATTERN).parse::<u32>().unwrap_or_default()
107 })
108}
109
110pub fn base64_to_proto<T: Message + Default>(value: Option<&str>) -> Result<Option<T>> {
111 let parsed_str = match value {
112 Some(value) => match BASE64_ENGINE.decode(value.as_bytes()) {
113 Ok(binary) => binary,
114 Err(e) => return Err(anyhow::anyhow!("Failed to decode base64 value: {:?}", e)),
115 },
116 None => return Ok(None),
117 };
118
119 let proto = T::decode(parsed_str.as_slice())
120 .map_err(|e| anyhow::anyhow!("Failed to decode protobuf: {:?}", e))?;
121 Ok(Some(proto))
122}
123
124#[cfg(test)]
125mod test {
126 use super::*;
127
128 #[test]
129 fn test_convert_to_float() {
130 assert_eq!(0.0, convert_to_float(None));
131 assert_eq!(0.0, convert_to_float(Some("")));
132 assert_eq!(0.0, convert_to_float(Some(" ")));
133
134 assert_eq!(0.0, convert_to_float(Some("0")));
135 assert_eq!(0.0, convert_to_float(Some("0.0")));
136 assert_eq!(0.0, convert_to_float(Some(" 0.0 ")));
137 assert_eq!(0.1234, convert_to_float(Some("0.1234")));
138 assert_eq!(0.1234, convert_to_float(Some("\" 0.1234 \"")));
139
140 assert_eq!(1.0, convert_to_float(Some("1.0")));
141 assert_eq!(1.0, convert_to_float(Some(" 1.0 ")));
142 }
143
144 #[test]
145 fn test_convert_to_bool() {
146 assert!(!convert_to_bool(None));
147 assert!(!convert_to_bool(Some("")));
148 assert!(!convert_to_bool(Some(" ")));
149
150 assert!(!convert_to_bool(Some("0")));
151 assert!(!convert_to_bool(Some("0.0")));
152 assert!(!convert_to_bool(Some(" 0.0 ")));
153 assert!(convert_to_bool(Some("0.1234")));
154 assert!(convert_to_bool(Some("\" 0.1234 \"")));
155
156 assert!(convert_to_bool(Some("1.0")));
157 assert!(convert_to_bool(Some(" 1.0 ")));
158 }
159
160 #[test]
161 fn test_convert_to_usize() {
162 assert_eq!(0, convert_to_usize(None));
163 assert_eq!(0, convert_to_usize(Some("")));
164 assert_eq!(0, convert_to_usize(Some(" ")));
165
166 assert_eq!(0, convert_to_usize(Some("0")));
167 assert_eq!(0, convert_to_usize(Some("0.0")));
168 assert_eq!(0, convert_to_usize(Some(" 0.0 ")));
169 assert_eq!(0, convert_to_usize(Some("1.0")));
170 assert_eq!(0, convert_to_usize(Some("-1")));
171 assert_eq!(1234, convert_to_usize(Some("1234")));
172 assert_eq!(1234, convert_to_usize(Some("\" 1234 \"")));
173
174 assert_eq!(1, convert_to_usize(Some("1")));
175 assert_eq!(1, convert_to_usize(Some(" 1 ")));
176 }
177
178 #[test]
179 fn test_convert_to_u64() {
180 assert_eq!(0, convert_to_u64(None));
181 assert_eq!(0, convert_to_u64(Some("")));
182 assert_eq!(0, convert_to_u64(Some(" ")));
183
184 assert_eq!(0, convert_to_u64(Some("0")));
185 assert_eq!(0, convert_to_u64(Some("0.0")));
186 assert_eq!(0, convert_to_u64(Some(" 0.0 ")));
187 assert_eq!(0, convert_to_u64(Some("1.0")));
188 assert_eq!(0, convert_to_u64(Some("-1")));
189 assert_eq!(1234, convert_to_u64(Some("1234")));
190 assert_eq!(1234, convert_to_u64(Some("\" 1234 \"")));
191 assert_eq!(1234, convert_to_u64(Some("1234")));
192 assert_eq!(1234, convert_to_u64(Some("1234")));
193 assert_eq!(
194 1227039953469571133,
195 convert_to_u64(Some("1227039953469571133"))
196 );
197 }
198}