A better Rust ATProto crate
1use crate::{
2 atproto::atproto_client_metadata,
3 authstore::ClientAuthStore,
4 dpop::DpopExt,
5 error::{CallbackError, Result},
6 request::{OAuthMetadata, exchange_code, par},
7 resolver::OAuthResolver,
8 scopes::Scope,
9 session::{ClientData, ClientSessionData, DpopClientData, SessionRegistry},
10 types::{AuthorizeOptions, CallbackParams},
11};
12use jacquard_common::{
13 AuthorizationToken, CowStr, IntoStatic,
14 cowstr::ToCowStr,
15 deps::fluent_uri::Uri,
16 error::{AuthError, ClientError, XrpcResult},
17 http_client::HttpClient,
18 types::{did::Did, string::Handle},
19 xrpc::{
20 CallOptions, Response, XrpcClient, XrpcError, XrpcExt, XrpcRequest, XrpcResp, XrpcResponse,
21 build_http_request, process_response,
22 },
23};
24
25#[cfg(feature = "websocket")]
26use jacquard_common::websocket::{WebSocketClient, WebSocketConnection};
27#[cfg(feature = "websocket")]
28use jacquard_common::xrpc::XrpcSubscription;
29use jacquard_identity::{
30 JacquardResolver,
31 resolver::{DidDocResponse, IdentityError, IdentityResolver, ResolverOptions},
32};
33use jose_jwk::JwkSet;
34use std::{future::Future, sync::Arc};
35use tokio::sync::RwLock;
36
37/// The top-level OAuth client responsible for driving the authorization flow.
38pub struct OAuthClient<T, S>
39where
40 T: OAuthResolver,
41 S: ClientAuthStore,
42{
43 /// Shared session registry that mediates access to the backing auth store.
44 pub registry: Arc<SessionRegistry<T, S>>,
45 /// Default call options applied to every outgoing XRPC request.
46 pub options: RwLock<CallOptions<'static>>,
47 /// Override for the XRPC base URI; falls back to the public Bluesky AppView when `None`.
48 pub endpoint: RwLock<Option<Uri<String>>>,
49 /// Underlying HTTP/identity/OAuth resolver used for all network operations.
50 pub client: Arc<T>,
51}
52
53impl<S: ClientAuthStore> OAuthClient<JacquardResolver, S> {
54 /// Create an `OAuthClient` using the default [`JacquardResolver`] for identity and metadata resolution.
55 pub fn new(store: S, client_data: ClientData<'static>) -> Self {
56 let client = JacquardResolver::default();
57 Self::new_from_resolver(store, client, client_data)
58 }
59
60 /// Create an OAuth client with the provided store and default localhost client metadata.
61 ///
62 /// This is a convenience constructor for quickly setting up an OAuth client
63 /// with default localhost redirect URIs and "atproto transition:generic" scopes.
64 ///
65 /// # Example
66 ///
67 /// ```no_run
68 /// # use jacquard_oauth::client::OAuthClient;
69 /// # use jacquard_oauth::authstore::MemoryAuthStore;
70 /// # #[tokio::main]
71 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
72 /// let store = MemoryAuthStore::new();
73 /// let oauth = OAuthClient::with_default_config(store);
74 /// # Ok(())
75 /// # }
76 /// ```
77 pub fn with_default_config(store: S) -> Self {
78 let client_data = ClientData {
79 keyset: None,
80 config: crate::atproto::AtprotoClientMetadata::default_localhost(),
81 };
82 Self::new(store, client_data)
83 }
84}
85
86impl OAuthClient<JacquardResolver, crate::authstore::MemoryAuthStore> {
87 /// Create an OAuth client with an in-memory auth store and default localhost client metadata.
88 ///
89 /// This is a convenience constructor for simple testing and development.
90 /// The session will not persist across restarts.
91 ///
92 /// # Example
93 ///
94 /// ```no_run
95 /// # use jacquard_oauth::client::OAuthClient;
96 /// # #[tokio::main]
97 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
98 /// let oauth = OAuthClient::with_memory_store();
99 /// # Ok(())
100 /// # }
101 /// ```
102 pub fn with_memory_store() -> Self {
103 Self::with_default_config(crate::authstore::MemoryAuthStore::new())
104 }
105}
106
107impl<T, S> OAuthClient<T, S>
108where
109 T: OAuthResolver,
110 S: ClientAuthStore,
111{
112 /// Create an OAuth client from an explicit resolver instance, taking ownership of both.
113 pub fn new_from_resolver(store: S, client: T, client_data: ClientData<'static>) -> Self {
114 // #[cfg(feature = "tracing")]
115 // tracing::info!(
116 // redirect_uris = ?client_data.config.redirect_uris,
117 // scopes = ?client_data.config.scopes,
118 // has_keyset = client_data.keyset.is_some(),
119 // "oauth client created:"
120 // );
121
122 let client = Arc::new(client);
123 let registry = Arc::new(SessionRegistry::new(store, client.clone(), client_data));
124 Self {
125 registry,
126 client,
127 options: RwLock::new(CallOptions::default()),
128 endpoint: RwLock::new(None),
129 }
130 }
131
132 /// Create an OAuth client from already-`Arc`-wrapped store and resolver.
133 pub fn new_with_shared(
134 store: Arc<S>,
135 client: Arc<T>,
136 client_data: ClientData<'static>,
137 ) -> Self {
138 let registry = Arc::new(SessionRegistry::new_shared(
139 store,
140 client.clone(),
141 client_data,
142 ));
143 Self {
144 registry,
145 client,
146 options: RwLock::new(CallOptions::default()),
147 endpoint: RwLock::new(None),
148 }
149 }
150}
151
152impl<T, S> OAuthClient<T, S>
153where
154 S: ClientAuthStore + Send + Sync + 'static,
155 T: OAuthResolver + DpopExt + Send + Sync + 'static,
156{
157 /// Return the public JWK set for this client's keyset, or an empty set if no keyset is configured.
158 pub fn jwks(&self) -> JwkSet {
159 self.registry
160 .client_data
161 .keyset
162 .as_ref()
163 .map(|keyset| keyset.public_jwks())
164 .unwrap_or_default()
165 }
166 /// Begin an OAuth authorization flow and return the URL to which the user should be redirected.
167 ///
168 /// This resolves OAuth metadata for the given `input` (a handle, DID, or PDS/entryway URL),
169 /// performs a Pushed Authorization Request (PAR) to the authorization server, persists the
170 /// resulting state for later callback verification, and returns a fully-constructed
171 /// authorization endpoint URL.
172 ///
173 /// The caller is responsible for redirecting the user's browser to the returned URL.
174 #[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", skip(self, input), fields(input = input.as_ref())))]
175 pub async fn start_auth(
176 &self,
177 input: impl AsRef<str>,
178 options: AuthorizeOptions<'_>,
179 ) -> Result<String> {
180 let client_metadata = atproto_client_metadata(
181 self.registry.client_data.config.clone(),
182 &self.registry.client_data.keyset,
183 )?;
184 let (server_metadata, identity) = self.client.resolve_oauth(input.as_ref()).await?;
185 let login_hint = if identity.is_some() {
186 Some(input.as_ref().into())
187 } else {
188 None
189 };
190 let metadata = OAuthMetadata {
191 server_metadata,
192 client_metadata,
193 keyset: self.registry.client_data.keyset.clone(),
194 };
195
196 let auth_req_info = par(
197 self.client.as_ref(),
198 login_hint,
199 options.prompt,
200 &metadata,
201 options.state,
202 )
203 .await?;
204
205 // Persist state for callback handling
206 self.registry
207 .store
208 .save_auth_req_info(&auth_req_info)
209 .await?;
210
211 #[derive(serde::Serialize)]
212 struct Parameters<'s> {
213 client_id: CowStr<'s>,
214 request_uri: CowStr<'s>,
215 }
216 Ok(metadata.server_metadata.authorization_endpoint.to_string()
217 + "?"
218 + &serde_html_form::to_string(Parameters {
219 client_id: metadata.client_metadata.client_id,
220 request_uri: auth_req_info.request_uri,
221 })
222 .unwrap())
223 }
224
225 /// Complete the OAuth authorization flow after the authorization server redirects back to the client.
226 ///
227 /// Validates the `state` and optional `iss` parameters, exchanges the authorization code for
228 /// tokens via the token endpoint, verifies the `sub` claim against the expected issuer, and
229 /// persists the resulting session. On success returns an [`OAuthSession`] ready for API calls.
230 #[cfg_attr(feature = "tracing", tracing::instrument(level = "info", skip_all, fields(state = params.state.as_ref().map(|s| s.as_ref()))))]
231 pub async fn callback(&self, params: CallbackParams<'_>) -> Result<OAuthSession<T, S>> {
232 let Some(state_key) = params.state else {
233 return Err(CallbackError::MissingState.into());
234 };
235
236 let Some(auth_req_info) = self.registry.store.get_auth_req_info(&state_key).await? else {
237 return Err(CallbackError::MissingState.into());
238 };
239
240 self.registry.store.delete_auth_req_info(&state_key).await?;
241
242 let metadata = self
243 .client
244 .get_authorization_server_metadata(&auth_req_info.authserver_url.to_cowstr())
245 .await?;
246
247 if let Some(iss) = params.iss {
248 if iss != metadata.issuer {
249 return Err(CallbackError::IssuerMismatch {
250 expected: metadata.issuer.to_string(),
251 got: iss.to_string(),
252 }
253 .into());
254 }
255 } else if metadata.authorization_response_iss_parameter_supported == Some(true) {
256 return Err(CallbackError::MissingIssuer.into());
257 }
258 let metadata = OAuthMetadata {
259 server_metadata: metadata,
260 client_metadata: atproto_client_metadata(
261 self.registry.client_data.config.clone(),
262 &self.registry.client_data.keyset,
263 )?,
264 keyset: self.registry.client_data.keyset.clone(),
265 };
266 let authserver_nonce = auth_req_info.dpop_data.dpop_authserver_nonce.clone();
267
268 match exchange_code(
269 self.client.as_ref(),
270 &mut auth_req_info.dpop_data.clone(),
271 ¶ms.code,
272 &auth_req_info.pkce_verifier,
273 &metadata,
274 )
275 .await
276 {
277 Ok(token_set) => {
278 let scopes = if let Some(scope) = &token_set.scope {
279 Scope::parse_multiple_reduced(&scope)
280 .expect("Failed to parse scopes")
281 .into_static()
282 } else {
283 vec![]
284 };
285 let client_data = ClientSessionData {
286 account_did: token_set.sub.clone(),
287 session_id: auth_req_info.state,
288 host_url: Uri::parse(token_set.aud.as_ref())?.to_owned(),
289 authserver_url: auth_req_info.authserver_url.to_cowstr(),
290 authserver_token_endpoint: auth_req_info.authserver_token_endpoint,
291 authserver_revocation_endpoint: auth_req_info.authserver_revocation_endpoint,
292 scopes,
293 dpop_data: DpopClientData {
294 dpop_key: auth_req_info.dpop_data.dpop_key.clone(),
295 dpop_authserver_nonce: authserver_nonce.unwrap_or(CowStr::default()),
296 dpop_host_nonce: auth_req_info
297 .dpop_data
298 .dpop_authserver_nonce
299 .unwrap_or(CowStr::default()),
300 },
301 token_set,
302 };
303
304 self.create_session(client_data).await
305 }
306 Err(e) => Err(e.into()),
307 }
308 }
309
310 async fn create_session(&self, data: ClientSessionData<'_>) -> Result<OAuthSession<T, S>> {
311 self.registry.set(data.clone()).await?;
312 Ok(OAuthSession::new(
313 self.registry.clone(),
314 self.client.clone(),
315 data.into_static(),
316 ))
317 }
318
319 /// Restore a previously created session from the backing store, refreshing tokens if needed.
320 pub async fn restore(&self, did: &Did<'_>, session_id: &str) -> Result<OAuthSession<T, S>> {
321 self.create_session(self.registry.get(did, session_id, true).await?)
322 .await
323 }
324
325 /// Revoke a session by deleting it from the backing store.
326 ///
327 /// Note: this removes the session from local storage but does **not** call the authorization
328 /// server's revocation endpoint. To also invalidate the token server-side, prefer
329 /// [`OAuthSession::logout`], which calls `revoke` on the token before deleting the session.
330 pub async fn revoke(&self, did: &Did<'_>, session_id: &str) -> Result<()> {
331 Ok(self.registry.del(did, session_id).await?)
332 }
333}
334
335impl<T, S> HttpClient for OAuthClient<T, S>
336where
337 S: ClientAuthStore + Send + Sync + 'static,
338 T: OAuthResolver + DpopExt + Send + Sync + 'static,
339{
340 type Error = T::Error;
341
342 async fn send_http(
343 &self,
344 request: http::Request<Vec<u8>>,
345 ) -> core::result::Result<http::Response<Vec<u8>>, Self::Error> {
346 self.client.send_http(request).await
347 }
348}
349
350impl<T, S> IdentityResolver for OAuthClient<T, S>
351where
352 S: ClientAuthStore + Send + Sync + 'static,
353 T: OAuthResolver + DpopExt + Send + Sync + 'static,
354{
355 fn options(&self) -> &ResolverOptions {
356 self.client.options()
357 }
358
359 async fn resolve_handle(
360 &self,
361 handle: &Handle<'_>,
362 ) -> jacquard_identity::resolver::Result<Did<'static>> {
363 self.client.resolve_handle(handle).await
364 }
365
366 async fn resolve_did_doc(
367 &self,
368 did: &Did<'_>,
369 ) -> jacquard_identity::resolver::Result<DidDocResponse> {
370 self.client.resolve_did_doc(did).await
371 }
372}
373
374impl<T, S> XrpcClient for OAuthClient<T, S>
375where
376 S: ClientAuthStore + Send + Sync + 'static,
377 T: OAuthResolver + DpopExt + Send + Sync + 'static,
378{
379 async fn base_uri(&self) -> Uri<String> {
380 self.endpoint.read().await.clone().unwrap_or_else(|| {
381 Uri::parse("https://public.api.bsky.app")
382 .expect("hardcoded URI is valid")
383 .to_owned()
384 })
385 }
386
387 async fn opts(&self) -> CallOptions<'_> {
388 self.options.read().await.clone()
389 }
390
391 async fn set_opts(&self, opts: CallOptions<'_>) {
392 let mut guard = self.options.write().await;
393 *guard = opts.into_static();
394 }
395
396 async fn set_base_uri(&self, uri: Uri<String>) {
397 let normalized = jacquard_common::xrpc::normalize_base_uri(uri);
398 let mut guard = self.endpoint.write().await;
399 *guard = Some(normalized);
400 }
401
402 async fn send<R>(&self, request: R) -> XrpcResult<XrpcResponse<R>>
403 where
404 R: XrpcRequest + Send + Sync,
405 <R as XrpcRequest>::Response: Send + Sync,
406 {
407 let opts = self.options.read().await.clone();
408 self.send_with_opts(request, opts).await
409 }
410
411 async fn send_with_opts<R>(
412 &self,
413 request: R,
414 opts: CallOptions<'_>,
415 ) -> XrpcResult<XrpcResponse<R>>
416 where
417 R: XrpcRequest + Send + Sync,
418 <R as XrpcRequest>::Response: Send + Sync,
419 {
420 let base_uri = self.base_uri().await;
421 self.client
422 .xrpc(base_uri)
423 .with_options(opts.clone())
424 .send(&request)
425 .await
426 }
427}
428
429/// An active OAuth session for a specific account, used to make authenticated API requests.
430///
431/// `OAuthSession` holds the DPoP-bound token set for one account and handles transparent
432/// token refresh on `401 invalid_token` responses. The optional `W` type parameter allows
433/// attaching a WebSocket client (defaults to `()` when WebSocket support is not needed).
434///
435/// Obtain an `OAuthSession` from [`OAuthClient::callback`] or [`OAuthClient::restore`].
436pub struct OAuthSession<T, S, W = ()>
437where
438 T: OAuthResolver,
439 S: ClientAuthStore,
440{
441 /// Shared registry used to persist and retrieve session data across refresh operations.
442 pub registry: Arc<SessionRegistry<T, S>>,
443 /// Underlying HTTP/identity/OAuth resolver shared with the parent `OAuthClient`.
444 pub client: Arc<T>,
445 /// Optional WebSocket client; `()` when WebSocket support is not required.
446 pub ws_client: W,
447 /// Mutable session data including DPoP key, nonces, and token set.
448 pub data: RwLock<ClientSessionData<'static>>,
449 /// Default call options applied to every outgoing XRPC request from this session.
450 pub options: RwLock<CallOptions<'static>>,
451}
452
453impl<T, S> OAuthSession<T, S, ()>
454where
455 T: OAuthResolver,
456 S: ClientAuthStore,
457{
458 /// Create a new session without a WebSocket client.
459 ///
460 /// This is the standard constructor used by [`OAuthClient::callback`] and
461 /// [`OAuthClient::restore`]. For WebSocket support use [`OAuthSession::new_with_ws`].
462 pub fn new(
463 registry: Arc<SessionRegistry<T, S>>,
464 client: Arc<T>,
465 data: ClientSessionData<'static>,
466 ) -> Self {
467 Self {
468 registry,
469 client,
470 ws_client: (),
471 data: RwLock::new(data),
472 options: RwLock::new(CallOptions::default()),
473 }
474 }
475}
476
477impl<T, S, W> OAuthSession<T, S, W>
478where
479 T: OAuthResolver,
480 S: ClientAuthStore,
481{
482 /// Create a new session with an attached WebSocket client.
483 ///
484 /// Use this variant when the session needs to support WebSocket subscriptions in addition
485 /// to standard XRPC calls. The `ws_client` is exposed via [`OAuthSession::ws_client`] and
486 /// is used by the `WebSocketClient` impl when the `websocket` feature is enabled.
487 pub fn new_with_ws(
488 registry: Arc<SessionRegistry<T, S>>,
489 client: Arc<T>,
490 ws_client: W,
491 data: ClientSessionData<'static>,
492 ) -> Self {
493 Self {
494 registry,
495 client,
496 ws_client,
497 data: RwLock::new(data),
498 options: RwLock::new(CallOptions::default()),
499 }
500 }
501
502 /// Consume this session and return a new one with the given call options pre-applied.
503 ///
504 /// Useful for setting request-level defaults (e.g., `atproto-proxy` or custom headers) once
505 /// at construction time rather than passing them to every individual XRPC call.
506 pub fn with_options(self, options: CallOptions<'_>) -> Self {
507 Self {
508 registry: self.registry,
509 client: self.client,
510 ws_client: self.ws_client,
511 data: self.data,
512 options: RwLock::new(options.into_static()),
513 }
514 }
515
516 /// Get a reference to the WebSocket client.
517 pub fn ws_client(&self) -> &W {
518 &self.ws_client
519 }
520
521 /// Replace the default call options for this session without consuming it.
522 pub async fn set_options(&self, options: CallOptions<'_>) {
523 *self.options.write().await = options.into_static();
524 }
525
526 /// Return the DID and session ID for this session.
527 ///
528 /// The session ID is the random `state` token generated during the PAR flow and can
529 /// be used together with the DID to restore the session via [`OAuthClient::restore`].
530 pub async fn session_info(&self) -> (Did<'_>, CowStr<'_>) {
531 let data = self.data.read().await;
532 (data.account_did.clone(), data.session_id.clone())
533 }
534
535 /// Return the resource server (PDS) base URI for this session.
536 pub async fn endpoint(&self) -> Uri<String> {
537 self.data.read().await.host_url.clone()
538 }
539
540 /// Return the current DPoP-bound access token for this session.
541 ///
542 /// The token may be stale if it has expired; use [`OAuthSession::refresh`] or
543 /// rely on the automatic refresh performed by `send_with_opts` to obtain a fresh one.
544 pub async fn access_token(&self) -> AuthorizationToken<'_> {
545 AuthorizationToken::Dpop(self.data.read().await.token_set.access_token.clone())
546 }
547
548 /// Return the current refresh token for this session, if one is present.
549 ///
550 /// Not all authorization servers issue refresh tokens. When `None` is returned,
551 /// the session cannot be silently renewed and the user must re-authenticate.
552 pub async fn refresh_token(&self) -> Option<AuthorizationToken<'_>> {
553 self.data
554 .read()
555 .await
556 .token_set
557 .refresh_token
558 .as_ref()
559 .map(|t| AuthorizationToken::Dpop(t.clone()))
560 }
561
562 /// Derive an unauthenticated [`OAuthClient`] that shares the same registry and resolver.
563 ///
564 /// Useful when you need to initiate a new authorization flow from within an existing
565 /// session context (e.g., to add a second account) without constructing a fresh client.
566 pub fn to_client(&self) -> OAuthClient<T, S> {
567 OAuthClient::from_session(self)
568 }
569}
570impl<T, S, W> OAuthSession<T, S, W>
571where
572 S: ClientAuthStore + Send + Sync + 'static,
573 T: OAuthResolver + DpopExt + Send + Sync + 'static,
574{
575 /// Revoke the access token at the authorization server and delete the session from the store.
576 ///
577 /// Revocation is best-effort: if the server does not advertise a revocation endpoint, or if
578 /// the revocation call fails, the session is still deleted locally. This prevents a dangling
579 /// session record from blocking future logins for the same account.
580 pub async fn logout(&self) -> Result<()> {
581 use crate::request::{OAuthMetadata, revoke};
582 let mut data = self.data.write().await;
583 let meta =
584 OAuthMetadata::new(self.client.as_ref(), &self.registry.client_data, &data).await?;
585 if meta.server_metadata.revocation_endpoint.is_some() {
586 let token = data.token_set.access_token.clone();
587 revoke(self.client.as_ref(), &mut data.dpop_data, &token, &meta)
588 .await
589 .ok();
590 }
591 // Remove from store
592 self.registry
593 .del(&data.account_did, &data.session_id)
594 .await?;
595 Ok(())
596 }
597}
598
599impl<T, S> OAuthClient<T, S>
600where
601 T: OAuthResolver,
602 S: ClientAuthStore,
603{
604 /// Construct an `OAuthClient` that shares the registry and resolver of an existing session.
605 ///
606 /// Equivalent to [`OAuthSession::to_client`]; provided on `OAuthClient` for symmetry so
607 /// callers can obtain an unauthenticated client without holding a session reference.
608 pub fn from_session<W>(session: &OAuthSession<T, S, W>) -> Self {
609 Self {
610 registry: session.registry.clone(),
611 client: session.client.clone(),
612 options: RwLock::new(CallOptions::default()),
613 endpoint: RwLock::new(None),
614 }
615 }
616}
617impl<T, S, W> OAuthSession<T, S, W>
618where
619 S: ClientAuthStore + Send + Sync + 'static,
620 T: OAuthResolver + DpopExt + Send + Sync + 'static,
621{
622 /// Explicitly refresh the access token using the stored refresh token.
623 ///
624 /// On success the new token set is written back into both the in-memory session data and
625 /// the backing store. The returned `AuthorizationToken` is the new access token, which
626 /// callers can immediately use to retry a failed request.
627 ///
628 /// The actual token exchange is serialized per `(DID, session_id)` pair via a `Mutex` inside
629 /// the registry, so concurrent refresh attempts will not result in duplicate token exchanges.
630 #[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", skip_all))]
631 pub async fn refresh(&self) -> Result<AuthorizationToken<'_>> {
632 // Read identifiers without holding the lock across await
633 let (did, sid) = {
634 let data = self.data.read().await;
635 (data.account_did.clone(), data.session_id.clone())
636 };
637 let refreshed = self.registry.as_ref().get(&did, &sid, true).await?;
638 let token = AuthorizationToken::Dpop(refreshed.token_set.access_token.clone());
639 // Write back updated session
640 *self.data.write().await = refreshed.clone().into_static();
641 // Store in the registry
642 self.registry.set(refreshed).await?;
643 Ok(token)
644 }
645}
646
647impl<T, S, W> HttpClient for OAuthSession<T, S, W>
648where
649 S: ClientAuthStore + Send + Sync + 'static,
650 T: OAuthResolver + DpopExt + Send + Sync + 'static,
651 W: Send + Sync,
652{
653 type Error = T::Error;
654
655 async fn send_http(
656 &self,
657 request: http::Request<Vec<u8>>,
658 ) -> core::result::Result<http::Response<Vec<u8>>, Self::Error> {
659 self.client.send_http(request).await
660 }
661}
662
663impl<T, S, W> XrpcClient for OAuthSession<T, S, W>
664where
665 S: ClientAuthStore + Send + Sync + 'static,
666 T: OAuthResolver + DpopExt + XrpcExt + Send + Sync + 'static,
667 W: Send + Sync,
668{
669 async fn base_uri(&self) -> Uri<String> {
670 self.data.read().await.host_url.clone()
671 }
672
673 async fn opts(&self) -> CallOptions<'_> {
674 self.options.read().await.clone()
675 }
676
677 async fn set_opts(&self, opts: CallOptions<'_>) {
678 let mut guard = self.options.write().await;
679 *guard = opts.into_static();
680 }
681
682 async fn set_base_uri(&self, uri: Uri<String>) {
683 let normalized = jacquard_common::xrpc::normalize_base_uri(uri);
684 let mut guard = self.data.write().await;
685 guard.host_url = normalized;
686 }
687
688 async fn send<R>(&self, request: R) -> XrpcResult<XrpcResponse<R>>
689 where
690 R: XrpcRequest + Send + Sync,
691 <R as XrpcRequest>::Response: Send + Sync,
692 {
693 let opts = self.options.read().await.clone();
694 self.send_with_opts(request, opts).await
695 }
696
697 async fn send_with_opts<R>(
698 &self,
699 request: R,
700 mut opts: CallOptions<'_>,
701 ) -> XrpcResult<XrpcResponse<R>>
702 where
703 R: XrpcRequest + Send + Sync,
704 <R as XrpcRequest>::Response: Send + Sync,
705 {
706 let base_uri = self.base_uri().await;
707 let original_token = self.access_token().await;
708 opts.auth = Some(original_token.clone());
709 // Clone dpop_data and release read lock before the await point
710 let mut dpop = self.data.read().await.dpop_data.clone();
711 let http_response = self
712 .client
713 .dpop_call(&mut dpop)
714 .send(build_http_request(&base_uri, &request, &opts)?)
715 .await
716 .map_err(|e| ClientError::from(e).for_nsid(R::NSID))?;
717 let resp = process_response(http_response);
718
719 // Write back updated nonce to session data (dpop_call may have updated it)
720 {
721 let mut guard = self.data.write().await;
722 guard.dpop_data.dpop_host_nonce = dpop.dpop_host_nonce.clone();
723 }
724
725 if is_invalid_token_response(&resp) {
726 // Optimistic refresh: check if another request already refreshed the token
727 let current_token = self.access_token().await;
728 if current_token != original_token {
729 // Token was already refreshed by another concurrent request, use it
730 opts.auth = Some(current_token);
731 } else {
732 // We need to refresh - this will be serialized by the registry's Mutex
733 opts.auth = Some(
734 self.refresh()
735 .await
736 .map_err(|e| ClientError::transport(e))?,
737 );
738 }
739 // Re-read dpop_data after refresh (refresh may have updated it)
740 let mut dpop = self.data.read().await.dpop_data.clone();
741 let http_response = self
742 .client
743 .dpop_call(&mut dpop)
744 .send(build_http_request(&base_uri, &request, &opts)?)
745 .await
746 .map_err(|e| {
747 ClientError::from(e)
748 .for_nsid(R::NSID)
749 .append_context("after token refresh")
750 })?;
751 let resp = process_response(http_response);
752
753 // Write back updated nonce after retry
754 {
755 let mut guard = self.data.write().await;
756 guard.dpop_data.dpop_host_nonce = dpop.dpop_host_nonce.clone();
757 }
758
759 resp
760 } else {
761 resp
762 }
763 }
764}
765
766#[cfg(feature = "streaming")]
767impl<T, S, W> jacquard_common::http_client::HttpClientExt for OAuthSession<T, S, W>
768where
769 S: ClientAuthStore + Send + Sync + 'static,
770 T: OAuthResolver
771 + DpopExt
772 + XrpcExt
773 + jacquard_common::http_client::HttpClientExt
774 + Send
775 + Sync
776 + 'static,
777 W: Send + Sync,
778{
779 async fn send_http_streaming(
780 &self,
781 request: http::Request<Vec<u8>>,
782 ) -> core::result::Result<http::Response<jacquard_common::stream::ByteStream>, Self::Error>
783 {
784 self.client.send_http_streaming(request).await
785 }
786
787 #[cfg(not(target_arch = "wasm32"))]
788 async fn send_http_bidirectional<Str>(
789 &self,
790 parts: http::request::Parts,
791 body: Str,
792 ) -> core::result::Result<http::Response<jacquard_common::stream::ByteStream>, Self::Error>
793 where
794 Str: n0_future::Stream<
795 Item = core::result::Result<bytes::Bytes, jacquard_common::StreamError>,
796 > + Send
797 + 'static,
798 {
799 self.client.send_http_bidirectional(parts, body).await
800 }
801
802 #[cfg(target_arch = "wasm32")]
803 async fn send_http_bidirectional<Str>(
804 &self,
805 parts: http::request::Parts,
806 body: Str,
807 ) -> core::result::Result<http::Response<jacquard_common::stream::ByteStream>, Self::Error>
808 where
809 Str: n0_future::Stream<
810 Item = core::result::Result<bytes::Bytes, jacquard_common::StreamError>,
811 > + 'static,
812 {
813 self.client.send_http_bidirectional(parts, body).await
814 }
815}
816
817#[cfg(feature = "streaming")]
818impl<T, S, W> jacquard_common::xrpc::XrpcStreamingClient for OAuthSession<T, S, W>
819where
820 S: ClientAuthStore + Send + Sync + 'static,
821 T: OAuthResolver
822 + DpopExt
823 + XrpcExt
824 + jacquard_common::http_client::HttpClientExt
825 + Send
826 + Sync
827 + 'static,
828 W: Send + Sync,
829{
830 async fn download<R>(
831 &self,
832 request: R,
833 ) -> core::result::Result<jacquard_common::xrpc::StreamingResponse, jacquard_common::StreamError>
834 where
835 R: XrpcRequest + Send + Sync,
836 <R as XrpcRequest>::Response: Send + Sync,
837 {
838 use jacquard_common::StreamError;
839
840 let base_uri = <Self as XrpcClient>::base_uri(self).await;
841 let mut opts = self.options.read().await.clone();
842 opts.auth = Some(self.access_token().await);
843 let http_request = build_http_request(&base_uri, &request, &opts)
844 .map_err(|e| StreamError::protocol(e.to_string()))?;
845 let guard = self.data.read().await;
846 let mut dpop = guard.dpop_data.clone();
847 let result = self
848 .client
849 .dpop_call(&mut dpop)
850 .send_streaming(http_request)
851 .await;
852 drop(guard);
853
854 match result {
855 Ok(response) => Ok(response),
856 Err(_e) => {
857 // Check if it's an auth error and retry
858 opts.auth = Some(
859 self.refresh()
860 .await
861 .map_err(|e| StreamError::transport(e))?,
862 );
863 let http_request = build_http_request(&base_uri, &request, &opts)
864 .map_err(|e| StreamError::protocol(e.to_string()))?;
865 let guard = self.data.read().await;
866 let mut dpop = guard.dpop_data.clone();
867 self.client
868 .dpop_call(&mut dpop)
869 .send_streaming(http_request)
870 .await
871 .map_err(StreamError::transport)
872 }
873 }
874 }
875
876 async fn stream<Str>(
877 &self,
878 stream: jacquard_common::xrpc::streaming::XrpcProcedureSend<Str::Frame<'static>>,
879 ) -> core::result::Result<
880 jacquard_common::xrpc::streaming::XrpcResponseStream<
881 <<Str as jacquard_common::xrpc::streaming::XrpcProcedureStream>::Response as jacquard_common::xrpc::streaming::XrpcStreamResp>::Frame<'static>,
882 >,
883 jacquard_common::StreamError,
884 >
885 where
886 Str: jacquard_common::xrpc::streaming::XrpcProcedureStream + 'static,
887 <<Str as jacquard_common::xrpc::streaming::XrpcProcedureStream>::Response as jacquard_common::xrpc::streaming::XrpcStreamResp>::Frame<'static>: jacquard_common::xrpc::streaming::XrpcStreamResp,
888 {
889 use jacquard_common::StreamError;
890 use n0_future::TryStreamExt;
891
892 let base_uri = self.base_uri().await;
893 let mut opts = self.options.read().await.clone();
894 opts.auth = Some(self.access_token().await);
895
896 let mut path = String::from(base_uri.as_str().trim_end_matches('/'));
897 path.push_str("/xrpc/");
898 path.push_str(<Str::Request as jacquard_common::xrpc::XrpcRequest>::NSID);
899
900 let mut builder = http::Request::post(path);
901
902 if let Some(token) = &opts.auth {
903 use jacquard_common::AuthorizationToken;
904 let hv = match token {
905 AuthorizationToken::Bearer(t) => {
906 http::HeaderValue::from_str(&format!("Bearer {}", t.as_ref()))
907 }
908 AuthorizationToken::Dpop(t) => {
909 http::HeaderValue::from_str(&format!("DPoP {}", t.as_ref()))
910 }
911 }
912 .map_err(|e| StreamError::protocol(format!("Invalid authorization token: {}", e)))?;
913 builder = builder.header(http::header::AUTHORIZATION, hv);
914 }
915
916 if let Some(proxy) = &opts.atproto_proxy {
917 builder = builder.header("atproto-proxy", proxy.as_ref());
918 }
919 if let Some(labelers) = &opts.atproto_accept_labelers {
920 if !labelers.is_empty() {
921 let joined = labelers
922 .iter()
923 .map(|s| s.as_ref())
924 .collect::<Vec<_>>()
925 .join(", ");
926 builder = builder.header("atproto-accept-labelers", joined);
927 }
928 }
929 for (name, value) in &opts.extra_headers {
930 builder = builder.header(name, value);
931 }
932
933 let (parts, _) = builder
934 .body(())
935 .map_err(|e| StreamError::protocol(e.to_string()))?
936 .into_parts();
937
938 let body_stream =
939 jacquard_common::stream::ByteStream::new(Box::pin(stream.0.map_ok(|f| f.buffer)));
940
941 let guard = self.data.read().await;
942 let mut dpop = guard.dpop_data.clone();
943 let result = self
944 .client
945 .dpop_call(&mut dpop)
946 .send_bidirectional(parts, body_stream)
947 .await;
948 drop(guard);
949
950 match result {
951 Ok(response) => {
952 let (resp_parts, resp_body) = response.into_parts();
953 Ok(
954 jacquard_common::xrpc::streaming::XrpcResponseStream::from_typed_parts(
955 resp_parts, resp_body,
956 ),
957 )
958 }
959 Err(e) => {
960 // OAuth token refresh and retry is handled by dpop wrapper
961 // If we get here, it's a real error
962 Err(StreamError::transport(e))
963 }
964 }
965 }
966}
967
968fn is_invalid_token_response<R: XrpcResp>(response: &XrpcResult<Response<R>>) -> bool {
969 use jacquard_common::error::ClientErrorKind;
970
971 match response {
972 Err(e) => match e.kind() {
973 ClientErrorKind::Auth(AuthError::InvalidToken) => true,
974 ClientErrorKind::Auth(AuthError::Other(value)) => value
975 .to_str()
976 .is_ok_and(|s| s.starts_with("DPoP ") && s.contains("error=\"invalid_token\"")),
977 _ => false,
978 },
979 Ok(resp) => match resp.parse() {
980 Err(XrpcError::Auth(AuthError::InvalidToken)) => true,
981 _ => false,
982 },
983 }
984}
985
986impl<T, S, W> IdentityResolver for OAuthSession<T, S, W>
987where
988 S: ClientAuthStore + Send + Sync + 'static,
989 T: OAuthResolver + IdentityResolver + XrpcExt + Send + Sync + 'static,
990 W: Send + Sync,
991{
992 fn options(&self) -> &ResolverOptions {
993 self.client.options()
994 }
995
996 fn resolve_handle(
997 &self,
998 handle: &Handle<'_>,
999 ) -> impl Future<Output = std::result::Result<Did<'static>, IdentityError>> {
1000 async { self.client.resolve_handle(handle).await }
1001 }
1002
1003 fn resolve_did_doc(
1004 &self,
1005 did: &Did<'_>,
1006 ) -> impl Future<Output = std::result::Result<DidDocResponse, IdentityError>> {
1007 async { self.client.resolve_did_doc(did).await }
1008 }
1009}
1010
1011#[cfg(feature = "websocket")]
1012impl<T, S, W> WebSocketClient for OAuthSession<T, S, W>
1013where
1014 S: ClientAuthStore + Send + Sync + 'static,
1015 T: OAuthResolver + Send + Sync + 'static,
1016 W: WebSocketClient + Send + Sync,
1017{
1018 type Error = W::Error;
1019
1020 async fn connect(
1021 &self,
1022 uri: Uri<&str>,
1023 ) -> std::result::Result<WebSocketConnection, Self::Error> {
1024 self.ws_client.connect(uri).await
1025 }
1026
1027 async fn connect_with_headers(
1028 &self,
1029 uri: Uri<&str>,
1030 headers: Vec<(CowStr<'_>, CowStr<'_>)>,
1031 ) -> std::result::Result<WebSocketConnection, Self::Error> {
1032 self.ws_client.connect_with_headers(uri, headers).await
1033 }
1034}
1035
1036#[cfg(feature = "websocket")]
1037impl<T, S, W> jacquard_common::xrpc::SubscriptionClient for OAuthSession<T, S, W>
1038where
1039 S: ClientAuthStore + Send + Sync + 'static,
1040 T: OAuthResolver + Send + Sync + 'static,
1041 W: WebSocketClient + Send + Sync,
1042{
1043 async fn base_uri(&self) -> Uri<String> {
1044 self.data.read().await.host_url.clone()
1045 }
1046
1047 async fn subscription_opts(&self) -> jacquard_common::xrpc::SubscriptionOptions<'_> {
1048 let mut opts = jacquard_common::xrpc::SubscriptionOptions::default();
1049 let token = self.access_token().await;
1050 let auth_value = match token {
1051 AuthorizationToken::Bearer(t) => format!("Bearer {}", t.as_ref()),
1052 AuthorizationToken::Dpop(t) => format!("DPoP {}", t.as_ref()),
1053 };
1054 opts.headers
1055 .push((CowStr::from("Authorization"), CowStr::from(auth_value)));
1056 opts
1057 }
1058
1059 async fn subscribe<Sub>(
1060 &self,
1061 params: &Sub,
1062 ) -> std::result::Result<jacquard_common::xrpc::SubscriptionStream<Sub::Stream>, Self::Error>
1063 where
1064 Sub: XrpcSubscription + Send + Sync,
1065 {
1066 let opts = self.subscription_opts().await;
1067 self.subscribe_with_opts(params, opts).await
1068 }
1069
1070 async fn subscribe_with_opts<Sub>(
1071 &self,
1072 params: &Sub,
1073 opts: jacquard_common::xrpc::SubscriptionOptions<'_>,
1074 ) -> std::result::Result<jacquard_common::xrpc::SubscriptionStream<Sub::Stream>, Self::Error>
1075 where
1076 Sub: XrpcSubscription + Send + Sync,
1077 {
1078 use jacquard_common::xrpc::SubscriptionExt;
1079 let base = self.base_uri().await;
1080 self.subscription(base)
1081 .with_options(opts)
1082 .subscribe(params)
1083 .await
1084 }
1085}