learn and share notes on atproto (wip) 馃
malfestio.stormlightlabs.org/
readability
solid
axum
atproto
srs
1//! OAuth API endpoints for AT Protocol authentication.
2//!
3//! Provides endpoints for:
4//! - Starting the OAuth authorization flow
5//! - Handling OAuth callbacks
6//! - Refreshing tokens
7
8use crate::db::DbPool;
9use crate::oauth::flow::{OAuthFlow, SessionStore, generate_state, new_session_store};
10use crate::repository::oauth::{DbOAuthRepository, OAuthRepository, StoreTokensRequest};
11use axum::{
12 Json,
13 extract::{Query, State},
14 http::StatusCode,
15 response::{IntoResponse, Redirect},
16};
17use chrono::{Duration, Utc};
18use serde::{Deserialize, Serialize};
19use serde_json::json;
20use std::sync::Arc;
21
22/// Shared OAuth state with database repository.
23pub struct OAuthState {
24 pub flow: OAuthFlow,
25 pub sessions: SessionStore,
26 pub repo: Arc<dyn OAuthRepository>,
27}
28
29impl OAuthState {
30 /// Create OAuth state with database connection.
31 pub fn with_pool(pool: DbPool) -> Self {
32 Self { flow: OAuthFlow::new(), sessions: new_session_store(), repo: Arc::new(DbOAuthRepository::new(pool)) }
33 }
34
35 /// Create OAuth state without database (for testing).
36 pub fn new() -> Self {
37 Self { flow: OAuthFlow::new(), sessions: new_session_store(), repo: Arc::new(MockOAuthRepository) }
38 }
39}
40
41impl Default for OAuthState {
42 fn default() -> Self {
43 Self::new()
44 }
45}
46
47/// Mock repository for testing.
48struct MockOAuthRepository;
49
50#[async_trait::async_trait]
51impl OAuthRepository for MockOAuthRepository {
52 async fn store_tokens(&self, _req: StoreTokensRequest<'_>) -> Result<(), crate::repository::oauth::OAuthRepoError> {
53 Ok(())
54 }
55
56 async fn store_app_password_session(
57 &self, _req: crate::repository::oauth::StoreAppPasswordSessionRequest<'_>,
58 ) -> Result<(), crate::repository::oauth::OAuthRepoError> {
59 Ok(())
60 }
61
62 async fn get_tokens(
63 &self, did: &str,
64 ) -> Result<crate::repository::oauth::StoredToken, crate::repository::oauth::OAuthRepoError> {
65 Err(crate::repository::oauth::OAuthRepoError::NotFound(did.to_string()))
66 }
67
68 async fn get_token_by_access_token(
69 &self, _access_token: &str,
70 ) -> Result<crate::repository::oauth::StoredToken, crate::repository::oauth::OAuthRepoError> {
71 Err(crate::repository::oauth::OAuthRepoError::NotFound(
72 "Mock impl".to_string(),
73 ))
74 }
75
76 async fn update_tokens(
77 &self, _did: &str, _access_token: &str, _refresh_token: Option<&str>,
78 _expires_at: Option<chrono::DateTime<Utc>>,
79 ) -> Result<(), crate::repository::oauth::OAuthRepoError> {
80 Ok(())
81 }
82
83 async fn delete_tokens(&self, _did: &str) -> Result<(), crate::repository::oauth::OAuthRepoError> {
84 Ok(())
85 }
86}
87
88/// Request to start OAuth authorization.
89#[derive(Deserialize)]
90pub struct AuthorizeRequest {
91 /// Handle or DID to authenticate
92 pub handle: String,
93}
94
95/// Response from starting authorization.
96#[derive(Serialize)]
97pub struct AuthorizeResponse {
98 /// URL to redirect the user to
99 pub authorization_url: String,
100 /// State parameter (for CSRF protection)
101 pub state: String,
102}
103
104/// Query parameters from OAuth callback.
105#[derive(Deserialize)]
106pub struct CallbackQuery {
107 pub code: Option<String>,
108 pub state: String,
109 #[serde(default)]
110 pub error: Option<String>,
111 #[serde(default)]
112 pub error_description: Option<String>,
113}
114
115/// Start the OAuth authorization flow.
116///
117/// POST /api/oauth/authorize
118/// Body: { "handle": "alice.bsky.social" }
119pub async fn authorize(
120 State(oauth): State<Arc<OAuthState>>, Json(payload): Json<AuthorizeRequest>,
121) -> impl IntoResponse {
122 tracing::info!("OAuth authorization request received for handle: {}", payload.handle);
123
124 let state = generate_state();
125 tracing::debug!("Generated state parameter: {}", state);
126
127 match oauth
128 .flow
129 .start_authorization(&payload.handle, &state, &oauth.sessions)
130 .await
131 {
132 Ok(auth_url) => {
133 tracing::info!(
134 "OAuth authorization started successfully for handle: {}",
135 payload.handle
136 );
137 (
138 StatusCode::OK,
139 Json(AuthorizeResponse { authorization_url: auth_url, state }),
140 )
141 .into_response()
142 }
143 Err(e) => {
144 tracing::error!("OAuth authorization failed for handle {}: {}", payload.handle, e);
145 (StatusCode::BAD_REQUEST, Json(json!({ "error": e.to_string() }))).into_response()
146 }
147 }
148}
149
150/// Handle OAuth callback from authorization server.
151///
152/// GET /api/oauth/callback?code=...&state=...
153pub async fn callback(State(oauth): State<Arc<OAuthState>>, Query(params): Query<CallbackQuery>) -> impl IntoResponse {
154 tracing::info!("OAuth callback received with state: {}", params.state);
155
156 if let Some(error) = params.error {
157 let description = params.error_description.unwrap_or_default();
158 tracing::error!("OAuth authorization error: {} - {}", error, description);
159 return Redirect::to(&format!(
160 "/login?error={}&description={}",
161 urlencoding::encode(&error),
162 urlencoding::encode(&description)
163 ))
164 .into_response();
165 }
166
167 let code = match params.code {
168 Some(c) => c,
169 None => {
170 tracing::error!("OAuth callback missing authorization code");
171 return Redirect::to("/login?error=missing_code").into_response();
172 }
173 };
174
175 tracing::debug!("Retrieving session for state: {}", params.state);
176 let session = {
177 let sessions = oauth.sessions.read().unwrap();
178 sessions.get(¶ms.state).cloned()
179 };
180
181 let session = match session {
182 Some(s) => {
183 tracing::debug!("Session found for state: {}", params.state);
184 s
185 }
186 None => {
187 tracing::error!("Session not found for state: {}", params.state);
188 return Redirect::to("/login?error=session_not_found").into_response();
189 }
190 };
191
192 match oauth.flow.exchange_code(&code, ¶ms.state, &oauth.sessions).await {
193 Ok(tokens) => {
194 let did = session.did.clone().unwrap_or_default();
195 let pds_url = session.pds_url.unwrap_or_default();
196 let expires_at = tokens
197 .expires_in
198 .map(|secs| Utc::now() + Duration::seconds(secs as i64));
199
200 tracing::info!("Storing tokens for DID: {}", did);
201
202 if let Err(e) = oauth
203 .repo
204 .store_tokens(StoreTokensRequest {
205 did: &did,
206 pds_url: &pds_url,
207 access_token: &tokens.access_token,
208 refresh_token: tokens.refresh_token.as_deref(),
209 token_type: &tokens.token_type,
210 expires_at,
211 dpop_keypair: &session.dpop_keypair,
212 })
213 .await
214 {
215 tracing::error!("Failed to store tokens for DID {}: {}", did, e);
216 return Redirect::to(&format!("/login?error={}", urlencoding::encode("token_storage_failed")))
217 .into_response();
218 }
219
220 tracing::info!("OAuth flow completed successfully for DID: {}", did);
221
222 let handle = match oauth.flow.resolve_did(&did).await {
223 Ok(identity) => identity.handle.unwrap_or(did.clone()),
224 Err(e) => {
225 tracing::warn!("Failed to resolve handle for DID {}: {}", did, e);
226 did.clone()
227 }
228 };
229
230 let fragment = format!(
231 "accessJwt={}&refreshJwt={}&did={}&handle={}",
232 urlencoding::encode(&tokens.access_token),
233 urlencoding::encode(tokens.refresh_token.as_deref().unwrap_or("")),
234 urlencoding::encode(&did),
235 urlencoding::encode(&handle)
236 );
237 Redirect::to(&format!("/login/success#{}", fragment)).into_response()
238 }
239 Err(e) => {
240 tracing::error!("Token exchange failed: {}", e);
241 Redirect::to(&format!("/login?error={}", urlencoding::encode(&e.to_string()))).into_response()
242 }
243 }
244}
245
246/// Request to refresh tokens.
247#[derive(Deserialize)]
248pub struct RefreshRequest {
249 pub did: String,
250}
251
252/// Response from token refresh.
253#[derive(Serialize)]
254pub struct RefreshResponse {
255 pub success: bool,
256 pub expires_at: Option<String>,
257}
258
259/// Refresh an access token.
260///
261/// POST /api/oauth/refresh
262/// Body: { "did": "did:plc:..." }
263pub async fn refresh(State(oauth): State<Arc<OAuthState>>, Json(payload): Json<RefreshRequest>) -> impl IntoResponse {
264 tracing::info!("Token refresh request for DID: {}", payload.did);
265
266 tracing::debug!("Retrieving stored tokens from database for DID: {}", payload.did);
267 let stored = match oauth.repo.get_tokens(&payload.did).await {
268 Ok(t) => {
269 tracing::debug!("Found stored tokens for DID: {}", payload.did);
270 t
271 }
272 Err(e) => {
273 tracing::error!("Failed to retrieve stored tokens for DID {}: {}", payload.did, e);
274 return (StatusCode::NOT_FOUND, Json(json!({ "error": e.to_string() }))).into_response();
275 }
276 };
277
278 tracing::debug!("Reconstructing DPoP keypair from stored data");
279 let dpop_keypair = match stored.dpop_keypair() {
280 Some(kp) => kp,
281 None => {
282 tracing::error!("Failed to reconstruct DPoP keypair for DID: {}", payload.did);
283 return (
284 StatusCode::INTERNAL_SERVER_ERROR,
285 Json(json!({ "error": "Invalid stored keypair" })),
286 )
287 .into_response();
288 }
289 };
290
291 let refresh_token = match &stored.refresh_token {
292 Some(rt) => rt.clone(),
293 None => {
294 tracing::error!("No refresh token available for DID: {}", payload.did);
295 return (
296 StatusCode::BAD_REQUEST,
297 Json(json!({ "error": "No refresh token available" })),
298 )
299 .into_response();
300 }
301 };
302
303 match oauth
304 .flow
305 .refresh_token(&refresh_token, &stored.pds_url, &dpop_keypair)
306 .await
307 {
308 Ok(new_tokens) => {
309 let expires_at = new_tokens
310 .expires_in
311 .map(|secs| Utc::now() + Duration::seconds(secs as i64));
312
313 tracing::info!(
314 "Token refresh successful, updating stored tokens for DID: {}",
315 payload.did
316 );
317 if let Err(e) = oauth
318 .repo
319 .update_tokens(
320 &payload.did,
321 &new_tokens.access_token,
322 new_tokens.refresh_token.as_deref(),
323 expires_at,
324 )
325 .await
326 {
327 tracing::error!("Failed to update tokens in database for DID {}: {}", payload.did, e);
328 return (
329 StatusCode::INTERNAL_SERVER_ERROR,
330 Json(json!({ "error": "Failed to update tokens" })),
331 )
332 .into_response();
333 }
334
335 tracing::info!("Token refresh completed successfully for DID: {}", payload.did);
336 (
337 StatusCode::OK,
338 Json(RefreshResponse { success: true, expires_at: expires_at.map(|dt| dt.to_rfc3339()) }),
339 )
340 .into_response()
341 }
342 Err(e) => {
343 tracing::error!("Token refresh failed for DID {}: {}", payload.did, e);
344 (StatusCode::BAD_REQUEST, Json(json!({ "error": e.to_string() }))).into_response()
345 }
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn test_oauth_state_creation() {
355 let state = OAuthState::new();
356 assert!(state.sessions.read().unwrap().is_empty());
357 }
358
359 #[test]
360 fn test_authorize_request_deserialization() {
361 let json = r#"{"handle": "alice.bsky.social"}"#;
362 let request: AuthorizeRequest = serde_json::from_str(json).unwrap();
363 assert_eq!(request.handle, "alice.bsky.social");
364 }
365
366 #[test]
367 fn test_authorize_response_serialization() {
368 let response = AuthorizeResponse {
369 authorization_url: "https://example.com/oauth".to_string(),
370 state: "abc123".to_string(),
371 };
372 let json = serde_json::to_string(&response).unwrap();
373 assert!(json.contains("authorization_url"));
374 assert!(json.contains("state"));
375 }
376
377 #[test]
378 fn test_callback_query_deserialization() {
379 let query = "code=abc123&state=xyz789";
380 let parsed: CallbackQuery = serde_qs::from_str(query).unwrap();
381 assert_eq!(parsed.code, Some("abc123".to_string()));
382 assert_eq!(parsed.state, "xyz789");
383 assert!(parsed.error.is_none());
384 }
385
386 #[test]
387 fn test_callback_query_with_error() {
388 let query = "code=&state=xyz789&error=access_denied&error_description=User+denied";
389 let parsed: CallbackQuery = serde_qs::from_str(query).unwrap();
390 assert_eq!(parsed.error, Some("access_denied".to_string()));
391 }
392
393 #[test]
394 fn test_refresh_request_deserialization() {
395 let json = r#"{"did": "did:plc:abc123"}"#;
396 let request: RefreshRequest = serde_json::from_str(json).unwrap();
397 assert_eq!(request.did, "did:plc:abc123");
398 }
399}