Microservice to bring 2FA to self hosted PDSes
91
fork

Configure Feed

Select the types of activity you want to include in your feed.

at sendmail 456 lines 15 kB view raw
1use super::{AuthRules, HandleCache, SessionData}; 2use crate::helpers::json_error_response; 3use crate::AppState; 4use axum::extract::{Request, State}; 5use axum::http::{HeaderMap, StatusCode}; 6use axum::middleware::Next; 7use axum::response::{IntoResponse, Response}; 8use jacquard_identity::resolver::IdentityResolver; 9use jacquard_identity::PublicResolver; 10use jwt_compact::alg::{Hs256, Hs256Key}; 11use jwt_compact::{AlgorithmExt, Claims, Token, UntrustedToken, ValidationError}; 12use serde::{Deserialize, Serialize}; 13use std::env; 14use std::sync::Arc; 15use tracing::log; 16 17#[derive(Clone, Copy, Debug, PartialEq, Eq)] 18pub enum AuthScheme { 19 Bearer, 20 DPoP, 21} 22 23#[derive(Serialize, Deserialize)] 24pub struct TokenClaims { 25 pub sub: String, 26 /// OAuth scopes as space-separated string (per OAuth 2.0 spec) 27 #[serde(default)] 28 pub scope: Option<String>, 29} 30 31/// State passed to the auth middleware containing both AppState and auth rules. 32#[derive(Clone)] 33pub struct AuthMiddlewareState { 34 pub app_state: AppState, 35 pub rules: AuthRules, 36} 37 38/// Core middleware function that validates authentication and applies auth rules. 39/// 40/// Use this with `axum::middleware::from_fn_with_state`: 41/// ```ignore 42/// use axum::middleware::from_fn_with_state; 43/// 44/// let mw_state = AuthMiddlewareState { 45/// app_state: state.clone(), 46/// rules: AuthRules::HandleEndsWith(".blacksky.team".into()), 47/// }; 48/// 49/// .route("/protected", get(handler).layer(from_fn_with_state(mw_state, auth_middleware))) 50/// ``` 51pub async fn auth_middleware( 52 State(mw_state): State<AuthMiddlewareState>, 53 req: Request, 54 next: Next, 55) -> Response { 56 let AuthMiddlewareState { app_state, rules } = mw_state; 57 58 // 1. Extract DID and scopes from JWT (Bearer token) 59 let extracted = match extract_auth_from_request(req.headers()) { 60 Ok(Some(auth)) => auth, 61 Ok(None) => { 62 return json_error_response(StatusCode::UNAUTHORIZED, "AuthRequired", "Authentication required") 63 .unwrap_or_else(|_| StatusCode::UNAUTHORIZED.into_response()); 64 } 65 Err(e) => { 66 log::error!("Token extraction error: {}", e); 67 return json_error_response(StatusCode::UNAUTHORIZED, "InvalidToken", &e) 68 .unwrap_or_else(|_| StatusCode::UNAUTHORIZED.into_response()); 69 } 70 }; 71 72 // 2. Resolve DID to handle (check cache first) 73 let handle = match resolve_did_to_handle(&app_state.resolver, &app_state.handle_cache, &extracted.did).await { 74 Ok(handle) => handle, 75 Err(e) => { 76 log::error!("Failed to resolve DID {} to handle: {}", extracted.did, e); 77 return json_error_response( 78 StatusCode::INTERNAL_SERVER_ERROR, 79 "ResolutionError", 80 "Failed to resolve identity", 81 ) 82 .unwrap_or_else(|_| StatusCode::INTERNAL_SERVER_ERROR.into_response()); 83 } 84 }; 85 86 // 3. Build session data and validate rules 87 let session = SessionData { 88 did: extracted.did, 89 handle, 90 scopes: extracted.scopes, 91 }; 92 93 if !rules.validate(&session) { 94 return json_error_response(StatusCode::FORBIDDEN, "AccessDenied", "Access denied by authorization rules") 95 .unwrap_or_else(|_| StatusCode::FORBIDDEN.into_response()); 96 } 97 98 // 4. Pass through on success 99 next.run(req).await 100} 101 102/// Extracted authentication data from JWT 103struct ExtractedAuth { 104 did: String, 105 scopes: Vec<String>, 106} 107 108/// Extracts the DID and scopes from the Authorization header (Bearer JWT). 109fn extract_auth_from_request(headers: &HeaderMap) -> Result<Option<ExtractedAuth>, String> { 110 let auth = extract_auth(headers)?; 111 112 match auth { 113 None => Ok(None), 114 Some((scheme, token_str)) => { 115 match scheme { 116 AuthScheme::Bearer => { 117 let token = UntrustedToken::new(&token_str) 118 .map_err(|_| "Invalid token format".to_string())?; 119 120 let _claims: Claims<TokenClaims> = token 121 .deserialize_claims_unchecked() 122 .map_err(|_| "Failed to parse token claims".to_string())?; 123 124 let key = Hs256Key::new( 125 env::var("PDS_JWT_SECRET") 126 .map_err(|_| "PDS_JWT_SECRET not configured".to_string())?, 127 ); 128 129 let validated: Token<TokenClaims> = Hs256 130 .validator(&key) 131 .validate(&token) 132 .map_err(|e: ValidationError| format!("Token validation failed: {:?}", e))?; 133 134 let custom = &validated.claims().custom; 135 136 // Parse scopes from space-separated string (OAuth 2.0 spec) 137 let scopes: Vec<String> = custom.scope 138 .as_ref() 139 .map(|s| s.split_whitespace().map(|s| s.to_string()).collect()) 140 .unwrap_or_default(); 141 142 Ok(Some(ExtractedAuth { 143 did: custom.sub.clone(), 144 scopes, 145 })) 146 } 147 AuthScheme::DPoP => { 148 // DPoP tokens are not validated here; pass through without auth data 149 Ok(None) 150 } 151 } 152 } 153 } 154} 155 156/// Extracts the authentication scheme and token from the Authorization header. 157fn extract_auth(headers: &HeaderMap) -> Result<Option<(AuthScheme, String)>, String> { 158 match headers.get(axum::http::header::AUTHORIZATION) { 159 None => Ok(None), 160 Some(hv) => { 161 let s = hv 162 .to_str() 163 .map_err(|_| "Authorization header is not valid UTF-8".to_string())?; 164 165 let mut parts = s.splitn(2, ' '); 166 match (parts.next(), parts.next()) { 167 (Some("Bearer"), Some(tok)) if !tok.is_empty() => { 168 Ok(Some((AuthScheme::Bearer, tok.to_string()))) 169 } 170 (Some("DPoP"), Some(tok)) if !tok.is_empty() => { 171 Ok(Some((AuthScheme::DPoP, tok.to_string()))) 172 } 173 _ => Err( 174 "Authorization header must be in format 'Bearer <token>' or 'DPoP <token>'" 175 .to_string(), 176 ), 177 } 178 } 179 } 180} 181 182/// Resolves a DID to its handle using the PublicResolver, with caching. 183async fn resolve_did_to_handle( 184 resolver: &Arc<PublicResolver>, 185 cache: &HandleCache, 186 did: &str, 187) -> Result<String, String> { 188 // Check cache first 189 if let Some(handle) = cache.get(did) { 190 return Ok(handle); 191 } 192 193 // Parse the DID 194 let did_parsed = jacquard_common::types::did::Did::new(did) 195 .map_err(|e| format!("Invalid DID: {:?}", e))?; 196 197 // Resolve the DID document 198 let did_doc_response = resolver 199 .resolve_did_doc(&did_parsed) 200 .await 201 .map_err(|e| format!("DID resolution failed: {:?}", e))?; 202 203 let doc = did_doc_response 204 .parse() 205 .map_err(|e| format!("Failed to parse DID document: {:?}", e))?; 206 207 // Extract handle from alsoKnownAs field 208 // Format is typically: ["at://handle.example.com"] 209 let handle: String = doc 210 .also_known_as 211 .as_ref() 212 .and_then(|aka| { 213 aka.iter() 214 .find(|uri| uri.starts_with("at://")) 215 .map(|uri| uri.strip_prefix("at://").unwrap_or(uri.as_ref()).to_string()) 216 }) 217 .ok_or_else(|| "No ATProto handle found in DID document".to_string())?; 218 219 // Cache the result 220 cache.insert(did.to_string(), handle.clone()); 221 222 Ok(handle) 223} 224 225// ============================================================================ 226// Helper Functions for Creating Middleware State 227// ============================================================================ 228 229/// Creates an `AuthMiddlewareState` for requiring the handle to end with a specific suffix. 230/// 231/// # Example 232/// ```ignore 233/// use axum::middleware::from_fn_with_state; 234/// use crate::auth::{auth_middleware, handle_ends_with}; 235/// 236/// .route("/protected", get(handler).layer( 237/// from_fn_with_state(handle_ends_with(".blacksky.team", &state), auth_middleware) 238/// )) 239/// ``` 240pub fn handle_ends_with(suffix: impl Into<String>, state: &AppState) -> AuthMiddlewareState { 241 AuthMiddlewareState { 242 app_state: state.clone(), 243 rules: AuthRules::HandleEndsWith(suffix.into()), 244 } 245} 246 247/// Creates an `AuthMiddlewareState` for requiring the handle to end with any of the specified suffixes. 248pub fn handle_ends_with_any<I, T>(suffixes: I, state: &AppState) -> AuthMiddlewareState 249where 250 I: IntoIterator<Item = T>, 251 T: Into<String>, 252{ 253 AuthMiddlewareState { 254 app_state: state.clone(), 255 rules: AuthRules::HandleEndsWithAny(suffixes.into_iter().map(|s| s.into()).collect()), 256 } 257} 258 259/// Creates an `AuthMiddlewareState` for requiring the DID to equal a specific value. 260pub fn did_equals(did: impl Into<String>, state: &AppState) -> AuthMiddlewareState { 261 AuthMiddlewareState { 262 app_state: state.clone(), 263 rules: AuthRules::DidEquals(did.into()), 264 } 265} 266 267/// Creates an `AuthMiddlewareState` for requiring the DID to be one of the specified values. 268pub fn did_equals_any<I, T>(dids: I, state: &AppState) -> AuthMiddlewareState 269where 270 I: IntoIterator<Item = T>, 271 T: Into<String>, 272{ 273 AuthMiddlewareState { 274 app_state: state.clone(), 275 rules: AuthRules::DidEqualsAny(dids.into_iter().map(|d| d.into()).collect()), 276 } 277} 278 279/// Creates an `AuthMiddlewareState` with custom auth rules. 280pub fn with_rules(rules: AuthRules, state: &AppState) -> AuthMiddlewareState { 281 AuthMiddlewareState { 282 app_state: state.clone(), 283 rules, 284 } 285} 286 287// ============================================================================ 288// Scope Helper Functions 289// ============================================================================ 290 291/// Creates an `AuthMiddlewareState` requiring a specific OAuth scope. 292/// 293/// # Example 294/// ```ignore 295/// .route("/xrpc/com.atproto.repo.createRecord", 296/// post(handler).layer(from_fn_with_state( 297/// scope_equals("repo:app.bsky.feed.post", &state), 298/// auth_middleware 299/// ))) 300/// ``` 301pub fn scope_equals(scope: impl Into<String>, state: &AppState) -> AuthMiddlewareState { 302 AuthMiddlewareState { 303 app_state: state.clone(), 304 rules: AuthRules::ScopeEquals(scope.into()), 305 } 306} 307 308/// Creates an `AuthMiddlewareState` requiring ANY of the specified scopes (OR logic). 309/// 310/// # Example 311/// ```ignore 312/// .route("/xrpc/com.atproto.repo.putRecord", 313/// post(handler).layer(from_fn_with_state( 314/// scope_any(["repo:app.bsky.feed.post", "transition:generic"], &state), 315/// auth_middleware 316/// ))) 317/// ``` 318pub fn scope_any<I, T>(scopes: I, state: &AppState) -> AuthMiddlewareState 319where 320 I: IntoIterator<Item = T>, 321 T: Into<String>, 322{ 323 AuthMiddlewareState { 324 app_state: state.clone(), 325 rules: AuthRules::ScopeEqualsAny(scopes.into_iter().map(|s| s.into()).collect()), 326 } 327} 328 329/// Creates an `AuthMiddlewareState` requiring ALL of the specified scopes (AND logic). 330/// 331/// # Example 332/// ```ignore 333/// .route("/xrpc/com.atproto.admin.updateAccount", 334/// post(handler).layer(from_fn_with_state( 335/// scope_all(["account:email", "account:repo?action=manage"], &state), 336/// auth_middleware 337/// ))) 338/// ``` 339pub fn scope_all<I, T>(scopes: I, state: &AppState) -> AuthMiddlewareState 340where 341 I: IntoIterator<Item = T>, 342 T: Into<String>, 343{ 344 AuthMiddlewareState { 345 app_state: state.clone(), 346 rules: AuthRules::ScopeEqualsAll(scopes.into_iter().map(|s| s.into()).collect()), 347 } 348} 349 350// ============================================================================ 351// Combined Rule Helpers (Identity + Scope) 352// ============================================================================ 353 354/// Creates an `AuthMiddlewareState` requiring handle to end with suffix AND have a specific scope. 355/// 356/// # Example 357/// ```ignore 358/// .route("/xrpc/community.blacksky.feed.generator", 359/// post(handler).layer(from_fn_with_state( 360/// handle_ends_with_and_scope(".blacksky.team", "transition:generic", &state), 361/// auth_middleware 362/// ))) 363/// ``` 364pub fn handle_ends_with_and_scope( 365 suffix: impl Into<String>, 366 scope: impl Into<String>, 367 state: &AppState, 368) -> AuthMiddlewareState { 369 AuthMiddlewareState { 370 app_state: state.clone(), 371 rules: AuthRules::All(vec![ 372 AuthRules::HandleEndsWith(suffix.into()), 373 AuthRules::ScopeEquals(scope.into()), 374 ]), 375 } 376} 377 378/// Creates an `AuthMiddlewareState` requiring handle to end with suffix AND have ALL specified scopes. 379/// 380/// # Example 381/// ```ignore 382/// .route("/xrpc/community.blacksky.admin.manage", 383/// post(handler).layer(from_fn_with_state( 384/// handle_ends_with_and_scopes(".blacksky.team", ["transition:generic", "identity:*"], &state), 385/// auth_middleware 386/// ))) 387/// ``` 388pub fn handle_ends_with_and_scopes<I, T>( 389 suffix: impl Into<String>, 390 scopes: I, 391 state: &AppState, 392) -> AuthMiddlewareState 393where 394 I: IntoIterator<Item = T>, 395 T: Into<String>, 396{ 397 AuthMiddlewareState { 398 app_state: state.clone(), 399 rules: AuthRules::All(vec![ 400 AuthRules::HandleEndsWith(suffix.into()), 401 AuthRules::ScopeEqualsAll(scopes.into_iter().map(|s| s.into()).collect()), 402 ]), 403 } 404} 405 406/// Creates an `AuthMiddlewareState` requiring DID to equal value AND have a specific scope. 407/// 408/// # Example 409/// ```ignore 410/// .route("/xrpc/com.atproto.admin.deleteAccount", 411/// post(handler).layer(from_fn_with_state( 412/// did_with_scope("did:plc:rnpkyqnmsw4ipey6eotbdnnf", "transition:generic", &state), 413/// auth_middleware 414/// ))) 415/// ``` 416pub fn did_with_scope( 417 did: impl Into<String>, 418 scope: impl Into<String>, 419 state: &AppState, 420) -> AuthMiddlewareState { 421 AuthMiddlewareState { 422 app_state: state.clone(), 423 rules: AuthRules::All(vec![ 424 AuthRules::DidEquals(did.into()), 425 AuthRules::ScopeEquals(scope.into()), 426 ]), 427 } 428} 429 430/// Creates an `AuthMiddlewareState` requiring DID to equal value AND have ALL specified scopes. 431/// 432/// # Example 433/// ```ignore 434/// .route("/xrpc/com.atproto.admin.fullAccess", 435/// post(handler).layer(from_fn_with_state( 436/// did_with_scopes("did:plc:rnpkyqnmsw4ipey6eotbdnnf", ["transition:generic", "identity:*"], &state), 437/// auth_middleware 438/// ))) 439/// ``` 440pub fn did_with_scopes<I, T>( 441 did: impl Into<String>, 442 scopes: I, 443 state: &AppState, 444) -> AuthMiddlewareState 445where 446 I: IntoIterator<Item = T>, 447 T: Into<String>, 448{ 449 AuthMiddlewareState { 450 app_state: state.clone(), 451 rules: AuthRules::All(vec![ 452 AuthRules::DidEquals(did.into()), 453 AuthRules::ScopeEqualsAll(scopes.into_iter().map(|s| s.into()).collect()), 454 ]), 455 } 456}