A lexicon-driven AppView for ATProto. happyview.dev
backfill firehose jetstream atproto appview oauth lexicon
8
fork

Configure Feed

Select the types of activity you want to include in your feed.

feat: add support for multiple domains

Trezy cd87c9f4 1a2085a9

+1449 -661
+2
migrations/postgres/20260414000000_drop_rate_limits_tables.sql
··· 1 + DROP TABLE IF EXISTS rate_limits; 2 + DROP TABLE IF EXISTS rate_limit_settings;
+7
migrations/postgres/20260415000000_create_domains.sql
··· 1 + CREATE TABLE IF NOT EXISTS domains ( 2 + id TEXT PRIMARY KEY, 3 + url TEXT NOT NULL UNIQUE, 4 + is_primary INTEGER NOT NULL DEFAULT 0, 5 + created_at TEXT NOT NULL, 6 + updated_at TEXT NOT NULL 7 + );
+2
migrations/sqlite/20260414000000_drop_rate_limits_tables.sql
··· 1 + DROP TABLE IF EXISTS rate_limits; 2 + DROP TABLE IF EXISTS rate_limit_settings;
+7
migrations/sqlite/20260415000000_create_domains.sql
··· 1 + CREATE TABLE IF NOT EXISTS domains ( 2 + id TEXT PRIMARY KEY, 3 + url TEXT NOT NULL UNIQUE, 4 + is_primary INTEGER NOT NULL DEFAULT 0, 5 + created_at TEXT NOT NULL, 6 + updated_at TEXT NOT NULL 7 + );
+8 -8
src/admin/api_clients.rs
··· 94 94 if let (Some(capacity), Some(refill_rate)) = 95 95 (body.rate_limit_capacity, body.rate_limit_refill_rate) 96 96 { 97 - let global = state.rate_limiter.global_config(); 97 + let defaults = state.rate_limiter.defaults(); 98 98 state.rate_limiter.register_client_config( 99 99 client_key.clone(), 100 100 crate::rate_limit::RateLimitConfig { 101 101 capacity: capacity as u32, 102 102 refill_rate, 103 - default_query_cost: global.default_query_cost, 104 - default_procedure_cost: global.default_procedure_cost, 105 - default_proxy_cost: global.default_proxy_cost, 103 + default_query_cost: defaults.query_cost, 104 + default_procedure_cost: defaults.procedure_cost, 105 + default_proxy_cost: defaults.proxy_cost, 106 106 }, 107 107 ); 108 108 } ··· 397 397 }, 398 398 ); 399 399 if let (Some(cap), Some(refill)) = (capacity, refill_rate) { 400 - let global = state.rate_limiter.global_config(); 400 + let defaults = state.rate_limiter.defaults(); 401 401 state.rate_limiter.register_client_config( 402 402 client_key, 403 403 crate::rate_limit::RateLimitConfig { 404 404 capacity: cap as u32, 405 405 refill_rate: refill, 406 - default_query_cost: global.default_query_cost, 407 - default_procedure_cost: global.default_procedure_cost, 408 - default_proxy_cost: global.default_proxy_cost, 406 + default_query_cost: defaults.query_cost, 407 + default_procedure_cost: defaults.procedure_cost, 408 + default_proxy_cost: defaults.proxy_cost, 409 409 }, 410 410 ); 411 411 } else {
+306
src/admin/domains.rs
··· 1 + use axum::Json; 2 + use axum::extract::{Path, State}; 3 + use axum::http::StatusCode; 4 + 5 + use crate::AppState; 6 + use crate::db::{adapt_sql, now_rfc3339}; 7 + use crate::domain::Domain; 8 + use crate::error::AppError; 9 + use crate::event_log::{EventLog, Severity, log_event}; 10 + 11 + use super::auth::UserAuth; 12 + use super::permissions::Permission; 13 + use super::types::{CreateDomainBody, DomainResponse}; 14 + 15 + fn domain_to_response(d: &Domain) -> DomainResponse { 16 + DomainResponse { 17 + id: d.id.clone(), 18 + url: d.url.clone(), 19 + is_primary: d.is_primary, 20 + created_at: d.created_at.clone(), 21 + updated_at: d.updated_at.clone(), 22 + } 23 + } 24 + 25 + /// GET /admin/domains 26 + pub(super) async fn list( 27 + State(state): State<AppState>, 28 + auth: UserAuth, 29 + ) -> Result<Json<Vec<DomainResponse>>, AppError> { 30 + auth.require(Permission::SettingsManage).await?; 31 + 32 + let sql = adapt_sql( 33 + "SELECT id, url, is_primary, created_at, updated_at FROM domains ORDER BY created_at", 34 + state.db_backend, 35 + ); 36 + let rows: Vec<(String, String, i32, String, String)> = sqlx::query_as(&sql) 37 + .fetch_all(&state.db) 38 + .await 39 + .map_err(|e| AppError::Internal(format!("failed to list domains: {e}")))?; 40 + 41 + let domains: Vec<DomainResponse> = rows 42 + .into_iter() 43 + .map( 44 + |(id, url, is_primary, created_at, updated_at)| DomainResponse { 45 + id, 46 + url, 47 + is_primary: is_primary != 0, 48 + created_at, 49 + updated_at, 50 + }, 51 + ) 52 + .collect(); 53 + 54 + Ok(Json(domains)) 55 + } 56 + 57 + /// POST /admin/domains 58 + pub(super) async fn create( 59 + State(state): State<AppState>, 60 + auth: UserAuth, 61 + Json(body): Json<CreateDomainBody>, 62 + ) -> Result<(StatusCode, Json<DomainResponse>), AppError> { 63 + auth.require(Permission::SettingsManage).await?; 64 + 65 + let url = body.url.trim_end_matches('/').to_string(); 66 + 67 + let parsed = 68 + reqwest::Url::parse(&url).map_err(|_| AppError::BadRequest("invalid URL".into()))?; 69 + 70 + if parsed.path() != "/" && !parsed.path().is_empty() { 71 + return Err(AppError::BadRequest("URL must not contain a path".into())); 72 + } 73 + 74 + if parsed.host_str().is_none() { 75 + return Err(AppError::BadRequest("URL must contain a host".into())); 76 + } 77 + 78 + let is_loopback = state.config.public_url.contains("127.0.0.1") 79 + || state.config.public_url.contains("[::1]") 80 + || state.config.public_url.contains("localhost"); 81 + 82 + if parsed.scheme() != "https" && !is_loopback { 83 + return Err(AppError::BadRequest("URL scheme must be https".into())); 84 + } 85 + 86 + // Check for duplicates 87 + let existing: Option<(String,)> = sqlx::query_as(&adapt_sql( 88 + "SELECT id FROM domains WHERE url = ?", 89 + state.db_backend, 90 + )) 91 + .bind(&url) 92 + .fetch_optional(&state.db) 93 + .await 94 + .map_err(|e| AppError::Internal(format!("failed to check domain: {e}")))?; 95 + 96 + if existing.is_some() { 97 + return Err(AppError::BadRequest(format!( 98 + "domain '{url}' already exists" 99 + ))); 100 + } 101 + 102 + let id = uuid::Uuid::new_v4().to_string(); 103 + let now = now_rfc3339(); 104 + 105 + let sql = adapt_sql( 106 + "INSERT INTO domains (id, url, is_primary, created_at, updated_at) VALUES (?, ?, 0, ?, ?)", 107 + state.db_backend, 108 + ); 109 + sqlx::query(&sql) 110 + .bind(&id) 111 + .bind(&url) 112 + .bind(&now) 113 + .bind(&now) 114 + .execute(&state.db) 115 + .await 116 + .map_err(|e| AppError::Internal(format!("failed to create domain: {e}")))?; 117 + 118 + let domain = Domain { 119 + id: id.clone(), 120 + url: url.clone(), 121 + is_primary: false, 122 + created_at: now.clone(), 123 + updated_at: now, 124 + }; 125 + 126 + // Register the OAuth client for this domain 127 + state 128 + .oauth 129 + .register_domain_client(url.clone(), state.oauth.primary_client()); 130 + 131 + // Build a proper OAuth client if not loopback 132 + let domain_is_loopback = 133 + url.contains("127.0.0.1") || url.contains("[::1]") || url.contains("localhost"); 134 + if !domain_is_loopback { 135 + let client_id_url = format!("{}/oauth-client-metadata.json", url.trim_end_matches('/')); 136 + let callback = format!("{}/auth/callback", url.trim_end_matches('/')); 137 + if let Err(e) = state.oauth.register_api_client( 138 + &client_id_url, 139 + &url, 140 + vec![callback], 141 + "atproto", 142 + &crate::auth::client_registry::ApiClientOAuthParams { 143 + plc_url: state.config.plc_url.clone(), 144 + state_store: state.oauth_state_store.clone(), 145 + session_store_pool: state.db.clone(), 146 + db_backend: state.db_backend, 147 + }, 148 + ) { 149 + tracing::error!(domain = %url, error = %e, "Failed to create OAuth client for domain"); 150 + } else { 151 + // Move from `clients` (where register_api_client puts it) to domain_clients + clients 152 + if let Some(client) = state.oauth.get(&client_id_url) { 153 + state.oauth.remove(&client_id_url); 154 + state.oauth.register_domain_client(url.clone(), client); 155 + } 156 + } 157 + } 158 + 159 + // Update in-memory cache 160 + state.domain_cache.insert(domain.clone()).await; 161 + 162 + log_event( 163 + &state.db, 164 + EventLog { 165 + event_type: "domain.created".to_string(), 166 + severity: Severity::Info, 167 + actor_did: Some(auth.did.clone()), 168 + subject: Some(url), 169 + detail: serde_json::json!({ "id": id }), 170 + }, 171 + state.db_backend, 172 + ) 173 + .await; 174 + 175 + let response = domain_to_response(&domain); 176 + Ok((StatusCode::CREATED, Json(response))) 177 + } 178 + 179 + /// DELETE /admin/domains/{id} 180 + pub(super) async fn delete( 181 + State(state): State<AppState>, 182 + auth: UserAuth, 183 + Path(id): Path<String>, 184 + ) -> Result<StatusCode, AppError> { 185 + auth.require(Permission::SettingsManage).await?; 186 + 187 + let sql = adapt_sql( 188 + "SELECT id, url, is_primary, created_at, updated_at FROM domains WHERE id = ?", 189 + state.db_backend, 190 + ); 191 + let row: Option<(String, String, i32, String, String)> = sqlx::query_as(&sql) 192 + .bind(&id) 193 + .fetch_optional(&state.db) 194 + .await 195 + .map_err(|e| AppError::Internal(format!("failed to find domain: {e}")))?; 196 + 197 + let (_, url, is_primary, _, _) = 198 + row.ok_or_else(|| AppError::NotFound("domain not found".into()))?; 199 + 200 + if is_primary != 0 { 201 + return Err(AppError::BadRequest( 202 + "cannot delete the primary domain — set a different domain as primary first".into(), 203 + )); 204 + } 205 + 206 + let delete_sql = adapt_sql("DELETE FROM domains WHERE id = ?", state.db_backend); 207 + sqlx::query(&delete_sql) 208 + .bind(&id) 209 + .execute(&state.db) 210 + .await 211 + .map_err(|e| AppError::Internal(format!("failed to delete domain: {e}")))?; 212 + 213 + // Remove OAuth client and cache entry 214 + state.oauth.remove_domain_client(&url); 215 + let host = url 216 + .strip_prefix("https://") 217 + .or_else(|| url.strip_prefix("http://")) 218 + .unwrap_or(&url); 219 + state.domain_cache.remove(host).await; 220 + 221 + log_event( 222 + &state.db, 223 + EventLog { 224 + event_type: "domain.deleted".to_string(), 225 + severity: Severity::Info, 226 + actor_did: Some(auth.did.clone()), 227 + subject: Some(url), 228 + detail: serde_json::json!({ "id": id }), 229 + }, 230 + state.db_backend, 231 + ) 232 + .await; 233 + 234 + Ok(StatusCode::NO_CONTENT) 235 + } 236 + 237 + /// POST /admin/domains/{id}/primary 238 + pub(super) async fn set_primary( 239 + State(state): State<AppState>, 240 + auth: UserAuth, 241 + Path(id): Path<String>, 242 + ) -> Result<StatusCode, AppError> { 243 + auth.require(Permission::SettingsManage).await?; 244 + 245 + let sql = adapt_sql( 246 + "SELECT id, url, is_primary, created_at, updated_at FROM domains WHERE id = ?", 247 + state.db_backend, 248 + ); 249 + let row: Option<(String, String, i32, String, String)> = sqlx::query_as(&sql) 250 + .bind(&id) 251 + .fetch_optional(&state.db) 252 + .await 253 + .map_err(|e| AppError::Internal(format!("failed to find domain: {e}")))?; 254 + 255 + let (_, url, _, _, _) = row.ok_or_else(|| AppError::NotFound("domain not found".into()))?; 256 + 257 + let now = now_rfc3339(); 258 + 259 + let unset_sql = adapt_sql( 260 + "UPDATE domains SET is_primary = 0, updated_at = ? WHERE is_primary = 1", 261 + state.db_backend, 262 + ); 263 + sqlx::query(&unset_sql) 264 + .bind(&now) 265 + .execute(&state.db) 266 + .await 267 + .map_err(|e| AppError::Internal(format!("failed to unset primary: {e}")))?; 268 + 269 + let set_sql = adapt_sql( 270 + "UPDATE domains SET is_primary = 1, updated_at = ? WHERE id = ?", 271 + state.db_backend, 272 + ); 273 + sqlx::query(&set_sql) 274 + .bind(&now) 275 + .bind(&id) 276 + .execute(&state.db) 277 + .await 278 + .map_err(|e| AppError::Internal(format!("failed to set primary: {e}")))?; 279 + 280 + // Update cache 281 + let host = url 282 + .strip_prefix("https://") 283 + .or_else(|| url.strip_prefix("http://")) 284 + .unwrap_or(&url); 285 + state.domain_cache.set_primary(host).await; 286 + 287 + // Update OAuth primary client 288 + if let Some(client) = state.oauth.get_domain_client(&url) { 289 + state.oauth.set_primary_client(client); 290 + } 291 + 292 + log_event( 293 + &state.db, 294 + EventLog { 295 + event_type: "domain.primary_changed".to_string(), 296 + severity: Severity::Info, 297 + actor_did: Some(auth.did.clone()), 298 + subject: Some(url), 299 + detail: serde_json::json!({ "id": id }), 300 + }, 301 + state.db_backend, 302 + ) 303 + .await; 304 + 305 + Ok(StatusCode::NO_CONTENT) 306 + }
+4 -6
src/admin/mod.rs
··· 2 2 mod api_keys; 3 3 pub(crate) mod auth; 4 4 mod backfill; 5 + mod domains; 5 6 mod events; 6 7 mod labelers; 7 8 mod lexicons; 8 9 mod network_lexicons; 9 10 pub(crate) mod permissions; 10 11 mod plugins; 11 - mod rate_limits; 12 12 mod records; 13 13 mod script_variables; 14 14 pub mod settings; ··· 70 70 "/labelers/{did}", 71 71 patch(labelers::update).delete(labelers::delete), 72 72 ) 73 - .route( 74 - "/rate-limits", 75 - post(rate_limits::upsert).get(rate_limits::list), 76 - ) 77 - .route("/rate-limits/enabled", put(rate_limits::set_enabled)) 78 73 .route("/settings", get(settings::list)) 79 74 .route( 80 75 "/settings/logo", ··· 104 99 .put(api_clients::update_api_client) 105 100 .delete(api_clients::delete_api_client), 106 101 ) 102 + .route("/domains", post(domains::create).get(domains::list)) 103 + .route("/domains/{id}", delete(domains::delete)) 104 + .route("/domains/{id}/primary", post(domains::set_primary)) 107 105 }
-16
src/admin/permissions.rs
··· 60 60 #[serde(rename = "labelers:delete")] 61 61 LabelersDelete, 62 62 63 - #[serde(rename = "rate-limits:read")] 64 - RateLimitsRead, 65 - #[serde(rename = "rate-limits:create")] 66 - RateLimitsCreate, 67 - #[serde(rename = "rate-limits:delete")] 68 - RateLimitsDelete, 69 - 70 63 #[serde(rename = "settings:manage")] 71 64 SettingsManage, 72 65 ··· 114 107 Self::LabelersCreate => "labelers:create", 115 108 Self::LabelersRead => "labelers:read", 116 109 Self::LabelersDelete => "labelers:delete", 117 - Self::RateLimitsRead => "rate-limits:read", 118 - Self::RateLimitsCreate => "rate-limits:create", 119 - Self::RateLimitsDelete => "rate-limits:delete", 120 110 Self::SettingsManage => "settings:manage", 121 111 Self::PluginsRead => "plugins:read", 122 112 Self::PluginsCreate => "plugins:create", ··· 154 144 Self::LabelersCreate, 155 145 Self::LabelersRead, 156 146 Self::LabelersDelete, 157 - Self::RateLimitsRead, 158 - Self::RateLimitsCreate, 159 - Self::RateLimitsDelete, 160 147 Self::SettingsManage, 161 148 Self::PluginsRead, 162 149 Self::PluginsCreate, ··· 209 196 perms.insert(Permission::LabelersCreate); 210 197 perms.insert(Permission::LabelersRead); 211 198 perms.insert(Permission::LabelersDelete); 212 - perms.insert(Permission::RateLimitsRead); 213 - perms.insert(Permission::RateLimitsCreate); 214 - perms.insert(Permission::RateLimitsDelete); 215 199 perms.insert(Permission::SettingsManage); 216 200 perms.insert(Permission::PluginsRead); 217 201 perms.insert(Permission::PluginsCreate);
-158
src/admin/rate_limits.rs
··· 1 - use axum::Json; 2 - use axum::extract::State; 3 - use axum::http::StatusCode; 4 - 5 - use crate::AppState; 6 - use crate::db::{adapt_sql, now_rfc3339}; 7 - use crate::error::AppError; 8 - use crate::event_log::{EventLog, Severity, log_event}; 9 - 10 - use super::auth::UserAuth; 11 - use super::permissions::Permission; 12 - use super::types::{RateLimitsResponse, SetEnabledBody, UpsertRateLimitBody}; 13 - 14 - /// GET /admin/rate-limits — list rate limit config. 15 - pub(super) async fn list( 16 - State(state): State<AppState>, 17 - auth: UserAuth, 18 - ) -> Result<Json<RateLimitsResponse>, AppError> { 19 - auth.require(Permission::RateLimitsRead).await?; 20 - 21 - let backend = state.db_backend; 22 - 23 - let enabled_sql = adapt_sql( 24 - "SELECT value FROM rate_limit_settings WHERE key = 'enabled'", 25 - backend, 26 - ); 27 - let enabled: String = sqlx::query_scalar(&enabled_sql) 28 - .fetch_optional(&state.db) 29 - .await 30 - .map_err(|e| AppError::Internal(format!("failed to read rate limit settings: {e}")))? 31 - .unwrap_or_else(|| "true".to_string()); 32 - 33 - let limits_sql = adapt_sql( 34 - "SELECT capacity, refill_rate, default_query_cost, default_procedure_cost, default_proxy_cost FROM rate_limits WHERE method IS NULL", 35 - backend, 36 - ); 37 - let row: Option<(i32, f64, i32, i32, i32)> = sqlx::query_as(&limits_sql) 38 - .fetch_optional(&state.db) 39 - .await 40 - .map_err(|e| AppError::Internal(format!("failed to read rate limits: {e}")))?; 41 - 42 - let (capacity, refill_rate, default_query_cost, default_procedure_cost, default_proxy_cost) = 43 - row.unwrap_or((100, 2.0, 1, 1, 1)); 44 - 45 - Ok(Json(RateLimitsResponse { 46 - enabled: enabled == "true", 47 - capacity, 48 - refill_rate, 49 - default_query_cost, 50 - default_procedure_cost, 51 - default_proxy_cost, 52 - })) 53 - } 54 - 55 - /// POST /admin/rate-limits — upsert the global rate limit config. 56 - pub(super) async fn upsert( 57 - State(state): State<AppState>, 58 - auth: UserAuth, 59 - Json(body): Json<UpsertRateLimitBody>, 60 - ) -> Result<StatusCode, AppError> { 61 - auth.require(Permission::RateLimitsCreate).await?; 62 - 63 - let backend = state.db_backend; 64 - let now = now_rfc3339(); 65 - let sql = adapt_sql( 66 - r#" 67 - INSERT INTO rate_limits (method, capacity, refill_rate, default_query_cost, default_procedure_cost, default_proxy_cost, created_at) 68 - VALUES (NULL, ?, ?, ?, ?, ?, ?) 69 - ON CONFLICT (method) DO UPDATE SET 70 - capacity = EXCLUDED.capacity, 71 - refill_rate = EXCLUDED.refill_rate, 72 - default_query_cost = EXCLUDED.default_query_cost, 73 - default_procedure_cost = EXCLUDED.default_procedure_cost, 74 - default_proxy_cost = EXCLUDED.default_proxy_cost, 75 - updated_at = ? 76 - "#, 77 - backend, 78 - ); 79 - sqlx::query(&sql) 80 - .bind(body.capacity as i32) 81 - .bind(body.refill_rate) 82 - .bind(body.default_query_cost as i32) 83 - .bind(body.default_procedure_cost as i32) 84 - .bind(body.default_proxy_cost as i32) 85 - .bind(&now) 86 - .bind(&now) 87 - .execute(&state.db) 88 - .await 89 - .map_err(|e| AppError::Internal(format!("failed to upsert rate limit: {e}")))?; 90 - 91 - state.rate_limiter.reload_from_db(&state.db).await; 92 - 93 - log_event( 94 - &state.db, 95 - EventLog { 96 - event_type: "rate_limit.upserted".to_string(), 97 - severity: Severity::Info, 98 - actor_did: Some(auth.did.clone()), 99 - subject: None, 100 - detail: serde_json::json!({ 101 - "capacity": body.capacity, 102 - "refill_rate": body.refill_rate, 103 - "default_query_cost": body.default_query_cost, 104 - "default_procedure_cost": body.default_procedure_cost, 105 - "default_proxy_cost": body.default_proxy_cost, 106 - }), 107 - }, 108 - state.db_backend, 109 - ) 110 - .await; 111 - 112 - Ok(StatusCode::CREATED) 113 - } 114 - 115 - /// PUT /admin/rate-limits/enabled — toggle rate limiting. 116 - pub(super) async fn set_enabled( 117 - State(state): State<AppState>, 118 - auth: UserAuth, 119 - Json(body): Json<SetEnabledBody>, 120 - ) -> Result<StatusCode, AppError> { 121 - auth.require(Permission::RateLimitsCreate).await?; 122 - 123 - let value = if body.enabled { "true" } else { "false" }; 124 - 125 - let backend = state.db_backend; 126 - let now = now_rfc3339(); 127 - let sql = adapt_sql( 128 - r#" 129 - INSERT INTO rate_limit_settings (key, value) 130 - VALUES ('enabled', ?) 131 - ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value, updated_at = ? 132 - "#, 133 - backend, 134 - ); 135 - sqlx::query(&sql) 136 - .bind(value) 137 - .bind(&now) 138 - .execute(&state.db) 139 - .await 140 - .map_err(|e| AppError::Internal(format!("failed to update rate limit settings: {e}")))?; 141 - 142 - state.rate_limiter.set_enabled(body.enabled); 143 - 144 - log_event( 145 - &state.db, 146 - EventLog { 147 - event_type: "rate_limit.toggled".to_string(), 148 - severity: Severity::Info, 149 - actor_did: Some(auth.did.clone()), 150 - subject: None, 151 - detail: serde_json::json!({ "enabled": body.enabled }), 152 - }, 153 - state.db_backend, 154 - ) 155 - .await; 156 - 157 - Ok(StatusCode::NO_CONTENT) 158 - }
+18 -28
src/admin/types.rs
··· 325 325 } 326 326 327 327 // --------------------------------------------------------------------------- 328 + // Domain types 329 + // --------------------------------------------------------------------------- 330 + 331 + #[derive(Serialize)] 332 + pub(super) struct DomainResponse { 333 + pub(super) id: String, 334 + pub(super) url: String, 335 + pub(super) is_primary: bool, 336 + pub(super) created_at: String, 337 + pub(super) updated_at: String, 338 + } 339 + 340 + #[derive(Deserialize)] 341 + pub(super) struct CreateDomainBody { 342 + pub(super) url: String, 343 + } 344 + 345 + // --------------------------------------------------------------------------- 328 346 // API client types 329 347 // --------------------------------------------------------------------------- 330 348 ··· 380 398 pub(super) name: String, 381 399 pub(super) client_id_url: String, 382 400 } 383 - 384 - // --------------------------------------------------------------------------- 385 - // Rate limit types 386 - // --------------------------------------------------------------------------- 387 - 388 - #[derive(Deserialize)] 389 - pub(super) struct UpsertRateLimitBody { 390 - pub(super) capacity: u32, 391 - pub(super) refill_rate: f64, 392 - pub(super) default_query_cost: u32, 393 - pub(super) default_procedure_cost: u32, 394 - pub(super) default_proxy_cost: u32, 395 - } 396 - 397 - #[derive(Deserialize)] 398 - pub(super) struct SetEnabledBody { 399 - pub(super) enabled: bool, 400 - } 401 - 402 - #[derive(Serialize)] 403 - pub(super) struct RateLimitsResponse { 404 - pub(super) enabled: bool, 405 - pub(super) capacity: i32, 406 - pub(super) refill_rate: f64, 407 - pub(super) default_query_cost: i32, 408 - pub(super) default_procedure_cost: i32, 409 - pub(super) default_proxy_cost: i32, 410 - }
+54 -9
src/auth/client_registry.rs
··· 1 + use arc_swap::ArcSwap; 1 2 use dashmap::DashMap; 2 3 use std::sync::Arc; 3 4 ··· 27 28 /// shows the correct domain. The default client is HappyView's own identity, 28 29 /// used for dashboard auth. 29 30 pub struct OAuthClientRegistry { 30 - default_client: Arc<HappyViewOAuthClient>, 31 + primary_client: ArcSwap<HappyViewOAuthClient>, 32 + domain_clients: DashMap<String, Arc<HappyViewOAuthClient>>, 31 33 clients: DashMap<String, Arc<HappyViewOAuthClient>>, 32 34 } 33 35 34 36 impl OAuthClientRegistry { 35 - pub fn new(default_client: Arc<HappyViewOAuthClient>) -> Self { 37 + pub fn new(primary_client: Arc<HappyViewOAuthClient>) -> Self { 36 38 Self { 37 - default_client, 39 + primary_client: ArcSwap::new(primary_client), 40 + domain_clients: DashMap::new(), 38 41 clients: DashMap::new(), 39 42 } 40 43 } ··· 54 57 self.clients.get(client_id_url).map(|r| r.value().clone()) 55 58 } 56 59 57 - /// Look up a client by `client_id_url`, falling back to the default. 60 + /// Look up a client by `client_id_url`, falling back to the primary client. 58 61 pub fn get_or_default(&self, client_id_url: Option<&str>) -> Arc<HappyViewOAuthClient> { 59 62 if let Some(url) = client_id_url { 60 63 self.clients 61 64 .get(url) 62 65 .map(|r| r.value().clone()) 63 - .unwrap_or_else(|| self.default_client.clone()) 66 + .unwrap_or_else(|| self.primary_client.load_full()) 64 67 } else { 65 - self.default_client.clone() 68 + self.primary_client.load_full() 66 69 } 67 70 } 68 71 69 - /// Get the default (HappyView dashboard) client. 70 - pub fn default_client(&self) -> &Arc<HappyViewOAuthClient> { 71 - &self.default_client 72 + /// Get the primary (HappyView dashboard) client. 73 + pub fn primary_client(&self) -> Arc<HappyViewOAuthClient> { 74 + self.primary_client.load_full() 75 + } 76 + 77 + /// Register a domain-specific OAuth client. 78 + /// Inserts into both `domain_clients` (keyed by domain URL, for `get_for_domain`) 79 + /// and `clients` (keyed by client_id_url, for `get_or_default`). 80 + pub fn register_domain_client(&self, domain_url: String, client: Arc<HappyViewOAuthClient>) { 81 + let client_id_url = format!( 82 + "{}/oauth-client-metadata.json", 83 + domain_url.trim_end_matches('/') 84 + ); 85 + self.domain_clients.insert(domain_url, Arc::clone(&client)); 86 + self.clients.insert(client_id_url, client); 87 + } 88 + 89 + /// Remove a domain-specific OAuth client from both maps. 90 + pub fn remove_domain_client(&self, domain_url: &str) { 91 + self.domain_clients.remove(domain_url); 92 + let client_id_url = format!( 93 + "{}/oauth-client-metadata.json", 94 + domain_url.trim_end_matches('/') 95 + ); 96 + self.clients.remove(&client_id_url); 97 + } 98 + 99 + /// Look up a domain-specific OAuth client. 100 + pub fn get_domain_client(&self, domain_url: &str) -> Option<Arc<HappyViewOAuthClient>> { 101 + self.domain_clients 102 + .get(domain_url) 103 + .map(|r| r.value().clone()) 104 + } 105 + 106 + /// Get the OAuth client for a domain, falling back to the primary client. 107 + pub fn get_for_domain(&self, domain_url: &str) -> Arc<HappyViewOAuthClient> { 108 + self.domain_clients 109 + .get(domain_url) 110 + .map(|r| r.value().clone()) 111 + .unwrap_or_else(|| self.primary_client.load_full()) 112 + } 113 + 114 + /// Replace the primary OAuth client (e.g. when admin changes the primary domain). 115 + pub fn set_primary_client(&self, client: Arc<HappyViewOAuthClient>) { 116 + self.primary_client.store(client); 72 117 } 73 118 74 119 /// Build and register a single OAuth client from API client metadata.
+16 -5
src/auth/routes.rs
··· 57 57 async fn login( 58 58 State(state): State<AppState>, 59 59 jar: SignedCookieJar<Key>, 60 + domain: Option<axum::extract::Extension<std::sync::Arc<crate::domain::Domain>>>, 60 61 Query(query): Query<LoginQuery>, 61 62 ) -> Result<(SignedCookieJar<Key>, Json<serde_json::Value>), AppError> { 62 63 tracing::debug!(handle = %query.handle, redirect_uri = ?query.redirect_uri, scope = ?query.scope, "login request"); ··· 77 78 78 79 tracing::debug!(scopes = ?scopes, client_id = ?query.client_id, "resolved oauth scopes"); 79 80 81 + // For dashboard logins (no explicit client_id), use the domain's OAuth client 82 + let domain_url = domain.map(|d| d.0.url.clone()); 83 + let effective_client_id = if query.client_id.is_some() { 84 + query.client_id.clone() 85 + } else { 86 + domain_url 87 + .as_ref() 88 + .map(|du| format!("{}/oauth-client-metadata.json", du.trim_end_matches('/'))) 89 + }; 90 + 80 91 // Select the appropriate OAuth client based on client_id 81 - let oauth_client = state.oauth.get_or_default(query.client_id.as_deref()); 92 + let oauth_client = state.oauth.get_or_default(effective_client_id.as_deref()); 82 93 83 94 // Hold the authorize lock so that authorize() + take_last_state_key() are atomic. 84 95 // This prevents concurrent logins from swapping each other's state keys. ··· 106 117 // Store the redirect URI in the database, keyed by the OAuth state parameter. 107 118 // This avoids third-party cookie issues when Pentaract (cross-origin) calls this endpoint. 108 119 // Store redirect URI and client_id for the callback to use 109 - if query.redirect_uri.is_some() || query.client_id.is_some() { 120 + if query.redirect_uri.is_some() || effective_client_id.is_some() { 110 121 let redirect_uri = query.redirect_uri.as_deref().unwrap_or(""); 111 - tracing::debug!(oauth_state = ?oauth_state, redirect_uri = %redirect_uri, client_id = ?query.client_id, "storing redirect for state"); 122 + tracing::debug!(oauth_state = ?oauth_state, redirect_uri = %redirect_uri, client_id = ?effective_client_id, "storing redirect for state"); 112 123 113 124 if let Some(oauth_state) = oauth_state { 114 125 let now = now_rfc3339(); ··· 120 131 let _ = sqlx::query(&sql) 121 132 .bind(&oauth_state) 122 133 .bind(redirect_uri) 123 - .bind(query.client_id.as_deref()) 134 + .bind(effective_client_id.as_deref()) 124 135 .bind(&now) 125 136 .bind(&expires_at) 126 137 .execute(&state.db) ··· 257 268 let raw = cookie.value().to_string(); 258 269 let did_str = raw.split('\n').next().unwrap_or(&raw).to_string(); 259 270 if let Ok(did) = atrium_api::types::string::Did::new(did_str) { 260 - let _ = state.oauth.default_client().revoke(&did).await; 271 + let _ = state.oauth.primary_client().revoke(&did).await; 261 272 } 262 273 } 263 274
+198
src/domain.rs
··· 1 + use serde::{Deserialize, Serialize}; 2 + use std::collections::HashMap; 3 + use std::sync::Arc; 4 + use tokio::sync::RwLock; 5 + 6 + #[derive(Debug, Clone, Serialize, Deserialize)] 7 + pub struct Domain { 8 + pub id: String, 9 + pub url: String, 10 + pub is_primary: bool, 11 + pub created_at: String, 12 + pub updated_at: String, 13 + } 14 + 15 + impl Domain { 16 + pub fn host(&self) -> &str { 17 + let after_scheme = self 18 + .url 19 + .strip_prefix("https://") 20 + .or_else(|| self.url.strip_prefix("http://")) 21 + .unwrap_or(&self.url); 22 + after_scheme.split(':').next().unwrap_or(after_scheme) 23 + } 24 + } 25 + 26 + #[derive(Clone)] 27 + pub struct DomainCache { 28 + by_host: Arc<RwLock<HashMap<String, Arc<Domain>>>>, 29 + primary: Arc<RwLock<Option<Arc<Domain>>>>, 30 + } 31 + 32 + impl DomainCache { 33 + pub fn new() -> Self { 34 + Self { 35 + by_host: Arc::new(RwLock::new(HashMap::new())), 36 + primary: Arc::new(RwLock::new(None)), 37 + } 38 + } 39 + 40 + pub async fn load(&self, domains: Vec<Domain>) { 41 + let mut by_host = self.by_host.write().await; 42 + let mut primary = self.primary.write().await; 43 + 44 + by_host.clear(); 45 + *primary = None; 46 + 47 + for domain in domains { 48 + let arc = Arc::new(domain); 49 + if arc.is_primary { 50 + *primary = Some(arc.clone()); 51 + } 52 + by_host.insert(arc.host().to_string(), arc); 53 + } 54 + } 55 + 56 + pub async fn get(&self, host: &str) -> Option<Arc<Domain>> { 57 + let by_host = self.by_host.read().await; 58 + by_host.get(host).cloned() 59 + } 60 + 61 + pub async fn primary(&self) -> Option<Arc<Domain>> { 62 + let primary = self.primary.read().await; 63 + primary.clone() 64 + } 65 + 66 + pub async fn insert(&self, domain: Domain) { 67 + let arc = Arc::new(domain); 68 + let mut by_host = self.by_host.write().await; 69 + let mut primary = self.primary.write().await; 70 + 71 + if arc.is_primary { 72 + *primary = Some(arc.clone()); 73 + } 74 + by_host.insert(arc.host().to_string(), arc); 75 + } 76 + 77 + pub async fn remove(&self, host: &str) { 78 + let mut by_host = self.by_host.write().await; 79 + let removed = by_host.remove(host); 80 + 81 + if let Some(domain) = removed 82 + && domain.is_primary 83 + { 84 + let mut primary = self.primary.write().await; 85 + *primary = None; 86 + } 87 + } 88 + 89 + pub async fn set_primary(&self, host: &str) { 90 + let by_host = self.by_host.read().await; 91 + if let Some(domain) = by_host.get(host).cloned() { 92 + drop(by_host); 93 + let mut primary = self.primary.write().await; 94 + *primary = Some(domain); 95 + } 96 + } 97 + 98 + pub async fn all(&self) -> Vec<Arc<Domain>> { 99 + let by_host = self.by_host.read().await; 100 + by_host.values().cloned().collect() 101 + } 102 + } 103 + 104 + impl Default for DomainCache { 105 + fn default() -> Self { 106 + Self::new() 107 + } 108 + } 109 + 110 + #[cfg(test)] 111 + mod tests { 112 + use super::*; 113 + use uuid::Uuid; 114 + 115 + fn make_domain(url: &str, is_primary: bool) -> Domain { 116 + Domain { 117 + id: Uuid::new_v4().to_string(), 118 + url: url.to_string(), 119 + is_primary, 120 + created_at: "2024-01-01T00:00:00Z".to_string(), 121 + updated_at: "2024-01-01T00:00:00Z".to_string(), 122 + } 123 + } 124 + 125 + #[test] 126 + fn host_strips_https() { 127 + let domain = make_domain("https://example.com", false); 128 + assert_eq!(domain.host(), "example.com"); 129 + } 130 + 131 + #[test] 132 + fn host_strips_http() { 133 + let domain = make_domain("http://localhost:3000", false); 134 + assert_eq!(domain.host(), "localhost"); 135 + } 136 + 137 + #[tokio::test] 138 + async fn load_and_get() { 139 + let cache = DomainCache::new(); 140 + let domains = vec![ 141 + make_domain("https://example.com", true), 142 + make_domain("https://other.com", false), 143 + ]; 144 + cache.load(domains).await; 145 + 146 + let found = cache.get("example.com").await; 147 + assert!(found.is_some()); 148 + assert_eq!(found.unwrap().url, "https://example.com"); 149 + 150 + assert!(cache.get("other.com").await.is_some()); 151 + 152 + let missing = cache.get("unknown.com").await; 153 + assert!(missing.is_none()); 154 + } 155 + 156 + #[tokio::test] 157 + async fn primary_returns_primary_domain() { 158 + let cache = DomainCache::new(); 159 + let domains = vec![ 160 + make_domain("https://example.com", false), 161 + make_domain("https://primary.com", true), 162 + ]; 163 + cache.load(domains).await; 164 + 165 + let primary = cache.primary().await; 166 + assert!(primary.is_some()); 167 + assert_eq!(primary.unwrap().url, "https://primary.com"); 168 + } 169 + 170 + #[tokio::test] 171 + async fn insert_and_remove() { 172 + let cache = DomainCache::new(); 173 + let domain = make_domain("https://example.com", false); 174 + cache.insert(domain).await; 175 + 176 + assert!(cache.get("example.com").await.is_some()); 177 + 178 + cache.remove("example.com").await; 179 + assert!(cache.get("example.com").await.is_none()); 180 + } 181 + 182 + #[tokio::test] 183 + async fn set_primary_updates() { 184 + let cache = DomainCache::new(); 185 + let domains = vec![ 186 + make_domain("https://example.com", true), 187 + make_domain("https://other.com", false), 188 + ]; 189 + cache.load(domains).await; 190 + 191 + // Initially example.com is primary 192 + assert_eq!(cache.primary().await.unwrap().url, "https://example.com"); 193 + 194 + // Change primary to other.com 195 + cache.set_primary("other.com").await; 196 + assert_eq!(cache.primary().await.unwrap().url, "https://other.com"); 197 + } 198 + }
+38
src/domain_middleware.rs
··· 1 + use axum::{ 2 + extract::{Request, State}, 3 + http::StatusCode, 4 + middleware::Next, 5 + response::Response, 6 + }; 7 + use std::sync::Arc; 8 + 9 + use crate::AppState; 10 + use crate::domain::Domain; 11 + 12 + pub async fn resolve_domain( 13 + State(state): State<AppState>, 14 + mut req: Request, 15 + next: Next, 16 + ) -> Result<Response, (StatusCode, &'static str)> { 17 + let host = req 18 + .headers() 19 + .get("x-forwarded-host") 20 + .or_else(|| req.headers().get("host")) 21 + .and_then(|v| v.to_str().ok()) 22 + .map(|h| h.split(':').next().unwrap_or(h)) 23 + .unwrap_or(""); 24 + 25 + let domain = state.domain_cache.get(host).await; 26 + 27 + match domain { 28 + Some(domain) => { 29 + req.extensions_mut().insert(domain); 30 + Ok(next.run(req).await) 31 + } 32 + None => Err((StatusCode::MISDIRECTED_REQUEST, "Unknown host")), 33 + } 34 + } 35 + 36 + pub fn extract_domain(req: &Request) -> Option<Arc<Domain>> { 37 + req.extensions().get::<Arc<Domain>>().cloned() 38 + }
+6 -1
src/external_auth/routes.rs
··· 75 75 State(app_state): State<AppState>, 76 76 Path(plugin_id): Path<String>, 77 77 Query(query): Query<AuthorizeQuery>, 78 + domain: Option<axum::extract::Extension<std::sync::Arc<crate::domain::Domain>>>, 78 79 claims: Claims, 79 80 ) -> Result<Json<serde_json::Value>, AppError> { 80 81 let _plugin = app_state ··· 128 129 129 130 // Build the backend callback URL for OpenID/OAuth return_to 130 131 // This ensures the auth provider redirects back to the backend, not the frontend 132 + let domain_url = domain 133 + .map(|d| d.0.url.clone()) 134 + .unwrap_or_else(|| app_state.config.public_url.clone()); 135 + 131 136 let callback_url = format!( 132 137 "{}/external-auth/{}/callback", 133 - app_state.config.public_url.trim_end_matches('/'), 138 + domain_url.trim_end_matches('/'), 134 139 plugin_id 135 140 ); 136 141
+3
src/lib.rs
··· 3 3 pub mod config; 4 4 pub mod db; 5 5 pub mod dns; 6 + pub mod domain; 7 + pub mod domain_middleware; 6 8 pub mod error; 7 9 pub mod event_log; 8 10 pub mod external_auth; ··· 54 56 pub http: reqwest::Client, 55 57 pub db: sqlx::AnyPool, 56 58 pub db_backend: DatabaseBackend, 59 + pub domain_cache: domain::DomainCache, 57 60 pub lexicons: LexiconRegistry, 58 61 pub collections_tx: watch::Sender<Vec<String>>, 59 62 pub labeler_subscriptions_tx: watch::Sender<()>,
+5 -7
src/lua/atproto_api.rs
··· 336 336 http: reqwest::Client::new(), 337 337 db: test_db.clone(), 338 338 db_backend: DatabaseBackend::Sqlite, 339 + domain_cache: crate::domain::DomainCache::new(), 339 340 lexicons: LexiconRegistry::new(), 340 341 collections_tx: tx, 341 342 labeler_subscriptions_tx: labeler_tx, 342 343 rate_limiter: crate::rate_limit::RateLimiter::new( 343 - false, 344 - crate::rate_limit::RateLimitConfig { 345 - capacity: 100, 346 - refill_rate: 2.0, 347 - default_query_cost: 1, 348 - default_procedure_cost: 1, 349 - default_proxy_cost: 1, 344 + crate::rate_limit::RateLimitDefaults { 345 + query_cost: 1, 346 + procedure_cost: 1, 347 + proxy_cost: 1, 350 348 }, 351 349 ), 352 350 oauth: std::sync::Arc::new(crate::auth::OAuthClientRegistry::new(std::sync::Arc::new(
+5 -7
src/lua/db_api.rs
··· 687 687 http: reqwest::Client::new(), 688 688 db: test_db.clone(), 689 689 db_backend: DatabaseBackend::Sqlite, 690 + domain_cache: crate::domain::DomainCache::new(), 690 691 lexicons: LexiconRegistry::new(), 691 692 collections_tx: tx, 692 693 labeler_subscriptions_tx: labeler_tx, 693 694 rate_limiter: crate::rate_limit::RateLimiter::new( 694 - false, 695 - crate::rate_limit::RateLimitConfig { 696 - capacity: 100, 697 - refill_rate: 2.0, 698 - default_query_cost: 1, 699 - default_procedure_cost: 1, 700 - default_proxy_cost: 1, 695 + crate::rate_limit::RateLimitDefaults { 696 + query_cost: 1, 697 + procedure_cost: 1, 698 + proxy_cost: 1, 701 699 }, 702 700 ), 703 701 oauth: std::sync::Arc::new(crate::auth::OAuthClientRegistry::new(std::sync::Arc::new(
+5 -7
src/lua/execute.rs
··· 1068 1068 http: reqwest::Client::new(), 1069 1069 db: test_db.clone(), 1070 1070 db_backend: DatabaseBackend::Sqlite, 1071 + domain_cache: crate::domain::DomainCache::new(), 1071 1072 lexicons: LexiconRegistry::new(), 1072 1073 collections_tx: tx, 1073 1074 labeler_subscriptions_tx: labeler_tx, 1074 1075 rate_limiter: crate::rate_limit::RateLimiter::new( 1075 - false, 1076 - crate::rate_limit::RateLimitConfig { 1077 - capacity: 100, 1078 - refill_rate: 2.0, 1079 - default_query_cost: 1, 1080 - default_procedure_cost: 1, 1081 - default_proxy_cost: 1, 1076 + crate::rate_limit::RateLimitDefaults { 1077 + query_cost: 1, 1078 + procedure_cost: 1, 1079 + proxy_cost: 1, 1082 1080 }, 1083 1081 ), 1084 1082 oauth: std::sync::Arc::new(crate::auth::OAuthClientRegistry::new(std::sync::Arc::new(
+5 -7
src/lua/http_api.rs
··· 153 153 http: reqwest::Client::new(), 154 154 db: test_db.clone(), 155 155 db_backend: crate::db::DatabaseBackend::Sqlite, 156 + domain_cache: crate::domain::DomainCache::new(), 156 157 lexicons: LexiconRegistry::new(), 157 158 collections_tx: tx, 158 159 labeler_subscriptions_tx: labeler_tx, 159 160 rate_limiter: crate::rate_limit::RateLimiter::new( 160 - false, 161 - crate::rate_limit::RateLimitConfig { 162 - capacity: 100, 163 - refill_rate: 2.0, 164 - default_query_cost: 1, 165 - default_procedure_cost: 1, 166 - default_proxy_cost: 1, 161 + crate::rate_limit::RateLimitDefaults { 162 + query_cost: 1, 163 + procedure_cost: 1, 164 + proxy_cost: 1, 167 165 }, 168 166 ), 169 167 oauth: std::sync::Arc::new(crate::auth::OAuthClientRegistry::new(std::sync::Arc::new(
+5 -7
src/lua/xrpc_api.rs
··· 257 257 http: reqwest::Client::new(), 258 258 db: test_db.clone(), 259 259 db_backend: DatabaseBackend::Sqlite, 260 + domain_cache: crate::domain::DomainCache::new(), 260 261 lexicons: LexiconRegistry::new(), 261 262 collections_tx: tx, 262 263 labeler_subscriptions_tx: labeler_tx, 263 264 rate_limiter: crate::rate_limit::RateLimiter::new( 264 - false, 265 - crate::rate_limit::RateLimitConfig { 266 - capacity: 100, 267 - refill_rate: 2.0, 268 - default_query_cost: 1, 269 - default_procedure_cost: 1, 270 - default_proxy_cost: 1, 265 + crate::rate_limit::RateLimitDefaults { 266 + query_cost: 1, 267 + procedure_cost: 1, 268 + proxy_cost: 1, 271 269 }, 272 270 ), 273 271 oauth: std::sync::Arc::new(crate::auth::OAuthClientRegistry::new(std::sync::Arc::new(
+192 -11
src/main.rs
··· 5 5 use happyview::db; 6 6 use happyview::dns::NativeDnsResolver; 7 7 use happyview::lexicon::{LexiconRegistry, ParsedLexicon, ProcedureAction}; 8 - use happyview::rate_limit::{RateLimitConfig, RateLimiter}; 8 + use happyview::rate_limit::{RateLimitDefaults, RateLimiter}; 9 9 use happyview::resolve::{fetch_lexicon_from_pds, resolve_nsid_authority}; 10 10 use happyview::{AppState, jetstream, labeler, server}; 11 + use sqlx::Row; 11 12 use tokio::sync::watch; 12 13 use tracing::{info, warn}; 13 14 ··· 264 265 } 265 266 } 266 267 267 - // Initialize rate limiter from DB. 268 - let rl_state = RateLimiter::load_from_db(&db_pool).await; 269 - let rate_limiter = RateLimiter::new(rl_state.enabled, rl_state.global); 268 + // Seed and load per-instance default token costs from instance_settings. 269 + let defaults = seed_and_load_rate_limit_defaults(&db_pool, db_backend).await; 270 + let rate_limiter = RateLimiter::new(defaults); 270 271 tokio::spawn(rate_limiter.clone().spawn_cleanup()); 271 272 272 273 // Load per-client rate limit configs and identities from api_clients table. ··· 279 280 .await 280 281 .unwrap_or_default(); 281 282 282 - let global = rate_limiter.global_config(); 283 283 for (client_key, secret_hash, client_uri, capacity, refill_rate) in client_rows { 284 284 rate_limiter.register_client_identity( 285 285 client_key.clone(), ··· 291 291 if let (Some(cap), Some(refill)) = (capacity, refill_rate) { 292 292 rate_limiter.register_client_config( 293 293 client_key, 294 - RateLimitConfig { 294 + happyview::rate_limit::RateLimitConfig { 295 295 capacity: cap as u32, 296 296 refill_rate: refill, 297 - default_query_cost: global.default_query_cost, 298 - default_procedure_cost: global.default_procedure_cost, 299 - default_proxy_cost: global.default_proxy_cost, 297 + default_query_cost: defaults.query_cost, 298 + default_procedure_cost: defaults.procedure_cost, 299 + default_proxy_cost: defaults.proxy_cost, 300 300 }, 301 301 ); 302 302 } 303 303 } 304 304 } 305 305 306 + // Seed and load domain cache 307 + let domain_cache = happyview::domain::DomainCache::new(); 308 + { 309 + let count_sql = happyview::db::adapt_sql("SELECT COUNT(*) FROM domains", db_backend); 310 + let row = sqlx::query(&count_sql) 311 + .fetch_one(&db_pool) 312 + .await 313 + .expect("Failed to count domains"); 314 + let count: i64 = row.try_get(0).unwrap_or(0); 315 + 316 + if count == 0 { 317 + let id = uuid::Uuid::new_v4().to_string(); 318 + let now = happyview::db::now_rfc3339(); 319 + let insert_sql = happyview::db::adapt_sql( 320 + "INSERT INTO domains (id, url, is_primary, created_at, updated_at) VALUES (?, ?, 1, ?, ?)", 321 + db_backend, 322 + ); 323 + sqlx::query(&insert_sql) 324 + .bind(&id) 325 + .bind(&config.public_url) 326 + .bind(&now) 327 + .bind(&now) 328 + .execute(&db_pool) 329 + .await 330 + .expect("Failed to insert primary domain"); 331 + info!("Seeded primary domain: {}", config.public_url); 332 + } 333 + 334 + let select_sql = happyview::db::adapt_sql( 335 + "SELECT id, url, is_primary, created_at, updated_at FROM domains", 336 + db_backend, 337 + ); 338 + let rows = sqlx::query(&select_sql) 339 + .fetch_all(&db_pool) 340 + .await 341 + .expect("Failed to load domains"); 342 + 343 + let domains: Vec<happyview::domain::Domain> = rows 344 + .into_iter() 345 + .map(|row| { 346 + let is_primary_int: i32 = row.try_get("is_primary").unwrap_or(0); 347 + happyview::domain::Domain { 348 + id: row.try_get("id").unwrap_or_default(), 349 + url: row.try_get("url").unwrap_or_default(), 350 + is_primary: is_primary_int != 0, 351 + created_at: row.try_get("created_at").unwrap_or_default(), 352 + updated_at: row.try_get("updated_at").unwrap_or_default(), 353 + } 354 + }) 355 + .collect(); 356 + 357 + let domain_count = domains.len(); 358 + domain_cache.load(domains).await; 359 + info!("Loaded {} domain(s) into cache", domain_count); 360 + } 361 + 306 362 // Build atrium-oauth client 307 363 let dns = NativeDnsResolver::new(); 308 364 let callback_url = format!("{}/auth/callback", config.public_url.trim_end_matches('/')); ··· 394 450 let (labeler_subscriptions_tx, labeler_subscriptions_rx) = watch::channel(()); 395 451 396 452 // Build the OAuth client registry and load API clients from DB 397 - let oauth_registry = Arc::new(happyview::auth::OAuthClientRegistry::new(Arc::new( 398 - oauth_client, 453 + let oauth_client_arc = Arc::new(oauth_client); 454 + let oauth_registry = Arc::new(happyview::auth::OAuthClientRegistry::new(Arc::clone( 455 + &oauth_client_arc, 399 456 ))); 400 457 oauth_registry 401 458 .load_from_db( ··· 407 464 ) 408 465 .await; 409 466 467 + // Register the primary domain's OAuth client in domain_clients 468 + if let Some(ref pd) = domain_cache.primary().await { 469 + oauth_registry.register_domain_client(pd.url.clone(), Arc::clone(&oauth_client_arc)); 470 + } 471 + 472 + // Build OAuth clients for all non-primary domains 473 + let all_domains = domain_cache.all().await; 474 + for domain in &all_domains { 475 + if domain.is_primary { 476 + continue; // Already registered above 477 + } 478 + 479 + let domain_callback_url = format!("{}/auth/callback", domain.url.trim_end_matches('/')); 480 + let domain_client_id = format!( 481 + "{}/oauth-client-metadata.json", 482 + domain.url.trim_end_matches('/') 483 + ); 484 + 485 + let domain_http = Arc::new(DefaultHttpClient::default()); 486 + let domain_resolver = OAuthResolverConfig { 487 + did_resolver: CommonDidResolver::new(CommonDidResolverConfig { 488 + plc_directory_url: config.plc_url.clone(), 489 + http_client: Arc::clone(&domain_http), 490 + }), 491 + handle_resolver: AtprotoHandleResolver::new(AtprotoHandleResolverConfig { 492 + dns_txt_resolver: NativeDnsResolver::new(), 493 + http_client: Arc::clone(&domain_http), 494 + }), 495 + authorization_server_metadata: Default::default(), 496 + protected_resource_metadata: Default::default(), 497 + }; 498 + 499 + match atrium_oauth::OAuthClient::new(OAuthClientConfig { 500 + client_metadata: AtprotoClientMetadata { 501 + client_id: domain_client_id, 502 + client_uri: Some(domain.url.clone()), 503 + redirect_uris: vec![domain_callback_url], 504 + token_endpoint_auth_method: AuthMethod::None, 505 + grant_types: vec![GrantType::AuthorizationCode, GrantType::RefreshToken], 506 + scopes: vec![Scope::Known(KnownScope::Atproto)], 507 + jwks_uri: None, 508 + token_endpoint_auth_signing_alg: None, 509 + }, 510 + keys: None, 511 + state_store: oauth_state_store.clone(), 512 + session_store: DbSessionStore::new(db_pool.clone(), db_backend), 513 + resolver: domain_resolver, 514 + }) { 515 + Ok(client) => { 516 + info!(domain = %domain.url, "Registered domain OAuth client"); 517 + oauth_registry.register_domain_client(domain.url.clone(), Arc::new(client)); 518 + } 519 + Err(e) => { 520 + tracing::error!(domain = %domain.url, error = %e, "Failed to create domain OAuth client"); 521 + } 522 + } 523 + } 524 + 410 525 let official_registry: happyview::plugin::official_registry::SharedRegistry = 411 526 std::sync::Arc::new(tokio::sync::RwLock::new( 412 527 happyview::plugin::official_registry::OfficialRegistryState::default(), ··· 424 539 http, 425 540 db: db_pool, 426 541 db_backend, 542 + domain_cache: domain_cache.clone(), 427 543 lexicons, 428 544 collections_tx, 429 545 labeler_subscriptions_tx, ··· 460 576 461 577 axum::serve(listener, app).await.expect("server error"); 462 578 } 579 + 580 + async fn seed_and_load_rate_limit_defaults( 581 + pool: &sqlx::AnyPool, 582 + backend: happyview::db::DatabaseBackend, 583 + ) -> RateLimitDefaults { 584 + use happyview::rate_limit::{ 585 + SEED_DEFAULT_PROCEDURE_COST, SEED_DEFAULT_PROXY_COST, SEED_DEFAULT_QUERY_COST, 586 + SETTING_DEFAULT_PROCEDURE_COST, SETTING_DEFAULT_PROXY_COST, SETTING_DEFAULT_QUERY_COST, 587 + }; 588 + 589 + async fn seed_and_read( 590 + pool: &sqlx::AnyPool, 591 + backend: happyview::db::DatabaseBackend, 592 + key: &str, 593 + seed: u32, 594 + ) -> u32 { 595 + if happyview::admin::settings::get_setting(pool, key, backend) 596 + .await 597 + .is_none() 598 + { 599 + let now = happyview::db::now_rfc3339(); 600 + let sql = happyview::db::adapt_sql( 601 + "INSERT INTO instance_settings (key, value, updated_at) VALUES (?, ?, ?) ON CONFLICT (key) DO NOTHING", 602 + backend, 603 + ); 604 + if let Err(e) = sqlx::query(&sql) 605 + .bind(key) 606 + .bind(seed.to_string()) 607 + .bind(&now) 608 + .execute(pool) 609 + .await 610 + { 611 + warn!(error = %e, key = key, "failed to seed rate-limit default"); 612 + } 613 + } 614 + happyview::admin::settings::get_setting(pool, key, backend) 615 + .await 616 + .and_then(|s| s.parse::<u32>().ok()) 617 + .unwrap_or(seed) 618 + } 619 + 620 + RateLimitDefaults { 621 + query_cost: seed_and_read( 622 + pool, 623 + backend, 624 + SETTING_DEFAULT_QUERY_COST, 625 + SEED_DEFAULT_QUERY_COST, 626 + ) 627 + .await, 628 + procedure_cost: seed_and_read( 629 + pool, 630 + backend, 631 + SETTING_DEFAULT_PROCEDURE_COST, 632 + SEED_DEFAULT_PROCEDURE_COST, 633 + ) 634 + .await, 635 + proxy_cost: seed_and_read( 636 + pool, 637 + backend, 638 + SETTING_DEFAULT_PROXY_COST, 639 + SEED_DEFAULT_PROXY_COST, 640 + ) 641 + .await, 642 + } 643 + }
+123 -321
src/rate_limit.rs
··· 1 - use arc_swap::ArcSwap; 2 1 use dashmap::DashMap; 3 - use sqlx::AnyPool; 4 2 use std::sync::Arc; 5 - use std::sync::atomic::{AtomicBool, Ordering}; 6 3 use std::time::{Instant, SystemTime, UNIX_EPOCH}; 4 + 5 + /// Hardcoded seed values for the per-instance default token costs. These are 6 + /// only used by the startup seeding step in `main.rs` to populate fresh 7 + /// `instance_settings` rows; at runtime the values are read from the DB into 8 + /// `RateLimitDefaults`. 9 + pub const SEED_DEFAULT_QUERY_COST: u32 = 1; 10 + pub const SEED_DEFAULT_PROCEDURE_COST: u32 = 1; 11 + pub const SEED_DEFAULT_PROXY_COST: u32 = 1; 12 + 13 + /// `instance_settings` keys for the seeded defaults. 14 + pub const SETTING_DEFAULT_QUERY_COST: &str = "rate_limit.default_query_cost"; 15 + pub const SETTING_DEFAULT_PROCEDURE_COST: &str = "rate_limit.default_procedure_cost"; 16 + pub const SETTING_DEFAULT_PROXY_COST: &str = "rate_limit.default_proxy_cost"; 17 + 18 + /// Default token costs per XRPC request type, loaded from `instance_settings` 19 + /// at startup. Owned by the `RateLimiter`. 20 + #[derive(Clone, Copy)] 21 + pub struct RateLimitDefaults { 22 + pub query_cost: u32, 23 + pub procedure_cost: u32, 24 + pub proxy_cost: u32, 25 + } 7 26 8 27 pub struct RateLimitConfig { 9 28 pub capacity: u32, ··· 44 63 } 45 64 46 65 pub struct RateLimiter { 47 - enabled: AtomicBool, 66 + defaults: RateLimitDefaults, 48 67 buckets: DashMap<String, TokenBucket>, 49 - global_config: ArcSwap<RateLimitConfig>, 50 - /// Per-client config overrides, keyed by client_key (e.g. "hvc_...") 68 + /// Per-client config, keyed by client_key (e.g. "hvc_..."). Presence in 69 + /// this map is the *only* thing that enables rate limiting for a key — 70 + /// unregistered keys are always allowed. 51 71 client_configs: DashMap<String, RateLimitConfig>, 52 72 /// Registered client identities, keyed by client_key 53 73 client_identities: DashMap<String, ClientIdentity>, 54 74 } 55 75 56 - pub struct RateLimiterState { 57 - pub enabled: bool, 58 - pub global: RateLimitConfig, 59 - } 60 - 61 76 fn now_unix() -> u64 { 62 77 SystemTime::now() 63 78 .duration_since(UNIX_EPOCH) ··· 66 81 } 67 82 68 83 impl RateLimiter { 69 - pub fn new(enabled: bool, global: RateLimitConfig) -> Arc<Self> { 84 + pub fn new(defaults: RateLimitDefaults) -> Arc<Self> { 70 85 Arc::new(Self { 71 - enabled: AtomicBool::new(enabled), 86 + defaults, 72 87 buckets: DashMap::new(), 73 - global_config: ArcSwap::new(Arc::new(global)), 74 88 client_configs: DashMap::new(), 75 89 client_identities: DashMap::new(), 76 90 }) 77 91 } 78 92 79 - pub fn check(&self, key: &str, cost: u32) -> CheckResult { 80 - if !self.enabled.load(Ordering::Relaxed) { 81 - return CheckResult::Disabled; 82 - } 93 + pub fn defaults(&self) -> RateLimitDefaults { 94 + self.defaults 95 + } 83 96 84 - // Use per-client config if available, otherwise fall back to global 85 - let (capacity, refill_rate) = if let Some(client_cfg) = self.client_configs.get(key) { 86 - (client_cfg.capacity, client_cfg.refill_rate) 87 - } else { 88 - let global = self.global_config.load(); 89 - (global.capacity, global.refill_rate) 97 + pub fn check(&self, key: &str, cost: u32) -> CheckResult { 98 + let (capacity, refill_rate) = match self.client_configs.get(key) { 99 + Some(cfg) => (cfg.capacity, cfg.refill_rate), 100 + None => return CheckResult::Disabled, 90 101 }; 91 102 let cost_f64 = cost as f64; 92 103 ··· 103 114 last_access: now, 104 115 }); 105 116 106 - // Hot-reload config changes 107 117 bucket.capacity = capacity; 108 118 bucket.refill_rate = refill_rate; 109 119 110 - // Refill tokens 111 120 let elapsed = now.duration_since(bucket.last_refill).as_secs_f64(); 112 121 bucket.tokens = (bucket.tokens + elapsed * refill_rate).min(capacity as f64); 113 122 bucket.last_refill = now; ··· 137 146 } 138 147 } 139 148 140 - /// Get the default cost for a given request type. 141 - pub fn default_cost_for_type(&self, request_type: &str) -> u32 { 142 - let config = self.global_config.load(); 149 + /// Get the default cost for a request type. Looks up the per-client 150 + /// override if one is registered, otherwise falls back to the seeded 151 + /// instance defaults. 152 + pub fn default_cost_for_type(&self, client_key: &str, request_type: &str) -> u32 { 153 + if let Some(cfg) = self.client_configs.get(client_key) { 154 + return match request_type { 155 + "query" => cfg.default_query_cost, 156 + "procedure" => cfg.default_procedure_cost, 157 + "proxy" => cfg.default_proxy_cost, 158 + _ => 1, 159 + }; 160 + } 143 161 match request_type { 144 - "query" => config.default_query_cost, 145 - "procedure" => config.default_procedure_cost, 146 - "proxy" => config.default_proxy_cost, 162 + "query" => self.defaults.query_cost, 163 + "procedure" => self.defaults.procedure_cost, 164 + "proxy" => self.defaults.proxy_cost, 147 165 _ => 1, 148 166 } 149 167 } 150 168 151 - /// Get a snapshot of the current global config. 152 - pub fn global_config(&self) -> Arc<RateLimitConfig> { 153 - self.global_config.load_full() 154 - } 155 - 156 - pub fn set_enabled(&self, enabled: bool) { 157 - self.enabled.store(enabled, Ordering::Relaxed); 158 - } 159 - 160 - pub fn is_enabled(&self) -> bool { 161 - self.enabled.load(Ordering::Relaxed) 162 - } 163 - 164 - pub fn update_config(&self, global: RateLimitConfig) { 165 - self.global_config.store(Arc::new(global)); 166 - } 167 - 168 - /// Register a per-client rate limit config override. 169 169 pub fn register_client_config(&self, client_key: String, config: RateLimitConfig) { 170 170 self.client_configs.insert(client_key, config); 171 171 } 172 172 173 - /// Remove a per-client rate limit config override. 174 173 pub fn remove_client_config(&self, client_key: &str) { 175 174 self.client_configs.remove(client_key); 176 175 } 177 176 178 - /// Register a client identity (key, secret hash, and client URI). 179 177 pub fn register_client_identity(&self, client_key: String, identity: ClientIdentity) { 180 178 self.client_identities.insert(client_key, identity); 181 179 } 182 180 183 - /// Remove a client identity. 184 181 pub fn remove_client_identity(&self, client_key: &str) { 185 182 self.client_identities.remove(client_key); 186 183 } 187 184 188 - /// Validate a client key + secret combination. Returns true if the secret 189 - /// hash matches the stored hash for this client key. 190 185 pub fn validate_client_secret(&self, client_key: &str, secret: &str) -> bool { 191 186 use sha2::{Digest, Sha256}; 192 187 if let Some(identity) = self.client_identities.get(client_key) { ··· 197 192 } 198 193 } 199 194 200 - /// Validate a client key + origin combination. Returns true if the origin 201 - /// matches the registered client_uri for this client key. 202 195 pub fn validate_client_origin(&self, client_key: &str, origin: &str) -> bool { 203 196 if let Some(identity) = self.client_identities.get(client_key) { 204 - // Compare origins: strip trailing slash for consistency 205 197 let registered = identity.client_uri.trim_end_matches('/'); 206 198 let provided = origin.trim_end_matches('/'); 207 199 registered == provided ··· 210 202 } 211 203 } 212 204 213 - /// Check whether a client key is registered. 214 205 pub fn is_valid_client_key(&self, client_key: &str) -> bool { 215 206 self.client_identities.contains_key(client_key) 216 207 } 217 208 218 209 pub async fn spawn_cleanup(self: Arc<Self>) { 219 210 let interval = tokio::time::Duration::from_secs(60); 220 - let stale_threshold = std::time::Duration::from_secs(300); // 5 minutes 211 + let stale_threshold = std::time::Duration::from_secs(300); 221 212 loop { 222 213 tokio::time::sleep(interval).await; 223 214 let now = Instant::now(); ··· 225 216 .retain(|_, bucket| now.duration_since(bucket.last_access) < stale_threshold); 226 217 } 227 218 } 228 - 229 - pub async fn load_from_db(db: &AnyPool) -> RateLimiterState { 230 - // Load enabled flag 231 - let enabled: bool = sqlx::query_scalar::<_, String>( 232 - "SELECT value FROM rate_limit_settings WHERE key = 'enabled'", 233 - ) 234 - .fetch_optional(db) 235 - .await 236 - .ok() 237 - .flatten() 238 - .map(|v| v == "true") 239 - .unwrap_or(true); 240 - 241 - // Load global rate limit config (method IS NULL row) 242 - let row: Option<(i32, f64, i32, i32, i32)> = sqlx::query_as( 243 - "SELECT capacity, refill_rate, default_query_cost, default_procedure_cost, default_proxy_cost FROM rate_limits WHERE method IS NULL", 244 - ) 245 - .fetch_optional(db) 246 - .await 247 - .unwrap_or(None); 248 - 249 - let global = match row { 250 - Some((capacity, refill_rate, query_cost, procedure_cost, proxy_cost)) => { 251 - RateLimitConfig { 252 - capacity: capacity as u32, 253 - refill_rate, 254 - default_query_cost: query_cost as u32, 255 - default_procedure_cost: procedure_cost as u32, 256 - default_proxy_cost: proxy_cost as u32, 257 - } 258 - } 259 - None => RateLimitConfig { 260 - capacity: 100, 261 - refill_rate: 2.0, 262 - default_query_cost: 1, 263 - default_procedure_cost: 1, 264 - default_proxy_cost: 1, 265 - }, 266 - }; 267 - 268 - RateLimiterState { enabled, global } 269 - } 270 - 271 - /// Reload all config from DB and apply to the live limiter. 272 - pub async fn reload_from_db(&self, db: &AnyPool) { 273 - let state = Self::load_from_db(db).await; 274 - self.set_enabled(state.enabled); 275 - self.update_config(state.global); 276 - } 277 219 } 278 220 279 221 #[cfg(test)] 280 222 mod tests { 281 223 use super::*; 282 224 283 - #[test] 284 - fn basic_allow_and_exhaust() { 285 - let rl = RateLimiter::new( 286 - true, 287 - RateLimitConfig { 288 - capacity: 3, 289 - refill_rate: 1.0, 290 - default_query_cost: 1, 291 - default_procedure_cost: 1, 292 - default_proxy_cost: 1, 293 - }, 294 - ); 295 - 296 - // Should allow 3 requests (bucket starts full, cost=1 each) 297 - for _ in 0..3 { 298 - assert!(matches!(rl.check("k", 1), CheckResult::Allowed { .. })); 225 + fn defaults() -> RateLimitDefaults { 226 + RateLimitDefaults { 227 + query_cost: 1, 228 + procedure_cost: 1, 229 + proxy_cost: 1, 299 230 } 300 - // 4th should be limited 301 - assert!(matches!(rl.check("k", 1), CheckResult::Limited { .. })); 302 231 } 303 232 304 - #[test] 305 - fn cost_deducts_multiple_tokens() { 306 - let rl = RateLimiter::new( 307 - true, 308 - RateLimitConfig { 309 - capacity: 10, 310 - refill_rate: 1.0, 311 - default_query_cost: 1, 312 - default_procedure_cost: 1, 313 - default_proxy_cost: 1, 314 - }, 315 - ); 316 - 317 - // Cost of 5 should allow 2 requests (10 tokens total) 318 - assert!(matches!( 319 - rl.check("k", 5), 320 - CheckResult::Allowed { remaining: 5, .. } 321 - )); 322 - assert!(matches!( 323 - rl.check("k", 5), 324 - CheckResult::Allowed { remaining: 0, .. } 325 - )); 326 - // 3rd should be limited 327 - assert!(matches!(rl.check("k", 5), CheckResult::Limited { .. })); 233 + fn cfg(capacity: u32, refill_rate: f64) -> RateLimitConfig { 234 + RateLimitConfig { 235 + capacity, 236 + refill_rate, 237 + default_query_cost: 1, 238 + default_procedure_cost: 1, 239 + default_proxy_cost: 1, 240 + } 328 241 } 329 242 330 243 #[test] 331 - fn disabled_returns_disabled() { 332 - let rl = RateLimiter::new( 333 - false, 334 - RateLimitConfig { 335 - capacity: 1, 336 - refill_rate: 1.0, 337 - default_query_cost: 1, 338 - default_procedure_cost: 1, 339 - default_proxy_cost: 1, 340 - }, 341 - ); 342 - assert!(matches!(rl.check("k", 1), CheckResult::Disabled)); 244 + fn unregistered_key_is_not_rate_limited() { 245 + let rl = RateLimiter::new(defaults()); 246 + for _ in 0..1000 { 247 + assert!(matches!(rl.check("anything", 1), CheckResult::Disabled)); 248 + } 343 249 } 344 250 345 251 #[test] 346 - fn default_cost_for_type() { 347 - let rl = RateLimiter::new( 348 - true, 349 - RateLimitConfig { 350 - capacity: 100, 351 - refill_rate: 10.0, 352 - default_query_cost: 2, 353 - default_procedure_cost: 5, 354 - default_proxy_cost: 3, 355 - }, 356 - ); 252 + fn registered_client_is_rate_limited() { 253 + let rl = RateLimiter::new(defaults()); 254 + rl.register_client_config("hvc_a".to_string(), cfg(3, 0.001)); 357 255 358 - assert_eq!(rl.default_cost_for_type("query"), 2); 359 - assert_eq!(rl.default_cost_for_type("procedure"), 5); 360 - assert_eq!(rl.default_cost_for_type("proxy"), 3); 361 - assert_eq!(rl.default_cost_for_type("unknown"), 1); 256 + for _ in 0..3 { 257 + assert!(matches!(rl.check("hvc_a", 1), CheckResult::Allowed { .. })); 258 + } 259 + assert!(matches!(rl.check("hvc_a", 1), CheckResult::Limited { .. })); 362 260 } 363 261 364 262 #[test] 365 - fn per_client_config_override() { 366 - let rl = RateLimiter::new( 367 - true, 368 - RateLimitConfig { 369 - capacity: 10, 370 - refill_rate: 1.0, 371 - default_query_cost: 1, 372 - default_procedure_cost: 1, 373 - default_proxy_cost: 1, 374 - }, 375 - ); 263 + fn cost_deducts_multiple_tokens() { 264 + let rl = RateLimiter::new(defaults()); 265 + rl.register_client_config("hvc_a".to_string(), cfg(10, 0.001)); 376 266 377 - // Register a client with lower capacity 378 - rl.register_client_config( 379 - "hvc_client1".to_string(), 380 - RateLimitConfig { 381 - capacity: 2, 382 - refill_rate: 0.001, 383 - default_query_cost: 1, 384 - default_procedure_cost: 1, 385 - default_proxy_cost: 1, 386 - }, 387 - ); 388 - 389 - // Client key should use client config (capacity=2) 390 - assert!(matches!( 391 - rl.check("hvc_client1", 1), 392 - CheckResult::Allowed { .. } 393 - )); 394 267 assert!(matches!( 395 - rl.check("hvc_client1", 1), 396 - CheckResult::Allowed { .. } 397 - )); 398 - assert!(matches!( 399 - rl.check("hvc_client1", 1), 400 - CheckResult::Limited { .. } 268 + rl.check("hvc_a", 5), 269 + CheckResult::Allowed { remaining: 5, .. } 401 270 )); 402 - 403 - // Other key should use global config (capacity=10) 404 - for _ in 0..10 { 405 - assert!(matches!( 406 - rl.check("other_key", 1), 407 - CheckResult::Allowed { .. } 408 - )); 409 - } 410 271 assert!(matches!( 411 - rl.check("other_key", 1), 412 - CheckResult::Limited { .. } 272 + rl.check("hvc_a", 5), 273 + CheckResult::Allowed { remaining: 0, .. } 413 274 )); 275 + assert!(matches!(rl.check("hvc_a", 5), CheckResult::Limited { .. })); 414 276 } 415 277 416 278 #[test] 417 - fn per_client_config_fallback_to_global() { 418 - let rl = RateLimiter::new( 419 - true, 420 - RateLimitConfig { 421 - capacity: 3, 422 - refill_rate: 1.0, 423 - default_query_cost: 1, 424 - default_procedure_cost: 1, 425 - default_proxy_cost: 1, 426 - }, 427 - ); 279 + fn different_clients_get_separate_buckets() { 280 + let rl = RateLimiter::new(defaults()); 281 + rl.register_client_config("hvc_a".to_string(), cfg(2, 0.001)); 282 + rl.register_client_config("hvc_b".to_string(), cfg(2, 0.001)); 428 283 429 - // No client config registered — should use global (capacity=3) 430 - for _ in 0..3 { 431 - assert!(matches!( 432 - rl.check("hvc_unregistered", 1), 433 - CheckResult::Allowed { .. } 434 - )); 435 - } 436 - assert!(matches!( 437 - rl.check("hvc_unregistered", 1), 438 - CheckResult::Limited { .. } 439 - )); 284 + assert!(matches!(rl.check("hvc_a", 1), CheckResult::Allowed { .. })); 285 + assert!(matches!(rl.check("hvc_a", 1), CheckResult::Allowed { .. })); 286 + assert!(matches!(rl.check("hvc_a", 1), CheckResult::Limited { .. })); 287 + 288 + assert!(matches!(rl.check("hvc_b", 1), CheckResult::Allowed { .. })); 289 + assert!(matches!(rl.check("hvc_b", 1), CheckResult::Allowed { .. })); 290 + assert!(matches!(rl.check("hvc_b", 1), CheckResult::Limited { .. })); 440 291 } 441 292 442 293 #[test] 443 - fn register_and_remove_client_config() { 444 - let rl = RateLimiter::new( 445 - true, 446 - RateLimitConfig { 447 - capacity: 10, 448 - refill_rate: 1.0, 449 - default_query_cost: 1, 450 - default_procedure_cost: 1, 451 - default_proxy_cost: 1, 452 - }, 453 - ); 454 - 455 - rl.register_client_config( 456 - "hvc_temp".to_string(), 457 - RateLimitConfig { 458 - capacity: 1, 459 - refill_rate: 0.001, 460 - default_query_cost: 1, 461 - default_procedure_cost: 1, 462 - default_proxy_cost: 1, 463 - }, 464 - ); 465 - 466 - // Should be limited after 1 request (client config capacity=1) 294 + fn remove_client_config_disables_limiting() { 295 + let rl = RateLimiter::new(defaults()); 296 + rl.register_client_config("hvc_temp".to_string(), cfg(1, 0.001)); 467 297 assert!(matches!( 468 298 rl.check("hvc_temp", 1), 469 299 CheckResult::Allowed { .. } ··· 473 303 CheckResult::Limited { .. } 474 304 )); 475 305 476 - // Remove client config — new bucket should use global (capacity=10) 477 306 rl.remove_client_config("hvc_temp"); 478 - // Note: the old bucket still exists and is exhausted, but capacity was 479 - // updated to global. A new bucket would get global capacity. 307 + assert!(matches!(rl.check("hvc_temp", 1), CheckResult::Disabled)); 480 308 } 481 309 482 310 #[test] 483 - fn different_clients_get_separate_buckets() { 484 - let rl = RateLimiter::new( 485 - true, 486 - RateLimitConfig { 487 - capacity: 2, 488 - refill_rate: 0.001, 489 - default_query_cost: 1, 490 - default_procedure_cost: 1, 491 - default_proxy_cost: 1, 492 - }, 493 - ); 494 - 495 - // Exhaust client A 496 - assert!(matches!( 497 - rl.check("clientA", 1), 498 - CheckResult::Allowed { .. } 499 - )); 500 - assert!(matches!( 501 - rl.check("clientA", 1), 502 - CheckResult::Allowed { .. } 503 - )); 504 - assert!(matches!( 505 - rl.check("clientA", 1), 506 - CheckResult::Limited { .. } 507 - )); 508 - 509 - // Client B should still have tokens 510 - assert!(matches!( 511 - rl.check("clientB", 1), 512 - CheckResult::Allowed { .. } 513 - )); 514 - assert!(matches!( 515 - rl.check("clientB", 1), 516 - CheckResult::Allowed { .. } 517 - )); 518 - assert!(matches!( 519 - rl.check("clientB", 1), 520 - CheckResult::Limited { .. } 521 - )); 311 + fn default_cost_for_type_uses_seeded_defaults_when_no_client_override() { 312 + let rl = RateLimiter::new(RateLimitDefaults { 313 + query_cost: 2, 314 + procedure_cost: 5, 315 + proxy_cost: 3, 316 + }); 317 + assert_eq!(rl.default_cost_for_type("nope", "query"), 2); 318 + assert_eq!(rl.default_cost_for_type("nope", "procedure"), 5); 319 + assert_eq!(rl.default_cost_for_type("nope", "proxy"), 3); 522 320 } 523 321 524 322 #[test] 525 - fn toggle_enabled() { 526 - let rl = RateLimiter::new( 527 - true, 323 + fn default_cost_for_type_uses_per_client_override() { 324 + let rl = RateLimiter::new(RateLimitDefaults { 325 + query_cost: 2, 326 + procedure_cost: 5, 327 + proxy_cost: 3, 328 + }); 329 + rl.register_client_config( 330 + "hvc_a".to_string(), 528 331 RateLimitConfig { 529 - capacity: 1, 332 + capacity: 100, 530 333 refill_rate: 1.0, 531 - default_query_cost: 1, 532 - default_procedure_cost: 1, 533 - default_proxy_cost: 1, 334 + default_query_cost: 7, 335 + default_procedure_cost: 8, 336 + default_proxy_cost: 9, 534 337 }, 535 338 ); 536 - assert!(rl.is_enabled()); 537 - rl.set_enabled(false); 538 - assert!(!rl.is_enabled()); 539 - assert!(matches!(rl.check("k", 1), CheckResult::Disabled)); 339 + assert_eq!(rl.default_cost_for_type("hvc_a", "query"), 7); 340 + assert_eq!(rl.default_cost_for_type("hvc_a", "procedure"), 8); 341 + assert_eq!(rl.default_cost_for_type("hvc_a", "proxy"), 9); 540 342 } 541 343 }
+1 -1
src/repo/session.rs
··· 14 14 Did::new(did.to_string()).map_err(|_| AppError::Auth(format!("invalid DID: {did}")))?; 15 15 state 16 16 .oauth 17 - .default_client() 17 + .primary_client() 18 18 .restore(&did) 19 19 .await 20 20 .map_err(|e| AppError::Auth(format!("no OAuth session for {}: {e}", did.as_ref())))
+12 -9
src/repo/upload_blob.rs
··· 17 17 headers: HeaderMap, 18 18 body: Bytes, 19 19 ) -> Result<Response, AppError> { 20 - let rate_key = claims.did().to_string(); 21 - let check = state.rate_limiter.check( 22 - &rate_key, 23 - state.rate_limiter.default_cost_for_type("procedure"), 24 - ); 20 + let check = if let Some(client_key) = claims.client_key() { 21 + let cost = state 22 + .rate_limiter 23 + .default_cost_for_type(client_key, "procedure"); 24 + Some(state.rate_limiter.check(client_key, cost)) 25 + } else { 26 + None 27 + }; 25 28 26 - if let CheckResult::Limited { 29 + if let Some(CheckResult::Limited { 27 30 retry_after, 28 31 limit, 29 32 reset, 30 - } = check 33 + }) = check 31 34 { 32 35 return Err(AppError::RateLimited { 33 36 retry_after, ··· 45 48 46 49 let mut response = pds_post_blob(&state, &session, content_type, body).await?; 47 50 48 - if let CheckResult::Allowed { 51 + if let Some(CheckResult::Allowed { 49 52 remaining, 50 53 limit, 51 54 reset, 52 - } = check 55 + }) = check 53 56 { 54 57 let h = response.headers_mut(); 55 58 h.insert("RateLimit-Limit", limit.into());
+46 -23
src/server.rs
··· 13 13 use crate::AppState; 14 14 use crate::admin; 15 15 use crate::auth::Claims; 16 + use crate::domain_middleware::resolve_domain; 16 17 use crate::error::AppError; 17 18 use crate::profile; 18 19 use crate::rate_limit::CheckResult; ··· 60 61 61 62 let serve_dir = ServeDir::new(&static_dir).not_found_service(spa_fallback); 62 63 63 - Router::new() 64 - .route("/health", get(health)) 65 - .route("/settings/logo", get(crate::admin::settings::serve_logo)) 66 - .nest("/admin", admin::admin_routes(state.clone())) 64 + let domain_routes = Router::new() 67 65 .nest("/auth", crate::auth::routes::routes()) 68 66 .nest("/external-auth", crate::external_auth::routes()) 69 67 // https://atproto.com/specs/oauth#types-of-clients ··· 76 74 // Catch-all for dynamically registered lexicons 77 75 .route("/xrpc/{method}", get(xrpc::xrpc_get).post(xrpc::xrpc_post)) 78 76 .route("/config", get(config_endpoint)) 77 + .route("/settings/logo", get(crate::admin::settings::serve_logo)) 78 + .layer(axum::middleware::from_fn_with_state( 79 + state.clone(), 80 + resolve_domain, 81 + )); 82 + 83 + Router::new() 84 + .route("/health", get(health)) 85 + .nest("/admin", admin::admin_routes(state.clone())) 86 + .merge(domain_routes) 79 87 .fallback_service(serve_dir) 80 88 .layer(TraceLayer::new_for_http()) 81 89 .layer( ··· 98 106 "ok" 99 107 } 100 108 101 - async fn config_endpoint(State(state): State<AppState>) -> Json<serde_json::Value> { 109 + async fn config_endpoint( 110 + State(state): State<AppState>, 111 + req: axum::extract::Request, 112 + ) -> Json<serde_json::Value> { 113 + let domain_url = crate::domain_middleware::extract_domain(&req) 114 + .map(|d| d.url.clone()) 115 + .unwrap_or_else(|| state.config.public_url.clone()); 116 + 102 117 let pool = &state.db; 103 118 let backend = state.db_backend; 104 119 ··· 112 127 let logo_url = if has_logo_data { 113 128 Some(format!( 114 129 "{}/settings/logo", 115 - state.config.public_url.trim_end_matches('/') 130 + domain_url.trim_end_matches('/') 116 131 )) 117 132 } else { 118 133 crate::admin::settings::get_setting(pool, "logo_uri", backend) ··· 126 141 }; 127 142 128 143 Json(serde_json::json!({ 129 - "public_url": state.config.public_url, 144 + "public_url": domain_url, 130 145 "version": version, 131 146 "database_backend": format!("{:?}", state.config.database_backend).to_lowercase(), 132 147 "jetstream_url": state.config.jetstream_url, ··· 139 154 })) 140 155 } 141 156 142 - async fn client_metadata(State(state): State<AppState>) -> Json<serde_json::Value> { 143 - let mut metadata = 144 - serde_json::to_value(&state.oauth.default_client().client_metadata).unwrap_or_default(); 157 + async fn client_metadata( 158 + State(state): State<AppState>, 159 + req: axum::extract::Request, 160 + ) -> Json<serde_json::Value> { 161 + let domain_url = crate::domain_middleware::extract_domain(&req) 162 + .map(|d| d.url.clone()) 163 + .unwrap_or_else(|| state.config.public_url.clone()); 164 + 165 + let oauth_client = state.oauth.get_for_domain(&domain_url); 166 + let mut metadata = serde_json::to_value(&oauth_client.client_metadata).unwrap_or_default(); 145 167 146 168 // The `client_id` field in the response must exactly match the URL the 147 169 // authorization server fetched. 148 170 let client_id = format!( 149 171 "{}/oauth-client-metadata.json", 150 - state.config.public_url.trim_end_matches('/') 172 + domain_url.trim_end_matches('/') 151 173 ); 152 174 metadata["client_id"] = serde_json::Value::String(client_id); 153 175 ··· 169 191 if has_logo_data { 170 192 metadata["logo_uri"] = serde_json::Value::String(format!( 171 193 "{}/settings/logo", 172 - state.config.public_url.trim_end_matches('/') 194 + domain_url.trim_end_matches('/') 173 195 )); 174 196 } else if let Some(uri) = crate::admin::settings::get_setting(pool, "logo_uri", backend).await { 175 197 metadata["logo_uri"] = serde_json::Value::String(uri); ··· 187 209 } 188 210 189 211 async fn get_profile(State(state): State<AppState>, claims: Claims) -> Result<Response, AppError> { 190 - let rate_key = claims 191 - .client_key() 192 - .map(|k| k.to_string()) 193 - .unwrap_or_else(|| claims.did().to_string()); 194 - let check = state 195 - .rate_limiter 196 - .check(&rate_key, state.rate_limiter.default_cost_for_type("query")); 212 + let check = if let Some(client_key) = claims.client_key() { 213 + let cost = state 214 + .rate_limiter 215 + .default_cost_for_type(client_key, "query"); 216 + Some(state.rate_limiter.check(client_key, cost)) 217 + } else { 218 + None 219 + }; 197 220 198 - if let CheckResult::Limited { 221 + if let Some(CheckResult::Limited { 199 222 retry_after, 200 223 limit, 201 224 reset, 202 - } = check 225 + }) = check 203 226 { 204 227 return Err(AppError::RateLimited { 205 228 retry_after, ··· 212 235 profile::resolve_profile(&state.http, &state.config.plc_url, claims.did()).await?; 213 236 let mut response = Json(profile).into_response(); 214 237 215 - if let CheckResult::Allowed { 238 + if let Some(CheckResult::Allowed { 216 239 remaining, 217 240 limit, 218 241 reset, 219 - } = check 242 + }) = check 220 243 { 221 244 let h = response.headers_mut(); 222 245 h.insert("RateLimit-Limit", limit.into());
+8 -4
src/xrpc/mod.rs
··· 267 267 let cost = if let Some(ref lex) = lexicon { 268 268 lex.token_cost.unwrap_or_else(|| { 269 269 let type_str = format!("{:?}", lex.lexicon_type).to_lowercase(); 270 - state.rate_limiter.default_cost_for_type(&type_str) 270 + state 271 + .rate_limiter 272 + .default_cost_for_type(&rate_key, &type_str) 271 273 }) 272 274 } else { 273 - state.rate_limiter.default_cost_for_type("proxy") 275 + state.rate_limiter.default_cost_for_type(&rate_key, "proxy") 274 276 }; 275 277 276 278 let check = state.rate_limiter.check(&rate_key, cost); ··· 349 351 let cost = if let Some(ref lex) = lexicon { 350 352 lex.token_cost.unwrap_or_else(|| { 351 353 let type_str = format!("{:?}", lex.lexicon_type).to_lowercase(); 352 - state.rate_limiter.default_cost_for_type(&type_str) 354 + state 355 + .rate_limiter 356 + .default_cost_for_type(&rate_key, &type_str) 353 357 }) 354 358 } else { 355 - state.rate_limiter.default_cost_for_type("proxy") 359 + state.rate_limiter.default_cost_for_type(&rate_key, "proxy") 356 360 }; 357 361 358 362 let check = state.rate_limiter.check(&rate_key, cost);
+5 -7
tests/common/app.rs
··· 122 122 http: reqwest::Client::new(), 123 123 db: pool.clone(), 124 124 db_backend: backend, 125 + domain_cache: happyview::domain::DomainCache::new(), 125 126 lexicons, 126 127 collections_tx, 127 128 labeler_subscriptions_tx, 128 129 rate_limiter: happyview::rate_limit::RateLimiter::new( 129 - false, 130 - happyview::rate_limit::RateLimitConfig { 131 - capacity: 100, 132 - refill_rate: 2.0, 133 - default_query_cost: 1, 134 - default_procedure_cost: 1, 135 - default_proxy_cost: 1, 130 + happyview::rate_limit::RateLimitDefaults { 131 + query_cost: 1, 132 + procedure_cost: 1, 133 + proxy_cost: 1, 136 134 }, 137 135 ), 138 136 oauth: std::sync::Arc::new(happyview::auth::OAuthClientRegistry::new(
+2 -1
tests/common/db.rs
··· 20 20 match backend { 21 21 DatabaseBackend::Postgres => { 22 22 sqlx::query( 23 - "TRUNCATE records, lexicons, backfill_jobs, users, user_permissions, api_keys, event_logs, script_variables, dead_letter_hooks, record_refs, labeler_subscriptions, labels, instance_settings RESTART IDENTITY CASCADE", 23 + "TRUNCATE records, lexicons, backfill_jobs, users, user_permissions, api_keys, event_logs, script_variables, dead_letter_hooks, record_refs, labeler_subscriptions, labels, instance_settings, domains RESTART IDENTITY CASCADE", 24 24 ) 25 25 .execute(pool) 26 26 .await ··· 41 41 "labeler_subscriptions", 42 42 "labels", 43 43 "instance_settings", 44 + "domains", 44 45 ]; 45 46 for table in tables { 46 47 sqlx::query(&format!("DELETE FROM {table}"))
+4 -4
tests/e2e_api_clients.rs
··· 580 580 .state 581 581 .oauth 582 582 .get_or_default(Some("https://unknown.example.com/metadata.json")); 583 - let default = app.state.oauth.default_client(); 583 + let default = app.state.oauth.primary_client(); 584 584 585 585 // Should be the same Arc (default client) 586 - assert!(std::sync::Arc::ptr_eq(&client, default)); 586 + assert!(std::sync::Arc::ptr_eq(&client, &default)); 587 587 } 588 588 589 589 #[tokio::test] ··· 593 593 let app = TestApp::new().await; 594 594 595 595 let client = app.state.oauth.get_or_default(None); 596 - let default = app.state.oauth.default_client(); 596 + let default = app.state.oauth.primary_client(); 597 597 598 - assert!(std::sync::Arc::ptr_eq(&client, default)); 598 + assert!(std::sync::Arc::ptr_eq(&client, &default)); 599 599 } 600 600 601 601 // ---------------------------------------------------------------------------
+352
tests/e2e_domains.rs
··· 1 + mod common; 2 + 3 + use axum::body::Body; 4 + use axum::http::{Method, Request, StatusCode}; 5 + use http_body_util::BodyExt; 6 + use serde_json::{Value, json}; 7 + use serial_test::serial; 8 + use tower::ServiceExt; 9 + 10 + use common::app::TestApp; 11 + 12 + // --------------------------------------------------------------------------- 13 + // Helpers 14 + // --------------------------------------------------------------------------- 15 + 16 + async fn json_body(resp: axum::response::Response) -> Value { 17 + let body = resp.into_body().collect().await.unwrap().to_bytes(); 18 + serde_json::from_slice(&body).unwrap() 19 + } 20 + 21 + fn admin_get( 22 + uri: &str, 23 + cookie: (axum::http::HeaderName, axum::http::HeaderValue), 24 + ) -> Request<Body> { 25 + Request::builder() 26 + .uri(uri) 27 + .header(cookie.0, cookie.1) 28 + .body(Body::empty()) 29 + .unwrap() 30 + } 31 + 32 + fn admin_post( 33 + uri: &str, 34 + cookie: (axum::http::HeaderName, axum::http::HeaderValue), 35 + body: &Value, 36 + ) -> Request<Body> { 37 + Request::builder() 38 + .method(Method::POST) 39 + .uri(uri) 40 + .header(cookie.0, cookie.1) 41 + .header("content-type", "application/json") 42 + .body(Body::from(serde_json::to_vec(body).unwrap())) 43 + .unwrap() 44 + } 45 + 46 + fn admin_delete( 47 + uri: &str, 48 + cookie: (axum::http::HeaderName, axum::http::HeaderValue), 49 + ) -> Request<Body> { 50 + Request::builder() 51 + .method(Method::DELETE) 52 + .uri(uri) 53 + .header(cookie.0, cookie.1) 54 + .body(Body::empty()) 55 + .unwrap() 56 + } 57 + 58 + fn get_with_host(uri: &str, host: &str) -> Request<Body> { 59 + Request::builder() 60 + .uri(uri) 61 + .header("host", host) 62 + .body(Body::empty()) 63 + .unwrap() 64 + } 65 + 66 + async fn seed_domain(app: &TestApp, id: &str, url: &str, is_primary: bool) { 67 + let now = happyview::db::now_rfc3339(); 68 + let sql = happyview::db::adapt_sql( 69 + "INSERT INTO domains (id, url, is_primary, created_at, updated_at) VALUES (?, ?, ?, ?, ?)", 70 + app.state.db_backend, 71 + ); 72 + sqlx::query(&sql) 73 + .bind(id) 74 + .bind(url) 75 + .bind(if is_primary { 1i32 } else { 0i32 }) 76 + .bind(&now) 77 + .bind(&now) 78 + .execute(&app.state.db) 79 + .await 80 + .unwrap(); 81 + app.state 82 + .domain_cache 83 + .insert(happyview::domain::Domain { 84 + id: id.into(), 85 + url: url.into(), 86 + is_primary, 87 + created_at: now.clone(), 88 + updated_at: now, 89 + }) 90 + .await; 91 + } 92 + 93 + // --------------------------------------------------------------------------- 94 + // Domains tests 95 + // --------------------------------------------------------------------------- 96 + 97 + #[tokio::test] 98 + #[serial] 99 + #[ignore] 100 + async fn domains_list_returns_seeded_domain() { 101 + let app = TestApp::new().await; 102 + 103 + seed_domain(&app, "primary-id", "http://127.0.0.1:0", true).await; 104 + 105 + let resp = app 106 + .router 107 + .clone() 108 + .oneshot(admin_get("/admin/domains", app.admin_cookie())) 109 + .await 110 + .unwrap(); 111 + 112 + assert_eq!(resp.status(), StatusCode::OK); 113 + let json = json_body(resp).await; 114 + let domains = json.as_array().expect("expected array"); 115 + assert_eq!(domains.len(), 1, "expected 1 domain, got {}", domains.len()); 116 + assert_eq!(domains[0]["url"], "http://127.0.0.1:0"); 117 + assert_eq!(domains[0]["is_primary"], true); 118 + } 119 + 120 + #[tokio::test] 121 + #[serial] 122 + #[ignore] 123 + async fn domains_create_and_delete() { 124 + let app = TestApp::new().await; 125 + 126 + seed_domain(&app, "primary-id", "http://127.0.0.1:0", true).await; 127 + 128 + // Create a new domain 129 + let resp = app 130 + .router 131 + .clone() 132 + .oneshot(admin_post( 133 + "/admin/domains", 134 + app.admin_cookie(), 135 + &json!({ "url": "http://127.0.0.1:9999" }), 136 + )) 137 + .await 138 + .unwrap(); 139 + 140 + assert_eq!( 141 + resp.status(), 142 + StatusCode::CREATED, 143 + "expected 201 on create, got {}", 144 + resp.status() 145 + ); 146 + let json = json_body(resp).await; 147 + assert!( 148 + json["id"].is_string(), 149 + "expected id in response, got {:?}", 150 + json 151 + ); 152 + assert_eq!(json["url"], "http://127.0.0.1:9999"); 153 + assert_eq!(json["is_primary"], false); 154 + 155 + let new_id = json["id"].as_str().unwrap().to_string(); 156 + 157 + // Delete the newly created domain 158 + let resp = app 159 + .router 160 + .clone() 161 + .oneshot(admin_delete( 162 + &format!("/admin/domains/{new_id}"), 163 + app.admin_cookie(), 164 + )) 165 + .await 166 + .unwrap(); 167 + 168 + assert_eq!( 169 + resp.status(), 170 + StatusCode::NO_CONTENT, 171 + "expected 204 on delete, got {}", 172 + resp.status() 173 + ); 174 + } 175 + 176 + #[tokio::test] 177 + #[serial] 178 + #[ignore] 179 + async fn domains_duplicate_url_returns_400() { 180 + let app = TestApp::new().await; 181 + 182 + seed_domain(&app, "primary-id", "http://127.0.0.1:0", true).await; 183 + 184 + // Attempt to create a domain with the same URL 185 + let resp = app 186 + .router 187 + .clone() 188 + .oneshot(admin_post( 189 + "/admin/domains", 190 + app.admin_cookie(), 191 + &json!({ "url": "http://127.0.0.1:0" }), 192 + )) 193 + .await 194 + .unwrap(); 195 + 196 + assert_eq!( 197 + resp.status(), 198 + StatusCode::BAD_REQUEST, 199 + "expected 400 on duplicate URL, got {}", 200 + resp.status() 201 + ); 202 + } 203 + 204 + #[tokio::test] 205 + #[serial] 206 + #[ignore] 207 + async fn domains_cannot_delete_primary() { 208 + let app = TestApp::new().await; 209 + 210 + seed_domain(&app, "primary-id", "http://127.0.0.1:0", true).await; 211 + 212 + let resp = app 213 + .router 214 + .clone() 215 + .oneshot(admin_delete( 216 + "/admin/domains/primary-id", 217 + app.admin_cookie(), 218 + )) 219 + .await 220 + .unwrap(); 221 + 222 + assert_eq!( 223 + resp.status(), 224 + StatusCode::BAD_REQUEST, 225 + "expected 400 when deleting primary domain, got {}", 226 + resp.status() 227 + ); 228 + } 229 + 230 + #[tokio::test] 231 + #[serial] 232 + #[ignore] 233 + async fn domains_set_primary() { 234 + let app = TestApp::new().await; 235 + 236 + seed_domain(&app, "id-a", "http://127.0.0.1:0", true).await; 237 + seed_domain(&app, "id-b", "http://127.0.0.1:9999", false).await; 238 + 239 + // Set domain b as primary 240 + let resp = app 241 + .router 242 + .clone() 243 + .oneshot(admin_post( 244 + "/admin/domains/id-b/primary", 245 + app.admin_cookie(), 246 + &json!({}), 247 + )) 248 + .await 249 + .unwrap(); 250 + 251 + assert_eq!( 252 + resp.status(), 253 + StatusCode::NO_CONTENT, 254 + "expected 204 on set primary, got {}", 255 + resp.status() 256 + ); 257 + 258 + // Verify domain b is now primary 259 + let resp = app 260 + .router 261 + .clone() 262 + .oneshot(admin_get("/admin/domains", app.admin_cookie())) 263 + .await 264 + .unwrap(); 265 + 266 + assert_eq!(resp.status(), StatusCode::OK); 267 + let json = json_body(resp).await; 268 + let domains = json.as_array().expect("expected array"); 269 + let domain_b = domains 270 + .iter() 271 + .find(|d| d["id"] == "id-b") 272 + .expect("domain b not found"); 273 + assert_eq!( 274 + domain_b["is_primary"], true, 275 + "expected domain b to be primary" 276 + ); 277 + } 278 + 279 + #[tokio::test] 280 + #[serial] 281 + #[ignore] 282 + async fn unknown_host_returns_421_on_domain_scoped_routes() { 283 + let app = TestApp::new().await; 284 + 285 + // No domains seeded — cache is empty 286 + let resp = app 287 + .router 288 + .clone() 289 + .oneshot(get_with_host("/config", "unknown.example.com")) 290 + .await 291 + .unwrap(); 292 + 293 + assert_eq!( 294 + resp.status(), 295 + StatusCode::MISDIRECTED_REQUEST, 296 + "expected 421 for unknown host, got {}", 297 + resp.status() 298 + ); 299 + } 300 + 301 + #[tokio::test] 302 + #[serial] 303 + #[ignore] 304 + async fn health_check_bypasses_domain_resolution() { 305 + let app = TestApp::new().await; 306 + 307 + // No domains seeded — cache is empty 308 + let resp = app 309 + .router 310 + .clone() 311 + .oneshot(get_with_host("/health", "unknown.example.com")) 312 + .await 313 + .unwrap(); 314 + 315 + assert_eq!( 316 + resp.status(), 317 + StatusCode::OK, 318 + "expected 200 on /health regardless of host, got {}", 319 + resp.status() 320 + ); 321 + } 322 + 323 + #[tokio::test] 324 + #[serial] 325 + #[ignore] 326 + async fn domain_scoped_route_works_with_known_host() { 327 + let app = TestApp::new().await; 328 + 329 + // Domain.host() for "http://localhost:3000" is "localhost:3000" 330 + seed_domain(&app, "local-id", "http://localhost:3000", true).await; 331 + 332 + let resp = app 333 + .router 334 + .clone() 335 + .oneshot(get_with_host("/config", "localhost:3000")) 336 + .await 337 + .unwrap(); 338 + 339 + assert_eq!( 340 + resp.status(), 341 + StatusCode::OK, 342 + "expected 200 on /config with known host, got {}", 343 + resp.status() 344 + ); 345 + 346 + let json = json_body(resp).await; 347 + assert_eq!( 348 + json["public_url"], "http://localhost:3000", 349 + "expected public_url to match domain URL, got {:?}", 350 + json["public_url"] 351 + ); 352 + }
+5 -7
tests/lua_atproto_api.rs
··· 73 73 collections_tx: tx, 74 74 labeler_subscriptions_tx: labeler_tx, 75 75 rate_limiter: happyview::rate_limit::RateLimiter::new( 76 - false, 77 - happyview::rate_limit::RateLimitConfig { 78 - capacity: 100, 79 - refill_rate: 2.0, 80 - default_query_cost: 1, 81 - default_procedure_cost: 1, 82 - default_proxy_cost: 1, 76 + happyview::rate_limit::RateLimitDefaults { 77 + query_cost: 1, 78 + procedure_cost: 1, 79 + proxy_cost: 1, 83 80 }, 84 81 ), 85 82 oauth: std::sync::Arc::new(happyview::auth::OAuthClientRegistry::new( ··· 97 94 )), 98 95 official_registry_config: happyview::plugin::official_registry::RegistryConfig::production( 99 96 ), 97 + domain_cache: happyview::domain::DomainCache::new(), 100 98 } 101 99 } 102 100
+5 -7
tests/lua_db_api.rs
··· 76 76 collections_tx: tx, 77 77 labeler_subscriptions_tx: labeler_tx, 78 78 rate_limiter: happyview::rate_limit::RateLimiter::new( 79 - false, 80 - happyview::rate_limit::RateLimitConfig { 81 - capacity: 100, 82 - refill_rate: 2.0, 83 - default_query_cost: 1, 84 - default_procedure_cost: 1, 85 - default_proxy_cost: 1, 79 + happyview::rate_limit::RateLimitDefaults { 80 + query_cost: 1, 81 + procedure_cost: 1, 82 + proxy_cost: 1, 86 83 }, 87 84 ), 88 85 oauth: std::sync::Arc::new(happyview::auth::OAuthClientRegistry::new( ··· 100 97 )), 101 98 official_registry_config: happyview::plugin::official_registry::RegistryConfig::production( 102 99 ), 100 + domain_cache: happyview::domain::DomainCache::new(), 103 101 } 104 102 } 105 103