use crate::AppState; use crate::helpers::{ AuthResult, ProxiedResult, TokenCheckError, json_error_response, preauth_check, proxy_get_json, }; use crate::middleware::Did; use axum::body::Body; use axum::extract::State; use axum::http::{HeaderMap, StatusCode}; use axum::response::{IntoResponse, Response}; use axum::{Extension, Json, debug_handler, extract, extract::Request}; use serde::{Deserialize, Serialize}; use serde_json; use tracing::log; #[derive(Serialize, Deserialize, Debug, Clone)] #[serde(rename_all = "camelCase")] enum AccountStatus { Takendown, Suspended, Deactivated, } #[derive(Serialize, Deserialize, Debug, Clone)] #[serde(rename_all = "camelCase")] struct GetSessionResponse { handle: String, did: String, #[serde(skip_serializing_if = "Option::is_none")] email: Option, #[serde(skip_serializing_if = "Option::is_none")] email_confirmed: Option, #[serde(skip_serializing_if = "Option::is_none")] email_auth_factor: Option, #[serde(skip_serializing_if = "Option::is_none")] did_doc: Option, #[serde(skip_serializing_if = "Option::is_none")] active: Option, #[serde(skip_serializing_if = "Option::is_none")] status: Option, } #[derive(Serialize, Deserialize, Debug, Clone)] #[serde(rename_all = "camelCase")] pub struct UpdateEmailResponse { email: String, #[serde(skip_serializing_if = "Option::is_none")] email_auth_factor: Option, #[serde(skip_serializing_if = "Option::is_none")] token: Option, } #[allow(dead_code)] #[derive(Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct CreateSessionRequest { identifier: String, password: String, #[serde(skip_serializing_if = "Option::is_none")] auth_factor_token: Option, #[serde(skip_serializing_if = "Option::is_none")] allow_takendown: Option, } pub async fn create_session( State(state): State, headers: HeaderMap, Json(payload): extract::Json, ) -> Result, StatusCode> { let identifier = payload.identifier.clone(); let password = payload.password.clone(); let auth_factor_token = payload.auth_factor_token.clone(); // Run the shared pre-auth logic to validate and check 2FA requirement match preauth_check(&state, &identifier, &password, auth_factor_token, false).await { Ok(result) => match result { AuthResult::WrongIdentityOrPassword => json_error_response( StatusCode::UNAUTHORIZED, "AuthenticationRequired", "Invalid identifier or password", ), AuthResult::TwoFactorRequired(_) => { // Email sending step can be handled here if needed in the future. json_error_response( StatusCode::UNAUTHORIZED, "AuthFactorTokenRequired", "A sign in code has been sent to your email address", ) } AuthResult::ProxyThrough => { log::info!("Proxying through"); //No 2FA or already passed let uri = format!( "{}{}", state.pds_base_url, "/xrpc/com.atproto.server.createSession" ); let mut req = axum::http::Request::post(uri); if let Some(req_headers) = req.headers_mut() { req_headers.extend(headers.clone()); } let payload_bytes = serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?; let req = req .body(Body::from(payload_bytes)) .map_err(|_| StatusCode::BAD_REQUEST)?; let proxied = state .reverse_proxy_client .request(req) .await .map_err(|_| StatusCode::BAD_REQUEST)? .into_response(); Ok(proxied) } AuthResult::TokenCheckFailed(err) => match err { TokenCheckError::InvalidToken => { json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "Token is invalid") } TokenCheckError::ExpiredToken => { json_error_response(StatusCode::BAD_REQUEST, "ExpiredToken", "Token is expired") } }, }, Err(err) => { log::error!( "Error during pre-auth check. This happens on the create_session endpoint when trying to decide if the user has access:\n {err}" ); json_error_response( StatusCode::INTERNAL_SERVER_ERROR, "InternalServerError", "This error was not generated by the PDS, but PDS Gatekeeper. Please contact your PDS administrator for help and for them to review the server logs.", ) } } } #[debug_handler] pub async fn update_email( State(state): State, Extension(did): Extension, headers: HeaderMap, Json(payload): extract::Json, ) -> Result, StatusCode> { //If email auth is not set at all it is a update email address let email_auth_not_set = payload.email_auth_factor.is_none(); //If email auth is set it is to either turn on or off 2fa let email_auth_update = payload.email_auth_factor.unwrap_or(false); // Email update asked for if email_auth_update { let email = payload.email.clone(); let email_confirmed = sqlx::query_as::<_, (String,)>( "SELECT did FROM account WHERE emailConfirmedAt IS NOT NULL AND email = ?", ) .bind(&email) .fetch_optional(&state.account_pool) .await .map_err(|_| StatusCode::BAD_REQUEST)?; //Since the email is already confirmed we can enable 2fa return match email_confirmed { None => Err(StatusCode::BAD_REQUEST), Some(did_row) => { let _ = sqlx::query( "INSERT INTO two_factor_accounts (did, required) VALUES (?, 1) ON CONFLICT(did) DO UPDATE SET required = 1", ) .bind(&did_row.0) .execute(&state.pds_gatekeeper_pool) .await .map_err(|_| StatusCode::BAD_REQUEST)?; Ok(StatusCode::OK.into_response()) } }; } // User wants auth turned off if !email_auth_update && !email_auth_not_set { //User wants auth turned off and has a token if let Some(token) = &payload.token { let token_found = sqlx::query_as::<_, (String,)>( "SELECT token FROM email_token WHERE token = ? AND did = ? AND purpose = 'update_email'", ) .bind(token) .bind(&did.0) .fetch_optional(&state.account_pool) .await .map_err(|_| StatusCode::BAD_REQUEST)?; if token_found.is_some() { let _ = sqlx::query( "INSERT INTO two_factor_accounts (did, required) VALUES (?, 0) ON CONFLICT(did) DO UPDATE SET required = 0", ) .bind(&did.0) .execute(&state.pds_gatekeeper_pool) .await .map_err(|_| StatusCode::BAD_REQUEST)?; return Ok(StatusCode::OK.into_response()); } else { return Err(StatusCode::BAD_REQUEST); } } } // Updating the actual email address by sending it on to the PDS let uri = format!( "{}{}", state.pds_base_url, "/xrpc/com.atproto.server.updateEmail" ); let mut req = axum::http::Request::post(uri); if let Some(req_headers) = req.headers_mut() { req_headers.extend(headers.clone()); } let payload_bytes = serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?; let req = req .body(Body::from(payload_bytes)) .map_err(|_| StatusCode::BAD_REQUEST)?; let proxied = state .reverse_proxy_client .request(req) .await .map_err(|_| StatusCode::BAD_REQUEST)? .into_response(); Ok(proxied) } pub async fn get_session( State(state): State, req: Request, ) -> Result, StatusCode> { match proxy_get_json::(&state, req, "/xrpc/com.atproto.server.getSession") .await? { ProxiedResult::Parsed { value: mut session, .. } => { let did = session.did.clone(); let required_opt = sqlx::query_as::<_, (u8,)>( "SELECT required FROM two_factor_accounts WHERE did = ? LIMIT 1", ) .bind(&did) .fetch_optional(&state.pds_gatekeeper_pool) .await .map_err(|_| StatusCode::BAD_REQUEST)?; let email_auth_factor = match required_opt { Some(row) => row.0 != 0, None => false, }; session.email_auth_factor = Some(email_auth_factor); Ok(Json(session).into_response()) } ProxiedResult::Passthrough(resp) => Ok(resp), } }