A lexicon-driven AppView for ATProto. happyview.dev
backfill firehose jetstream atproto appview oauth lexicon
8
fork

Configure Feed

Select the types of activity you want to include in your feed.

fix(xrpc): fix XRPC writes to use the new auth system

Trezy c6d77c35 8d6da9e2

+1050 -233
+24 -8
src/admin/api_clients.rs
··· 30 30 rand::rng().fill(&mut random_bytes); 31 31 let client_key = format!("hvc_{}", hex::encode(random_bytes)); 32 32 33 - // Generate the client secret: "hvs_" + 64 random hex chars. 34 - let mut secret_bytes = [0u8; 32]; 35 - rand::rng().fill(&mut secret_bytes); 36 - let client_secret = format!("hvs_{}", hex::encode(secret_bytes)); 37 - let client_secret_hash = hex::encode(Sha256::digest(client_secret.as_bytes())); 33 + // Generate the client secret for confidential clients only. 34 + let (client_secret, client_secret_hash) = if body.client_type == "confidential" { 35 + let mut secret_bytes = [0u8; 32]; 36 + rand::rng().fill(&mut secret_bytes); 37 + let secret = format!("hvs_{}", hex::encode(secret_bytes)); 38 + let hash = hex::encode(Sha256::digest(secret.as_bytes())); 39 + (Some(secret), hash) 40 + } else { 41 + (None, String::new()) 42 + }; 38 43 39 44 let id = Uuid::new_v4().to_string(); 40 45 let now = now_rfc3339(); ··· 92 97 state.rate_limiter.register_client_identity( 93 98 client_key.clone(), 94 99 crate::rate_limit::ClientIdentity { 95 - secret_hash: client_secret_hash, 100 + secret_hash: client_secret_hash.clone(), 96 101 client_uri: body.client_uri.clone(), 97 102 }, 98 103 ); ··· 138 143 client_secret, 139 144 name: body.name, 140 145 client_id_url: body.client_id_url, 146 + client_type: body.client_type, 141 147 }), 142 148 )) 143 149 } ··· 317 323 318 324 // Read current values 319 325 let select_sql = adapt_sql( 320 - "SELECT client_key, client_secret_hash, name, client_id_url, client_uri, redirect_uris, scopes, rate_limit_capacity, rate_limit_refill_rate, is_active FROM api_clients WHERE id = ?", 326 + "SELECT client_key, client_secret_hash, name, client_id_url, client_uri, redirect_uris, scopes, allowed_origins, rate_limit_capacity, rate_limit_refill_rate, is_active FROM api_clients WHERE id = ?", 321 327 state.db_backend, 322 328 ); 323 329 ··· 329 335 String, 330 336 String, 331 337 String, 338 + Option<String>, 332 339 Option<i32>, 333 340 Option<f64>, 334 341 i32, ··· 347 354 cur_client_uri, 348 355 cur_redirect_uris, 349 356 cur_scopes, 357 + cur_allowed_origins, 350 358 cur_capacity, 351 359 cur_refill, 352 360 cur_active, ··· 362 370 .map(|uris| serde_json::to_string(&uris).unwrap_or_else(|_| "[]".to_string())) 363 371 .unwrap_or(cur_redirect_uris); 364 372 let scopes = body.scopes.unwrap_or(cur_scopes); 373 + let allowed_origins_json: Option<String> = match body.allowed_origins { 374 + Some(Some(origins)) => { 375 + Some(serde_json::to_string(&origins).unwrap_or_else(|_| "[]".to_string())) 376 + } 377 + Some(None) => None, 378 + None => cur_allowed_origins, 379 + }; 365 380 let capacity = body.rate_limit_capacity.unwrap_or(cur_capacity); 366 381 let refill_rate = body.rate_limit_refill_rate.unwrap_or(cur_refill); 367 382 let is_active = body ··· 371 386 let now = now_rfc3339(); 372 387 373 388 let update_sql = adapt_sql( 374 - "UPDATE api_clients SET name = ?, client_uri = ?, redirect_uris = ?, scopes = ?, rate_limit_capacity = ?, rate_limit_refill_rate = ?, is_active = ?, updated_at = ? WHERE id = ?", 389 + "UPDATE api_clients SET name = ?, client_uri = ?, redirect_uris = ?, scopes = ?, allowed_origins = ?, rate_limit_capacity = ?, rate_limit_refill_rate = ?, is_active = ?, updated_at = ? WHERE id = ?", 375 390 state.db_backend, 376 391 ); 377 392 ··· 380 395 .bind(&client_uri) 381 396 .bind(&redirect_uris_json) 382 397 .bind(&scopes) 398 + .bind(&allowed_origins_json) 383 399 .bind(capacity) 384 400 .bind(refill_rate) 385 401 .bind(is_active)
+4 -1
src/admin/types.rs
··· 375 375 pub(super) client_uri: Option<String>, 376 376 pub(super) redirect_uris: Option<Vec<String>>, 377 377 pub(super) scopes: Option<String>, 378 + pub(super) allowed_origins: Option<Option<Vec<String>>>, 378 379 pub(super) rate_limit_capacity: Option<Option<i32>>, 379 380 pub(super) rate_limit_refill_rate: Option<Option<f64>>, 380 381 pub(super) is_active: Option<bool>, ··· 403 404 pub(super) struct CreateApiClientResponse { 404 405 pub(super) id: String, 405 406 pub(super) client_key: String, 406 - pub(super) client_secret: String, 407 + #[serde(skip_serializing_if = "Option::is_none")] 408 + pub(super) client_secret: Option<String>, 407 409 pub(super) name: String, 408 410 pub(super) client_id_url: String, 411 + pub(super) client_type: String, 409 412 }
+78 -38
src/auth/client_registry.rs
··· 5 5 use atrium_identity::did::{CommonDidResolver, CommonDidResolverConfig}; 6 6 use atrium_identity::handle::{AtprotoHandleResolver, AtprotoHandleResolverConfig}; 7 7 use atrium_oauth::{ 8 - AtprotoClientMetadata, AuthMethod, DefaultHttpClient, GrantType, OAuthClientConfig, 9 - OAuthResolverConfig, 8 + AtprotoClientMetadata, AtprotoLocalhostClientMetadata, AuthMethod, DefaultHttpClient, 9 + GrantType, OAuthClientConfig, OAuthResolverConfig, 10 10 }; 11 11 12 12 use crate::HappyViewOAuthClient; 13 13 use crate::auth::oauth_store::{DbSessionStore, DbStateStore}; 14 14 use crate::db::{DatabaseBackend, adapt_sql}; 15 15 use crate::dns::NativeDnsResolver; 16 + 17 + fn is_loopback_url(url: &str) -> bool { 18 + url.contains("127.0.0.1") || url.contains("[::1]") || url.contains("localhost") 19 + } 16 20 17 21 /// Parameters needed to build an OAuth client for an API client registration. 18 22 pub struct ApiClientOAuthParams { ··· 55 59 /// Look up a client by `client_id_url`. 56 60 pub fn get(&self, client_id_url: &str) -> Option<Arc<HappyViewOAuthClient>> { 57 61 self.clients.get(client_id_url).map(|r| r.value().clone()) 62 + } 63 + 64 + /// Get the resolved OAuth `client_id` for a registered client. 65 + /// 66 + /// For loopback clients this returns `http://localhost?scope=...` (the format 67 + /// auth servers expect), not the original `client_id_url` key. 68 + pub fn get_resolved_client_id(&self, client_id_url: &str) -> Option<String> { 69 + self.clients 70 + .get(client_id_url) 71 + .map(|r| r.value().client_metadata.client_id.clone()) 58 72 } 59 73 60 74 /// Look up a client by `client_id_url`, falling back to the primary client. ··· 139 153 )] 140 154 } else { 141 155 scopes 142 - }; 143 - 144 - let metadata = AtprotoClientMetadata { 145 - client_id: client_id_url.to_string(), 146 - client_uri: Some(client_uri.to_string()), 147 - redirect_uris, 148 - token_endpoint_auth_method: AuthMethod::None, 149 - grant_types: vec![GrantType::AuthorizationCode, GrantType::RefreshToken], 150 - scopes, 151 - jwks_uri: None, 152 - token_endpoint_auth_signing_alg: None, 153 156 }; 154 157 155 158 let http = Arc::new(DefaultHttpClient::default()); ··· 166 169 protected_resource_metadata: Default::default(), 167 170 }; 168 171 169 - match atrium_oauth::OAuthClient::new(OAuthClientConfig { 170 - client_metadata: metadata, 171 - keys: None, 172 - state_store: state_store.clone(), 173 - session_store: DbSessionStore::new(session_store_pool.clone(), *db_backend), 174 - resolver, 175 - }) { 172 + let client = if is_loopback_url(client_id_url) { 173 + atrium_oauth::OAuthClient::new(OAuthClientConfig { 174 + client_metadata: AtprotoLocalhostClientMetadata { 175 + redirect_uris: None, 176 + scopes: Some(scopes), 177 + }, 178 + keys: None, 179 + state_store: state_store.clone(), 180 + session_store: DbSessionStore::new(session_store_pool.clone(), *db_backend), 181 + resolver, 182 + }) 183 + } else { 184 + atrium_oauth::OAuthClient::new(OAuthClientConfig { 185 + client_metadata: AtprotoClientMetadata { 186 + client_id: client_id_url.to_string(), 187 + client_uri: Some(client_uri.to_string()), 188 + redirect_uris, 189 + token_endpoint_auth_method: AuthMethod::None, 190 + grant_types: vec![GrantType::AuthorizationCode, GrantType::RefreshToken], 191 + scopes, 192 + jwks_uri: None, 193 + token_endpoint_auth_signing_alg: None, 194 + }, 195 + keys: None, 196 + state_store: state_store.clone(), 197 + session_store: DbSessionStore::new(session_store_pool.clone(), *db_backend), 198 + resolver, 199 + }) 200 + }; 201 + 202 + match client { 176 203 Ok(client) => { 177 204 self.register(client_id_url.to_string(), Arc::new(client)); 178 205 Ok(()) ··· 217 244 scopes 218 245 }; 219 246 220 - let metadata = AtprotoClientMetadata { 221 - client_id: client_id_url.clone(), 222 - client_uri: Some(client_uri), 223 - redirect_uris, 224 - token_endpoint_auth_method: AuthMethod::None, 225 - grant_types: vec![GrantType::AuthorizationCode, GrantType::RefreshToken], 226 - scopes, 227 - jwks_uri: None, 228 - token_endpoint_auth_signing_alg: None, 229 - }; 230 - 231 247 // Each OAuthClient needs its own resolver instances (they're not Clone) 232 248 let http = Arc::new(DefaultHttpClient::default()); 233 249 let resolver = OAuthResolverConfig { ··· 243 259 protected_resource_metadata: Default::default(), 244 260 }; 245 261 246 - match atrium_oauth::OAuthClient::new(OAuthClientConfig { 247 - client_metadata: metadata, 248 - keys: None, 249 - state_store: state_store.clone(), 250 - session_store: DbSessionStore::new(session_store_pool.clone(), db_backend), 251 - resolver, 252 - }) { 262 + let client = if is_loopback_url(&client_id_url) { 263 + atrium_oauth::OAuthClient::new(OAuthClientConfig { 264 + client_metadata: AtprotoLocalhostClientMetadata { 265 + redirect_uris: None, 266 + scopes: Some(scopes), 267 + }, 268 + keys: None, 269 + state_store: state_store.clone(), 270 + session_store: DbSessionStore::new(session_store_pool.clone(), db_backend), 271 + resolver, 272 + }) 273 + } else { 274 + atrium_oauth::OAuthClient::new(OAuthClientConfig { 275 + client_metadata: AtprotoClientMetadata { 276 + client_id: client_id_url.clone(), 277 + client_uri: Some(client_uri), 278 + redirect_uris, 279 + token_endpoint_auth_method: AuthMethod::None, 280 + grant_types: vec![GrantType::AuthorizationCode, GrantType::RefreshToken], 281 + scopes, 282 + jwks_uri: None, 283 + token_endpoint_auth_signing_alg: None, 284 + }, 285 + keys: None, 286 + state_store: state_store.clone(), 287 + session_store: DbSessionStore::new(session_store_pool.clone(), db_backend), 288 + resolver, 289 + }) 290 + }; 291 + 292 + match client { 253 293 Ok(client) => { 254 294 tracing::info!(client_id = %client_id_url, "Registered API client OAuth identity"); 255 295 self.register(client_id_url, Arc::new(client));
+64 -26
src/lua/execute.rs
··· 57 57 let script_source = script.to_string(); 58 58 let input_json = input.clone(); 59 59 60 - let session = match repo::get_oauth_session(state, claims.did()).await { 61 - Ok(s) => s, 62 - Err(e) => { 63 - let error_message = format!("{e}"); 64 - log_event( 65 - &state.db, 66 - EventLog { 67 - event_type: "script.error".to_string(), 68 - severity: Severity::Error, 69 - actor_did: Some(claims.did().to_string()), 70 - subject: Some(method.to_string()), 71 - detail: serde_json::json!({ 72 - "error": error_message, 73 - "script_source": script_source, 74 - "input": input_json, 75 - "caller_did": claims.did(), 76 - "method": method, 77 - "duration_ms": start.elapsed().as_millis() as u64, 78 - }), 79 - }, 80 - backend, 81 - ) 82 - .await; 83 - return Err(e); 60 + let pds_auth = if let Some(client_key) = claims.client_key() { 61 + let encryption_key = state 62 + .config 63 + .token_encryption_key 64 + .as_ref() 65 + .ok_or_else(|| AppError::Internal("TOKEN_ENCRYPTION_KEY not configured".into()))?; 66 + let api_client_id = match repo::get_dpop_client_id(state, client_key).await { 67 + Ok(id) => id, 68 + Err(e) => { 69 + let error_message = format!("{e}"); 70 + log_event( 71 + &state.db, 72 + EventLog { 73 + event_type: "script.error".to_string(), 74 + severity: Severity::Error, 75 + actor_did: Some(claims.did().to_string()), 76 + subject: Some(method.to_string()), 77 + detail: serde_json::json!({ 78 + "error": error_message, 79 + "script_source": script_source, 80 + "input": input_json, 81 + "caller_did": claims.did(), 82 + "method": method, 83 + "duration_ms": start.elapsed().as_millis() as u64, 84 + }), 85 + }, 86 + backend, 87 + ) 88 + .await; 89 + return Err(e); 90 + } 91 + }; 92 + repo::PdsAuth::Dpop { 93 + api_client_id, 94 + encryption_key: *encryption_key, 95 + } 96 + } else { 97 + match repo::get_oauth_session(state, claims.did()).await { 98 + Ok(s) => repo::PdsAuth::OAuth(Arc::new(s)), 99 + Err(e) => { 100 + let error_message = format!("{e}"); 101 + log_event( 102 + &state.db, 103 + EventLog { 104 + event_type: "script.error".to_string(), 105 + severity: Severity::Error, 106 + actor_did: Some(claims.did().to_string()), 107 + subject: Some(method.to_string()), 108 + detail: serde_json::json!({ 109 + "error": error_message, 110 + "script_source": script_source, 111 + "input": input_json, 112 + "caller_did": claims.did(), 113 + "method": method, 114 + "duration_ms": start.elapsed().as_millis() as u64, 115 + }), 116 + }, 117 + backend, 118 + ) 119 + .await; 120 + return Err(e); 121 + } 84 122 } 85 123 }; 86 124 ··· 113 151 114 152 let state_arc = Arc::new(state.clone()); 115 153 let claims_arc = Arc::new(claims.clone()); 116 - let session_arc = Arc::new(session); 154 + let pds_auth_arc = Arc::new(pds_auth); 117 155 118 156 if let Err(e) = db_api::register_db_api(&lua, state_arc.clone()) { 119 157 let error_message = format!("failed to register db API: {e}"); ··· 213 251 return Err(AppError::Internal(error_message)); 214 252 } 215 253 216 - if let Err(e) = record::register_record_api(&lua, state_arc, claims_arc, session_arc) { 254 + if let Err(e) = record::register_record_api(&lua, state_arc, claims_arc, pds_auth_arc) { 217 255 let error_message = format!("failed to register Record API: {e}"); 218 256 log_event( 219 257 &state.db,
+35 -56
src/lua/record.rs
··· 4 4 use std::sync::Arc; 5 5 6 6 use crate::AppState; 7 - use crate::HappyViewOAuthSession; 8 7 use crate::auth::Claims; 9 8 use crate::db::{adapt_sql, now_rfc3339}; 10 9 use crate::record_refs::sync_refs; 11 - use crate::repo; 10 + use crate::repo::PdsAuth; 12 11 13 12 use super::tid::generate_tid; 14 13 ··· 28 27 lua: &Lua, 29 28 state: Arc<AppState>, 30 29 claims: Arc<Claims>, 31 - session: Arc<HappyViewOAuthSession>, 30 + pds_auth: Arc<PdsAuth>, 32 31 ) -> LuaResult<()> { 33 32 // -- methods table (shared by all Record instances) -- 34 33 let methods = lua.create_table()?; ··· 37 36 { 38 37 let state = state.clone(); 39 38 let claims = claims.clone(); 40 - let session = session.clone(); 39 + let pds_auth = pds_auth.clone(); 41 40 let save_fn = lua.create_async_function(move |lua, this: mlua::Table| { 42 41 let state = state.clone(); 43 42 let claims = claims.clone(); 44 - let session = session.clone(); 43 + let pds_auth = pds_auth.clone(); 45 44 async move { 46 45 let backend = state.db_backend; 47 46 let collection: String = this.raw_get("_collection")?; ··· 74 73 "record": data, 75 74 }); 76 75 77 - let resp = repo::pds_post_json_raw( 78 - &state, 79 - &session, 80 - "com.atproto.repo.putRecord", 81 - &pds_body, 82 - ) 83 - .await 84 - .map_err(|e| mlua::Error::runtime(format!("PDS putRecord failed: {e}")))?; 76 + let resp = pds_auth 77 + .post_json(&state, repo, "com.atproto.repo.putRecord", &pds_body) 78 + .await 79 + .map_err(|e| mlua::Error::runtime(format!("PDS putRecord failed: {e}")))?; 85 80 86 81 if !resp.status().is_success() { 87 82 let status = resp.status(); ··· 141 136 pds_body["rkey"] = json!(rkey); 142 137 } 143 138 144 - let resp = repo::pds_post_json_raw( 145 - &state, 146 - &session, 147 - "com.atproto.repo.createRecord", 148 - &pds_body, 149 - ) 150 - .await 151 - .map_err(|e| mlua::Error::runtime(format!("PDS createRecord failed: {e}")))?; 139 + let resp = pds_auth 140 + .post_json(&state, repo, "com.atproto.repo.createRecord", &pds_body) 141 + .await 142 + .map_err(|e| mlua::Error::runtime(format!("PDS createRecord failed: {e}")))?; 152 143 153 144 if !resp.status().is_success() { 154 145 let status = resp.status(); ··· 215 206 { 216 207 let state = state.clone(); 217 208 let claims = claims.clone(); 218 - let session = session.clone(); 209 + let pds_auth = pds_auth.clone(); 219 210 let delete_fn = lua.create_async_function(move |_lua, this: mlua::Table| { 220 211 let state = state.clone(); 221 212 let claims = claims.clone(); 222 - let session = session.clone(); 213 + let pds_auth = pds_auth.clone(); 223 214 async move { 224 215 let backend = state.db_backend; 225 216 let uri: String = this.raw_get::<Option<String>>("_uri")?.ok_or_else(|| { ··· 241 232 "rkey": rkey, 242 233 }); 243 234 244 - let resp = repo::pds_post_json_raw( 245 - &state, 246 - &session, 247 - "com.atproto.repo.deleteRecord", 248 - &pds_body, 249 - ) 250 - .await 251 - .map_err(|e| mlua::Error::runtime(format!("PDS deleteRecord failed: {e}")))?; 235 + let resp = pds_auth 236 + .post_json(&state, repo, "com.atproto.repo.deleteRecord", &pds_body) 237 + .await 238 + .map_err(|e| mlua::Error::runtime(format!("PDS deleteRecord failed: {e}")))?; 252 239 253 240 if !resp.status().is_success() { 254 241 let status = resp.status(); ··· 451 438 { 452 439 let state = state.clone(); 453 440 let claims = claims.clone(); 454 - let session = session.clone(); 441 + let pds_auth = pds_auth.clone(); 455 442 let save_all_fn = 456 443 lua.create_async_function(move |lua, records_table: mlua::Table| { 457 444 let state = state.clone(); 458 445 let claims = claims.clone(); 459 - let session = session.clone(); 446 + let pds_auth = pds_auth.clone(); 460 447 async move { 461 448 let backend = state.db_backend; 462 449 // Extract save data from each record (sync) ··· 484 471 let futs = save_items.iter().map(|(_, collection, existing_uri, rkey, repo_override, data)| { 485 472 let state = state.clone(); 486 473 let claims = claims.clone(); 487 - let session = session.clone(); 474 + let pds_auth = pds_auth.clone(); 488 475 let collection = collection.clone(); 489 476 let existing_uri = existing_uri.clone(); 490 477 let rkey = rkey.clone(); ··· 506 493 "record": data, 507 494 }); 508 495 509 - let resp = repo::pds_post_json_raw( 510 - &state, 511 - &session, 512 - "com.atproto.repo.putRecord", 513 - &pds_body, 514 - ) 515 - .await 516 - .map_err(|e| { 517 - mlua::Error::runtime(format!("PDS putRecord failed: {e}")) 518 - })?; 496 + let resp = pds_auth 497 + .post_json(&state, repo, "com.atproto.repo.putRecord", &pds_body) 498 + .await 499 + .map_err(|e| { 500 + mlua::Error::runtime(format!("PDS putRecord failed: {e}")) 501 + })?; 519 502 520 503 if !resp.status().is_success() { 521 504 let status = resp.status(); ··· 575 558 pds_body["rkey"] = json!(rkey); 576 559 } 577 560 578 - let resp = repo::pds_post_json_raw( 579 - &state, 580 - &session, 581 - "com.atproto.repo.createRecord", 582 - &pds_body, 583 - ) 584 - .await 585 - .map_err(|e| { 586 - mlua::Error::runtime(format!( 587 - "PDS createRecord failed: {e}" 588 - )) 589 - })?; 561 + let resp = pds_auth 562 + .post_json(&state, repo, "com.atproto.repo.createRecord", &pds_body) 563 + .await 564 + .map_err(|e| { 565 + mlua::Error::runtime(format!( 566 + "PDS createRecord failed: {e}" 567 + )) 568 + })?; 590 569 591 570 if !resp.status().is_success() { 592 571 let status = resp.status();
+158 -20
src/oauth/client_auth.rs
··· 156 156 /// Rules: 157 157 /// - `atproto` must be present in token scopes (always implicitly allowed) 158 158 /// - Every non-`atproto` scope in the token must appear in the client's registered scopes 159 - pub fn validate_scopes(token_scopes: &str, client_scopes: &str) -> Result<(), AppError> { 159 + /// - `include:X` client scopes are expanded by looking up the permission set 160 + /// lexicon `X` and extracting its `rpc:` and `repo:` permissions 161 + pub async fn validate_scopes( 162 + token_scopes: &str, 163 + client_scopes: &str, 164 + lexicons: &crate::lexicon::LexiconRegistry, 165 + ) -> Result<(), AppError> { 160 166 let token_set: std::collections::HashSet<&str> = token_scopes.split_whitespace().collect(); 161 - let client_set: std::collections::HashSet<&str> = client_scopes.split_whitespace().collect(); 167 + let mut client_set: std::collections::HashSet<String> = std::collections::HashSet::new(); 168 + 169 + for scope in client_scopes.split_whitespace() { 170 + if let Some(perm_set_id) = scope.strip_prefix("include:") { 171 + expand_permission_set(perm_set_id, lexicons, &mut client_set).await; 172 + } 173 + client_set.insert(scope.to_string()); 174 + } 162 175 163 176 if !token_set.contains("atproto") { 164 177 return Err(AppError::BadRequest( ··· 168 181 169 182 for scope in &token_set { 170 183 if *scope == "atproto" { 171 - continue; // always allowed 184 + continue; 172 185 } 173 - if !client_set.contains(scope) { 186 + if !client_set.contains(*scope) { 174 187 return Err(AppError::BadRequest(format!( 175 188 "scope '{}' is not allowed for this client", 176 189 scope ··· 181 194 Ok(()) 182 195 } 183 196 197 + /// Expand a permission set lexicon into individual `rpc:` and `repo:` scopes. 198 + async fn expand_permission_set( 199 + nsid: &str, 200 + lexicons: &crate::lexicon::LexiconRegistry, 201 + out: &mut std::collections::HashSet<String>, 202 + ) { 203 + let lexicon = match lexicons.get(nsid).await { 204 + Some(l) => l, 205 + None => { 206 + tracing::warn!(nsid = %nsid, "permission set lexicon not found in registry"); 207 + return; 208 + } 209 + }; 210 + 211 + let permissions = match lexicon 212 + .raw 213 + .get("defs") 214 + .and_then(|d| d.get("main")) 215 + .and_then(|m| m.get("permissions")) 216 + .and_then(|p| p.as_array()) 217 + { 218 + Some(p) => p, 219 + None => return, 220 + }; 221 + 222 + for perm in permissions { 223 + let resource = perm.get("resource").and_then(|r| r.as_str()).unwrap_or(""); 224 + match resource { 225 + "rpc" => { 226 + if let Some(lxms) = perm.get("lxm").and_then(|l| l.as_array()) { 227 + for lxm in lxms { 228 + if let Some(s) = lxm.as_str() { 229 + out.insert(format!("rpc:{s}")); 230 + } 231 + } 232 + } 233 + } 234 + "repo" => { 235 + if let Some(collections) = perm.get("collection").and_then(|c| c.as_array()) { 236 + for col in collections { 237 + if let Some(s) = col.as_str() { 238 + out.insert(format!("repo:{s}?action=create")); 239 + out.insert(format!("repo:{s}?action=update")); 240 + out.insert(format!("repo:{s}?action=delete")); 241 + } 242 + } 243 + } 244 + } 245 + _ => {} 246 + } 247 + } 248 + } 249 + 184 250 /// Verify a PKCE challenge against a verifier. 185 251 pub fn verify_pkce(challenge: &str, verifier: &str) -> bool { 186 252 use base64::Engine; ··· 194 260 mod tests { 195 261 use super::*; 196 262 197 - #[test] 198 - fn validate_scopes_requires_atproto() { 199 - let result = validate_scopes("transition:generic", "atproto transition:generic"); 263 + fn empty_registry() -> crate::lexicon::LexiconRegistry { 264 + crate::lexicon::LexiconRegistry::new() 265 + } 266 + 267 + #[tokio::test] 268 + async fn validate_scopes_requires_atproto() { 269 + let reg = empty_registry(); 270 + let result = 271 + validate_scopes("transition:generic", "atproto transition:generic", &reg).await; 200 272 assert!(result.is_err()); 201 273 } 202 274 203 - #[test] 204 - fn validate_scopes_atproto_only_always_passes() { 205 - let result = validate_scopes("atproto", "com.example.whatever"); 275 + #[tokio::test] 276 + async fn validate_scopes_atproto_only_always_passes() { 277 + let reg = empty_registry(); 278 + let result = validate_scopes("atproto", "com.example.whatever", &reg).await; 206 279 assert!(result.is_ok()); 207 280 } 208 281 209 - #[test] 210 - fn validate_scopes_subset_passes() { 282 + #[tokio::test] 283 + async fn validate_scopes_subset_passes() { 284 + let reg = empty_registry(); 211 285 let result = validate_scopes( 212 286 "atproto com.example.basic", 213 287 "atproto com.example.basic com.example.advanced", 214 - ); 288 + &reg, 289 + ) 290 + .await; 215 291 assert!(result.is_ok()); 216 292 } 217 293 218 - #[test] 219 - fn validate_scopes_excess_scope_fails() { 294 + #[tokio::test] 295 + async fn validate_scopes_excess_scope_fails() { 296 + let reg = empty_registry(); 220 297 let result = validate_scopes( 221 298 "atproto com.example.basic com.example.advanced", 222 299 "atproto com.example.basic", 223 - ); 300 + &reg, 301 + ) 302 + .await; 224 303 assert!(result.is_err()); 225 304 } 226 305 227 - #[test] 228 - fn validate_scopes_transition_generic_requires_registration() { 229 - let result = validate_scopes("atproto transition:generic", "atproto"); 306 + #[tokio::test] 307 + async fn validate_scopes_transition_generic_requires_registration() { 308 + let reg = empty_registry(); 309 + let result = validate_scopes("atproto transition:generic", "atproto", &reg).await; 230 310 assert!(result.is_err()); 231 311 232 - let result = validate_scopes("atproto transition:generic", "atproto transition:generic"); 312 + let result = validate_scopes( 313 + "atproto transition:generic", 314 + "atproto transition:generic", 315 + &reg, 316 + ) 317 + .await; 318 + assert!(result.is_ok()); 319 + } 320 + 321 + #[tokio::test] 322 + async fn validate_scopes_expands_include_permission_set() { 323 + let reg = empty_registry(); 324 + let raw = serde_json::json!({ 325 + "lexicon": 1, 326 + "id": "com.example.authBasic", 327 + "defs": { 328 + "main": { 329 + "type": "permission-set", 330 + "permissions": [ 331 + { 332 + "type": "permission", 333 + "resource": "rpc", 334 + "lxm": ["com.example.getProfile", "com.example.putProfile"] 335 + }, 336 + { 337 + "type": "permission", 338 + "resource": "repo", 339 + "collection": ["com.example.profile"] 340 + } 341 + ] 342 + } 343 + } 344 + }); 345 + let parsed = crate::lexicon::ParsedLexicon::parse( 346 + raw, 347 + 1, 348 + None, 349 + crate::lexicon::ProcedureAction::Upsert, 350 + None, 351 + None, 352 + None, 353 + ) 354 + .unwrap(); 355 + reg.upsert(parsed).await; 356 + 357 + let result = validate_scopes( 358 + "atproto rpc:com.example.getProfile repo:com.example.profile?action=create", 359 + "atproto include:com.example.authBasic", 360 + &reg, 361 + ) 362 + .await; 233 363 assert!(result.is_ok()); 364 + 365 + let result = validate_scopes( 366 + "atproto rpc:com.example.notAllowed", 367 + "atproto include:com.example.authBasic", 368 + &reg, 369 + ) 370 + .await; 371 + assert!(result.is_err()); 234 372 } 235 373 236 374 #[test]
+497 -52
src/oauth/pds_write.rs
··· 3 3 use p256::ecdsa::{SigningKey, signature::Signer}; 4 4 use sha2::{Digest, Sha256}; 5 5 6 + use std::sync::Arc; 7 + 8 + use crate::auth::OAuthClientRegistry; 6 9 use crate::db::DatabaseBackend; 7 10 use crate::error::AppError; 8 11 use crate::plugin::encryption::decrypt; 9 12 10 - /// Make an authenticated POST to a PDS XRPC endpoint using a DPoP session. 11 - #[allow(clippy::too_many_arguments)] 12 - pub async fn dpop_pds_post( 13 + use super::sessions::DpopSession; 14 + 15 + /// Resolved DPoP credentials needed to make authenticated PDS requests. 16 + struct DpopCredentials { 17 + session: DpopSession, 18 + pds_url: String, 19 + private_jwk: serde_json::Value, 20 + } 21 + 22 + /// Resolve DPoP credentials: session, PDS URL, and decrypted private key. 23 + async fn resolve_credentials( 13 24 http: &reqwest::Client, 14 25 pool: &sqlx::AnyPool, 15 26 backend: DatabaseBackend, ··· 17 28 plc_url: &str, 18 29 api_client_id: &str, 19 30 user_did: &str, 20 - xrpc_method: &str, 21 - body: &serde_json::Value, 22 - ) -> Result<reqwest::Response, AppError> { 31 + ) -> Result<DpopCredentials, AppError> { 23 32 let session = 24 33 super::sessions::get_dpop_session(pool, backend, encryption_key, api_client_id, user_did) 25 34 .await?; ··· 29 38 None => resolve_pds_from_did(http, plc_url, user_did).await?, 30 39 }; 31 40 32 - let target_url = format!("{}/xrpc/{}", pds_url.trim_end_matches('/'), xrpc_method); 33 - 34 - // Decrypt the DPoP private key 35 41 let key_sql = crate::db::adapt_sql( 36 42 "SELECT private_key_enc FROM dpop_keys WHERE id = ?", 37 43 backend, ··· 50 56 let private_jwk: serde_json::Value = serde_json::from_slice(&key_bytes) 51 57 .map_err(|e| AppError::Internal(format!("failed to parse DPoP key: {e}")))?; 52 58 53 - let proof = generate_dpop_proof(&private_jwk, "POST", &target_url, &session.access_token)?; 59 + Ok(DpopCredentials { 60 + session, 61 + pds_url, 62 + private_jwk, 63 + }) 64 + } 65 + 66 + /// Make an authenticated POST, handling DPoP nonce negotiation and token refresh. 67 + #[allow(clippy::too_many_arguments)] 68 + async fn dpop_post_with_retry( 69 + http: &reqwest::Client, 70 + pool: &sqlx::AnyPool, 71 + backend: DatabaseBackend, 72 + encryption_key: &[u8; 32], 73 + oauth_registry: &Arc<OAuthClientRegistry>, 74 + creds: &mut DpopCredentials, 75 + target_url: &str, 76 + request_builder: impl Fn(&reqwest::Client, &str, &str) -> reqwest::RequestBuilder, 77 + ) -> Result<reqwest::Response, AppError> { 78 + let proof = generate_dpop_proof( 79 + &creds.private_jwk, 80 + "POST", 81 + target_url, 82 + &creds.session.access_token, 83 + None, 84 + )?; 54 85 55 - let resp = http 56 - .post(&target_url) 57 - .header("Authorization", format!("DPoP {}", session.access_token)) 58 - .header("DPoP", proof) 59 - .header("Content-Type", "application/json") 60 - .json(body) 86 + let resp = request_builder(http, &creds.session.access_token, &proof) 61 87 .send() 62 88 .await 63 89 .map_err(|e| AppError::Internal(format!("PDS request failed: {e}")))?; 64 90 91 + // Handle DPoP nonce requirement 92 + if let Some(nonce) = extract_dpop_nonce(&resp) { 93 + let proof = generate_dpop_proof( 94 + &creds.private_jwk, 95 + "POST", 96 + target_url, 97 + &creds.session.access_token, 98 + Some(&nonce), 99 + )?; 100 + 101 + let resp = request_builder(http, &creds.session.access_token, &proof) 102 + .send() 103 + .await 104 + .map_err(|e| AppError::Internal(format!("PDS request failed: {e}")))?; 105 + 106 + // If we still get invalid_token after nonce, try refresh 107 + if is_expired_token(&resp) { 108 + return retry_after_refresh( 109 + http, 110 + pool, 111 + backend, 112 + encryption_key, 113 + oauth_registry, 114 + creds, 115 + target_url, 116 + Some(&nonce), 117 + &request_builder, 118 + ) 119 + .await; 120 + } 121 + 122 + return Ok(resp); 123 + } 124 + 125 + // Handle expired token 126 + if is_expired_token(&resp) { 127 + return retry_after_refresh( 128 + http, 129 + pool, 130 + backend, 131 + encryption_key, 132 + oauth_registry, 133 + creds, 134 + target_url, 135 + None, 136 + &request_builder, 137 + ) 138 + .await; 139 + } 140 + 65 141 Ok(resp) 66 142 } 67 143 144 + /// Refresh the access token and retry the PDS request. 145 + #[allow(clippy::too_many_arguments)] 146 + async fn retry_after_refresh( 147 + http: &reqwest::Client, 148 + pool: &sqlx::AnyPool, 149 + backend: DatabaseBackend, 150 + encryption_key: &[u8; 32], 151 + oauth_registry: &Arc<OAuthClientRegistry>, 152 + creds: &mut DpopCredentials, 153 + target_url: &str, 154 + nonce: Option<&str>, 155 + request_builder: &impl Fn(&reqwest::Client, &str, &str) -> reqwest::RequestBuilder, 156 + ) -> Result<reqwest::Response, AppError> { 157 + refresh_access_token(http, pool, backend, encryption_key, oauth_registry, creds).await?; 158 + 159 + let proof = generate_dpop_proof( 160 + &creds.private_jwk, 161 + "POST", 162 + target_url, 163 + &creds.session.access_token, 164 + nonce, 165 + )?; 166 + 167 + let resp = request_builder(http, &creds.session.access_token, &proof) 168 + .send() 169 + .await 170 + .map_err(|e| AppError::Internal(format!("PDS request failed after token refresh: {e}")))?; 171 + 172 + // One more nonce negotiation attempt after refresh 173 + if let Some(new_nonce) = extract_dpop_nonce(&resp) { 174 + let proof = generate_dpop_proof( 175 + &creds.private_jwk, 176 + "POST", 177 + target_url, 178 + &creds.session.access_token, 179 + Some(&new_nonce), 180 + )?; 181 + 182 + let resp = request_builder(http, &creds.session.access_token, &proof) 183 + .send() 184 + .await 185 + .map_err(|e| AppError::Internal(format!("PDS request failed: {e}")))?; 186 + 187 + return Ok(resp); 188 + } 189 + 190 + Ok(resp) 191 + } 192 + 193 + /// Make an authenticated POST to a PDS XRPC endpoint using a DPoP session. 194 + #[allow(clippy::too_many_arguments)] 195 + pub async fn dpop_pds_post( 196 + http: &reqwest::Client, 197 + pool: &sqlx::AnyPool, 198 + backend: DatabaseBackend, 199 + encryption_key: &[u8; 32], 200 + oauth_registry: &Arc<OAuthClientRegistry>, 201 + plc_url: &str, 202 + api_client_id: &str, 203 + user_did: &str, 204 + xrpc_method: &str, 205 + body: &serde_json::Value, 206 + ) -> Result<reqwest::Response, AppError> { 207 + let mut creds = resolve_credentials( 208 + http, 209 + pool, 210 + backend, 211 + encryption_key, 212 + plc_url, 213 + api_client_id, 214 + user_did, 215 + ) 216 + .await?; 217 + 218 + let target_url = format!( 219 + "{}/xrpc/{}", 220 + creds.pds_url.trim_end_matches('/'), 221 + xrpc_method 222 + ); 223 + 224 + let body = body.clone(); 225 + let target = target_url.clone(); 226 + dpop_post_with_retry( 227 + http, 228 + pool, 229 + backend, 230 + encryption_key, 231 + oauth_registry, 232 + &mut creds, 233 + &target_url, 234 + |http, access_token, proof| { 235 + http.post(&target) 236 + .header("Authorization", format!("DPoP {access_token}")) 237 + .header("DPoP", proof) 238 + .header("Content-Type", "application/json") 239 + .json(&body) 240 + }, 241 + ) 242 + .await 243 + } 244 + 68 245 /// Make an authenticated blob upload to a PDS using a DPoP session. 69 246 #[allow(clippy::too_many_arguments)] 70 247 pub async fn dpop_pds_post_blob( ··· 72 249 pool: &sqlx::AnyPool, 73 250 backend: DatabaseBackend, 74 251 encryption_key: &[u8; 32], 252 + oauth_registry: &Arc<OAuthClientRegistry>, 75 253 plc_url: &str, 76 254 api_client_id: &str, 77 255 user_did: &str, 78 256 content_type: &str, 79 257 blob: bytes::Bytes, 80 258 ) -> Result<reqwest::Response, AppError> { 81 - let session = 82 - super::sessions::get_dpop_session(pool, backend, encryption_key, api_client_id, user_did) 83 - .await?; 84 - 85 - let pds_url = match session.pds_url { 86 - Some(ref url) => url.clone(), 87 - None => resolve_pds_from_did(http, plc_url, user_did).await?, 88 - }; 259 + let mut creds = resolve_credentials( 260 + http, 261 + pool, 262 + backend, 263 + encryption_key, 264 + plc_url, 265 + api_client_id, 266 + user_did, 267 + ) 268 + .await?; 89 269 90 270 let target_url = format!( 91 271 "{}/xrpc/com.atproto.repo.uploadBlob", 92 - pds_url.trim_end_matches('/') 272 + creds.pds_url.trim_end_matches('/') 93 273 ); 94 274 95 - // Decrypt the DPoP private key 96 - let key_sql = crate::db::adapt_sql( 97 - "SELECT private_key_enc FROM dpop_keys WHERE id = ?", 275 + let content_type = content_type.to_string(); 276 + let target = target_url.clone(); 277 + dpop_post_with_retry( 278 + http, 279 + pool, 98 280 backend, 99 - ); 100 - let row: Option<(Vec<u8>,)> = sqlx::query_as(&key_sql) 101 - .bind(&session.dpop_key_id) 102 - .fetch_optional(pool) 103 - .await 104 - .map_err(|e| AppError::Internal(format!("failed to look up DPoP key: {e}")))?; 281 + encryption_key, 282 + oauth_registry, 283 + &mut creds, 284 + &target_url, 285 + |http, access_token, proof| { 286 + http.post(&target) 287 + .header("Authorization", format!("DPoP {access_token}")) 288 + .header("DPoP", proof) 289 + .header("Content-Type", &content_type) 290 + .body(blob.clone()) 291 + }, 292 + ) 293 + .await 294 + } 295 + 296 + /// Check if a response is a 401 with an expired/invalid token error. 297 + fn is_expired_token(resp: &reqwest::Response) -> bool { 298 + resp.status() == reqwest::StatusCode::UNAUTHORIZED 299 + } 300 + 301 + /// Check if a response indicates that a DPoP nonce is required, and extract it. 302 + fn extract_dpop_nonce(resp: &reqwest::Response) -> Option<String> { 303 + if resp.status() == reqwest::StatusCode::UNAUTHORIZED 304 + || resp.status() == reqwest::StatusCode::BAD_REQUEST 305 + { 306 + resp.headers() 307 + .get("dpop-nonce") 308 + .and_then(|v| v.to_str().ok()) 309 + .map(|s| s.to_string()) 310 + } else { 311 + None 312 + } 313 + } 314 + 315 + /// Refresh an expired access token using the session's refresh_token. 316 + /// 317 + /// Discovers the token endpoint from the issuer's OAuth metadata, sends a 318 + /// `grant_type=refresh_token` request with a DPoP proof, and updates the 319 + /// stored session with the new tokens. 320 + async fn refresh_access_token( 321 + http: &reqwest::Client, 322 + pool: &sqlx::AnyPool, 323 + backend: DatabaseBackend, 324 + encryption_key: &[u8; 32], 325 + oauth_registry: &Arc<OAuthClientRegistry>, 326 + creds: &mut DpopCredentials, 327 + ) -> Result<(), AppError> { 328 + let refresh_token = creds 329 + .session 330 + .refresh_token 331 + .as_deref() 332 + .ok_or_else(|| AppError::Auth("token expired and no refresh_token available".into()))?; 105 333 106 - let (encrypted_key,) = row.ok_or_else(|| AppError::Internal("DPoP key not found".into()))?; 334 + let issuer = creds 335 + .session 336 + .issuer 337 + .as_deref() 338 + .ok_or_else(|| AppError::Auth("token expired and no issuer URL stored".into()))?; 107 339 108 - let key_bytes = decrypt(encryption_key, &encrypted_key) 109 - .map_err(|e| AppError::Internal(format!("failed to decrypt DPoP key: {e}")))?; 340 + let token_endpoint = discover_token_endpoint(http, issuer).await?; 110 341 111 - let private_jwk: serde_json::Value = serde_json::from_slice(&key_bytes) 112 - .map_err(|e| AppError::Internal(format!("failed to parse DPoP key: {e}")))?; 342 + // Get the resolved client_id from the OAuth registry. For loopback clients 343 + // this returns `http://localhost?scope=...` which auth servers handle inline, 344 + // rather than the `client_id_url` from the DB which they'd try to fetch. 345 + let client_id_url = lookup_client_id_url(pool, backend, &creds.session.api_client_id).await?; 346 + let client_id = oauth_registry 347 + .get_resolved_client_id(&client_id_url) 348 + .unwrap_or(client_id_url); 113 349 114 - let proof = generate_dpop_proof(&private_jwk, "POST", &target_url, &session.access_token)?; 350 + let proof = generate_dpop_proof_no_ath(&creds.private_jwk, "POST", &token_endpoint, None)?; 115 351 116 352 let resp = http 117 - .post(&target_url) 118 - .header("Authorization", format!("DPoP {}", session.access_token)) 119 - .header("DPoP", proof) 120 - .header("Content-Type", content_type) 121 - .body(blob) 353 + .post(&token_endpoint) 354 + .header("DPoP", &proof) 355 + .header("Content-Type", "application/x-www-form-urlencoded") 356 + .form(&[ 357 + ("grant_type", "refresh_token"), 358 + ("refresh_token", refresh_token), 359 + ("client_id", &client_id), 360 + ]) 122 361 .send() 123 362 .await 124 - .map_err(|e| AppError::Internal(format!("PDS uploadBlob request failed: {e}")))?; 363 + .map_err(|e| AppError::Internal(format!("token refresh request failed: {e}")))?; 364 + 365 + // Handle nonce requirement on the token endpoint 366 + if let Some(nonce) = extract_dpop_nonce(&resp) { 367 + let proof = 368 + generate_dpop_proof_no_ath(&creds.private_jwk, "POST", &token_endpoint, Some(&nonce))?; 369 + 370 + let resp = http 371 + .post(&token_endpoint) 372 + .header("DPoP", &proof) 373 + .header("Content-Type", "application/x-www-form-urlencoded") 374 + .form(&[ 375 + ("grant_type", "refresh_token"), 376 + ("refresh_token", refresh_token), 377 + ("client_id", &client_id), 378 + ]) 379 + .send() 380 + .await 381 + .map_err(|e| AppError::Internal(format!("token refresh request failed: {e}")))?; 382 + 383 + return handle_refresh_response(http, pool, backend, encryption_key, creds, resp).await; 384 + } 125 385 126 - Ok(resp) 386 + handle_refresh_response(http, pool, backend, encryption_key, creds, resp).await 387 + } 388 + 389 + /// Parse the token refresh response and update the stored session. 390 + async fn handle_refresh_response( 391 + _http: &reqwest::Client, 392 + pool: &sqlx::AnyPool, 393 + backend: DatabaseBackend, 394 + encryption_key: &[u8; 32], 395 + creds: &mut DpopCredentials, 396 + resp: reqwest::Response, 397 + ) -> Result<(), AppError> { 398 + if !resp.status().is_success() { 399 + let status = resp.status(); 400 + let body = resp.text().await.unwrap_or_default(); 401 + return Err(AppError::Auth(format!( 402 + "token refresh failed ({status}): {body}" 403 + ))); 404 + } 405 + 406 + let token_resp: serde_json::Value = resp 407 + .json() 408 + .await 409 + .map_err(|e| AppError::Internal(format!("invalid token refresh response: {e}")))?; 410 + 411 + let new_access_token = token_resp["access_token"] 412 + .as_str() 413 + .ok_or_else(|| AppError::Internal("refresh response missing access_token".into()))?; 414 + 415 + let new_refresh_token = token_resp["refresh_token"].as_str(); 416 + 417 + let expires_in = token_resp["expires_in"].as_u64(); 418 + let new_expires_at = expires_in 419 + .map(|secs| (chrono::Utc::now() + chrono::Duration::seconds(secs as i64)).to_rfc3339()); 420 + 421 + // Update the stored session 422 + super::sessions::store_dpop_session( 423 + pool, 424 + backend, 425 + encryption_key, 426 + &creds.session.id, 427 + &creds.session.api_client_id, 428 + &creds.session.dpop_key_id, 429 + &creds.session.user_did, 430 + new_access_token, 431 + new_refresh_token.or(creds.session.refresh_token.as_deref()), 432 + new_expires_at 433 + .as_deref() 434 + .or(creds.session.token_expires_at.as_deref()), 435 + &creds.session.scopes, 436 + creds.session.pds_url.as_deref(), 437 + creds.session.issuer.as_deref(), 438 + ) 439 + .await?; 440 + 441 + // Update the in-memory credentials 442 + creds.session.access_token = new_access_token.to_string(); 443 + if let Some(rt) = new_refresh_token { 444 + creds.session.refresh_token = Some(rt.to_string()); 445 + } 446 + if let Some(ref exp) = new_expires_at { 447 + creds.session.token_expires_at = Some(exp.clone()); 448 + } 449 + 450 + tracing::info!( 451 + user_did = %creds.session.user_did, 452 + api_client_id = %creds.session.api_client_id, 453 + "refreshed DPoP access token" 454 + ); 455 + 456 + Ok(()) 457 + } 458 + 459 + /// Discover the token endpoint from an OAuth authorization server's metadata. 460 + async fn discover_token_endpoint(http: &reqwest::Client, issuer: &str) -> Result<String, AppError> { 461 + let metadata_url = format!( 462 + "{}/.well-known/oauth-authorization-server", 463 + issuer.trim_end_matches('/') 464 + ); 465 + 466 + let resp = 467 + http.get(&metadata_url).send().await.map_err(|e| { 468 + AppError::Internal(format!("failed to fetch auth server metadata: {e}")) 469 + })?; 470 + 471 + if !resp.status().is_success() { 472 + return Err(AppError::Internal(format!( 473 + "auth server metadata returned {}", 474 + resp.status() 475 + ))); 476 + } 477 + 478 + let metadata: serde_json::Value = resp 479 + .json() 480 + .await 481 + .map_err(|e| AppError::Internal(format!("invalid auth server metadata: {e}")))?; 482 + 483 + metadata["token_endpoint"] 484 + .as_str() 485 + .map(|s| s.to_string()) 486 + .ok_or_else(|| AppError::Internal("auth server metadata missing token_endpoint".into())) 487 + } 488 + 489 + /// Look up the client_id_url for an API client by its internal ID. 490 + async fn lookup_client_id_url( 491 + pool: &sqlx::AnyPool, 492 + backend: DatabaseBackend, 493 + api_client_id: &str, 494 + ) -> Result<String, AppError> { 495 + let sql = crate::db::adapt_sql( 496 + "SELECT client_id_url FROM api_clients WHERE id = ?", 497 + backend, 498 + ); 499 + let row: Option<(String,)> = sqlx::query_as(&sql) 500 + .bind(api_client_id) 501 + .fetch_optional(pool) 502 + .await 503 + .map_err(|e| AppError::Internal(format!("failed to look up API client: {e}")))?; 504 + 505 + row.map(|(url,)| url) 506 + .ok_or_else(|| AppError::Internal("API client not found".into())) 127 507 } 128 508 129 509 /// Generate a DPoP proof JWT for a PDS request. ··· 132 512 method: &str, 133 513 url: &str, 134 514 access_token: &str, 515 + nonce: Option<&str>, 516 + ) -> Result<String, AppError> { 517 + let ath = URL_SAFE_NO_PAD.encode(Sha256::digest(access_token.as_bytes())); 518 + generate_dpop_proof_inner(private_jwk, method, url, Some(&ath), nonce) 519 + } 520 + 521 + /// Generate a DPoP proof JWT without an `ath` claim (for token endpoint requests). 522 + fn generate_dpop_proof_no_ath( 523 + private_jwk: &serde_json::Value, 524 + method: &str, 525 + url: &str, 526 + nonce: Option<&str>, 527 + ) -> Result<String, AppError> { 528 + generate_dpop_proof_inner(private_jwk, method, url, None, nonce) 529 + } 530 + 531 + fn generate_dpop_proof_inner( 532 + private_jwk: &serde_json::Value, 533 + method: &str, 534 + url: &str, 535 + ath: Option<&str>, 536 + nonce: Option<&str>, 135 537 ) -> Result<String, AppError> { 136 538 let d_b64 = private_jwk["d"] 137 539 .as_str() ··· 161 563 .duration_since(std::time::UNIX_EPOCH) 162 564 .unwrap() 163 565 .as_secs(); 164 - 165 - let ath = URL_SAFE_NO_PAD.encode(Sha256::digest(access_token.as_bytes())); 166 566 167 567 let header = serde_json::json!({ 168 568 "alg": "ES256", ··· 170 570 "jwk": public_jwk, 171 571 }); 172 572 173 - let payload = serde_json::json!({ 573 + let mut payload = serde_json::json!({ 174 574 "htm": method, 175 575 "htu": url, 176 576 "iat": now, 177 - "ath": ath, 178 577 "jti": format!("{:x}", rand::random::<u64>()), 179 578 }); 579 + if let Some(ath) = ath { 580 + payload["ath"] = serde_json::json!(ath); 581 + } 582 + if let Some(nonce) = nonce { 583 + payload["nonce"] = serde_json::json!(nonce); 584 + } 180 585 181 586 let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_vec(&header).unwrap()); 182 587 let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_vec(&payload).unwrap()); ··· 245 650 "POST", 246 651 "https://pds.example.com/xrpc/com.atproto.repo.createRecord", 247 652 "test-access-token", 653 + None, 248 654 ) 249 655 .unwrap(); 250 656 ··· 270 676 } 271 677 272 678 #[test] 679 + fn generate_dpop_proof_includes_nonce() { 680 + let keypair = super::super::keys::generate_dpop_keypair().unwrap(); 681 + 682 + let proof = generate_dpop_proof( 683 + &keypair.private_jwk, 684 + "POST", 685 + "https://pds.example.com/xrpc/test", 686 + "token", 687 + Some("server-nonce-123"), 688 + ) 689 + .unwrap(); 690 + 691 + let parts: Vec<&str> = proof.split('.').collect(); 692 + let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap(); 693 + let payload: serde_json::Value = serde_json::from_slice(&payload_bytes).unwrap(); 694 + assert_eq!(payload["nonce"], "server-nonce-123"); 695 + } 696 + 697 + #[test] 698 + fn generate_dpop_proof_no_ath_omits_ath() { 699 + let keypair = super::super::keys::generate_dpop_keypair().unwrap(); 700 + 701 + let proof = generate_dpop_proof_no_ath( 702 + &keypair.private_jwk, 703 + "POST", 704 + "https://auth.example.com/oauth/token", 705 + None, 706 + ) 707 + .unwrap(); 708 + 709 + let parts: Vec<&str> = proof.split('.').collect(); 710 + let payload_bytes = URL_SAFE_NO_PAD.decode(parts[1]).unwrap(); 711 + let payload: serde_json::Value = serde_json::from_slice(&payload_bytes).unwrap(); 712 + assert!(payload.get("ath").is_none()); 713 + assert!(payload["htm"].is_string()); 714 + assert!(payload["htu"].is_string()); 715 + } 716 + 717 + #[test] 273 718 fn generated_proof_validates_against_own_key() { 274 719 let keypair = super::super::keys::generate_dpop_keypair().unwrap(); 275 720 let url = "https://pds.example.com/xrpc/test.method"; 276 721 let token = "my-access-token"; 277 722 278 - let proof = generate_dpop_proof(&keypair.private_jwk, "POST", url, token).unwrap(); 723 + let proof = generate_dpop_proof(&keypair.private_jwk, "POST", url, token, None).unwrap(); 279 724 280 725 let result = super::super::dpop_proof::validate_dpop_proof( 281 726 &proof,
+1 -1
src/oauth/routes.rs
··· 224 224 } 225 225 226 226 // Validate scopes 227 - client_auth::validate_scopes(&body.scopes, &client.scopes)?; 227 + client_auth::validate_scopes(&body.scopes, &client.scopes, &state.lexicons).await?; 228 228 229 229 // Clean up any existing session's DPoP key before upserting 230 230 // (the ON CONFLICT upsert would orphan the old key otherwise)
+1 -1
src/repo/mod.rs
··· 2 2 pub(crate) mod session; 3 3 mod upload_blob; 4 4 5 - pub(crate) use pds::{forward_pds_response, pds_post_json_raw}; 5 + pub(crate) use pds::{PdsAuth, forward_pds_response, pds_post_json_raw}; 6 6 pub(crate) use session::{get_dpop_client_id, get_oauth_session}; 7 7 pub use upload_blob::upload_blob;
+47
src/repo/pds.rs
··· 1 + use std::sync::Arc; 2 + 1 3 use atrium_xrpc::{InputDataOrBytes, OutputDataOrBytes, XrpcClient, XrpcRequest, http::Method}; 2 4 use axum::body::Bytes; 3 5 use axum::http::StatusCode; ··· 7 9 use crate::AppState; 8 10 use crate::HappyViewOAuthSession; 9 11 use crate::error::AppError; 12 + 13 + /// Abstraction over the two PDS authentication paths. 14 + /// 15 + /// - `OAuth`: uses atrium's OAuthSession (dashboard cookie auth) 16 + /// - `Dpop`: uses the manual DPoP session from `dpop_sessions` table (third-party apps) 17 + #[derive(Clone)] 18 + pub(crate) enum PdsAuth { 19 + OAuth(Arc<HappyViewOAuthSession>), 20 + Dpop { 21 + api_client_id: String, 22 + encryption_key: [u8; 32], 23 + }, 24 + } 25 + 26 + impl PdsAuth { 27 + pub async fn post_json( 28 + &self, 29 + state: &AppState, 30 + user_did: &str, 31 + xrpc_method: &str, 32 + body: &Value, 33 + ) -> Result<reqwest::Response, AppError> { 34 + match self { 35 + PdsAuth::OAuth(session) => pds_post_json_raw(state, session, xrpc_method, body).await, 36 + PdsAuth::Dpop { 37 + api_client_id, 38 + encryption_key, 39 + } => { 40 + crate::oauth::pds_write::dpop_pds_post( 41 + &state.http, 42 + &state.db, 43 + state.db_backend, 44 + encryption_key, 45 + &state.oauth, 46 + &state.config.plc_url, 47 + api_client_id, 48 + user_did, 49 + xrpc_method, 50 + body, 51 + ) 52 + .await 53 + } 54 + } 55 + } 56 + } 10 57 11 58 /// Forward a PDS response back to the client, preserving status and body. 12 59 pub(crate) async fn forward_pds_response(resp: reqwest::Response) -> Result<Response, AppError> {
+1
src/repo/upload_blob.rs
··· 61 61 &state.db, 62 62 state.db_backend, 63 63 encryption_key, 64 + &state.oauth, 64 65 &state.config.plc_url, 65 66 &api_client_id, 66 67 claims.did(),
+1
src/xrpc/procedure.rs
··· 401 401 &state.db, 402 402 state.db_backend, 403 403 encryption_key, 404 + &state.oauth, 404 405 &state.config.plc_url, 405 406 api_client_id, 406 407 claims.did(),
+1 -1
tests/dpop_auth.rs
··· 378 378 379 379 // 3. Generate a DPoP proof for an XRPC GET request 380 380 let request_url = "http://127.0.0.1:0/xrpc/com.example.test.getStuff"; 381 - let proof = generate_dpop_proof(dpop_key, "GET", request_url, access_token) 381 + let proof = generate_dpop_proof(dpop_key, "GET", request_url, access_token, None) 382 382 .expect("failed to generate DPoP proof"); 383 383 384 384 // 4. Make an XRPC request with DPoP auth
+131 -28
web/src/app/dashboard/settings/api-clients/page.tsx
··· 1 1 "use client"; 2 2 3 3 import { useCallback, useEffect, useState } from "react"; 4 - import { Copy, Check, Trash2, X } from "lucide-react"; 4 + import { Copy, Check, Trash2, X, ExternalLink } from "lucide-react"; 5 5 6 6 import { useConfig } from "@/lib/config-context"; 7 7 import { useCurrentUser } from "@/hooks/use-current-user"; ··· 17 17 import { Button } from "@/components/ui/button"; 18 18 import { Input } from "@/components/ui/input"; 19 19 import { Label } from "@/components/ui/label"; 20 + import { RadioGroup, RadioGroupItem } from "@/components/ui/radio-group"; 20 21 import { Switch } from "@/components/ui/switch"; 21 22 import { 22 23 ResponsiveDialog, ··· 151 152 <TableHeader> 152 153 <TableRow> 153 154 <TableHead>Name</TableHead> 155 + <TableHead>Type</TableHead> 154 156 <TableHead>Client Key</TableHead> 155 157 <TableHead>Client ID URL</TableHead> 156 158 <TableHead>Scopes</TableHead> ··· 163 165 {clients.length === 0 && ( 164 166 <TableRow> 165 167 <TableCell 166 - colSpan={7} 168 + colSpan={8} 167 169 className="text-muted-foreground text-center" 168 170 > 169 171 No API clients yet. ··· 176 178 className={!client.is_active ? "opacity-50" : undefined} 177 179 > 178 180 <TableCell className="font-medium">{client.name}</TableCell> 181 + <TableCell> 182 + <Badge variant="outline"> 183 + {client.client_type === "public" ? "Public" : "Confidential"} 184 + </Badge> 185 + </TableCell> 179 186 <TableCell className="font-mono text-sm"> 180 - {client.client_key.slice(0, 12)}... 187 + {client.client_key} 181 188 </TableCell> 182 189 <TableCell className="max-w-48 truncate text-sm"> 183 190 {client.client_id_url} ··· 221 228 const config = useConfig(); 222 229 const happyviewCallbackUri = `${config.public_url.replace(/\/$/, "")}/auth/callback`; 223 230 231 + const [clientType, setClientType] = useState<"confidential" | "public">("confidential"); 224 232 const [name, setName] = useState(""); 225 233 const [clientIdUrl, setClientIdUrl] = useState(""); 226 234 const [clientUri, setClientUri] = useState(""); 227 235 const [redirectUris, setRedirectUris] = useState<string[]>([""]); 236 + const [allowedOrigins, setAllowedOrigins] = useState<string[]>([""]); 228 237 const [scopes, setScopes] = useState<string[]>([""]); 229 238 const [rateLimitEnabled, setRateLimitEnabled] = useState(true); 230 239 const [rateLimitCapacity, setRateLimitCapacity] = useState( ··· 241 250 function handleOpenChange(nextOpen: boolean) { 242 251 setOpen(nextOpen); 243 252 if (!nextOpen) { 253 + setClientType("confidential"); 244 254 setName(""); 245 255 setClientIdUrl(""); 246 256 setClientUri(""); 247 257 setRedirectUris([""]); 258 + setAllowedOrigins([""]); 248 259 setScopes([""]); 249 260 setRateLimitEnabled(true); 250 261 setRateLimitCapacity(String(config.default_rate_limit_capacity)); ··· 279 290 return; 280 291 } 281 292 try { 293 + const filteredOrigins = allowedOrigins.map((o) => o.trim()).filter(Boolean); 282 294 const result = await createApiClient({ 283 295 name: name.trim(), 284 296 client_id_url: clientIdUrl.trim(), 285 297 client_uri: clientUri.trim(), 286 298 redirect_uris: allUris, 287 299 scopes: allScopes, 300 + client_type: clientType, 301 + allowed_origins: clientType === "public" && filteredOrigins.length > 0 ? filteredOrigins : undefined, 288 302 rate_limit_capacity: rateLimitEnabled ? Number(rateLimitCapacity) : null, 289 303 rate_limit_refill_rate: rateLimitEnabled ? Number(rateLimitRefillRate) : null, 290 304 }); ··· 306 320 </ResponsiveDialogTitle> 307 321 <ResponsiveDialogDescription> 308 322 {created 309 - ? "Save the credentials below. The secret will not be shown again." 323 + ? created.client_type === "public" 324 + ? "Your public client has been created. Use PKCE for authentication." 325 + : "Save the credentials below. The secret will not be shown again." 310 326 : "Register a new application that authenticates through this AppView."} 311 327 </ResponsiveDialogDescription> 312 328 </ResponsiveDialogHeader> ··· 339 355 or <code className="bg-muted px-1 rounded">client_key</code> query parameter. 340 356 </p> 341 357 </div> 342 - <div className="flex flex-col gap-2"> 343 - <Label>Client Secret</Label> 344 - <div className="flex gap-2"> 345 - <Input 346 - readOnly 347 - value={created.client_secret} 348 - className="font-mono text-sm" 349 - /> 350 - <Button 351 - variant="outline" 352 - size="icon" 353 - onClick={() => handleCopy(created.client_secret, "secret")} 354 - title="Copy to clipboard" 358 + {created.client_secret ? ( 359 + <div className="flex flex-col gap-2"> 360 + <Label>Client Secret</Label> 361 + <div className="flex gap-2"> 362 + <Input 363 + readOnly 364 + value={created.client_secret} 365 + className="font-mono text-sm" 366 + /> 367 + <Button 368 + variant="outline" 369 + size="icon" 370 + onClick={() => handleCopy(created.client_secret!, "secret")} 371 + title="Copy to clipboard" 372 + > 373 + {copiedField === "secret" ? ( 374 + <Check className="size-4" /> 375 + ) : ( 376 + <Copy className="size-4" /> 377 + )} 378 + </Button> 379 + </div> 380 + <p className="text-muted-foreground text-xs"> 381 + Keep this secret. Send as the <code className="bg-muted px-1 rounded">X-Client-Secret</code> header 382 + for server-to-server requests. Browser requests are validated by Origin instead. 383 + </p> 384 + </div> 385 + ) : ( 386 + <div className="flex flex-col gap-2 rounded-lg border p-4 bg-muted/50"> 387 + <p className="text-sm"> 388 + This is a public client. Authenticate using PKCE instead of a client secret. 389 + </p> 390 + <a 391 + href="/docs/getting-started/authentication#pkce" 392 + target="_blank" 393 + rel="noopener noreferrer" 394 + className="inline-flex items-center gap-1 text-sm text-primary hover:underline" 355 395 > 356 - {copiedField === "secret" ? ( 357 - <Check className="size-4" /> 358 - ) : ( 359 - <Copy className="size-4" /> 360 - )} 361 - </Button> 396 + PKCE authentication docs 397 + <ExternalLink className="size-3" /> 398 + </a> 362 399 </div> 363 - <p className="text-muted-foreground text-xs"> 364 - Keep this secret. Send as the <code className="bg-muted px-1 rounded">X-Client-Secret</code> header 365 - for server-to-server requests. Browser requests are validated by Origin instead. 366 - </p> 367 - </div> 400 + )} 368 401 </div> 369 402 ) : ( 370 403 <div className="flex flex-col gap-4 max-h-[60vh] overflow-y-auto"> 371 404 {error && <p className="text-destructive text-sm">{error}</p>} 405 + <fieldset className="flex flex-col gap-3 rounded-lg border p-4"> 406 + <legend className="text-sm font-medium px-1">Client Type</legend> 407 + <RadioGroup 408 + value={clientType} 409 + onValueChange={(v) => setClientType(v as "confidential" | "public")} 410 + className="flex flex-col gap-3" 411 + > 412 + <div className="flex items-start gap-3"> 413 + <RadioGroupItem value="confidential" id="type-confidential" className="mt-0.5" /> 414 + <div className="flex flex-col gap-0.5"> 415 + <Label htmlFor="type-confidential" className="cursor-pointer font-medium">Confidential</Label> 416 + <p className="text-muted-foreground text-xs"> 417 + Server-side applications that can securely store a client secret. 418 + </p> 419 + </div> 420 + </div> 421 + <div className="flex items-start gap-3"> 422 + <RadioGroupItem value="public" id="type-public" className="mt-0.5" /> 423 + <div className="flex flex-col gap-0.5"> 424 + <Label htmlFor="type-public" className="cursor-pointer font-medium">Public</Label> 425 + <p className="text-muted-foreground text-xs"> 426 + Browser or native apps that authenticate using PKCE (no secret). 427 + </p> 428 + </div> 429 + </div> 430 + </RadioGroup> 431 + </fieldset> 372 432 <fieldset className="flex flex-col gap-3 rounded-lg border p-4"> 373 433 <legend className="text-sm font-medium px-1">Application</legend> 374 434 <div className="flex flex-col gap-2"> ··· 418 478 readonlyValues={[happyviewCallbackUri]} 419 479 /> 420 480 </fieldset> 481 + {clientType === "public" && ( 482 + <fieldset className="flex flex-col gap-3 rounded-lg border p-4"> 483 + <legend className="text-sm font-medium px-1">Allowed Origins</legend> 484 + <p className="text-muted-foreground text-xs"> 485 + Origins permitted to use this client. Requests from unlisted origins will be 486 + rejected. Leave empty to allow any origin. 487 + </p> 488 + <MultiInput 489 + id="allowed-origins" 490 + values={allowedOrigins} 491 + onChange={setAllowedOrigins} 492 + placeholder="https://myapp.com" 493 + /> 494 + </fieldset> 495 + )} 421 496 <fieldset className="flex flex-col gap-3 rounded-lg border p-4"> 422 497 <legend className="text-sm font-medium px-1">Scopes</legend> 423 498 <p className="text-muted-foreground text-xs"> ··· 521 596 return parts.length > 0 ? [...parts, ""] : [""]; 522 597 } 523 598 599 + function parseAllowedOrigins(origins: string[] | null): string[] { 600 + if (!origins || origins.length === 0) return [""]; 601 + return [...origins, ""]; 602 + } 603 + 524 604 const [name, setName] = useState(client.name); 525 605 const [redirectUris, setRedirectUris] = useState<string[]>( 526 606 parseRedirectUris(client.redirect_uris) 607 + ); 608 + const [allowedOrigins, setAllowedOrigins] = useState<string[]>( 609 + parseAllowedOrigins(client.allowed_origins) 527 610 ); 528 611 const [scopes, setScopes] = useState<string[]>(parseScopes(client.scopes)); 529 612 const [isActive, setIsActive] = useState(client.is_active); ··· 545 628 if (nextOpen) { 546 629 setName(client.name); 547 630 setRedirectUris(parseRedirectUris(client.redirect_uris)); 631 + setAllowedOrigins(parseAllowedOrigins(client.allowed_origins)); 548 632 setScopes(parseScopes(client.scopes)); 549 633 setIsActive(client.is_active); 550 634 setRateLimitEnabled( ··· 573 657 const extraScopes = scopes.map((s) => s.trim()).filter(Boolean); 574 658 const allScopes = ["atproto", ...extraScopes].join(" "); 575 659 660 + const filteredOrigins = allowedOrigins.map((o) => o.trim()).filter(Boolean); 576 661 await updateApiClient(client.id, { 577 662 name: name.trim() || undefined, 578 663 redirect_uris: allUris, 579 664 scopes: allScopes, 665 + allowed_origins: client.client_type === "public" 666 + ? (filteredOrigins.length > 0 ? filteredOrigins : null) 667 + : undefined, 580 668 is_active: isActive, 581 669 rate_limit_capacity: rateLimitEnabled ? Number(rateLimitCapacity) : null, 582 670 rate_limit_refill_rate: rateLimitEnabled ? Number(rateLimitRefillRate) : null, ··· 639 727 readonlyValues={[happyviewCallbackUri]} 640 728 /> 641 729 </fieldset> 730 + {client.client_type === "public" && ( 731 + <fieldset className="flex flex-col gap-3 rounded-lg border p-4"> 732 + <legend className="text-sm font-medium px-1">Allowed Origins</legend> 733 + <p className="text-muted-foreground text-xs"> 734 + Origins permitted to use this client. Requests from unlisted origins will be 735 + rejected. Leave empty to allow any origin. 736 + </p> 737 + <MultiInput 738 + id="edit-allowed-origins" 739 + values={allowedOrigins} 740 + onChange={setAllowedOrigins} 741 + placeholder="https://myapp.com" 742 + /> 743 + </fieldset> 744 + )} 642 745 <fieldset className="flex flex-col gap-3 rounded-lg border p-4"> 643 746 <legend className="text-sm font-medium px-1">Scopes</legend> 644 747 <p className="text-muted-foreground text-xs">
+3
web/src/lib/api.ts
··· 373 373 client_uri: string 374 374 redirect_uris: string[] 375 375 scopes?: string 376 + client_type?: string 377 + allowed_origins?: string[] 376 378 rate_limit_capacity: number | null 377 379 rate_limit_refill_rate: number | null 378 380 } ··· 390 392 client_uri?: string 391 393 redirect_uris?: string[] 392 394 scopes?: string 395 + allowed_origins?: string[] | null 393 396 rate_limit_capacity?: number | null 394 397 rate_limit_refill_rate?: number | null 395 398 is_active?: boolean
+4 -1
web/src/types/api-clients.ts
··· 6 6 client_uri: string 7 7 redirect_uris: string[] 8 8 scopes: string 9 + client_type: string 10 + allowed_origins: string[] | null 9 11 rate_limit_capacity: number | null 10 12 rate_limit_refill_rate: number | null 11 13 is_active: boolean ··· 17 19 export interface CreateApiClientResponse { 18 20 id: string 19 21 client_key: string 20 - client_secret: string 22 + client_secret?: string 21 23 name: string 22 24 client_id_url: string 25 + client_type: string 23 26 }