A better Rust ATProto crate
102
fork

Configure Feed

Select the types of activity you want to include in your feed.

at pretty-codegen 273 lines 10 kB view raw
1//! 2//! Helpers for the local loopback server method of atproto OAuth. 3//! 4//! `OAuthClient::login_with_local_server()` is the nice helper. Here is where 5//! it and its components live. Below is what it does, so you can have more 6//! granular control without having to make your own loopback server. 7//! 8//! ```ignore 9//! let input = "your_handle_here"; 10//! let cfg = LoopbackConfig::default(); 11//! let opts = AuthorizeOptions::default(); 12//! let port = match cfg.port { 13//! LoopbackPort::Fixed(p) => p, 14//! LoopbackPort::Ephemeral => 0, 15//! }; 16//! // TODO: fix this to it also accepts ipv6 and properly finds a free port 17//! let bind_addr: SocketAddr = format!("0.0.0.0:{}", port) 18//! .parse() 19//! .expect("invalid loopback host/port"); 20//! let oauth = OAuthClient::with_default_config(FileAuthStore::new(&args.store)); 21//! 22//! let (local_addr, handle) = one_shot_server(bind_addr); 23//! println!("Listening on {}", local_addr); 24//! 25//! let client_data = oauth.build_localhost_client_data(&cfg, &opts, local_addr); 26//! // Build client using store and resolver 27//! let flow_client = OAuthClient::new_with_shared( 28//! self.registry.store.clone(), 29//! self.client.clone(), 30//! client_data, 31//! ); 32//! 33//! // Start auth and get authorization URL 34//! let auth_url = flow_client.start_auth(input.as_ref(), opts).await?; 35//! // Print URL for copy/paste 36//! println!("To authenticate with your PDS, visit:\n{}\n", auth_url); 37//! // Optionally open browser 38//! if cfg.open_browser { 39//! let _ = try_open_in_browser(&auth_url); 40//! } 41//! 42//! handle_localhost_callback(handle, &flow_client, &cfg).await 43//! ``` 44//! 45//! 46#![cfg(feature = "loopback")] 47use crate::{ 48 atproto::AtprotoClientMetadata, 49 authstore::ClientAuthStore, 50 client::OAuthClient, 51 dpop::DpopExt, 52 error::{CallbackError, OAuthError}, 53 resolver::OAuthResolver, 54 types::{AuthorizeOptions, CallbackParams}, 55}; 56use jacquard_common::deps::fluent_uri::Uri; 57use jacquard_common::{IntoStatic, cowstr::ToCowStr}; 58use rouille::Server; 59use std::net::SocketAddr; 60use tokio::sync::mpsc; 61 62/// Port selection strategy for the loopback OAuth callback server. 63#[derive(Clone, Debug)] 64pub enum LoopbackPort { 65 /// Bind to a specific port number. 66 Fixed(u16), 67 /// Let the OS assign an available port. 68 Ephemeral, 69} 70 71/// Configuration for the loopback OAuth callback server. 72#[derive(Clone, Debug)] 73pub struct LoopbackConfig { 74 /// The host address to bind to (e.g., `"127.0.0.1"`). 75 pub host: String, 76 /// Port selection strategy. 77 pub port: LoopbackPort, 78 /// Whether to attempt opening the authorization URL in the user's browser. 79 pub open_browser: bool, 80 /// How long to wait for the callback before timing out, in milliseconds. 81 pub timeout_ms: u64, 82} 83 84impl Default for LoopbackConfig { 85 fn default() -> Self { 86 Self { 87 host: "127.0.0.1".into(), 88 port: LoopbackPort::Fixed(4000), 89 open_browser: true, 90 timeout_ms: 5 * 60 * 1000, 91 } 92 } 93} 94 95/// Attempts to open the given URL in the user's default browser. 96/// 97/// Returns `true` if the browser was opened successfully, `false` otherwise. 98#[cfg(feature = "browser-open")] 99pub fn try_open_in_browser(url: &str) -> bool { 100 webbrowser::open(url).is_ok() 101} 102/// Stub for when the `browser-open` feature is disabled. Always returns `false`. 103#[cfg(not(feature = "browser-open"))] 104pub fn try_open_in_browser(_url: &str) -> bool { 105 false 106} 107 108fn create_callback_router( 109 request: &rouille::Request, 110 tx: mpsc::Sender<CallbackParams>, 111) -> rouille::Response { 112 rouille::router!(request, 113 (GET) (/oauth/callback) => { 114 let state = request.get_param("state").unwrap(); 115 let code = request.get_param("code").unwrap(); 116 let iss = request.get_param("iss").unwrap(); 117 let callback_params = CallbackParams { 118 state: Some(state.to_cowstr().into_static()), 119 code: code.to_cowstr().into_static(), 120 iss: Some(iss.to_cowstr().into_static()), 121 }; 122 tx.try_send(callback_params).unwrap(); 123 rouille::Response::text("Logged in!") 124 }, 125 _ => rouille::Response::empty_404() 126 ) 127} 128 129/// Handle to a running loopback callback server, used to await the OAuth redirect. 130pub struct CallbackHandle { 131 #[allow(dead_code)] 132 server_handle: std::thread::JoinHandle<()>, 133 server_stop: std::sync::mpsc::Sender<()>, 134 callback_rx: mpsc::Receiver<CallbackParams<'static>>, 135} 136 137/// One-shot OAuth callback server. 138/// 139/// Starts an ephemeral in-process web server that listens for the OAuth 140/// callback redirect. Returns the server address and a [`CallbackHandle`] 141/// that can be used to wait for the callback and stop the server. 142/// 143/// Use in combination with [`handle_localhost_callback`] to handle the 144/// callback for the localhost loopback server. 145pub fn one_shot_server(addr: SocketAddr) -> (SocketAddr, CallbackHandle) { 146 let (tx, callback_rx) = mpsc::channel(5); 147 let server = Server::new(addr, move |request| { 148 create_callback_router(request, tx.clone()) 149 }) 150 .expect("Could not start server"); 151 let (server_handle, server_stop) = server.stoppable(); 152 let handle = CallbackHandle { 153 server_handle, 154 server_stop, 155 callback_rx, 156 }; 157 (addr, handle) 158} 159 160/// Handles the OAuth callback for the localhost loopback server. 161/// 162/// Returns a session if the callback succeeds within the configured timeout 163/// and shuts down the server. 164pub async fn handle_localhost_callback<T, S>( 165 handle: CallbackHandle, 166 flow_client: &super::client::OAuthClient<T, S>, 167 cfg: &LoopbackConfig, 168) -> crate::error::Result<super::client::OAuthSession<T, S>> 169where 170 T: OAuthResolver + DpopExt + Send + Sync + 'static, 171 S: ClientAuthStore + Send + Sync + 'static, 172{ 173 // Await callback or timeout 174 let mut callback_rx = handle.callback_rx; 175 let cb = tokio::time::timeout( 176 std::time::Duration::from_millis(cfg.timeout_ms), 177 callback_rx.recv(), 178 ) 179 .await; 180 // trigger shutdown 181 let _ = handle.server_stop.send(()); 182 if let Ok(Some(cb)) = cb { 183 // Handle callback and create a session 184 Ok(flow_client.callback(cb).await?) 185 } else { 186 Err(OAuthError::Callback(CallbackError::Timeout)) 187 } 188} 189 190impl<T, S> OAuthClient<T, S> 191where 192 T: OAuthResolver + DpopExt + Send + Sync + 'static, 193 S: ClientAuthStore + Send + Sync + 'static, 194{ 195 /// Drive the full OAuth flow using a local loopback server. 196 /// 197 /// This uses localhost OAuth and an ephemeral in-process web server to 198 /// handle the OAuth callback redirect. It has a bunch of nice friendly 199 /// defaults to help you get started and will basically drive the *entire* 200 /// callback flow itself. 201 /// 202 /// Best used for development and for small CLI applications that don't 203 /// require long session lengths. For long-running unattended sessions, 204 /// app passwords (via CredentialSession in the jacquard crate) remain 205 /// the best option. For more complex OAuth, or if you want more control 206 /// over the process, use the other methods on OAuthClient. 207 /// 208 /// 'input' parameter is what you type in the login box (usually, your handle) 209 /// for it to look up your PDS and redirect to its authentication interface. 210 /// 211 /// If the `browser-open` feature is enabled, this will open a web browser 212 /// for you to authenticate with your PDS. It will also print the 213 /// callback url to the console for you to copy. 214 pub async fn login_with_local_server( 215 &self, 216 input: impl AsRef<str>, 217 opts: AuthorizeOptions<'_>, 218 cfg: LoopbackConfig, 219 ) -> crate::error::Result<super::client::OAuthSession<T, S>> { 220 let port = match cfg.port { 221 LoopbackPort::Fixed(p) => p, 222 LoopbackPort::Ephemeral => 0, 223 }; 224 // TODO: fix this to it also accepts ipv6 and properly finds a free port 225 let bind_addr: SocketAddr = format!("0.0.0.0:{}", port) 226 .parse() 227 .expect("invalid loopback host/port"); 228 let (local_addr, handle) = one_shot_server(bind_addr); 229 println!("Listening on {}", local_addr); 230 231 let client_data = self.build_localhost_client_data(&cfg, &opts, local_addr); 232 // Build client using store and resolver 233 let flow_client = OAuthClient::new_with_shared( 234 self.registry.store.clone(), 235 self.client.clone(), 236 client_data, 237 ); 238 239 // Start auth and get authorization URL 240 let auth_url = flow_client.start_auth(input.as_ref(), opts).await?; 241 // Print URL for copy/paste 242 println!("To authenticate with your PDS, visit:\n{}\n", auth_url); 243 // Optionally open browser 244 if cfg.open_browser { 245 let _ = try_open_in_browser(&auth_url); 246 } 247 248 handle_localhost_callback(handle, &flow_client, &cfg).await 249 } 250 251 /// Builds a [`crate::session::ClientData`] for use with the local loopback server method of OAuth. 252 pub fn build_localhost_client_data( 253 &self, 254 cfg: &LoopbackConfig, 255 opts: &AuthorizeOptions<'_>, 256 local_addr: SocketAddr, 257 ) -> crate::session::ClientData<'static> { 258 let redirect_uri = format!("http://{}:{}/oauth/callback", cfg.host, local_addr.port(),); 259 let redirect = Uri::parse(redirect_uri).unwrap(); 260 261 let scopes = if opts.scopes.is_empty() { 262 Some(self.registry.client_data.config.scopes.clone()) 263 } else { 264 Some(opts.scopes.clone().into_static()) 265 }; 266 267 crate::session::ClientData { 268 keyset: self.registry.client_data.keyset.clone(), 269 config: AtprotoClientMetadata::new_localhost(Some(vec![redirect]), scopes), 270 } 271 .into_static() 272 } 273}