very fast at protocol indexer with flexible filtering, xrpc queries, cursor-backed event stream, and more, built on fjall
rust
fjall
at-protocol
atproto
indexer
1#![allow(dead_code)]
2
3use std::{hash::Hash, time::Duration};
4
5use jacquard_common::{deps::fluent_uri, types::string::Handle};
6use rand::RngExt;
7use reqwest::StatusCode;
8use serde::{Deserialize, Deserializer, Serializer};
9use tokio::sync::watch;
10use tracing::info;
11use url::Url;
12
13use crate::{db::types::DidKey, types::RepoStatus};
14
15pub mod throttle;
16
17#[allow(dead_code)]
18/// checks if the error contains a hyper / std io timeout error
19pub fn is_timeout(err: &dyn std::error::Error) -> bool {
20 let mut source = err.source();
21
22 while let Some(err) = source {
23 if let Some(hyper_err) = err.downcast_ref::<hyper::Error>() {
24 if hyper_err.is_timeout() {
25 return true;
26 }
27 }
28 if let Some(io) = err.downcast_ref::<std::io::Error>() {
29 if io.kind() == std::io::ErrorKind::TimedOut {
30 return true;
31 }
32 }
33 source = err.source();
34 }
35
36 false
37}
38
39pub fn is_tls_cert_error(io_err: &std::io::Error) -> bool {
40 let Some(inner) = io_err.get_ref() else {
41 return false;
42 };
43 if let Some(rustls_err) = inner.downcast_ref::<rustls::Error>() {
44 return is_tls_error_their_fault(rustls_err);
45 }
46 if let Some(nested_io) = inner.downcast_ref::<std::io::Error>() {
47 return is_tls_cert_error(nested_io);
48 }
49 false
50}
51
52pub fn is_io_error_their_fault(e: &std::io::Error) -> bool {
53 use std::io::ErrorKind::*;
54 matches!(
55 e.kind(),
56 // some of these maybe our fault, but lets assume we have working networking
57 // if its our fault chances are most of the other hosts will also fail, which will be in the logs
58 // we log the error anyway so it should be easy to tell if something is going bad
59 ConnectionRefused
60 | HostUnreachable
61 | NetworkUnreachable
62 | ConnectionReset
63 | ConnectionAborted
64 | TimedOut
65 | UnexpectedEof
66 )
67}
68
69pub fn is_tls_error_their_fault(e: &rustls::Error) -> bool {
70 use rustls::Error::*;
71
72 matches!(
73 *e,
74 InvalidCertificate(_)
75 | PeerMisbehaved(_)
76 | InconsistentKeys(_)
77 | InappropriateMessage { .. }
78 | InappropriateHandshakeMessage { .. }
79 | InvalidMessage(_)
80 | NoCertificatesPresented
81 | UnsupportedNameType
82 | DecryptError
83 | AlertReceived(_)
84 | PeerIncompatible(_)
85 | InvalidCertRevocationList(_)
86 | InvalidEncryptedClientHello(_)
87 | PeerSentOversizedRecord
88 | NoApplicationProtocol // this is not exhaustive, so remember to look at rustls::Error on version changes
89 )
90}
91
92// use this for public (unauth) xrpc errors
93pub fn is_status_their_fault(status: u16) -> bool {
94 return (status >= 100 && status < 200) // informational, why are we here?
95 || (status >= 500 && status < 600) // server error :>
96 || (status >= 300 && status < 400) // any 3xx error doesnt make sense in the context of a pds / relay
97 || matches!(
98 status,
99 404 // NOT FOUND: we know its not our fault because we use known xrpcs..
100 | 436 // some stupid ass error code idk, some domain park uses this i think???
101 | 403 // FORBIDDEN: sob
102 | 401 // UNAUTHORIZED: sob
103 | 410 // GONE: sob
104 );
105}
106
107/// outcome of [`RetryWithBackoff::retry`] when the operation does not succeed.
108pub enum RetryOutcome<E> {
109 /// ratelimited after exhausting all retries
110 Ratelimited,
111 /// non-ratelimit failure, carrying the last error
112 Failed(E),
113}
114
115/// extension trait that adds `.retry()` to async `FnMut` closures.
116///
117/// `on_ratelimit` receives the error and current attempt number.
118/// returning `Some(duration)` signals a transient failure and provides the backoff;
119/// returning `None` signals a terminal failure.
120pub trait RetryWithBackoff<T, E, Fut>: FnMut() -> Fut
121where
122 Fut: Future<Output = Result<T, E>>,
123{
124 #[allow(async_fn_in_trait)]
125 async fn retry(
126 &mut self,
127 max_retries: u32,
128 on_ratelimit: impl Fn(&E, u32) -> Option<Duration>,
129 ) -> Result<T, RetryOutcome<E>> {
130 let mut attempt = 0u32;
131 loop {
132 match self().await {
133 Ok(val) => return Ok(val),
134 Err(e) => match on_ratelimit(&e, attempt) {
135 Some(_) if attempt >= max_retries => return Err(RetryOutcome::Ratelimited),
136 Some(backoff) => {
137 // jitter the backoff
138 let backoff = rand::rng().random_range((backoff / 2)..backoff);
139 tokio::time::sleep(backoff).await;
140 attempt += 1;
141 }
142 None => return Err(RetryOutcome::Failed(e)),
143 },
144 }
145 }
146 }
147}
148
149impl<T, E, F, Fut> RetryWithBackoff<T, E, Fut> for F
150where
151 F: FnMut() -> Fut,
152 Fut: Future<Output = Result<T, E>>,
153{
154}
155
156/// extension trait that adds `.wait_enabled()` to `watch::Receiver<bool>`.
157///
158/// waits until the value becomes `true`, logging once when paused and once when resumed.
159pub trait WatchEnabledExt {
160 #[allow(async_fn_in_trait)]
161 async fn wait_enabled(&mut self, component: &'static str);
162}
163
164impl WatchEnabledExt for watch::Receiver<bool> {
165 async fn wait_enabled(&mut self, component: &'static str) {
166 if !*self.borrow() {
167 info!("{component} paused");
168 while !*self.borrow() {
169 let _ = self.changed().await;
170 }
171 info!("{component} resumed");
172 }
173 }
174}
175
176/// extension trait that adds `.error_for_status()` to futures returning a reqwest `Response`.
177pub trait ErrorForStatus: Future<Output = Result<reqwest::Response, reqwest::Error>> {
178 fn error_for_status(self) -> impl Future<Output = Result<reqwest::Response, reqwest::Error>>
179 where
180 Self: Sized,
181 {
182 futures::FutureExt::map(self, |r| r.and_then(|r| r.error_for_status()))
183 }
184}
185
186impl<F: Future<Output = Result<reqwest::Response, reqwest::Error>>> ErrorForStatus for F {}
187
188/// extracts a retry delay in seconds from rate limit response headers.
189///
190/// checks in priority order:
191/// - `retry-after: <seconds>` (relative)
192/// - `ratelimit-reset: <unix timestamp>` (absolute) (ref pds sends this)
193pub fn parse_retry_after(resp: &reqwest::Response) -> Option<u64> {
194 let headers = resp.headers();
195
196 let retry_after = headers
197 .get(reqwest::header::RETRY_AFTER)
198 .and_then(|v| v.to_str().ok())
199 .and_then(|s| s.parse::<u64>().ok());
200
201 let rate_limit_reset = headers
202 .get("ratelimit-reset")
203 .and_then(|v| v.to_str().ok())
204 .and_then(|s| s.parse::<i64>().ok())
205 .map(|ts| {
206 let now = chrono::Utc::now().timestamp();
207 (ts - now).max(1) as u64
208 });
209
210 retry_after.or(rate_limit_reset)
211}
212
213// cloudflare-specific status codes
214pub const CONNECTION_TIMEOUT: StatusCode = unsafe {
215 match StatusCode::from_u16(522) {
216 Ok(s) => s,
217 _ => std::hint::unreachable_unchecked(),
218 }
219};
220pub const SITE_FROZEN: StatusCode = unsafe {
221 match StatusCode::from_u16(530) {
222 Ok(s) => s,
223 _ => std::hint::unreachable_unchecked(),
224 }
225};
226pub const SSL_HANDSHAKE_FAILURE: StatusCode = unsafe {
227 match StatusCode::from_u16(525) {
228 Ok(s) => s,
229 _ => std::hint::unreachable_unchecked(),
230 }
231};
232
233pub fn ser_status_code<S: Serializer>(s: &Option<StatusCode>, ser: S) -> Result<S::Ok, S::Error> {
234 match s {
235 Some(code) => ser.serialize_some(&code.as_u16()),
236 None => ser.serialize_none(),
237 }
238}
239
240pub fn deser_status_code<'de, D: Deserializer<'de>>(
241 deser: D,
242) -> Result<Option<StatusCode>, D::Error> {
243 Option::<u16>::deserialize(deser)?
244 .map(StatusCode::from_u16)
245 .transpose()
246 .map_err(serde::de::Error::custom)
247}
248
249pub fn opt_cid_serialize_str<S: Serializer>(v: &Option<cid::Cid>, s: S) -> Result<S::Ok, S::Error> {
250 match v {
251 Some(cid) => s.serialize_some(cid.to_string().as_str()),
252 None => s.serialize_none(),
253 }
254}
255
256pub fn did_key_serialize_str<S: Serializer>(v: &DidKey<'_>, s: S) -> Result<S::Ok, S::Error> {
257 s.serialize_str(&v.encode())
258}
259
260pub fn opt_did_key_serialize_str<S: Serializer>(
261 v: &Option<DidKey<'_>>,
262 s: S,
263) -> Result<S::Ok, S::Error> {
264 match v {
265 Some(k) => s.serialize_some(k.encode().as_str()),
266 None => s.serialize_none(),
267 }
268}
269
270pub fn repo_status_serialize_str<S: Serializer>(v: &RepoStatus, s: S) -> Result<S::Ok, S::Error> {
271 s.serialize_str(&v.to_string())
272}
273
274pub fn url_to_fluent_uri(url: &Url) -> fluent_uri::Uri<String> {
275 fluent_uri::Uri::parse(url.as_str())
276 .expect("that url is validated")
277 .to_owned()
278}
279
280pub(crate) fn invalid_handle() -> Handle<'static> {
281 unsafe { Handle::unchecked("handle.invalid") }
282}
283
284/// returns hash of value using ahash
285pub fn hash<T: Hash>(val: &T) -> u64 {
286 use std::hash::Hasher;
287 let mut hasher = ahash::AHasher::default();
288 val.hash(&mut hasher);
289 hasher.finish()
290}