A better Rust ATProto crate
102
fork

Configure Feed

Select the types of activity you want to include in your feed.

at pretty-codegen 1085 lines 39 kB view raw
1use crate::{ 2 atproto::atproto_client_metadata, 3 authstore::ClientAuthStore, 4 dpop::DpopExt, 5 error::{CallbackError, Result}, 6 request::{OAuthMetadata, exchange_code, par}, 7 resolver::OAuthResolver, 8 scopes::Scope, 9 session::{ClientData, ClientSessionData, DpopClientData, SessionRegistry}, 10 types::{AuthorizeOptions, CallbackParams}, 11}; 12use jacquard_common::{ 13 AuthorizationToken, CowStr, IntoStatic, 14 cowstr::ToCowStr, 15 deps::fluent_uri::Uri, 16 error::{AuthError, ClientError, XrpcResult}, 17 http_client::HttpClient, 18 types::{did::Did, string::Handle}, 19 xrpc::{ 20 CallOptions, Response, XrpcClient, XrpcError, XrpcExt, XrpcRequest, XrpcResp, XrpcResponse, 21 build_http_request, process_response, 22 }, 23}; 24 25#[cfg(feature = "websocket")] 26use jacquard_common::websocket::{WebSocketClient, WebSocketConnection}; 27#[cfg(feature = "websocket")] 28use jacquard_common::xrpc::XrpcSubscription; 29use jacquard_identity::{ 30 JacquardResolver, 31 resolver::{DidDocResponse, IdentityError, IdentityResolver, ResolverOptions}, 32}; 33use jose_jwk::JwkSet; 34use std::{future::Future, sync::Arc}; 35use tokio::sync::RwLock; 36 37/// The top-level OAuth client responsible for driving the authorization flow. 38pub struct OAuthClient<T, S> 39where 40 T: OAuthResolver, 41 S: ClientAuthStore, 42{ 43 /// Shared session registry that mediates access to the backing auth store. 44 pub registry: Arc<SessionRegistry<T, S>>, 45 /// Default call options applied to every outgoing XRPC request. 46 pub options: RwLock<CallOptions<'static>>, 47 /// Override for the XRPC base URI; falls back to the public Bluesky AppView when `None`. 48 pub endpoint: RwLock<Option<Uri<String>>>, 49 /// Underlying HTTP/identity/OAuth resolver used for all network operations. 50 pub client: Arc<T>, 51} 52 53impl<S: ClientAuthStore> OAuthClient<JacquardResolver, S> { 54 /// Create an `OAuthClient` using the default [`JacquardResolver`] for identity and metadata resolution. 55 pub fn new(store: S, client_data: ClientData<'static>) -> Self { 56 let client = JacquardResolver::default(); 57 Self::new_from_resolver(store, client, client_data) 58 } 59 60 /// Create an OAuth client with the provided store and default localhost client metadata. 61 /// 62 /// This is a convenience constructor for quickly setting up an OAuth client 63 /// with default localhost redirect URIs and "atproto transition:generic" scopes. 64 /// 65 /// # Example 66 /// 67 /// ```no_run 68 /// # use jacquard_oauth::client::OAuthClient; 69 /// # use jacquard_oauth::authstore::MemoryAuthStore; 70 /// # #[tokio::main] 71 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> { 72 /// let store = MemoryAuthStore::new(); 73 /// let oauth = OAuthClient::with_default_config(store); 74 /// # Ok(()) 75 /// # } 76 /// ``` 77 pub fn with_default_config(store: S) -> Self { 78 let client_data = ClientData { 79 keyset: None, 80 config: crate::atproto::AtprotoClientMetadata::default_localhost(), 81 }; 82 Self::new(store, client_data) 83 } 84} 85 86impl OAuthClient<JacquardResolver, crate::authstore::MemoryAuthStore> { 87 /// Create an OAuth client with an in-memory auth store and default localhost client metadata. 88 /// 89 /// This is a convenience constructor for simple testing and development. 90 /// The session will not persist across restarts. 91 /// 92 /// # Example 93 /// 94 /// ```no_run 95 /// # use jacquard_oauth::client::OAuthClient; 96 /// # #[tokio::main] 97 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> { 98 /// let oauth = OAuthClient::with_memory_store(); 99 /// # Ok(()) 100 /// # } 101 /// ``` 102 pub fn with_memory_store() -> Self { 103 Self::with_default_config(crate::authstore::MemoryAuthStore::new()) 104 } 105} 106 107impl<T, S> OAuthClient<T, S> 108where 109 T: OAuthResolver, 110 S: ClientAuthStore, 111{ 112 /// Create an OAuth client from an explicit resolver instance, taking ownership of both. 113 pub fn new_from_resolver(store: S, client: T, client_data: ClientData<'static>) -> Self { 114 // #[cfg(feature = "tracing")] 115 // tracing::info!( 116 // redirect_uris = ?client_data.config.redirect_uris, 117 // scopes = ?client_data.config.scopes, 118 // has_keyset = client_data.keyset.is_some(), 119 // "oauth client created:" 120 // ); 121 122 let client = Arc::new(client); 123 let registry = Arc::new(SessionRegistry::new(store, client.clone(), client_data)); 124 Self { 125 registry, 126 client, 127 options: RwLock::new(CallOptions::default()), 128 endpoint: RwLock::new(None), 129 } 130 } 131 132 /// Create an OAuth client from already-`Arc`-wrapped store and resolver. 133 pub fn new_with_shared( 134 store: Arc<S>, 135 client: Arc<T>, 136 client_data: ClientData<'static>, 137 ) -> Self { 138 let registry = Arc::new(SessionRegistry::new_shared( 139 store, 140 client.clone(), 141 client_data, 142 )); 143 Self { 144 registry, 145 client, 146 options: RwLock::new(CallOptions::default()), 147 endpoint: RwLock::new(None), 148 } 149 } 150} 151 152impl<T, S> OAuthClient<T, S> 153where 154 S: ClientAuthStore + Send + Sync + 'static, 155 T: OAuthResolver + DpopExt + Send + Sync + 'static, 156{ 157 /// Return the public JWK set for this client's keyset, or an empty set if no keyset is configured. 158 pub fn jwks(&self) -> JwkSet { 159 self.registry 160 .client_data 161 .keyset 162 .as_ref() 163 .map(|keyset| keyset.public_jwks()) 164 .unwrap_or_default() 165 } 166 /// Begin an OAuth authorization flow and return the URL to which the user should be redirected. 167 /// 168 /// This resolves OAuth metadata for the given `input` (a handle, DID, or PDS/entryway URL), 169 /// performs a Pushed Authorization Request (PAR) to the authorization server, persists the 170 /// resulting state for later callback verification, and returns a fully-constructed 171 /// authorization endpoint URL. 172 /// 173 /// The caller is responsible for redirecting the user's browser to the returned URL. 174 #[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", skip(self, input), fields(input = input.as_ref())))] 175 pub async fn start_auth( 176 &self, 177 input: impl AsRef<str>, 178 options: AuthorizeOptions<'_>, 179 ) -> Result<String> { 180 let client_metadata = atproto_client_metadata( 181 self.registry.client_data.config.clone(), 182 &self.registry.client_data.keyset, 183 )?; 184 let (server_metadata, identity) = self.client.resolve_oauth(input.as_ref()).await?; 185 let login_hint = if identity.is_some() { 186 Some(input.as_ref().into()) 187 } else { 188 None 189 }; 190 let metadata = OAuthMetadata { 191 server_metadata, 192 client_metadata, 193 keyset: self.registry.client_data.keyset.clone(), 194 }; 195 196 let auth_req_info = par( 197 self.client.as_ref(), 198 login_hint, 199 options.prompt, 200 &metadata, 201 options.state, 202 ) 203 .await?; 204 205 // Persist state for callback handling 206 self.registry 207 .store 208 .save_auth_req_info(&auth_req_info) 209 .await?; 210 211 #[derive(serde::Serialize)] 212 struct Parameters<'s> { 213 client_id: CowStr<'s>, 214 request_uri: CowStr<'s>, 215 } 216 Ok(metadata.server_metadata.authorization_endpoint.to_string() 217 + "?" 218 + &serde_html_form::to_string(Parameters { 219 client_id: metadata.client_metadata.client_id, 220 request_uri: auth_req_info.request_uri, 221 }) 222 .unwrap()) 223 } 224 225 /// Complete the OAuth authorization flow after the authorization server redirects back to the client. 226 /// 227 /// Validates the `state` and optional `iss` parameters, exchanges the authorization code for 228 /// tokens via the token endpoint, verifies the `sub` claim against the expected issuer, and 229 /// persists the resulting session. On success returns an [`OAuthSession`] ready for API calls. 230 #[cfg_attr(feature = "tracing", tracing::instrument(level = "info", skip_all, fields(state = params.state.as_ref().map(|s| s.as_ref()))))] 231 pub async fn callback(&self, params: CallbackParams<'_>) -> Result<OAuthSession<T, S>> { 232 let Some(state_key) = params.state else { 233 return Err(CallbackError::MissingState.into()); 234 }; 235 236 let Some(auth_req_info) = self.registry.store.get_auth_req_info(&state_key).await? else { 237 return Err(CallbackError::MissingState.into()); 238 }; 239 240 self.registry.store.delete_auth_req_info(&state_key).await?; 241 242 let metadata = self 243 .client 244 .get_authorization_server_metadata(&auth_req_info.authserver_url.to_cowstr()) 245 .await?; 246 247 if let Some(iss) = params.iss { 248 if iss != metadata.issuer { 249 return Err(CallbackError::IssuerMismatch { 250 expected: metadata.issuer.to_string(), 251 got: iss.to_string(), 252 } 253 .into()); 254 } 255 } else if metadata.authorization_response_iss_parameter_supported == Some(true) { 256 return Err(CallbackError::MissingIssuer.into()); 257 } 258 let metadata = OAuthMetadata { 259 server_metadata: metadata, 260 client_metadata: atproto_client_metadata( 261 self.registry.client_data.config.clone(), 262 &self.registry.client_data.keyset, 263 )?, 264 keyset: self.registry.client_data.keyset.clone(), 265 }; 266 let authserver_nonce = auth_req_info.dpop_data.dpop_authserver_nonce.clone(); 267 268 match exchange_code( 269 self.client.as_ref(), 270 &mut auth_req_info.dpop_data.clone(), 271 &params.code, 272 &auth_req_info.pkce_verifier, 273 &metadata, 274 ) 275 .await 276 { 277 Ok(token_set) => { 278 let scopes = if let Some(scope) = &token_set.scope { 279 Scope::parse_multiple_reduced(&scope) 280 .expect("Failed to parse scopes") 281 .into_static() 282 } else { 283 vec![] 284 }; 285 let client_data = ClientSessionData { 286 account_did: token_set.sub.clone(), 287 session_id: auth_req_info.state, 288 host_url: Uri::parse(token_set.aud.as_ref())?.to_owned(), 289 authserver_url: auth_req_info.authserver_url.to_cowstr(), 290 authserver_token_endpoint: auth_req_info.authserver_token_endpoint, 291 authserver_revocation_endpoint: auth_req_info.authserver_revocation_endpoint, 292 scopes, 293 dpop_data: DpopClientData { 294 dpop_key: auth_req_info.dpop_data.dpop_key.clone(), 295 dpop_authserver_nonce: authserver_nonce.unwrap_or(CowStr::default()), 296 dpop_host_nonce: auth_req_info 297 .dpop_data 298 .dpop_authserver_nonce 299 .unwrap_or(CowStr::default()), 300 }, 301 token_set, 302 }; 303 304 self.create_session(client_data).await 305 } 306 Err(e) => Err(e.into()), 307 } 308 } 309 310 async fn create_session(&self, data: ClientSessionData<'_>) -> Result<OAuthSession<T, S>> { 311 self.registry.set(data.clone()).await?; 312 Ok(OAuthSession::new( 313 self.registry.clone(), 314 self.client.clone(), 315 data.into_static(), 316 )) 317 } 318 319 /// Restore a previously created session from the backing store, refreshing tokens if needed. 320 pub async fn restore(&self, did: &Did<'_>, session_id: &str) -> Result<OAuthSession<T, S>> { 321 self.create_session(self.registry.get(did, session_id, true).await?) 322 .await 323 } 324 325 /// Revoke a session by deleting it from the backing store. 326 /// 327 /// Note: this removes the session from local storage but does **not** call the authorization 328 /// server's revocation endpoint. To also invalidate the token server-side, prefer 329 /// [`OAuthSession::logout`], which calls `revoke` on the token before deleting the session. 330 pub async fn revoke(&self, did: &Did<'_>, session_id: &str) -> Result<()> { 331 Ok(self.registry.del(did, session_id).await?) 332 } 333} 334 335impl<T, S> HttpClient for OAuthClient<T, S> 336where 337 S: ClientAuthStore + Send + Sync + 'static, 338 T: OAuthResolver + DpopExt + Send + Sync + 'static, 339{ 340 type Error = T::Error; 341 342 async fn send_http( 343 &self, 344 request: http::Request<Vec<u8>>, 345 ) -> core::result::Result<http::Response<Vec<u8>>, Self::Error> { 346 self.client.send_http(request).await 347 } 348} 349 350impl<T, S> IdentityResolver for OAuthClient<T, S> 351where 352 S: ClientAuthStore + Send + Sync + 'static, 353 T: OAuthResolver + DpopExt + Send + Sync + 'static, 354{ 355 fn options(&self) -> &ResolverOptions { 356 self.client.options() 357 } 358 359 async fn resolve_handle( 360 &self, 361 handle: &Handle<'_>, 362 ) -> jacquard_identity::resolver::Result<Did<'static>> { 363 self.client.resolve_handle(handle).await 364 } 365 366 async fn resolve_did_doc( 367 &self, 368 did: &Did<'_>, 369 ) -> jacquard_identity::resolver::Result<DidDocResponse> { 370 self.client.resolve_did_doc(did).await 371 } 372} 373 374impl<T, S> XrpcClient for OAuthClient<T, S> 375where 376 S: ClientAuthStore + Send + Sync + 'static, 377 T: OAuthResolver + DpopExt + Send + Sync + 'static, 378{ 379 async fn base_uri(&self) -> Uri<String> { 380 self.endpoint.read().await.clone().unwrap_or_else(|| { 381 Uri::parse("https://public.api.bsky.app") 382 .expect("hardcoded URI is valid") 383 .to_owned() 384 }) 385 } 386 387 async fn opts(&self) -> CallOptions<'_> { 388 self.options.read().await.clone() 389 } 390 391 async fn set_opts(&self, opts: CallOptions<'_>) { 392 let mut guard = self.options.write().await; 393 *guard = opts.into_static(); 394 } 395 396 async fn set_base_uri(&self, uri: Uri<String>) { 397 let normalized = jacquard_common::xrpc::normalize_base_uri(uri); 398 let mut guard = self.endpoint.write().await; 399 *guard = Some(normalized); 400 } 401 402 async fn send<R>(&self, request: R) -> XrpcResult<XrpcResponse<R>> 403 where 404 R: XrpcRequest + Send + Sync, 405 <R as XrpcRequest>::Response: Send + Sync, 406 { 407 let opts = self.options.read().await.clone(); 408 self.send_with_opts(request, opts).await 409 } 410 411 async fn send_with_opts<R>( 412 &self, 413 request: R, 414 opts: CallOptions<'_>, 415 ) -> XrpcResult<XrpcResponse<R>> 416 where 417 R: XrpcRequest + Send + Sync, 418 <R as XrpcRequest>::Response: Send + Sync, 419 { 420 let base_uri = self.base_uri().await; 421 self.client 422 .xrpc(base_uri) 423 .with_options(opts.clone()) 424 .send(&request) 425 .await 426 } 427} 428 429/// An active OAuth session for a specific account, used to make authenticated API requests. 430/// 431/// `OAuthSession` holds the DPoP-bound token set for one account and handles transparent 432/// token refresh on `401 invalid_token` responses. The optional `W` type parameter allows 433/// attaching a WebSocket client (defaults to `()` when WebSocket support is not needed). 434/// 435/// Obtain an `OAuthSession` from [`OAuthClient::callback`] or [`OAuthClient::restore`]. 436pub struct OAuthSession<T, S, W = ()> 437where 438 T: OAuthResolver, 439 S: ClientAuthStore, 440{ 441 /// Shared registry used to persist and retrieve session data across refresh operations. 442 pub registry: Arc<SessionRegistry<T, S>>, 443 /// Underlying HTTP/identity/OAuth resolver shared with the parent `OAuthClient`. 444 pub client: Arc<T>, 445 /// Optional WebSocket client; `()` when WebSocket support is not required. 446 pub ws_client: W, 447 /// Mutable session data including DPoP key, nonces, and token set. 448 pub data: RwLock<ClientSessionData<'static>>, 449 /// Default call options applied to every outgoing XRPC request from this session. 450 pub options: RwLock<CallOptions<'static>>, 451} 452 453impl<T, S> OAuthSession<T, S, ()> 454where 455 T: OAuthResolver, 456 S: ClientAuthStore, 457{ 458 /// Create a new session without a WebSocket client. 459 /// 460 /// This is the standard constructor used by [`OAuthClient::callback`] and 461 /// [`OAuthClient::restore`]. For WebSocket support use [`OAuthSession::new_with_ws`]. 462 pub fn new( 463 registry: Arc<SessionRegistry<T, S>>, 464 client: Arc<T>, 465 data: ClientSessionData<'static>, 466 ) -> Self { 467 Self { 468 registry, 469 client, 470 ws_client: (), 471 data: RwLock::new(data), 472 options: RwLock::new(CallOptions::default()), 473 } 474 } 475} 476 477impl<T, S, W> OAuthSession<T, S, W> 478where 479 T: OAuthResolver, 480 S: ClientAuthStore, 481{ 482 /// Create a new session with an attached WebSocket client. 483 /// 484 /// Use this variant when the session needs to support WebSocket subscriptions in addition 485 /// to standard XRPC calls. The `ws_client` is exposed via [`OAuthSession::ws_client`] and 486 /// is used by the `WebSocketClient` impl when the `websocket` feature is enabled. 487 pub fn new_with_ws( 488 registry: Arc<SessionRegistry<T, S>>, 489 client: Arc<T>, 490 ws_client: W, 491 data: ClientSessionData<'static>, 492 ) -> Self { 493 Self { 494 registry, 495 client, 496 ws_client, 497 data: RwLock::new(data), 498 options: RwLock::new(CallOptions::default()), 499 } 500 } 501 502 /// Consume this session and return a new one with the given call options pre-applied. 503 /// 504 /// Useful for setting request-level defaults (e.g., `atproto-proxy` or custom headers) once 505 /// at construction time rather than passing them to every individual XRPC call. 506 pub fn with_options(self, options: CallOptions<'_>) -> Self { 507 Self { 508 registry: self.registry, 509 client: self.client, 510 ws_client: self.ws_client, 511 data: self.data, 512 options: RwLock::new(options.into_static()), 513 } 514 } 515 516 /// Get a reference to the WebSocket client. 517 pub fn ws_client(&self) -> &W { 518 &self.ws_client 519 } 520 521 /// Replace the default call options for this session without consuming it. 522 pub async fn set_options(&self, options: CallOptions<'_>) { 523 *self.options.write().await = options.into_static(); 524 } 525 526 /// Return the DID and session ID for this session. 527 /// 528 /// The session ID is the random `state` token generated during the PAR flow and can 529 /// be used together with the DID to restore the session via [`OAuthClient::restore`]. 530 pub async fn session_info(&self) -> (Did<'_>, CowStr<'_>) { 531 let data = self.data.read().await; 532 (data.account_did.clone(), data.session_id.clone()) 533 } 534 535 /// Return the resource server (PDS) base URI for this session. 536 pub async fn endpoint(&self) -> Uri<String> { 537 self.data.read().await.host_url.clone() 538 } 539 540 /// Return the current DPoP-bound access token for this session. 541 /// 542 /// The token may be stale if it has expired; use [`OAuthSession::refresh`] or 543 /// rely on the automatic refresh performed by `send_with_opts` to obtain a fresh one. 544 pub async fn access_token(&self) -> AuthorizationToken<'_> { 545 AuthorizationToken::Dpop(self.data.read().await.token_set.access_token.clone()) 546 } 547 548 /// Return the current refresh token for this session, if one is present. 549 /// 550 /// Not all authorization servers issue refresh tokens. When `None` is returned, 551 /// the session cannot be silently renewed and the user must re-authenticate. 552 pub async fn refresh_token(&self) -> Option<AuthorizationToken<'_>> { 553 self.data 554 .read() 555 .await 556 .token_set 557 .refresh_token 558 .as_ref() 559 .map(|t| AuthorizationToken::Dpop(t.clone())) 560 } 561 562 /// Derive an unauthenticated [`OAuthClient`] that shares the same registry and resolver. 563 /// 564 /// Useful when you need to initiate a new authorization flow from within an existing 565 /// session context (e.g., to add a second account) without constructing a fresh client. 566 pub fn to_client(&self) -> OAuthClient<T, S> { 567 OAuthClient::from_session(self) 568 } 569} 570impl<T, S, W> OAuthSession<T, S, W> 571where 572 S: ClientAuthStore + Send + Sync + 'static, 573 T: OAuthResolver + DpopExt + Send + Sync + 'static, 574{ 575 /// Revoke the access token at the authorization server and delete the session from the store. 576 /// 577 /// Revocation is best-effort: if the server does not advertise a revocation endpoint, or if 578 /// the revocation call fails, the session is still deleted locally. This prevents a dangling 579 /// session record from blocking future logins for the same account. 580 pub async fn logout(&self) -> Result<()> { 581 use crate::request::{OAuthMetadata, revoke}; 582 let mut data = self.data.write().await; 583 let meta = 584 OAuthMetadata::new(self.client.as_ref(), &self.registry.client_data, &data).await?; 585 if meta.server_metadata.revocation_endpoint.is_some() { 586 let token = data.token_set.access_token.clone(); 587 revoke(self.client.as_ref(), &mut data.dpop_data, &token, &meta) 588 .await 589 .ok(); 590 } 591 // Remove from store 592 self.registry 593 .del(&data.account_did, &data.session_id) 594 .await?; 595 Ok(()) 596 } 597} 598 599impl<T, S> OAuthClient<T, S> 600where 601 T: OAuthResolver, 602 S: ClientAuthStore, 603{ 604 /// Construct an `OAuthClient` that shares the registry and resolver of an existing session. 605 /// 606 /// Equivalent to [`OAuthSession::to_client`]; provided on `OAuthClient` for symmetry so 607 /// callers can obtain an unauthenticated client without holding a session reference. 608 pub fn from_session<W>(session: &OAuthSession<T, S, W>) -> Self { 609 Self { 610 registry: session.registry.clone(), 611 client: session.client.clone(), 612 options: RwLock::new(CallOptions::default()), 613 endpoint: RwLock::new(None), 614 } 615 } 616} 617impl<T, S, W> OAuthSession<T, S, W> 618where 619 S: ClientAuthStore + Send + Sync + 'static, 620 T: OAuthResolver + DpopExt + Send + Sync + 'static, 621{ 622 /// Explicitly refresh the access token using the stored refresh token. 623 /// 624 /// On success the new token set is written back into both the in-memory session data and 625 /// the backing store. The returned `AuthorizationToken` is the new access token, which 626 /// callers can immediately use to retry a failed request. 627 /// 628 /// The actual token exchange is serialized per `(DID, session_id)` pair via a `Mutex` inside 629 /// the registry, so concurrent refresh attempts will not result in duplicate token exchanges. 630 #[cfg_attr(feature = "tracing", tracing::instrument(level = "debug", skip_all))] 631 pub async fn refresh(&self) -> Result<AuthorizationToken<'_>> { 632 // Read identifiers without holding the lock across await 633 let (did, sid) = { 634 let data = self.data.read().await; 635 (data.account_did.clone(), data.session_id.clone()) 636 }; 637 let refreshed = self.registry.as_ref().get(&did, &sid, true).await?; 638 let token = AuthorizationToken::Dpop(refreshed.token_set.access_token.clone()); 639 // Write back updated session 640 *self.data.write().await = refreshed.clone().into_static(); 641 // Store in the registry 642 self.registry.set(refreshed).await?; 643 Ok(token) 644 } 645} 646 647impl<T, S, W> HttpClient for OAuthSession<T, S, W> 648where 649 S: ClientAuthStore + Send + Sync + 'static, 650 T: OAuthResolver + DpopExt + Send + Sync + 'static, 651 W: Send + Sync, 652{ 653 type Error = T::Error; 654 655 async fn send_http( 656 &self, 657 request: http::Request<Vec<u8>>, 658 ) -> core::result::Result<http::Response<Vec<u8>>, Self::Error> { 659 self.client.send_http(request).await 660 } 661} 662 663impl<T, S, W> XrpcClient for OAuthSession<T, S, W> 664where 665 S: ClientAuthStore + Send + Sync + 'static, 666 T: OAuthResolver + DpopExt + XrpcExt + Send + Sync + 'static, 667 W: Send + Sync, 668{ 669 async fn base_uri(&self) -> Uri<String> { 670 self.data.read().await.host_url.clone() 671 } 672 673 async fn opts(&self) -> CallOptions<'_> { 674 self.options.read().await.clone() 675 } 676 677 async fn set_opts(&self, opts: CallOptions<'_>) { 678 let mut guard = self.options.write().await; 679 *guard = opts.into_static(); 680 } 681 682 async fn set_base_uri(&self, uri: Uri<String>) { 683 let normalized = jacquard_common::xrpc::normalize_base_uri(uri); 684 let mut guard = self.data.write().await; 685 guard.host_url = normalized; 686 } 687 688 async fn send<R>(&self, request: R) -> XrpcResult<XrpcResponse<R>> 689 where 690 R: XrpcRequest + Send + Sync, 691 <R as XrpcRequest>::Response: Send + Sync, 692 { 693 let opts = self.options.read().await.clone(); 694 self.send_with_opts(request, opts).await 695 } 696 697 async fn send_with_opts<R>( 698 &self, 699 request: R, 700 mut opts: CallOptions<'_>, 701 ) -> XrpcResult<XrpcResponse<R>> 702 where 703 R: XrpcRequest + Send + Sync, 704 <R as XrpcRequest>::Response: Send + Sync, 705 { 706 let base_uri = self.base_uri().await; 707 let original_token = self.access_token().await; 708 opts.auth = Some(original_token.clone()); 709 // Clone dpop_data and release read lock before the await point 710 let mut dpop = self.data.read().await.dpop_data.clone(); 711 let http_response = self 712 .client 713 .dpop_call(&mut dpop) 714 .send(build_http_request(&base_uri, &request, &opts)?) 715 .await 716 .map_err(|e| ClientError::from(e).for_nsid(R::NSID))?; 717 let resp = process_response(http_response); 718 719 // Write back updated nonce to session data (dpop_call may have updated it) 720 { 721 let mut guard = self.data.write().await; 722 guard.dpop_data.dpop_host_nonce = dpop.dpop_host_nonce.clone(); 723 } 724 725 if is_invalid_token_response(&resp) { 726 // Optimistic refresh: check if another request already refreshed the token 727 let current_token = self.access_token().await; 728 if current_token != original_token { 729 // Token was already refreshed by another concurrent request, use it 730 opts.auth = Some(current_token); 731 } else { 732 // We need to refresh - this will be serialized by the registry's Mutex 733 opts.auth = Some( 734 self.refresh() 735 .await 736 .map_err(|e| ClientError::transport(e))?, 737 ); 738 } 739 // Re-read dpop_data after refresh (refresh may have updated it) 740 let mut dpop = self.data.read().await.dpop_data.clone(); 741 let http_response = self 742 .client 743 .dpop_call(&mut dpop) 744 .send(build_http_request(&base_uri, &request, &opts)?) 745 .await 746 .map_err(|e| { 747 ClientError::from(e) 748 .for_nsid(R::NSID) 749 .append_context("after token refresh") 750 })?; 751 let resp = process_response(http_response); 752 753 // Write back updated nonce after retry 754 { 755 let mut guard = self.data.write().await; 756 guard.dpop_data.dpop_host_nonce = dpop.dpop_host_nonce.clone(); 757 } 758 759 resp 760 } else { 761 resp 762 } 763 } 764} 765 766#[cfg(feature = "streaming")] 767impl<T, S, W> jacquard_common::http_client::HttpClientExt for OAuthSession<T, S, W> 768where 769 S: ClientAuthStore + Send + Sync + 'static, 770 T: OAuthResolver 771 + DpopExt 772 + XrpcExt 773 + jacquard_common::http_client::HttpClientExt 774 + Send 775 + Sync 776 + 'static, 777 W: Send + Sync, 778{ 779 async fn send_http_streaming( 780 &self, 781 request: http::Request<Vec<u8>>, 782 ) -> core::result::Result<http::Response<jacquard_common::stream::ByteStream>, Self::Error> 783 { 784 self.client.send_http_streaming(request).await 785 } 786 787 #[cfg(not(target_arch = "wasm32"))] 788 async fn send_http_bidirectional<Str>( 789 &self, 790 parts: http::request::Parts, 791 body: Str, 792 ) -> core::result::Result<http::Response<jacquard_common::stream::ByteStream>, Self::Error> 793 where 794 Str: n0_future::Stream< 795 Item = core::result::Result<bytes::Bytes, jacquard_common::StreamError>, 796 > + Send 797 + 'static, 798 { 799 self.client.send_http_bidirectional(parts, body).await 800 } 801 802 #[cfg(target_arch = "wasm32")] 803 async fn send_http_bidirectional<Str>( 804 &self, 805 parts: http::request::Parts, 806 body: Str, 807 ) -> core::result::Result<http::Response<jacquard_common::stream::ByteStream>, Self::Error> 808 where 809 Str: n0_future::Stream< 810 Item = core::result::Result<bytes::Bytes, jacquard_common::StreamError>, 811 > + 'static, 812 { 813 self.client.send_http_bidirectional(parts, body).await 814 } 815} 816 817#[cfg(feature = "streaming")] 818impl<T, S, W> jacquard_common::xrpc::XrpcStreamingClient for OAuthSession<T, S, W> 819where 820 S: ClientAuthStore + Send + Sync + 'static, 821 T: OAuthResolver 822 + DpopExt 823 + XrpcExt 824 + jacquard_common::http_client::HttpClientExt 825 + Send 826 + Sync 827 + 'static, 828 W: Send + Sync, 829{ 830 async fn download<R>( 831 &self, 832 request: R, 833 ) -> core::result::Result<jacquard_common::xrpc::StreamingResponse, jacquard_common::StreamError> 834 where 835 R: XrpcRequest + Send + Sync, 836 <R as XrpcRequest>::Response: Send + Sync, 837 { 838 use jacquard_common::StreamError; 839 840 let base_uri = <Self as XrpcClient>::base_uri(self).await; 841 let mut opts = self.options.read().await.clone(); 842 opts.auth = Some(self.access_token().await); 843 let http_request = build_http_request(&base_uri, &request, &opts) 844 .map_err(|e| StreamError::protocol(e.to_string()))?; 845 let guard = self.data.read().await; 846 let mut dpop = guard.dpop_data.clone(); 847 let result = self 848 .client 849 .dpop_call(&mut dpop) 850 .send_streaming(http_request) 851 .await; 852 drop(guard); 853 854 match result { 855 Ok(response) => Ok(response), 856 Err(_e) => { 857 // Check if it's an auth error and retry 858 opts.auth = Some( 859 self.refresh() 860 .await 861 .map_err(|e| StreamError::transport(e))?, 862 ); 863 let http_request = build_http_request(&base_uri, &request, &opts) 864 .map_err(|e| StreamError::protocol(e.to_string()))?; 865 let guard = self.data.read().await; 866 let mut dpop = guard.dpop_data.clone(); 867 self.client 868 .dpop_call(&mut dpop) 869 .send_streaming(http_request) 870 .await 871 .map_err(StreamError::transport) 872 } 873 } 874 } 875 876 async fn stream<Str>( 877 &self, 878 stream: jacquard_common::xrpc::streaming::XrpcProcedureSend<Str::Frame<'static>>, 879 ) -> core::result::Result< 880 jacquard_common::xrpc::streaming::XrpcResponseStream< 881 <<Str as jacquard_common::xrpc::streaming::XrpcProcedureStream>::Response as jacquard_common::xrpc::streaming::XrpcStreamResp>::Frame<'static>, 882 >, 883 jacquard_common::StreamError, 884 > 885 where 886 Str: jacquard_common::xrpc::streaming::XrpcProcedureStream + 'static, 887 <<Str as jacquard_common::xrpc::streaming::XrpcProcedureStream>::Response as jacquard_common::xrpc::streaming::XrpcStreamResp>::Frame<'static>: jacquard_common::xrpc::streaming::XrpcStreamResp, 888 { 889 use jacquard_common::StreamError; 890 use n0_future::TryStreamExt; 891 892 let base_uri = self.base_uri().await; 893 let mut opts = self.options.read().await.clone(); 894 opts.auth = Some(self.access_token().await); 895 896 let mut path = String::from(base_uri.as_str().trim_end_matches('/')); 897 path.push_str("/xrpc/"); 898 path.push_str(<Str::Request as jacquard_common::xrpc::XrpcRequest>::NSID); 899 900 let mut builder = http::Request::post(path); 901 902 if let Some(token) = &opts.auth { 903 use jacquard_common::AuthorizationToken; 904 let hv = match token { 905 AuthorizationToken::Bearer(t) => { 906 http::HeaderValue::from_str(&format!("Bearer {}", t.as_ref())) 907 } 908 AuthorizationToken::Dpop(t) => { 909 http::HeaderValue::from_str(&format!("DPoP {}", t.as_ref())) 910 } 911 } 912 .map_err(|e| StreamError::protocol(format!("Invalid authorization token: {}", e)))?; 913 builder = builder.header(http::header::AUTHORIZATION, hv); 914 } 915 916 if let Some(proxy) = &opts.atproto_proxy { 917 builder = builder.header("atproto-proxy", proxy.as_ref()); 918 } 919 if let Some(labelers) = &opts.atproto_accept_labelers { 920 if !labelers.is_empty() { 921 let joined = labelers 922 .iter() 923 .map(|s| s.as_ref()) 924 .collect::<Vec<_>>() 925 .join(", "); 926 builder = builder.header("atproto-accept-labelers", joined); 927 } 928 } 929 for (name, value) in &opts.extra_headers { 930 builder = builder.header(name, value); 931 } 932 933 let (parts, _) = builder 934 .body(()) 935 .map_err(|e| StreamError::protocol(e.to_string()))? 936 .into_parts(); 937 938 let body_stream = 939 jacquard_common::stream::ByteStream::new(Box::pin(stream.0.map_ok(|f| f.buffer))); 940 941 let guard = self.data.read().await; 942 let mut dpop = guard.dpop_data.clone(); 943 let result = self 944 .client 945 .dpop_call(&mut dpop) 946 .send_bidirectional(parts, body_stream) 947 .await; 948 drop(guard); 949 950 match result { 951 Ok(response) => { 952 let (resp_parts, resp_body) = response.into_parts(); 953 Ok( 954 jacquard_common::xrpc::streaming::XrpcResponseStream::from_typed_parts( 955 resp_parts, resp_body, 956 ), 957 ) 958 } 959 Err(e) => { 960 // OAuth token refresh and retry is handled by dpop wrapper 961 // If we get here, it's a real error 962 Err(StreamError::transport(e)) 963 } 964 } 965 } 966} 967 968fn is_invalid_token_response<R: XrpcResp>(response: &XrpcResult<Response<R>>) -> bool { 969 use jacquard_common::error::ClientErrorKind; 970 971 match response { 972 Err(e) => match e.kind() { 973 ClientErrorKind::Auth(AuthError::InvalidToken) => true, 974 ClientErrorKind::Auth(AuthError::Other(value)) => value 975 .to_str() 976 .is_ok_and(|s| s.starts_with("DPoP ") && s.contains("error=\"invalid_token\"")), 977 _ => false, 978 }, 979 Ok(resp) => match resp.parse() { 980 Err(XrpcError::Auth(AuthError::InvalidToken)) => true, 981 _ => false, 982 }, 983 } 984} 985 986impl<T, S, W> IdentityResolver for OAuthSession<T, S, W> 987where 988 S: ClientAuthStore + Send + Sync + 'static, 989 T: OAuthResolver + IdentityResolver + XrpcExt + Send + Sync + 'static, 990 W: Send + Sync, 991{ 992 fn options(&self) -> &ResolverOptions { 993 self.client.options() 994 } 995 996 fn resolve_handle( 997 &self, 998 handle: &Handle<'_>, 999 ) -> impl Future<Output = std::result::Result<Did<'static>, IdentityError>> { 1000 async { self.client.resolve_handle(handle).await } 1001 } 1002 1003 fn resolve_did_doc( 1004 &self, 1005 did: &Did<'_>, 1006 ) -> impl Future<Output = std::result::Result<DidDocResponse, IdentityError>> { 1007 async { self.client.resolve_did_doc(did).await } 1008 } 1009} 1010 1011#[cfg(feature = "websocket")] 1012impl<T, S, W> WebSocketClient for OAuthSession<T, S, W> 1013where 1014 S: ClientAuthStore + Send + Sync + 'static, 1015 T: OAuthResolver + Send + Sync + 'static, 1016 W: WebSocketClient + Send + Sync, 1017{ 1018 type Error = W::Error; 1019 1020 async fn connect( 1021 &self, 1022 uri: Uri<&str>, 1023 ) -> std::result::Result<WebSocketConnection, Self::Error> { 1024 self.ws_client.connect(uri).await 1025 } 1026 1027 async fn connect_with_headers( 1028 &self, 1029 uri: Uri<&str>, 1030 headers: Vec<(CowStr<'_>, CowStr<'_>)>, 1031 ) -> std::result::Result<WebSocketConnection, Self::Error> { 1032 self.ws_client.connect_with_headers(uri, headers).await 1033 } 1034} 1035 1036#[cfg(feature = "websocket")] 1037impl<T, S, W> jacquard_common::xrpc::SubscriptionClient for OAuthSession<T, S, W> 1038where 1039 S: ClientAuthStore + Send + Sync + 'static, 1040 T: OAuthResolver + Send + Sync + 'static, 1041 W: WebSocketClient + Send + Sync, 1042{ 1043 async fn base_uri(&self) -> Uri<String> { 1044 self.data.read().await.host_url.clone() 1045 } 1046 1047 async fn subscription_opts(&self) -> jacquard_common::xrpc::SubscriptionOptions<'_> { 1048 let mut opts = jacquard_common::xrpc::SubscriptionOptions::default(); 1049 let token = self.access_token().await; 1050 let auth_value = match token { 1051 AuthorizationToken::Bearer(t) => format!("Bearer {}", t.as_ref()), 1052 AuthorizationToken::Dpop(t) => format!("DPoP {}", t.as_ref()), 1053 }; 1054 opts.headers 1055 .push((CowStr::from("Authorization"), CowStr::from(auth_value))); 1056 opts 1057 } 1058 1059 async fn subscribe<Sub>( 1060 &self, 1061 params: &Sub, 1062 ) -> std::result::Result<jacquard_common::xrpc::SubscriptionStream<Sub::Stream>, Self::Error> 1063 where 1064 Sub: XrpcSubscription + Send + Sync, 1065 { 1066 let opts = self.subscription_opts().await; 1067 self.subscribe_with_opts(params, opts).await 1068 } 1069 1070 async fn subscribe_with_opts<Sub>( 1071 &self, 1072 params: &Sub, 1073 opts: jacquard_common::xrpc::SubscriptionOptions<'_>, 1074 ) -> std::result::Result<jacquard_common::xrpc::SubscriptionStream<Sub::Stream>, Self::Error> 1075 where 1076 Sub: XrpcSubscription + Send + Sync, 1077 { 1078 use jacquard_common::xrpc::SubscriptionExt; 1079 let base = self.base_uri().await; 1080 self.subscription(base) 1081 .with_options(opts) 1082 .subscribe(params) 1083 .await 1084 } 1085}