A better Rust ATProto crate
102
fork

Configure Feed

Select the types of activity you want to include in your feed.

at pretty-codegen 484 lines 15 kB view raw
1//! Service authentication JWT parsing and verification for AT Protocol. 2//! 3//! Service auth is atproto's inter-service authentication mechanism. When a backend 4//! service (feed generator, labeler, etc.) receives requests, the PDS signs a 5//! short-lived JWT with the user's signing key and includes it as a Bearer token. 6//! 7//! # JWT Structure 8//! 9//! - Header: `alg` (ES256K for k256, ES256 for p256), `typ` ("JWT") 10//! - Payload: 11//! - `iss`: user's DID (issuer) 12//! - `aud`: target service DID (audience) 13//! - `exp`: expiration unix timestamp 14//! - `iat`: issued at unix timestamp 15//! - `jti`: random nonce (128-bit hex) for replay protection 16//! - `lxm`: lexicon method NSID (method binding) 17//! - Signature: signed with user's signing key from DID doc (ES256 or ES256K) 18 19use crate::CowStr; 20use crate::IntoStatic; 21use crate::types::string::{Did, Nsid}; 22use alloc::string::String; 23use alloc::string::ToString; 24use alloc::vec::Vec; 25use base64::Engine; 26use base64::engine::general_purpose::URL_SAFE_NO_PAD; 27use ouroboros::self_referencing; 28use serde::{Deserialize, Serialize}; 29use signature::Verifier; 30use smol_str::SmolStr; 31use smol_str::format_smolstr; 32use thiserror::Error; 33 34#[cfg(feature = "crypto-p256")] 35use p256::ecdsa::{Signature as P256Signature, VerifyingKey as P256VerifyingKey}; 36 37#[cfg(feature = "crypto-k256")] 38use k256::ecdsa::{Signature as K256Signature, VerifyingKey as K256VerifyingKey}; 39 40/// Errors that can occur during JWT parsing and verification. 41#[derive(Debug, Error, miette::Diagnostic)] 42#[non_exhaustive] 43pub enum ServiceAuthError { 44 /// JWT format is invalid (not three base64-encoded parts separated by dots) 45 #[error("malformed JWT: {0}")] 46 MalformedToken(CowStr<'static>), 47 48 /// Base64 decoding failed 49 #[error("base64 decode error: {0}")] 50 Base64Decode(#[from] base64::DecodeError), 51 52 /// JSON parsing failed 53 #[error("JSON parsing error: {0}")] 54 JsonParse(#[from] serde_json::Error), 55 56 /// Signature verification failed 57 #[error("invalid signature")] 58 InvalidSignature, 59 60 /// Unsupported algorithm 61 #[error("unsupported algorithm: {alg}")] 62 UnsupportedAlgorithm { 63 /// Algorithm name from JWT header 64 alg: SmolStr, 65 }, 66 67 /// Token has expired 68 #[error("token expired at {exp} (current time: {now})")] 69 Expired { 70 /// Expiration timestamp from token 71 exp: i64, 72 /// Current timestamp 73 now: i64, 74 }, 75 76 /// Audience mismatch 77 #[error("audience mismatch: expected {expected}, got {actual}")] 78 AudienceMismatch { 79 /// Expected audience DID 80 expected: Did<'static>, 81 /// Actual audience DID in token 82 actual: Did<'static>, 83 }, 84 85 /// Method mismatch (lxm field) 86 #[error("method mismatch: expected {expected}, got {actual:?}")] 87 MethodMismatch { 88 /// Expected method NSID 89 expected: Nsid<'static>, 90 /// Actual method NSID in token (if any) 91 actual: Option<Nsid<'static>>, 92 }, 93 94 /// Missing required field 95 #[error("missing required field: {0}")] 96 MissingField(&'static str), 97 98 /// Crypto error 99 #[error("crypto error: {0}")] 100 Crypto(CowStr<'static>), 101} 102 103/// JWT header for service auth tokens. 104#[derive(Debug, Clone, Serialize, Deserialize)] 105pub struct JwtHeader<'a> { 106 /// Algorithm used for signing 107 #[serde(borrow)] 108 pub alg: CowStr<'a>, 109 /// Type (always "JWT") 110 #[serde(borrow)] 111 pub typ: CowStr<'a>, 112} 113 114impl IntoStatic for JwtHeader<'_> { 115 type Output = JwtHeader<'static>; 116 117 fn into_static(self) -> Self::Output { 118 JwtHeader { 119 alg: self.alg.into_static(), 120 typ: self.typ.into_static(), 121 } 122 } 123} 124 125/// Service authentication claims. 126/// 127/// These are the payload fields in a service auth JWT. 128#[derive(Debug, Clone, Serialize, Deserialize)] 129pub struct ServiceAuthClaims<'a> { 130 /// Issuer (user's DID) 131 #[serde(borrow)] 132 pub iss: Did<'a>, 133 134 /// Audience (target service DID) 135 #[serde(borrow)] 136 pub aud: Did<'a>, 137 138 /// Expiration time (unix timestamp) 139 pub exp: i64, 140 141 /// Issued at (unix timestamp) 142 pub iat: i64, 143 144 /// JWT ID (nonce for replay protection) 145 #[serde(borrow, skip_serializing_if = "Option::is_none")] 146 pub jti: Option<CowStr<'a>>, 147 148 /// Lexicon method NSID (method binding) 149 #[serde(borrow, skip_serializing_if = "Option::is_none")] 150 pub lxm: Option<Nsid<'a>>, 151} 152 153impl<'a> IntoStatic for ServiceAuthClaims<'a> { 154 type Output = ServiceAuthClaims<'static>; 155 156 fn into_static(self) -> Self::Output { 157 ServiceAuthClaims { 158 iss: self.iss.into_static(), 159 aud: self.aud.into_static(), 160 exp: self.exp, 161 iat: self.iat, 162 jti: self.jti.map(|j| j.into_static()), 163 lxm: self.lxm.map(|l| l.into_static()), 164 } 165 } 166} 167 168impl<'a> ServiceAuthClaims<'a> { 169 /// Validate the claims against expected values. 170 /// 171 /// Checks: 172 /// - Audience matches expected DID 173 /// - Token is not expired 174 pub fn validate(&self, expected_aud: &Did) -> Result<(), ServiceAuthError> { 175 // Check audience 176 if self.aud.as_str() != expected_aud.as_str() { 177 return Err(ServiceAuthError::AudienceMismatch { 178 expected: expected_aud.clone().into_static(), 179 actual: self.aud.clone().into_static(), 180 }); 181 } 182 183 // Check expiration 184 if self.is_expired() { 185 let now = chrono::Utc::now().timestamp(); 186 return Err(ServiceAuthError::Expired { exp: self.exp, now }); 187 } 188 189 Ok(()) 190 } 191 192 /// Check if the token has expired. 193 pub fn is_expired(&self) -> bool { 194 let now = chrono::Utc::now().timestamp(); 195 self.exp <= now 196 } 197 198 /// Check if the method (lxm) matches the expected NSID. 199 pub fn check_method(&self, nsid: &Nsid) -> bool { 200 self.lxm 201 .as_ref() 202 .map(|lxm| lxm.as_str() == nsid.as_str()) 203 .unwrap_or(false) 204 } 205 206 /// Require that the method (lxm) matches the expected NSID. 207 pub fn require_method(&self, nsid: &Nsid) -> Result<(), ServiceAuthError> { 208 if !self.check_method(nsid) { 209 return Err(ServiceAuthError::MethodMismatch { 210 expected: nsid.clone().into_static(), 211 actual: self.lxm.as_ref().map(|l| l.clone().into_static()), 212 }); 213 } 214 Ok(()) 215 } 216} 217 218/// Parsed JWT components. 219/// 220/// This struct owns the decoded buffers and parsed components using ouroboros 221/// self-referencing. The header and claims borrow from their respective buffers. 222#[self_referencing] 223pub struct ParsedJwt { 224 /// Decoded header buffer (owned) 225 header_buf: Vec<u8>, 226 /// Decoded payload buffer (owned) 227 payload_buf: Vec<u8>, 228 /// Original token string for signing_input 229 token: String, 230 /// Signature bytes 231 signature: Vec<u8>, 232 /// Parsed header borrowing from header_buf 233 #[borrows(header_buf)] 234 #[covariant] 235 header: JwtHeader<'this>, 236 /// Parsed claims borrowing from payload_buf 237 #[borrows(payload_buf)] 238 #[covariant] 239 claims: ServiceAuthClaims<'this>, 240} 241 242impl ParsedJwt { 243 /// Get the signing input (header.payload) for signature verification. 244 pub fn signing_input(&self) -> &[u8] { 245 self.with_token(|token| { 246 let dot_pos = token.find('.').unwrap(); 247 let second_dot_pos = token[dot_pos + 1..].find('.').unwrap() + dot_pos + 1; 248 token[..second_dot_pos].as_bytes() 249 }) 250 } 251 252 /// Get a reference to the header. 253 pub fn header(&self) -> &JwtHeader<'_> { 254 self.borrow_header() 255 } 256 257 /// Get a reference to the claims. 258 pub fn claims(&self) -> &ServiceAuthClaims<'_> { 259 self.borrow_claims() 260 } 261 262 /// Get a reference to the signature. 263 pub fn signature(&self) -> &[u8] { 264 self.borrow_signature() 265 } 266 267 /// Get owned header with 'static lifetime. 268 pub fn into_header(self) -> JwtHeader<'static> { 269 self.with_header(|header| header.clone().into_static()) 270 } 271 272 /// Get owned claims with 'static lifetime. 273 pub fn into_claims(self) -> ServiceAuthClaims<'static> { 274 self.with_claims(|claims| claims.clone().into_static()) 275 } 276} 277 278/// Parse a JWT token into its components without verifying the signature. 279/// 280/// This extracts and decodes all JWT components. The header and claims are parsed 281/// and borrow from their respective owned buffers using ouroboros self-referencing. 282pub fn parse_jwt(token: &str) -> Result<ParsedJwt, ServiceAuthError> { 283 let parts: Vec<&str> = token.split('.').collect(); 284 if parts.len() != 3 { 285 return Err(ServiceAuthError::MalformedToken(CowStr::new_static( 286 "JWT must have exactly 3 parts separated by dots", 287 ))); 288 } 289 290 let header_b64 = parts[0]; 291 let payload_b64 = parts[1]; 292 let signature_b64 = parts[2]; 293 294 // Decode all components 295 let header_buf = URL_SAFE_NO_PAD.decode(header_b64)?; 296 let payload_buf = URL_SAFE_NO_PAD.decode(payload_b64)?; 297 let signature = URL_SAFE_NO_PAD.decode(signature_b64)?; 298 299 // Validate that buffers contain valid JSON for their types 300 // We parse once here to validate, then again in the builder (unavoidable with ouroboros) 301 let _header: JwtHeader = serde_json::from_slice(&header_buf)?; 302 let _claims: ServiceAuthClaims = serde_json::from_slice(&payload_buf)?; 303 304 Ok(ParsedJwtBuilder { 305 header_buf, 306 payload_buf, 307 token: token.to_string(), 308 signature, 309 header_builder: |buf| { 310 // Safe: we validated this succeeds above 311 serde_json::from_slice(buf).expect("header was validated") 312 }, 313 claims_builder: |buf| { 314 // Safe: we validated this succeeds above 315 serde_json::from_slice(buf).expect("claims were validated") 316 }, 317 } 318 .build()) 319} 320 321/// Public key types for signature verification. 322#[derive(Debug, Clone)] 323pub enum PublicKey { 324 /// P-256 (ES256) public key 325 #[cfg(feature = "crypto-p256")] 326 P256(P256VerifyingKey), 327 328 /// secp256k1 (ES256K) public key 329 #[cfg(feature = "crypto-k256")] 330 K256(K256VerifyingKey), 331} 332 333impl PublicKey { 334 /// Create a P-256 public key from compressed or uncompressed bytes. 335 #[cfg(feature = "crypto-p256")] 336 pub fn from_p256_bytes(bytes: &[u8]) -> Result<Self, ServiceAuthError> { 337 let key = P256VerifyingKey::from_sec1_bytes(bytes).map_err(|e| { 338 ServiceAuthError::Crypto(CowStr::Owned(format_smolstr!("invalid P-256 key: {}", e))) 339 })?; 340 Ok(PublicKey::P256(key)) 341 } 342 343 /// Create a secp256k1 public key from compressed or uncompressed bytes. 344 #[cfg(feature = "crypto-k256")] 345 pub fn from_k256_bytes(bytes: &[u8]) -> Result<Self, ServiceAuthError> { 346 let key = K256VerifyingKey::from_sec1_bytes(bytes).map_err(|e| { 347 ServiceAuthError::Crypto(CowStr::Owned(format_smolstr!("invalid K-256 key: {}", e))) 348 })?; 349 Ok(PublicKey::K256(key)) 350 } 351} 352 353/// Verify a JWT signature using the provided public key. 354/// 355/// The algorithm is determined by the JWT header and must match the public key type. 356pub fn verify_signature( 357 parsed: &ParsedJwt, 358 public_key: &PublicKey, 359) -> Result<(), ServiceAuthError> { 360 let alg = parsed.header().alg.as_str(); 361 let signing_input = parsed.signing_input(); 362 let signature = parsed.signature(); 363 364 match (alg, public_key) { 365 #[cfg(feature = "crypto-p256")] 366 ("ES256", PublicKey::P256(key)) => { 367 let sig = P256Signature::from_slice(signature).map_err(|e| { 368 ServiceAuthError::Crypto(CowStr::Owned(format_smolstr!( 369 "invalid ES256 signature: {}", 370 e 371 ))) 372 })?; 373 key.verify(signing_input, &sig) 374 .map_err(|_| ServiceAuthError::InvalidSignature)?; 375 Ok(()) 376 } 377 378 #[cfg(feature = "crypto-k256")] 379 ("ES256K", PublicKey::K256(key)) => { 380 let sig = K256Signature::from_slice(signature).map_err(|e| { 381 ServiceAuthError::Crypto(CowStr::Owned(format_smolstr!( 382 "invalid ES256K signature: {}", 383 e 384 ))) 385 })?; 386 key.verify(signing_input, &sig) 387 .map_err(|_| ServiceAuthError::InvalidSignature)?; 388 Ok(()) 389 } 390 391 _ => Err(ServiceAuthError::UnsupportedAlgorithm { 392 alg: SmolStr::new(alg), 393 }), 394 } 395} 396 397/// Parse and verify a service auth JWT in one step, returning owned claims. 398/// 399/// This is a convenience function that combines parsing and signature verification. 400pub fn verify_service_jwt( 401 token: &str, 402 public_key: &PublicKey, 403) -> Result<ServiceAuthClaims<'static>, ServiceAuthError> { 404 let parsed = parse_jwt(token)?; 405 verify_signature(&parsed, public_key)?; 406 Ok(parsed.into_claims()) 407} 408 409#[cfg(test)] 410mod tests { 411 use super::*; 412 413 #[test] 414 fn test_parse_jwt_invalid_format() { 415 let result = parse_jwt("not.a.valid.jwt.with.too.many.parts"); 416 assert!(matches!(result, Err(ServiceAuthError::MalformedToken(_)))); 417 } 418 419 #[test] 420 fn test_claims_expiration() { 421 let now = chrono::Utc::now().timestamp(); 422 let expired_claims = ServiceAuthClaims { 423 iss: Did::new("did:plc:test").unwrap(), 424 aud: Did::new("did:web:example.com").unwrap(), 425 exp: now - 100, 426 iat: now - 200, 427 jti: None, 428 lxm: None, 429 }; 430 431 assert!(expired_claims.is_expired()); 432 433 let valid_claims = ServiceAuthClaims { 434 iss: Did::new("did:plc:test").unwrap(), 435 aud: Did::new("did:web:example.com").unwrap(), 436 exp: now + 100, 437 iat: now, 438 jti: None, 439 lxm: None, 440 }; 441 442 assert!(!valid_claims.is_expired()); 443 } 444 445 #[test] 446 fn test_audience_validation() { 447 let now = chrono::Utc::now().timestamp(); 448 let claims = ServiceAuthClaims { 449 iss: Did::new("did:plc:test").unwrap(), 450 aud: Did::new("did:web:example.com").unwrap(), 451 exp: now + 100, 452 iat: now, 453 jti: None, 454 lxm: None, 455 }; 456 457 let expected_aud = Did::new("did:web:example.com").unwrap(); 458 assert!(claims.validate(&expected_aud).is_ok()); 459 460 let wrong_aud = Did::new("did:web:wrong.com").unwrap(); 461 assert!(matches!( 462 claims.validate(&wrong_aud), 463 Err(ServiceAuthError::AudienceMismatch { .. }) 464 )); 465 } 466 467 #[test] 468 fn test_method_check() { 469 let claims = ServiceAuthClaims { 470 iss: Did::new("did:plc:test").unwrap(), 471 aud: Did::new("did:web:example.com").unwrap(), 472 exp: chrono::Utc::now().timestamp() + 100, 473 iat: chrono::Utc::now().timestamp(), 474 jti: None, 475 lxm: Some(Nsid::new("app.bsky.feed.getFeedSkeleton").unwrap()), 476 }; 477 478 let expected = Nsid::new("app.bsky.feed.getFeedSkeleton").unwrap(); 479 assert!(claims.check_method(&expected)); 480 481 let wrong = Nsid::new("app.bsky.feed.getTimeline").unwrap(); 482 assert!(!claims.check_method(&wrong)); 483 } 484}