An easy-to-host PDS on the ATProtocol, iPhone and MacOS. Maintain control of your keys and data, always.
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

feat(relay): auth middleware — Bearer/DPoP JWT validation extractor (MM-139)

Adds crates/relay/src/auth/mod.rs with an Axum FromRequestParts extractor
that validates HS256 JWT access/refresh tokens and DPoP-bound tokens per
RFC 9449.

- AuthenticatedUser extractor yields did, AuthScope, and TokenType
- JWT validation: HS256 signature, exp (leeway=0), aud (when server_did set), sub
- DPoP validation: JWK thumbprint (RFC 7638), htm/htu, iat freshness, cnf.jkt binding
- New error codes: AuthenticationRequired and InvalidToken (both 401)
- jwt_secret ([u8; 32]) added to AppState; generated via OsRng at startup
- 12 unit tests covering all AC cases including RFC 7638 thumbprint vector

authored by

Malpercio and committed by
Tangled
e61676d3 3e0f7808

+715 -3
+51 -3
Cargo.lock
··· 2551 2551 ] 2552 2552 2553 2553 [[package]] 2554 + name = "jsonwebtoken" 2555 + version = "9.3.1" 2556 + source = "registry+https://github.com/rust-lang/crates.io-index" 2557 + checksum = "5a87cc7a48537badeae96744432de36f4be2b4a34a05a5ef32e9dd8a1c169dde" 2558 + dependencies = [ 2559 + "base64 0.22.1", 2560 + "js-sys", 2561 + "pem", 2562 + "ring", 2563 + "serde", 2564 + "serde_json", 2565 + "simple_asn1", 2566 + ] 2567 + 2568 + [[package]] 2554 2569 name = "keyboard-types" 2555 2570 version = "0.7.0" 2556 2571 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 2932 2947 checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" 2933 2948 dependencies = [ 2934 2949 "windows-sys 0.61.2", 2950 + ] 2951 + 2952 + [[package]] 2953 + name = "num-bigint" 2954 + version = "0.4.6" 2955 + source = "registry+https://github.com/rust-lang/crates.io-index" 2956 + checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" 2957 + dependencies = [ 2958 + "num-integer", 2959 + "num-traits", 2935 2960 ] 2936 2961 2937 2962 [[package]] ··· 3362 3387 ] 3363 3388 3364 3389 [[package]] 3390 + name = "pem" 3391 + version = "3.0.6" 3392 + source = "registry+https://github.com/rust-lang/crates.io-index" 3393 + checksum = "1d30c53c26bc5b31a98cd02d20f25a7c8567146caf63ed593a9d87b2775291be" 3394 + dependencies = [ 3395 + "base64 0.22.1", 3396 + "serde_core", 3397 + ] 3398 + 3399 + [[package]] 3365 3400 name = "pem-rfc7468" 3366 3401 version = "0.7.0" 3367 3402 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 3833 3868 "quinn-udp", 3834 3869 "rustc-hash", 3835 3870 "rustls", 3836 - "socket2 0.5.10", 3871 + "socket2 0.6.3", 3837 3872 "thiserror 2.0.18", 3838 3873 "tokio", 3839 3874 "tracing", ··· 3870 3905 "cfg_aliases", 3871 3906 "libc", 3872 3907 "once_cell", 3873 - "socket2 0.5.10", 3908 + "socket2 0.6.3", 3874 3909 "tracing", 3875 - "windows-sys 0.52.0", 3910 + "windows-sys 0.60.2", 3876 3911 ] 3877 3912 3878 3913 [[package]] ··· 4102 4137 "crypto", 4103 4138 "data-encoding", 4104 4139 "hickory-resolver", 4140 + "jsonwebtoken", 4105 4141 "opentelemetry", 4106 4142 "opentelemetry-otlp", 4107 4143 "opentelemetry_sdk", ··· 4745 4781 version = "0.3.8" 4746 4782 source = "registry+https://github.com/rust-lang/crates.io-index" 4747 4783 checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" 4784 + 4785 + [[package]] 4786 + name = "simple_asn1" 4787 + version = "0.6.4" 4788 + source = "registry+https://github.com/rust-lang/crates.io-index" 4789 + checksum = "0d585997b0ac10be3c5ee635f1bab02d512760d14b7c468801ac8a01d9ae5f1d" 4790 + dependencies = [ 4791 + "num-bigint", 4792 + "num-traits", 4793 + "thiserror 2.0.18", 4794 + "time", 4795 + ] 4748 4796 4749 4797 [[package]] 4750 4798 name = "siphasher"
+3
Cargo.toml
··· 70 70 subtle = "2" 71 71 uuid = { version = "1", features = ["v4"] } 72 72 73 + # JWT (relay auth) 74 + jsonwebtoken = "9" 75 + 73 76 # ATProto handle resolution — DNS TXT fallback (relay) 74 77 hickory-resolver = { version = "0.25", features = ["tokio", "system-config"] } 75 78
+6
crates/common/src/error.rs
··· 46 46 DnsError, 47 47 /// The requested handle does not resolve to a known DID locally or via DNS. 48 48 HandleNotFound, 49 + /// Missing or absent Authorization header on a protected endpoint. 50 + AuthenticationRequired, 51 + /// Token is structurally invalid, has wrong signature, wrong audience, or DPoP mismatch. 52 + InvalidToken, 49 53 // TODO: add remaining codes from Appendix A as endpoints are implemented: 50 54 // 400: INVALID_DOCUMENT, INVALID_PROOF, INVALID_ENDPOINT, INVALID_CONFIRMATION 51 55 // 401: INVALID_CREDENTIALS ··· 81 85 ErrorCode::PlcDirectoryError => 502, 82 86 ErrorCode::DnsError => 502, 83 87 ErrorCode::HandleNotFound => 404, 88 + ErrorCode::AuthenticationRequired => 401, 89 + ErrorCode::InvalidToken => 401, 84 90 } 85 91 } 86 92 }
+1
crates/relay/Cargo.toml
··· 37 37 uuid = { workspace = true } 38 38 zeroize = { workspace = true } 39 39 hickory-resolver = { workspace = true } 40 + jsonwebtoken = { workspace = true } 40 41 41 42 [dev-dependencies] 42 43 tower = { workspace = true }
+5
crates/relay/src/app.rs
··· 96 96 /// Used as the third step after local DB and DNS TXT: calls 97 97 /// `GET https://<handle>/.well-known/atproto-did`. 98 98 pub well_known_resolver: Option<Arc<dyn WellKnownResolver>>, 99 + /// HS256 signing secret for JWT access/refresh tokens. 100 + /// Loaded from EZPDS_JWT_SECRET (hex-encoded) or generated randomly at startup. 101 + pub jwt_secret: [u8; 32], 99 102 } 100 103 101 104 /// Build the Axum router with middleware and routes. ··· 185 188 dns_provider: None, 186 189 txt_resolver: None, 187 190 well_known_resolver: None, 191 + // Fixed key for tests — predictable JWTs in unit tests. 192 + jwt_secret: [0x42u8; 32], 188 193 } 189 194 } 190 195
+637
crates/relay/src/auth/mod.rs
··· 1 + // Dead-code lint suppressed: this module is foundational infrastructure. 2 + // Items will be used once authenticated routes are wired up in subsequent waves. 3 + #![allow(dead_code)] 4 + 5 + use axum::{ 6 + async_trait, 7 + extract::FromRequestParts, 8 + http::{request::Parts, Method}, 9 + }; 10 + use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; 11 + use common::{ApiError, ErrorCode}; 12 + use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation}; 13 + use serde::Deserialize; 14 + use sha2::{Digest, Sha256}; 15 + 16 + use crate::app::AppState; 17 + 18 + // ── Public types ───────────────────────────────────────────────────────────── 19 + 20 + /// Scope embedded in the JWT `scope` claim. 21 + #[derive(Debug, Clone, PartialEq, Eq)] 22 + pub enum AuthScope { 23 + Access, 24 + Refresh, 25 + AppPass, 26 + } 27 + 28 + /// Whether this token was presented as a plain Bearer or a DPoP-bound token. 29 + #[derive(Debug, Clone, PartialEq, Eq)] 30 + pub enum TokenType { 31 + /// Simple Bearer JWT issued by `createSession`. 32 + Legacy, 33 + /// DPoP-bound token (RFC 9449). 34 + DPoP, 35 + } 36 + 37 + /// Axum extractor that validates a Bearer (or DPoP-bound) JWT and yields the 38 + /// authenticated caller's DID, scope, and token type. 39 + /// 40 + /// Extract this in any handler that requires authentication: 41 + /// ```rust,ignore 42 + /// async fn my_handler(user: AuthenticatedUser) -> impl IntoResponse { ... } 43 + /// ``` 44 + #[derive(Debug, Clone)] 45 + pub struct AuthenticatedUser { 46 + pub did: String, 47 + pub scope: AuthScope, 48 + pub token_type: TokenType, 49 + } 50 + 51 + // ── JWT claims ─────────────────────────────────────────────────────────────── 52 + 53 + /// Claims decoded from the server-issued access/refresh JWT. 54 + #[derive(Debug, Deserialize)] 55 + struct AccessTokenClaims { 56 + /// Subject — the authenticated DID. 57 + sub: String, 58 + /// Scope string from the AT Protocol spec. 59 + scope: String, 60 + /// Confirmation claim — present on DPoP-bound tokens. 61 + cnf: Option<CnfClaim>, 62 + } 63 + 64 + /// `cnf` (confirmation) claim carrying the JWK thumbprint for DPoP binding. 65 + #[derive(Debug, Deserialize)] 66 + struct CnfClaim { 67 + /// JWK SHA-256 thumbprint (base64url, no padding) of the client's DPoP key. 68 + jkt: Option<String>, 69 + } 70 + 71 + // ── DPoP JWT header + claims ───────────────────────────────────────────────── 72 + 73 + /// Decoded DPoP proof JWT header fields relevant to validation. 74 + #[derive(Debug, Deserialize)] 75 + struct DPopHeader { 76 + /// Must be `"dpop+jwt"`. 77 + typ: String, 78 + /// Algorithm (e.g. `"ES256"`). 79 + alg: String, 80 + /// The client's public JWK, embedded in the proof header (RFC 9449 §4.2). 81 + jwk: serde_json::Value, 82 + } 83 + 84 + /// Claims from the DPoP proof JWT payload. 85 + #[derive(Debug, Deserialize)] 86 + struct DPopClaims { 87 + /// HTTP method (e.g. `"POST"`). 88 + htm: String, 89 + /// HTTP URI (scheme + host + path, no query string). 90 + htu: String, 91 + /// Issued-at (Unix timestamp). Used for freshness; replaces `exp`. 92 + iat: i64, 93 + /// Unique token ID — must be present for replay protection. 94 + jti: String, 95 + } 96 + 97 + // ── Extractor implementation ───────────────────────────────────────────────── 98 + 99 + #[async_trait] 100 + impl FromRequestParts<AppState> for AuthenticatedUser { 101 + type Rejection = ApiError; 102 + 103 + async fn from_request_parts( 104 + parts: &mut Parts, 105 + state: &AppState, 106 + ) -> Result<Self, Self::Rejection> { 107 + // 1. Extract the raw Bearer token string from Authorization header. 108 + let token_str = extract_bearer_token(&parts.headers)?; 109 + 110 + // 2. Detect the DPoP header before decoding the access token. 111 + let dpop_value = parts 112 + .headers 113 + .get("DPoP") 114 + .and_then(|v| v.to_str().ok()) 115 + .map(str::to_owned); 116 + let has_dpop = dpop_value.is_some(); 117 + 118 + // 3. Decode and verify the access token (HS256). 119 + let claims = verify_access_token(token_str, state)?; 120 + 121 + // 4. Resolve scope enum. 122 + let scope = parse_scope(&claims.scope)?; 123 + 124 + // 5. DPoP validation — only when the DPoP header is present. 125 + if has_dpop { 126 + let dpop_token = dpop_value.as_deref().unwrap(); 127 + validate_dpop(dpop_token, &parts.method, &parts.uri, &claims)?; 128 + } 129 + 130 + let token_type = if has_dpop { 131 + TokenType::DPoP 132 + } else { 133 + TokenType::Legacy 134 + }; 135 + 136 + Ok(AuthenticatedUser { 137 + did: claims.sub, 138 + scope, 139 + token_type, 140 + }) 141 + } 142 + } 143 + 144 + // ── Internal helpers ───────────────────────────────────────────────────────── 145 + 146 + /// Extract `Authorization: Bearer <token>` from headers. 147 + fn extract_bearer_token(headers: &axum::http::HeaderMap) -> Result<&str, ApiError> { 148 + let auth_value = headers 149 + .get(axum::http::header::AUTHORIZATION) 150 + .and_then(|v| { 151 + v.to_str() 152 + .inspect_err(|_| { 153 + tracing::warn!( 154 + "Authorization header contains non-UTF-8 bytes; treating as absent" 155 + ); 156 + }) 157 + .ok() 158 + }) 159 + .ok_or_else(|| { 160 + ApiError::new( 161 + ErrorCode::AuthenticationRequired, 162 + "missing Authorization header", 163 + ) 164 + })?; 165 + 166 + auth_value.strip_prefix("Bearer ").ok_or_else(|| { 167 + ApiError::new( 168 + ErrorCode::AuthenticationRequired, 169 + "Authorization header must use Bearer scheme", 170 + ) 171 + }) 172 + } 173 + 174 + /// Decode and verify the HS256 access/refresh JWT issued by this server. 175 + fn verify_access_token( 176 + token: &str, 177 + state: &AppState, 178 + ) -> Result<AccessTokenClaims, ApiError> { 179 + let decoding_key = DecodingKey::from_secret(&state.jwt_secret); 180 + 181 + let mut validation = Validation::new(Algorithm::HS256); 182 + // Validate audience only when the server DID is configured. 183 + match state.config.server_did.as_deref() { 184 + Some(did) => validation.set_audience(&[did]), 185 + None => { 186 + validation.validate_aud = false; 187 + tracing::debug!("server_did not configured; skipping JWT audience validation"); 188 + } 189 + } 190 + // `sub` is required by AT Protocol but not in jsonwebtoken's default required set. 191 + validation.set_required_spec_claims(&["exp", "sub"]); 192 + // Zero leeway: tokens we issued ourselves need no clock-skew tolerance. 193 + validation.leeway = 0; 194 + 195 + decode::<AccessTokenClaims>(token, &decoding_key, &validation) 196 + .map(|data| data.claims) 197 + .map_err(|e| { 198 + use jsonwebtoken::errors::ErrorKind; 199 + match e.kind() { 200 + ErrorKind::ExpiredSignature => { 201 + ApiError::new(ErrorCode::TokenExpired, "token has expired") 202 + } 203 + _ => ApiError::new(ErrorCode::InvalidToken, "invalid token"), 204 + } 205 + }) 206 + } 207 + 208 + /// Parse the ATProto scope string into [`AuthScope`]. 209 + fn parse_scope(scope: &str) -> Result<AuthScope, ApiError> { 210 + match scope { 211 + "com.atproto.access" => Ok(AuthScope::Access), 212 + "com.atproto.refresh" => Ok(AuthScope::Refresh), 213 + "com.atproto.appPass" => Ok(AuthScope::AppPass), 214 + _ => Err(ApiError::new(ErrorCode::InvalidToken, "unrecognised token scope")), 215 + } 216 + } 217 + 218 + /// Validate the DPoP proof JWT (RFC 9449). 219 + /// 220 + /// Checks: 221 + /// - `typ` header is `"dpop+jwt"` 222 + /// - Signature verifies against the embedded JWK 223 + /// - `htm` matches request method, `htu` matches request URI 224 + /// - `jti` is present (replay protection hook) 225 + /// - Access token `cnf.jkt` matches the computed JWK thumbprint 226 + fn validate_dpop( 227 + dpop_token: &str, 228 + method: &Method, 229 + uri: &axum::http::Uri, 230 + access_claims: &AccessTokenClaims, 231 + ) -> Result<(), ApiError> { 232 + let invalid = || ApiError::new(ErrorCode::InvalidToken, "DPoP proof invalid"); 233 + 234 + // Decode the DPoP proof header manually — jsonwebtoken's Header type doesn't 235 + // expose custom header fields like `jwk`, so we base64-decode the first segment. 236 + let header_b64 = dpop_token.split('.').next().ok_or_else(invalid)?; 237 + let header_bytes = URL_SAFE_NO_PAD.decode(header_b64).map_err(|_| invalid())?; 238 + let dpop_header: DPopHeader = 239 + serde_json::from_slice(&header_bytes).map_err(|_| invalid())?; 240 + 241 + if dpop_header.typ != "dpop+jwt" { 242 + return Err(ApiError::new( 243 + ErrorCode::InvalidToken, 244 + "DPoP proof typ must be dpop+jwt", 245 + )); 246 + } 247 + 248 + // Compute JWK thumbprint (RFC 7638) from the embedded public key. 249 + let thumbprint = jwk_thumbprint(&dpop_header.jwk).map_err(|_| invalid())?; 250 + 251 + // Verify that the access token was bound to this DPoP key. 252 + let bound_thumbprint = access_claims 253 + .cnf 254 + .as_ref() 255 + .and_then(|c| c.jkt.as_deref()) 256 + .ok_or_else(|| { 257 + ApiError::new(ErrorCode::InvalidToken, "access token missing DPoP binding") 258 + })?; 259 + if thumbprint != bound_thumbprint { 260 + return Err(ApiError::new( 261 + ErrorCode::InvalidToken, 262 + "DPoP key thumbprint does not match token binding", 263 + )); 264 + } 265 + 266 + // Verify the DPoP JWT signature using the embedded public JWK. 267 + let jwk: jsonwebtoken::jwk::Jwk = 268 + serde_json::from_value(dpop_header.jwk.clone()).map_err(|_| invalid())?; 269 + let decoding_key = DecodingKey::from_jwk(&jwk).map_err(|_| invalid())?; 270 + let alg = dpop_alg_from_str(&dpop_header.alg).ok_or_else(invalid)?; 271 + 272 + let mut validation = Validation::new(alg); 273 + // DPoP proofs don't carry `exp`; freshness is via `iat`. 274 + validation.validate_exp = false; 275 + validation.set_required_spec_claims::<&str>(&[]); 276 + validation.validate_aud = false; 277 + 278 + let dpop_data = 279 + decode::<DPopClaims>(dpop_token, &decoding_key, &validation).map_err(|_| invalid())?; 280 + let dpop_claims = dpop_data.claims; 281 + 282 + // Require `jti` for replay protection (must be present and non-empty). 283 + if dpop_claims.jti.is_empty() { 284 + return Err(ApiError::new(ErrorCode::InvalidToken, "DPoP proof missing jti")); 285 + } 286 + 287 + // Validate `htm` (HTTP method) and `htu` (HTTP URI). 288 + if dpop_claims.htm.to_uppercase() != method.as_str().to_uppercase() { 289 + return Err(ApiError::new( 290 + ErrorCode::InvalidToken, 291 + "DPoP htm does not match request method", 292 + )); 293 + } 294 + 295 + // `htu` must match scheme + authority + path (no query string per RFC 9449 §4.3). 296 + let expected_htu = { 297 + let scheme = uri.scheme_str().unwrap_or("https"); 298 + let authority = uri.authority().map(|a| a.as_str()).unwrap_or(""); 299 + let path = uri.path(); 300 + format!("{scheme}://{authority}{path}") 301 + }; 302 + if dpop_claims.htu != expected_htu { 303 + return Err(ApiError::new( 304 + ErrorCode::InvalidToken, 305 + "DPoP htu does not match request URI", 306 + )); 307 + } 308 + 309 + // Freshness: reject proofs older than 60 seconds. 310 + let now = std::time::SystemTime::now() 311 + .duration_since(std::time::UNIX_EPOCH) 312 + .map(|d| d.as_secs() as i64) 313 + .unwrap_or(0); 314 + if (now - dpop_claims.iat).abs() > 60 { 315 + return Err(ApiError::new(ErrorCode::InvalidToken, "DPoP proof is stale")); 316 + } 317 + 318 + Ok(()) 319 + } 320 + 321 + /// Map a DPoP `alg` string to a [`jsonwebtoken::Algorithm`]. 322 + fn dpop_alg_from_str(alg: &str) -> Option<Algorithm> { 323 + match alg { 324 + "ES256" => Some(Algorithm::ES256), 325 + "ES384" => Some(Algorithm::ES384), 326 + "RS256" => Some(Algorithm::RS256), 327 + "RS384" => Some(Algorithm::RS384), 328 + "RS512" => Some(Algorithm::RS512), 329 + "PS256" => Some(Algorithm::PS256), 330 + "PS384" => Some(Algorithm::PS384), 331 + "PS512" => Some(Algorithm::PS512), 332 + _ => None, 333 + } 334 + } 335 + 336 + /// Compute the RFC 7638 JWK thumbprint: SHA-256 of the canonical JSON member set, 337 + /// base64url-encoded with no padding. 338 + fn jwk_thumbprint(jwk: &serde_json::Value) -> Result<String, ()> { 339 + let kty = jwk["kty"].as_str().ok_or(())?; 340 + 341 + // Canonical member set per RFC 7638 §3.2, in lexicographic order. 342 + // serde_json's default Map is a BTreeMap, so json! keys are sorted automatically. 343 + let canonical: serde_json::Value = match kty { 344 + "EC" => serde_json::json!({ 345 + "crv": jwk["crv"], 346 + "kty": kty, 347 + "x": jwk["x"], 348 + "y": jwk["y"], 349 + }), 350 + "RSA" => serde_json::json!({ 351 + "e": jwk["e"], 352 + "kty": kty, 353 + "n": jwk["n"], 354 + }), 355 + "OKP" => serde_json::json!({ 356 + "crv": jwk["crv"], 357 + "kty": kty, 358 + "x": jwk["x"], 359 + }), 360 + _ => return Err(()), 361 + }; 362 + 363 + let canonical_json = serde_json::to_string(&canonical).map_err(|_| ())?; 364 + let hash = Sha256::digest(canonical_json.as_bytes()); 365 + Ok(URL_SAFE_NO_PAD.encode(hash)) 366 + } 367 + 368 + // ── Tests ──────────────────────────────────────────────────────────────────── 369 + 370 + #[cfg(test)] 371 + mod tests { 372 + use super::*; 373 + use axum::{ 374 + body::Body, 375 + http::{Request, StatusCode}, 376 + routing::get, 377 + Router, 378 + }; 379 + use jsonwebtoken::{encode, EncodingKey, Header}; 380 + use serde::Serialize; 381 + use tower::ServiceExt; 382 + 383 + use crate::app::test_state; 384 + 385 + /// Claims struct for minting test JWTs. 386 + #[derive(Serialize)] 387 + struct TestClaims { 388 + sub: String, 389 + aud: String, 390 + exp: u64, 391 + scope: String, 392 + #[serde(skip_serializing_if = "Option::is_none")] 393 + cnf: Option<serde_json::Value>, 394 + } 395 + 396 + fn now_secs() -> u64 { 397 + std::time::SystemTime::now() 398 + .duration_since(std::time::UNIX_EPOCH) 399 + .unwrap() 400 + .as_secs() 401 + } 402 + 403 + /// Mint a valid HS256 JWT using the test state's jwt_secret. 404 + fn mint_token( 405 + sub: &str, 406 + scope: &str, 407 + exp_offset_secs: i64, 408 + secret: &[u8; 32], 409 + cnf: Option<serde_json::Value>, 410 + ) -> String { 411 + let exp = (now_secs() as i64 + exp_offset_secs) as u64; 412 + let claims = TestClaims { 413 + sub: sub.to_owned(), 414 + aud: "did:plc:test".to_owned(), 415 + exp, 416 + scope: scope.to_owned(), 417 + cnf, 418 + }; 419 + encode( 420 + &Header::new(Algorithm::HS256), 421 + &claims, 422 + &EncodingKey::from_secret(secret), 423 + ) 424 + .unwrap() 425 + } 426 + 427 + /// Build a minimal Axum router that uses AuthenticatedUser as an extractor. 428 + fn protected_app(state: AppState) -> Router { 429 + Router::new() 430 + .route( 431 + "/protected", 432 + get(|user: AuthenticatedUser| async move { 433 + format!("did={} scope={:?}", user.did, user.scope) 434 + }), 435 + ) 436 + .with_state(state) 437 + } 438 + 439 + async fn get_protected(app: Router, token: Option<&str>) -> axum::response::Response { 440 + let mut builder = Request::builder().uri("/protected"); 441 + if let Some(t) = token { 442 + builder = builder.header("Authorization", format!("Bearer {t}")); 443 + } 444 + app.oneshot(builder.body(Body::empty()).unwrap()).await.unwrap() 445 + } 446 + 447 + // ── Missing / malformed Authorization header ────────────────────────────── 448 + 449 + #[tokio::test] 450 + async fn missing_auth_header_returns_401_authentication_required() { 451 + let state = test_state().await; 452 + let app = protected_app(state); 453 + let resp = get_protected(app, None).await; 454 + assert_eq!(resp.status(), StatusCode::UNAUTHORIZED); 455 + 456 + let body = axum::body::to_bytes(resp.into_body(), 4096).await.unwrap(); 457 + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); 458 + assert_eq!(json["error"]["code"], "AUTHENTICATION_REQUIRED"); 459 + } 460 + 461 + #[tokio::test] 462 + async fn bearer_prefix_missing_returns_401_authentication_required() { 463 + let state = test_state().await; 464 + let app = protected_app(state); 465 + let req = Request::builder() 466 + .uri("/protected") 467 + .header("Authorization", "Token abc123") 468 + .body(Body::empty()) 469 + .unwrap(); 470 + let resp = app.oneshot(req).await.unwrap(); 471 + assert_eq!(resp.status(), StatusCode::UNAUTHORIZED); 472 + 473 + let body = axum::body::to_bytes(resp.into_body(), 4096).await.unwrap(); 474 + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); 475 + assert_eq!(json["error"]["code"], "AUTHENTICATION_REQUIRED"); 476 + } 477 + 478 + // ── Malformed / invalid token ───────────────────────────────────────────── 479 + 480 + #[tokio::test] 481 + async fn malformed_token_returns_401_invalid_token() { 482 + let state = test_state().await; 483 + let app = protected_app(state); 484 + let resp = get_protected(app, Some("not.a.jwt")).await; 485 + assert_eq!(resp.status(), StatusCode::UNAUTHORIZED); 486 + 487 + let body = axum::body::to_bytes(resp.into_body(), 4096).await.unwrap(); 488 + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); 489 + assert_eq!(json["error"]["code"], "INVALID_TOKEN"); 490 + } 491 + 492 + #[tokio::test] 493 + async fn wrong_signature_returns_401_invalid_token() { 494 + let state = test_state().await; 495 + let wrong_secret = [0xFFu8; 32]; 496 + let token = mint_token("did:plc:user", "com.atproto.access", 3600, &wrong_secret, None); 497 + let app = protected_app(state); 498 + let resp = get_protected(app, Some(&token)).await; 499 + assert_eq!(resp.status(), StatusCode::UNAUTHORIZED); 500 + 501 + let body = axum::body::to_bytes(resp.into_body(), 4096).await.unwrap(); 502 + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); 503 + assert_eq!(json["error"]["code"], "INVALID_TOKEN"); 504 + } 505 + 506 + // ── Expired token ───────────────────────────────────────────────────────── 507 + 508 + #[tokio::test] 509 + async fn expired_token_returns_401_token_expired() { 510 + let state = test_state().await; 511 + let secret = state.jwt_secret; 512 + // exp is 1 second in the past. 513 + let token = mint_token("did:plc:user", "com.atproto.access", -1, &secret, None); 514 + let app = protected_app(state); 515 + let resp = get_protected(app, Some(&token)).await; 516 + assert_eq!(resp.status(), StatusCode::UNAUTHORIZED); 517 + 518 + let body = axum::body::to_bytes(resp.into_body(), 4096).await.unwrap(); 519 + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); 520 + assert_eq!(json["error"]["code"], "TOKEN_EXPIRED"); 521 + } 522 + 523 + // ── Valid access token ──────────────────────────────────────────────────── 524 + 525 + #[tokio::test] 526 + async fn valid_access_token_extracts_did_and_scope() { 527 + let state = test_state().await; 528 + let secret = state.jwt_secret; 529 + let token = mint_token( 530 + "did:plc:alice", 531 + "com.atproto.access", 532 + 3600, 533 + &secret, 534 + None, 535 + ); 536 + let app = protected_app(state); 537 + let resp = get_protected(app, Some(&token)).await; 538 + assert_eq!(resp.status(), StatusCode::OK); 539 + 540 + let body = axum::body::to_bytes(resp.into_body(), 4096).await.unwrap(); 541 + let text = String::from_utf8(body.to_vec()).unwrap(); 542 + assert!(text.contains("did=did:plc:alice")); 543 + assert!(text.contains("scope=Access")); 544 + } 545 + 546 + #[tokio::test] 547 + async fn valid_refresh_token_extracts_refresh_scope() { 548 + let state = test_state().await; 549 + let secret = state.jwt_secret; 550 + let token = mint_token("did:plc:alice", "com.atproto.refresh", 3600, &secret, None); 551 + let app = protected_app(state); 552 + let resp = get_protected(app, Some(&token)).await; 553 + assert_eq!(resp.status(), StatusCode::OK); 554 + 555 + let body = axum::body::to_bytes(resp.into_body(), 4096).await.unwrap(); 556 + let text = String::from_utf8(body.to_vec()).unwrap(); 557 + assert!(text.contains("scope=Refresh")); 558 + } 559 + 560 + // ── Unknown scope ───────────────────────────────────────────────────────── 561 + 562 + #[tokio::test] 563 + async fn unknown_scope_returns_401_invalid_token() { 564 + let state = test_state().await; 565 + let secret = state.jwt_secret; 566 + let token = mint_token("did:plc:user", "com.example.unknown", 3600, &secret, None); 567 + let app = protected_app(state); 568 + let resp = get_protected(app, Some(&token)).await; 569 + assert_eq!(resp.status(), StatusCode::UNAUTHORIZED); 570 + 571 + let body = axum::body::to_bytes(resp.into_body(), 4096).await.unwrap(); 572 + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); 573 + assert_eq!(json["error"]["code"], "INVALID_TOKEN"); 574 + } 575 + 576 + // ── JWK thumbprint ──────────────────────────────────────────────────────── 577 + 578 + #[test] 579 + fn rsa_jwk_thumbprint_matches_rfc7638_example() { 580 + // RFC 7638 §3.3 canonical example — RSA key with known expected thumbprint. 581 + let jwk = serde_json::json!({ 582 + "e": "AQAB", 583 + "kty": "RSA", 584 + "n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw", 585 + // Extra member — must be excluded from the canonical form. 586 + "use": "sig" 587 + }); 588 + let thumb = jwk_thumbprint(&jwk).unwrap(); 589 + assert_eq!(thumb, "NzbLsXh8uDCcd-6MNwXF4W_7noWXFZAfHkxZsRGC9Xs"); 590 + } 591 + 592 + #[test] 593 + fn ec_jwk_thumbprint_produces_correct_format() { 594 + // EC (P-256) key from RFC 7517 Appendix A.2. Extra fields like "use" and "d" 595 + // must be stripped from the canonical form. 596 + let jwk = serde_json::json!({ 597 + "kty": "EC", 598 + "crv": "P-256", 599 + "x": "f83OJ3D2xF1Bg8vub9tLe1gHMzV76e8Tus9uPHvRVEU", 600 + "y": "x_FEzRu9m36HLN_tue659LNpXW6pCyStikYjKIWI5a0", 601 + "use": "sig" 602 + }); 603 + let thumb = jwk_thumbprint(&jwk).unwrap(); 604 + // SHA-256 base64url (no padding) is always 43 characters. 605 + assert_eq!(thumb.len(), 43, "thumbprint must be 43 base64url chars"); 606 + assert!( 607 + thumb.chars().all(|c| c.is_alphanumeric() || c == '-' || c == '_'), 608 + "thumbprint must be base64url" 609 + ); 610 + // Stable value — verified against implementation; guards against regressions. 611 + assert_eq!(thumb, "oKIywvGUpTVTyxMQ3bwIIeQUudfr_CkLMjCE19ECD-U"); 612 + } 613 + 614 + // ── DPoP binding — token without cnf claim rejected ─────────────────────── 615 + 616 + #[tokio::test] 617 + async fn dpop_header_without_cnf_claim_returns_401() { 618 + let state = test_state().await; 619 + let secret = state.jwt_secret; 620 + // Access token has no `cnf` claim. 621 + let token = mint_token("did:plc:user", "com.atproto.access", 3600, &secret, None); 622 + let app = protected_app(state); 623 + 624 + let req = Request::builder() 625 + .uri("/protected") 626 + .header("Authorization", format!("Bearer {token}")) 627 + .header("DPoP", "dummy.dpop.value") 628 + .body(Body::empty()) 629 + .unwrap(); 630 + let resp = app.oneshot(req).await.unwrap(); 631 + assert_eq!(resp.status(), StatusCode::UNAUTHORIZED); 632 + 633 + let body = axum::body::to_bytes(resp.into_body(), 4096).await.unwrap(); 634 + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); 635 + assert_eq!(json["error"]["code"], "INVALID_TOKEN"); 636 + } 637 + }
+7
crates/relay/src/main.rs
··· 1 1 use anyhow::Context; 2 2 use clap::Parser; 3 + use rand_core::RngCore; 3 4 use reqwest::Client; 4 5 use std::{path::PathBuf, sync::Arc}; 5 6 6 7 mod app; 8 + mod auth; 7 9 mod db; 8 10 mod dns; 9 11 mod routes; ··· 123 125 well_known::HttpWellKnownResolver::new(http_client.clone()), 124 126 )); 125 127 128 + let mut jwt_secret = [0u8; 32]; 129 + rand_core::OsRng.fill_bytes(&mut jwt_secret); 130 + tracing::info!("JWT signing secret generated (ephemeral — rotates on restart)"); 131 + 126 132 let state = app::AppState { 127 133 config: Arc::new(config), 128 134 db: pool, ··· 130 136 dns_provider: None, 131 137 txt_resolver, 132 138 well_known_resolver, 139 + jwt_secret, 133 140 }; 134 141 135 142 let listener = tokio::net::TcpListener::bind(&addr)
+1
crates/relay/src/routes/auth.rs
··· 206 206 dns_provider: base.dns_provider, 207 207 txt_resolver: base.txt_resolver, 208 208 well_known_resolver: base.well_known_resolver, 209 + jwt_secret: base.jwt_secret, 209 210 } 210 211 } 211 212
+2
crates/relay/src/routes/create_signing_key.rs
··· 128 128 dns_provider: base.dns_provider, 129 129 txt_resolver: base.txt_resolver, 130 130 well_known_resolver: base.well_known_resolver, 131 + jwt_secret: base.jwt_secret, 131 132 } 132 133 } 133 134 ··· 381 382 dns_provider: base.dns_provider, 382 383 txt_resolver: base.txt_resolver, 383 384 well_known_resolver: base.well_known_resolver, 385 + jwt_secret: base.jwt_secret, 384 386 }; 385 387 386 388 let response = app(state)
+1
crates/relay/src/routes/describe_server.rs
··· 130 130 dns_provider: base.dns_provider, 131 131 txt_resolver: base.txt_resolver, 132 132 well_known_resolver: base.well_known_resolver, 133 + jwt_secret: base.jwt_secret, 133 134 }; 134 135 135 136 let response = app(state)
+1
crates/relay/src/routes/test_utils.rs
··· 18 18 dns_provider: base.dns_provider, 19 19 txt_resolver: base.txt_resolver, 20 20 well_known_resolver: base.well_known_resolver, 21 + jwt_secret: base.jwt_secret, 21 22 } 22 23 }