use crate::helpers::json_error_response; use axum::extract::Request; use axum::http::{HeaderMap, StatusCode}; use axum::middleware::Next; use axum::response::IntoResponse; use jwt_compact::alg::{Hs256, Hs256Key}; use jwt_compact::{AlgorithmExt, Claims, Token, UntrustedToken, ValidationError}; use serde::{Deserialize, Serialize}; use std::env; use tracing::log; #[derive(Clone, Debug)] pub struct Did(pub Option); #[derive(Serialize, Deserialize)] pub struct TokenClaims { pub sub: String, } pub async fn extract_did(mut req: Request, next: Next) -> impl IntoResponse { let token = extract_bearer(req.headers()); match token { Ok(token) => { match token { None => json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "") .expect("Error creating an error response"), Some(token) => { let token = UntrustedToken::new(&token); if token.is_err() { return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "") .expect("Error creating an error response"); } let parsed_token = token.expect("Already checked for error"); let claims: Result, ValidationError> = parsed_token.deserialize_claims_unchecked(); if claims.is_err() { return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "") .expect("Error creating an error response"); } let key = Hs256Key::new( env::var("PDS_JWT_SECRET").expect("PDS_JWT_SECRET not set in the pds.env"), ); let token: Result, ValidationError> = Hs256.validator(&key).validate(&parsed_token); if token.is_err() { return json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "") .expect("Error creating an error response"); } let token = token.expect("Already checked for error,"); //Not going to worry about expiration since it still goes to the PDS req.extensions_mut() .insert(Did(Some(token.claims().custom.sub.clone()))); next.run(req).await } } } Err(err) => { log::error!("Error extracting token: {err}"); json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "") .expect("Error creating an error response") } } } fn extract_bearer(headers: &HeaderMap) -> Result, String> { match headers.get(axum::http::header::AUTHORIZATION) { None => Ok(None), Some(hv) => match hv.to_str() { Err(_) => Err("Authorization header is not valid".into()), Ok(s) => { // Accept forms like: "Bearer " (case-sensitive for the scheme here) let mut parts = s.splitn(2, ' '); match (parts.next(), parts.next()) { (Some("Bearer"), Some(tok)) if !tok.is_empty() => Ok(Some(tok.to_string())), _ => Err("Authorization header must be in format 'Bearer '".into()), } } }, } }