A better Rust ATProto crate
1//!
2//! Helpers for the local loopback server method of atproto OAuth.
3//!
4//! `OAuthClient::login_with_local_server()` is the nice helper. Here is where
5//! it and its components live. Below is what it does, so you can have more
6//! granular control without having to make your own loopback server.
7//!
8//! ```ignore
9//! let input = "your_handle_here";
10//! let cfg = LoopbackConfig::default();
11//! let opts = AuthorizeOptions::default();
12//! let port = match cfg.port {
13//! LoopbackPort::Fixed(p) => p,
14//! LoopbackPort::Ephemeral => 0,
15//! };
16//! // TODO: fix this to it also accepts ipv6 and properly finds a free port
17//! let bind_addr: SocketAddr = format!("0.0.0.0:{}", port)
18//! .parse()
19//! .expect("invalid loopback host/port");
20//! let oauth = OAuthClient::with_default_config(FileAuthStore::new(&args.store));
21//!
22//! let (local_addr, handle) = one_shot_server(bind_addr);
23//! println!("Listening on {}", local_addr);
24//!
25//! let client_data = oauth.build_localhost_client_data(&cfg, &opts, local_addr);
26//! // Build client using store and resolver
27//! let flow_client = OAuthClient::new_with_shared(
28//! self.registry.store.clone(),
29//! self.client.clone(),
30//! client_data,
31//! );
32//!
33//! // Start auth and get authorization URL
34//! let auth_url = flow_client.start_auth(input.as_ref(), opts).await?;
35//! // Print URL for copy/paste
36//! println!("To authenticate with your PDS, visit:\n{}\n", auth_url);
37//! // Optionally open browser
38//! if cfg.open_browser {
39//! let _ = try_open_in_browser(&auth_url);
40//! }
41//!
42//! handle_localhost_callback(handle, &flow_client, &cfg).await
43//! ```
44//!
45//!
46#![cfg(feature = "loopback")]
47use crate::{
48 atproto::AtprotoClientMetadata,
49 authstore::ClientAuthStore,
50 client::OAuthClient,
51 dpop::DpopExt,
52 error::{CallbackError, OAuthError},
53 resolver::OAuthResolver,
54 types::{AuthorizeOptions, CallbackParams},
55};
56use jacquard_common::deps::fluent_uri::Uri;
57use jacquard_common::{IntoStatic, cowstr::ToCowStr};
58use rouille::Server;
59use std::net::SocketAddr;
60use tokio::sync::mpsc;
61
62/// Port selection strategy for the loopback OAuth callback server.
63#[derive(Clone, Debug)]
64pub enum LoopbackPort {
65 /// Bind to a specific port number.
66 Fixed(u16),
67 /// Let the OS assign an available port.
68 Ephemeral,
69}
70
71/// Configuration for the loopback OAuth callback server.
72#[derive(Clone, Debug)]
73pub struct LoopbackConfig {
74 /// The host address to bind to (e.g., `"127.0.0.1"`).
75 pub host: String,
76 /// Port selection strategy.
77 pub port: LoopbackPort,
78 /// Whether to attempt opening the authorization URL in the user's browser.
79 pub open_browser: bool,
80 /// How long to wait for the callback before timing out, in milliseconds.
81 pub timeout_ms: u64,
82}
83
84impl Default for LoopbackConfig {
85 fn default() -> Self {
86 Self {
87 host: "127.0.0.1".into(),
88 port: LoopbackPort::Fixed(4000),
89 open_browser: true,
90 timeout_ms: 5 * 60 * 1000,
91 }
92 }
93}
94
95/// Attempts to open the given URL in the user's default browser.
96///
97/// Returns `true` if the browser was opened successfully, `false` otherwise.
98#[cfg(feature = "browser-open")]
99pub fn try_open_in_browser(url: &str) -> bool {
100 webbrowser::open(url).is_ok()
101}
102/// Stub for when the `browser-open` feature is disabled. Always returns `false`.
103#[cfg(not(feature = "browser-open"))]
104pub fn try_open_in_browser(_url: &str) -> bool {
105 false
106}
107
108fn create_callback_router(
109 request: &rouille::Request,
110 tx: mpsc::Sender<CallbackParams>,
111) -> rouille::Response {
112 rouille::router!(request,
113 (GET) (/oauth/callback) => {
114 let state = request.get_param("state").unwrap();
115 let code = request.get_param("code").unwrap();
116 let iss = request.get_param("iss").unwrap();
117 let callback_params = CallbackParams {
118 state: Some(state.to_cowstr().into_static()),
119 code: code.to_cowstr().into_static(),
120 iss: Some(iss.to_cowstr().into_static()),
121 };
122 tx.try_send(callback_params).unwrap();
123 rouille::Response::text("Logged in!")
124 },
125 _ => rouille::Response::empty_404()
126 )
127}
128
129/// Handle to a running loopback callback server, used to await the OAuth redirect.
130pub struct CallbackHandle {
131 #[allow(dead_code)]
132 server_handle: std::thread::JoinHandle<()>,
133 server_stop: std::sync::mpsc::Sender<()>,
134 callback_rx: mpsc::Receiver<CallbackParams<'static>>,
135}
136
137/// One-shot OAuth callback server.
138///
139/// Starts an ephemeral in-process web server that listens for the OAuth
140/// callback redirect. Returns the server address and a [`CallbackHandle`]
141/// that can be used to wait for the callback and stop the server.
142///
143/// Use in combination with [`handle_localhost_callback`] to handle the
144/// callback for the localhost loopback server.
145pub fn one_shot_server(addr: SocketAddr) -> (SocketAddr, CallbackHandle) {
146 let (tx, callback_rx) = mpsc::channel(5);
147 let server = Server::new(addr, move |request| {
148 create_callback_router(request, tx.clone())
149 })
150 .expect("Could not start server");
151 let (server_handle, server_stop) = server.stoppable();
152 let handle = CallbackHandle {
153 server_handle,
154 server_stop,
155 callback_rx,
156 };
157 (addr, handle)
158}
159
160/// Handles the OAuth callback for the localhost loopback server.
161///
162/// Returns a session if the callback succeeds within the configured timeout
163/// and shuts down the server.
164pub async fn handle_localhost_callback<T, S>(
165 handle: CallbackHandle,
166 flow_client: &super::client::OAuthClient<T, S>,
167 cfg: &LoopbackConfig,
168) -> crate::error::Result<super::client::OAuthSession<T, S>>
169where
170 T: OAuthResolver + DpopExt + Send + Sync + 'static,
171 S: ClientAuthStore + Send + Sync + 'static,
172{
173 // Await callback or timeout
174 let mut callback_rx = handle.callback_rx;
175 let cb = tokio::time::timeout(
176 std::time::Duration::from_millis(cfg.timeout_ms),
177 callback_rx.recv(),
178 )
179 .await;
180 // trigger shutdown
181 let _ = handle.server_stop.send(());
182 if let Ok(Some(cb)) = cb {
183 // Handle callback and create a session
184 Ok(flow_client.callback(cb).await?)
185 } else {
186 Err(OAuthError::Callback(CallbackError::Timeout))
187 }
188}
189
190impl<T, S> OAuthClient<T, S>
191where
192 T: OAuthResolver + DpopExt + Send + Sync + 'static,
193 S: ClientAuthStore + Send + Sync + 'static,
194{
195 /// Drive the full OAuth flow using a local loopback server.
196 ///
197 /// This uses localhost OAuth and an ephemeral in-process web server to
198 /// handle the OAuth callback redirect. It has a bunch of nice friendly
199 /// defaults to help you get started and will basically drive the *entire*
200 /// callback flow itself.
201 ///
202 /// Best used for development and for small CLI applications that don't
203 /// require long session lengths. For long-running unattended sessions,
204 /// app passwords (via CredentialSession in the jacquard crate) remain
205 /// the best option. For more complex OAuth, or if you want more control
206 /// over the process, use the other methods on OAuthClient.
207 ///
208 /// 'input' parameter is what you type in the login box (usually, your handle)
209 /// for it to look up your PDS and redirect to its authentication interface.
210 ///
211 /// If the `browser-open` feature is enabled, this will open a web browser
212 /// for you to authenticate with your PDS. It will also print the
213 /// callback url to the console for you to copy.
214 pub async fn login_with_local_server(
215 &self,
216 input: impl AsRef<str>,
217 opts: AuthorizeOptions<'_>,
218 cfg: LoopbackConfig,
219 ) -> crate::error::Result<super::client::OAuthSession<T, S>> {
220 let port = match cfg.port {
221 LoopbackPort::Fixed(p) => p,
222 LoopbackPort::Ephemeral => 0,
223 };
224 // TODO: fix this to it also accepts ipv6 and properly finds a free port
225 let bind_addr: SocketAddr = format!("0.0.0.0:{}", port)
226 .parse()
227 .expect("invalid loopback host/port");
228 let (local_addr, handle) = one_shot_server(bind_addr);
229 println!("Listening on {}", local_addr);
230
231 let client_data = self.build_localhost_client_data(&cfg, &opts, local_addr);
232 // Build client using store and resolver
233 let flow_client = OAuthClient::new_with_shared(
234 self.registry.store.clone(),
235 self.client.clone(),
236 client_data,
237 );
238
239 // Start auth and get authorization URL
240 let auth_url = flow_client.start_auth(input.as_ref(), opts).await?;
241 // Print URL for copy/paste
242 println!("To authenticate with your PDS, visit:\n{}\n", auth_url);
243 // Optionally open browser
244 if cfg.open_browser {
245 let _ = try_open_in_browser(&auth_url);
246 }
247
248 handle_localhost_callback(handle, &flow_client, &cfg).await
249 }
250
251 /// Builds a [`crate::session::ClientData`] for use with the local loopback server method of OAuth.
252 pub fn build_localhost_client_data(
253 &self,
254 cfg: &LoopbackConfig,
255 opts: &AuthorizeOptions<'_>,
256 local_addr: SocketAddr,
257 ) -> crate::session::ClientData<'static> {
258 let redirect_uri = format!("http://{}:{}/oauth/callback", cfg.host, local_addr.port(),);
259 let redirect = Uri::parse(redirect_uri).unwrap();
260
261 let scopes = if opts.scopes.is_empty() {
262 Some(self.registry.client_data.config.scopes.clone())
263 } else {
264 Some(opts.scopes.clone().into_static())
265 };
266
267 crate::session::ClientData {
268 keyset: self.registry.client_data.keyset.clone(),
269 config: AtprotoClientMetadata::new_localhost(Some(vec![redirect]), scopes),
270 }
271 .into_static()
272 }
273}