A better Rust ATProto crate
1//! Service authentication JWT parsing and verification for AT Protocol.
2//!
3//! Service auth is atproto's inter-service authentication mechanism. When a backend
4//! service (feed generator, labeler, etc.) receives requests, the PDS signs a
5//! short-lived JWT with the user's signing key and includes it as a Bearer token.
6//!
7//! # JWT Structure
8//!
9//! - Header: `alg` (ES256K for k256, ES256 for p256), `typ` ("JWT")
10//! - Payload:
11//! - `iss`: user's DID (issuer)
12//! - `aud`: target service DID (audience)
13//! - `exp`: expiration unix timestamp
14//! - `iat`: issued at unix timestamp
15//! - `jti`: random nonce (128-bit hex) for replay protection
16//! - `lxm`: lexicon method NSID (method binding)
17//! - Signature: signed with user's signing key from DID doc (ES256 or ES256K)
18
19use crate::CowStr;
20use crate::IntoStatic;
21use crate::types::string::{Did, Nsid};
22use alloc::string::String;
23use alloc::string::ToString;
24use alloc::vec::Vec;
25use base64::Engine;
26use base64::engine::general_purpose::URL_SAFE_NO_PAD;
27use ouroboros::self_referencing;
28use serde::{Deserialize, Serialize};
29use signature::Verifier;
30use smol_str::SmolStr;
31use smol_str::format_smolstr;
32use thiserror::Error;
33
34#[cfg(feature = "crypto-p256")]
35use p256::ecdsa::{Signature as P256Signature, VerifyingKey as P256VerifyingKey};
36
37#[cfg(feature = "crypto-k256")]
38use k256::ecdsa::{Signature as K256Signature, VerifyingKey as K256VerifyingKey};
39
40/// Errors that can occur during JWT parsing and verification.
41#[derive(Debug, Error, miette::Diagnostic)]
42#[non_exhaustive]
43pub enum ServiceAuthError {
44 /// JWT format is invalid (not three base64-encoded parts separated by dots)
45 #[error("malformed JWT: {0}")]
46 MalformedToken(CowStr<'static>),
47
48 /// Base64 decoding failed
49 #[error("base64 decode error: {0}")]
50 Base64Decode(#[from] base64::DecodeError),
51
52 /// JSON parsing failed
53 #[error("JSON parsing error: {0}")]
54 JsonParse(#[from] serde_json::Error),
55
56 /// Signature verification failed
57 #[error("invalid signature")]
58 InvalidSignature,
59
60 /// Unsupported algorithm
61 #[error("unsupported algorithm: {alg}")]
62 UnsupportedAlgorithm {
63 /// Algorithm name from JWT header
64 alg: SmolStr,
65 },
66
67 /// Token has expired
68 #[error("token expired at {exp} (current time: {now})")]
69 Expired {
70 /// Expiration timestamp from token
71 exp: i64,
72 /// Current timestamp
73 now: i64,
74 },
75
76 /// Audience mismatch
77 #[error("audience mismatch: expected {expected}, got {actual}")]
78 AudienceMismatch {
79 /// Expected audience DID
80 expected: Did<'static>,
81 /// Actual audience DID in token
82 actual: Did<'static>,
83 },
84
85 /// Method mismatch (lxm field)
86 #[error("method mismatch: expected {expected}, got {actual:?}")]
87 MethodMismatch {
88 /// Expected method NSID
89 expected: Nsid<'static>,
90 /// Actual method NSID in token (if any)
91 actual: Option<Nsid<'static>>,
92 },
93
94 /// Missing required field
95 #[error("missing required field: {0}")]
96 MissingField(&'static str),
97
98 /// Crypto error
99 #[error("crypto error: {0}")]
100 Crypto(CowStr<'static>),
101}
102
103/// JWT header for service auth tokens.
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct JwtHeader<'a> {
106 /// Algorithm used for signing
107 #[serde(borrow)]
108 pub alg: CowStr<'a>,
109 /// Type (always "JWT")
110 #[serde(borrow)]
111 pub typ: CowStr<'a>,
112}
113
114impl IntoStatic for JwtHeader<'_> {
115 type Output = JwtHeader<'static>;
116
117 fn into_static(self) -> Self::Output {
118 JwtHeader {
119 alg: self.alg.into_static(),
120 typ: self.typ.into_static(),
121 }
122 }
123}
124
125/// Service authentication claims.
126///
127/// These are the payload fields in a service auth JWT.
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct ServiceAuthClaims<'a> {
130 /// Issuer (user's DID)
131 #[serde(borrow)]
132 pub iss: Did<'a>,
133
134 /// Audience (target service DID)
135 #[serde(borrow)]
136 pub aud: Did<'a>,
137
138 /// Expiration time (unix timestamp)
139 pub exp: i64,
140
141 /// Issued at (unix timestamp)
142 pub iat: i64,
143
144 /// JWT ID (nonce for replay protection)
145 #[serde(borrow, skip_serializing_if = "Option::is_none")]
146 pub jti: Option<CowStr<'a>>,
147
148 /// Lexicon method NSID (method binding)
149 #[serde(borrow, skip_serializing_if = "Option::is_none")]
150 pub lxm: Option<Nsid<'a>>,
151}
152
153impl<'a> IntoStatic for ServiceAuthClaims<'a> {
154 type Output = ServiceAuthClaims<'static>;
155
156 fn into_static(self) -> Self::Output {
157 ServiceAuthClaims {
158 iss: self.iss.into_static(),
159 aud: self.aud.into_static(),
160 exp: self.exp,
161 iat: self.iat,
162 jti: self.jti.map(|j| j.into_static()),
163 lxm: self.lxm.map(|l| l.into_static()),
164 }
165 }
166}
167
168impl<'a> ServiceAuthClaims<'a> {
169 /// Validate the claims against expected values.
170 ///
171 /// Checks:
172 /// - Audience matches expected DID
173 /// - Token is not expired
174 pub fn validate(&self, expected_aud: &Did) -> Result<(), ServiceAuthError> {
175 // Check audience
176 if self.aud.as_str() != expected_aud.as_str() {
177 return Err(ServiceAuthError::AudienceMismatch {
178 expected: expected_aud.clone().into_static(),
179 actual: self.aud.clone().into_static(),
180 });
181 }
182
183 // Check expiration
184 if self.is_expired() {
185 let now = chrono::Utc::now().timestamp();
186 return Err(ServiceAuthError::Expired { exp: self.exp, now });
187 }
188
189 Ok(())
190 }
191
192 /// Check if the token has expired.
193 pub fn is_expired(&self) -> bool {
194 let now = chrono::Utc::now().timestamp();
195 self.exp <= now
196 }
197
198 /// Check if the method (lxm) matches the expected NSID.
199 pub fn check_method(&self, nsid: &Nsid) -> bool {
200 self.lxm
201 .as_ref()
202 .map(|lxm| lxm.as_str() == nsid.as_str())
203 .unwrap_or(false)
204 }
205
206 /// Require that the method (lxm) matches the expected NSID.
207 pub fn require_method(&self, nsid: &Nsid) -> Result<(), ServiceAuthError> {
208 if !self.check_method(nsid) {
209 return Err(ServiceAuthError::MethodMismatch {
210 expected: nsid.clone().into_static(),
211 actual: self.lxm.as_ref().map(|l| l.clone().into_static()),
212 });
213 }
214 Ok(())
215 }
216}
217
218/// Parsed JWT components.
219///
220/// This struct owns the decoded buffers and parsed components using ouroboros
221/// self-referencing. The header and claims borrow from their respective buffers.
222#[self_referencing]
223pub struct ParsedJwt {
224 /// Decoded header buffer (owned)
225 header_buf: Vec<u8>,
226 /// Decoded payload buffer (owned)
227 payload_buf: Vec<u8>,
228 /// Original token string for signing_input
229 token: String,
230 /// Signature bytes
231 signature: Vec<u8>,
232 /// Parsed header borrowing from header_buf
233 #[borrows(header_buf)]
234 #[covariant]
235 header: JwtHeader<'this>,
236 /// Parsed claims borrowing from payload_buf
237 #[borrows(payload_buf)]
238 #[covariant]
239 claims: ServiceAuthClaims<'this>,
240}
241
242impl ParsedJwt {
243 /// Get the signing input (header.payload) for signature verification.
244 pub fn signing_input(&self) -> &[u8] {
245 self.with_token(|token| {
246 let dot_pos = token.find('.').unwrap();
247 let second_dot_pos = token[dot_pos + 1..].find('.').unwrap() + dot_pos + 1;
248 token[..second_dot_pos].as_bytes()
249 })
250 }
251
252 /// Get a reference to the header.
253 pub fn header(&self) -> &JwtHeader<'_> {
254 self.borrow_header()
255 }
256
257 /// Get a reference to the claims.
258 pub fn claims(&self) -> &ServiceAuthClaims<'_> {
259 self.borrow_claims()
260 }
261
262 /// Get a reference to the signature.
263 pub fn signature(&self) -> &[u8] {
264 self.borrow_signature()
265 }
266
267 /// Get owned header with 'static lifetime.
268 pub fn into_header(self) -> JwtHeader<'static> {
269 self.with_header(|header| header.clone().into_static())
270 }
271
272 /// Get owned claims with 'static lifetime.
273 pub fn into_claims(self) -> ServiceAuthClaims<'static> {
274 self.with_claims(|claims| claims.clone().into_static())
275 }
276}
277
278/// Parse a JWT token into its components without verifying the signature.
279///
280/// This extracts and decodes all JWT components. The header and claims are parsed
281/// and borrow from their respective owned buffers using ouroboros self-referencing.
282pub fn parse_jwt(token: &str) -> Result<ParsedJwt, ServiceAuthError> {
283 let parts: Vec<&str> = token.split('.').collect();
284 if parts.len() != 3 {
285 return Err(ServiceAuthError::MalformedToken(CowStr::new_static(
286 "JWT must have exactly 3 parts separated by dots",
287 )));
288 }
289
290 let header_b64 = parts[0];
291 let payload_b64 = parts[1];
292 let signature_b64 = parts[2];
293
294 // Decode all components
295 let header_buf = URL_SAFE_NO_PAD.decode(header_b64)?;
296 let payload_buf = URL_SAFE_NO_PAD.decode(payload_b64)?;
297 let signature = URL_SAFE_NO_PAD.decode(signature_b64)?;
298
299 // Validate that buffers contain valid JSON for their types
300 // We parse once here to validate, then again in the builder (unavoidable with ouroboros)
301 let _header: JwtHeader = serde_json::from_slice(&header_buf)?;
302 let _claims: ServiceAuthClaims = serde_json::from_slice(&payload_buf)?;
303
304 Ok(ParsedJwtBuilder {
305 header_buf,
306 payload_buf,
307 token: token.to_string(),
308 signature,
309 header_builder: |buf| {
310 // Safe: we validated this succeeds above
311 serde_json::from_slice(buf).expect("header was validated")
312 },
313 claims_builder: |buf| {
314 // Safe: we validated this succeeds above
315 serde_json::from_slice(buf).expect("claims were validated")
316 },
317 }
318 .build())
319}
320
321/// Public key types for signature verification.
322#[derive(Debug, Clone)]
323pub enum PublicKey {
324 /// P-256 (ES256) public key
325 #[cfg(feature = "crypto-p256")]
326 P256(P256VerifyingKey),
327
328 /// secp256k1 (ES256K) public key
329 #[cfg(feature = "crypto-k256")]
330 K256(K256VerifyingKey),
331}
332
333impl PublicKey {
334 /// Create a P-256 public key from compressed or uncompressed bytes.
335 #[cfg(feature = "crypto-p256")]
336 pub fn from_p256_bytes(bytes: &[u8]) -> Result<Self, ServiceAuthError> {
337 let key = P256VerifyingKey::from_sec1_bytes(bytes).map_err(|e| {
338 ServiceAuthError::Crypto(CowStr::Owned(format_smolstr!("invalid P-256 key: {}", e)))
339 })?;
340 Ok(PublicKey::P256(key))
341 }
342
343 /// Create a secp256k1 public key from compressed or uncompressed bytes.
344 #[cfg(feature = "crypto-k256")]
345 pub fn from_k256_bytes(bytes: &[u8]) -> Result<Self, ServiceAuthError> {
346 let key = K256VerifyingKey::from_sec1_bytes(bytes).map_err(|e| {
347 ServiceAuthError::Crypto(CowStr::Owned(format_smolstr!("invalid K-256 key: {}", e)))
348 })?;
349 Ok(PublicKey::K256(key))
350 }
351}
352
353/// Verify a JWT signature using the provided public key.
354///
355/// The algorithm is determined by the JWT header and must match the public key type.
356pub fn verify_signature(
357 parsed: &ParsedJwt,
358 public_key: &PublicKey,
359) -> Result<(), ServiceAuthError> {
360 let alg = parsed.header().alg.as_str();
361 let signing_input = parsed.signing_input();
362 let signature = parsed.signature();
363
364 match (alg, public_key) {
365 #[cfg(feature = "crypto-p256")]
366 ("ES256", PublicKey::P256(key)) => {
367 let sig = P256Signature::from_slice(signature).map_err(|e| {
368 ServiceAuthError::Crypto(CowStr::Owned(format_smolstr!(
369 "invalid ES256 signature: {}",
370 e
371 )))
372 })?;
373 key.verify(signing_input, &sig)
374 .map_err(|_| ServiceAuthError::InvalidSignature)?;
375 Ok(())
376 }
377
378 #[cfg(feature = "crypto-k256")]
379 ("ES256K", PublicKey::K256(key)) => {
380 let sig = K256Signature::from_slice(signature).map_err(|e| {
381 ServiceAuthError::Crypto(CowStr::Owned(format_smolstr!(
382 "invalid ES256K signature: {}",
383 e
384 )))
385 })?;
386 key.verify(signing_input, &sig)
387 .map_err(|_| ServiceAuthError::InvalidSignature)?;
388 Ok(())
389 }
390
391 _ => Err(ServiceAuthError::UnsupportedAlgorithm {
392 alg: SmolStr::new(alg),
393 }),
394 }
395}
396
397/// Parse and verify a service auth JWT in one step, returning owned claims.
398///
399/// This is a convenience function that combines parsing and signature verification.
400pub fn verify_service_jwt(
401 token: &str,
402 public_key: &PublicKey,
403) -> Result<ServiceAuthClaims<'static>, ServiceAuthError> {
404 let parsed = parse_jwt(token)?;
405 verify_signature(&parsed, public_key)?;
406 Ok(parsed.into_claims())
407}
408
409#[cfg(test)]
410mod tests {
411 use super::*;
412
413 #[test]
414 fn test_parse_jwt_invalid_format() {
415 let result = parse_jwt("not.a.valid.jwt.with.too.many.parts");
416 assert!(matches!(result, Err(ServiceAuthError::MalformedToken(_))));
417 }
418
419 #[test]
420 fn test_claims_expiration() {
421 let now = chrono::Utc::now().timestamp();
422 let expired_claims = ServiceAuthClaims {
423 iss: Did::new("did:plc:test").unwrap(),
424 aud: Did::new("did:web:example.com").unwrap(),
425 exp: now - 100,
426 iat: now - 200,
427 jti: None,
428 lxm: None,
429 };
430
431 assert!(expired_claims.is_expired());
432
433 let valid_claims = ServiceAuthClaims {
434 iss: Did::new("did:plc:test").unwrap(),
435 aud: Did::new("did:web:example.com").unwrap(),
436 exp: now + 100,
437 iat: now,
438 jti: None,
439 lxm: None,
440 };
441
442 assert!(!valid_claims.is_expired());
443 }
444
445 #[test]
446 fn test_audience_validation() {
447 let now = chrono::Utc::now().timestamp();
448 let claims = ServiceAuthClaims {
449 iss: Did::new("did:plc:test").unwrap(),
450 aud: Did::new("did:web:example.com").unwrap(),
451 exp: now + 100,
452 iat: now,
453 jti: None,
454 lxm: None,
455 };
456
457 let expected_aud = Did::new("did:web:example.com").unwrap();
458 assert!(claims.validate(&expected_aud).is_ok());
459
460 let wrong_aud = Did::new("did:web:wrong.com").unwrap();
461 assert!(matches!(
462 claims.validate(&wrong_aud),
463 Err(ServiceAuthError::AudienceMismatch { .. })
464 ));
465 }
466
467 #[test]
468 fn test_method_check() {
469 let claims = ServiceAuthClaims {
470 iss: Did::new("did:plc:test").unwrap(),
471 aud: Did::new("did:web:example.com").unwrap(),
472 exp: chrono::Utc::now().timestamp() + 100,
473 iat: chrono::Utc::now().timestamp(),
474 jti: None,
475 lxm: Some(Nsid::new("app.bsky.feed.getFeedSkeleton").unwrap()),
476 };
477
478 let expected = Nsid::new("app.bsky.feed.getFeedSkeleton").unwrap();
479 assert!(claims.check_method(&expected));
480
481 let wrong = Nsid::new("app.bsky.feed.getTimeline").unwrap();
482 assert!(!claims.check_method(&wrong));
483 }
484}