learn and share notes on atproto (wip) 馃 malfestio.stormlightlabs.org/
readability solid axum atproto srs
5
fork

Configure Feed

Select the types of activity you want to include in your feed.

at main 399 lines 14 kB view raw
1//! OAuth API endpoints for AT Protocol authentication. 2//! 3//! Provides endpoints for: 4//! - Starting the OAuth authorization flow 5//! - Handling OAuth callbacks 6//! - Refreshing tokens 7 8use crate::db::DbPool; 9use crate::oauth::flow::{OAuthFlow, SessionStore, generate_state, new_session_store}; 10use crate::repository::oauth::{DbOAuthRepository, OAuthRepository, StoreTokensRequest}; 11use axum::{ 12 Json, 13 extract::{Query, State}, 14 http::StatusCode, 15 response::{IntoResponse, Redirect}, 16}; 17use chrono::{Duration, Utc}; 18use serde::{Deserialize, Serialize}; 19use serde_json::json; 20use std::sync::Arc; 21 22/// Shared OAuth state with database repository. 23pub struct OAuthState { 24 pub flow: OAuthFlow, 25 pub sessions: SessionStore, 26 pub repo: Arc<dyn OAuthRepository>, 27} 28 29impl OAuthState { 30 /// Create OAuth state with database connection. 31 pub fn with_pool(pool: DbPool) -> Self { 32 Self { flow: OAuthFlow::new(), sessions: new_session_store(), repo: Arc::new(DbOAuthRepository::new(pool)) } 33 } 34 35 /// Create OAuth state without database (for testing). 36 pub fn new() -> Self { 37 Self { flow: OAuthFlow::new(), sessions: new_session_store(), repo: Arc::new(MockOAuthRepository) } 38 } 39} 40 41impl Default for OAuthState { 42 fn default() -> Self { 43 Self::new() 44 } 45} 46 47/// Mock repository for testing. 48struct MockOAuthRepository; 49 50#[async_trait::async_trait] 51impl OAuthRepository for MockOAuthRepository { 52 async fn store_tokens(&self, _req: StoreTokensRequest<'_>) -> Result<(), crate::repository::oauth::OAuthRepoError> { 53 Ok(()) 54 } 55 56 async fn store_app_password_session( 57 &self, _req: crate::repository::oauth::StoreAppPasswordSessionRequest<'_>, 58 ) -> Result<(), crate::repository::oauth::OAuthRepoError> { 59 Ok(()) 60 } 61 62 async fn get_tokens( 63 &self, did: &str, 64 ) -> Result<crate::repository::oauth::StoredToken, crate::repository::oauth::OAuthRepoError> { 65 Err(crate::repository::oauth::OAuthRepoError::NotFound(did.to_string())) 66 } 67 68 async fn get_token_by_access_token( 69 &self, _access_token: &str, 70 ) -> Result<crate::repository::oauth::StoredToken, crate::repository::oauth::OAuthRepoError> { 71 Err(crate::repository::oauth::OAuthRepoError::NotFound( 72 "Mock impl".to_string(), 73 )) 74 } 75 76 async fn update_tokens( 77 &self, _did: &str, _access_token: &str, _refresh_token: Option<&str>, 78 _expires_at: Option<chrono::DateTime<Utc>>, 79 ) -> Result<(), crate::repository::oauth::OAuthRepoError> { 80 Ok(()) 81 } 82 83 async fn delete_tokens(&self, _did: &str) -> Result<(), crate::repository::oauth::OAuthRepoError> { 84 Ok(()) 85 } 86} 87 88/// Request to start OAuth authorization. 89#[derive(Deserialize)] 90pub struct AuthorizeRequest { 91 /// Handle or DID to authenticate 92 pub handle: String, 93} 94 95/// Response from starting authorization. 96#[derive(Serialize)] 97pub struct AuthorizeResponse { 98 /// URL to redirect the user to 99 pub authorization_url: String, 100 /// State parameter (for CSRF protection) 101 pub state: String, 102} 103 104/// Query parameters from OAuth callback. 105#[derive(Deserialize)] 106pub struct CallbackQuery { 107 pub code: Option<String>, 108 pub state: String, 109 #[serde(default)] 110 pub error: Option<String>, 111 #[serde(default)] 112 pub error_description: Option<String>, 113} 114 115/// Start the OAuth authorization flow. 116/// 117/// POST /api/oauth/authorize 118/// Body: { "handle": "alice.bsky.social" } 119pub async fn authorize( 120 State(oauth): State<Arc<OAuthState>>, Json(payload): Json<AuthorizeRequest>, 121) -> impl IntoResponse { 122 tracing::info!("OAuth authorization request received for handle: {}", payload.handle); 123 124 let state = generate_state(); 125 tracing::debug!("Generated state parameter: {}", state); 126 127 match oauth 128 .flow 129 .start_authorization(&payload.handle, &state, &oauth.sessions) 130 .await 131 { 132 Ok(auth_url) => { 133 tracing::info!( 134 "OAuth authorization started successfully for handle: {}", 135 payload.handle 136 ); 137 ( 138 StatusCode::OK, 139 Json(AuthorizeResponse { authorization_url: auth_url, state }), 140 ) 141 .into_response() 142 } 143 Err(e) => { 144 tracing::error!("OAuth authorization failed for handle {}: {}", payload.handle, e); 145 (StatusCode::BAD_REQUEST, Json(json!({ "error": e.to_string() }))).into_response() 146 } 147 } 148} 149 150/// Handle OAuth callback from authorization server. 151/// 152/// GET /api/oauth/callback?code=...&state=... 153pub async fn callback(State(oauth): State<Arc<OAuthState>>, Query(params): Query<CallbackQuery>) -> impl IntoResponse { 154 tracing::info!("OAuth callback received with state: {}", params.state); 155 156 if let Some(error) = params.error { 157 let description = params.error_description.unwrap_or_default(); 158 tracing::error!("OAuth authorization error: {} - {}", error, description); 159 return Redirect::to(&format!( 160 "/login?error={}&description={}", 161 urlencoding::encode(&error), 162 urlencoding::encode(&description) 163 )) 164 .into_response(); 165 } 166 167 let code = match params.code { 168 Some(c) => c, 169 None => { 170 tracing::error!("OAuth callback missing authorization code"); 171 return Redirect::to("/login?error=missing_code").into_response(); 172 } 173 }; 174 175 tracing::debug!("Retrieving session for state: {}", params.state); 176 let session = { 177 let sessions = oauth.sessions.read().unwrap(); 178 sessions.get(&params.state).cloned() 179 }; 180 181 let session = match session { 182 Some(s) => { 183 tracing::debug!("Session found for state: {}", params.state); 184 s 185 } 186 None => { 187 tracing::error!("Session not found for state: {}", params.state); 188 return Redirect::to("/login?error=session_not_found").into_response(); 189 } 190 }; 191 192 match oauth.flow.exchange_code(&code, &params.state, &oauth.sessions).await { 193 Ok(tokens) => { 194 let did = session.did.clone().unwrap_or_default(); 195 let pds_url = session.pds_url.unwrap_or_default(); 196 let expires_at = tokens 197 .expires_in 198 .map(|secs| Utc::now() + Duration::seconds(secs as i64)); 199 200 tracing::info!("Storing tokens for DID: {}", did); 201 202 if let Err(e) = oauth 203 .repo 204 .store_tokens(StoreTokensRequest { 205 did: &did, 206 pds_url: &pds_url, 207 access_token: &tokens.access_token, 208 refresh_token: tokens.refresh_token.as_deref(), 209 token_type: &tokens.token_type, 210 expires_at, 211 dpop_keypair: &session.dpop_keypair, 212 }) 213 .await 214 { 215 tracing::error!("Failed to store tokens for DID {}: {}", did, e); 216 return Redirect::to(&format!("/login?error={}", urlencoding::encode("token_storage_failed"))) 217 .into_response(); 218 } 219 220 tracing::info!("OAuth flow completed successfully for DID: {}", did); 221 222 let handle = match oauth.flow.resolve_did(&did).await { 223 Ok(identity) => identity.handle.unwrap_or(did.clone()), 224 Err(e) => { 225 tracing::warn!("Failed to resolve handle for DID {}: {}", did, e); 226 did.clone() 227 } 228 }; 229 230 let fragment = format!( 231 "accessJwt={}&refreshJwt={}&did={}&handle={}", 232 urlencoding::encode(&tokens.access_token), 233 urlencoding::encode(tokens.refresh_token.as_deref().unwrap_or("")), 234 urlencoding::encode(&did), 235 urlencoding::encode(&handle) 236 ); 237 Redirect::to(&format!("/login/success#{}", fragment)).into_response() 238 } 239 Err(e) => { 240 tracing::error!("Token exchange failed: {}", e); 241 Redirect::to(&format!("/login?error={}", urlencoding::encode(&e.to_string()))).into_response() 242 } 243 } 244} 245 246/// Request to refresh tokens. 247#[derive(Deserialize)] 248pub struct RefreshRequest { 249 pub did: String, 250} 251 252/// Response from token refresh. 253#[derive(Serialize)] 254pub struct RefreshResponse { 255 pub success: bool, 256 pub expires_at: Option<String>, 257} 258 259/// Refresh an access token. 260/// 261/// POST /api/oauth/refresh 262/// Body: { "did": "did:plc:..." } 263pub async fn refresh(State(oauth): State<Arc<OAuthState>>, Json(payload): Json<RefreshRequest>) -> impl IntoResponse { 264 tracing::info!("Token refresh request for DID: {}", payload.did); 265 266 tracing::debug!("Retrieving stored tokens from database for DID: {}", payload.did); 267 let stored = match oauth.repo.get_tokens(&payload.did).await { 268 Ok(t) => { 269 tracing::debug!("Found stored tokens for DID: {}", payload.did); 270 t 271 } 272 Err(e) => { 273 tracing::error!("Failed to retrieve stored tokens for DID {}: {}", payload.did, e); 274 return (StatusCode::NOT_FOUND, Json(json!({ "error": e.to_string() }))).into_response(); 275 } 276 }; 277 278 tracing::debug!("Reconstructing DPoP keypair from stored data"); 279 let dpop_keypair = match stored.dpop_keypair() { 280 Some(kp) => kp, 281 None => { 282 tracing::error!("Failed to reconstruct DPoP keypair for DID: {}", payload.did); 283 return ( 284 StatusCode::INTERNAL_SERVER_ERROR, 285 Json(json!({ "error": "Invalid stored keypair" })), 286 ) 287 .into_response(); 288 } 289 }; 290 291 let refresh_token = match &stored.refresh_token { 292 Some(rt) => rt.clone(), 293 None => { 294 tracing::error!("No refresh token available for DID: {}", payload.did); 295 return ( 296 StatusCode::BAD_REQUEST, 297 Json(json!({ "error": "No refresh token available" })), 298 ) 299 .into_response(); 300 } 301 }; 302 303 match oauth 304 .flow 305 .refresh_token(&refresh_token, &stored.pds_url, &dpop_keypair) 306 .await 307 { 308 Ok(new_tokens) => { 309 let expires_at = new_tokens 310 .expires_in 311 .map(|secs| Utc::now() + Duration::seconds(secs as i64)); 312 313 tracing::info!( 314 "Token refresh successful, updating stored tokens for DID: {}", 315 payload.did 316 ); 317 if let Err(e) = oauth 318 .repo 319 .update_tokens( 320 &payload.did, 321 &new_tokens.access_token, 322 new_tokens.refresh_token.as_deref(), 323 expires_at, 324 ) 325 .await 326 { 327 tracing::error!("Failed to update tokens in database for DID {}: {}", payload.did, e); 328 return ( 329 StatusCode::INTERNAL_SERVER_ERROR, 330 Json(json!({ "error": "Failed to update tokens" })), 331 ) 332 .into_response(); 333 } 334 335 tracing::info!("Token refresh completed successfully for DID: {}", payload.did); 336 ( 337 StatusCode::OK, 338 Json(RefreshResponse { success: true, expires_at: expires_at.map(|dt| dt.to_rfc3339()) }), 339 ) 340 .into_response() 341 } 342 Err(e) => { 343 tracing::error!("Token refresh failed for DID {}: {}", payload.did, e); 344 (StatusCode::BAD_REQUEST, Json(json!({ "error": e.to_string() }))).into_response() 345 } 346 } 347} 348 349#[cfg(test)] 350mod tests { 351 use super::*; 352 353 #[test] 354 fn test_oauth_state_creation() { 355 let state = OAuthState::new(); 356 assert!(state.sessions.read().unwrap().is_empty()); 357 } 358 359 #[test] 360 fn test_authorize_request_deserialization() { 361 let json = r#"{"handle": "alice.bsky.social"}"#; 362 let request: AuthorizeRequest = serde_json::from_str(json).unwrap(); 363 assert_eq!(request.handle, "alice.bsky.social"); 364 } 365 366 #[test] 367 fn test_authorize_response_serialization() { 368 let response = AuthorizeResponse { 369 authorization_url: "https://example.com/oauth".to_string(), 370 state: "abc123".to_string(), 371 }; 372 let json = serde_json::to_string(&response).unwrap(); 373 assert!(json.contains("authorization_url")); 374 assert!(json.contains("state")); 375 } 376 377 #[test] 378 fn test_callback_query_deserialization() { 379 let query = "code=abc123&state=xyz789"; 380 let parsed: CallbackQuery = serde_qs::from_str(query).unwrap(); 381 assert_eq!(parsed.code, Some("abc123".to_string())); 382 assert_eq!(parsed.state, "xyz789"); 383 assert!(parsed.error.is_none()); 384 } 385 386 #[test] 387 fn test_callback_query_with_error() { 388 let query = "code=&state=xyz789&error=access_denied&error_description=User+denied"; 389 let parsed: CallbackQuery = serde_qs::from_str(query).unwrap(); 390 assert_eq!(parsed.error, Some("access_denied".to_string())); 391 } 392 393 #[test] 394 fn test_refresh_request_deserialization() { 395 let json = r#"{"did": "did:plc:abc123"}"#; 396 let request: RefreshRequest = serde_json::from_str(json).unwrap(); 397 assert_eq!(request.did, "did:plc:abc123"); 398 } 399}