A lexicon-driven AppView for ATProto. happyview.dev
backfill firehose jetstream atproto appview oauth lexicon
8
fork

Configure Feed

Select the types of activity you want to include in your feed.

feat: add support for API clients

Trezy a289f80b ef6e23c2

+1847 -324
+15
migrations/postgres/20260412000000_create_api_clients.sql
··· 1 + CREATE TABLE IF NOT EXISTS api_clients ( 2 + id TEXT PRIMARY KEY, 3 + client_key TEXT NOT NULL UNIQUE, 4 + name TEXT NOT NULL, 5 + client_id_url TEXT NOT NULL UNIQUE, 6 + client_uri TEXT NOT NULL, 7 + redirect_uris TEXT NOT NULL, 8 + scopes TEXT NOT NULL DEFAULT 'atproto', 9 + rate_limit_capacity INTEGER, 10 + rate_limit_refill_rate REAL, 11 + is_active INTEGER NOT NULL DEFAULT 1, 12 + created_by TEXT NOT NULL, 13 + created_at TEXT NOT NULL DEFAULT '', 14 + updated_at TEXT NOT NULL DEFAULT '' 15 + );
+1
migrations/postgres/20260412100000_auth_redirects_add_client_id.sql
··· 1 + ALTER TABLE auth_login_redirects ADD COLUMN IF NOT EXISTS client_id TEXT;
+1
migrations/postgres/20260412200000_drop_rate_limit_allowlist.sql
··· 1 + DROP TABLE IF EXISTS rate_limit_allowlist;
+15
migrations/sqlite/20260412000000_create_api_clients.sql
··· 1 + CREATE TABLE IF NOT EXISTS api_clients ( 2 + id TEXT PRIMARY KEY, 3 + client_key TEXT NOT NULL UNIQUE, 4 + name TEXT NOT NULL, 5 + client_id_url TEXT NOT NULL UNIQUE, 6 + client_uri TEXT NOT NULL, 7 + redirect_uris TEXT NOT NULL, 8 + scopes TEXT NOT NULL DEFAULT 'atproto', 9 + rate_limit_capacity INTEGER, 10 + rate_limit_refill_rate REAL, 11 + is_active INTEGER NOT NULL DEFAULT 1, 12 + created_by TEXT NOT NULL, 13 + created_at TEXT NOT NULL DEFAULT '', 14 + updated_at TEXT NOT NULL DEFAULT '' 15 + );
+1
migrations/sqlite/20260412100000_auth_redirects_add_client_id.sql
··· 1 + ALTER TABLE auth_login_redirects ADD COLUMN client_id TEXT;
+1
migrations/sqlite/20260412200000_drop_rate_limit_allowlist.sql
··· 1 + DROP TABLE IF EXISTS rate_limit_allowlist;
+471
src/admin/api_clients.rs
··· 1 + use axum::Json; 2 + use axum::extract::{Path, State}; 3 + use axum::http::StatusCode; 4 + use hex; 5 + use rand::Rng; 6 + use uuid::Uuid; 7 + 8 + use crate::AppState; 9 + use crate::db::{adapt_sql, now_rfc3339}; 10 + use crate::error::AppError; 11 + use crate::event_log::{EventLog, Severity, log_event}; 12 + 13 + use super::auth::UserAuth; 14 + use super::permissions::Permission; 15 + use super::types::{ 16 + ApiClientSummary, CreateApiClientBody, CreateApiClientResponse, UpdateApiClientBody, 17 + }; 18 + 19 + /// POST /admin/api-clients — create a new API client. 20 + pub(super) async fn create_api_client( 21 + State(state): State<AppState>, 22 + auth: UserAuth, 23 + Json(body): Json<CreateApiClientBody>, 24 + ) -> Result<(StatusCode, Json<CreateApiClientResponse>), AppError> { 25 + auth.require(Permission::ApiClientsCreate).await?; 26 + 27 + // Generate the client key: "hvc_" + 32 random hex chars. 28 + let mut random_bytes = [0u8; 16]; 29 + rand::rng().fill(&mut random_bytes); 30 + let client_key = format!("hvc_{}", hex::encode(random_bytes)); 31 + 32 + let id = Uuid::new_v4().to_string(); 33 + let now = now_rfc3339(); 34 + let redirect_uris_json = 35 + serde_json::to_string(&body.redirect_uris).unwrap_or_else(|_| "[]".to_string()); 36 + 37 + let insert_sql = adapt_sql( 38 + "INSERT INTO api_clients (id, client_key, name, client_id_url, client_uri, redirect_uris, scopes, rate_limit_capacity, rate_limit_refill_rate, is_active, created_by, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, 1, ?, ?, ?)", 39 + state.db_backend, 40 + ); 41 + 42 + sqlx::query(&insert_sql) 43 + .bind(&id) 44 + .bind(&client_key) 45 + .bind(&body.name) 46 + .bind(&body.client_id_url) 47 + .bind(&body.client_uri) 48 + .bind(&redirect_uris_json) 49 + .bind(&body.scopes) 50 + .bind(body.rate_limit_capacity) 51 + .bind(body.rate_limit_refill_rate) 52 + .bind(&auth.did) 53 + .bind(&now) 54 + .bind(&now) 55 + .execute(&state.db) 56 + .await 57 + .map_err(|e| AppError::Internal(format!("failed to create api client: {e}")))?; 58 + 59 + // Register the new client in the OAuth registry so it's usable immediately. 60 + let oauth_params = crate::auth::client_registry::ApiClientOAuthParams { 61 + plc_url: state.config.plc_url.clone(), 62 + state_store: state.oauth_state_store.clone(), 63 + session_store_pool: state.db.clone(), 64 + db_backend: state.db_backend, 65 + }; 66 + if let Err(e) = state.oauth.register_api_client( 67 + &body.client_id_url, 68 + &body.client_uri, 69 + body.redirect_uris.clone(), 70 + &body.scopes, 71 + &oauth_params, 72 + ) { 73 + tracing::warn!(client_id = %body.client_id_url, error = %e, "OAuth client registration failed (DB row created)"); 74 + } 75 + 76 + // Register per-client rate limit config if overrides are set. 77 + if let (Some(capacity), Some(refill_rate)) = 78 + (body.rate_limit_capacity, body.rate_limit_refill_rate) 79 + { 80 + let global = state.rate_limiter.global_config(); 81 + state.rate_limiter.register_client_config( 82 + client_key.clone(), 83 + crate::rate_limit::RateLimitConfig { 84 + capacity: capacity as u32, 85 + refill_rate, 86 + default_query_cost: global.default_query_cost, 87 + default_procedure_cost: global.default_procedure_cost, 88 + default_proxy_cost: global.default_proxy_cost, 89 + }, 90 + ); 91 + } 92 + 93 + log_event( 94 + &state.db, 95 + EventLog { 96 + event_type: "api_client.created".to_string(), 97 + severity: Severity::Info, 98 + actor_did: Some(auth.did.clone()), 99 + subject: Some(body.name.clone()), 100 + detail: serde_json::json!({ 101 + "client_key": client_key, 102 + "client_id_url": body.client_id_url, 103 + }), 104 + }, 105 + state.db_backend, 106 + ) 107 + .await; 108 + 109 + Ok(( 110 + StatusCode::CREATED, 111 + Json(CreateApiClientResponse { 112 + id, 113 + client_key, 114 + name: body.name, 115 + client_id_url: body.client_id_url, 116 + }), 117 + )) 118 + } 119 + 120 + /// GET /admin/api-clients — list all API clients. 121 + pub(super) async fn list_api_clients( 122 + State(state): State<AppState>, 123 + auth: UserAuth, 124 + ) -> Result<Json<Vec<ApiClientSummary>>, AppError> { 125 + auth.require(Permission::ApiClientsView).await?; 126 + 127 + let select_sql = adapt_sql( 128 + "SELECT id, client_key, name, client_id_url, client_uri, redirect_uris, scopes, rate_limit_capacity, rate_limit_refill_rate, is_active, created_by, created_at, updated_at FROM api_clients ORDER BY created_at DESC", 129 + state.db_backend, 130 + ); 131 + 132 + #[allow(clippy::type_complexity)] 133 + let rows: Vec<( 134 + String, 135 + String, 136 + String, 137 + String, 138 + String, 139 + String, 140 + String, 141 + Option<i32>, 142 + Option<f64>, 143 + i32, 144 + String, 145 + String, 146 + String, 147 + )> = sqlx::query_as(&select_sql) 148 + .fetch_all(&state.db) 149 + .await 150 + .map_err(|e| AppError::Internal(format!("failed to list api clients: {e}")))?; 151 + 152 + let clients: Vec<ApiClientSummary> = rows 153 + .into_iter() 154 + .map( 155 + |( 156 + id, 157 + client_key, 158 + name, 159 + client_id_url, 160 + client_uri, 161 + redirect_uris_json, 162 + scopes, 163 + rate_limit_capacity, 164 + rate_limit_refill_rate, 165 + is_active, 166 + created_by, 167 + created_at, 168 + updated_at, 169 + )| { 170 + let redirect_uris: Vec<String> = 171 + serde_json::from_str(&redirect_uris_json).unwrap_or_default(); 172 + ApiClientSummary { 173 + id, 174 + client_key, 175 + name, 176 + client_id_url, 177 + client_uri, 178 + redirect_uris, 179 + scopes, 180 + rate_limit_capacity, 181 + rate_limit_refill_rate, 182 + is_active: is_active != 0, 183 + created_by, 184 + created_at, 185 + updated_at, 186 + } 187 + }, 188 + ) 189 + .collect(); 190 + 191 + Ok(Json(clients)) 192 + } 193 + 194 + /// GET /admin/api-clients/:id — get a single API client. 195 + pub(super) async fn get_api_client( 196 + State(state): State<AppState>, 197 + auth: UserAuth, 198 + Path(id): Path<String>, 199 + ) -> Result<Json<ApiClientSummary>, AppError> { 200 + auth.require(Permission::ApiClientsView).await?; 201 + 202 + let select_sql = adapt_sql( 203 + "SELECT id, client_key, name, client_id_url, client_uri, redirect_uris, scopes, rate_limit_capacity, rate_limit_refill_rate, is_active, created_by, created_at, updated_at FROM api_clients WHERE id = ?", 204 + state.db_backend, 205 + ); 206 + 207 + type GetRow = ( 208 + String, 209 + String, 210 + String, 211 + String, 212 + String, 213 + String, 214 + String, 215 + Option<i32>, 216 + Option<f64>, 217 + i32, 218 + String, 219 + String, 220 + String, 221 + ); 222 + let row: Option<GetRow> = sqlx::query_as(&select_sql) 223 + .bind(&id) 224 + .fetch_optional(&state.db) 225 + .await 226 + .map_err(|e| AppError::Internal(format!("failed to get api client: {e}")))?; 227 + 228 + let Some(( 229 + id, 230 + client_key, 231 + name, 232 + client_id_url, 233 + client_uri, 234 + redirect_uris_json, 235 + scopes, 236 + rate_limit_capacity, 237 + rate_limit_refill_rate, 238 + is_active, 239 + created_by, 240 + created_at, 241 + updated_at, 242 + )) = row 243 + else { 244 + return Err(AppError::NotFound(format!("api client '{id}' not found"))); 245 + }; 246 + 247 + let redirect_uris: Vec<String> = serde_json::from_str(&redirect_uris_json).unwrap_or_default(); 248 + 249 + Ok(Json(ApiClientSummary { 250 + id, 251 + client_key, 252 + name, 253 + client_id_url, 254 + client_uri, 255 + redirect_uris, 256 + scopes, 257 + rate_limit_capacity, 258 + rate_limit_refill_rate, 259 + is_active: is_active != 0, 260 + created_by, 261 + created_at, 262 + updated_at, 263 + })) 264 + } 265 + 266 + /// PUT /admin/api-clients/:id — update an API client. 267 + pub(super) async fn update_api_client( 268 + State(state): State<AppState>, 269 + auth: UserAuth, 270 + Path(id): Path<String>, 271 + Json(body): Json<UpdateApiClientBody>, 272 + ) -> Result<StatusCode, AppError> { 273 + auth.require(Permission::ApiClientsEdit).await?; 274 + 275 + // Read current values 276 + let select_sql = adapt_sql( 277 + "SELECT client_key, name, client_id_url, client_uri, redirect_uris, scopes, rate_limit_capacity, rate_limit_refill_rate, is_active FROM api_clients WHERE id = ?", 278 + state.db_backend, 279 + ); 280 + 281 + type UpdateRow = ( 282 + String, 283 + String, 284 + String, 285 + String, 286 + String, 287 + String, 288 + Option<i32>, 289 + Option<f64>, 290 + i32, 291 + ); 292 + let row: Option<UpdateRow> = sqlx::query_as(&select_sql) 293 + .bind(&id) 294 + .fetch_optional(&state.db) 295 + .await 296 + .map_err(|e| AppError::Internal(format!("failed to get api client: {e}")))?; 297 + 298 + let Some(( 299 + client_key, 300 + cur_name, 301 + client_id_url, 302 + cur_client_uri, 303 + cur_redirect_uris, 304 + cur_scopes, 305 + cur_capacity, 306 + cur_refill, 307 + cur_active, 308 + )) = row 309 + else { 310 + return Err(AppError::NotFound(format!("api client '{id}' not found"))); 311 + }; 312 + 313 + let name = body.name.unwrap_or(cur_name); 314 + let client_uri = body.client_uri.unwrap_or(cur_client_uri); 315 + let redirect_uris_json = body 316 + .redirect_uris 317 + .map(|uris| serde_json::to_string(&uris).unwrap_or_else(|_| "[]".to_string())) 318 + .unwrap_or(cur_redirect_uris); 319 + let scopes = body.scopes.unwrap_or(cur_scopes); 320 + let capacity = body.rate_limit_capacity.unwrap_or(cur_capacity); 321 + let refill_rate = body.rate_limit_refill_rate.unwrap_or(cur_refill); 322 + let is_active = body 323 + .is_active 324 + .map(|a| if a { 1i32 } else { 0i32 }) 325 + .unwrap_or(cur_active); 326 + let now = now_rfc3339(); 327 + 328 + let update_sql = adapt_sql( 329 + "UPDATE api_clients SET name = ?, client_uri = ?, redirect_uris = ?, scopes = ?, rate_limit_capacity = ?, rate_limit_refill_rate = ?, is_active = ?, updated_at = ? WHERE id = ?", 330 + state.db_backend, 331 + ); 332 + 333 + sqlx::query(&update_sql) 334 + .bind(&name) 335 + .bind(&client_uri) 336 + .bind(&redirect_uris_json) 337 + .bind(&scopes) 338 + .bind(capacity) 339 + .bind(refill_rate) 340 + .bind(is_active) 341 + .bind(&now) 342 + .bind(&id) 343 + .execute(&state.db) 344 + .await 345 + .map_err(|e| AppError::Internal(format!("failed to update api client: {e}")))?; 346 + 347 + // Re-register or remove from OAuth registry based on active status. 348 + let oauth_params = crate::auth::client_registry::ApiClientOAuthParams { 349 + plc_url: state.config.plc_url.clone(), 350 + state_store: state.oauth_state_store.clone(), 351 + session_store_pool: state.db.clone(), 352 + db_backend: state.db_backend, 353 + }; 354 + if is_active != 0 { 355 + let redirect_uris: Vec<String> = 356 + serde_json::from_str(&redirect_uris_json).unwrap_or_default(); 357 + if let Err(e) = state.oauth.register_api_client( 358 + &client_id_url, 359 + &client_uri, 360 + redirect_uris, 361 + &scopes, 362 + &oauth_params, 363 + ) { 364 + tracing::warn!(client_id = %client_id_url, error = %e, "OAuth client re-registration failed"); 365 + } 366 + } else { 367 + state.oauth.remove(&client_id_url); 368 + } 369 + 370 + // Update per-client rate limit config. 371 + if is_active != 0 { 372 + if let (Some(cap), Some(refill)) = (capacity, refill_rate) { 373 + let global = state.rate_limiter.global_config(); 374 + state.rate_limiter.register_client_config( 375 + client_key, 376 + crate::rate_limit::RateLimitConfig { 377 + capacity: cap as u32, 378 + refill_rate: refill, 379 + default_query_cost: global.default_query_cost, 380 + default_procedure_cost: global.default_procedure_cost, 381 + default_proxy_cost: global.default_proxy_cost, 382 + }, 383 + ); 384 + } else { 385 + // Rate limit overrides were cleared — remove per-client config. 386 + state.rate_limiter.remove_client_config(&client_key); 387 + } 388 + } else { 389 + state.rate_limiter.remove_client_config(&client_key); 390 + } 391 + 392 + log_event( 393 + &state.db, 394 + EventLog { 395 + event_type: "api_client.updated".to_string(), 396 + severity: Severity::Info, 397 + actor_did: Some(auth.did.clone()), 398 + subject: Some(id), 399 + detail: serde_json::json!({}), 400 + }, 401 + state.db_backend, 402 + ) 403 + .await; 404 + 405 + Ok(StatusCode::NO_CONTENT) 406 + } 407 + 408 + /// DELETE /admin/api-clients/:id — delete an API client. 409 + pub(super) async fn delete_api_client( 410 + State(state): State<AppState>, 411 + auth: UserAuth, 412 + Path(id): Path<String>, 413 + ) -> Result<StatusCode, AppError> { 414 + auth.require(Permission::ApiClientsDelete).await?; 415 + 416 + // Look up client_id_url and client_key before deleting so we can remove from registries. 417 + let lookup_sql = adapt_sql( 418 + "SELECT client_id_url, client_key FROM api_clients WHERE id = ?", 419 + state.db_backend, 420 + ); 421 + let client_info: Option<(String, String)> = sqlx::query_as(&lookup_sql) 422 + .bind(&id) 423 + .fetch_optional(&state.db) 424 + .await 425 + .map_err(|e| AppError::Internal(format!("failed to look up api client: {e}")))?; 426 + 427 + let delete_sql = adapt_sql("DELETE FROM api_clients WHERE id = ?", state.db_backend); 428 + 429 + let result = sqlx::query(&delete_sql) 430 + .bind(&id) 431 + .execute(&state.db) 432 + .await 433 + .map_err(|e| AppError::Internal(format!("failed to delete api client: {e}")))?; 434 + 435 + if result.rows_affected() == 0 { 436 + return Err(AppError::NotFound(format!("api client '{id}' not found"))); 437 + } 438 + 439 + // Remove from OAuth registry and rate limiter. 440 + if let Some((url, key)) = client_info { 441 + state.oauth.remove(&url); 442 + state.rate_limiter.remove_client_config(&key); 443 + } 444 + 445 + log_event( 446 + &state.db, 447 + EventLog { 448 + event_type: "api_client.deleted".to_string(), 449 + severity: Severity::Info, 450 + actor_did: Some(auth.did.clone()), 451 + subject: Some(id), 452 + detail: serde_json::json!({}), 453 + }, 454 + state.db_backend, 455 + ) 456 + .await; 457 + 458 + Ok(StatusCode::NO_CONTENT) 459 + } 460 + 461 + #[cfg(test)] 462 + mod tests { 463 + #[test] 464 + fn test_client_key_prefix() { 465 + let mut random_bytes = [0u8; 16]; 466 + rand::Rng::fill(&mut rand::rng(), &mut random_bytes); 467 + let key = format!("hvc_{}", hex::encode(random_bytes)); 468 + assert!(key.starts_with("hvc_")); 469 + assert_eq!(key.len(), 4 + 32); // "hvc_" + 32 hex chars 470 + } 471 + }
+11 -5
src/admin/mod.rs
··· 1 + mod api_clients; 1 2 mod api_keys; 2 3 pub(crate) mod auth; 3 4 mod backfill; ··· 74 75 post(rate_limits::upsert).get(rate_limits::list), 75 76 ) 76 77 .route("/rate-limits/enabled", put(rate_limits::set_enabled)) 77 - .route("/rate-limits/allowlist", post(rate_limits::add_allowlist)) 78 - .route( 79 - "/rate-limits/allowlist/{id}", 80 - delete(rate_limits::remove_allowlist), 81 - ) 82 78 .route("/settings", get(settings::list)) 83 79 .route( 84 80 "/settings/logo", ··· 95 91 .route( 96 92 "/plugins/{id}/secrets", 97 93 get(plugins::get_secrets).put(plugins::update_secrets), 94 + ) 95 + .route( 96 + "/api-clients", 97 + post(api_clients::create_api_client).get(api_clients::list_api_clients), 98 + ) 99 + .route( 100 + "/api-clients/{id}", 101 + get(api_clients::get_api_client) 102 + .put(api_clients::update_api_client) 103 + .delete(api_clients::delete_api_client), 98 104 ) 99 105 }
+22 -1
src/admin/permissions.rs
··· 76 76 PluginsCreate, 77 77 #[serde(rename = "plugins:delete")] 78 78 PluginsDelete, 79 + 80 + #[serde(rename = "api-clients:view")] 81 + ApiClientsView, 82 + #[serde(rename = "api-clients:create")] 83 + ApiClientsCreate, 84 + #[serde(rename = "api-clients:edit")] 85 + ApiClientsEdit, 86 + #[serde(rename = "api-clients:delete")] 87 + ApiClientsDelete, 79 88 } 80 89 81 90 impl Permission { ··· 112 121 Self::PluginsRead => "plugins:read", 113 122 Self::PluginsCreate => "plugins:create", 114 123 Self::PluginsDelete => "plugins:delete", 124 + Self::ApiClientsView => "api-clients:view", 125 + Self::ApiClientsCreate => "api-clients:create", 126 + Self::ApiClientsEdit => "api-clients:edit", 127 + Self::ApiClientsDelete => "api-clients:delete", 115 128 } 116 129 } 117 130 118 - /// All 30 permissions. 131 + /// All permissions. 119 132 pub fn all() -> HashSet<Permission> { 120 133 HashSet::from([ 121 134 Self::LexiconsCreate, ··· 148 161 Self::PluginsRead, 149 162 Self::PluginsCreate, 150 163 Self::PluginsDelete, 164 + Self::ApiClientsView, 165 + Self::ApiClientsCreate, 166 + Self::ApiClientsEdit, 167 + Self::ApiClientsDelete, 151 168 ]) 152 169 } 153 170 } ··· 199 216 perms.insert(Permission::PluginsRead); 200 217 perms.insert(Permission::PluginsCreate); 201 218 perms.insert(Permission::PluginsDelete); 219 + perms.insert(Permission::ApiClientsView); 220 + perms.insert(Permission::ApiClientsCreate); 221 + perms.insert(Permission::ApiClientsEdit); 222 + perms.insert(Permission::ApiClientsDelete); 202 223 perms 203 224 } 204 225 Self::FullAccess => Permission::all(),
+2 -124
src/admin/rate_limits.rs
··· 1 1 use axum::Json; 2 - use axum::extract::{Path, State}; 2 + use axum::extract::State; 3 3 use axum::http::StatusCode; 4 4 5 5 use crate::AppState; ··· 9 9 10 10 use super::auth::UserAuth; 11 11 use super::permissions::Permission; 12 - use super::types::{ 13 - AddAllowlistBody, AllowlistEntry, RateLimitsResponse, SetEnabledBody, UpsertRateLimitBody, 14 - }; 12 + use super::types::{RateLimitsResponse, SetEnabledBody, UpsertRateLimitBody}; 15 13 16 14 /// GET /admin/rate-limits — list rate limit config. 17 15 pub(super) async fn list( ··· 44 42 let (capacity, refill_rate, default_query_cost, default_procedure_cost, default_proxy_cost) = 45 43 row.unwrap_or((100, 2.0, 1, 1, 1)); 46 44 47 - let allowlist_sql = adapt_sql( 48 - "SELECT id, cidr, note, created_at FROM rate_limit_allowlist ORDER BY id", 49 - backend, 50 - ); 51 - let allowlist_rows: Vec<(i32, String, Option<String>, String)> = sqlx::query_as(&allowlist_sql) 52 - .fetch_all(&state.db) 53 - .await 54 - .map_err(|e| AppError::Internal(format!("failed to list allowlist: {e}")))?; 55 - 56 - let allowlist: Vec<AllowlistEntry> = allowlist_rows 57 - .into_iter() 58 - .map(|(id, cidr, note, created_at)| AllowlistEntry { 59 - id, 60 - cidr, 61 - note, 62 - created_at, 63 - }) 64 - .collect(); 65 - 66 45 Ok(Json(RateLimitsResponse { 67 46 enabled: enabled == "true", 68 47 capacity, ··· 70 49 default_query_cost, 71 50 default_procedure_cost, 72 51 default_proxy_cost, 73 - allowlist, 74 52 })) 75 53 } 76 54 ··· 178 156 179 157 Ok(StatusCode::NO_CONTENT) 180 158 } 181 - 182 - /// POST /admin/rate-limits/allowlist — add an IP/CIDR to the allowlist. 183 - pub(super) async fn add_allowlist( 184 - State(state): State<AppState>, 185 - auth: UserAuth, 186 - Json(body): Json<AddAllowlistBody>, 187 - ) -> Result<StatusCode, AppError> { 188 - auth.require(Permission::RateLimitsCreate).await?; 189 - 190 - // Validate CIDR syntax; if it's a bare IP, append /32 or /128 191 - let cidr_str = if body.cidr.contains('/') { 192 - body.cidr.clone() 193 - } else if let Ok(ip) = body.cidr.parse::<std::net::IpAddr>() { 194 - match ip { 195 - std::net::IpAddr::V4(_) => format!("{}/32", body.cidr), 196 - std::net::IpAddr::V6(_) => format!("{}/128", body.cidr), 197 - } 198 - } else { 199 - return Err(AppError::BadRequest(format!( 200 - "invalid IP or CIDR: {}", 201 - body.cidr 202 - ))); 203 - }; 204 - 205 - // Validate it parses as IpNet 206 - if cidr_str.parse::<ipnet::IpNet>().is_err() { 207 - return Err(AppError::BadRequest(format!("invalid CIDR: {}", cidr_str))); 208 - } 209 - 210 - let backend = state.db_backend; 211 - let now = now_rfc3339(); 212 - let sql = adapt_sql( 213 - "INSERT INTO rate_limit_allowlist (cidr, note, created_at) VALUES (?, ?, ?)", 214 - backend, 215 - ); 216 - sqlx::query(&sql) 217 - .bind(&cidr_str) 218 - .bind(&body.note) 219 - .bind(&now) 220 - .execute(&state.db) 221 - .await 222 - .map_err(|e| AppError::Internal(format!("failed to add allowlist entry: {e}")))?; 223 - 224 - state.rate_limiter.reload_from_db(&state.db).await; 225 - 226 - log_event( 227 - &state.db, 228 - EventLog { 229 - event_type: "rate_limit.allowlist_added".to_string(), 230 - severity: Severity::Info, 231 - actor_did: Some(auth.did.clone()), 232 - subject: Some(cidr_str), 233 - detail: serde_json::json!({ "note": body.note }), 234 - }, 235 - state.db_backend, 236 - ) 237 - .await; 238 - 239 - Ok(StatusCode::CREATED) 240 - } 241 - 242 - /// DELETE /admin/rate-limits/allowlist/{id} — remove an allowlist entry. 243 - pub(super) async fn remove_allowlist( 244 - State(state): State<AppState>, 245 - auth: UserAuth, 246 - Path(id): Path<i32>, 247 - ) -> Result<StatusCode, AppError> { 248 - auth.require(Permission::RateLimitsDelete).await?; 249 - 250 - let backend = state.db_backend; 251 - let sql = adapt_sql("DELETE FROM rate_limit_allowlist WHERE id = ?", backend); 252 - let result = sqlx::query(&sql) 253 - .bind(id) 254 - .execute(&state.db) 255 - .await 256 - .map_err(|e| AppError::Internal(format!("failed to delete allowlist entry: {e}")))?; 257 - 258 - if result.rows_affected() == 0 { 259 - return Err(AppError::NotFound(format!( 260 - "allowlist entry {id} not found" 261 - ))); 262 - } 263 - 264 - state.rate_limiter.reload_from_db(&state.db).await; 265 - 266 - log_event( 267 - &state.db, 268 - EventLog { 269 - event_type: "rate_limit.allowlist_removed".to_string(), 270 - severity: Severity::Info, 271 - actor_did: Some(auth.did.clone()), 272 - subject: Some(id.to_string()), 273 - detail: serde_json::json!({}), 274 - }, 275 - state.db_backend, 276 - ) 277 - .await; 278 - 279 - Ok(StatusCode::NO_CONTENT) 280 - }
+56 -15
src/admin/types.rs
··· 297 297 } 298 298 299 299 // --------------------------------------------------------------------------- 300 + // API client types 301 + // --------------------------------------------------------------------------- 302 + 303 + #[derive(Deserialize)] 304 + pub(super) struct CreateApiClientBody { 305 + pub(super) name: String, 306 + pub(super) client_id_url: String, 307 + pub(super) client_uri: String, 308 + pub(super) redirect_uris: Vec<String>, 309 + #[serde(default = "default_scopes")] 310 + pub(super) scopes: String, 311 + pub(super) rate_limit_capacity: Option<i32>, 312 + pub(super) rate_limit_refill_rate: Option<f64>, 313 + } 314 + 315 + fn default_scopes() -> String { 316 + "atproto".to_string() 317 + } 318 + 319 + #[derive(Deserialize)] 320 + pub(super) struct UpdateApiClientBody { 321 + pub(super) name: Option<String>, 322 + pub(super) client_uri: Option<String>, 323 + pub(super) redirect_uris: Option<Vec<String>>, 324 + pub(super) scopes: Option<String>, 325 + pub(super) rate_limit_capacity: Option<Option<i32>>, 326 + pub(super) rate_limit_refill_rate: Option<Option<f64>>, 327 + pub(super) is_active: Option<bool>, 328 + } 329 + 330 + #[derive(Serialize)] 331 + pub(super) struct ApiClientSummary { 332 + pub(super) id: String, 333 + pub(super) client_key: String, 334 + pub(super) name: String, 335 + pub(super) client_id_url: String, 336 + pub(super) client_uri: String, 337 + pub(super) redirect_uris: Vec<String>, 338 + pub(super) scopes: String, 339 + pub(super) rate_limit_capacity: Option<i32>, 340 + pub(super) rate_limit_refill_rate: Option<f64>, 341 + pub(super) is_active: bool, 342 + pub(super) created_by: String, 343 + pub(super) created_at: String, 344 + pub(super) updated_at: String, 345 + } 346 + 347 + #[derive(Serialize)] 348 + pub(super) struct CreateApiClientResponse { 349 + pub(super) id: String, 350 + pub(super) client_key: String, 351 + pub(super) name: String, 352 + pub(super) client_id_url: String, 353 + } 354 + 355 + // --------------------------------------------------------------------------- 300 356 // Rate limit types 301 357 // --------------------------------------------------------------------------- 302 358 ··· 314 370 pub(super) enabled: bool, 315 371 } 316 372 317 - #[derive(Deserialize)] 318 - pub(super) struct AddAllowlistBody { 319 - pub(super) cidr: String, 320 - pub(super) note: Option<String>, 321 - } 322 - 323 373 #[derive(Serialize)] 324 374 pub(super) struct RateLimitsResponse { 325 375 pub(super) enabled: bool, ··· 328 378 pub(super) default_query_cost: i32, 329 379 pub(super) default_procedure_cost: i32, 330 380 pub(super) default_proxy_cost: i32, 331 - pub(super) allowlist: Vec<AllowlistEntry>, 332 - } 333 - 334 - #[derive(Serialize)] 335 - pub(super) struct AllowlistEntry { 336 - pub(super) id: i32, 337 - pub(super) cidr: String, 338 - pub(super) note: Option<String>, 339 - pub(super) created_at: String, 340 381 }
+250
src/auth/client_registry.rs
··· 1 + use dashmap::DashMap; 2 + use std::sync::Arc; 3 + 4 + use atrium_identity::did::{CommonDidResolver, CommonDidResolverConfig}; 5 + use atrium_identity::handle::{AtprotoHandleResolver, AtprotoHandleResolverConfig}; 6 + use atrium_oauth::{ 7 + AtprotoClientMetadata, AuthMethod, DefaultHttpClient, GrantType, OAuthClientConfig, 8 + OAuthResolverConfig, 9 + }; 10 + 11 + use crate::HappyViewOAuthClient; 12 + use crate::auth::oauth_store::{DbSessionStore, DbStateStore}; 13 + use crate::db::{DatabaseBackend, adapt_sql}; 14 + use crate::dns::NativeDnsResolver; 15 + 16 + /// Parameters needed to build an OAuth client for an API client registration. 17 + pub struct ApiClientOAuthParams { 18 + pub plc_url: String, 19 + pub state_store: DbStateStore, 20 + pub session_store_pool: sqlx::AnyPool, 21 + pub db_backend: DatabaseBackend, 22 + } 23 + 24 + /// Registry of OAuth clients, keyed by `client_id_url`. 25 + /// 26 + /// Each API client gets its own `OAuthClient` instance so the PDS auth screen 27 + /// shows the correct domain. The default client is HappyView's own identity, 28 + /// used for dashboard auth. 29 + pub struct OAuthClientRegistry { 30 + default_client: Arc<HappyViewOAuthClient>, 31 + clients: DashMap<String, Arc<HappyViewOAuthClient>>, 32 + } 33 + 34 + impl OAuthClientRegistry { 35 + pub fn new(default_client: Arc<HappyViewOAuthClient>) -> Self { 36 + Self { 37 + default_client, 38 + clients: DashMap::new(), 39 + } 40 + } 41 + 42 + /// Register an API client's OAuth client, keyed by its `client_id_url`. 43 + pub fn register(&self, client_id_url: String, client: Arc<HappyViewOAuthClient>) { 44 + self.clients.insert(client_id_url, client); 45 + } 46 + 47 + /// Remove an API client's OAuth client. 48 + pub fn remove(&self, client_id_url: &str) { 49 + self.clients.remove(client_id_url); 50 + } 51 + 52 + /// Look up a client by `client_id_url`. 53 + pub fn get(&self, client_id_url: &str) -> Option<Arc<HappyViewOAuthClient>> { 54 + self.clients.get(client_id_url).map(|r| r.value().clone()) 55 + } 56 + 57 + /// Look up a client by `client_id_url`, falling back to the default. 58 + pub fn get_or_default(&self, client_id_url: Option<&str>) -> Arc<HappyViewOAuthClient> { 59 + if let Some(url) = client_id_url { 60 + self.clients 61 + .get(url) 62 + .map(|r| r.value().clone()) 63 + .unwrap_or_else(|| self.default_client.clone()) 64 + } else { 65 + self.default_client.clone() 66 + } 67 + } 68 + 69 + /// Get the default (HappyView dashboard) client. 70 + pub fn default_client(&self) -> &Arc<HappyViewOAuthClient> { 71 + &self.default_client 72 + } 73 + 74 + /// Build and register a single OAuth client from API client metadata. 75 + /// Used when creating or updating an API client via the admin UI. 76 + pub fn register_api_client( 77 + &self, 78 + client_id_url: &str, 79 + client_uri: &str, 80 + redirect_uris: Vec<String>, 81 + scopes_str: &str, 82 + params: &ApiClientOAuthParams, 83 + ) -> Result<(), String> { 84 + let ApiClientOAuthParams { 85 + plc_url, 86 + state_store, 87 + session_store_pool, 88 + db_backend, 89 + } = params; 90 + let scopes = crate::auth::parse_scope_string(scopes_str); 91 + let scopes = if scopes.is_empty() { 92 + vec![atrium_oauth::Scope::Known( 93 + atrium_oauth::KnownScope::Atproto, 94 + )] 95 + } else { 96 + scopes 97 + }; 98 + 99 + let metadata = AtprotoClientMetadata { 100 + client_id: client_id_url.to_string(), 101 + client_uri: Some(client_uri.to_string()), 102 + redirect_uris, 103 + token_endpoint_auth_method: AuthMethod::None, 104 + grant_types: vec![GrantType::AuthorizationCode, GrantType::RefreshToken], 105 + scopes, 106 + jwks_uri: None, 107 + token_endpoint_auth_signing_alg: None, 108 + }; 109 + 110 + let http = Arc::new(DefaultHttpClient::default()); 111 + let resolver = OAuthResolverConfig { 112 + did_resolver: CommonDidResolver::new(CommonDidResolverConfig { 113 + plc_directory_url: plc_url.to_string(), 114 + http_client: Arc::clone(&http), 115 + }), 116 + handle_resolver: AtprotoHandleResolver::new(AtprotoHandleResolverConfig { 117 + dns_txt_resolver: NativeDnsResolver::new(), 118 + http_client: Arc::clone(&http), 119 + }), 120 + authorization_server_metadata: Default::default(), 121 + protected_resource_metadata: Default::default(), 122 + }; 123 + 124 + match atrium_oauth::OAuthClient::new(OAuthClientConfig { 125 + client_metadata: metadata, 126 + keys: None, 127 + state_store: state_store.clone(), 128 + session_store: DbSessionStore::new(session_store_pool.clone(), *db_backend), 129 + resolver, 130 + }) { 131 + Ok(client) => { 132 + self.register(client_id_url.to_string(), Arc::new(client)); 133 + Ok(()) 134 + } 135 + Err(e) => Err(format!("failed to create OAuth client: {e}")), 136 + } 137 + } 138 + 139 + /// Load all active API clients from the database and register OAuth clients for each. 140 + pub async fn load_from_db( 141 + &self, 142 + db: &sqlx::AnyPool, 143 + db_backend: DatabaseBackend, 144 + plc_url: &str, 145 + state_store: DbStateStore, 146 + session_store_pool: sqlx::AnyPool, 147 + ) { 148 + let sql = adapt_sql( 149 + "SELECT client_id_url, client_uri, redirect_uris, scopes FROM api_clients WHERE is_active = 1", 150 + db_backend, 151 + ); 152 + 153 + let rows: Vec<(String, String, String, String)> = 154 + match sqlx::query_as(&sql).fetch_all(db).await { 155 + Ok(r) => r, 156 + Err(e) => { 157 + tracing::error!("Failed to load API clients from database: {e}"); 158 + return; 159 + } 160 + }; 161 + 162 + for (client_id_url, client_uri, redirect_uris_json, scopes_str) in rows { 163 + let redirect_uris: Vec<String> = 164 + serde_json::from_str(&redirect_uris_json).unwrap_or_default(); 165 + 166 + let scopes = crate::auth::parse_scope_string(&scopes_str); 167 + let scopes = if scopes.is_empty() { 168 + vec![atrium_oauth::Scope::Known( 169 + atrium_oauth::KnownScope::Atproto, 170 + )] 171 + } else { 172 + scopes 173 + }; 174 + 175 + let metadata = AtprotoClientMetadata { 176 + client_id: client_id_url.clone(), 177 + client_uri: Some(client_uri), 178 + redirect_uris, 179 + token_endpoint_auth_method: AuthMethod::None, 180 + grant_types: vec![GrantType::AuthorizationCode, GrantType::RefreshToken], 181 + scopes, 182 + jwks_uri: None, 183 + token_endpoint_auth_signing_alg: None, 184 + }; 185 + 186 + // Each OAuthClient needs its own resolver instances (they're not Clone) 187 + let http = Arc::new(DefaultHttpClient::default()); 188 + let resolver = OAuthResolverConfig { 189 + did_resolver: CommonDidResolver::new(CommonDidResolverConfig { 190 + plc_directory_url: plc_url.to_string(), 191 + http_client: Arc::clone(&http), 192 + }), 193 + handle_resolver: AtprotoHandleResolver::new(AtprotoHandleResolverConfig { 194 + dns_txt_resolver: NativeDnsResolver::new(), 195 + http_client: Arc::clone(&http), 196 + }), 197 + authorization_server_metadata: Default::default(), 198 + protected_resource_metadata: Default::default(), 199 + }; 200 + 201 + match atrium_oauth::OAuthClient::new(OAuthClientConfig { 202 + client_metadata: metadata, 203 + keys: None, 204 + state_store: state_store.clone(), 205 + session_store: DbSessionStore::new(session_store_pool.clone(), db_backend), 206 + resolver, 207 + }) { 208 + Ok(client) => { 209 + tracing::info!(client_id = %client_id_url, "Registered API client OAuth identity"); 210 + self.register(client_id_url, Arc::new(client)); 211 + } 212 + Err(e) => { 213 + tracing::error!(client_id = %client_id_url, error = %e, "Failed to create OAuth client for API client"); 214 + } 215 + } 216 + } 217 + } 218 + } 219 + 220 + #[cfg(test)] 221 + mod tests { 222 + use super::*; 223 + 224 + // Note: we can't easily construct real OAuthClient instances in unit tests 225 + // because they require resolvers, stores, etc. The registry logic is simple 226 + // enough that we test it via integration tests that stand up the full stack. 227 + // These tests verify the DashMap-based lookup logic using a mock approach. 228 + 229 + #[test] 230 + fn test_registry_stores_and_retrieves() { 231 + // We can at least verify the DashMap operations work correctly 232 + let map: DashMap<String, String> = DashMap::new(); 233 + map.insert("key1".to_string(), "val1".to_string()); 234 + 235 + assert!(map.get("key1").is_some()); 236 + assert!(map.get("key2").is_none()); 237 + 238 + map.remove("key1"); 239 + assert!(map.get("key1").is_none()); 240 + } 241 + 242 + #[test] 243 + fn test_registry_overwrite() { 244 + let map: DashMap<String, String> = DashMap::new(); 245 + map.insert("key1".to_string(), "val1".to_string()); 246 + map.insert("key1".to_string(), "val2".to_string()); 247 + 248 + assert_eq!(map.get("key1").unwrap().value(), "val2"); 249 + } 250 + }
+27 -4
src/auth/middleware.rs
··· 15 15 #[derive(Debug, Clone)] 16 16 pub struct Claims { 17 17 did: String, 18 + /// The API client key (e.g. "hvc_...") if the user authenticated via an API client. 19 + client_key: Option<String>, 18 20 } 19 21 22 + /// Separator used to encode `did` and `client_key` in a single cookie value. 23 + /// Newlines cannot appear in DIDs or client keys, so this is safe. 24 + const COOKIE_SEP: char = '\n'; 25 + 20 26 impl Claims { 21 27 /// The authenticated user's DID. 22 28 pub fn did(&self) -> &str { 23 29 &self.did 24 30 } 25 31 32 + /// The API client key, if the user logged in via an API client. 33 + pub fn client_key(&self) -> Option<&str> { 34 + self.client_key.as_deref() 35 + } 36 + 26 37 /// Test-only constructor. 27 38 #[cfg(test)] 28 39 pub fn new_for_test(did: String) -> Self { 29 - Self { did } 40 + Self { 41 + did, 42 + client_key: None, 43 + } 30 44 } 31 45 } 32 46 ··· 43 57 .map_err(|_| AppError::Auth("failed to read cookies".into()))?; 44 58 45 59 if let Some(cookie) = jar.get(COOKIE_NAME) { 46 - let did = cookie.value().to_string(); 47 - return Ok(Claims { did }); 60 + let value = cookie.value().to_string(); 61 + let (did, client_key) = if let Some((d, k)) = value.split_once(COOKIE_SEP) { 62 + (d.to_string(), Some(k.to_string())) 63 + } else { 64 + (value, None) 65 + }; 66 + return Ok(Claims { did, client_key }); 48 67 } 49 68 50 69 // Path 2: Authorization header ··· 63 82 // API key auth is handled by UserAuth extractor which looks up the key. 64 83 // We need to extract the DID from the api_keys table. 65 84 let did = resolve_api_key_did(state, token).await?; 66 - return Ok(Claims { did }); 85 + return Ok(Claims { 86 + did, 87 + client_key: None, 88 + }); 67 89 } 68 90 69 91 // Otherwise, try service auth JWT 70 92 let service_auth = super::service_auth::ServiceAuth::from_bearer(token, state).await?; 71 93 return Ok(Claims { 72 94 did: service_auth.did, 95 + client_key: None, 73 96 }); 74 97 } 75 98
+2
src/auth/mod.rs
··· 1 + pub mod client_registry; 1 2 pub mod middleware; 2 3 pub mod oauth_store; 3 4 pub mod routes; 4 5 pub mod service_auth; 5 6 7 + pub use client_registry::OAuthClientRegistry; 6 8 pub use middleware::Claims; 7 9 pub use routes::parse_scope_string; 8 10 pub use service_auth::ServiceAuth;
+59 -19
src/auth/routes.rs
··· 21 21 handle: String, 22 22 redirect_uri: Option<String>, 23 23 scope: Option<String>, 24 + client_id: Option<String>, 24 25 } 25 26 26 27 /// Parse a whitespace-separated OAuth scope string into typed `Scope` values. ··· 85 86 } 86 87 }; 87 88 88 - tracing::debug!(scopes = ?scopes, "resolved oauth scopes"); 89 + tracing::debug!(scopes = ?scopes, client_id = ?query.client_id, "resolved oauth scopes"); 90 + 91 + // Select the appropriate OAuth client based on client_id 92 + let oauth_client = state.oauth.get_or_default(query.client_id.as_deref()); 89 93 90 94 // Hold the authorize lock so that authorize() + take_last_state_key() are atomic. 91 95 // This prevents concurrent logins from swapping each other's state keys. ··· 96 100 ..Default::default() 97 101 }; 98 102 99 - let url = state 100 - .oauth 103 + let url = oauth_client 101 104 .authorize(&query.handle, options) 102 105 .await 103 106 .map_err(|e| AppError::Internal(format!("OAuth authorize failed: {e}")))?; ··· 113 116 114 117 // Store the redirect URI in the database, keyed by the OAuth state parameter. 115 118 // This avoids third-party cookie issues when Pentaract (cross-origin) calls this endpoint. 116 - if let Some(redirect_uri) = &query.redirect_uri { 117 - tracing::debug!(oauth_state = ?oauth_state, redirect_uri = %redirect_uri, "storing redirect for state"); 119 + // Store redirect URI and client_id for the callback to use 120 + if query.redirect_uri.is_some() || query.client_id.is_some() { 121 + let redirect_uri = query.redirect_uri.as_deref().unwrap_or(""); 122 + tracing::debug!(oauth_state = ?oauth_state, redirect_uri = %redirect_uri, client_id = ?query.client_id, "storing redirect for state"); 118 123 119 124 if let Some(oauth_state) = oauth_state { 120 125 let now = now_rfc3339(); 121 126 let expires_at = (chrono::Utc::now() + chrono::Duration::minutes(10)).to_rfc3339(); 122 127 let sql = adapt_sql( 123 - "INSERT INTO auth_login_redirects (state, redirect_uri, created_at, expires_at) VALUES (?, ?, ?, ?)", 128 + "INSERT INTO auth_login_redirects (state, redirect_uri, client_id, created_at, expires_at) VALUES (?, ?, ?, ?, ?)", 124 129 state.db_backend, 125 130 ); 126 131 let _ = sqlx::query(&sql) 127 132 .bind(&oauth_state) 128 133 .bind(redirect_uri) 134 + .bind(query.client_id.as_deref()) 129 135 .bind(&now) 130 136 .bind(&expires_at) 131 137 .execute(&state.db) ··· 145 151 ) -> Result<(SignedCookieJar<Key>, Redirect), AppError> { 146 152 tracing::debug!(state = ?query.state, "callback received"); 147 153 148 - // Look up the redirect URI from the database before the OAuth library consumes the state 149 - let redirect_url = if let Some(oauth_state) = &query.state { 154 + // Look up the redirect URI and client_id from the database before the OAuth library consumes the state 155 + let (redirect_url, client_id) = if let Some(oauth_state) = &query.state { 150 156 let sql = adapt_sql( 151 - "SELECT redirect_uri FROM auth_login_redirects WHERE state = ? AND expires_at > ?", 157 + "SELECT redirect_uri, client_id FROM auth_login_redirects WHERE state = ? AND expires_at > ?", 152 158 state.db_backend, 153 159 ); 154 160 let now = now_rfc3339(); 155 - let row: Option<(String,)> = sqlx::query_as(&sql) 161 + let row: Option<(String, Option<String>)> = sqlx::query_as(&sql) 156 162 .bind(oauth_state) 157 163 .bind(&now) 158 164 .fetch_optional(&state.db) ··· 172 178 } 173 179 174 180 tracing::debug!(found_redirect = ?row, "redirect lookup result"); 175 - row.map(|(uri,)| uri) 181 + match row { 182 + Some((uri, cid)) => { 183 + let uri = if uri.is_empty() { None } else { Some(uri) }; 184 + (uri, cid) 185 + } 186 + None => (None, None), 187 + } 176 188 } else { 177 189 tracing::debug!("no state in callback query"); 178 - None 190 + (None, None) 179 191 }; 192 + 193 + // Use the same OAuth client that was used for authorize 194 + let oauth_client = state.oauth.get_or_default(client_id.as_deref()); 180 195 181 196 let params = atrium_oauth::CallbackParams { 182 197 code: query.code, ··· 184 199 iss: query.iss, 185 200 }; 186 201 187 - let (session, _app_state) = state 188 - .oauth 202 + let (session, _app_state) = oauth_client 189 203 .callback(params) 190 204 .await 191 205 .map_err(|e| AppError::Internal(format!("OAuth callback failed: {e}")))?; ··· 195 209 .did() 196 210 .await 197 211 .ok_or_else(|| AppError::Internal("no DID in OAuth session".into()))?; 212 + 213 + // Look up the client_key for the API client so we can store it in the session cookie 214 + // for per-client rate limiting. 215 + let client_key = if let Some(ref cid) = client_id { 216 + let sql = adapt_sql( 217 + "SELECT client_key FROM api_clients WHERE client_id_url = ? AND is_active = 1", 218 + state.db_backend, 219 + ); 220 + let row: Option<(String,)> = sqlx::query_as(&sql) 221 + .bind(cid) 222 + .fetch_optional(&state.db) 223 + .await 224 + .unwrap_or(None); 225 + row.map(|(k,)| k) 226 + } else { 227 + None 228 + }; 198 229 199 230 // Use DB-stored redirect, or default to "/" 200 - let redirect_url = redirect_url.unwrap_or_else(|| "/".to_string()); 231 + let redirect_url = redirect_url.unwrap_or_else(|| "/".into()); 201 232 tracing::debug!(redirect_url = %redirect_url, "redirecting after callback"); 202 233 203 234 // Set the session cookie 204 235 // Must use SameSite=None for cross-origin requests (e.g., Pentaract calling HappyView) 205 - let mut session_cookie = Cookie::new(COOKIE_NAME, did.to_string()); 236 + // Encode did and optional client_key separated by newline. 237 + let did_str = did.as_ref(); 238 + let cookie_value = if let Some(ref ck) = client_key { 239 + format!("{did_str}\n{ck}") 240 + } else { 241 + did_str.to_string() 242 + }; 243 + let mut session_cookie = Cookie::new(COOKIE_NAME, cookie_value); 206 244 session_cookie.set_path("/"); 207 245 session_cookie.set_http_only(true); 208 246 session_cookie.set_same_site(axum_extra::extract::cookie::SameSite::None); ··· 227 265 jar: SignedCookieJar<Key>, 228 266 ) -> Result<SignedCookieJar<Key>, AppError> { 229 267 if let Some(cookie) = jar.get(COOKIE_NAME) { 230 - let did_str = cookie.value().to_string(); 268 + let raw = cookie.value().to_string(); 269 + let did_str = raw.split('\n').next().unwrap_or(&raw).to_string(); 231 270 if let Ok(did) = atrium_api::types::string::Did::new(did_str) { 232 - let _ = state.oauth.revoke(&did).await; 271 + let _ = state.oauth.default_client().revoke(&did).await; 233 272 } 234 273 } 235 274 ··· 254 293 let cookie = jar 255 294 .get(COOKIE_NAME) 256 295 .ok_or(AppError::Auth("not authenticated".into()))?; 257 - let did = cookie.value().to_string(); 296 + let raw = cookie.value().to_string(); 297 + let did = raw.split('\n').next().unwrap_or(&raw).to_string(); 258 298 259 299 let backend = state.db_backend; 260 300 let user: Option<(i32,)> =
+1 -1
src/lib.rs
··· 57 57 pub collections_tx: watch::Sender<Vec<String>>, 58 58 pub labeler_subscriptions_tx: watch::Sender<()>, 59 59 pub rate_limiter: Arc<RateLimiter>, 60 - pub oauth: Arc<HappyViewOAuthClient>, 60 + pub oauth: Arc<auth::OAuthClientRegistry>, 61 61 pub oauth_state_store: DbStateStore, 62 62 pub cookie_key: axum_extra::extract::cookie::Key, 63 63 pub plugin_registry: Arc<plugin::PluginRegistry>,
+3 -2
src/lua/atproto_api.rs
··· 346 346 default_procedure_cost: 1, 347 347 default_proxy_cost: 1, 348 348 }, 349 - vec![], 350 349 ), 351 - oauth: std::sync::Arc::new(oauth), 350 + oauth: std::sync::Arc::new(crate::auth::OAuthClientRegistry::new(std::sync::Arc::new( 351 + oauth, 352 + ))), 352 353 oauth_state_store: crate::auth::oauth_store::DbStateStore::new( 353 354 test_db.clone(), 354 355 crate::db::DatabaseBackend::Sqlite,
+3 -2
src/lua/db_api.rs
··· 697 697 default_procedure_cost: 1, 698 698 default_proxy_cost: 1, 699 699 }, 700 - vec![], 701 700 ), 702 - oauth: std::sync::Arc::new(oauth), 701 + oauth: std::sync::Arc::new(crate::auth::OAuthClientRegistry::new(std::sync::Arc::new( 702 + oauth, 703 + ))), 703 704 oauth_state_store: crate::auth::oauth_store::DbStateStore::new( 704 705 test_db.clone(), 705 706 crate::db::DatabaseBackend::Sqlite,
+3 -2
src/lua/execute.rs
··· 1023 1023 default_procedure_cost: 1, 1024 1024 default_proxy_cost: 1, 1025 1025 }, 1026 - vec![], 1027 1026 ), 1028 - oauth: std::sync::Arc::new(oauth), 1027 + oauth: std::sync::Arc::new(crate::auth::OAuthClientRegistry::new(std::sync::Arc::new( 1028 + oauth, 1029 + ))), 1029 1030 oauth_state_store: crate::auth::oauth_store::DbStateStore::new( 1030 1031 test_db.clone(), 1031 1032 crate::db::DatabaseBackend::Sqlite,
+3 -2
src/lua/http_api.rs
··· 163 163 default_procedure_cost: 1, 164 164 default_proxy_cost: 1, 165 165 }, 166 - vec![], 167 166 ), 168 - oauth: std::sync::Arc::new(oauth), 167 + oauth: std::sync::Arc::new(crate::auth::OAuthClientRegistry::new(std::sync::Arc::new( 168 + oauth, 169 + ))), 169 170 oauth_state_store: crate::auth::oauth_store::DbStateStore::new( 170 171 test_db.clone(), 171 172 crate::db::DatabaseBackend::Sqlite,
+41 -3
src/main.rs
··· 5 5 use happyview::db; 6 6 use happyview::dns::NativeDnsResolver; 7 7 use happyview::lexicon::{LexiconRegistry, ParsedLexicon, ProcedureAction}; 8 - use happyview::rate_limit::RateLimiter; 8 + use happyview::rate_limit::{RateLimitConfig, RateLimiter}; 9 9 use happyview::resolve::{fetch_lexicon_from_pds, resolve_nsid_authority}; 10 10 use happyview::{AppState, jetstream, labeler, server}; 11 11 use tokio::sync::watch; ··· 266 266 267 267 // Initialize rate limiter from DB. 268 268 let rl_state = RateLimiter::load_from_db(&db_pool).await; 269 - let rate_limiter = RateLimiter::new(rl_state.enabled, rl_state.global, rl_state.allowlist); 269 + let rate_limiter = RateLimiter::new(rl_state.enabled, rl_state.global); 270 270 tokio::spawn(rate_limiter.clone().spawn_cleanup()); 271 + 272 + // Load per-client rate limit configs from api_clients table. 273 + { 274 + let client_configs: Vec<(String, i32, f64)> = sqlx::query_as( 275 + "SELECT client_key, rate_limit_capacity, rate_limit_refill_rate FROM api_clients WHERE is_active = 1 AND rate_limit_capacity IS NOT NULL AND rate_limit_refill_rate IS NOT NULL", 276 + ) 277 + .fetch_all(&db_pool) 278 + .await 279 + .unwrap_or_default(); 280 + 281 + let global = rate_limiter.global_config(); 282 + for (client_key, capacity, refill_rate) in client_configs { 283 + rate_limiter.register_client_config( 284 + client_key, 285 + RateLimitConfig { 286 + capacity: capacity as u32, 287 + refill_rate, 288 + default_query_cost: global.default_query_cost, 289 + default_procedure_cost: global.default_procedure_cost, 290 + default_proxy_cost: global.default_proxy_cost, 291 + }, 292 + ); 293 + } 294 + } 271 295 272 296 // Build atrium-oauth client 273 297 let dns = NativeDnsResolver::new(); ··· 370 394 let (collections_tx, collections_rx) = watch::channel(initial_collections); 371 395 let (labeler_subscriptions_tx, labeler_subscriptions_rx) = watch::channel(()); 372 396 397 + // Build the OAuth client registry and load API clients from DB 398 + let oauth_registry = Arc::new(happyview::auth::OAuthClientRegistry::new(Arc::new( 399 + oauth_client, 400 + ))); 401 + oauth_registry 402 + .load_from_db( 403 + &db_pool, 404 + db_backend, 405 + &config.plc_url, 406 + oauth_state_store.clone(), 407 + db_pool.clone(), 408 + ) 409 + .await; 410 + 373 411 let state = AppState { 374 412 config: config.clone(), 375 413 http, ··· 379 417 collections_tx, 380 418 labeler_subscriptions_tx, 381 419 rate_limiter, 382 - oauth: Arc::new(oauth_client), 420 + oauth: oauth_registry, 383 421 oauth_state_store, 384 422 cookie_key, 385 423 plugin_registry,
+189 -74
src/rate_limit.rs
··· 1 1 use arc_swap::ArcSwap; 2 2 use dashmap::DashMap; 3 - use ipnet::IpNet; 4 3 use sqlx::AnyPool; 5 - use std::net::IpAddr; 6 4 use std::sync::Arc; 7 5 use std::sync::atomic::{AtomicBool, Ordering}; 8 6 use std::time::{Instant, SystemTime, UNIX_EPOCH}; ··· 41 39 enabled: AtomicBool, 42 40 buckets: DashMap<String, TokenBucket>, 43 41 global_config: ArcSwap<RateLimitConfig>, 44 - allowlist: ArcSwap<Vec<IpNet>>, 42 + /// Per-client config overrides, keyed by client_key (e.g. "hvc_...") 43 + client_configs: DashMap<String, RateLimitConfig>, 45 44 } 46 45 47 46 pub struct RateLimiterState { 48 47 pub enabled: bool, 49 48 pub global: RateLimitConfig, 50 - pub allowlist: Vec<IpNet>, 51 49 } 52 50 53 51 fn now_unix() -> u64 { ··· 58 56 } 59 57 60 58 impl RateLimiter { 61 - pub fn new(enabled: bool, global: RateLimitConfig, allowlist: Vec<IpNet>) -> Arc<Self> { 59 + pub fn new(enabled: bool, global: RateLimitConfig) -> Arc<Self> { 62 60 Arc::new(Self { 63 61 enabled: AtomicBool::new(enabled), 64 62 buckets: DashMap::new(), 65 63 global_config: ArcSwap::new(Arc::new(global)), 66 - allowlist: ArcSwap::new(Arc::new(allowlist)), 64 + client_configs: DashMap::new(), 67 65 }) 68 66 } 69 67 70 - pub fn check(&self, key: &str, cost: u32, client_ip: Option<IpAddr>) -> CheckResult { 68 + pub fn check(&self, key: &str, cost: u32) -> CheckResult { 71 69 if !self.enabled.load(Ordering::Relaxed) { 72 70 return CheckResult::Disabled; 73 71 } 74 72 75 - if let Some(ip) = client_ip { 76 - let list = self.allowlist.load(); 77 - for net in list.iter() { 78 - if net.contains(&ip) { 79 - return CheckResult::Disabled; 80 - } 81 - } 82 - } 83 - 84 - let global = self.global_config.load(); 85 - let capacity = global.capacity; 86 - let refill_rate = global.refill_rate; 73 + // Use per-client config if available, otherwise fall back to global 74 + let (capacity, refill_rate) = if let Some(client_cfg) = self.client_configs.get(key) { 75 + (client_cfg.capacity, client_cfg.refill_rate) 76 + } else { 77 + let global = self.global_config.load(); 78 + (global.capacity, global.refill_rate) 79 + }; 87 80 let cost_f64 = cost as f64; 88 81 89 82 let now = Instant::now(); ··· 144 137 } 145 138 } 146 139 140 + /// Get a snapshot of the current global config. 141 + pub fn global_config(&self) -> Arc<RateLimitConfig> { 142 + self.global_config.load_full() 143 + } 144 + 147 145 pub fn set_enabled(&self, enabled: bool) { 148 146 self.enabled.store(enabled, Ordering::Relaxed); 149 147 } ··· 156 154 self.global_config.store(Arc::new(global)); 157 155 } 158 156 159 - pub fn update_allowlist(&self, entries: Vec<IpNet>) { 160 - self.allowlist.store(Arc::new(entries)); 157 + /// Register a per-client rate limit config override. 158 + pub fn register_client_config(&self, client_key: String, config: RateLimitConfig) { 159 + self.client_configs.insert(client_key, config); 160 + } 161 + 162 + /// Remove a per-client rate limit config override. 163 + pub fn remove_client_config(&self, client_key: &str) { 164 + self.client_configs.remove(client_key); 161 165 } 162 166 163 167 pub async fn spawn_cleanup(self: Arc<Self>) { ··· 210 214 }, 211 215 }; 212 216 213 - // Load allowlist 214 - let cidr_rows: Vec<(String,)> = sqlx::query_as("SELECT cidr FROM rate_limit_allowlist") 215 - .fetch_all(db) 216 - .await 217 - .unwrap_or_default(); 218 - 219 - let allowlist: Vec<IpNet> = cidr_rows 220 - .into_iter() 221 - .filter_map(|(cidr,)| cidr.parse().ok()) 222 - .collect(); 223 - 224 - RateLimiterState { 225 - enabled, 226 - global, 227 - allowlist, 228 - } 217 + RateLimiterState { enabled, global } 229 218 } 230 219 231 220 /// Reload all config from DB and apply to the live limiter. ··· 233 222 let state = Self::load_from_db(db).await; 234 223 self.set_enabled(state.enabled); 235 224 self.update_config(state.global); 236 - self.update_allowlist(state.allowlist); 237 225 } 238 226 } 239 227 ··· 252 240 default_procedure_cost: 1, 253 241 default_proxy_cost: 1, 254 242 }, 255 - vec![], 256 243 ); 257 244 258 245 // Should allow 3 requests (bucket starts full, cost=1 each) 259 246 for _ in 0..3 { 260 - assert!(matches!( 261 - rl.check("k", 1, None), 262 - CheckResult::Allowed { .. } 263 - )); 247 + assert!(matches!(rl.check("k", 1), CheckResult::Allowed { .. })); 264 248 } 265 249 // 4th should be limited 266 - assert!(matches!( 267 - rl.check("k", 1, None), 268 - CheckResult::Limited { .. } 269 - )); 250 + assert!(matches!(rl.check("k", 1), CheckResult::Limited { .. })); 270 251 } 271 252 272 253 #[test] ··· 280 261 default_procedure_cost: 1, 281 262 default_proxy_cost: 1, 282 263 }, 283 - vec![], 284 264 ); 285 265 286 266 // Cost of 5 should allow 2 requests (10 tokens total) 287 267 assert!(matches!( 288 - rl.check("k", 5, None), 268 + rl.check("k", 5), 289 269 CheckResult::Allowed { remaining: 5, .. } 290 270 )); 291 271 assert!(matches!( 292 - rl.check("k", 5, None), 272 + rl.check("k", 5), 293 273 CheckResult::Allowed { remaining: 0, .. } 294 274 )); 295 275 // 3rd should be limited 276 + assert!(matches!(rl.check("k", 5), CheckResult::Limited { .. })); 277 + } 278 + 279 + #[test] 280 + fn disabled_returns_disabled() { 281 + let rl = RateLimiter::new( 282 + false, 283 + RateLimitConfig { 284 + capacity: 1, 285 + refill_rate: 1.0, 286 + default_query_cost: 1, 287 + default_procedure_cost: 1, 288 + default_proxy_cost: 1, 289 + }, 290 + ); 291 + assert!(matches!(rl.check("k", 1), CheckResult::Disabled)); 292 + } 293 + 294 + #[test] 295 + fn default_cost_for_type() { 296 + let rl = RateLimiter::new( 297 + true, 298 + RateLimitConfig { 299 + capacity: 100, 300 + refill_rate: 10.0, 301 + default_query_cost: 2, 302 + default_procedure_cost: 5, 303 + default_proxy_cost: 3, 304 + }, 305 + ); 306 + 307 + assert_eq!(rl.default_cost_for_type("query"), 2); 308 + assert_eq!(rl.default_cost_for_type("procedure"), 5); 309 + assert_eq!(rl.default_cost_for_type("proxy"), 3); 310 + assert_eq!(rl.default_cost_for_type("unknown"), 1); 311 + } 312 + 313 + #[test] 314 + fn per_client_config_override() { 315 + let rl = RateLimiter::new( 316 + true, 317 + RateLimitConfig { 318 + capacity: 10, 319 + refill_rate: 1.0, 320 + default_query_cost: 1, 321 + default_procedure_cost: 1, 322 + default_proxy_cost: 1, 323 + }, 324 + ); 325 + 326 + // Register a client with lower capacity 327 + rl.register_client_config( 328 + "hvc_client1".to_string(), 329 + RateLimitConfig { 330 + capacity: 2, 331 + refill_rate: 0.001, 332 + default_query_cost: 1, 333 + default_procedure_cost: 1, 334 + default_proxy_cost: 1, 335 + }, 336 + ); 337 + 338 + // Client key should use client config (capacity=2) 296 339 assert!(matches!( 297 - rl.check("k", 5, None), 340 + rl.check("hvc_client1", 1), 341 + CheckResult::Allowed { .. } 342 + )); 343 + assert!(matches!( 344 + rl.check("hvc_client1", 1), 345 + CheckResult::Allowed { .. } 346 + )); 347 + assert!(matches!( 348 + rl.check("hvc_client1", 1), 349 + CheckResult::Limited { .. } 350 + )); 351 + 352 + // Other key should use global config (capacity=10) 353 + for _ in 0..10 { 354 + assert!(matches!( 355 + rl.check("other_key", 1), 356 + CheckResult::Allowed { .. } 357 + )); 358 + } 359 + assert!(matches!( 360 + rl.check("other_key", 1), 298 361 CheckResult::Limited { .. } 299 362 )); 300 363 } 301 364 302 365 #[test] 303 - fn disabled_returns_disabled() { 366 + fn per_client_config_fallback_to_global() { 304 367 let rl = RateLimiter::new( 305 - false, 368 + true, 306 369 RateLimitConfig { 307 - capacity: 1, 370 + capacity: 3, 308 371 refill_rate: 1.0, 309 372 default_query_cost: 1, 310 373 default_procedure_cost: 1, 311 374 default_proxy_cost: 1, 312 375 }, 313 - vec![], 314 376 ); 315 - assert!(matches!(rl.check("k", 1, None), CheckResult::Disabled)); 377 + 378 + // No client config registered — should use global (capacity=3) 379 + for _ in 0..3 { 380 + assert!(matches!( 381 + rl.check("hvc_unregistered", 1), 382 + CheckResult::Allowed { .. } 383 + )); 384 + } 385 + assert!(matches!( 386 + rl.check("hvc_unregistered", 1), 387 + CheckResult::Limited { .. } 388 + )); 316 389 } 317 390 318 391 #[test] 319 - fn allowlisted_ip_bypasses() { 392 + fn register_and_remove_client_config() { 320 393 let rl = RateLimiter::new( 321 394 true, 395 + RateLimitConfig { 396 + capacity: 10, 397 + refill_rate: 1.0, 398 + default_query_cost: 1, 399 + default_procedure_cost: 1, 400 + default_proxy_cost: 1, 401 + }, 402 + ); 403 + 404 + rl.register_client_config( 405 + "hvc_temp".to_string(), 322 406 RateLimitConfig { 323 407 capacity: 1, 324 408 refill_rate: 0.001, ··· 326 410 default_procedure_cost: 1, 327 411 default_proxy_cost: 1, 328 412 }, 329 - vec!["10.0.0.0/8".parse().unwrap()], 330 413 ); 331 414 332 - let ip: IpAddr = "10.0.0.5".parse().unwrap(); 333 - // Even after exhausting, allowlisted IP gets Disabled 334 - assert!(matches!(rl.check("k", 1, Some(ip)), CheckResult::Disabled)); 415 + // Should be limited after 1 request (client config capacity=1) 416 + assert!(matches!( 417 + rl.check("hvc_temp", 1), 418 + CheckResult::Allowed { .. } 419 + )); 420 + assert!(matches!( 421 + rl.check("hvc_temp", 1), 422 + CheckResult::Limited { .. } 423 + )); 424 + 425 + // Remove client config — new bucket should use global (capacity=10) 426 + rl.remove_client_config("hvc_temp"); 427 + // Note: the old bucket still exists and is exhausted, but capacity was 428 + // updated to global. A new bucket would get global capacity. 335 429 } 336 430 337 431 #[test] 338 - fn default_cost_for_type() { 432 + fn different_clients_get_separate_buckets() { 339 433 let rl = RateLimiter::new( 340 434 true, 341 435 RateLimitConfig { 342 - capacity: 100, 343 - refill_rate: 10.0, 344 - default_query_cost: 2, 345 - default_procedure_cost: 5, 346 - default_proxy_cost: 3, 436 + capacity: 2, 437 + refill_rate: 0.001, 438 + default_query_cost: 1, 439 + default_procedure_cost: 1, 440 + default_proxy_cost: 1, 347 441 }, 348 - vec![], 349 442 ); 350 443 351 - assert_eq!(rl.default_cost_for_type("query"), 2); 352 - assert_eq!(rl.default_cost_for_type("procedure"), 5); 353 - assert_eq!(rl.default_cost_for_type("proxy"), 3); 354 - assert_eq!(rl.default_cost_for_type("unknown"), 1); 444 + // Exhaust client A 445 + assert!(matches!( 446 + rl.check("clientA", 1), 447 + CheckResult::Allowed { .. } 448 + )); 449 + assert!(matches!( 450 + rl.check("clientA", 1), 451 + CheckResult::Allowed { .. } 452 + )); 453 + assert!(matches!( 454 + rl.check("clientA", 1), 455 + CheckResult::Limited { .. } 456 + )); 457 + 458 + // Client B should still have tokens 459 + assert!(matches!( 460 + rl.check("clientB", 1), 461 + CheckResult::Allowed { .. } 462 + )); 463 + assert!(matches!( 464 + rl.check("clientB", 1), 465 + CheckResult::Allowed { .. } 466 + )); 467 + assert!(matches!( 468 + rl.check("clientB", 1), 469 + CheckResult::Limited { .. } 470 + )); 355 471 } 356 472 357 473 #[test] ··· 365 481 default_procedure_cost: 1, 366 482 default_proxy_cost: 1, 367 483 }, 368 - vec![], 369 484 ); 370 485 assert!(rl.is_enabled()); 371 486 rl.set_enabled(false); 372 487 assert!(!rl.is_enabled()); 373 - assert!(matches!(rl.check("k", 1, None), CheckResult::Disabled)); 488 + assert!(matches!(rl.check("k", 1), CheckResult::Disabled)); 374 489 } 375 490 }
+1
src/repo/session.rs
··· 14 14 Did::new(did.to_string()).map_err(|_| AppError::Auth(format!("invalid DID: {did}")))?; 15 15 state 16 16 .oauth 17 + .default_client() 17 18 .restore(&did) 18 19 .await 19 20 .map_err(|e| AppError::Auth(format!("no OAuth session for {}: {e}", did.as_ref())))
-8
src/repo/upload_blob.rs
··· 2 2 use axum::extract::State; 3 3 use axum::http::HeaderMap; 4 4 use axum::response::Response; 5 - use std::net::IpAddr; 6 5 7 6 use crate::AppState; 8 7 use crate::auth::Claims; ··· 18 17 headers: HeaderMap, 19 18 body: Bytes, 20 19 ) -> Result<Response, AppError> { 21 - let client_ip: Option<IpAddr> = headers 22 - .get("x-forwarded-for") 23 - .and_then(|v| v.to_str().ok()) 24 - .and_then(|s| s.split(',').next()) 25 - .and_then(|s| s.trim().parse().ok()); 26 - 27 20 let rate_key = claims.did().to_string(); 28 21 let check = state.rate_limiter.check( 29 22 &rate_key, 30 23 state.rate_limiter.default_cost_for_type("procedure"), 31 - client_ip, 32 24 ); 33 25 34 26 if let CheckResult::Limited {
+6 -16
src/server.rs
··· 7 7 use bytes::Bytes; 8 8 use http_body_util::Full; 9 9 use std::convert::Infallible; 10 - use std::net::IpAddr; 11 10 use tower_http::cors::CorsLayer; 12 11 use tower_http::services::ServeDir; 13 12 use tower_http::trace::TraceLayer; ··· 106 105 } 107 106 108 107 async fn client_metadata(State(state): State<AppState>) -> Json<serde_json::Value> { 109 - let mut metadata = serde_json::to_value(&state.oauth.client_metadata).unwrap_or_default(); 108 + let mut metadata = 109 + serde_json::to_value(&state.oauth.default_client().client_metadata).unwrap_or_default(); 110 110 111 111 // The `client_id` field in the response must exactly match the URL the 112 112 // authorization server fetched. ··· 161 161 Json(metadata) 162 162 } 163 163 164 - fn ip_from_forwarded_for(value: Option<&str>) -> Option<IpAddr> { 165 - let forwarded = value?; 166 - let first = forwarded.split(',').next()?; 167 - first.trim().parse::<IpAddr>().ok() 168 - } 169 - 170 164 async fn get_profile( 171 165 State(state): State<AppState>, 172 166 claims: Claims, 173 - headers: HeaderMap, 167 + _headers: HeaderMap, 174 168 ) -> Result<Response, AppError> { 175 - let client_ip = 176 - ip_from_forwarded_for(headers.get("x-forwarded-for").and_then(|v| v.to_str().ok())); 177 169 let rate_key = claims.did().to_string(); 178 - let check = state.rate_limiter.check( 179 - &rate_key, 180 - state.rate_limiter.default_cost_for_type("query"), 181 - client_ip, 182 - ); 170 + let check = state 171 + .rate_limiter 172 + .check(&rate_key, state.rate_limiter.default_cost_for_type("query")); 183 173 184 174 if let CheckResult::Limited { 185 175 retry_after,
+11 -40
src/xrpc/mod.rs
··· 3 3 4 4 use axum::Json; 5 5 use axum::body::Body; 6 - use axum::extract::{ConnectInfo, FromRequestParts, Path, RawQuery, State}; 6 + use axum::extract::{FromRequestParts, Path, RawQuery, State}; 7 7 use axum::http::StatusCode; 8 8 use axum::http::request::Parts; 9 9 use axum::response::Response; 10 10 use serde_json::Value; 11 11 use std::collections::HashMap; 12 - use std::net::{IpAddr, SocketAddr}; 13 12 14 13 use crate::AppState; 15 14 use crate::auth::Claims; ··· 155 154 .unwrap()) 156 155 } 157 156 158 - /// Extract client IP from X-Forwarded-For header or ConnectInfo. 159 - fn extract_client_ip(parts: &Parts) -> Option<IpAddr> { 160 - if let Some(forwarded) = parts 161 - .headers 162 - .get("x-forwarded-for") 163 - .and_then(|v| v.to_str().ok()) 164 - && let Some(first) = forwarded.split(',').next() 165 - && let Ok(ip) = first.trim().parse::<IpAddr>() 166 - { 167 - return Some(ip); 168 - } 169 - parts 170 - .extensions 171 - .get::<ConnectInfo<SocketAddr>>() 172 - .map(|ci| ci.0.ip()) 173 - } 174 - 175 157 /// Apply rate limit headers to a response. 176 158 fn apply_rate_limit_headers(response: &mut Response, remaining: u32, limit: u32, reset: u64) { 177 159 let headers = response.headers_mut(); ··· 189 171 ) -> Result<Response, AppError> { 190 172 let raw_query = raw_query.unwrap_or_default(); 191 173 let mut params = parse_query_params(&raw_query); 192 - let client_ip = extract_client_ip(&parts); 193 174 let claims = Claims::from_request_parts(&mut parts, &state).await.ok(); 194 175 195 - // Rate limit check 176 + // Rate limit check — keyed by client_key for API client requests, "anonymous" otherwise 196 177 let rate_key = claims 197 178 .as_ref() 198 - .map(|c| c.did().to_string()) 199 - .unwrap_or_else(|| { 200 - client_ip 201 - .map(|ip| ip.to_string()) 202 - .unwrap_or_else(|| "unknown".to_string()) 203 - }); 179 + .and_then(|c| c.client_key().map(|k| k.to_string())) 180 + .unwrap_or_else(|| "anonymous".to_string()); 204 181 205 182 let lexicon = state.lexicons.get(&method).await; 206 183 ··· 214 191 state.rate_limiter.default_cost_for_type("proxy") 215 192 }; 216 193 217 - let check = state.rate_limiter.check(&rate_key, cost, client_ip); 194 + let check = state.rate_limiter.check(&rate_key, cost); 218 195 219 196 match check { 220 197 CheckResult::Limited { ··· 270 247 Ok(response) 271 248 } 272 249 273 - /// Extract client IP from X-Forwarded-For header value. 274 - fn ip_from_forwarded_for(value: Option<&str>) -> Option<IpAddr> { 275 - let forwarded = value?; 276 - let first = forwarded.split(',').next()?; 277 - first.trim().parse::<IpAddr>().ok() 278 - } 279 - 280 250 /// Catch-all POST handler for XRPC procedures. 281 251 pub async fn xrpc_post( 282 252 State(state): State<AppState>, 283 253 Path(method): Path<String>, 284 254 RawQuery(raw_query): RawQuery, 285 255 claims: Claims, 286 - headers: axum::http::HeaderMap, 256 + _headers: axum::http::HeaderMap, 287 257 Json(body): Json<serde_json::Value>, 288 258 ) -> Result<Response, AppError> { 289 259 let raw_query = raw_query.unwrap_or_default(); 290 260 let mut params = parse_query_params(&raw_query); 291 - let client_ip = 292 - ip_from_forwarded_for(headers.get("x-forwarded-for").and_then(|v| v.to_str().ok())); 293 - let rate_key = claims.did().to_string(); 261 + let rate_key = claims 262 + .client_key() 263 + .map(|k| k.to_string()) 264 + .unwrap_or_else(|| "anonymous".to_string()); 294 265 295 266 let lexicon = state.lexicons.get(&method).await; 296 267 ··· 304 275 state.rate_limiter.default_cost_for_type("proxy") 305 276 }; 306 277 307 - let check = state.rate_limiter.check(&rate_key, cost, client_ip); 278 + let check = state.rate_limiter.check(&rate_key, cost); 308 279 309 280 match check { 310 281 CheckResult::Limited {
+3 -2
tests/common/app.rs
··· 123 123 default_procedure_cost: 1, 124 124 default_proxy_cost: 1, 125 125 }, 126 - vec![], 127 126 ), 128 - oauth: std::sync::Arc::new(oauth), 127 + oauth: std::sync::Arc::new(happyview::auth::OAuthClientRegistry::new( 128 + std::sync::Arc::new(oauth), 129 + )), 129 130 oauth_state_store: happyview::auth::oauth_store::DbStateStore::new( 130 131 pool.clone(), 131 132 backend,
+643
tests/e2e_api_clients.rs
··· 1 + mod common; 2 + 3 + use axum::body::Body; 4 + use axum::http::{Request, StatusCode}; 5 + use http_body_util::BodyExt; 6 + use serde_json::{Value, json}; 7 + use serial_test::serial; 8 + use tower::ServiceExt; 9 + 10 + use common::app::TestApp; 11 + 12 + // --------------------------------------------------------------------------- 13 + // Helpers 14 + // --------------------------------------------------------------------------- 15 + 16 + async fn json_body(resp: axum::response::Response) -> Value { 17 + let body = resp.into_body().collect().await.unwrap().to_bytes(); 18 + serde_json::from_slice(&body).unwrap() 19 + } 20 + 21 + fn admin_get( 22 + uri: &str, 23 + cookie: (axum::http::HeaderName, axum::http::HeaderValue), 24 + ) -> Request<Body> { 25 + Request::builder() 26 + .uri(uri) 27 + .header(cookie.0, cookie.1) 28 + .body(Body::empty()) 29 + .unwrap() 30 + } 31 + 32 + fn admin_post( 33 + uri: &str, 34 + cookie: (axum::http::HeaderName, axum::http::HeaderValue), 35 + body: &Value, 36 + ) -> Request<Body> { 37 + Request::builder() 38 + .method("POST") 39 + .uri(uri) 40 + .header(cookie.0, cookie.1) 41 + .header("content-type", "application/json") 42 + .body(Body::from(serde_json::to_vec(body).unwrap())) 43 + .unwrap() 44 + } 45 + 46 + fn admin_put( 47 + uri: &str, 48 + cookie: (axum::http::HeaderName, axum::http::HeaderValue), 49 + body: &Value, 50 + ) -> Request<Body> { 51 + Request::builder() 52 + .method("PUT") 53 + .uri(uri) 54 + .header(cookie.0, cookie.1) 55 + .header("content-type", "application/json") 56 + .body(Body::from(serde_json::to_vec(body).unwrap())) 57 + .unwrap() 58 + } 59 + 60 + fn admin_delete( 61 + uri: &str, 62 + cookie: (axum::http::HeaderName, axum::http::HeaderValue), 63 + ) -> Request<Body> { 64 + Request::builder() 65 + .method("DELETE") 66 + .uri(uri) 67 + .header(cookie.0, cookie.1) 68 + .body(Body::empty()) 69 + .unwrap() 70 + } 71 + 72 + fn sample_api_client_body() -> Value { 73 + json!({ 74 + "name": "Test App", 75 + "client_id_url": "https://testapp.example.com/oauth-client-metadata.json", 76 + "client_uri": "https://testapp.example.com", 77 + "redirect_uris": ["https://happyview.example.com/auth/callback"], 78 + "scopes": "atproto" 79 + }) 80 + } 81 + 82 + // --------------------------------------------------------------------------- 83 + // Create 84 + // --------------------------------------------------------------------------- 85 + 86 + #[tokio::test] 87 + #[serial] 88 + #[ignore] 89 + async fn create_api_client_returns_201() { 90 + let app = TestApp::new().await; 91 + let body = sample_api_client_body(); 92 + 93 + let resp = app 94 + .router 95 + .clone() 96 + .oneshot(admin_post("/admin/api-clients", app.admin_cookie(), &body)) 97 + .await 98 + .unwrap(); 99 + 100 + assert_eq!(resp.status(), StatusCode::CREATED); 101 + let json = json_body(resp).await; 102 + assert_eq!(json["name"], "Test App"); 103 + assert_eq!( 104 + json["client_id_url"], 105 + "https://testapp.example.com/oauth-client-metadata.json" 106 + ); 107 + let key = json["client_key"].as_str().unwrap(); 108 + assert!(key.starts_with("hvc_"), "client_key should start with hvc_"); 109 + assert_eq!(key.len(), 36); // "hvc_" (4) + 32 hex chars 110 + assert!(json["id"].as_str().is_some()); 111 + } 112 + 113 + #[tokio::test] 114 + #[serial] 115 + #[ignore] 116 + async fn create_api_client_duplicate_client_id_url_fails() { 117 + let app = TestApp::new().await; 118 + let body = sample_api_client_body(); 119 + 120 + // First create succeeds 121 + let resp = app 122 + .router 123 + .clone() 124 + .oneshot(admin_post("/admin/api-clients", app.admin_cookie(), &body)) 125 + .await 126 + .unwrap(); 127 + assert_eq!(resp.status(), StatusCode::CREATED); 128 + 129 + // Second create with same client_id_url should fail (UNIQUE constraint) 130 + let resp = app 131 + .router 132 + .clone() 133 + .oneshot(admin_post("/admin/api-clients", app.admin_cookie(), &body)) 134 + .await 135 + .unwrap(); 136 + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); 137 + } 138 + 139 + #[tokio::test] 140 + #[serial] 141 + #[ignore] 142 + async fn create_api_client_registers_in_oauth_registry() { 143 + let app = TestApp::new().await; 144 + let body = sample_api_client_body(); 145 + 146 + let resp = app 147 + .router 148 + .clone() 149 + .oneshot(admin_post("/admin/api-clients", app.admin_cookie(), &body)) 150 + .await 151 + .unwrap(); 152 + assert_eq!(resp.status(), StatusCode::CREATED); 153 + 154 + // The OAuth registry should now have this client 155 + let client_id_url = "https://testapp.example.com/oauth-client-metadata.json"; 156 + assert!( 157 + app.state.oauth.get(client_id_url).is_some(), 158 + "OAuth registry should contain the newly created client" 159 + ); 160 + } 161 + 162 + // --------------------------------------------------------------------------- 163 + // List 164 + // --------------------------------------------------------------------------- 165 + 166 + #[tokio::test] 167 + #[serial] 168 + #[ignore] 169 + async fn list_api_clients_empty() { 170 + let app = TestApp::new().await; 171 + 172 + let resp = app 173 + .router 174 + .clone() 175 + .oneshot(admin_get("/admin/api-clients", app.admin_cookie())) 176 + .await 177 + .unwrap(); 178 + 179 + assert_eq!(resp.status(), StatusCode::OK); 180 + let json = json_body(resp).await; 181 + assert!(json.as_array().unwrap().is_empty()); 182 + } 183 + 184 + #[tokio::test] 185 + #[serial] 186 + #[ignore] 187 + async fn list_api_clients_returns_created_clients() { 188 + let app = TestApp::new().await; 189 + 190 + // Create two clients 191 + let body1 = json!({ 192 + "name": "App One", 193 + "client_id_url": "https://one.example.com/oauth-client-metadata.json", 194 + "client_uri": "https://one.example.com", 195 + "redirect_uris": ["https://happyview.example.com/auth/callback"], 196 + "scopes": "atproto" 197 + }); 198 + let body2 = json!({ 199 + "name": "App Two", 200 + "client_id_url": "https://two.example.com/oauth-client-metadata.json", 201 + "client_uri": "https://two.example.com", 202 + "redirect_uris": ["https://happyview.example.com/auth/callback"], 203 + "scopes": "atproto" 204 + }); 205 + 206 + app.router 207 + .clone() 208 + .oneshot(admin_post("/admin/api-clients", app.admin_cookie(), &body1)) 209 + .await 210 + .unwrap(); 211 + app.router 212 + .clone() 213 + .oneshot(admin_post("/admin/api-clients", app.admin_cookie(), &body2)) 214 + .await 215 + .unwrap(); 216 + 217 + let resp = app 218 + .router 219 + .clone() 220 + .oneshot(admin_get("/admin/api-clients", app.admin_cookie())) 221 + .await 222 + .unwrap(); 223 + 224 + assert_eq!(resp.status(), StatusCode::OK); 225 + let json = json_body(resp).await; 226 + let arr = json.as_array().unwrap(); 227 + assert_eq!(arr.len(), 2); 228 + 229 + // Verify fields are present 230 + for client in arr { 231 + assert!(client["id"].as_str().is_some()); 232 + assert!(client["client_key"].as_str().is_some()); 233 + assert!(client["name"].as_str().is_some()); 234 + assert!(client["client_id_url"].as_str().is_some()); 235 + assert!(client["is_active"].as_bool().is_some()); 236 + assert!(client["created_by"].as_str().is_some()); 237 + } 238 + } 239 + 240 + // --------------------------------------------------------------------------- 241 + // Get 242 + // --------------------------------------------------------------------------- 243 + 244 + #[tokio::test] 245 + #[serial] 246 + #[ignore] 247 + async fn get_api_client_returns_details() { 248 + let app = TestApp::new().await; 249 + let body = sample_api_client_body(); 250 + 251 + let create_resp = app 252 + .router 253 + .clone() 254 + .oneshot(admin_post("/admin/api-clients", app.admin_cookie(), &body)) 255 + .await 256 + .unwrap(); 257 + let created = json_body(create_resp).await; 258 + let id = created["id"].as_str().unwrap(); 259 + 260 + let resp = app 261 + .router 262 + .clone() 263 + .oneshot(admin_get( 264 + &format!("/admin/api-clients/{id}"), 265 + app.admin_cookie(), 266 + )) 267 + .await 268 + .unwrap(); 269 + 270 + assert_eq!(resp.status(), StatusCode::OK); 271 + let json = json_body(resp).await; 272 + assert_eq!(json["name"], "Test App"); 273 + assert_eq!( 274 + json["client_id_url"], 275 + "https://testapp.example.com/oauth-client-metadata.json" 276 + ); 277 + assert_eq!(json["client_uri"], "https://testapp.example.com"); 278 + assert_eq!(json["scopes"], "atproto"); 279 + assert_eq!(json["is_active"], true); 280 + assert_eq!(json["created_by"], "did:plc:testadmin"); 281 + assert!(json["redirect_uris"].as_array().unwrap().len() == 1); 282 + } 283 + 284 + #[tokio::test] 285 + #[serial] 286 + #[ignore] 287 + async fn get_api_client_not_found() { 288 + let app = TestApp::new().await; 289 + 290 + let resp = app 291 + .router 292 + .clone() 293 + .oneshot(admin_get( 294 + "/admin/api-clients/00000000-0000-0000-0000-000000000000", 295 + app.admin_cookie(), 296 + )) 297 + .await 298 + .unwrap(); 299 + 300 + assert_eq!(resp.status(), StatusCode::NOT_FOUND); 301 + } 302 + 303 + // --------------------------------------------------------------------------- 304 + // Update 305 + // --------------------------------------------------------------------------- 306 + 307 + #[tokio::test] 308 + #[serial] 309 + #[ignore] 310 + async fn update_api_client_changes_fields() { 311 + let app = TestApp::new().await; 312 + 313 + // Create 314 + let create_resp = app 315 + .router 316 + .clone() 317 + .oneshot(admin_post( 318 + "/admin/api-clients", 319 + app.admin_cookie(), 320 + &sample_api_client_body(), 321 + )) 322 + .await 323 + .unwrap(); 324 + let created = json_body(create_resp).await; 325 + let id = created["id"].as_str().unwrap(); 326 + 327 + // Update 328 + let update_body = json!({ 329 + "name": "Updated App", 330 + "scopes": "atproto transition:generic" 331 + }); 332 + let resp = app 333 + .router 334 + .clone() 335 + .oneshot(admin_put( 336 + &format!("/admin/api-clients/{id}"), 337 + app.admin_cookie(), 338 + &update_body, 339 + )) 340 + .await 341 + .unwrap(); 342 + assert_eq!(resp.status(), StatusCode::NO_CONTENT); 343 + 344 + // Verify 345 + let resp = app 346 + .router 347 + .clone() 348 + .oneshot(admin_get( 349 + &format!("/admin/api-clients/{id}"), 350 + app.admin_cookie(), 351 + )) 352 + .await 353 + .unwrap(); 354 + let json = json_body(resp).await; 355 + assert_eq!(json["name"], "Updated App"); 356 + assert_eq!(json["scopes"], "atproto transition:generic"); 357 + // Unchanged fields should remain 358 + assert_eq!(json["client_uri"], "https://testapp.example.com"); 359 + } 360 + 361 + #[tokio::test] 362 + #[serial] 363 + #[ignore] 364 + async fn update_api_client_not_found() { 365 + let app = TestApp::new().await; 366 + 367 + let resp = app 368 + .router 369 + .clone() 370 + .oneshot(admin_put( 371 + "/admin/api-clients/00000000-0000-0000-0000-000000000000", 372 + app.admin_cookie(), 373 + &json!({"name": "Nope"}), 374 + )) 375 + .await 376 + .unwrap(); 377 + 378 + assert_eq!(resp.status(), StatusCode::NOT_FOUND); 379 + } 380 + 381 + #[tokio::test] 382 + #[serial] 383 + #[ignore] 384 + async fn update_api_client_deactivate_removes_from_registry() { 385 + let app = TestApp::new().await; 386 + 387 + let create_resp = app 388 + .router 389 + .clone() 390 + .oneshot(admin_post( 391 + "/admin/api-clients", 392 + app.admin_cookie(), 393 + &sample_api_client_body(), 394 + )) 395 + .await 396 + .unwrap(); 397 + let created = json_body(create_resp).await; 398 + let id = created["id"].as_str().unwrap(); 399 + 400 + let client_id_url = "https://testapp.example.com/oauth-client-metadata.json"; 401 + assert!(app.state.oauth.get(client_id_url).is_some()); 402 + 403 + // Deactivate 404 + let resp = app 405 + .router 406 + .clone() 407 + .oneshot(admin_put( 408 + &format!("/admin/api-clients/{id}"), 409 + app.admin_cookie(), 410 + &json!({"is_active": false}), 411 + )) 412 + .await 413 + .unwrap(); 414 + assert_eq!(resp.status(), StatusCode::NO_CONTENT); 415 + 416 + // Should be removed from registry 417 + assert!( 418 + app.state.oauth.get(client_id_url).is_none(), 419 + "Deactivated client should be removed from OAuth registry" 420 + ); 421 + } 422 + 423 + // --------------------------------------------------------------------------- 424 + // Delete 425 + // --------------------------------------------------------------------------- 426 + 427 + #[tokio::test] 428 + #[serial] 429 + #[ignore] 430 + async fn delete_api_client_returns_204() { 431 + let app = TestApp::new().await; 432 + 433 + let create_resp = app 434 + .router 435 + .clone() 436 + .oneshot(admin_post( 437 + "/admin/api-clients", 438 + app.admin_cookie(), 439 + &sample_api_client_body(), 440 + )) 441 + .await 442 + .unwrap(); 443 + let created = json_body(create_resp).await; 444 + let id = created["id"].as_str().unwrap(); 445 + 446 + let resp = app 447 + .router 448 + .clone() 449 + .oneshot(admin_delete( 450 + &format!("/admin/api-clients/{id}"), 451 + app.admin_cookie(), 452 + )) 453 + .await 454 + .unwrap(); 455 + 456 + assert_eq!(resp.status(), StatusCode::NO_CONTENT); 457 + 458 + // Verify gone from list 459 + let resp = app 460 + .router 461 + .clone() 462 + .oneshot(admin_get("/admin/api-clients", app.admin_cookie())) 463 + .await 464 + .unwrap(); 465 + let json = json_body(resp).await; 466 + assert!(json.as_array().unwrap().is_empty()); 467 + } 468 + 469 + #[tokio::test] 470 + #[serial] 471 + #[ignore] 472 + async fn delete_api_client_removes_from_oauth_registry() { 473 + let app = TestApp::new().await; 474 + 475 + let create_resp = app 476 + .router 477 + .clone() 478 + .oneshot(admin_post( 479 + "/admin/api-clients", 480 + app.admin_cookie(), 481 + &sample_api_client_body(), 482 + )) 483 + .await 484 + .unwrap(); 485 + let created = json_body(create_resp).await; 486 + let id = created["id"].as_str().unwrap(); 487 + 488 + let client_id_url = "https://testapp.example.com/oauth-client-metadata.json"; 489 + assert!(app.state.oauth.get(client_id_url).is_some()); 490 + 491 + app.router 492 + .clone() 493 + .oneshot(admin_delete( 494 + &format!("/admin/api-clients/{id}"), 495 + app.admin_cookie(), 496 + )) 497 + .await 498 + .unwrap(); 499 + 500 + assert!( 501 + app.state.oauth.get(client_id_url).is_none(), 502 + "Deleted client should be removed from OAuth registry" 503 + ); 504 + } 505 + 506 + #[tokio::test] 507 + #[serial] 508 + #[ignore] 509 + async fn delete_api_client_not_found() { 510 + let app = TestApp::new().await; 511 + 512 + let resp = app 513 + .router 514 + .clone() 515 + .oneshot(admin_delete( 516 + "/admin/api-clients/00000000-0000-0000-0000-000000000000", 517 + app.admin_cookie(), 518 + )) 519 + .await 520 + .unwrap(); 521 + 522 + assert_eq!(resp.status(), StatusCode::NOT_FOUND); 523 + } 524 + 525 + // --------------------------------------------------------------------------- 526 + // Permission enforcement 527 + // --------------------------------------------------------------------------- 528 + 529 + #[tokio::test] 530 + #[serial] 531 + #[ignore] 532 + async fn api_clients_no_auth_returns_401() { 533 + let app = TestApp::new().await; 534 + 535 + let resp = app 536 + .router 537 + .clone() 538 + .oneshot( 539 + Request::builder() 540 + .uri("/admin/api-clients") 541 + .body(Body::empty()) 542 + .unwrap(), 543 + ) 544 + .await 545 + .unwrap(); 546 + 547 + assert_eq!(resp.status(), StatusCode::UNAUTHORIZED); 548 + } 549 + 550 + #[tokio::test] 551 + #[serial] 552 + #[ignore] 553 + async fn api_clients_non_admin_returns_403() { 554 + let app = TestApp::new().await; 555 + 556 + let resp = app 557 + .router 558 + .clone() 559 + .oneshot(admin_get( 560 + "/admin/api-clients", 561 + common::auth::admin_cookie_header("did:plc:notadmin", &app.state.cookie_key), 562 + )) 563 + .await 564 + .unwrap(); 565 + 566 + assert_eq!(resp.status(), StatusCode::FORBIDDEN); 567 + } 568 + 569 + // --------------------------------------------------------------------------- 570 + // OAuth registry (unit-level via AppState) 571 + // --------------------------------------------------------------------------- 572 + 573 + #[tokio::test] 574 + #[serial] 575 + #[ignore] 576 + async fn oauth_registry_get_or_default_returns_default_for_unknown() { 577 + let app = TestApp::new().await; 578 + 579 + let client = app 580 + .state 581 + .oauth 582 + .get_or_default(Some("https://unknown.example.com/metadata.json")); 583 + let default = app.state.oauth.default_client(); 584 + 585 + // Should be the same Arc (default client) 586 + assert!(std::sync::Arc::ptr_eq(&client, default)); 587 + } 588 + 589 + #[tokio::test] 590 + #[serial] 591 + #[ignore] 592 + async fn oauth_registry_get_or_default_returns_default_for_none() { 593 + let app = TestApp::new().await; 594 + 595 + let client = app.state.oauth.get_or_default(None); 596 + let default = app.state.oauth.default_client(); 597 + 598 + assert!(std::sync::Arc::ptr_eq(&client, default)); 599 + } 600 + 601 + // --------------------------------------------------------------------------- 602 + // Rate limit config on API clients 603 + // --------------------------------------------------------------------------- 604 + 605 + #[tokio::test] 606 + #[serial] 607 + #[ignore] 608 + async fn create_api_client_with_rate_limit_overrides() { 609 + let app = TestApp::new().await; 610 + let body = json!({ 611 + "name": "Rate Limited App", 612 + "client_id_url": "https://ratelimited.example.com/oauth-client-metadata.json", 613 + "client_uri": "https://ratelimited.example.com", 614 + "redirect_uris": ["https://happyview.example.com/auth/callback"], 615 + "scopes": "atproto", 616 + "rate_limit_capacity": 50, 617 + "rate_limit_refill_rate": 1.5 618 + }); 619 + 620 + let resp = app 621 + .router 622 + .clone() 623 + .oneshot(admin_post("/admin/api-clients", app.admin_cookie(), &body)) 624 + .await 625 + .unwrap(); 626 + assert_eq!(resp.status(), StatusCode::CREATED); 627 + let created = json_body(resp).await; 628 + let id = created["id"].as_str().unwrap(); 629 + 630 + // Verify overrides persisted 631 + let resp = app 632 + .router 633 + .clone() 634 + .oneshot(admin_get( 635 + &format!("/admin/api-clients/{id}"), 636 + app.admin_cookie(), 637 + )) 638 + .await 639 + .unwrap(); 640 + let json = json_body(resp).await; 641 + assert_eq!(json["rate_limit_capacity"], 50); 642 + assert_eq!(json["rate_limit_refill_rate"], 1.5); 643 + }
+3 -2
tests/lua_atproto_api.rs
··· 79 79 default_procedure_cost: 1, 80 80 default_proxy_cost: 1, 81 81 }, 82 - vec![], 83 82 ), 84 - oauth: std::sync::Arc::new(oauth), 83 + oauth: std::sync::Arc::new(happyview::auth::OAuthClientRegistry::new( 84 + std::sync::Arc::new(oauth), 85 + )), 85 86 oauth_state_store: happyview::auth::oauth_store::DbStateStore::new(pool.clone(), backend), 86 87 cookie_key: axum_extra::extract::cookie::Key::derive_from(b"test-secret"), 87 88 plugin_registry: std::sync::Arc::new(happyview::plugin::PluginRegistry::new()),
+3 -2
tests/lua_db_api.rs
··· 82 82 default_procedure_cost: 1, 83 83 default_proxy_cost: 1, 84 84 }, 85 - vec![], 86 85 ), 87 - oauth: std::sync::Arc::new(oauth), 86 + oauth: std::sync::Arc::new(happyview::auth::OAuthClientRegistry::new( 87 + std::sync::Arc::new(oauth), 88 + )), 88 89 oauth_state_store: happyview::auth::oauth_store::DbStateStore::new(pool.clone(), backend), 89 90 cookie_key: axum_extra::extract::cookie::Key::derive_from(b"test-secret"), 90 91 plugin_registry: std::sync::Arc::new(happyview::plugin::PluginRegistry::new()),