A better Rust ATProto crate
103
fork

Configure Feed

Select the types of activity you want to include in your feed.

at pretty-codegen 1116 lines 41 kB view raw
1use chrono::{TimeDelta, Utc}; 2use http::{Method, Request, StatusCode}; 3use jacquard_common::{ 4 CowStr, IntoStatic, 5 cowstr::ToCowStr, 6 http_client::HttpClient, 7 session::SessionStoreError, 8 types::{ 9 did::Did, 10 string::{AtStrError, Datetime}, 11 }, 12}; 13use jacquard_identity::resolver::IdentityError; 14use serde::Serialize; 15use serde_json::Value; 16use smol_str::ToSmolStr; 17 18use jose_jwa::Signing; 19 20use crate::{ 21 FALLBACK_ALG, 22 atproto::atproto_client_metadata, 23 dpop::DpopExt, 24 jose::jwt::{RegisteredClaims, RegisteredClaimsAud}, 25 keyset::Keyset, 26 resolver::OAuthResolver, 27 scopes::Scope, 28 session::{ 29 AuthRequestData, ClientData, ClientSessionData, DpopClientData, DpopDataSource, DpopReqData, 30 }, 31 types::{ 32 AuthorizationCodeChallengeMethod, AuthorizationResponseType, AuthorizeOptionPrompt, 33 OAuthAuthorizationServerMetadata, OAuthClientMetadata, OAuthParResponse, 34 OAuthTokenResponse, ParParameters, RefreshRequestParameters, RevocationRequestParameters, 35 TokenGrantType, TokenRequestParameters, TokenSet, 36 }, 37 utils::{compare_algos, generate_dpop_key, generate_nonce, generate_pkce}, 38}; 39 40// https://datatracker.ietf.org/doc/html/rfc7523#section-2.2 41const CLIENT_ASSERTION_TYPE_JWT_BEARER: &str = 42 "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"; 43 44use smol_str::SmolStr; 45 46/// Convenience alias for a heap-allocated, thread-safe, `'static` error value. 47pub type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>; 48 49/// OAuth request error for token operations and auth flows 50#[derive(Debug, thiserror::Error, miette::Diagnostic)] 51#[error("{kind}")] 52pub struct RequestError { 53 #[diagnostic_source] 54 kind: RequestErrorKind, 55 #[source] 56 source: Option<BoxError>, 57 #[help] 58 help: Option<SmolStr>, 59 context: Option<SmolStr>, 60 url: Option<SmolStr>, 61 details: Option<SmolStr>, 62 location: Option<SmolStr>, 63} 64 65/// Error categories for OAuth request operations 66#[derive(Debug, thiserror::Error, miette::Diagnostic)] 67#[non_exhaustive] 68pub enum RequestErrorKind { 69 /// No endpoint available 70 #[error("no {0} endpoint available")] 71 #[diagnostic( 72 code(jacquard_oauth::request::no_endpoint), 73 help("server does not advertise this endpoint") 74 )] 75 NoEndpoint(SmolStr), 76 77 /// Token response verification failed 78 #[error("token response verification failed")] 79 #[diagnostic(code(jacquard_oauth::request::token_verification))] 80 TokenVerification, 81 82 /// Unsupported authentication method 83 #[error("unsupported authentication method")] 84 #[diagnostic( 85 code(jacquard_oauth::request::unsupported_auth_method), 86 help( 87 "server must support `private_key_jwt` or `none`; configure client metadata accordingly" 88 ) 89 )] 90 UnsupportedAuthMethod, 91 92 /// No refresh token available 93 #[error("no refresh token available")] 94 #[diagnostic(code(jacquard_oauth::request::no_refresh_token))] 95 NoRefreshToken, 96 97 /// Invalid DID 98 #[error("failed to parse DID")] 99 #[diagnostic(code(jacquard_oauth::request::invalid_did))] 100 InvalidDid, 101 102 /// DPoP client error 103 #[error("dpop error")] 104 #[diagnostic(code(jacquard_oauth::request::dpop))] 105 Dpop, 106 107 /// Session storage error 108 #[error("storage error")] 109 #[diagnostic(code(jacquard_oauth::request::storage))] 110 Storage, 111 112 /// Resolver error 113 #[error("resolver error")] 114 #[diagnostic(code(jacquard_oauth::request::resolver))] 115 Resolver, 116 117 /// HTTP build error 118 #[error("http build error")] 119 #[diagnostic(code(jacquard_oauth::request::http_build))] 120 HttpBuild, 121 122 /// HTTP status error 123 #[error("http status: {0}")] 124 #[diagnostic( 125 code(jacquard_oauth::request::http_status), 126 help("see server response for details") 127 )] 128 HttpStatus(StatusCode), 129 130 /// HTTP status with error body 131 #[error("http status: {status}, body: {body:?}")] 132 #[diagnostic( 133 code(jacquard_oauth::request::http_status_body), 134 help("server returned error JSON; inspect fields like `error`, `error_description`") 135 )] 136 HttpStatusWithBody { 137 /// HTTP status code returned by the server. 138 status: StatusCode, 139 /// Parsed JSON body containing OAuth error fields such as `error` and `error_description`. 140 body: Value, 141 }, 142 143 /// Identity resolution error 144 #[error("identity error")] 145 #[diagnostic(code(jacquard_oauth::request::identity))] 146 Identity, 147 148 /// Keyset error 149 #[error("keyset error")] 150 #[diagnostic(code(jacquard_oauth::request::keyset))] 151 Keyset, 152 153 /// Form serialization error 154 #[error("form serialization error")] 155 #[diagnostic(code(jacquard_oauth::request::serde_form))] 156 SerdeHtmlForm, 157 158 /// JSON error 159 #[error("json error")] 160 #[diagnostic(code(jacquard_oauth::request::serde_json))] 161 SerdeJson, 162 163 /// Atproto metadata error 164 #[error("atproto error")] 165 #[diagnostic(code(jacquard_oauth::request::atproto))] 166 Atproto, 167} 168 169impl RequestError { 170 /// Create a new error with the given kind and optional source 171 pub fn new(kind: RequestErrorKind, source: Option<BoxError>) -> Self { 172 Self { 173 kind, 174 source, 175 help: None, 176 context: None, 177 url: None, 178 details: None, 179 location: None, 180 } 181 } 182 183 /// Get the error kind 184 pub fn kind(&self) -> &RequestErrorKind { 185 &self.kind 186 } 187 188 /// Get the source error if present 189 pub fn source_err(&self) -> Option<&BoxError> { 190 self.source.as_ref() 191 } 192 193 /// Get the context string if present 194 pub fn context(&self) -> Option<&str> { 195 self.context.as_ref().map(|s| s.as_str()) 196 } 197 198 /// Get the URL if present 199 pub fn url(&self) -> Option<&str> { 200 self.url.as_ref().map(|s| s.as_str()) 201 } 202 203 /// Get the details if present 204 pub fn details(&self) -> Option<&str> { 205 self.details.as_ref().map(|s| s.as_str()) 206 } 207 208 /// Get the location if present 209 pub fn location(&self) -> Option<&str> { 210 self.location.as_ref().map(|s| s.as_str()) 211 } 212 213 /// Add help text to this error 214 pub fn with_help(mut self, help: impl Into<SmolStr>) -> Self { 215 self.help = Some(help.into()); 216 self 217 } 218 219 /// Add context to this error 220 pub fn with_context(mut self, context: impl Into<SmolStr>) -> Self { 221 self.context = Some(context.into()); 222 self 223 } 224 225 /// Add URL to this error 226 pub fn with_url(mut self, url: impl Into<SmolStr>) -> Self { 227 self.url = Some(url.into()); 228 self 229 } 230 231 /// Add details to this error 232 pub fn with_details(mut self, details: impl Into<SmolStr>) -> Self { 233 self.details = Some(details.into()); 234 self 235 } 236 237 /// Add location to this error 238 pub fn with_location(mut self, location: impl Into<SmolStr>) -> Self { 239 self.location = Some(location.into()); 240 self 241 } 242 243 // Constructors for each kind 244 245 /// Create a no endpoint error 246 pub fn no_endpoint(endpoint: impl Into<SmolStr>) -> Self { 247 Self::new(RequestErrorKind::NoEndpoint(endpoint.into()), None) 248 } 249 250 /// Create a token verification error 251 pub fn token_verification() -> Self { 252 Self::new(RequestErrorKind::TokenVerification, None) 253 } 254 255 /// Create an unsupported authentication method error 256 pub fn unsupported_auth_method() -> Self { 257 Self::new(RequestErrorKind::UnsupportedAuthMethod, None) 258 } 259 260 /// Create a no refresh token error 261 pub fn no_refresh_token() -> Self { 262 Self::new(RequestErrorKind::NoRefreshToken, None) 263 } 264 265 /// Create an invalid DID error 266 pub fn invalid_did(source: impl std::error::Error + Send + Sync + 'static) -> Self { 267 Self::new(RequestErrorKind::InvalidDid, Some(Box::new(source))) 268 } 269 270 /// Create a DPoP error 271 pub fn dpop(source: impl std::error::Error + Send + Sync + 'static) -> Self { 272 Self::new(RequestErrorKind::Dpop, Some(Box::new(source))) 273 } 274 275 /// Create a storage error 276 pub fn storage(source: impl std::error::Error + Send + Sync + 'static) -> Self { 277 Self::new(RequestErrorKind::Storage, Some(Box::new(source))) 278 } 279 280 /// Create a resolver error 281 pub fn resolver(source: impl std::error::Error + Send + Sync + 'static) -> Self { 282 Self::new(RequestErrorKind::Resolver, Some(Box::new(source))) 283 } 284 285 /// Create an HTTP build error 286 pub fn http_build(source: impl std::error::Error + Send + Sync + 'static) -> Self { 287 Self::new(RequestErrorKind::HttpBuild, Some(Box::new(source))) 288 } 289 290 /// Create an HTTP status error 291 pub fn http_status(status: StatusCode) -> Self { 292 Self::new(RequestErrorKind::HttpStatus(status), None) 293 } 294 295 /// Create an HTTP status with body error 296 pub fn http_status_with_body(status: StatusCode, body: Value) -> Self { 297 Self::new(RequestErrorKind::HttpStatusWithBody { status, body }, None) 298 } 299 300 /// Create an identity error 301 pub fn identity(source: impl std::error::Error + Send + Sync + 'static) -> Self { 302 Self::new(RequestErrorKind::Identity, Some(Box::new(source))) 303 } 304 305 /// Create a keyset error 306 pub fn keyset(source: impl std::error::Error + Send + Sync + 'static) -> Self { 307 Self::new(RequestErrorKind::Keyset, Some(Box::new(source))) 308 } 309 310 /// Create an atproto metadata error 311 pub fn atproto(source: impl std::error::Error + Send + Sync + 'static) -> Self { 312 Self::new(RequestErrorKind::Atproto, Some(Box::new(source))) 313 } 314 315 /// Returns true if this error indicates permanent auth failure 316 /// (token revoked, refresh_token expired, etc.) 317 /// 318 /// When this returns true, the session should be cleared from storage 319 /// rather than retried. 320 pub fn is_permanent(&self) -> bool { 321 match &self.kind { 322 RequestErrorKind::NoRefreshToken => true, 323 RequestErrorKind::HttpStatusWithBody { body, .. } => body 324 .get("error") 325 .and_then(|e| e.as_str()) 326 .is_some_and(|e| matches!(e, "invalid_grant" | "access_denied")), 327 _ => false, 328 } 329 } 330} 331 332// From impls for common error types 333 334impl From<AtStrError> for RequestError { 335 fn from(e: AtStrError) -> Self { 336 let msg = smol_str::format_smolstr!("{:?}", e); 337 Self::new(RequestErrorKind::InvalidDid, Some(Box::new(e))) 338 .with_context(msg) 339 .with_help("ensure DID is correctly formatted (e.g., did:plc:abc123)") 340 } 341} 342 343impl From<crate::dpop::DpopError> for RequestError { 344 fn from(e: crate::dpop::DpopError) -> Self { 345 let msg = smol_str::format_smolstr!("{:?}", e); 346 Self::new(RequestErrorKind::Dpop, Some(Box::new(e))) 347 .with_context(msg) 348 .with_help("check DPoP key configuration and nonce handling") 349 } 350} 351 352impl From<SessionStoreError> for RequestError { 353 fn from(e: SessionStoreError) -> Self { 354 let msg = smol_str::format_smolstr!("{:?}", e); 355 Self::new(RequestErrorKind::Storage, Some(Box::new(e))) 356 .with_context(msg) 357 .with_help("verify session store is accessible and writable") 358 } 359} 360 361impl From<crate::resolver::ResolverError> for RequestError { 362 fn from(e: crate::resolver::ResolverError) -> Self { 363 let msg = smol_str::format_smolstr!("{:?}", e); 364 Self::new(RequestErrorKind::Resolver, Some(Box::new(e))) 365 .with_context(msg) 366 .with_help("check identity resolution and OAuth metadata endpoints") 367 } 368} 369 370impl From<http::Error> for RequestError { 371 fn from(e: http::Error) -> Self { 372 let msg = smol_str::format_smolstr!("{:?}", e); 373 Self::new(RequestErrorKind::HttpBuild, Some(Box::new(e))) 374 .with_context(msg) 375 .with_help("verify request URIs and headers are valid") 376 } 377} 378 379impl From<IdentityError> for RequestError { 380 fn from(e: IdentityError) -> Self { 381 let msg = smol_str::format_smolstr!("{:?}", e); 382 Self::new(RequestErrorKind::Identity, Some(Box::new(e))) 383 .with_context(msg) 384 .with_help("check handle/DID is valid and identity resolver is configured") 385 } 386} 387 388impl From<crate::keyset::Error> for RequestError { 389 fn from(e: crate::keyset::Error) -> Self { 390 let msg = smol_str::format_smolstr!("{:?}", e); 391 Self::new(RequestErrorKind::Keyset, Some(Box::new(e))) 392 .with_context(msg) 393 .with_help("verify keyset configuration and signing algorithm support") 394 } 395} 396 397impl From<serde_html_form::ser::Error> for RequestError { 398 fn from(e: serde_html_form::ser::Error) -> Self { 399 let msg = smol_str::format_smolstr!("{:?}", e); 400 Self::new(RequestErrorKind::SerdeHtmlForm, Some(Box::new(e))) 401 .with_context(msg) 402 .with_help("check OAuth request parameters are serializable") 403 } 404} 405 406impl From<serde_json::Error> for RequestError { 407 fn from(e: serde_json::Error) -> Self { 408 let msg = smol_str::format_smolstr!("{:?}", e); 409 Self::new(RequestErrorKind::SerdeJson, Some(Box::new(e))) 410 .with_context(msg) 411 .with_help("verify OAuth response body is valid JSON") 412 } 413} 414 415impl From<crate::atproto::Error> for RequestError { 416 fn from(e: crate::atproto::Error) -> Self { 417 let msg = smol_str::format_smolstr!("{:?}", e); 418 Self::new(RequestErrorKind::Atproto, Some(Box::new(e))) 419 .with_context(msg) 420 .with_help("ensure client metadata matches atproto requirements") 421 } 422} 423 424/// Convenience `Result` type for OAuth request operations, defaulting to [`RequestError`]. 425pub type Result<T> = core::result::Result<T, RequestError>; 426 427/// Represents the different OAuth token-endpoint request types sent by this crate. 428#[allow(dead_code)] 429pub enum OAuthRequest<'a> { 430 /// Standard authorization-code token exchange. 431 Token(TokenRequestParameters<'a>), 432 /// Refresh-token grant to obtain a fresh access token. 433 Refresh(RefreshRequestParameters<'a>), 434 /// Token revocation request (RFC 7009). 435 Revocation(RevocationRequestParameters<'a>), 436 /// Token introspection request (RFC 7662). 437 Introspection, 438 /// Pushed authorization request (RFC 9126) for pre-registering auth parameters. 439 PushedAuthorizationRequest(ParParameters<'a>), 440} 441 442impl OAuthRequest<'_> { 443 /// Return a human-readable name for this request variant, used in error messages. 444 pub fn name(&self) -> CowStr<'static> { 445 CowStr::new_static(match self { 446 Self::Token(_) => "token", 447 Self::Refresh(_) => "refresh", 448 Self::Revocation(_) => "revocation", 449 Self::Introspection => "introspection", 450 Self::PushedAuthorizationRequest(_) => "pushed_authorization_request", 451 }) 452 } 453 /// Returns the HTTP status code that a successful response to this request should carry. 454 pub fn expected_status(&self) -> StatusCode { 455 match self { 456 Self::Token(_) | Self::Refresh(_) => StatusCode::OK, 457 Self::PushedAuthorizationRequest(_) => StatusCode::CREATED, 458 // Unlike https://datatracker.ietf.org/doc/html/rfc7009#section-2.2, oauth-provider seems to return `204`. 459 Self::Revocation(_) => StatusCode::NO_CONTENT, 460 _ => unimplemented!(), 461 } 462 } 463} 464 465/// The serialized body of an OAuth token-endpoint request. 466#[derive(Debug, Serialize)] 467pub struct RequestPayload<'a, T> 468where 469 T: Serialize, 470{ 471 /// The OAuth `client_id` advertised in the client metadata document. 472 client_id: CowStr<'a>, 473 /// The assertion type URI; set to `urn:ietf:params:oauth:client-assertion-type:jwt-bearer` 474 /// when using `private_key_jwt` client authentication. 475 #[serde(skip_serializing_if = "Option::is_none")] 476 client_assertion_type: Option<CowStr<'a>>, 477 /// A JWT signed with the client's private key, proving client identity to the server. 478 #[serde(skip_serializing_if = "Option::is_none")] 479 client_assertion: Option<CowStr<'a>>, 480 /// The grant-specific parameters (token request, refresh, PAR, etc.) flattened into the body. 481 #[serde(flatten)] 482 parameters: T, 483} 484 485/// Bundled OAuth metadata needed to perform token-endpoint operations. 486/// 487/// Aggregates the server's authorization server metadata, the client's own registered metadata, 488/// and the optional signing keyset into a single value that is passed to helper functions such 489/// as [`par`], [`exchange_code`], [`refresh`], and [`revoke`]. 490#[derive(Debug, Clone)] 491pub struct OAuthMetadata { 492 /// Metadata fetched from the authorization server's `/.well-known/oauth-authorization-server` document. 493 pub server_metadata: OAuthAuthorizationServerMetadata<'static>, 494 /// This client's registered metadata, derived from [`crate::atproto::AtprotoClientMetadata`]. 495 pub client_metadata: OAuthClientMetadata<'static>, 496 /// Optional signing keyset; required for `private_key_jwt` client authentication. 497 pub keyset: Option<Keyset>, 498} 499 500impl OAuthMetadata { 501 /// Fetch server metadata and assemble an `OAuthMetadata` from an active session context. 502 /// 503 /// Contacts the authorization server recorded in `session_data` to retrieve its current 504 /// metadata, then combines it with the client configuration. This is the preferred way to 505 /// build an `OAuthMetadata` during token refresh or revocation. 506 pub async fn new<'r, T: HttpClient + OAuthResolver + Send + Sync>( 507 client: &T, 508 ClientData { keyset, config }: &ClientData<'r>, 509 session_data: &ClientSessionData<'r>, 510 ) -> Result<Self> { 511 Ok(OAuthMetadata { 512 server_metadata: client 513 .get_authorization_server_metadata(&session_data.authserver_url) 514 .await?, 515 client_metadata: atproto_client_metadata(config.clone(), &keyset) 516 .unwrap() 517 .into_static(), 518 keyset: keyset.clone(), 519 }) 520 } 521} 522 523/// Perform a Pushed Authorization Request (PAR) and return the resulting state for the auth flow. 524/// 525/// Generates a PKCE code challenge, a fresh DPoP key, and a random `state` token, then POSTs 526/// them to the authorization server's PAR endpoint. The returned [`AuthRequestData`] must be 527/// persisted (e.g., in the auth store) so it can be retrieved and verified during 528/// [`crate::client::OAuthClient::callback`]. 529#[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", skip_all, fields(login_hint = login_hint.as_ref().map(|h| h.as_ref()))))] 530pub async fn par<'r, T: OAuthResolver + DpopExt + Send + Sync + 'static>( 531 client: &T, 532 login_hint: Option<CowStr<'r>>, 533 prompt: Option<AuthorizeOptionPrompt>, 534 metadata: &OAuthMetadata, 535 state: Option<CowStr<'r>>, 536) -> crate::request::Result<AuthRequestData<'r>> { 537 let state = if let Some(state) = state { 538 state 539 } else { 540 generate_nonce() 541 }; 542 let (code_challenge, verifier) = generate_pkce(); 543 544 let Some(dpop_key) = generate_dpop_key(&metadata.server_metadata) else { 545 return Err(RequestError::token_verification()); 546 }; 547 let mut dpop_data = DpopReqData { 548 dpop_key, 549 dpop_authserver_nonce: None, 550 }; 551 let parameters = ParParameters { 552 response_type: AuthorizationResponseType::Code, 553 redirect_uri: metadata.client_metadata.redirect_uris[0].to_cowstr(), 554 state: state.clone(), 555 scope: metadata.client_metadata.scope.clone(), 556 response_mode: None, 557 code_challenge, 558 code_challenge_method: AuthorizationCodeChallengeMethod::S256, 559 login_hint: login_hint, 560 prompt: prompt.map(CowStr::from), 561 }; 562 563 if metadata 564 .server_metadata 565 .pushed_authorization_request_endpoint 566 .is_some() 567 { 568 let par_response = oauth_request::<OAuthParResponse, T, DpopReqData>( 569 &client, 570 &mut dpop_data, 571 OAuthRequest::PushedAuthorizationRequest(parameters), 572 metadata, 573 ) 574 .await?; 575 576 let scopes = if let Some(scope) = &metadata.client_metadata.scope { 577 Scope::parse_multiple_reduced(&scope) 578 .expect("Failed to parse scopes") 579 .into_static() 580 } else { 581 vec![] 582 }; 583 let auth_req_data = AuthRequestData { 584 state, 585 authserver_url: metadata.server_metadata.issuer.clone(), 586 account_did: None, 587 scopes, 588 request_uri: par_response.request_uri.to_cowstr().into_static(), 589 authserver_token_endpoint: metadata.server_metadata.token_endpoint.clone(), 590 authserver_revocation_endpoint: metadata.server_metadata.revocation_endpoint.clone(), 591 pkce_verifier: verifier, 592 dpop_data, 593 }; 594 595 Ok(auth_req_data) 596 } else if metadata 597 .server_metadata 598 .require_pushed_authorization_requests 599 == Some(true) 600 { 601 Err(RequestError::no_endpoint("pushed_authorization_request")) 602 } else { 603 todo!("use of PAR is mandatory") 604 } 605} 606 607/// Exchange a refresh token for a fresh token set and update the session data in place. 608#[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", skip_all, fields(did = %session_data.account_did)))] 609pub async fn refresh<'r, T>( 610 client: &T, 611 mut session_data: ClientSessionData<'r>, 612 metadata: &OAuthMetadata, 613) -> Result<ClientSessionData<'r>> 614where 615 T: OAuthResolver + DpopExt + Send + Sync + 'static, 616{ 617 let Some(refresh_token) = session_data.token_set.refresh_token.as_ref() else { 618 return Err(RequestError::no_refresh_token()); 619 }; 620 621 // /!\ IMPORTANT /!\ 622 // 623 // The "sub" MUST be a DID, whose issuer authority is indeed the server we 624 // are trying to obtain credentials from. Note that we are doing this 625 // *before* we actually try to refresh the token: 626 // 1) To avoid unnecessary refresh 627 // 2) So that the refresh is the last async operation, ensuring as few 628 // async operations happen before the result gets a chance to be stored. 629 let aud = client 630 .verify_issuer(&metadata.server_metadata, &session_data.token_set.sub) 631 .await?; 632 let iss = metadata.server_metadata.issuer.clone(); 633 634 let response = oauth_request::<OAuthTokenResponse, T, DpopClientData>( 635 client, 636 &mut session_data.dpop_data, 637 OAuthRequest::Refresh(RefreshRequestParameters { 638 grant_type: TokenGrantType::RefreshToken, 639 refresh_token: refresh_token.clone(), 640 scope: None, 641 }), 642 metadata, 643 ) 644 .await?; 645 646 let expires_at = response.expires_in.and_then(|expires_in| { 647 let now = Datetime::now(); 648 now.as_ref() 649 .checked_add_signed(TimeDelta::seconds(expires_in)) 650 .map(Datetime::new) 651 }); 652 653 session_data.update_with_tokens(TokenSet { 654 iss, 655 sub: session_data.token_set.sub.clone(), 656 aud: CowStr::Owned(aud.to_smolstr()), 657 scope: response.scope.map(CowStr::Owned), 658 access_token: CowStr::Owned(response.access_token), 659 refresh_token: response.refresh_token.map(CowStr::Owned), 660 token_type: response.token_type, 661 expires_at, 662 }); 663 664 Ok(session_data) 665} 666 667/// Exchange an authorization code for a token set and return a fully-verified [`TokenSet`]. 668/// 669/// Per the AT Protocol OAuth spec, the `sub` claim in the token response **must** be verified 670/// against the expected authorization server issuer before the token can be trusted. This 671/// function performs that verification as part of the exchange, so callers receive a token 672/// set that is safe to persist. 673#[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", skip_all))] 674pub async fn exchange_code<'r, T, D>( 675 client: &T, 676 data_source: &'r mut D, 677 code: &str, 678 verifier: &str, 679 metadata: &OAuthMetadata, 680) -> Result<TokenSet<'r>> 681where 682 T: OAuthResolver + DpopExt + Send + Sync + 'static, 683 D: DpopDataSource, 684{ 685 let token_response = oauth_request::<OAuthTokenResponse, T, D>( 686 client, 687 data_source, 688 OAuthRequest::Token(TokenRequestParameters { 689 grant_type: TokenGrantType::AuthorizationCode, 690 code: code.into(), 691 redirect_uri: CowStr::Owned( 692 metadata.client_metadata.redirect_uris[0] 693 .clone() 694 .to_smolstr(), 695 ), 696 code_verifier: verifier.into(), 697 }), 698 metadata, 699 ) 700 .await?; 701 let Some(sub) = token_response.sub else { 702 return Err(RequestError::token_verification()); 703 }; 704 let sub = Did::new_owned(sub)?; 705 let iss = metadata.server_metadata.issuer.clone(); 706 // /!\ IMPORTANT /!\ 707 // 708 // The token_response MUST always be valid before the "sub" it contains 709 // can be trusted (see Atproto's OAuth spec for details). 710 let aud = client 711 .verify_issuer(&metadata.server_metadata, &sub) 712 .await?; 713 714 let expires_at = token_response.expires_in.and_then(|expires_in| { 715 Datetime::now() 716 .as_ref() 717 .checked_add_signed(TimeDelta::seconds(expires_in)) 718 .map(Datetime::new) 719 }); 720 Ok(TokenSet { 721 iss, 722 sub, 723 aud: CowStr::Owned(aud.to_smolstr()), 724 scope: token_response.scope.map(CowStr::Owned), 725 access_token: CowStr::Owned(token_response.access_token), 726 refresh_token: token_response.refresh_token.map(CowStr::Owned), 727 token_type: token_response.token_type, 728 expires_at, 729 }) 730} 731 732/// Send a token revocation request (RFC 7009) to the authorization server. 733/// 734/// This function is called by [`crate::client::OAuthSession::logout`] when a revocation endpoint is advertised 735/// by the server. The caller is responsible for deleting the session from local storage regardless 736/// of whether revocation succeeds. 737#[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", skip_all))] 738pub async fn revoke<'r, T, D>( 739 client: &T, 740 data_source: &'r mut D, 741 token: &str, 742 metadata: &OAuthMetadata, 743) -> Result<()> 744where 745 T: OAuthResolver + DpopExt + Send + Sync + 'static, 746 D: DpopDataSource, 747{ 748 oauth_request::<(), T, D>( 749 client, 750 data_source, 751 OAuthRequest::Revocation(RevocationRequestParameters { 752 token: token.into(), 753 }), 754 metadata, 755 ) 756 .await?; 757 Ok(()) 758} 759 760/// Low-level function for sending an OAuth token-endpoint request and deserializing the response. 761/// 762/// Selects the correct server endpoint for `request`, builds the form-encoded body with 763/// client authentication, performs the DPoP-wrapped HTTP POST, and deserializes the response 764/// body into `O`. The type parameter `O` is inferred from the call site; use `()` for requests 765/// where the response body is empty (e.g., revocation). 766pub async fn oauth_request<'de: 'r, 'r, O, T, D>( 767 client: &T, 768 data_source: &'r mut D, 769 request: OAuthRequest<'r>, 770 metadata: &OAuthMetadata, 771) -> Result<O> 772where 773 T: OAuthResolver + DpopExt + Send + Sync + 'static, 774 O: serde::de::DeserializeOwned, 775 D: DpopDataSource, 776{ 777 let Some(url) = endpoint_for_req(&metadata.server_metadata, &request) else { 778 return Err(RequestError::no_endpoint(request.name())); 779 }; 780 let client_assertions = build_auth( 781 metadata.keyset.as_ref(), 782 &metadata.server_metadata, 783 &metadata.client_metadata, 784 )?; 785 let body = match &request { 786 OAuthRequest::Token(params) => build_oauth_req_body(client_assertions, params)?, 787 OAuthRequest::Refresh(params) => build_oauth_req_body(client_assertions, params)?, 788 OAuthRequest::Revocation(params) => build_oauth_req_body(client_assertions, params)?, 789 OAuthRequest::PushedAuthorizationRequest(params) => { 790 build_oauth_req_body(client_assertions, params)? 791 } 792 _ => unimplemented!(), 793 }; 794 let req = Request::builder() 795 .uri(url.to_string()) 796 .method(Method::POST) 797 .header("Content-Type", "application/x-www-form-urlencoded") 798 .body(body.into_bytes())?; 799 let res = client.dpop_server_call(data_source).send(req).await?; 800 if res.status() == request.expected_status() { 801 let body = res.body(); 802 if body.is_empty() { 803 // since an empty body cannot be deserialized, use “null” temporarily to allow deserialization to `()`. 804 Ok(serde_json::from_slice(b"null")?) 805 } else { 806 let output: O = serde_json::from_slice(body)?; 807 Ok(output) 808 } 809 } else if res.status().is_client_error() { 810 Err(RequestError::http_status_with_body( 811 res.status(), 812 serde_json::from_slice(res.body())?, 813 )) 814 } else { 815 Err(RequestError::http_status(res.status())) 816 } 817} 818 819#[inline] 820fn endpoint_for_req<'a, 'r>( 821 server_metadata: &'r OAuthAuthorizationServerMetadata<'a>, 822 request: &'r OAuthRequest, 823) -> Option<&'r CowStr<'a>> { 824 match request { 825 OAuthRequest::Token(_) | OAuthRequest::Refresh(_) => Some(&server_metadata.token_endpoint), 826 OAuthRequest::Revocation(_) => server_metadata.revocation_endpoint.as_ref(), 827 OAuthRequest::Introspection => server_metadata.introspection_endpoint.as_ref(), 828 OAuthRequest::PushedAuthorizationRequest(_) => server_metadata 829 .pushed_authorization_request_endpoint 830 .as_ref(), 831 } 832} 833 834#[inline] 835fn build_oauth_req_body<'a, S>(client_assertions: ClientAuth<'a>, parameters: S) -> Result<String> 836where 837 S: Serialize, 838{ 839 Ok(serde_html_form::to_string(RequestPayload { 840 client_id: client_assertions.client_id, 841 client_assertion_type: client_assertions.assertion_type, 842 client_assertion: client_assertions.assertion, 843 parameters, 844 })?) 845} 846 847/// Client identity fields appended to every token-endpoint request body. 848/// 849/// Encapsulates the result of choosing a client authentication method (`none` vs. 850/// `private_key_jwt`). The `build_auth` helper selects the appropriate variant based 851/// on server capabilities and client configuration. 852#[derive(Debug, Clone, Default)] 853pub struct ClientAuth<'a> { 854 /// The OAuth `client_id` for this client. 855 client_id: CowStr<'a>, 856 /// Either absent (for `none` auth) or `urn:ietf:params:oauth:client-assertion-type:jwt-bearer`. 857 assertion_type: Option<CowStr<'a>>, 858 /// A signed JWT proving client identity; present only for `private_key_jwt` auth. 859 assertion: Option<CowStr<'a>>, 860} 861 862impl<'s> ClientAuth<'s> { 863 /// Construct a `ClientAuth` with only a `client_id` and no assertion (the `none` method). 864 pub fn new_id(client_id: CowStr<'s>) -> Self { 865 Self { 866 client_id, 867 assertion_type: None, 868 assertion: None, 869 } 870 } 871} 872 873fn build_auth<'a>( 874 keyset: Option<&Keyset>, 875 server_metadata: &OAuthAuthorizationServerMetadata<'a>, 876 client_metadata: &OAuthClientMetadata<'a>, 877) -> Result<ClientAuth<'a>> { 878 let method_supported = server_metadata 879 .token_endpoint_auth_methods_supported 880 .as_ref(); 881 882 let client_id = client_metadata.client_id.to_cowstr().into_static(); 883 if let Some(method) = client_metadata.token_endpoint_auth_method.as_ref() { 884 match (*method).as_ref() { 885 "private_key_jwt" 886 if method_supported 887 .as_ref() 888 .is_some_and(|v| v.contains(&CowStr::new_static("private_key_jwt"))) => 889 { 890 if let Some(keyset) = &keyset { 891 let mut alg_strs = server_metadata 892 .token_endpoint_auth_signing_alg_values_supported 893 .clone() 894 .unwrap_or(vec![FALLBACK_ALG.into()]); 895 alg_strs.sort_by(compare_algos); 896 let algs: Vec<Signing> = alg_strs 897 .iter() 898 .filter_map(|s| crate::keyset::parse_signing_alg(s)) 899 .collect(); 900 let iat = Utc::now().timestamp(); 901 return Ok(ClientAuth { 902 client_id: client_id.clone(), 903 assertion_type: Some(CowStr::new_static(CLIENT_ASSERTION_TYPE_JWT_BEARER)), 904 assertion: Some( 905 keyset.create_jwt( 906 &algs, 907 // https://datatracker.ietf.org/doc/html/rfc7523#section-3 908 RegisteredClaims { 909 iss: Some(client_id.clone()), 910 sub: Some(client_id), 911 aud: Some(RegisteredClaimsAud::Single( 912 server_metadata.issuer.clone(), 913 )), 914 exp: Some(iat + 60), 915 // "iat" is required and **MUST** be less than one minute 916 // https://datatracker.ietf.org/doc/html/rfc9101 917 iat: Some(iat), 918 // atproto oauth-provider requires "jti" to be present 919 jti: Some(generate_nonce()), 920 ..Default::default() 921 } 922 .into(), 923 )?, 924 ), 925 }); 926 } 927 } 928 "none" 929 if method_supported 930 .as_ref() 931 .is_some_and(|v| v.contains(&CowStr::new_static("none"))) => 932 { 933 return Ok(ClientAuth::new_id(client_id)); 934 } 935 _ => {} 936 } 937 } 938 939 Err(RequestError::unsupported_auth_method()) 940} 941 942#[cfg(test)] 943mod tests { 944 use super::*; 945 use crate::types::{OAuthAuthorizationServerMetadata, OAuthClientMetadata}; 946 use bytes::Bytes; 947 use http::{Response as HttpResponse, StatusCode}; 948 use jacquard_common::{deps::fluent_uri::Uri, http_client::HttpClient, types::string::Did}; 949 use jacquard_identity::resolver::IdentityResolver; 950 use std::sync::Arc; 951 use tokio::sync::Mutex; 952 953 #[derive(Clone, Default)] 954 struct MockClient { 955 resp: Arc<Mutex<Option<HttpResponse<Vec<u8>>>>>, 956 } 957 958 impl HttpClient for MockClient { 959 type Error = std::convert::Infallible; 960 fn send_http( 961 &self, 962 _request: http::Request<Vec<u8>>, 963 ) -> impl core::future::Future< 964 Output = core::result::Result<http::Response<Vec<u8>>, Self::Error>, 965 > + Send { 966 let resp = self.resp.clone(); 967 async move { Ok(resp.lock().await.take().unwrap()) } 968 } 969 } 970 971 // IdentityResolver methods won't be called in these tests; provide stubs. 972 impl IdentityResolver for MockClient { 973 fn options(&self) -> &jacquard_identity::resolver::ResolverOptions { 974 use std::sync::LazyLock; 975 static OPTS: LazyLock<jacquard_identity::resolver::ResolverOptions> = 976 LazyLock::new(|| jacquard_identity::resolver::ResolverOptions::default()); 977 &OPTS 978 } 979 async fn resolve_handle( 980 &self, 981 _handle: &jacquard_common::types::string::Handle<'_>, 982 ) -> std::result::Result<Did<'static>, jacquard_identity::resolver::IdentityError> { 983 Ok(Did::new_static("did:plc:alice").unwrap()) 984 } 985 async fn resolve_did_doc( 986 &self, 987 _did: &Did<'_>, 988 ) -> std::result::Result< 989 jacquard_identity::resolver::DidDocResponse, 990 jacquard_identity::resolver::IdentityError, 991 > { 992 let doc = serde_json::json!({ 993 "id": "did:plc:alice", 994 "service": [{ 995 "id": "#pds", 996 "type": "AtprotoPersonalDataServer", 997 "serviceEndpoint": "https://pds" 998 }] 999 }); 1000 let buf = Bytes::from(serde_json::to_vec(&doc).unwrap()); 1001 Ok(jacquard_identity::resolver::DidDocResponse { 1002 buffer: buf, 1003 status: StatusCode::OK, 1004 requested: None, 1005 }) 1006 } 1007 } 1008 1009 // Allow using DPoP helpers on MockClient 1010 impl crate::dpop::DpopExt for MockClient {} 1011 impl crate::resolver::OAuthResolver for MockClient {} 1012 1013 fn base_metadata() -> OAuthMetadata { 1014 let mut server = OAuthAuthorizationServerMetadata::default(); 1015 server.issuer = CowStr::from("https://issuer"); 1016 server.authorization_endpoint = CowStr::from("https://issuer/authorize"); 1017 server.token_endpoint = CowStr::from("https://issuer/token"); 1018 server.token_endpoint_auth_methods_supported = Some(vec![CowStr::from("none")]); 1019 OAuthMetadata { 1020 server_metadata: server, 1021 client_metadata: OAuthClientMetadata { 1022 client_id: CowStr::new_static("https://client"), 1023 client_uri: None, 1024 redirect_uris: vec![CowStr::new_static("https://client/cb")], 1025 scope: Some(CowStr::from("atproto")), 1026 grant_types: None, 1027 response_types: vec![CowStr::new_static("code")], 1028 application_type: Some(CowStr::new_static("web")), 1029 token_endpoint_auth_method: Some(CowStr::from("none")), 1030 dpop_bound_access_tokens: None, 1031 jwks_uri: None, 1032 jwks: None, 1033 token_endpoint_auth_signing_alg: None, 1034 client_name: None, 1035 privacy_policy_uri: None, 1036 tos_uri: None, 1037 logo_uri: None, 1038 }, 1039 keyset: None, 1040 } 1041 } 1042 1043 #[tokio::test] 1044 async fn par_missing_endpoint() { 1045 let mut meta = base_metadata(); 1046 meta.server_metadata.require_pushed_authorization_requests = Some(true); 1047 meta.server_metadata.pushed_authorization_request_endpoint = None; 1048 // require_pushed_authorization_requests is true and no endpoint 1049 let err = super::par(&MockClient::default(), None, None, &meta, None) 1050 .await 1051 .unwrap_err(); 1052 assert!( 1053 matches!(err.kind(), RequestErrorKind::NoEndpoint(name) if name == "pushed_authorization_request") 1054 ); 1055 } 1056 1057 #[tokio::test] 1058 async fn refresh_no_refresh_token() { 1059 let client = MockClient::default(); 1060 let meta = base_metadata(); 1061 let session = ClientSessionData { 1062 account_did: Did::new_static("did:plc:alice").unwrap(), 1063 session_id: CowStr::from("state"), 1064 host_url: Uri::parse("https://pds").expect("valid").to_owned(), 1065 authserver_url: CowStr::new_static("https://issuer"), 1066 authserver_token_endpoint: CowStr::from("https://issuer/token"), 1067 authserver_revocation_endpoint: None, 1068 scopes: vec![], 1069 dpop_data: DpopClientData { 1070 dpop_key: crate::utils::generate_key(&[CowStr::from("ES256")]).unwrap(), 1071 dpop_authserver_nonce: CowStr::from(""), 1072 dpop_host_nonce: CowStr::from(""), 1073 }, 1074 token_set: crate::types::TokenSet { 1075 iss: CowStr::from("https://issuer"), 1076 sub: Did::new_static("did:plc:alice").unwrap(), 1077 aud: CowStr::from("https://pds"), 1078 scope: None, 1079 refresh_token: None, 1080 access_token: CowStr::from("abc"), 1081 token_type: crate::types::OAuthTokenType::DPoP, 1082 expires_at: None, 1083 }, 1084 }; 1085 let err = super::refresh(&client, session, &meta).await.unwrap_err(); 1086 assert!(matches!(err.kind(), RequestErrorKind::NoRefreshToken)); 1087 } 1088 1089 #[tokio::test] 1090 async fn exchange_code_missing_sub() { 1091 let client = MockClient::default(); 1092 // set mock HTTP response body: token response without `sub` 1093 *client.resp.lock().await = Some( 1094 HttpResponse::builder() 1095 .status(StatusCode::OK) 1096 .body( 1097 serde_json::to_vec(&serde_json::json!({ 1098 "access_token":"tok", 1099 "token_type":"DPoP", 1100 "expires_in": 3600 1101 })) 1102 .unwrap(), 1103 ) 1104 .unwrap(), 1105 ); 1106 let meta = base_metadata(); 1107 let mut dpop = DpopReqData { 1108 dpop_key: crate::utils::generate_key(&[CowStr::from("ES256")]).unwrap(), 1109 dpop_authserver_nonce: None, 1110 }; 1111 let err = super::exchange_code(&client, &mut dpop, "abc", "verifier", &meta) 1112 .await 1113 .unwrap_err(); 1114 assert!(matches!(err.kind(), RequestErrorKind::TokenVerification)); 1115 } 1116}