CLI app for developers prototyping atproto functionality
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

feat(oauth-client): add JwksFetcher trait and RealJwksFetcher

Introduces the JwksFetcher trait for fetching JWKS documents from external
URIs, along with RealJwksFetcher implementation. The trait separates JWKS
fetching from the generic HttpClient seam to allow independent mocking in
tests.

Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>

+619
+1
src/commands/test/oauth/client/pipeline.rs
··· 1 1 //! OAuth client conformance test pipeline and target parsing. 2 2 3 3 pub mod discovery; 4 + pub mod jwks; 4 5 pub mod metadata; 5 6 6 7 use miette::{Diagnostic, NamedSource, SourceSpan};
+618
src/commands/test/oauth/client/pipeline/jwks.rs
··· 1 + //! OAuth client JWKS validation stage. 2 + //! 3 + //! Fetches and validates the client's JWKS (JSON Web Key Set) for confidential 4 + //! clients using either an inline document or an external URI. 5 + 6 + use async_trait::async_trait; 7 + use miette::Diagnostic; 8 + use reqwest::Client as ReqwestClient; 9 + use std::borrow::Cow; 10 + use std::sync::Arc; 11 + use thiserror::Error; 12 + use url::Url; 13 + 14 + use crate::common::oauth::jws::ParsedJwk; 15 + use crate::common::report::{CheckResult, CheckStatus, Stage}; 16 + 17 + // Re-export JwksSource from metadata. 18 + pub use super::metadata::JwksSource; 19 + 20 + /// Response from fetching a JWKS document via the JwksFetcher seam. 21 + pub struct JwksFetchResponse { 22 + pub status: u16, 23 + pub body: Vec<u8>, 24 + pub content_type: Option<String>, 25 + } 26 + 27 + /// Error from fetching a JWKS document. 28 + #[derive(Debug, Error, Diagnostic)] 29 + #[error("network error fetching JWKS at `{url}`: {message}")] 30 + #[diagnostic(code = "oauth_client::jws::jwks_uri_unreachable")] 31 + pub struct JwksFetchError { 32 + pub url: Url, 33 + pub message: String, 34 + } 35 + 36 + /// Trait for fetching JWKS documents from a URI. 37 + /// 38 + /// Kept separate from the generic `HttpClient` seam because JWKS fetching is 39 + /// the one place the tool performs a content-type-aware fetch against an 40 + /// arbitrary third-party URL, and tests want to mock it separately from the 41 + /// client metadata document fetch. 42 + #[async_trait] 43 + pub trait JwksFetcher: Send + Sync { 44 + /// Fetch a JWKS document from the given URL. 45 + async fn fetch(&self, url: &Url) -> Result<JwksFetchResponse, JwksFetchError>; 46 + } 47 + 48 + /// Real JWKS fetcher using reqwest. 49 + pub struct RealJwksFetcher { 50 + client: ReqwestClient, 51 + } 52 + 53 + impl RealJwksFetcher { 54 + /// Create a new RealJwksFetcher from a reqwest Client. 55 + pub fn new(client: ReqwestClient) -> Self { 56 + Self { client } 57 + } 58 + } 59 + 60 + #[async_trait] 61 + impl JwksFetcher for RealJwksFetcher { 62 + async fn fetch(&self, url: &Url) -> Result<JwksFetchResponse, JwksFetchError> { 63 + let resp = self 64 + .client 65 + .get(url.clone()) 66 + .send() 67 + .await 68 + .map_err(|e| JwksFetchError { 69 + url: url.clone(), 70 + message: e.to_string(), 71 + })?; 72 + 73 + let status = resp.status().as_u16(); 74 + let content_type = resp 75 + .headers() 76 + .get(reqwest::header::CONTENT_TYPE) 77 + .and_then(|v| v.to_str().ok()) 78 + .map(str::to_string); 79 + 80 + let body = resp 81 + .bytes() 82 + .await 83 + .map_err(|e| JwksFetchError { 84 + url: url.clone(), 85 + message: e.to_string(), 86 + })? 87 + .to_vec(); 88 + 89 + Ok(JwksFetchResponse { 90 + status, 91 + body, 92 + content_type, 93 + }) 94 + } 95 + } 96 + 97 + /// Facts extracted from the JWKS validation stage. 98 + pub struct JwksFacts { 99 + /// The successfully parsed JWK keys. 100 + pub keys: Vec<ParsedJwk>, 101 + /// The source of the JWKS (inline or URI). 102 + pub source: JwksSource, 103 + } 104 + 105 + /// A single check performed by the JWKS validation stage. 106 + #[derive(Debug, Clone, Copy, PartialEq, Eq)] 107 + pub enum Check { 108 + /// JWKS is present (inline or URI). 109 + JwksPresent, 110 + /// JWKS URI is fetchable and returns 2xx. 111 + JwksUriFetchable, 112 + /// JWKS document is valid JSON with a "keys" array. 113 + JwksIsJson, 114 + /// All keys have unique `kid` values. 115 + KeysHaveUniqueKids, 116 + /// All keys have an `alg` field. 117 + KeysHaveAlg, 118 + /// All keys have `use == "sig"` (or absent, which defaults to sig). 119 + KeysUseSigningUse, 120 + /// All algorithms are modern EC (ES256 or ES256K only). 121 + AlgsAreModernEc, 122 + } 123 + 124 + impl Check { 125 + /// The stable check ID string. 126 + pub fn id(self) -> &'static str { 127 + match self { 128 + Check::JwksPresent => "oauth_client::jws::jwks_present", 129 + Check::JwksUriFetchable => "oauth_client::jws::jwks_uri_fetchable", 130 + Check::JwksIsJson => "oauth_client::jws::jwks_is_json", 131 + Check::KeysHaveUniqueKids => "oauth_client::jws::keys_have_unique_kids", 132 + Check::KeysHaveAlg => "oauth_client::jws::keys_have_alg", 133 + Check::KeysUseSigningUse => "oauth_client::jws::keys_use_signing_use", 134 + Check::AlgsAreModernEc => "oauth_client::jws::algs_are_modern_ec", 135 + } 136 + } 137 + 138 + /// Human-readable summary of the check. 139 + pub fn summary(self) -> &'static str { 140 + match self { 141 + Check::JwksPresent => "JWKS is present", 142 + Check::JwksUriFetchable => "JWKS URI is fetchable", 143 + Check::JwksIsJson => "JWKS is valid JSON", 144 + Check::KeysHaveUniqueKids => "Keys have unique kid values", 145 + Check::KeysHaveAlg => "Keys declare alg field", 146 + Check::KeysUseSigningUse => "Keys use signing use", 147 + Check::AlgsAreModernEc => "Algorithms are modern EC", 148 + } 149 + } 150 + 151 + /// Return a passing CheckResult. 152 + pub fn pass(self) -> CheckResult { 153 + CheckResult { 154 + id: self.id(), 155 + stage: Stage::JWKS, 156 + status: CheckStatus::Pass, 157 + summary: Cow::Borrowed(self.summary()), 158 + diagnostic: None, 159 + skipped_reason: None, 160 + } 161 + } 162 + 163 + /// Return a skipped CheckResult with a reason. 164 + pub fn skipped(self, reason: impl Into<Cow<'static, str>>) -> CheckResult { 165 + CheckResult { 166 + id: self.id(), 167 + stage: Stage::JWKS, 168 + status: CheckStatus::Skipped, 169 + summary: Cow::Borrowed(self.summary()), 170 + diagnostic: None, 171 + skipped_reason: Some(reason.into()), 172 + } 173 + } 174 + 175 + /// Return a network error CheckResult. 176 + pub fn network_error( 177 + self, 178 + diagnostic: Box<dyn miette::Diagnostic + Send + Sync>, 179 + ) -> CheckResult { 180 + CheckResult { 181 + id: self.id(), 182 + stage: Stage::JWKS, 183 + status: CheckStatus::NetworkError, 184 + summary: Cow::Borrowed(self.summary()), 185 + diagnostic: Some(diagnostic), 186 + skipped_reason: None, 187 + } 188 + } 189 + 190 + /// Return a spec violation CheckResult. 191 + pub fn spec_violation( 192 + self, 193 + diagnostic: Box<dyn miette::Diagnostic + Send + Sync>, 194 + ) -> CheckResult { 195 + CheckResult { 196 + id: self.id(), 197 + stage: Stage::JWKS, 198 + status: CheckStatus::SpecViolation, 199 + summary: Cow::Borrowed(self.summary()), 200 + diagnostic: Some(diagnostic), 201 + skipped_reason: None, 202 + } 203 + } 204 + } 205 + 206 + /// Output from the JWKS validation stage. 207 + pub struct JwksStageOutput { 208 + /// Extracted facts (None if validation failed early). 209 + pub facts: Option<JwksFacts>, 210 + /// All check results from this stage. 211 + pub results: Vec<CheckResult>, 212 + } 213 + 214 + /// Emit all JWKS checks as blocked by a prerequisite check. 215 + pub async fn emit_all_blocked_by(blocker_id: &'static str) -> JwksStageOutput { 216 + let checks = [ 217 + Check::JwksPresent, 218 + Check::JwksUriFetchable, 219 + Check::JwksIsJson, 220 + Check::KeysHaveUniqueKids, 221 + Check::KeysHaveAlg, 222 + Check::KeysUseSigningUse, 223 + Check::AlgsAreModernEc, 224 + ]; 225 + 226 + let results = checks 227 + .iter() 228 + .map(|&check| { 229 + crate::common::report::blocked_by(check.id(), Stage::JWKS, check.summary(), blocker_id) 230 + }) 231 + .collect(); 232 + 233 + JwksStageOutput { 234 + facts: None, 235 + results, 236 + } 237 + } 238 + 239 + /// Run the JWKS validation stage. 240 + /// 241 + /// Validates inline or external JWKS documents for confidential clients. 242 + /// For other client kinds, emits Skipped results with kind-specific reasons. 243 + pub async fn run( 244 + facts: &super::metadata::MetadataFacts, 245 + fetcher: &dyn JwksFetcher, 246 + _raw_source_name: &str, 247 + ) -> JwksStageOutput { 248 + use super::metadata::ClientKind; 249 + 250 + // For non-confidential clients, skip all JWKS checks. 251 + match facts.kind { 252 + ClientKind::Loopback => { 253 + return JwksStageOutput { 254 + facts: None, 255 + results: vec![ 256 + Check::JwksPresent.skipped("jwks not applicable to loopback clients"), 257 + Check::JwksUriFetchable.skipped("jwks not applicable to loopback clients"), 258 + Check::JwksIsJson.skipped("jwks not applicable to loopback clients"), 259 + Check::KeysHaveUniqueKids.skipped("jwks not applicable to loopback clients"), 260 + Check::KeysHaveAlg.skipped("jwks not applicable to loopback clients"), 261 + Check::KeysUseSigningUse.skipped("jwks not applicable to loopback clients"), 262 + Check::AlgsAreModernEc.skipped("jwks not applicable to loopback clients"), 263 + ], 264 + }; 265 + } 266 + ClientKind::WebPublic => { 267 + return JwksStageOutput { 268 + facts: None, 269 + results: vec![ 270 + Check::JwksPresent.skipped("jwks not required for public clients"), 271 + Check::JwksUriFetchable.skipped("jwks not required for public clients"), 272 + Check::JwksIsJson.skipped("jwks not required for public clients"), 273 + Check::KeysHaveUniqueKids.skipped("jwks not required for public clients"), 274 + Check::KeysHaveAlg.skipped("jwks not required for public clients"), 275 + Check::KeysUseSigningUse.skipped("jwks not required for public clients"), 276 + Check::AlgsAreModernEc.skipped("jwks not required for public clients"), 277 + ], 278 + }; 279 + } 280 + ClientKind::Native => { 281 + return JwksStageOutput { 282 + facts: None, 283 + results: vec![ 284 + Check::JwksPresent.skipped("jwks not required for native clients"), 285 + Check::JwksUriFetchable.skipped("jwks not required for native clients"), 286 + Check::JwksIsJson.skipped("jwks not required for native clients"), 287 + Check::KeysHaveUniqueKids.skipped("jwks not required for native clients"), 288 + Check::KeysHaveAlg.skipped("jwks not required for native clients"), 289 + Check::KeysUseSigningUse.skipped("jwks not required for native clients"), 290 + Check::AlgsAreModernEc.skipped("jwks not required for native clients"), 291 + ], 292 + }; 293 + } 294 + ClientKind::WebConfidential => {} // Continue to validation below. 295 + } 296 + 297 + // For confidential clients: inspect jwks_source. 298 + let mut results = vec![]; 299 + 300 + if facts.jwks_source.is_none() { 301 + // Metadata stage flagged a confidential-without-jwks violation. 302 + // Skip all remaining JWKS checks blocked by the metadata check. 303 + return emit_all_blocked_by("oauth_client::metadata::confidential_requires_jwks").await; 304 + } 305 + 306 + let jwks_source = facts.jwks_source.as_ref().unwrap(); 307 + 308 + // Determine the JWKS bytes and emit the appropriate initial checks. 309 + let jwks_bytes = match jwks_source { 310 + JwksSource::Inline(v) => { 311 + results.push(Check::JwksPresent.pass()); 312 + results.push(Check::JwksUriFetchable.skipped("jwks is inline")); 313 + 314 + // Serialize back to bytes for consistent spans. 315 + match serde_json::to_vec(v) { 316 + Ok(bytes) => bytes, 317 + Err(_) => { 318 + // This should never happen for a valid Value that came from parsing. 319 + results.push(Check::JwksIsJson.spec_violation(Box::new(JwksJsonError( 320 + "failed to re-serialize inline JWKS".to_string(), 321 + )))); 322 + return JwksStageOutput { 323 + facts: None, 324 + results, 325 + }; 326 + } 327 + } 328 + } 329 + JwksSource::Uri(u) => { 330 + results.push(Check::JwksPresent.pass()); 331 + 332 + // Fetch the JWKS from the URI. 333 + let fetch_result = fetcher.fetch(u).await; 334 + let response = match fetch_result { 335 + Ok(resp) => resp, 336 + Err(e) => { 337 + results.push(Check::JwksUriFetchable.network_error(Box::new(e))); 338 + // Skip remaining checks. 339 + for check in &[ 340 + Check::JwksIsJson, 341 + Check::KeysHaveUniqueKids, 342 + Check::KeysHaveAlg, 343 + Check::KeysUseSigningUse, 344 + Check::AlgsAreModernEc, 345 + ] { 346 + results.push(crate::common::report::blocked_by( 347 + check.id(), 348 + Stage::JWKS, 349 + check.summary(), 350 + Check::JwksUriFetchable.id(), 351 + )); 352 + } 353 + return JwksStageOutput { 354 + facts: None, 355 + results, 356 + }; 357 + } 358 + }; 359 + 360 + // Check for non-2xx status. 361 + if response.status < 200 || response.status >= 300 { 362 + results.push( 363 + Check::JwksUriFetchable.network_error(Box::new(JwksStatusError { 364 + url: u.clone(), 365 + status: response.status, 366 + })), 367 + ); 368 + // Skip remaining checks. 369 + for check in &[ 370 + Check::JwksIsJson, 371 + Check::KeysHaveUniqueKids, 372 + Check::KeysHaveAlg, 373 + Check::KeysUseSigningUse, 374 + Check::AlgsAreModernEc, 375 + ] { 376 + results.push(crate::common::report::blocked_by( 377 + check.id(), 378 + Stage::JWKS, 379 + check.summary(), 380 + Check::JwksUriFetchable.id(), 381 + )); 382 + } 383 + return JwksStageOutput { 384 + facts: None, 385 + results, 386 + }; 387 + } 388 + 389 + results.push(Check::JwksUriFetchable.pass()); 390 + response.body 391 + } 392 + }; 393 + 394 + // Parse the JWKS JSON and extract the keys array. 395 + let keys_array = match serde_json::from_slice::<serde_json::Value>(&jwks_bytes) { 396 + Ok(value) => { 397 + // Try to extract the "keys" array. 398 + match value.get("keys").and_then(|v| v.as_array()) { 399 + Some(arr) => { 400 + results.push(Check::JwksIsJson.pass()); 401 + arr.clone() 402 + } 403 + None => { 404 + results.push(Check::JwksIsJson.spec_violation(Box::new(JwksJsonError( 405 + "JWKS document missing required `keys` array".to_string(), 406 + )))); 407 + // Skip remaining checks. 408 + for check in &[ 409 + Check::KeysHaveUniqueKids, 410 + Check::KeysHaveAlg, 411 + Check::KeysUseSigningUse, 412 + Check::AlgsAreModernEc, 413 + ] { 414 + results.push(crate::common::report::blocked_by( 415 + check.id(), 416 + Stage::JWKS, 417 + check.summary(), 418 + Check::JwksIsJson.id(), 419 + )); 420 + } 421 + return JwksStageOutput { 422 + facts: None, 423 + results, 424 + }; 425 + } 426 + } 427 + } 428 + Err(_) => { 429 + results.push(Check::JwksIsJson.spec_violation(Box::new(JwksJsonError( 430 + "JWKS document is not valid JSON".to_string(), 431 + )))); 432 + // Skip remaining checks. 433 + for check in &[ 434 + Check::KeysHaveUniqueKids, 435 + Check::KeysHaveAlg, 436 + Check::KeysUseSigningUse, 437 + Check::AlgsAreModernEc, 438 + ] { 439 + results.push(crate::common::report::blocked_by( 440 + check.id(), 441 + Stage::JWKS, 442 + check.summary(), 443 + Check::JwksIsJson.id(), 444 + )); 445 + } 446 + return JwksStageOutput { 447 + facts: None, 448 + results, 449 + }; 450 + } 451 + }; 452 + 453 + // Parse each key and track violations. 454 + let source_bytes: Arc<[u8]> = Arc::from(jwks_bytes); 455 + let mut parsed_keys = Vec::new(); 456 + let mut has_key_alg_violation = false; 457 + let mut has_key_use_violation = false; 458 + let mut has_alg_violation = false; 459 + let mut kid_map: std::collections::HashMap<Option<Arc<str>>, Vec<usize>> = 460 + std::collections::HashMap::new(); 461 + 462 + for (i, key_value) in keys_array.iter().enumerate() { 463 + match crate::common::oauth::jws::parse_jwk(key_value, "<jwks>", source_bytes.clone()) { 464 + Err(e) => { 465 + // Map JwsError to the appropriate check and set violation flags. 466 + match e { 467 + crate::common::oauth::jws::JwsError::NotJson { .. } => { 468 + results.push(Check::JwksIsJson.spec_violation(Box::new(e))); 469 + return JwksStageOutput { 470 + facts: None, 471 + results, 472 + }; 473 + } 474 + crate::common::oauth::jws::JwsError::JwkMissingField { field, .. } => { 475 + match field { 476 + "kty" | "crv" | "x" | "y" => { 477 + results.push(Check::JwksIsJson.spec_violation(Box::new(e))); 478 + return JwksStageOutput { 479 + facts: None, 480 + results, 481 + }; 482 + } 483 + "alg" => { 484 + has_key_alg_violation = true; 485 + } 486 + _ => { 487 + // Unknown field. 488 + } 489 + } 490 + } 491 + crate::common::oauth::jws::JwsError::JwkKtyMismatch { .. } => { 492 + has_alg_violation = true; 493 + } 494 + _ => { 495 + // Other errors. 496 + } 497 + } 498 + } 499 + Ok(parsed) => { 500 + // Track kid for uniqueness check. 501 + kid_map.entry(parsed.kid.clone()).or_default().push(i); 502 + 503 + // Check for missing alg. 504 + if parsed.alg.is_none() && parsed.alg_raw.is_none() { 505 + has_key_alg_violation = true; 506 + } 507 + 508 + // Check for non-sig use. 509 + if parsed.r#use != crate::common::oauth::jws::JwkUse::Sig { 510 + has_key_use_violation = true; 511 + } 512 + 513 + // Check for non-modern-EC algorithm. 514 + if parsed.alg_raw.is_some() { 515 + has_alg_violation = true; 516 + } 517 + 518 + parsed_keys.push(parsed); 519 + } 520 + } 521 + } 522 + 523 + // Emit results for structural checks (alg, use). 524 + if has_key_alg_violation { 525 + results.push(Check::KeysHaveAlg.spec_violation(Box::new(JwksJsonError( 526 + "One or more keys missing required `alg` field".to_string(), 527 + )))); 528 + } else { 529 + results.push(Check::KeysHaveAlg.pass()); 530 + } 531 + 532 + if has_key_use_violation { 533 + results.push( 534 + Check::KeysUseSigningUse.spec_violation(Box::new(JwksJsonError( 535 + "One or more keys have `use` other than `sig`".to_string(), 536 + ))), 537 + ); 538 + } else { 539 + results.push(Check::KeysUseSigningUse.pass()); 540 + } 541 + 542 + if has_alg_violation { 543 + results.push( 544 + Check::AlgsAreModernEc.spec_violation(Box::new(JwksJsonError( 545 + "One or more keys declare non-modern algorithms".to_string(), 546 + ))), 547 + ); 548 + } else { 549 + results.push(Check::AlgsAreModernEc.pass()); 550 + } 551 + 552 + // Check for unique kids. 553 + let mut has_duplicate_kids = false; 554 + for (_kid, indices) in kid_map.iter() { 555 + if indices.len() > 1 { 556 + has_duplicate_kids = true; 557 + break; 558 + } 559 + } 560 + 561 + if has_duplicate_kids { 562 + results.push( 563 + Check::KeysHaveUniqueKids.spec_violation(Box::new(JwksJsonError( 564 + "Two or more keys share the same `kid` value".to_string(), 565 + ))), 566 + ); 567 + } else { 568 + results.push(Check::KeysHaveUniqueKids.pass()); 569 + } 570 + 571 + // Return the output with parsed keys. 572 + JwksStageOutput { 573 + facts: Some(JwksFacts { 574 + keys: parsed_keys, 575 + source: jwks_source.clone(), 576 + }), 577 + results, 578 + } 579 + } 580 + 581 + /// Simple diagnostic for JWKS JSON/structural errors. 582 + #[derive(Debug)] 583 + struct JwksJsonError(String); 584 + 585 + impl std::fmt::Display for JwksJsonError { 586 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 587 + write!(f, "{}", self.0) 588 + } 589 + } 590 + 591 + impl std::error::Error for JwksJsonError {} 592 + 593 + impl miette::Diagnostic for JwksJsonError { 594 + fn code<'a>(&'a self) -> Option<Box<dyn std::fmt::Display + 'a>> { 595 + Some(Box::new("oauth_client::jws::jwks_json")) 596 + } 597 + } 598 + 599 + /// Diagnostic for JWKS HTTP status errors. 600 + #[derive(Debug)] 601 + struct JwksStatusError { 602 + url: Url, 603 + status: u16, 604 + } 605 + 606 + impl std::fmt::Display for JwksStatusError { 607 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 608 + write!(f, "JWKS URI returned {}: {}", self.status, self.url) 609 + } 610 + } 611 + 612 + impl std::error::Error for JwksStatusError {} 613 + 614 + impl miette::Diagnostic for JwksStatusError { 615 + fn code<'a>(&'a self) -> Option<Box<dyn std::fmt::Display + 'a>> { 616 + Some(Box::new("oauth_client::jws::jwks_uri_unreachable")) 617 + } 618 + }