···66use smol_str::SmolStr;
7788use crate::config::RateTier;
99-use crate::db::pds_tiers as db_pds;
99+use crate::db::pds_meta as db_pds;
1010+use crate::pds_meta::PdsMeta;
1011use crate::state::AppState;
11121213/// a single PDS-to-tier assignment.
···4142pub struct PdsControl(pub(super) Arc<AppState>);
42434344impl PdsControl {
4545+ async fn update<F, G>(&self, db_op: F, mem_op: G) -> Result<()>
4646+ where
4747+ F: FnOnce(&mut fjall::OwnedWriteBatch, &fjall::Keyspace) + Send + 'static,
4848+ G: FnOnce(&mut PdsMeta),
4949+ {
5050+ let state = self.0.clone();
5151+ tokio::task::spawn_blocking(move || {
5252+ let mut batch = state.db.inner.batch();
5353+ db_op(&mut batch, &state.db.filter);
5454+ batch.commit().into_diagnostic()?;
5555+ state.db.persist()
5656+ })
5757+ .await
5858+ .into_diagnostic()??;
5959+6060+ let mut snapshot = (**self.0.pds_meta.load()).clone();
6161+ mem_op(&mut snapshot);
6262+ self.0.pds_meta.store(Arc::new(snapshot));
6363+6464+ Ok(())
6565+ }
6666+4467 /// list all current per-PDS tier assignments.
4545- pub async fn list_assignments(&self) -> Vec<PdsTierAssignment> {
4646- let snapshot = self.0.pds_tiers.load();
6868+ pub async fn list_tiers(&self) -> HashMap<String, String> {
6969+ let snapshot = self.0.pds_meta.load();
4770 snapshot
7171+ .tiers
4872 .iter()
4949- .map(|(host, tier)| PdsTierAssignment {
5050- host: host.clone(),
5151- tier: tier.to_string(),
5252- })
7373+ .map(|(host, tier)| (host.clone(), tier.to_string()))
5374 .collect()
5475 }
55767777+ /// returns the assigned tier for `host`, or "default" if none is assigned.
7878+ pub fn get_tier(&self, host: impl AsRef<str>) -> String {
7979+ let snapshot = self.0.pds_meta.load();
8080+ snapshot
8181+ .tiers
8282+ .get(host.as_ref())
8383+ .map(|t| t.to_string())
8484+ .unwrap_or_else(|| "default".to_string())
8585+ }
8686+8787+ /// returns true if `host` is currently banned.
8888+ pub fn is_banned(&self, host: impl AsRef<str>) -> bool {
8989+ self.0.pds_meta.load().is_banned(host.as_ref())
9090+ }
9191+9292+ /// list all currently banned PDS hosts.
9393+ pub async fn list_banned(&self) -> Vec<String> {
9494+ let snapshot = self.0.pds_meta.load();
9595+ snapshot.banned.iter().cloned().collect()
9696+ }
9797+5698 /// list all configured rate tier definitions.
5799 pub fn list_rate_tiers(&self) -> HashMap<String, PdsTierDefinition> {
58100 self.0
···6410665107 /// assign `host` to `tier`, persisting the change to the database.
66108 /// returns an error if `tier` is not a known tier name.
6767- pub async fn set_tier(&self, host: String, tier: String) -> Result<()> {
109109+ pub async fn set_tier(&self, host: impl AsRef<str>, tier: String) -> Result<()> {
68110 if !self.0.rate_tiers.contains_key(&tier) {
69111 miette::bail!(
70112 "unknown tier '{tier}'; known tiers: {:?}",
···72114 );
73115 }
741167575- let state = self.0.clone();
117117+ let host = host.as_ref().to_string();
76118 let host_clone = host.clone();
77119 let tier_clone = tier.clone();
7878- tokio::task::spawn_blocking(move || {
7979- let mut batch = state.db.inner.batch();
8080- db_pds::set(&mut batch, &state.db.filter, &host_clone, &tier_clone);
8181- batch.commit().into_diagnostic()?;
8282- state.db.persist()
8383- })
120120+ self.update(
121121+ move |batch, ks| db_pds::set_tier(batch, ks, &host_clone, &tier_clone),
122122+ move |meta| {
123123+ meta.tiers.insert(host, SmolStr::new(&tier));
124124+ },
125125+ )
84126 .await
8585- .into_diagnostic()??;
8686-8787- let mut snapshot = (**self.0.pds_tiers.load()).clone();
8888- snapshot.insert(host, SmolStr::new(&tier));
8989- self.0.pds_tiers.store(Arc::new(snapshot));
9090-9191- Ok(())
92127 }
9312894129 /// remove any explicit tier assignment for `host`, reverting it to the default tier.
9595- pub async fn remove_tier(&self, host: String) -> Result<()> {
9696- let state = self.0.clone();
130130+ pub async fn remove_tier(&self, host: impl AsRef<str>) -> Result<()> {
131131+ let host = host.as_ref().to_string();
97132 let host_clone = host.clone();
9898- tokio::task::spawn_blocking(move || {
9999- let mut batch = state.db.inner.batch();
100100- db_pds::remove(&mut batch, &state.db.filter, &host_clone);
101101- batch.commit().into_diagnostic()?;
102102- state.db.persist()
103103- })
133133+ self.update(
134134+ move |batch, ks| db_pds::remove_tier(batch, ks, &host_clone),
135135+ move |meta| {
136136+ meta.tiers.remove(&host);
137137+ },
138138+ )
104139 .await
105105- .into_diagnostic()??;
140140+ }
106141107107- let mut snapshot = (**self.0.pds_tiers.load()).clone();
108108- snapshot.remove(&host);
109109- self.0.pds_tiers.store(Arc::new(snapshot));
142142+ /// ban `host`, persisting the change to the database.
143143+ pub async fn ban(&self, host: impl AsRef<str>) -> Result<()> {
144144+ let host = host.as_ref().to_string();
145145+ let host_clone = host.clone();
146146+ self.update(
147147+ move |batch, ks| db_pds::set_banned(batch, ks, &host_clone),
148148+ move |meta| {
149149+ meta.banned.insert(host);
150150+ },
151151+ )
152152+ .await
153153+ }
110154111111- Ok(())
155155+ /// unban `host`, removing it from the database.
156156+ pub async fn unban(&self, host: impl AsRef<str>) -> Result<()> {
157157+ let host = host.as_ref().to_string();
158158+ let host_clone = host.clone();
159159+ self.update(
160160+ move |batch, ks| db_pds::remove_banned(batch, ks, &host_clone),
161161+ move |meta| {
162162+ meta.banned.remove(&host);
163163+ },
164164+ )
165165+ .await
112166 }
113167}
+2-2
src/db/mod.rs
···2727pub mod filter;
2828pub mod keys;
2929pub mod migration;
3030-pub mod pds_tiers;
3030+pub mod pds_meta;
3131pub mod types;
32323333use tokio::sync::broadcast;
···389389 opts()
390390 // only iterators are used here
391391 .expect_point_read_hits(true)
392392- .max_memtable_size(mb(16))
392392+ .max_memtable_size(mb(8))
393393 // did -> failed state, not very compressable
394394 .data_block_size_policy(BlockSizePolicy::all(kb(2)))
395395 .data_block_compression_policy(CompressionPolicy::disabled())
···105105 // this is not for connection throttling (thats handled by ThrottleHandle)
106106 // its for stream errors (cbor decode etc)
107107 let mut backoff = Duration::from_secs(0);
108108- const MAX_BACKOFF: Duration = Duration::from_secs(60 * 60); // 1 ohur
108108+ const MAX_BACKOFF: Duration = Duration::from_secs(60 * 60); // 1 hour
109109110110 loop {
111111+ if self.state.pds_meta.load().is_banned(host) {
112112+ break Ok(());
113113+ }
114114+111115 self.enabled.wait_enabled("firehose").await;
112116113117 tokio::time::sleep(backoff).await;
···169173 match decode_frame(&bytes) {
170174 Ok(msg) => {
171175 if self.is_pds {
176176+ let tier = {
177177+ let meta = self.state.pds_meta.load();
178178+ let banned = meta.is_banned(host);
179179+ if banned {
180180+ break Ok(());
181181+ }
182182+ meta.tier_for(host, &self.state.rate_tiers)
183183+ };
172184 let accounts = self.state.db.get_count(&count_key).await;
173173- let tier = self.state.pds_tier_for(&host);
174185 tokio::select! {
175186 _ = self.throttle.wait_for_allow(accounts, &tier) => {}
176187 _ = self.enabled.changed() => {
+5-1
src/lib.rs
···11pub mod config;
22+/// hydrant main api, includes the Hydrant type for programmatic control.
23pub mod control;
33-pub mod filter;
44+pub(crate) mod filter;
55+pub(crate) mod pds_meta;
46pub mod types;
5768#[cfg(all(feature = "relay", feature = "indexer"))]
···2325pub(crate) mod resolver;
2426pub(crate) mod state;
2527pub(crate) mod util;
2828+2929+pub use filter::FilterMode;