A better Rust ATProto crate
1use chrono::{TimeDelta, Utc};
2use http::{Method, Request, StatusCode};
3use jacquard_common::{
4 CowStr, IntoStatic,
5 cowstr::ToCowStr,
6 http_client::HttpClient,
7 session::SessionStoreError,
8 types::{
9 did::Did,
10 string::{AtStrError, Datetime},
11 },
12};
13use jacquard_identity::resolver::IdentityError;
14use serde::Serialize;
15use serde_json::Value;
16use smol_str::ToSmolStr;
17
18use jose_jwa::Signing;
19
20use crate::{
21 FALLBACK_ALG,
22 atproto::atproto_client_metadata,
23 dpop::DpopExt,
24 jose::jwt::{RegisteredClaims, RegisteredClaimsAud},
25 keyset::Keyset,
26 resolver::OAuthResolver,
27 scopes::Scope,
28 session::{
29 AuthRequestData, ClientData, ClientSessionData, DpopClientData, DpopDataSource, DpopReqData,
30 },
31 types::{
32 AuthorizationCodeChallengeMethod, AuthorizationResponseType, AuthorizeOptionPrompt,
33 OAuthAuthorizationServerMetadata, OAuthClientMetadata, OAuthParResponse,
34 OAuthTokenResponse, ParParameters, RefreshRequestParameters, RevocationRequestParameters,
35 TokenGrantType, TokenRequestParameters, TokenSet,
36 },
37 utils::{compare_algos, generate_dpop_key, generate_nonce, generate_pkce},
38};
39
40// https://datatracker.ietf.org/doc/html/rfc7523#section-2.2
41const CLIENT_ASSERTION_TYPE_JWT_BEARER: &str =
42 "urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
43
44use smol_str::SmolStr;
45
46/// Convenience alias for a heap-allocated, thread-safe, `'static` error value.
47pub type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
48
49/// OAuth request error for token operations and auth flows
50#[derive(Debug, thiserror::Error, miette::Diagnostic)]
51#[error("{kind}")]
52pub struct RequestError {
53 #[diagnostic_source]
54 kind: RequestErrorKind,
55 #[source]
56 source: Option<BoxError>,
57 #[help]
58 help: Option<SmolStr>,
59 context: Option<SmolStr>,
60 url: Option<SmolStr>,
61 details: Option<SmolStr>,
62 location: Option<SmolStr>,
63}
64
65/// Error categories for OAuth request operations
66#[derive(Debug, thiserror::Error, miette::Diagnostic)]
67#[non_exhaustive]
68pub enum RequestErrorKind {
69 /// No endpoint available
70 #[error("no {0} endpoint available")]
71 #[diagnostic(
72 code(jacquard_oauth::request::no_endpoint),
73 help("server does not advertise this endpoint")
74 )]
75 NoEndpoint(SmolStr),
76
77 /// Token response verification failed
78 #[error("token response verification failed")]
79 #[diagnostic(code(jacquard_oauth::request::token_verification))]
80 TokenVerification,
81
82 /// Unsupported authentication method
83 #[error("unsupported authentication method")]
84 #[diagnostic(
85 code(jacquard_oauth::request::unsupported_auth_method),
86 help(
87 "server must support `private_key_jwt` or `none`; configure client metadata accordingly"
88 )
89 )]
90 UnsupportedAuthMethod,
91
92 /// No refresh token available
93 #[error("no refresh token available")]
94 #[diagnostic(code(jacquard_oauth::request::no_refresh_token))]
95 NoRefreshToken,
96
97 /// Invalid DID
98 #[error("failed to parse DID")]
99 #[diagnostic(code(jacquard_oauth::request::invalid_did))]
100 InvalidDid,
101
102 /// DPoP client error
103 #[error("dpop error")]
104 #[diagnostic(code(jacquard_oauth::request::dpop))]
105 Dpop,
106
107 /// Session storage error
108 #[error("storage error")]
109 #[diagnostic(code(jacquard_oauth::request::storage))]
110 Storage,
111
112 /// Resolver error
113 #[error("resolver error")]
114 #[diagnostic(code(jacquard_oauth::request::resolver))]
115 Resolver,
116
117 /// HTTP build error
118 #[error("http build error")]
119 #[diagnostic(code(jacquard_oauth::request::http_build))]
120 HttpBuild,
121
122 /// HTTP status error
123 #[error("http status: {0}")]
124 #[diagnostic(
125 code(jacquard_oauth::request::http_status),
126 help("see server response for details")
127 )]
128 HttpStatus(StatusCode),
129
130 /// HTTP status with error body
131 #[error("http status: {status}, body: {body:?}")]
132 #[diagnostic(
133 code(jacquard_oauth::request::http_status_body),
134 help("server returned error JSON; inspect fields like `error`, `error_description`")
135 )]
136 HttpStatusWithBody {
137 /// HTTP status code returned by the server.
138 status: StatusCode,
139 /// Parsed JSON body containing OAuth error fields such as `error` and `error_description`.
140 body: Value,
141 },
142
143 /// Identity resolution error
144 #[error("identity error")]
145 #[diagnostic(code(jacquard_oauth::request::identity))]
146 Identity,
147
148 /// Keyset error
149 #[error("keyset error")]
150 #[diagnostic(code(jacquard_oauth::request::keyset))]
151 Keyset,
152
153 /// Form serialization error
154 #[error("form serialization error")]
155 #[diagnostic(code(jacquard_oauth::request::serde_form))]
156 SerdeHtmlForm,
157
158 /// JSON error
159 #[error("json error")]
160 #[diagnostic(code(jacquard_oauth::request::serde_json))]
161 SerdeJson,
162
163 /// Atproto metadata error
164 #[error("atproto error")]
165 #[diagnostic(code(jacquard_oauth::request::atproto))]
166 Atproto,
167}
168
169impl RequestError {
170 /// Create a new error with the given kind and optional source
171 pub fn new(kind: RequestErrorKind, source: Option<BoxError>) -> Self {
172 Self {
173 kind,
174 source,
175 help: None,
176 context: None,
177 url: None,
178 details: None,
179 location: None,
180 }
181 }
182
183 /// Get the error kind
184 pub fn kind(&self) -> &RequestErrorKind {
185 &self.kind
186 }
187
188 /// Get the source error if present
189 pub fn source_err(&self) -> Option<&BoxError> {
190 self.source.as_ref()
191 }
192
193 /// Get the context string if present
194 pub fn context(&self) -> Option<&str> {
195 self.context.as_ref().map(|s| s.as_str())
196 }
197
198 /// Get the URL if present
199 pub fn url(&self) -> Option<&str> {
200 self.url.as_ref().map(|s| s.as_str())
201 }
202
203 /// Get the details if present
204 pub fn details(&self) -> Option<&str> {
205 self.details.as_ref().map(|s| s.as_str())
206 }
207
208 /// Get the location if present
209 pub fn location(&self) -> Option<&str> {
210 self.location.as_ref().map(|s| s.as_str())
211 }
212
213 /// Add help text to this error
214 pub fn with_help(mut self, help: impl Into<SmolStr>) -> Self {
215 self.help = Some(help.into());
216 self
217 }
218
219 /// Add context to this error
220 pub fn with_context(mut self, context: impl Into<SmolStr>) -> Self {
221 self.context = Some(context.into());
222 self
223 }
224
225 /// Add URL to this error
226 pub fn with_url(mut self, url: impl Into<SmolStr>) -> Self {
227 self.url = Some(url.into());
228 self
229 }
230
231 /// Add details to this error
232 pub fn with_details(mut self, details: impl Into<SmolStr>) -> Self {
233 self.details = Some(details.into());
234 self
235 }
236
237 /// Add location to this error
238 pub fn with_location(mut self, location: impl Into<SmolStr>) -> Self {
239 self.location = Some(location.into());
240 self
241 }
242
243 // Constructors for each kind
244
245 /// Create a no endpoint error
246 pub fn no_endpoint(endpoint: impl Into<SmolStr>) -> Self {
247 Self::new(RequestErrorKind::NoEndpoint(endpoint.into()), None)
248 }
249
250 /// Create a token verification error
251 pub fn token_verification() -> Self {
252 Self::new(RequestErrorKind::TokenVerification, None)
253 }
254
255 /// Create an unsupported authentication method error
256 pub fn unsupported_auth_method() -> Self {
257 Self::new(RequestErrorKind::UnsupportedAuthMethod, None)
258 }
259
260 /// Create a no refresh token error
261 pub fn no_refresh_token() -> Self {
262 Self::new(RequestErrorKind::NoRefreshToken, None)
263 }
264
265 /// Create an invalid DID error
266 pub fn invalid_did(source: impl std::error::Error + Send + Sync + 'static) -> Self {
267 Self::new(RequestErrorKind::InvalidDid, Some(Box::new(source)))
268 }
269
270 /// Create a DPoP error
271 pub fn dpop(source: impl std::error::Error + Send + Sync + 'static) -> Self {
272 Self::new(RequestErrorKind::Dpop, Some(Box::new(source)))
273 }
274
275 /// Create a storage error
276 pub fn storage(source: impl std::error::Error + Send + Sync + 'static) -> Self {
277 Self::new(RequestErrorKind::Storage, Some(Box::new(source)))
278 }
279
280 /// Create a resolver error
281 pub fn resolver(source: impl std::error::Error + Send + Sync + 'static) -> Self {
282 Self::new(RequestErrorKind::Resolver, Some(Box::new(source)))
283 }
284
285 /// Create an HTTP build error
286 pub fn http_build(source: impl std::error::Error + Send + Sync + 'static) -> Self {
287 Self::new(RequestErrorKind::HttpBuild, Some(Box::new(source)))
288 }
289
290 /// Create an HTTP status error
291 pub fn http_status(status: StatusCode) -> Self {
292 Self::new(RequestErrorKind::HttpStatus(status), None)
293 }
294
295 /// Create an HTTP status with body error
296 pub fn http_status_with_body(status: StatusCode, body: Value) -> Self {
297 Self::new(RequestErrorKind::HttpStatusWithBody { status, body }, None)
298 }
299
300 /// Create an identity error
301 pub fn identity(source: impl std::error::Error + Send + Sync + 'static) -> Self {
302 Self::new(RequestErrorKind::Identity, Some(Box::new(source)))
303 }
304
305 /// Create a keyset error
306 pub fn keyset(source: impl std::error::Error + Send + Sync + 'static) -> Self {
307 Self::new(RequestErrorKind::Keyset, Some(Box::new(source)))
308 }
309
310 /// Create an atproto metadata error
311 pub fn atproto(source: impl std::error::Error + Send + Sync + 'static) -> Self {
312 Self::new(RequestErrorKind::Atproto, Some(Box::new(source)))
313 }
314
315 /// Returns true if this error indicates permanent auth failure
316 /// (token revoked, refresh_token expired, etc.)
317 ///
318 /// When this returns true, the session should be cleared from storage
319 /// rather than retried.
320 pub fn is_permanent(&self) -> bool {
321 match &self.kind {
322 RequestErrorKind::NoRefreshToken => true,
323 RequestErrorKind::HttpStatusWithBody { body, .. } => body
324 .get("error")
325 .and_then(|e| e.as_str())
326 .is_some_and(|e| matches!(e, "invalid_grant" | "access_denied")),
327 _ => false,
328 }
329 }
330}
331
332// From impls for common error types
333
334impl From<AtStrError> for RequestError {
335 fn from(e: AtStrError) -> Self {
336 let msg = smol_str::format_smolstr!("{:?}", e);
337 Self::new(RequestErrorKind::InvalidDid, Some(Box::new(e)))
338 .with_context(msg)
339 .with_help("ensure DID is correctly formatted (e.g., did:plc:abc123)")
340 }
341}
342
343impl From<crate::dpop::DpopError> for RequestError {
344 fn from(e: crate::dpop::DpopError) -> Self {
345 let msg = smol_str::format_smolstr!("{:?}", e);
346 Self::new(RequestErrorKind::Dpop, Some(Box::new(e)))
347 .with_context(msg)
348 .with_help("check DPoP key configuration and nonce handling")
349 }
350}
351
352impl From<SessionStoreError> for RequestError {
353 fn from(e: SessionStoreError) -> Self {
354 let msg = smol_str::format_smolstr!("{:?}", e);
355 Self::new(RequestErrorKind::Storage, Some(Box::new(e)))
356 .with_context(msg)
357 .with_help("verify session store is accessible and writable")
358 }
359}
360
361impl From<crate::resolver::ResolverError> for RequestError {
362 fn from(e: crate::resolver::ResolverError) -> Self {
363 let msg = smol_str::format_smolstr!("{:?}", e);
364 Self::new(RequestErrorKind::Resolver, Some(Box::new(e)))
365 .with_context(msg)
366 .with_help("check identity resolution and OAuth metadata endpoints")
367 }
368}
369
370impl From<http::Error> for RequestError {
371 fn from(e: http::Error) -> Self {
372 let msg = smol_str::format_smolstr!("{:?}", e);
373 Self::new(RequestErrorKind::HttpBuild, Some(Box::new(e)))
374 .with_context(msg)
375 .with_help("verify request URIs and headers are valid")
376 }
377}
378
379impl From<IdentityError> for RequestError {
380 fn from(e: IdentityError) -> Self {
381 let msg = smol_str::format_smolstr!("{:?}", e);
382 Self::new(RequestErrorKind::Identity, Some(Box::new(e)))
383 .with_context(msg)
384 .with_help("check handle/DID is valid and identity resolver is configured")
385 }
386}
387
388impl From<crate::keyset::Error> for RequestError {
389 fn from(e: crate::keyset::Error) -> Self {
390 let msg = smol_str::format_smolstr!("{:?}", e);
391 Self::new(RequestErrorKind::Keyset, Some(Box::new(e)))
392 .with_context(msg)
393 .with_help("verify keyset configuration and signing algorithm support")
394 }
395}
396
397impl From<serde_html_form::ser::Error> for RequestError {
398 fn from(e: serde_html_form::ser::Error) -> Self {
399 let msg = smol_str::format_smolstr!("{:?}", e);
400 Self::new(RequestErrorKind::SerdeHtmlForm, Some(Box::new(e)))
401 .with_context(msg)
402 .with_help("check OAuth request parameters are serializable")
403 }
404}
405
406impl From<serde_json::Error> for RequestError {
407 fn from(e: serde_json::Error) -> Self {
408 let msg = smol_str::format_smolstr!("{:?}", e);
409 Self::new(RequestErrorKind::SerdeJson, Some(Box::new(e)))
410 .with_context(msg)
411 .with_help("verify OAuth response body is valid JSON")
412 }
413}
414
415impl From<crate::atproto::Error> for RequestError {
416 fn from(e: crate::atproto::Error) -> Self {
417 let msg = smol_str::format_smolstr!("{:?}", e);
418 Self::new(RequestErrorKind::Atproto, Some(Box::new(e)))
419 .with_context(msg)
420 .with_help("ensure client metadata matches atproto requirements")
421 }
422}
423
424/// Convenience `Result` type for OAuth request operations, defaulting to [`RequestError`].
425pub type Result<T> = core::result::Result<T, RequestError>;
426
427/// Represents the different OAuth token-endpoint request types sent by this crate.
428#[allow(dead_code)]
429pub enum OAuthRequest<'a> {
430 /// Standard authorization-code token exchange.
431 Token(TokenRequestParameters<'a>),
432 /// Refresh-token grant to obtain a fresh access token.
433 Refresh(RefreshRequestParameters<'a>),
434 /// Token revocation request (RFC 7009).
435 Revocation(RevocationRequestParameters<'a>),
436 /// Token introspection request (RFC 7662).
437 Introspection,
438 /// Pushed authorization request (RFC 9126) for pre-registering auth parameters.
439 PushedAuthorizationRequest(ParParameters<'a>),
440}
441
442impl OAuthRequest<'_> {
443 /// Return a human-readable name for this request variant, used in error messages.
444 pub fn name(&self) -> CowStr<'static> {
445 CowStr::new_static(match self {
446 Self::Token(_) => "token",
447 Self::Refresh(_) => "refresh",
448 Self::Revocation(_) => "revocation",
449 Self::Introspection => "introspection",
450 Self::PushedAuthorizationRequest(_) => "pushed_authorization_request",
451 })
452 }
453 /// Returns the HTTP status code that a successful response to this request should carry.
454 pub fn expected_status(&self) -> StatusCode {
455 match self {
456 Self::Token(_) | Self::Refresh(_) => StatusCode::OK,
457 Self::PushedAuthorizationRequest(_) => StatusCode::CREATED,
458 // Unlike https://datatracker.ietf.org/doc/html/rfc7009#section-2.2, oauth-provider seems to return `204`.
459 Self::Revocation(_) => StatusCode::NO_CONTENT,
460 _ => unimplemented!(),
461 }
462 }
463}
464
465/// The serialized body of an OAuth token-endpoint request.
466#[derive(Debug, Serialize)]
467pub struct RequestPayload<'a, T>
468where
469 T: Serialize,
470{
471 /// The OAuth `client_id` advertised in the client metadata document.
472 client_id: CowStr<'a>,
473 /// The assertion type URI; set to `urn:ietf:params:oauth:client-assertion-type:jwt-bearer`
474 /// when using `private_key_jwt` client authentication.
475 #[serde(skip_serializing_if = "Option::is_none")]
476 client_assertion_type: Option<CowStr<'a>>,
477 /// A JWT signed with the client's private key, proving client identity to the server.
478 #[serde(skip_serializing_if = "Option::is_none")]
479 client_assertion: Option<CowStr<'a>>,
480 /// The grant-specific parameters (token request, refresh, PAR, etc.) flattened into the body.
481 #[serde(flatten)]
482 parameters: T,
483}
484
485/// Bundled OAuth metadata needed to perform token-endpoint operations.
486///
487/// Aggregates the server's authorization server metadata, the client's own registered metadata,
488/// and the optional signing keyset into a single value that is passed to helper functions such
489/// as [`par`], [`exchange_code`], [`refresh`], and [`revoke`].
490#[derive(Debug, Clone)]
491pub struct OAuthMetadata {
492 /// Metadata fetched from the authorization server's `/.well-known/oauth-authorization-server` document.
493 pub server_metadata: OAuthAuthorizationServerMetadata<'static>,
494 /// This client's registered metadata, derived from [`crate::atproto::AtprotoClientMetadata`].
495 pub client_metadata: OAuthClientMetadata<'static>,
496 /// Optional signing keyset; required for `private_key_jwt` client authentication.
497 pub keyset: Option<Keyset>,
498}
499
500impl OAuthMetadata {
501 /// Fetch server metadata and assemble an `OAuthMetadata` from an active session context.
502 ///
503 /// Contacts the authorization server recorded in `session_data` to retrieve its current
504 /// metadata, then combines it with the client configuration. This is the preferred way to
505 /// build an `OAuthMetadata` during token refresh or revocation.
506 pub async fn new<'r, T: HttpClient + OAuthResolver + Send + Sync>(
507 client: &T,
508 ClientData { keyset, config }: &ClientData<'r>,
509 session_data: &ClientSessionData<'r>,
510 ) -> Result<Self> {
511 Ok(OAuthMetadata {
512 server_metadata: client
513 .get_authorization_server_metadata(&session_data.authserver_url)
514 .await?,
515 client_metadata: atproto_client_metadata(config.clone(), &keyset)
516 .unwrap()
517 .into_static(),
518 keyset: keyset.clone(),
519 })
520 }
521}
522
523/// Perform a Pushed Authorization Request (PAR) and return the resulting state for the auth flow.
524///
525/// Generates a PKCE code challenge, a fresh DPoP key, and a random `state` token, then POSTs
526/// them to the authorization server's PAR endpoint. The returned [`AuthRequestData`] must be
527/// persisted (e.g., in the auth store) so it can be retrieved and verified during
528/// [`crate::client::OAuthClient::callback`].
529#[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", skip_all, fields(login_hint = login_hint.as_ref().map(|h| h.as_ref()))))]
530pub async fn par<'r, T: OAuthResolver + DpopExt + Send + Sync + 'static>(
531 client: &T,
532 login_hint: Option<CowStr<'r>>,
533 prompt: Option<AuthorizeOptionPrompt>,
534 metadata: &OAuthMetadata,
535 state: Option<CowStr<'r>>,
536) -> crate::request::Result<AuthRequestData<'r>> {
537 let state = if let Some(state) = state {
538 state
539 } else {
540 generate_nonce()
541 };
542 let (code_challenge, verifier) = generate_pkce();
543
544 let Some(dpop_key) = generate_dpop_key(&metadata.server_metadata) else {
545 return Err(RequestError::token_verification());
546 };
547 let mut dpop_data = DpopReqData {
548 dpop_key,
549 dpop_authserver_nonce: None,
550 };
551 let parameters = ParParameters {
552 response_type: AuthorizationResponseType::Code,
553 redirect_uri: metadata.client_metadata.redirect_uris[0].to_cowstr(),
554 state: state.clone(),
555 scope: metadata.client_metadata.scope.clone(),
556 response_mode: None,
557 code_challenge,
558 code_challenge_method: AuthorizationCodeChallengeMethod::S256,
559 login_hint: login_hint,
560 prompt: prompt.map(CowStr::from),
561 };
562
563 if metadata
564 .server_metadata
565 .pushed_authorization_request_endpoint
566 .is_some()
567 {
568 let par_response = oauth_request::<OAuthParResponse, T, DpopReqData>(
569 &client,
570 &mut dpop_data,
571 OAuthRequest::PushedAuthorizationRequest(parameters),
572 metadata,
573 )
574 .await?;
575
576 let scopes = if let Some(scope) = &metadata.client_metadata.scope {
577 Scope::parse_multiple_reduced(&scope)
578 .expect("Failed to parse scopes")
579 .into_static()
580 } else {
581 vec![]
582 };
583 let auth_req_data = AuthRequestData {
584 state,
585 authserver_url: metadata.server_metadata.issuer.clone(),
586 account_did: None,
587 scopes,
588 request_uri: par_response.request_uri.to_cowstr().into_static(),
589 authserver_token_endpoint: metadata.server_metadata.token_endpoint.clone(),
590 authserver_revocation_endpoint: metadata.server_metadata.revocation_endpoint.clone(),
591 pkce_verifier: verifier,
592 dpop_data,
593 };
594
595 Ok(auth_req_data)
596 } else if metadata
597 .server_metadata
598 .require_pushed_authorization_requests
599 == Some(true)
600 {
601 Err(RequestError::no_endpoint("pushed_authorization_request"))
602 } else {
603 todo!("use of PAR is mandatory")
604 }
605}
606
607/// Exchange a refresh token for a fresh token set and update the session data in place.
608#[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", skip_all, fields(did = %session_data.account_did)))]
609pub async fn refresh<'r, T>(
610 client: &T,
611 mut session_data: ClientSessionData<'r>,
612 metadata: &OAuthMetadata,
613) -> Result<ClientSessionData<'r>>
614where
615 T: OAuthResolver + DpopExt + Send + Sync + 'static,
616{
617 let Some(refresh_token) = session_data.token_set.refresh_token.as_ref() else {
618 return Err(RequestError::no_refresh_token());
619 };
620
621 // /!\ IMPORTANT /!\
622 //
623 // The "sub" MUST be a DID, whose issuer authority is indeed the server we
624 // are trying to obtain credentials from. Note that we are doing this
625 // *before* we actually try to refresh the token:
626 // 1) To avoid unnecessary refresh
627 // 2) So that the refresh is the last async operation, ensuring as few
628 // async operations happen before the result gets a chance to be stored.
629 let aud = client
630 .verify_issuer(&metadata.server_metadata, &session_data.token_set.sub)
631 .await?;
632 let iss = metadata.server_metadata.issuer.clone();
633
634 let response = oauth_request::<OAuthTokenResponse, T, DpopClientData>(
635 client,
636 &mut session_data.dpop_data,
637 OAuthRequest::Refresh(RefreshRequestParameters {
638 grant_type: TokenGrantType::RefreshToken,
639 refresh_token: refresh_token.clone(),
640 scope: None,
641 }),
642 metadata,
643 )
644 .await?;
645
646 let expires_at = response.expires_in.and_then(|expires_in| {
647 let now = Datetime::now();
648 now.as_ref()
649 .checked_add_signed(TimeDelta::seconds(expires_in))
650 .map(Datetime::new)
651 });
652
653 session_data.update_with_tokens(TokenSet {
654 iss,
655 sub: session_data.token_set.sub.clone(),
656 aud: CowStr::Owned(aud.to_smolstr()),
657 scope: response.scope.map(CowStr::Owned),
658 access_token: CowStr::Owned(response.access_token),
659 refresh_token: response.refresh_token.map(CowStr::Owned),
660 token_type: response.token_type,
661 expires_at,
662 });
663
664 Ok(session_data)
665}
666
667/// Exchange an authorization code for a token set and return a fully-verified [`TokenSet`].
668///
669/// Per the AT Protocol OAuth spec, the `sub` claim in the token response **must** be verified
670/// against the expected authorization server issuer before the token can be trusted. This
671/// function performs that verification as part of the exchange, so callers receive a token
672/// set that is safe to persist.
673#[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", skip_all))]
674pub async fn exchange_code<'r, T, D>(
675 client: &T,
676 data_source: &'r mut D,
677 code: &str,
678 verifier: &str,
679 metadata: &OAuthMetadata,
680) -> Result<TokenSet<'r>>
681where
682 T: OAuthResolver + DpopExt + Send + Sync + 'static,
683 D: DpopDataSource,
684{
685 let token_response = oauth_request::<OAuthTokenResponse, T, D>(
686 client,
687 data_source,
688 OAuthRequest::Token(TokenRequestParameters {
689 grant_type: TokenGrantType::AuthorizationCode,
690 code: code.into(),
691 redirect_uri: CowStr::Owned(
692 metadata.client_metadata.redirect_uris[0]
693 .clone()
694 .to_smolstr(),
695 ),
696 code_verifier: verifier.into(),
697 }),
698 metadata,
699 )
700 .await?;
701 let Some(sub) = token_response.sub else {
702 return Err(RequestError::token_verification());
703 };
704 let sub = Did::new_owned(sub)?;
705 let iss = metadata.server_metadata.issuer.clone();
706 // /!\ IMPORTANT /!\
707 //
708 // The token_response MUST always be valid before the "sub" it contains
709 // can be trusted (see Atproto's OAuth spec for details).
710 let aud = client
711 .verify_issuer(&metadata.server_metadata, &sub)
712 .await?;
713
714 let expires_at = token_response.expires_in.and_then(|expires_in| {
715 Datetime::now()
716 .as_ref()
717 .checked_add_signed(TimeDelta::seconds(expires_in))
718 .map(Datetime::new)
719 });
720 Ok(TokenSet {
721 iss,
722 sub,
723 aud: CowStr::Owned(aud.to_smolstr()),
724 scope: token_response.scope.map(CowStr::Owned),
725 access_token: CowStr::Owned(token_response.access_token),
726 refresh_token: token_response.refresh_token.map(CowStr::Owned),
727 token_type: token_response.token_type,
728 expires_at,
729 })
730}
731
732/// Send a token revocation request (RFC 7009) to the authorization server.
733///
734/// This function is called by [`crate::client::OAuthSession::logout`] when a revocation endpoint is advertised
735/// by the server. The caller is responsible for deleting the session from local storage regardless
736/// of whether revocation succeeds.
737#[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", skip_all))]
738pub async fn revoke<'r, T, D>(
739 client: &T,
740 data_source: &'r mut D,
741 token: &str,
742 metadata: &OAuthMetadata,
743) -> Result<()>
744where
745 T: OAuthResolver + DpopExt + Send + Sync + 'static,
746 D: DpopDataSource,
747{
748 oauth_request::<(), T, D>(
749 client,
750 data_source,
751 OAuthRequest::Revocation(RevocationRequestParameters {
752 token: token.into(),
753 }),
754 metadata,
755 )
756 .await?;
757 Ok(())
758}
759
760/// Low-level function for sending an OAuth token-endpoint request and deserializing the response.
761///
762/// Selects the correct server endpoint for `request`, builds the form-encoded body with
763/// client authentication, performs the DPoP-wrapped HTTP POST, and deserializes the response
764/// body into `O`. The type parameter `O` is inferred from the call site; use `()` for requests
765/// where the response body is empty (e.g., revocation).
766pub async fn oauth_request<'de: 'r, 'r, O, T, D>(
767 client: &T,
768 data_source: &'r mut D,
769 request: OAuthRequest<'r>,
770 metadata: &OAuthMetadata,
771) -> Result<O>
772where
773 T: OAuthResolver + DpopExt + Send + Sync + 'static,
774 O: serde::de::DeserializeOwned,
775 D: DpopDataSource,
776{
777 let Some(url) = endpoint_for_req(&metadata.server_metadata, &request) else {
778 return Err(RequestError::no_endpoint(request.name()));
779 };
780 let client_assertions = build_auth(
781 metadata.keyset.as_ref(),
782 &metadata.server_metadata,
783 &metadata.client_metadata,
784 )?;
785 let body = match &request {
786 OAuthRequest::Token(params) => build_oauth_req_body(client_assertions, params)?,
787 OAuthRequest::Refresh(params) => build_oauth_req_body(client_assertions, params)?,
788 OAuthRequest::Revocation(params) => build_oauth_req_body(client_assertions, params)?,
789 OAuthRequest::PushedAuthorizationRequest(params) => {
790 build_oauth_req_body(client_assertions, params)?
791 }
792 _ => unimplemented!(),
793 };
794 let req = Request::builder()
795 .uri(url.to_string())
796 .method(Method::POST)
797 .header("Content-Type", "application/x-www-form-urlencoded")
798 .body(body.into_bytes())?;
799 let res = client.dpop_server_call(data_source).send(req).await?;
800 if res.status() == request.expected_status() {
801 let body = res.body();
802 if body.is_empty() {
803 // since an empty body cannot be deserialized, use “null” temporarily to allow deserialization to `()`.
804 Ok(serde_json::from_slice(b"null")?)
805 } else {
806 let output: O = serde_json::from_slice(body)?;
807 Ok(output)
808 }
809 } else if res.status().is_client_error() {
810 Err(RequestError::http_status_with_body(
811 res.status(),
812 serde_json::from_slice(res.body())?,
813 ))
814 } else {
815 Err(RequestError::http_status(res.status()))
816 }
817}
818
819#[inline]
820fn endpoint_for_req<'a, 'r>(
821 server_metadata: &'r OAuthAuthorizationServerMetadata<'a>,
822 request: &'r OAuthRequest,
823) -> Option<&'r CowStr<'a>> {
824 match request {
825 OAuthRequest::Token(_) | OAuthRequest::Refresh(_) => Some(&server_metadata.token_endpoint),
826 OAuthRequest::Revocation(_) => server_metadata.revocation_endpoint.as_ref(),
827 OAuthRequest::Introspection => server_metadata.introspection_endpoint.as_ref(),
828 OAuthRequest::PushedAuthorizationRequest(_) => server_metadata
829 .pushed_authorization_request_endpoint
830 .as_ref(),
831 }
832}
833
834#[inline]
835fn build_oauth_req_body<'a, S>(client_assertions: ClientAuth<'a>, parameters: S) -> Result<String>
836where
837 S: Serialize,
838{
839 Ok(serde_html_form::to_string(RequestPayload {
840 client_id: client_assertions.client_id,
841 client_assertion_type: client_assertions.assertion_type,
842 client_assertion: client_assertions.assertion,
843 parameters,
844 })?)
845}
846
847/// Client identity fields appended to every token-endpoint request body.
848///
849/// Encapsulates the result of choosing a client authentication method (`none` vs.
850/// `private_key_jwt`). The `build_auth` helper selects the appropriate variant based
851/// on server capabilities and client configuration.
852#[derive(Debug, Clone, Default)]
853pub struct ClientAuth<'a> {
854 /// The OAuth `client_id` for this client.
855 client_id: CowStr<'a>,
856 /// Either absent (for `none` auth) or `urn:ietf:params:oauth:client-assertion-type:jwt-bearer`.
857 assertion_type: Option<CowStr<'a>>,
858 /// A signed JWT proving client identity; present only for `private_key_jwt` auth.
859 assertion: Option<CowStr<'a>>,
860}
861
862impl<'s> ClientAuth<'s> {
863 /// Construct a `ClientAuth` with only a `client_id` and no assertion (the `none` method).
864 pub fn new_id(client_id: CowStr<'s>) -> Self {
865 Self {
866 client_id,
867 assertion_type: None,
868 assertion: None,
869 }
870 }
871}
872
873fn build_auth<'a>(
874 keyset: Option<&Keyset>,
875 server_metadata: &OAuthAuthorizationServerMetadata<'a>,
876 client_metadata: &OAuthClientMetadata<'a>,
877) -> Result<ClientAuth<'a>> {
878 let method_supported = server_metadata
879 .token_endpoint_auth_methods_supported
880 .as_ref();
881
882 let client_id = client_metadata.client_id.to_cowstr().into_static();
883 if let Some(method) = client_metadata.token_endpoint_auth_method.as_ref() {
884 match (*method).as_ref() {
885 "private_key_jwt"
886 if method_supported
887 .as_ref()
888 .is_some_and(|v| v.contains(&CowStr::new_static("private_key_jwt"))) =>
889 {
890 if let Some(keyset) = &keyset {
891 let mut alg_strs = server_metadata
892 .token_endpoint_auth_signing_alg_values_supported
893 .clone()
894 .unwrap_or(vec![FALLBACK_ALG.into()]);
895 alg_strs.sort_by(compare_algos);
896 let algs: Vec<Signing> = alg_strs
897 .iter()
898 .filter_map(|s| crate::keyset::parse_signing_alg(s))
899 .collect();
900 let iat = Utc::now().timestamp();
901 return Ok(ClientAuth {
902 client_id: client_id.clone(),
903 assertion_type: Some(CowStr::new_static(CLIENT_ASSERTION_TYPE_JWT_BEARER)),
904 assertion: Some(
905 keyset.create_jwt(
906 &algs,
907 // https://datatracker.ietf.org/doc/html/rfc7523#section-3
908 RegisteredClaims {
909 iss: Some(client_id.clone()),
910 sub: Some(client_id),
911 aud: Some(RegisteredClaimsAud::Single(
912 server_metadata.issuer.clone(),
913 )),
914 exp: Some(iat + 60),
915 // "iat" is required and **MUST** be less than one minute
916 // https://datatracker.ietf.org/doc/html/rfc9101
917 iat: Some(iat),
918 // atproto oauth-provider requires "jti" to be present
919 jti: Some(generate_nonce()),
920 ..Default::default()
921 }
922 .into(),
923 )?,
924 ),
925 });
926 }
927 }
928 "none"
929 if method_supported
930 .as_ref()
931 .is_some_and(|v| v.contains(&CowStr::new_static("none"))) =>
932 {
933 return Ok(ClientAuth::new_id(client_id));
934 }
935 _ => {}
936 }
937 }
938
939 Err(RequestError::unsupported_auth_method())
940}
941
942#[cfg(test)]
943mod tests {
944 use super::*;
945 use crate::types::{OAuthAuthorizationServerMetadata, OAuthClientMetadata};
946 use bytes::Bytes;
947 use http::{Response as HttpResponse, StatusCode};
948 use jacquard_common::{deps::fluent_uri::Uri, http_client::HttpClient, types::string::Did};
949 use jacquard_identity::resolver::IdentityResolver;
950 use std::sync::Arc;
951 use tokio::sync::Mutex;
952
953 #[derive(Clone, Default)]
954 struct MockClient {
955 resp: Arc<Mutex<Option<HttpResponse<Vec<u8>>>>>,
956 }
957
958 impl HttpClient for MockClient {
959 type Error = std::convert::Infallible;
960 fn send_http(
961 &self,
962 _request: http::Request<Vec<u8>>,
963 ) -> impl core::future::Future<
964 Output = core::result::Result<http::Response<Vec<u8>>, Self::Error>,
965 > + Send {
966 let resp = self.resp.clone();
967 async move { Ok(resp.lock().await.take().unwrap()) }
968 }
969 }
970
971 // IdentityResolver methods won't be called in these tests; provide stubs.
972 impl IdentityResolver for MockClient {
973 fn options(&self) -> &jacquard_identity::resolver::ResolverOptions {
974 use std::sync::LazyLock;
975 static OPTS: LazyLock<jacquard_identity::resolver::ResolverOptions> =
976 LazyLock::new(|| jacquard_identity::resolver::ResolverOptions::default());
977 &OPTS
978 }
979 async fn resolve_handle(
980 &self,
981 _handle: &jacquard_common::types::string::Handle<'_>,
982 ) -> std::result::Result<Did<'static>, jacquard_identity::resolver::IdentityError> {
983 Ok(Did::new_static("did:plc:alice").unwrap())
984 }
985 async fn resolve_did_doc(
986 &self,
987 _did: &Did<'_>,
988 ) -> std::result::Result<
989 jacquard_identity::resolver::DidDocResponse,
990 jacquard_identity::resolver::IdentityError,
991 > {
992 let doc = serde_json::json!({
993 "id": "did:plc:alice",
994 "service": [{
995 "id": "#pds",
996 "type": "AtprotoPersonalDataServer",
997 "serviceEndpoint": "https://pds"
998 }]
999 });
1000 let buf = Bytes::from(serde_json::to_vec(&doc).unwrap());
1001 Ok(jacquard_identity::resolver::DidDocResponse {
1002 buffer: buf,
1003 status: StatusCode::OK,
1004 requested: None,
1005 })
1006 }
1007 }
1008
1009 // Allow using DPoP helpers on MockClient
1010 impl crate::dpop::DpopExt for MockClient {}
1011 impl crate::resolver::OAuthResolver for MockClient {}
1012
1013 fn base_metadata() -> OAuthMetadata {
1014 let mut server = OAuthAuthorizationServerMetadata::default();
1015 server.issuer = CowStr::from("https://issuer");
1016 server.authorization_endpoint = CowStr::from("https://issuer/authorize");
1017 server.token_endpoint = CowStr::from("https://issuer/token");
1018 server.token_endpoint_auth_methods_supported = Some(vec![CowStr::from("none")]);
1019 OAuthMetadata {
1020 server_metadata: server,
1021 client_metadata: OAuthClientMetadata {
1022 client_id: CowStr::new_static("https://client"),
1023 client_uri: None,
1024 redirect_uris: vec![CowStr::new_static("https://client/cb")],
1025 scope: Some(CowStr::from("atproto")),
1026 grant_types: None,
1027 response_types: vec![CowStr::new_static("code")],
1028 application_type: Some(CowStr::new_static("web")),
1029 token_endpoint_auth_method: Some(CowStr::from("none")),
1030 dpop_bound_access_tokens: None,
1031 jwks_uri: None,
1032 jwks: None,
1033 token_endpoint_auth_signing_alg: None,
1034 client_name: None,
1035 privacy_policy_uri: None,
1036 tos_uri: None,
1037 logo_uri: None,
1038 },
1039 keyset: None,
1040 }
1041 }
1042
1043 #[tokio::test]
1044 async fn par_missing_endpoint() {
1045 let mut meta = base_metadata();
1046 meta.server_metadata.require_pushed_authorization_requests = Some(true);
1047 meta.server_metadata.pushed_authorization_request_endpoint = None;
1048 // require_pushed_authorization_requests is true and no endpoint
1049 let err = super::par(&MockClient::default(), None, None, &meta, None)
1050 .await
1051 .unwrap_err();
1052 assert!(
1053 matches!(err.kind(), RequestErrorKind::NoEndpoint(name) if name == "pushed_authorization_request")
1054 );
1055 }
1056
1057 #[tokio::test]
1058 async fn refresh_no_refresh_token() {
1059 let client = MockClient::default();
1060 let meta = base_metadata();
1061 let session = ClientSessionData {
1062 account_did: Did::new_static("did:plc:alice").unwrap(),
1063 session_id: CowStr::from("state"),
1064 host_url: Uri::parse("https://pds").expect("valid").to_owned(),
1065 authserver_url: CowStr::new_static("https://issuer"),
1066 authserver_token_endpoint: CowStr::from("https://issuer/token"),
1067 authserver_revocation_endpoint: None,
1068 scopes: vec![],
1069 dpop_data: DpopClientData {
1070 dpop_key: crate::utils::generate_key(&[CowStr::from("ES256")]).unwrap(),
1071 dpop_authserver_nonce: CowStr::from(""),
1072 dpop_host_nonce: CowStr::from(""),
1073 },
1074 token_set: crate::types::TokenSet {
1075 iss: CowStr::from("https://issuer"),
1076 sub: Did::new_static("did:plc:alice").unwrap(),
1077 aud: CowStr::from("https://pds"),
1078 scope: None,
1079 refresh_token: None,
1080 access_token: CowStr::from("abc"),
1081 token_type: crate::types::OAuthTokenType::DPoP,
1082 expires_at: None,
1083 },
1084 };
1085 let err = super::refresh(&client, session, &meta).await.unwrap_err();
1086 assert!(matches!(err.kind(), RequestErrorKind::NoRefreshToken));
1087 }
1088
1089 #[tokio::test]
1090 async fn exchange_code_missing_sub() {
1091 let client = MockClient::default();
1092 // set mock HTTP response body: token response without `sub`
1093 *client.resp.lock().await = Some(
1094 HttpResponse::builder()
1095 .status(StatusCode::OK)
1096 .body(
1097 serde_json::to_vec(&serde_json::json!({
1098 "access_token":"tok",
1099 "token_type":"DPoP",
1100 "expires_in": 3600
1101 }))
1102 .unwrap(),
1103 )
1104 .unwrap(),
1105 );
1106 let meta = base_metadata();
1107 let mut dpop = DpopReqData {
1108 dpop_key: crate::utils::generate_key(&[CowStr::from("ES256")]).unwrap(),
1109 dpop_authserver_nonce: None,
1110 };
1111 let err = super::exchange_code(&client, &mut dpop, "abc", "verifier", &meta)
1112 .await
1113 .unwrap_err();
1114 assert!(matches!(err.kind(), RequestErrorKind::TokenVerification));
1115 }
1116}