Our Personal Data Server from scratch!
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

pds uses repository pattern, delete inline db code

+5434 -8712
+1
crates/tranquil-comms/Cargo.toml
··· 14 14 sqlx = { workspace = true } 15 15 thiserror = { workspace = true } 16 16 tokio = { workspace = true } 17 + tranquil-db-traits = { workspace = true } 17 18 urlencoding = { workspace = true } 18 19 uuid = { workspace = true }
+1 -58
crates/tranquil-comms/src/types.rs
··· 1 - use chrono::{DateTime, Utc}; 2 - use serde::{Deserialize, Serialize}; 3 1 use uuid::Uuid; 4 2 5 - #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)] 6 - #[serde(rename_all = "lowercase")] 7 - #[sqlx(type_name = "comms_channel", rename_all = "lowercase")] 8 - pub enum CommsChannel { 9 - Email, 10 - Discord, 11 - Telegram, 12 - Signal, 13 - } 14 - 15 - #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, sqlx::Type)] 16 - #[serde(rename_all = "lowercase")] 17 - #[sqlx(type_name = "comms_status", rename_all = "lowercase")] 18 - pub enum CommsStatus { 19 - Pending, 20 - Processing, 21 - Sent, 22 - Failed, 23 - } 24 - 25 - #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, sqlx::Type)] 26 - #[serde(rename_all = "snake_case")] 27 - #[sqlx(type_name = "comms_type", rename_all = "snake_case")] 28 - pub enum CommsType { 29 - Welcome, 30 - EmailVerification, 31 - PasswordReset, 32 - EmailUpdate, 33 - AccountDeletion, 34 - AdminEmail, 35 - PlcOperation, 36 - TwoFactorCode, 37 - PasskeyRecovery, 38 - LegacyLoginAlert, 39 - MigrationVerification, 40 - } 41 - 42 - #[derive(Debug, Clone)] 43 - pub struct QueuedComms { 44 - pub id: Uuid, 45 - pub user_id: Uuid, 46 - pub channel: CommsChannel, 47 - pub comms_type: CommsType, 48 - pub status: CommsStatus, 49 - pub recipient: String, 50 - pub subject: Option<String>, 51 - pub body: String, 52 - pub metadata: Option<serde_json::Value>, 53 - pub attempts: i32, 54 - pub max_attempts: i32, 55 - pub last_error: Option<String>, 56 - pub created_at: DateTime<Utc>, 57 - pub updated_at: DateTime<Utc>, 58 - pub scheduled_for: DateTime<Utc>, 59 - pub processed_at: Option<DateTime<Utc>>, 60 - } 3 + pub use tranquil_db_traits::{CommsChannel, CommsStatus, CommsType, QueuedComms}; 61 4 62 5 pub struct NewComms { 63 6 pub user_id: Uuid,
+2
crates/tranquil-pds/Cargo.toml
··· 15 15 tranquil-auth = { workspace = true } 16 16 tranquil-oauth = { workspace = true } 17 17 tranquil-comms = { workspace = true } 18 + tranquil-db = { workspace = true } 19 + tranquil-db-traits = { workspace = true } 18 20 19 21 aes-gcm = { workspace = true } 20 22 backon = { workspace = true }
+38 -79
crates/tranquil-pds/src/api/actor/preferences.rs
··· 38 38 ) -> Response { 39 39 let auth_user = auth.0; 40 40 let has_full_access = auth_user.permissions().has_full_access(); 41 - let user_id: uuid::Uuid = 42 - match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", &*auth_user.did) 43 - .fetch_optional(&state.db) 44 - .await 45 - { 46 - Ok(Some(id)) => id, 47 - _ => { 48 - return ApiError::InternalError(Some("User not found".into())).into_response(); 49 - } 50 - }; 51 - let prefs_result = sqlx::query!( 52 - "SELECT name, value_json FROM account_preferences WHERE user_id = $1", 53 - user_id 54 - ) 55 - .fetch_all(&state.db) 56 - .await; 57 - let prefs = match prefs_result { 41 + let user_id: uuid::Uuid = match state.user_repo.get_id_by_did(&auth_user.did).await { 42 + Ok(Some(id)) => id, 43 + _ => { 44 + return ApiError::InternalError(Some("User not found".into())).into_response(); 45 + } 46 + }; 47 + let prefs = match state.infra_repo.get_account_preferences(user_id).await { 58 48 Ok(rows) => rows, 59 49 Err(_) => { 60 50 return ApiError::InternalError(Some("Failed to fetch preferences".into())) ··· 64 54 let mut personal_details_pref: Option<Value> = None; 65 55 let mut preferences: Vec<Value> = prefs 66 56 .into_iter() 67 - .filter(|row| { 68 - row.name == APP_BSKY_NAMESPACE 69 - || row.name.starts_with(&format!("{}.", APP_BSKY_NAMESPACE)) 57 + .filter(|(name, _)| { 58 + name == APP_BSKY_NAMESPACE || name.starts_with(&format!("{}.", APP_BSKY_NAMESPACE)) 70 59 }) 71 - .filter_map(|row| { 72 - if row.name == DECLARED_AGE_PREF { 60 + .filter_map(|(name, value_json)| { 61 + if name == DECLARED_AGE_PREF { 73 62 return None; 74 63 } 75 - if row.name == PERSONAL_DETAILS_PREF { 64 + if name == PERSONAL_DETAILS_PREF { 76 65 if !has_full_access { 77 66 return None; 78 67 } 79 - personal_details_pref = serde_json::from_value(row.value_json.clone()).ok(); 68 + personal_details_pref = serde_json::from_value(value_json.clone()).ok(); 80 69 } 81 - serde_json::from_value(row.value_json).ok() 70 + serde_json::from_value(value_json).ok() 82 71 }) 83 72 .collect(); 84 73 if let Some(age) = personal_details_pref ··· 109 98 ) -> Response { 110 99 let auth_user = auth.0; 111 100 let has_full_access = auth_user.permissions().has_full_access(); 112 - let user_id: uuid::Uuid = 113 - match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", &*auth_user.did) 114 - .fetch_optional(&state.db) 115 - .await 116 - { 117 - Ok(Some(id)) => id, 118 - _ => { 119 - return ApiError::InternalError(Some("User not found".into())).into_response(); 120 - } 121 - }; 101 + let user_id: uuid::Uuid = match state.user_repo.get_id_by_did(&auth_user.did).await { 102 + Ok(Some(id)) => id, 103 + _ => { 104 + return ApiError::InternalError(Some("User not found".into())).into_response(); 105 + } 106 + }; 122 107 if input.preferences.len() > MAX_PREFERENCES_COUNT { 123 108 return ApiError::InvalidRequest(format!( 124 109 "Too many preferences: {} exceeds limit of {}", ··· 195 180 )) 196 181 .into_response(); 197 182 } 198 - let mut tx = match state.db.begin().await { 199 - Ok(tx) => tx, 200 - Err(_) => { 201 - return ApiError::InternalError(Some("Failed to start transaction".into())) 202 - .into_response(); 203 - } 204 - }; 205 - let delete_result = sqlx::query!( 206 - "DELETE FROM account_preferences WHERE user_id = $1 AND (name = $2 OR name LIKE $3)", 207 - user_id, 208 - APP_BSKY_NAMESPACE, 209 - format!("{}.%", APP_BSKY_NAMESPACE) 210 - ) 211 - .execute(&mut *tx) 212 - .await; 213 - if delete_result.is_err() { 214 - let _ = tx.rollback().await; 215 - return ApiError::InternalError(Some("Failed to clear preferences".into())).into_response(); 216 - } 217 - for pref in input.preferences { 218 - let pref_type = match pref.get("$type").and_then(|t| t.as_str()) { 219 - Some(t) => t, 220 - None => continue, 221 - }; 222 - if pref_type == DECLARED_AGE_PREF { 223 - continue; 224 - } 225 - let insert_result = sqlx::query!( 226 - "INSERT INTO account_preferences (user_id, name, value_json) VALUES ($1, $2, $3)", 227 - user_id, 228 - pref_type, 229 - pref 230 - ) 231 - .execute(&mut *tx) 232 - .await; 233 - if insert_result.is_err() { 234 - let _ = tx.rollback().await; 235 - return ApiError::InternalError(Some("Failed to save preference".into())) 236 - .into_response(); 237 - } 238 - } 239 - if tx.commit().await.is_err() { 240 - return ApiError::InternalError(Some("Failed to commit transaction".into())) 241 - .into_response(); 183 + let prefs_to_save: Vec<(String, Value)> = input 184 + .preferences 185 + .into_iter() 186 + .filter_map(|pref| { 187 + let pref_type = pref.get("$type").and_then(|t| t.as_str())?; 188 + if pref_type == DECLARED_AGE_PREF { 189 + return None; 190 + } 191 + Some((pref_type.to_string(), pref)) 192 + }) 193 + .collect(); 194 + 195 + if let Err(_) = state 196 + .infra_repo 197 + .replace_namespace_preferences(user_id, APP_BSKY_NAMESPACE, prefs_to_save) 198 + .await 199 + { 200 + return ApiError::InternalError(Some("Failed to save preferences".into())).into_response(); 242 201 } 243 202 StatusCode::OK.into_response() 244 203 }
+6 -96
crates/tranquil-pds/src/api/admin/account/delete.rs
··· 22 22 Json(input): Json<DeleteAccountInput>, 23 23 ) -> Response { 24 24 let did = &input.did; 25 - let user = sqlx::query!("SELECT id, handle FROM users WHERE did = $1", did.as_str()) 26 - .fetch_optional(&state.db) 27 - .await; 28 - let (user_id, handle) = match user { 25 + let (user_id, handle) = match state.user_repo.get_id_and_handle_by_did(did).await { 29 26 Ok(Some(row)) => (row.id, row.handle), 30 27 Ok(None) => { 31 28 return ApiError::AccountNotFound.into_response(); ··· 35 32 return ApiError::InternalError(None).into_response(); 36 33 } 37 34 }; 38 - let mut tx = match state.db.begin().await { 39 - Ok(tx) => tx, 40 - Err(e) => { 41 - error!("Failed to begin transaction for account deletion: {:?}", e); 42 - return ApiError::InternalError(None).into_response(); 43 - } 44 - }; 45 - if let Err(e) = sqlx::query!("DELETE FROM session_tokens WHERE did = $1", did.as_str()) 46 - .execute(&mut *tx) 47 - .await 48 - { 49 - error!("Failed to delete session tokens for {}: {:?}", did, e); 50 - return ApiError::InternalError(Some("Failed to delete session tokens".into())) 51 - .into_response(); 52 - } 53 - if let Err(e) = sqlx::query!("DELETE FROM used_refresh_tokens WHERE session_id IN (SELECT id FROM session_tokens WHERE did = $1)", did.as_str()) 54 - .execute(&mut *tx) 55 - .await 56 - { 57 - error!("Failed to delete used refresh tokens for {}: {:?}", did, e); 58 - } 59 - if let Err(e) = sqlx::query!("DELETE FROM records WHERE repo_id = $1", user_id) 60 - .execute(&mut *tx) 61 - .await 62 - { 63 - error!("Failed to delete records for user {}: {:?}", user_id, e); 64 - return ApiError::InternalError(Some("Failed to delete records".into())).into_response(); 65 - } 66 - if let Err(e) = sqlx::query!("DELETE FROM repos WHERE user_id = $1", user_id) 67 - .execute(&mut *tx) 68 - .await 69 - { 70 - error!("Failed to delete repos for user {}: {:?}", user_id, e); 71 - return ApiError::InternalError(Some("Failed to delete repos".into())).into_response(); 72 - } 73 - if let Err(e) = sqlx::query!("DELETE FROM blobs WHERE created_by_user = $1", user_id) 74 - .execute(&mut *tx) 75 - .await 76 - { 77 - error!("Failed to delete blobs for user {}: {:?}", user_id, e); 78 - return ApiError::InternalError(Some("Failed to delete blobs".into())).into_response(); 79 - } 80 - if let Err(e) = sqlx::query!("DELETE FROM app_passwords WHERE user_id = $1", user_id) 81 - .execute(&mut *tx) 82 - .await 83 - { 84 - error!( 85 - "Failed to delete app passwords for user {}: {:?}", 86 - user_id, e 87 - ); 88 - return ApiError::InternalError(Some("Failed to delete app passwords".into())) 89 - .into_response(); 90 - } 91 - if let Err(e) = sqlx::query!( 92 - "DELETE FROM invite_code_uses WHERE used_by_user = $1", 93 - user_id 94 - ) 95 - .execute(&mut *tx) 96 - .await 97 - { 98 - error!( 99 - "Failed to delete invite code uses for user {}: {:?}", 100 - user_id, e 101 - ); 102 - } 103 - if let Err(e) = sqlx::query!( 104 - "DELETE FROM invite_codes WHERE created_by_user = $1", 105 - user_id 106 - ) 107 - .execute(&mut *tx) 108 - .await 109 - { 110 - error!( 111 - "Failed to delete invite codes for user {}: {:?}", 112 - user_id, e 113 - ); 114 - } 115 - if let Err(e) = sqlx::query!("DELETE FROM user_keys WHERE user_id = $1", user_id) 116 - .execute(&mut *tx) 35 + if let Err(e) = state 36 + .user_repo 37 + .admin_delete_account_complete(user_id, did) 117 38 .await 118 39 { 119 - error!("Failed to delete user keys for user {}: {:?}", user_id, e); 120 - return ApiError::InternalError(Some("Failed to delete user keys".into())).into_response(); 121 - } 122 - if let Err(e) = sqlx::query!("DELETE FROM users WHERE id = $1", user_id) 123 - .execute(&mut *tx) 124 - .await 125 - { 126 - error!("Failed to delete user {}: {:?}", user_id, e); 127 - return ApiError::InternalError(Some("Failed to delete user".into())).into_response(); 128 - } 129 - if let Err(e) = tx.commit().await { 130 - error!("Failed to commit account deletion transaction: {:?}", e); 131 - return ApiError::InternalError(Some("Failed to commit deletion".into())).into_response(); 40 + error!("Failed to delete account {}: {:?}", did, e); 41 + return ApiError::InternalError(Some("Failed to delete account".into())).into_response(); 132 42 } 133 43 if let Err(e) = 134 44 crate::api::repo::record::sequence_account_event(&state, did, false, Some("deleted")).await
+21 -24
crates/tranquil-pds/src/api/admin/account/email.rs
··· 35 35 if content.is_empty() { 36 36 return ApiError::InvalidRequest("content is required".into()).into_response(); 37 37 } 38 - let user = sqlx::query!( 39 - "SELECT id, email, handle FROM users WHERE did = $1", 40 - input.recipient_did.as_str() 41 - ) 42 - .fetch_optional(&state.db) 43 - .await; 44 - let (user_id, email, handle) = match user { 45 - Ok(Some(row)) => { 46 - let email = match row.email { 47 - Some(e) => e, 48 - None => { 49 - return ApiError::NoEmail.into_response(); 50 - } 51 - }; 52 - (row.id, email, row.handle) 53 - } 38 + let user = match state.user_repo.get_by_did(&input.recipient_did).await { 39 + Ok(Some(row)) => row, 54 40 Ok(None) => { 55 41 return ApiError::AccountNotFound.into_response(); 56 42 } ··· 59 45 return ApiError::InternalError(None).into_response(); 60 46 } 61 47 }; 48 + let email = match user.email { 49 + Some(e) => e, 50 + None => { 51 + return ApiError::NoEmail.into_response(); 52 + } 53 + }; 54 + let (user_id, handle) = (user.id, user.handle); 62 55 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 63 56 let subject = input 64 57 .subject 65 58 .clone() 66 59 .unwrap_or_else(|| format!("Message from {}", hostname)); 67 - let item = crate::comms::NewComms::email( 68 - user_id, 69 - crate::comms::CommsType::AdminEmail, 70 - email, 71 - subject, 72 - content.to_string(), 73 - ); 74 - let result = crate::comms::enqueue_comms(&state.db, item).await; 60 + let result = state 61 + .infra_repo 62 + .enqueue_comms( 63 + Some(user_id), 64 + tranquil_db_traits::CommsChannel::Email, 65 + tranquil_db_traits::CommsType::AdminEmail, 66 + &email, 67 + Some(&subject), 68 + content, 69 + None, 70 + ) 71 + .await; 75 72 match result { 76 73 Ok(_) => { 77 74 tracing::info!(
+134 -204
crates/tranquil-pds/src/api/admin/account/info.rs
··· 9 9 response::{IntoResponse, Response}, 10 10 }; 11 11 use serde::{Deserialize, Serialize}; 12 + use std::collections::HashMap; 12 13 use tracing::error; 13 14 14 15 #[derive(Deserialize)] ··· 43 44 pub code: String, 44 45 pub available: i32, 45 46 pub disabled: bool, 46 - pub for_account: Did, 47 - pub created_by: Did, 47 + #[serde(skip_serializing_if = "Option::is_none")] 48 + pub for_account: Option<Did>, 49 + #[serde(skip_serializing_if = "Option::is_none")] 50 + pub created_by: Option<Did>, 48 51 pub created_at: String, 49 52 pub uses: Vec<InviteCodeUseInfo>, 50 53 } ··· 67 70 _auth: BearerAuthAdmin, 68 71 Query(params): Query<GetAccountInfoParams>, 69 72 ) -> Response { 70 - let result = sqlx::query!( 71 - r#" 72 - SELECT id, did, handle, email, created_at, invites_disabled, email_verified, deactivated_at 73 - FROM users 74 - WHERE did = $1 75 - "#, 76 - params.did.as_str() 77 - ) 78 - .fetch_optional(&state.db) 79 - .await; 80 - match result { 81 - Ok(Some(row)) => { 82 - let invited_by = get_invited_by(&state.db, row.id).await; 83 - let invites = get_invites_for_user(&state.db, row.id).await; 84 - ( 85 - StatusCode::OK, 86 - Json(AccountInfo { 87 - did: row.did.into(), 88 - handle: row.handle.into(), 89 - email: row.email, 90 - indexed_at: row.created_at.to_rfc3339(), 91 - invite_note: None, 92 - invites_disabled: row.invites_disabled.unwrap_or(false), 93 - email_confirmed_at: if row.email_verified { 94 - Some(row.created_at.to_rfc3339()) 95 - } else { 96 - None 97 - }, 98 - deactivated_at: row.deactivated_at.map(|dt| dt.to_rfc3339()), 99 - invited_by, 100 - invites, 101 - }), 102 - ) 103 - .into_response() 104 - } 105 - Ok(None) => ApiError::AccountNotFound.into_response(), 73 + let account = match state.infra_repo.get_admin_account_info_by_did(&params.did).await { 74 + Ok(Some(a)) => a, 75 + Ok(None) => return ApiError::AccountNotFound.into_response(), 106 76 Err(e) => { 107 77 error!("DB error in get_account_info: {:?}", e); 108 - ApiError::InternalError(None).into_response() 78 + return ApiError::InternalError(None).into_response(); 109 79 } 110 - } 80 + }; 81 + 82 + let invited_by = get_invited_by(&state, account.id).await; 83 + let invites = get_invites_for_user(&state, account.id).await; 84 + 85 + ( 86 + StatusCode::OK, 87 + Json(AccountInfo { 88 + did: account.did, 89 + handle: account.handle, 90 + email: account.email, 91 + indexed_at: account.created_at.to_rfc3339(), 92 + invite_note: None, 93 + invites_disabled: account.invites_disabled, 94 + email_confirmed_at: if account.email_verified { 95 + Some(account.created_at.to_rfc3339()) 96 + } else { 97 + None 98 + }, 99 + deactivated_at: account.deactivated_at.map(|dt| dt.to_rfc3339()), 100 + invited_by, 101 + invites, 102 + }), 103 + ) 104 + .into_response() 111 105 } 112 106 113 - async fn get_invited_by(db: &sqlx::PgPool, user_id: uuid::Uuid) -> Option<InviteCodeInfo> { 114 - let use_row = sqlx::query!( 115 - r#" 116 - SELECT icu.code 117 - FROM invite_code_uses icu 118 - WHERE icu.used_by_user = $1 119 - LIMIT 1 120 - "#, 121 - user_id 122 - ) 123 - .fetch_optional(db) 124 - .await 125 - .ok()??; 126 - get_invite_code_info(db, &use_row.code).await 107 + async fn get_invited_by(state: &AppState, user_id: uuid::Uuid) -> Option<InviteCodeInfo> { 108 + let code = state 109 + .infra_repo 110 + .get_invite_code_used_by_user(user_id) 111 + .await 112 + .ok()??; 113 + 114 + get_invite_code_info(state, &code).await 127 115 } 128 116 129 - async fn get_invites_for_user( 130 - db: &sqlx::PgPool, 131 - user_id: uuid::Uuid, 132 - ) -> Option<Vec<InviteCodeInfo>> { 133 - let invite_codes = sqlx::query!( 134 - r#" 135 - SELECT ic.code, ic.available_uses, ic.disabled, ic.for_account, ic.created_at, u.did as created_by 136 - FROM invite_codes ic 137 - JOIN users u ON ic.created_by_user = u.id 138 - WHERE ic.created_by_user = $1 139 - "#, 140 - user_id 141 - ) 142 - .fetch_all(db) 143 - .await 144 - .ok()?; 117 + async fn get_invites_for_user(state: &AppState, user_id: uuid::Uuid) -> Option<Vec<InviteCodeInfo>> { 118 + let invite_codes = state 119 + .infra_repo 120 + .get_invites_created_by_user(user_id) 121 + .await 122 + .ok()?; 145 123 146 124 if invite_codes.is_empty() { 147 125 return None; 148 126 } 149 127 150 128 let code_strings: Vec<String> = invite_codes.iter().map(|ic| ic.code.clone()).collect(); 151 - let mut uses_by_code: std::collections::HashMap<String, Vec<InviteCodeUseInfo>> = 152 - std::collections::HashMap::new(); 153 - sqlx::query!( 154 - r#" 155 - SELECT icu.code, u.did as used_by, icu.used_at 156 - FROM invite_code_uses icu 157 - JOIN users u ON icu.used_by_user = u.id 158 - WHERE icu.code = ANY($1) 159 - "#, 160 - &code_strings 161 - ) 162 - .fetch_all(db) 163 - .await 164 - .ok()? 165 - .into_iter() 166 - .for_each(|r| { 167 - uses_by_code 168 - .entry(r.code) 169 - .or_default() 170 - .push(InviteCodeUseInfo { 171 - used_by: r.used_by.into(), 172 - used_at: r.used_at.to_rfc3339(), 129 + 130 + let uses = state 131 + .infra_repo 132 + .get_invite_code_uses_batch(&code_strings) 133 + .await 134 + .ok()?; 135 + 136 + let uses_by_code: HashMap<String, Vec<InviteCodeUseInfo>> = 137 + uses.into_iter().fold(HashMap::new(), |mut acc, u| { 138 + acc.entry(u.code.clone()).or_default().push(InviteCodeUseInfo { 139 + used_by: u.used_by_did, 140 + used_at: u.used_at.to_rfc3339(), 173 141 }); 174 - }); 142 + acc 143 + }); 175 144 176 145 let invites: Vec<InviteCodeInfo> = invite_codes 177 146 .into_iter() 178 147 .map(|ic| InviteCodeInfo { 179 148 code: ic.code.clone(), 180 149 available: ic.available_uses, 181 - disabled: ic.disabled.unwrap_or(false), 182 - for_account: ic.for_account.into(), 183 - created_by: ic.created_by.into(), 150 + disabled: ic.disabled, 151 + for_account: ic.for_account, 152 + created_by: ic.created_by, 184 153 created_at: ic.created_at.to_rfc3339(), 185 154 uses: uses_by_code.get(&ic.code).cloned().unwrap_or_default(), 186 155 }) ··· 193 162 } 194 163 } 195 164 196 - async fn get_invite_code_info(db: &sqlx::PgPool, code: &str) -> Option<InviteCodeInfo> { 197 - let row = sqlx::query!( 198 - r#" 199 - SELECT ic.code, ic.available_uses, ic.disabled, ic.for_account, ic.created_at, u.did as created_by 200 - FROM invite_codes ic 201 - JOIN users u ON ic.created_by_user = u.id 202 - WHERE ic.code = $1 203 - "#, 204 - code 205 - ) 206 - .fetch_optional(db) 207 - .await 208 - .ok()??; 209 - let uses = sqlx::query!( 210 - r#" 211 - SELECT u.did as used_by, icu.used_at 212 - FROM invite_code_uses icu 213 - JOIN users u ON icu.used_by_user = u.id 214 - WHERE icu.code = $1 215 - "#, 216 - code 217 - ) 218 - .fetch_all(db) 219 - .await 220 - .ok()?; 165 + async fn get_invite_code_info(state: &AppState, code: &str) -> Option<InviteCodeInfo> { 166 + let info = state.infra_repo.get_invite_code_info(code).await.ok()??; 167 + 168 + let uses = state 169 + .infra_repo 170 + .get_invite_code_uses(code) 171 + .await 172 + .ok() 173 + .unwrap_or_default(); 174 + 221 175 Some(InviteCodeInfo { 222 - code: row.code, 223 - available: row.available_uses, 224 - disabled: row.disabled.unwrap_or(false), 225 - for_account: row.for_account.into(), 226 - created_by: row.created_by.into(), 227 - created_at: row.created_at.to_rfc3339(), 176 + code: info.code, 177 + available: info.available_uses, 178 + disabled: info.disabled, 179 + for_account: info.for_account, 180 + created_by: info.created_by, 181 + created_at: info.created_at.to_rfc3339(), 228 182 uses: uses 229 183 .into_iter() 230 184 .map(|u| InviteCodeUseInfo { 231 - used_by: u.used_by.into(), 185 + used_by: u.used_by_did, 232 186 used_at: u.used_at.to_rfc3339(), 233 187 }) 234 188 .collect(), ··· 244 198 .into_iter() 245 199 .filter(|d| !d.is_empty()) 246 200 .collect(); 201 + 247 202 if dids.is_empty() { 248 203 return ApiError::InvalidRequest("dids is required".into()).into_response(); 249 204 } 250 - let users = match sqlx::query!( 251 - r#" 252 - SELECT id, did, handle, email, created_at, invites_disabled, email_verified, deactivated_at 253 - FROM users 254 - WHERE did = ANY($1) 255 - "#, 256 - &dids 257 - ) 258 - .fetch_all(&state.db) 259 - .await 260 - { 261 - Ok(rows) => rows, 205 + 206 + let dids_typed: Vec<Did> = dids 207 + .iter() 208 + .filter_map(|d| d.parse().ok()) 209 + .collect(); 210 + let accounts = match state.infra_repo.get_admin_account_infos_by_dids(&dids_typed).await { 211 + Ok(accounts) => accounts, 262 212 Err(e) => { 263 213 error!("Failed to fetch account infos: {:?}", e); 264 214 return ApiError::InternalError(None).into_response(); 265 215 } 266 216 }; 267 217 268 - let user_ids: Vec<uuid::Uuid> = users.iter().map(|u| u.id).collect(); 218 + let user_ids: Vec<uuid::Uuid> = accounts.iter().map(|u| u.id).collect(); 269 219 270 - let all_invite_codes = sqlx::query!( 271 - r#" 272 - SELECT ic.code, ic.available_uses, ic.disabled, ic.for_account, ic.created_at, 273 - ic.created_by_user, u.did as created_by 274 - FROM invite_codes ic 275 - JOIN users u ON ic.created_by_user = u.id 276 - WHERE ic.created_by_user = ANY($1) 277 - "#, 278 - &user_ids 279 - ) 280 - .fetch_all(&state.db) 281 - .await 282 - .unwrap_or_default(); 220 + let all_invite_codes = state 221 + .infra_repo 222 + .get_invite_codes_by_users(&user_ids) 223 + .await 224 + .unwrap_or_default(); 225 + 226 + let all_codes: Vec<String> = all_invite_codes.iter().map(|(_, c)| c.code.clone()).collect(); 283 227 284 - let all_codes: Vec<String> = all_invite_codes.iter().map(|c| c.code.clone()).collect(); 285 228 let all_invite_uses = if !all_codes.is_empty() { 286 - sqlx::query!( 287 - r#" 288 - SELECT icu.code, u.did as used_by, icu.used_at 289 - FROM invite_code_uses icu 290 - JOIN users u ON icu.used_by_user = u.id 291 - WHERE icu.code = ANY($1) 292 - "#, 293 - &all_codes 294 - ) 295 - .fetch_all(&state.db) 296 - .await 297 - .unwrap_or_default() 229 + state 230 + .infra_repo 231 + .get_invite_code_uses_batch(&all_codes) 232 + .await 233 + .unwrap_or_default() 298 234 } else { 299 235 Vec::new() 300 236 }; 301 237 302 - let invited_by_map: std::collections::HashMap<uuid::Uuid, String> = sqlx::query!( 303 - r#" 304 - SELECT icu.used_by_user, icu.code 305 - FROM invite_code_uses icu 306 - WHERE icu.used_by_user = ANY($1) 307 - "#, 308 - &user_ids 309 - ) 310 - .fetch_all(&state.db) 311 - .await 312 - .unwrap_or_default() 313 - .into_iter() 314 - .map(|r| (r.used_by_user, r.code)) 315 - .collect(); 238 + let invited_by_map: HashMap<uuid::Uuid, String> = state 239 + .infra_repo 240 + .get_invite_code_uses_by_users(&user_ids) 241 + .await 242 + .unwrap_or_default() 243 + .into_iter() 244 + .collect(); 316 245 317 - let uses_by_code: std::collections::HashMap<String, Vec<InviteCodeUseInfo>> = 246 + let uses_by_code: HashMap<String, Vec<InviteCodeUseInfo>> = 318 247 all_invite_uses 319 248 .into_iter() 320 - .fold(std::collections::HashMap::new(), |mut acc, u| { 249 + .fold(HashMap::new(), |mut acc, u| { 321 250 acc.entry(u.code.clone()).or_default().push(InviteCodeUseInfo { 322 - used_by: u.used_by.into(), 251 + used_by: u.used_by_did, 323 252 used_at: u.used_at.to_rfc3339(), 324 253 }); 325 254 acc 326 255 }); 327 256 328 257 let (codes_by_user, code_info_map): ( 329 - std::collections::HashMap<uuid::Uuid, Vec<InviteCodeInfo>>, 330 - std::collections::HashMap<String, InviteCodeInfo>, 258 + HashMap<uuid::Uuid, Vec<InviteCodeInfo>>, 259 + HashMap<String, InviteCodeInfo>, 331 260 ) = all_invite_codes.into_iter().fold( 332 - (std::collections::HashMap::new(), std::collections::HashMap::new()), 333 - |(mut by_user, mut by_code), ic| { 261 + (HashMap::new(), HashMap::new()), 262 + |(mut by_user, mut by_code), (user_id, ic)| { 334 263 let info = InviteCodeInfo { 335 264 code: ic.code.clone(), 336 265 available: ic.available_uses, 337 - disabled: ic.disabled.unwrap_or(false), 338 - for_account: ic.for_account.into(), 339 - created_by: ic.created_by.into(), 266 + disabled: ic.disabled, 267 + for_account: ic.for_account, 268 + created_by: ic.created_by, 340 269 created_at: ic.created_at.to_rfc3339(), 341 270 uses: uses_by_code.get(&ic.code).cloned().unwrap_or_default(), 342 271 }; 343 272 by_code.insert(ic.code.clone(), info.clone()); 344 - by_user.entry(ic.created_by_user).or_default().push(info); 273 + by_user.entry(user_id).or_default().push(info); 345 274 (by_user, by_code) 346 275 }, 347 276 ); 348 277 349 - let infos: Vec<AccountInfo> = users 278 + let infos: Vec<AccountInfo> = accounts 350 279 .into_iter() 351 - .map(|row| { 280 + .map(|account| { 352 281 let invited_by = invited_by_map 353 - .get(&row.id) 282 + .get(&account.id) 354 283 .and_then(|code| code_info_map.get(code).cloned()); 355 - let invites = codes_by_user.get(&row.id).cloned(); 284 + let invites = codes_by_user.get(&account.id).cloned(); 356 285 AccountInfo { 357 - did: row.did.into(), 358 - handle: row.handle.into(), 359 - email: row.email, 360 - indexed_at: row.created_at.to_rfc3339(), 286 + did: account.did, 287 + handle: account.handle, 288 + email: account.email, 289 + indexed_at: account.created_at.to_rfc3339(), 361 290 invite_note: None, 362 - invites_disabled: row.invites_disabled.unwrap_or(false), 363 - email_confirmed_at: if row.email_verified { 364 - Some(row.created_at.to_rfc3339()) 291 + invites_disabled: account.invites_disabled, 292 + email_confirmed_at: if account.email_verified { 293 + Some(account.created_at.to_rfc3339()) 365 294 } else { 366 295 None 367 296 }, 368 - deactivated_at: row.deactivated_at.map(|dt| dt.to_rfc3339()), 297 + deactivated_at: account.deactivated_at.map(|dt| dt.to_rfc3339()), 369 298 invited_by, 370 299 invites, 371 300 } 372 301 }) 373 302 .collect(); 303 + 374 304 (StatusCode::OK, Json(GetAccountInfosOutput { infos })).into_response() 375 305 }
+22 -53
crates/tranquil-pds/src/api/admin/account/search.rs
··· 54 54 Query(params): Query<SearchAccountsParams>, 55 55 ) -> Response { 56 56 let limit = params.limit.clamp(1, 100); 57 - let cursor_did = params.cursor.as_deref().unwrap_or(""); 58 57 let email_filter = params.email.as_deref().map(|e| format!("%{}%", e)); 59 58 let handle_filter = params.handle.as_deref().map(|h| format!("%{}%", h)); 60 - let result = sqlx::query_as::< 61 - _, 62 - ( 63 - String, 64 - String, 65 - Option<String>, 66 - chrono::DateTime<chrono::Utc>, 67 - bool, 68 - Option<chrono::DateTime<chrono::Utc>>, 69 - Option<bool>, 70 - ), 71 - >( 72 - r#" 73 - SELECT did, handle, email, created_at, email_verified, deactivated_at, invites_disabled 74 - FROM users 75 - WHERE did > $1 76 - AND ($2::text IS NULL OR email ILIKE $2) 77 - AND ($3::text IS NULL OR handle ILIKE $3) 78 - ORDER BY did ASC 79 - LIMIT $4 80 - "#, 81 - ) 82 - .bind(cursor_did) 83 - .bind(&email_filter) 84 - .bind(&handle_filter) 85 - .bind(limit + 1) 86 - .fetch_all(&state.db) 87 - .await; 59 + let cursor_did: Option<Did> = params.cursor.as_ref().and_then(|c| c.parse().ok()); 60 + let result = state 61 + .user_repo 62 + .search_accounts( 63 + cursor_did.as_ref(), 64 + email_filter.as_deref(), 65 + handle_filter.as_deref(), 66 + limit + 1, 67 + ) 68 + .await; 88 69 match result { 89 70 Ok(rows) => { 90 71 let has_more = rows.len() > limit as usize; 91 72 let accounts: Vec<AccountView> = rows 92 73 .into_iter() 93 74 .take(limit as usize) 94 - .map( 95 - |( 96 - did, 97 - handle, 98 - email, 99 - created_at, 100 - email_verified, 101 - deactivated_at, 102 - invites_disabled, 103 - )| { 104 - AccountView { 105 - did: did.clone().into(), 106 - handle: handle.into(), 107 - email, 108 - indexed_at: created_at.to_rfc3339(), 109 - email_confirmed_at: if email_verified { 110 - Some(created_at.to_rfc3339()) 111 - } else { 112 - None 113 - }, 114 - deactivated_at: deactivated_at.map(|dt| dt.to_rfc3339()), 115 - invites_disabled, 116 - } 75 + .map(|row| AccountView { 76 + did: row.did.clone(), 77 + handle: row.handle, 78 + email: row.email, 79 + indexed_at: row.created_at.to_rfc3339(), 80 + email_confirmed_at: if row.email_verified { 81 + Some(row.created_at.to_rfc3339()) 82 + } else { 83 + None 117 84 }, 118 - ) 85 + deactivated_at: row.deactivated_at.map(|dt| dt.to_rfc3339()), 86 + invites_disabled: row.invites_disabled, 87 + }) 119 88 .collect(); 120 89 let next_cursor = if has_more { 121 90 accounts.last().map(|a| a.did.to_string())
+26 -50
crates/tranquil-pds/src/api/admin/account/update.rs
··· 27 27 if account.is_empty() || email.is_empty() { 28 28 return ApiError::InvalidRequest("account and email are required".into()).into_response(); 29 29 } 30 - let result = sqlx::query!("UPDATE users SET email = $1 WHERE did = $2", email, account) 31 - .execute(&state.db) 32 - .await; 33 - match result { 34 - Ok(r) => { 35 - if r.rows_affected() == 0 { 36 - return ApiError::AccountNotFound.into_response(); 37 - } 38 - EmptyResponse::ok().into_response() 39 - } 30 + let account_did: Did = match account.parse() { 31 + Ok(d) => d, 32 + Err(_) => return ApiError::InvalidDid("Invalid DID format".into()).into_response(), 33 + }; 34 + match state.user_repo.admin_update_email(&account_did, email).await { 35 + Ok(0) => ApiError::AccountNotFound.into_response(), 36 + Ok(_) => EmptyResponse::ok().into_response(), 40 37 Err(e) => { 41 38 error!("DB error updating email: {:?}", e); 42 39 ApiError::InternalError(None).into_response() ··· 67 64 return ApiError::InvalidHandle(None).into_response(); 68 65 } 69 66 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 67 + let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); 70 68 let handle = if !input_handle.contains('.') { 71 - format!("{}.{}", input_handle, hostname) 69 + format!("{}.{}", input_handle, hostname_for_handles) 72 70 } else { 73 71 input_handle.to_string() 74 72 }; 75 - let old_handle = sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", did.as_str()) 76 - .fetch_optional(&state.db) 73 + let old_handle = state 74 + .user_repo 75 + .get_handle_by_did(did) 77 76 .await 78 77 .ok() 79 78 .flatten(); 80 - let existing = sqlx::query!( 81 - "SELECT id FROM users WHERE handle = $1 AND did != $2", 82 - handle, 83 - did.as_str() 84 - ) 85 - .fetch_optional(&state.db) 86 - .await; 87 - if let Ok(Some(_)) = existing { 79 + let user_id = match state.user_repo.get_id_by_did(did).await { 80 + Ok(Some(id)) => id, 81 + _ => return ApiError::AccountNotFound.into_response(), 82 + }; 83 + let handle_for_check = Handle::new_unchecked(&handle); 84 + if let Ok(true) = state.user_repo.check_handle_exists(&handle_for_check, user_id).await { 88 85 return ApiError::HandleTaken.into_response(); 89 86 } 90 - let result = sqlx::query!( 91 - "UPDATE users SET handle = $1 WHERE did = $2", 92 - handle, 93 - did.as_str() 94 - ) 95 - .execute(&state.db) 96 - .await; 97 - match result { 98 - Ok(r) => { 99 - if r.rows_affected() == 0 { 100 - return ApiError::AccountNotFound.into_response(); 101 - } 87 + match state.user_repo.admin_update_handle(did, &handle_for_check).await { 88 + Ok(0) => ApiError::AccountNotFound.into_response(), 89 + Ok(_) => { 102 90 if let Some(old) = old_handle { 103 91 let _ = state.cache.delete(&format!("handle:{}", old)).await; 104 92 } 105 93 let _ = state.cache.delete(&format!("handle:{}", handle)).await; 106 - let handle_typed = Handle::new_unchecked(&handle); 107 94 if let Err(e) = 108 - crate::api::repo::record::sequence_identity_event(&state, did, Some(&handle_typed)) 95 + crate::api::repo::record::sequence_identity_event(&state, did, Some(&handle_for_check)) 109 96 .await 110 97 { 111 98 warn!( ··· 114 101 ); 115 102 } 116 103 if let Err(e) = 117 - crate::api::identity::did::update_plc_handle(&state, did.as_str(), &handle).await 104 + crate::api::identity::did::update_plc_handle(&state, did, &handle_for_check).await 118 105 { 119 106 warn!("Failed to update PLC handle for admin handle update: {}", e); 120 107 } ··· 150 137 return ApiError::InternalError(None).into_response(); 151 138 } 152 139 }; 153 - let result = sqlx::query!( 154 - "UPDATE users SET password_hash = $1 WHERE did = $2", 155 - password_hash, 156 - did.as_str() 157 - ) 158 - .execute(&state.db) 159 - .await; 160 - match result { 161 - Ok(r) => { 162 - if r.rows_affected() == 0 { 163 - return ApiError::AccountNotFound.into_response(); 164 - } 165 - EmptyResponse::ok().into_response() 166 - } 140 + match state.user_repo.admin_update_password(did, &password_hash).await { 141 + Ok(0) => ApiError::AccountNotFound.into_response(), 142 + Ok(_) => EmptyResponse::ok().into_response(), 167 143 Err(e) => { 168 144 error!("DB error updating password: {:?}", e); 169 145 ApiError::InternalError(None).into_response()
+123 -55
crates/tranquil-pds/src/api/admin/config.rs
··· 4 4 use axum::{Json, extract::State}; 5 5 use serde::{Deserialize, Serialize}; 6 6 use tracing::error; 7 + use tranquil_types::CidLink; 7 8 8 9 #[derive(Serialize)] 9 10 #[serde(rename_all = "camelCase")] ··· 42 43 pub async fn get_server_config( 43 44 State(state): State<AppState>, 44 45 ) -> Result<Json<ServerConfigResponse>, ApiError> { 45 - let rows: Vec<(String, String)> = sqlx::query_as( 46 - "SELECT key, value FROM server_config WHERE key IN ('server_name', 'primary_color', 'primary_color_dark', 'secondary_color', 'secondary_color_dark', 'logo_cid')" 47 - ) 48 - .fetch_all(&state.db) 49 - .await?; 46 + let keys = &[ 47 + "server_name", 48 + "primary_color", 49 + "primary_color_dark", 50 + "secondary_color", 51 + "secondary_color_dark", 52 + "logo_cid", 53 + ]; 54 + 55 + let rows = state 56 + .infra_repo 57 + .get_server_configs(keys) 58 + .await 59 + .map_err(|e| { 60 + error!("DB error fetching server config: {:?}", e); 61 + ApiError::InternalError(None) 62 + })?; 50 63 51 - let config_map: std::collections::HashMap<String, String> = 52 - rows.into_iter().collect(); 64 + let config_map: std::collections::HashMap<String, String> = rows.into_iter().collect(); 53 65 54 66 Ok(Json(ServerConfigResponse { 55 67 server_name: config_map ··· 64 76 })) 65 77 } 66 78 67 - async fn upsert_config(db: &sqlx::PgPool, key: &str, value: &str) -> Result<(), sqlx::Error> { 68 - sqlx::query( 69 - "INSERT INTO server_config (key, value, updated_at) VALUES ($1, $2, NOW()) 70 - ON CONFLICT (key) DO UPDATE SET value = $2, updated_at = NOW()", 71 - ) 72 - .bind(key) 73 - .bind(value) 74 - .execute(db) 75 - .await?; 76 - Ok(()) 77 - } 78 - 79 - async fn delete_config(db: &sqlx::PgPool, key: &str) -> Result<(), sqlx::Error> { 80 - sqlx::query("DELETE FROM server_config WHERE key = $1") 81 - .bind(key) 82 - .execute(db) 83 - .await?; 84 - Ok(()) 85 - } 86 - 87 79 pub async fn update_server_config( 88 80 State(state): State<AppState>, 89 81 _admin: BearerAuthAdmin, ··· 96 88 "Server name must be 1-100 characters".into(), 97 89 )); 98 90 } 99 - upsert_config(&state.db, "server_name", trimmed).await?; 91 + state 92 + .infra_repo 93 + .upsert_server_config("server_name", trimmed) 94 + .await 95 + .map_err(|e| { 96 + error!("DB error upserting server_name: {:?}", e); 97 + ApiError::InternalError(None) 98 + })?; 100 99 } 101 100 102 101 if let Some(ref color) = req.primary_color { 103 102 if color.is_empty() { 104 - delete_config(&state.db, "primary_color").await?; 103 + state 104 + .infra_repo 105 + .delete_server_config("primary_color") 106 + .await 107 + .map_err(|e| { 108 + error!("DB error deleting primary_color: {:?}", e); 109 + ApiError::InternalError(None) 110 + })?; 105 111 } else if is_valid_hex_color(color) { 106 - upsert_config(&state.db, "primary_color", color).await?; 112 + state 113 + .infra_repo 114 + .upsert_server_config("primary_color", color) 115 + .await 116 + .map_err(|e| { 117 + error!("DB error upserting primary_color: {:?}", e); 118 + ApiError::InternalError(None) 119 + })?; 107 120 } else { 108 121 return Err(ApiError::InvalidRequest( 109 122 "Invalid primary color format (expected #RRGGBB)".into(), ··· 113 126 114 127 if let Some(ref color) = req.primary_color_dark { 115 128 if color.is_empty() { 116 - delete_config(&state.db, "primary_color_dark").await?; 129 + state 130 + .infra_repo 131 + .delete_server_config("primary_color_dark") 132 + .await 133 + .map_err(|e| { 134 + error!("DB error deleting primary_color_dark: {:?}", e); 135 + ApiError::InternalError(None) 136 + })?; 117 137 } else if is_valid_hex_color(color) { 118 - upsert_config(&state.db, "primary_color_dark", color).await?; 138 + state 139 + .infra_repo 140 + .upsert_server_config("primary_color_dark", color) 141 + .await 142 + .map_err(|e| { 143 + error!("DB error upserting primary_color_dark: {:?}", e); 144 + ApiError::InternalError(None) 145 + })?; 119 146 } else { 120 147 return Err(ApiError::InvalidRequest( 121 148 "Invalid primary dark color format (expected #RRGGBB)".into(), ··· 125 152 126 153 if let Some(ref color) = req.secondary_color { 127 154 if color.is_empty() { 128 - delete_config(&state.db, "secondary_color").await?; 155 + state 156 + .infra_repo 157 + .delete_server_config("secondary_color") 158 + .await 159 + .map_err(|e| { 160 + error!("DB error deleting secondary_color: {:?}", e); 161 + ApiError::InternalError(None) 162 + })?; 129 163 } else if is_valid_hex_color(color) { 130 - upsert_config(&state.db, "secondary_color", color).await?; 164 + state 165 + .infra_repo 166 + .upsert_server_config("secondary_color", color) 167 + .await 168 + .map_err(|e| { 169 + error!("DB error upserting secondary_color: {:?}", e); 170 + ApiError::InternalError(None) 171 + })?; 131 172 } else { 132 173 return Err(ApiError::InvalidRequest( 133 174 "Invalid secondary color format (expected #RRGGBB)".into(), ··· 137 178 138 179 if let Some(ref color) = req.secondary_color_dark { 139 180 if color.is_empty() { 140 - delete_config(&state.db, "secondary_color_dark").await?; 181 + state 182 + .infra_repo 183 + .delete_server_config("secondary_color_dark") 184 + .await 185 + .map_err(|e| { 186 + error!("DB error deleting secondary_color_dark: {:?}", e); 187 + ApiError::InternalError(None) 188 + })?; 141 189 } else if is_valid_hex_color(color) { 142 - upsert_config(&state.db, "secondary_color_dark", color).await?; 190 + state 191 + .infra_repo 192 + .upsert_server_config("secondary_color_dark", color) 193 + .await 194 + .map_err(|e| { 195 + error!("DB error upserting secondary_color_dark: {:?}", e); 196 + ApiError::InternalError(None) 197 + })?; 143 198 } else { 144 199 return Err(ApiError::InvalidRequest( 145 200 "Invalid secondary dark color format (expected #RRGGBB)".into(), ··· 148 203 } 149 204 150 205 if let Some(ref logo_cid) = req.logo_cid { 151 - let old_logo_cid: Option<String> = 152 - sqlx::query_scalar("SELECT value FROM server_config WHERE key = 'logo_cid'") 153 - .fetch_optional(&state.db) 154 - .await?; 206 + let old_logo_cid = state 207 + .infra_repo 208 + .get_server_config("logo_cid") 209 + .await 210 + .ok() 211 + .flatten(); 155 212 156 213 let should_delete_old = match (&old_logo_cid, logo_cid.is_empty()) { 157 214 (Some(old), true) => Some(old.clone()), ··· 159 216 _ => None, 160 217 }; 161 218 162 - if let Some(old_cid) = should_delete_old 163 - && let Ok(Some(blob)) = 164 - sqlx::query!("SELECT storage_key FROM blobs WHERE cid = $1", old_cid) 165 - .fetch_optional(&state.db) 166 - .await 167 - { 168 - if let Err(e) = state.blob_store.delete(&blob.storage_key).await { 169 - error!("Failed to delete old logo blob from storage: {:?}", e); 170 - } 171 - if let Err(e) = sqlx::query!("DELETE FROM blobs WHERE cid = $1", old_cid) 172 - .execute(&state.db) 173 - .await 219 + if let Some(old_cid_str) = should_delete_old { 220 + let old_cid = CidLink::new_unchecked(old_cid_str); 221 + if let Ok(Some(storage_key)) = 222 + state.infra_repo.get_blob_storage_key_by_cid(&old_cid).await 174 223 { 175 - error!("Failed to delete old logo blob record: {:?}", e); 224 + if let Err(e) = state.blob_store.delete(&storage_key).await { 225 + error!("Failed to delete old logo blob from storage: {:?}", e); 226 + } 227 + if let Err(e) = state.infra_repo.delete_blob_by_cid(&old_cid).await { 228 + error!("Failed to delete old logo blob record: {:?}", e); 229 + } 176 230 } 177 231 } 178 232 179 233 if logo_cid.is_empty() { 180 - delete_config(&state.db, "logo_cid").await?; 234 + state 235 + .infra_repo 236 + .delete_server_config("logo_cid") 237 + .await 238 + .map_err(|e| { 239 + error!("DB error deleting logo_cid: {:?}", e); 240 + ApiError::InternalError(None) 241 + })?; 181 242 } else { 182 - upsert_config(&state.db, "logo_cid", logo_cid).await?; 243 + state 244 + .infra_repo 245 + .upsert_server_config("logo_cid", logo_cid) 246 + .await 247 + .map_err(|e| { 248 + error!("DB error upserting logo_cid: {:?}", e); 249 + ApiError::InternalError(None) 250 + })?; 183 251 } 184 252 } 185 253
+71 -137
crates/tranquil-pds/src/api/admin/invite.rs
··· 10 10 }; 11 11 use serde::{Deserialize, Serialize}; 12 12 use tracing::error; 13 + use tranquil_db_traits::InviteCodeSortOrder; 13 14 14 15 #[derive(Deserialize)] 15 16 #[serde(rename_all = "camelCase")] ··· 24 25 Json(input): Json<DisableInviteCodesInput>, 25 26 ) -> Response { 26 27 if let Some(codes) = &input.codes { 27 - let _ = sqlx::query!( 28 - "UPDATE invite_codes SET disabled = TRUE WHERE code = ANY($1)", 29 - codes as &[String] 30 - ) 31 - .execute(&state.db) 32 - .await; 28 + if let Err(e) = state.infra_repo.disable_invite_codes_by_code(codes).await { 29 + error!("DB error disabling invite codes: {:?}", e); 30 + } 33 31 } 34 32 if let Some(accounts) = &input.accounts { 35 - let _ = sqlx::query!( 36 - "UPDATE invite_codes SET disabled = TRUE WHERE created_by_user IN (SELECT id FROM users WHERE did = ANY($1))", 37 - accounts as &[String] 38 - ) 39 - .execute(&state.db) 40 - .await; 33 + let accounts_typed: Vec<tranquil_types::Did> = accounts 34 + .iter() 35 + .filter_map(|a| a.parse().ok()) 36 + .collect(); 37 + if let Err(e) = state 38 + .infra_repo 39 + .disable_invite_codes_by_account(&accounts_typed) 40 + .await 41 + { 42 + error!("DB error disabling invite codes by account: {:?}", e); 43 + } 41 44 } 42 45 EmptyResponse::ok().into_response() 43 46 } ··· 81 84 Query(params): Query<GetInviteCodesParams>, 82 85 ) -> Response { 83 86 let limit = params.limit.unwrap_or(100).clamp(1, 500); 84 - let sort = params.sort.as_deref().unwrap_or("recent"); 85 - let order_clause = match sort { 86 - "usage" => "available_uses DESC", 87 - _ => "created_at DESC", 87 + let sort_order = match params.sort.as_deref() { 88 + Some("usage") => InviteCodeSortOrder::Usage, 89 + _ => InviteCodeSortOrder::Recent, 88 90 }; 89 - let codes_result = if let Some(cursor) = &params.cursor { 90 - sqlx::query_as::< 91 - _, 92 - ( 93 - String, 94 - i32, 95 - Option<bool>, 96 - uuid::Uuid, 97 - chrono::DateTime<chrono::Utc>, 98 - ), 99 - >(&format!( 100 - r#" 101 - SELECT ic.code, ic.available_uses, ic.disabled, ic.created_by_user, ic.created_at 102 - FROM invite_codes ic 103 - WHERE ic.created_at < (SELECT created_at FROM invite_codes WHERE code = $1) 104 - ORDER BY {} 105 - LIMIT $2 106 - "#, 107 - order_clause 108 - )) 109 - .bind(cursor) 110 - .bind(limit) 111 - .fetch_all(&state.db) 112 - .await 113 - } else { 114 - sqlx::query_as::< 115 - _, 116 - ( 117 - String, 118 - i32, 119 - Option<bool>, 120 - uuid::Uuid, 121 - chrono::DateTime<chrono::Utc>, 122 - ), 123 - >(&format!( 124 - r#" 125 - SELECT ic.code, ic.available_uses, ic.disabled, ic.created_by_user, ic.created_at 126 - FROM invite_codes ic 127 - ORDER BY {} 128 - LIMIT $1 129 - "#, 130 - order_clause 131 - )) 132 - .bind(limit) 133 - .fetch_all(&state.db) 91 + 92 + let codes_rows = match state 93 + .infra_repo 94 + .list_invite_codes(params.cursor.as_deref(), limit, sort_order) 134 95 .await 135 - }; 136 - let codes_rows = match codes_result { 96 + { 137 97 Ok(rows) => rows, 138 98 Err(e) => { 139 99 error!("DB error fetching invite codes: {:?}", e); ··· 141 101 } 142 102 }; 143 103 144 - let user_ids: Vec<uuid::Uuid> = codes_rows.iter().map(|(_, _, _, uid, _)| *uid).collect(); 145 - let code_strings: Vec<String> = codes_rows.iter().map(|(c, _, _, _, _)| c.clone()).collect(); 146 - 147 - let mut creator_dids: std::collections::HashMap<uuid::Uuid, String> = 148 - std::collections::HashMap::new(); 149 - sqlx::query!( 150 - "SELECT id, did FROM users WHERE id = ANY($1)", 151 - &user_ids 152 - ) 153 - .fetch_all(&state.db) 154 - .await 155 - .unwrap_or_default() 156 - .into_iter() 157 - .for_each(|r| { 158 - creator_dids.insert(r.id, r.did); 159 - }); 104 + let user_ids: Vec<uuid::Uuid> = codes_rows.iter().map(|r| r.created_by_user).collect(); 105 + let code_strings: Vec<String> = codes_rows.iter().map(|r| r.code.clone()).collect(); 160 106 161 - let mut uses_by_code: std::collections::HashMap<String, Vec<InviteCodeUseInfo>> = 162 - std::collections::HashMap::new(); 163 - if !code_strings.is_empty() { 164 - sqlx::query!( 165 - r#" 166 - SELECT icu.code, u.did, icu.used_at 167 - FROM invite_code_uses icu 168 - JOIN users u ON icu.used_by_user = u.id 169 - WHERE icu.code = ANY($1) 170 - ORDER BY icu.used_at DESC 171 - "#, 172 - &code_strings 173 - ) 174 - .fetch_all(&state.db) 107 + let creator_dids: std::collections::HashMap<uuid::Uuid, tranquil_types::Did> = state 108 + .infra_repo 109 + .get_user_dids_by_ids(&user_ids) 175 110 .await 176 111 .unwrap_or_default() 177 112 .into_iter() 178 - .for_each(|r| { 179 - uses_by_code 180 - .entry(r.code) 181 - .or_default() 182 - .push(InviteCodeUseInfo { 183 - used_by: r.did, 184 - used_at: r.used_at.to_rfc3339(), 113 + .collect(); 114 + 115 + let uses_by_code: std::collections::HashMap<String, Vec<InviteCodeUseInfo>> = if code_strings 116 + .is_empty() 117 + { 118 + std::collections::HashMap::new() 119 + } else { 120 + state 121 + .infra_repo 122 + .get_invite_code_uses_batch(&code_strings) 123 + .await 124 + .unwrap_or_default() 125 + .into_iter() 126 + .fold(std::collections::HashMap::new(), |mut acc, u| { 127 + acc.entry(u.code.clone()).or_default().push(InviteCodeUseInfo { 128 + used_by: u.used_by_did.to_string(), 129 + used_at: u.used_at.to_rfc3339(), 185 130 }); 186 - }); 187 - } 131 + acc 132 + }) 133 + }; 188 134 189 135 let codes: Vec<InviteCodeInfo> = codes_rows 190 136 .iter() 191 - .map(|(code, available_uses, disabled, created_by_user, created_at)| { 137 + .map(|r| { 192 138 let creator_did = creator_dids 193 - .get(created_by_user) 194 - .cloned() 139 + .get(&r.created_by_user) 140 + .map(|d| d.to_string()) 195 141 .unwrap_or_else(|| "unknown".to_string()); 196 142 InviteCodeInfo { 197 - code: code.clone(), 198 - available: *available_uses, 199 - disabled: disabled.unwrap_or(false), 143 + code: r.code.clone(), 144 + available: r.available_uses, 145 + disabled: r.disabled.unwrap_or(false), 200 146 for_account: creator_did.clone(), 201 147 created_by: creator_did, 202 - created_at: created_at.to_rfc3339(), 203 - uses: uses_by_code.get(code).cloned().unwrap_or_default(), 148 + created_at: r.created_at.to_rfc3339(), 149 + uses: uses_by_code.get(&r.code).cloned().unwrap_or_default(), 204 150 } 205 151 }) 206 152 .collect(); 207 153 208 154 let next_cursor = if codes_rows.len() == limit as usize { 209 - codes_rows.last().map(|(code, _, _, _, _)| code.clone()) 155 + codes_rows.last().map(|r| r.code.clone()) 210 156 } else { 211 157 None 212 158 }; ··· 234 180 if account.is_empty() { 235 181 return ApiError::InvalidRequest("account is required".into()).into_response(); 236 182 } 237 - let result = sqlx::query!( 238 - "UPDATE users SET invites_disabled = TRUE WHERE did = $1", 239 - account 240 - ) 241 - .execute(&state.db) 242 - .await; 243 - match result { 244 - Ok(r) => { 245 - if r.rows_affected() == 0 { 246 - return ApiError::AccountNotFound.into_response(); 247 - } 248 - EmptyResponse::ok().into_response() 249 - } 183 + let account_did: tranquil_types::Did = match account.parse() { 184 + Ok(d) => d, 185 + Err(_) => return ApiError::InvalidDid("Invalid DID format".into()).into_response(), 186 + }; 187 + match state.user_repo.set_invites_disabled(&account_did, true).await { 188 + Ok(true) => EmptyResponse::ok().into_response(), 189 + Ok(false) => ApiError::AccountNotFound.into_response(), 250 190 Err(e) => { 251 191 error!("DB error disabling account invites: {:?}", e); 252 192 ApiError::InternalError(None).into_response() ··· 268 208 if account.is_empty() { 269 209 return ApiError::InvalidRequest("account is required".into()).into_response(); 270 210 } 271 - let result = sqlx::query!( 272 - "UPDATE users SET invites_disabled = FALSE WHERE did = $1", 273 - account 274 - ) 275 - .execute(&state.db) 276 - .await; 277 - match result { 278 - Ok(r) => { 279 - if r.rows_affected() == 0 { 280 - return ApiError::AccountNotFound.into_response(); 281 - } 282 - EmptyResponse::ok().into_response() 283 - } 211 + let account_did: tranquil_types::Did = match account.parse() { 212 + Ok(d) => d, 213 + Err(_) => return ApiError::InvalidDid("Invalid DID format".into()).into_response(), 214 + }; 215 + match state.user_repo.set_invites_disabled(&account_did, false).await { 216 + Ok(true) => EmptyResponse::ok().into_response(), 217 + Ok(false) => ApiError::AccountNotFound.into_response(), 284 218 Err(e) => { 285 219 error!("DB error enabling account invites: {:?}", e); 286 220 ApiError::InternalError(None).into_response()
+4 -36
crates/tranquil-pds/src/api/admin/server_stats.rs
··· 17 17 } 18 18 19 19 pub async fn get_server_stats(State(state): State<AppState>, _auth: BearerAuthAdmin) -> Response { 20 - let user_count: i64 = match sqlx::query_scalar!("SELECT COUNT(*) FROM users") 21 - .fetch_one(&state.db) 22 - .await 23 - { 24 - Ok(Some(count)) => count, 25 - Ok(None) => 0, 26 - Err(_) => 0, 27 - }; 28 - 29 - let repo_count: i64 = match sqlx::query_scalar!("SELECT COUNT(*) FROM repos") 30 - .fetch_one(&state.db) 31 - .await 32 - { 33 - Ok(Some(count)) => count, 34 - Ok(None) => 0, 35 - Err(_) => 0, 36 - }; 37 - 38 - let record_count: i64 = match sqlx::query_scalar!("SELECT COUNT(*) FROM records") 39 - .fetch_one(&state.db) 40 - .await 41 - { 42 - Ok(Some(count)) => count, 43 - Ok(None) => 0, 44 - Err(_) => 0, 45 - }; 46 - 47 - let blob_storage_bytes: i64 = 48 - match sqlx::query_scalar!("SELECT COALESCE(SUM(size_bytes), 0)::BIGINT FROM blobs") 49 - .fetch_one(&state.db) 50 - .await 51 - { 52 - Ok(Some(bytes)) => bytes, 53 - Ok(None) => 0, 54 - Err(_) => 0, 55 - }; 20 + let user_count = state.user_repo.count_users().await.unwrap_or(0); 21 + let repo_count = state.repo_repo.count_repos().await.unwrap_or(0); 22 + let record_count = state.repo_repo.count_all_records().await.unwrap_or(0); 23 + let blob_storage_bytes = state.blob_repo.sum_blob_storage().await.unwrap_or(0); 56 24 57 25 Json(ServerStatsResponse { 58 26 user_count,
+62 -94
crates/tranquil-pds/src/api/admin/status.rs
··· 1 1 use crate::api::error::ApiError; 2 2 use crate::auth::BearerAuthAdmin; 3 3 use crate::state::AppState; 4 - use crate::types::Did; 4 + use crate::types::{CidLink, Did}; 5 5 use axum::{ 6 6 Json, 7 7 extract::{Query, State}, ··· 41 41 if params.did.is_none() && params.uri.is_none() && params.blob.is_none() { 42 42 return ApiError::InvalidRequest("Must provide did, uri, or blob".into()).into_response(); 43 43 } 44 - if let Some(did) = &params.did { 45 - let user = sqlx::query!( 46 - "SELECT did, deactivated_at, takedown_ref FROM users WHERE did = $1", 47 - did 48 - ) 49 - .fetch_optional(&state.db) 50 - .await; 51 - match user { 52 - Ok(Some(row)) => { 53 - let deactivated = row.deactivated_at.map(|_| StatusAttr { 44 + if let Some(did_str) = &params.did { 45 + let did: Did = match did_str.parse() { 46 + Ok(d) => d, 47 + Err(_) => return ApiError::InvalidDid("Invalid DID format".into()).into_response(), 48 + }; 49 + match state.user_repo.get_status_by_did(&did).await { 50 + Ok(Some(status)) => { 51 + let deactivated = status.deactivated_at.map(|_| StatusAttr { 54 52 applied: true, 55 53 r#ref: None, 56 54 }); 57 - let takedown = row.takedown_ref.as_ref().map(|r| StatusAttr { 55 + let takedown = status.takedown_ref.as_ref().map(|r| StatusAttr { 58 56 applied: true, 59 57 r#ref: Some(r.clone()), 60 58 }); ··· 63 61 Json(SubjectStatus { 64 62 subject: json!({ 65 63 "$type": "com.atproto.admin.defs#repoRef", 66 - "did": row.did 64 + "did": did_str 67 65 }), 68 66 takedown, 69 67 deactivated, ··· 80 78 } 81 79 } 82 80 } 83 - if let Some(uri) = &params.uri { 84 - let record = sqlx::query!( 85 - "SELECT r.id, r.takedown_ref FROM records r WHERE r.record_cid = $1", 86 - uri 87 - ) 88 - .fetch_optional(&state.db) 89 - .await; 90 - match record { 91 - Ok(Some(row)) => { 92 - let takedown = row.takedown_ref.as_ref().map(|r| StatusAttr { 81 + if let Some(uri_str) = &params.uri { 82 + let cid: CidLink = match uri_str.parse() { 83 + Ok(c) => c, 84 + Err(_) => return ApiError::InvalidRequest("Invalid CID format".into()).into_response(), 85 + }; 86 + match state.repo_repo.get_record_by_cid(&cid).await { 87 + Ok(Some(record)) => { 88 + let takedown = record.takedown_ref.as_ref().map(|r| StatusAttr { 93 89 applied: true, 94 90 r#ref: Some(r.clone()), 95 91 }); ··· 98 94 Json(SubjectStatus { 99 95 subject: json!({ 100 96 "$type": "com.atproto.repo.strongRef", 101 - "uri": uri, 102 - "cid": uri 97 + "uri": uri_str, 98 + "cid": uri_str 103 99 }), 104 100 takedown, 105 101 deactivated: None, ··· 116 112 } 117 113 } 118 114 } 119 - if let Some(blob_cid) = &params.blob { 115 + if let Some(blob_cid_str) = &params.blob { 116 + let blob_cid: CidLink = match blob_cid_str.parse() { 117 + Ok(c) => c, 118 + Err(_) => return ApiError::InvalidRequest("Invalid CID format".into()).into_response(), 119 + }; 120 120 let did = match &params.did { 121 121 Some(d) => d, 122 122 None => { ··· 124 124 .into_response(); 125 125 } 126 126 }; 127 - let blob = sqlx::query!( 128 - "SELECT cid, takedown_ref FROM blobs WHERE cid = $1", 129 - blob_cid 130 - ) 131 - .fetch_optional(&state.db) 132 - .await; 133 - match blob { 134 - Ok(Some(row)) => { 135 - let takedown = row.takedown_ref.as_ref().map(|r| StatusAttr { 127 + match state.blob_repo.get_blob_with_takedown(&blob_cid).await { 128 + Ok(Some(blob)) => { 129 + let takedown = blob.takedown_ref.as_ref().map(|r| StatusAttr { 136 130 applied: true, 137 131 r#ref: Some(r.clone()), 138 132 }); ··· 142 136 subject: json!({ 143 137 "$type": "com.atproto.admin.defs#repoBlobRef", 144 138 "did": did, 145 - "cid": row.cid 139 + "cid": blob.cid 146 140 }), 147 141 takedown, 148 142 deactivated: None, ··· 187 181 let did_str = input.subject.get("did").and_then(|d| d.as_str()); 188 182 if let Some(did_str) = did_str { 189 183 let did = Did::new_unchecked(did_str); 190 - let mut tx = match state.db.begin().await { 191 - Ok(tx) => tx, 192 - Err(e) => { 193 - error!("Failed to begin transaction: {:?}", e); 194 - return ApiError::InternalError(None).into_response(); 195 - } 196 - }; 197 184 if let Some(takedown) = &input.takedown { 198 185 let takedown_ref = if takedown.applied { 199 - takedown.r#ref.clone() 186 + takedown.r#ref.as_deref() 200 187 } else { 201 188 None 202 189 }; 203 - if let Err(e) = sqlx::query!( 204 - "UPDATE users SET takedown_ref = $1 WHERE did = $2", 205 - takedown_ref, 206 - did.as_str() 207 - ) 208 - .execute(&mut *tx) 209 - .await 190 + if let Err(e) = state 191 + .user_repo 192 + .set_user_takedown(&did, takedown_ref) 193 + .await 210 194 { 211 195 error!("Failed to update user takedown status for {}: {:?}", did, e); 212 196 return ApiError::InternalError(Some( ··· 217 201 } 218 202 if let Some(deactivated) = &input.deactivated { 219 203 let result = if deactivated.applied { 220 - sqlx::query!( 221 - "UPDATE users SET deactivated_at = NOW() WHERE did = $1", 222 - did.as_str() 223 - ) 224 - .execute(&mut *tx) 225 - .await 204 + state.user_repo.deactivate_account(&did, None).await 226 205 } else { 227 - sqlx::query!( 228 - "UPDATE users SET deactivated_at = NULL WHERE did = $1", 229 - did.as_str() 230 - ) 231 - .execute(&mut *tx) 232 - .await 206 + state.user_repo.activate_account(&did).await 233 207 }; 234 208 if let Err(e) = result { 235 209 error!( ··· 241 215 )) 242 216 .into_response(); 243 217 } 244 - } 245 - if let Err(e) = tx.commit().await { 246 - error!("Failed to commit transaction: {:?}", e); 247 - return ApiError::InternalError(None).into_response(); 248 218 } 249 219 if let Some(takedown) = &input.takedown { 250 220 let status = if takedown.applied { ··· 280 250 warn!("Failed to sequence account event for deactivation: {}", e); 281 251 } 282 252 } 283 - if let Ok(Some(handle)) = 284 - sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", did.as_str()) 285 - .fetch_optional(&state.db) 286 - .await 287 - { 253 + if let Ok(Some(handle)) = state.user_repo.get_handle_by_did(&did).await { 288 254 let _ = state.cache.delete(&format!("handle:{}", handle)).await; 289 255 } 290 256 return ( ··· 304 270 } 305 271 } 306 272 Some("com.atproto.repo.strongRef") => { 307 - let uri = input.subject.get("uri").and_then(|u| u.as_str()); 308 - if let Some(uri) = uri { 273 + let uri_str = input.subject.get("uri").and_then(|u| u.as_str()); 274 + if let Some(uri_str) = uri_str { 275 + let cid: CidLink = match uri_str.parse() { 276 + Ok(c) => c, 277 + Err(_) => return ApiError::InvalidRequest("Invalid CID format".into()).into_response(), 278 + }; 309 279 if let Some(takedown) = &input.takedown { 310 280 let takedown_ref = if takedown.applied { 311 - takedown.r#ref.clone() 281 + takedown.r#ref.as_deref() 312 282 } else { 313 283 None 314 284 }; 315 - if let Err(e) = sqlx::query!( 316 - "UPDATE records SET takedown_ref = $1 WHERE record_cid = $2", 317 - takedown_ref, 318 - uri 319 - ) 320 - .execute(&state.db) 321 - .await 285 + if let Err(e) = state 286 + .repo_repo 287 + .set_record_takedown(&cid, takedown_ref) 288 + .await 322 289 { 323 290 error!( 324 291 "Failed to update record takedown status for {}: {:?}", 325 - uri, e 292 + uri_str, e 326 293 ); 327 294 return ApiError::InternalError(Some( 328 295 "Failed to update takedown status".into(), ··· 344 311 } 345 312 } 346 313 Some("com.atproto.admin.defs#repoBlobRef") => { 347 - let cid = input.subject.get("cid").and_then(|c| c.as_str()); 348 - if let Some(cid) = cid { 314 + let cid_str = input.subject.get("cid").and_then(|c| c.as_str()); 315 + if let Some(cid_str) = cid_str { 316 + let cid: CidLink = match cid_str.parse() { 317 + Ok(c) => c, 318 + Err(_) => return ApiError::InvalidRequest("Invalid CID format".into()).into_response(), 319 + }; 349 320 if let Some(takedown) = &input.takedown { 350 321 let takedown_ref = if takedown.applied { 351 - takedown.r#ref.clone() 322 + takedown.r#ref.as_deref() 352 323 } else { 353 324 None 354 325 }; 355 - if let Err(e) = sqlx::query!( 356 - "UPDATE blobs SET takedown_ref = $1 WHERE cid = $2", 357 - takedown_ref, 358 - cid 359 - ) 360 - .execute(&state.db) 361 - .await 326 + if let Err(e) = state 327 + .blob_repo 328 + .update_blob_takedown(&cid, takedown_ref) 329 + .await 362 330 { 363 - error!("Failed to update blob takedown status for {}: {:?}", cid, e); 331 + error!("Failed to update blob takedown status for {}: {:?}", cid_str, e); 364 332 return ApiError::InternalError(Some( 365 333 "Failed to update takedown status".into(), 366 334 ))
+12 -15
crates/tranquil-pds/src/api/age_assurance.rs
··· 43 43 let http_uri = "/"; 44 44 45 45 let auth_user = match validate_token_with_dpop( 46 - &state.db, 46 + state.user_repo.as_ref(), 47 + state.oauth_repo.as_ref(), 47 48 &extracted.token, 48 49 extracted.is_dpop, 49 50 dpop_proof, ··· 64 65 } 65 66 }; 66 67 67 - let row = match sqlx::query!( 68 - "SELECT created_at FROM users WHERE did = $1", 69 - &auth_user.did 70 - ) 71 - .fetch_optional(&state.db) 72 - .await 73 - { 74 - Ok(r) => { 75 - tracing::debug!(?r, "age assurance: query result"); 76 - r 68 + match state.user_repo.get_by_did(&auth_user.did).await { 69 + Ok(Some(user)) => { 70 + tracing::debug!(created_at = ?user.created_at, "age assurance: got user"); 71 + Some(user.created_at.to_rfc3339()) 72 + } 73 + Ok(None) => { 74 + tracing::debug!("age assurance: user not found"); 75 + None 77 76 } 78 77 Err(e) => { 79 78 tracing::warn!(?e, "age assurance: query failed"); 80 - return None; 79 + None 81 80 } 82 - }; 83 - 84 - row.map(|r| r.created_at.to_rfc3339()) 81 + } 85 82 }
+33 -132
crates/tranquil-pds/src/api/backup.rs
··· 14 14 use serde::{Deserialize, Serialize}; 15 15 use serde_json::json; 16 16 use std::str::FromStr; 17 + use tranquil_db::{BackupRepository, OldBackupInfo}; 17 18 use tracing::{error, info, warn}; 18 19 19 20 #[derive(Serialize)] ··· 35 36 } 36 37 37 38 pub async fn list_backups(State(state): State<AppState>, auth: BearerAuth) -> Response { 38 - let user = match sqlx::query!( 39 - "SELECT id, backup_enabled FROM users WHERE did = $1", 40 - auth.0.did.as_str() 41 - ) 42 - .fetch_optional(&state.db) 43 - .await 44 - { 45 - Ok(Some(u)) => u, 39 + let (user_id, backup_enabled) = match state.backup_repo.get_user_backup_status(&auth.0.did).await { 40 + Ok(Some(status)) => status, 46 41 Ok(None) => { 47 42 return ApiError::AccountNotFound.into_response(); 48 43 } ··· 52 47 } 53 48 }; 54 49 55 - let backups = match sqlx::query!( 56 - r#" 57 - SELECT id, repo_rev, repo_root_cid, block_count, size_bytes, created_at 58 - FROM account_backups 59 - WHERE user_id = $1 60 - ORDER BY created_at DESC 61 - "#, 62 - user.id 63 - ) 64 - .fetch_all(&state.db) 65 - .await 66 - { 50 + let backups = match state.backup_repo.list_backups_for_user(user_id).await { 67 51 Ok(rows) => rows, 68 52 Err(e) => { 69 53 error!("DB error fetching backups: {:?}", e); ··· 87 71 StatusCode::OK, 88 72 Json(ListBackupsOutput { 89 73 backups: backup_list, 90 - backup_enabled: user.backup_enabled, 74 + backup_enabled, 91 75 }), 92 76 ) 93 77 .into_response() ··· 110 94 } 111 95 }; 112 96 113 - let backup = match sqlx::query!( 114 - r#" 115 - SELECT ab.storage_key, ab.repo_rev 116 - FROM account_backups ab 117 - JOIN users u ON u.id = ab.user_id 118 - WHERE ab.id = $1 AND u.did = $2 119 - "#, 120 - backup_id, 121 - auth.0.did.as_str() 122 - ) 123 - .fetch_optional(&state.db) 124 - .await 125 - { 97 + let backup_info = match state.backup_repo.get_backup_storage_info(backup_id, &auth.0.did).await { 126 98 Ok(Some(b)) => b, 127 99 Ok(None) => { 128 100 return ApiError::BackupNotFound.into_response(); ··· 140 112 } 141 113 }; 142 114 143 - let car_bytes = match backup_storage.get_backup(&backup.storage_key).await { 115 + let car_bytes = match backup_storage.get_backup(&backup_info.storage_key).await { 144 116 Ok(bytes) => bytes, 145 117 Err(e) => { 146 118 error!("Failed to fetch backup from storage: {:?}", e); ··· 155 127 (axum::http::header::CONTENT_TYPE, "application/vnd.ipld.car"), 156 128 ( 157 129 axum::http::header::CONTENT_DISPOSITION, 158 - &format!("attachment; filename=\"{}.car\"", backup.repo_rev), 130 + &format!("attachment; filename=\"{}.car\"", backup_info.repo_rev), 159 131 ), 160 132 ], 161 133 car_bytes, ··· 180 152 } 181 153 }; 182 154 183 - let user = match sqlx::query!( 184 - r#" 185 - SELECT u.id, u.did, u.backup_enabled, u.deactivated_at, r.repo_root_cid, r.repo_rev 186 - FROM users u 187 - JOIN repos r ON r.user_id = u.id 188 - WHERE u.did = $1 189 - "#, 190 - auth.0.did.as_str() 191 - ) 192 - .fetch_optional(&state.db) 193 - .await 194 - { 155 + let user = match state.backup_repo.get_user_for_backup(&auth.0.did).await { 195 156 Ok(Some(u)) => u, 196 157 Ok(None) => { 197 158 return ApiError::AccountNotFound.into_response(); ··· 221 182 }; 222 183 223 184 let car_bytes = 224 - match generate_full_backup(&state.db, &state.block_store, user.id, &head_cid).await { 185 + match generate_full_backup(state.repo_repo.as_ref(), &state.block_store, user.id, &head_cid).await { 225 186 Ok(bytes) => bytes, 226 187 Err(e) => { 227 188 error!("Failed to generate CAR: {:?}", e); ··· 244 205 } 245 206 }; 246 207 247 - let backup_id = match sqlx::query_scalar!( 248 - r#" 249 - INSERT INTO account_backups (user_id, storage_key, repo_root_cid, repo_rev, block_count, size_bytes) 250 - VALUES ($1, $2, $3, $4, $5, $6) 251 - RETURNING id 252 - "#, 208 + let backup_id = match state.backup_repo.insert_backup( 253 209 user.id, 254 - storage_key, 255 - user.repo_root_cid, 256 - repo_rev, 210 + &storage_key, 211 + &user.repo_root_cid, 212 + &repo_rev, 257 213 block_count, 258 - size_bytes 259 - ) 260 - .fetch_one(&state.db) 261 - .await 262 - { 214 + size_bytes, 215 + ).await { 263 216 Ok(id) => id, 264 217 Err(e) => { 265 218 error!("DB error inserting backup: {:?}", e); ··· 282 235 ); 283 236 284 237 let retention = BackupStorage::retention_count(); 285 - if let Err(e) = cleanup_old_backups(&state.db, backup_storage, user.id, retention).await { 238 + if let Err(e) = cleanup_old_backups(state.backup_repo.as_ref(), backup_storage, user.id, retention).await { 286 239 warn!(did = %user.did, error = %e, "Failed to cleanup old backups after manual backup"); 287 240 } 288 241 ··· 299 252 } 300 253 301 254 async fn cleanup_old_backups( 302 - db: &sqlx::PgPool, 255 + backup_repo: &dyn BackupRepository, 303 256 backup_storage: &BackupStorage, 304 257 user_id: uuid::Uuid, 305 258 retention_count: u32, 306 259 ) -> Result<(), String> { 307 - let old_backups = sqlx::query!( 308 - r#" 309 - SELECT id, storage_key 310 - FROM account_backups 311 - WHERE user_id = $1 312 - ORDER BY created_at DESC 313 - OFFSET $2 314 - "#, 315 - user_id, 316 - retention_count as i64 317 - ) 318 - .fetch_all(db) 319 - .await 320 - .map_err(|e| format!("DB error fetching old backups: {}", e))?; 260 + let old_backups: Vec<OldBackupInfo> = backup_repo 261 + .get_old_backups(user_id, retention_count as i64) 262 + .await 263 + .map_err(|e| format!("DB error fetching old backups: {}", e))?; 321 264 322 265 for backup in old_backups { 323 266 if let Err(e) = backup_storage.delete_backup(&backup.storage_key).await { ··· 329 272 continue; 330 273 } 331 274 332 - sqlx::query!("DELETE FROM account_backups WHERE id = $1", backup.id) 333 - .execute(db) 275 + backup_repo 276 + .delete_backup(backup.id) 334 277 .await 335 278 .map_err(|e| format!("Failed to delete old backup record: {}", e))?; 336 279 } ··· 355 298 } 356 299 }; 357 300 358 - let backup = match sqlx::query!( 359 - r#" 360 - SELECT ab.id, ab.storage_key, u.deactivated_at 361 - FROM account_backups ab 362 - JOIN users u ON u.id = ab.user_id 363 - WHERE ab.id = $1 AND u.did = $2 364 - "#, 365 - backup_id, 366 - auth.0.did.as_str() 367 - ) 368 - .fetch_optional(&state.db) 369 - .await 370 - { 301 + let backup = match state.backup_repo.get_backup_for_deletion(backup_id, &auth.0.did).await { 371 302 Ok(Some(b)) => b, 372 303 Ok(None) => { 373 304 return ApiError::BackupNotFound.into_response(); ··· 392 323 ); 393 324 } 394 325 395 - if let Err(e) = sqlx::query!("DELETE FROM account_backups WHERE id = $1", backup.id) 396 - .execute(&state.db) 397 - .await 398 - { 326 + if let Err(e) = state.backup_repo.delete_backup(backup.id).await { 399 327 error!("DB error deleting backup: {:?}", e); 400 328 return ApiError::InternalError(Some("Failed to delete backup".into())).into_response(); 401 329 } ··· 416 344 auth: BearerAuth, 417 345 Json(input): Json<SetBackupEnabledInput>, 418 346 ) -> Response { 419 - let user = match sqlx::query!( 420 - "SELECT deactivated_at FROM users WHERE did = $1", 421 - auth.0.did.as_str() 422 - ) 423 - .fetch_optional(&state.db) 424 - .await 425 - { 426 - Ok(Some(u)) => u, 347 + let deactivated_at = match state.backup_repo.get_user_deactivated_status(&auth.0.did).await { 348 + Ok(Some(status)) => status, 427 349 Ok(None) => { 428 350 return ApiError::AccountNotFound.into_response(); 429 351 } ··· 433 355 } 434 356 }; 435 357 436 - if user.deactivated_at.is_some() { 358 + if deactivated_at.is_some() { 437 359 return ApiError::AccountDeactivated.into_response(); 438 360 } 439 361 440 - if let Err(e) = sqlx::query!( 441 - "UPDATE users SET backup_enabled = $1 WHERE did = $2", 442 - input.enabled, 443 - auth.0.did.as_str() 444 - ) 445 - .execute(&state.db) 446 - .await 447 - { 362 + if let Err(e) = state.backup_repo.update_backup_enabled(&auth.0.did, input.enabled).await { 448 363 error!("DB error updating backup_enabled: {:?}", e); 449 364 return ApiError::InternalError(Some("Failed to update setting".into())).into_response(); 450 365 } ··· 455 370 } 456 371 457 372 pub async fn export_blobs(State(state): State<AppState>, auth: BearerAuth) -> Response { 458 - let user = match sqlx::query!("SELECT id FROM users WHERE did = $1", auth.0.did.as_str()) 459 - .fetch_optional(&state.db) 460 - .await 461 - { 462 - Ok(Some(u)) => u, 373 + let user_id = match state.backup_repo.get_user_id_by_did(&auth.0.did).await { 374 + Ok(Some(id)) => id, 463 375 Ok(None) => { 464 376 return ApiError::AccountNotFound.into_response(); 465 377 } ··· 469 381 } 470 382 }; 471 383 472 - let blobs = match sqlx::query!( 473 - r#" 474 - SELECT DISTINCT b.cid, b.storage_key, b.mime_type 475 - FROM blobs b 476 - JOIN record_blobs rb ON rb.blob_cid = b.cid 477 - WHERE rb.repo_id = $1 478 - "#, 479 - user.id 480 - ) 481 - .fetch_all(&state.db) 482 - .await 483 - { 384 + let blobs = match state.backup_repo.get_blobs_for_export(user_id).await { 484 385 Ok(rows) => rows, 485 386 Err(e) => { 486 387 error!("DB error fetching blobs: {:?}", e);
+160 -248
crates/tranquil-pds/src/api/delegation.rs
··· 1 1 use crate::api::error::ApiError; 2 2 use crate::api::repo::record::utils::create_signed_commit; 3 3 use crate::auth::BearerAuth; 4 - use crate::delegation::{self, DelegationActionType}; 5 - use crate::oauth::db as oauth_db; 4 + use crate::delegation::{DelegationActionType, SCOPE_PRESETS, scopes}; 6 5 use crate::state::{AppState, RateLimitKind}; 7 6 use crate::types::{Did, Handle, Nsid, Rkey}; 8 7 use crate::util::extract_client_ip; ··· 35 34 } 36 35 37 36 pub async fn list_controllers(State(state): State<AppState>, auth: BearerAuth) -> Response { 38 - let controllers = match delegation::get_delegations_for_account(&state.db, &auth.0.did).await { 37 + let controllers = match state 38 + .delegation_repo 39 + .get_delegations_for_account(&auth.0.did) 40 + .await 41 + { 39 42 Ok(c) => c, 40 43 Err(e) => { 41 44 tracing::error!("Failed to list controllers: {:?}", e); ··· 49 52 .into_iter() 50 53 .map(|c| ControllerInfo { 51 54 did: c.did.into(), 52 - handle: c.handle, 55 + handle: c.handle.into(), 53 56 granted_scopes: c.granted_scopes, 54 57 granted_at: c.granted_at, 55 58 is_active: c.is_active, ··· 70 73 auth: BearerAuth, 71 74 Json(input): Json<AddControllerInput>, 72 75 ) -> Response { 73 - if let Err(e) = delegation::scopes::validate_delegation_scopes(&input.granted_scopes) { 76 + if let Err(e) = scopes::validate_delegation_scopes(&input.granted_scopes) { 74 77 return ApiError::InvalidScopes(e).into_response(); 75 78 } 76 79 77 - let controller_exists: bool = sqlx::query_scalar!( 78 - r#"SELECT EXISTS(SELECT 1 FROM users WHERE did = $1) as "exists!""#, 79 - input.controller_did.as_str() 80 - ) 81 - .fetch_one(&state.db) 82 - .await 83 - .unwrap_or(false); 80 + let controller_exists = state 81 + .user_repo 82 + .get_by_did(&input.controller_did) 83 + .await 84 + .ok() 85 + .flatten() 86 + .is_some(); 84 87 85 88 if !controller_exists { 86 89 return ApiError::ControllerNotFound.into_response(); 87 90 } 88 91 89 - match delegation::controls_any_accounts(&state.db, &auth.0.did).await { 92 + match state.delegation_repo.controls_any_accounts(&auth.0.did).await { 90 93 Ok(true) => { 91 94 return ApiError::InvalidDelegation( 92 95 "Cannot add controllers to an account that controls other accounts".into(), ··· 101 104 Ok(false) => {} 102 105 } 103 106 104 - match delegation::has_any_controllers(&state.db, &input.controller_did).await { 107 + match state 108 + .delegation_repo 109 + .has_any_controllers(&input.controller_did) 110 + .await 111 + { 105 112 Ok(true) => { 106 113 return ApiError::InvalidDelegation( 107 114 "Cannot add a controlled account as a controller".into(), ··· 116 123 Ok(false) => {} 117 124 } 118 125 119 - match delegation::create_delegation( 120 - &state.db, 121 - &auth.0.did, 122 - &input.controller_did, 123 - &input.granted_scopes, 124 - &auth.0.did, 125 - ) 126 - .await 126 + match state 127 + .delegation_repo 128 + .create_delegation( 129 + &auth.0.did, 130 + &input.controller_did, 131 + &input.granted_scopes, 132 + &auth.0.did, 133 + ) 134 + .await 127 135 { 128 136 Ok(_) => { 129 - let _ = delegation::log_delegation_action( 130 - &state.db, 131 - &auth.0.did, 132 - &auth.0.did, 133 - Some(&input.controller_did), 134 - DelegationActionType::GrantCreated, 135 - Some(serde_json::json!({ 136 - "granted_scopes": input.granted_scopes 137 - })), 138 - None, 139 - None, 140 - ) 141 - .await; 137 + let _ = state 138 + .delegation_repo 139 + .log_delegation_action( 140 + &auth.0.did, 141 + &auth.0.did, 142 + Some(&input.controller_did), 143 + DelegationActionType::GrantCreated, 144 + Some(serde_json::json!({ 145 + "granted_scopes": input.granted_scopes 146 + })), 147 + None, 148 + None, 149 + ) 150 + .await; 142 151 143 152 ( 144 153 StatusCode::OK, ··· 165 174 auth: BearerAuth, 166 175 Json(input): Json<RemoveControllerInput>, 167 176 ) -> Response { 168 - match delegation::revoke_delegation(&state.db, &auth.0.did, &input.controller_did, &auth.0.did) 177 + match state 178 + .delegation_repo 179 + .revoke_delegation(&auth.0.did, &input.controller_did, &auth.0.did) 169 180 .await 170 181 { 171 182 Ok(true) => { 172 - let revoked_app_passwords = sqlx::query_scalar!( 173 - r#"DELETE FROM app_passwords 174 - WHERE user_id = (SELECT id FROM users WHERE did = $1) 175 - AND created_by_controller_did = $2 176 - RETURNING id"#, 177 - &auth.0.did, 178 - input.controller_did.as_str() 179 - ) 180 - .fetch_all(&state.db) 181 - .await 182 - .map(|r| r.len()) 183 - .unwrap_or(0); 183 + let revoked_app_passwords = state 184 + .session_repo 185 + .delete_app_passwords_by_controller(&auth.0.did, &input.controller_did) 186 + .await 187 + .unwrap_or(0) as usize; 184 188 185 - let revoked_oauth_tokens = oauth_db::revoke_tokens_for_controller( 186 - &state.db, 187 - &auth.0.did, 188 - &input.controller_did, 189 - ) 190 - .await 191 - .unwrap_or(0); 189 + let revoked_oauth_tokens = state 190 + .oauth_repo 191 + .revoke_tokens_for_controller(&auth.0.did, &input.controller_did) 192 + .await 193 + .unwrap_or(0); 192 194 193 - let _ = delegation::log_delegation_action( 194 - &state.db, 195 - &auth.0.did, 196 - &auth.0.did, 197 - Some(&input.controller_did), 198 - DelegationActionType::GrantRevoked, 199 - Some(serde_json::json!({ 200 - "revoked_app_passwords": revoked_app_passwords, 201 - "revoked_oauth_tokens": revoked_oauth_tokens 202 - })), 203 - None, 204 - None, 205 - ) 206 - .await; 195 + let _ = state 196 + .delegation_repo 197 + .log_delegation_action( 198 + &auth.0.did, 199 + &auth.0.did, 200 + Some(&input.controller_did), 201 + DelegationActionType::GrantRevoked, 202 + Some(serde_json::json!({ 203 + "revoked_app_passwords": revoked_app_passwords, 204 + "revoked_oauth_tokens": revoked_oauth_tokens 205 + })), 206 + None, 207 + None, 208 + ) 209 + .await; 207 210 208 211 ( 209 212 StatusCode::OK, ··· 232 235 auth: BearerAuth, 233 236 Json(input): Json<UpdateControllerScopesInput>, 234 237 ) -> Response { 235 - if let Err(e) = delegation::scopes::validate_delegation_scopes(&input.granted_scopes) { 238 + if let Err(e) = scopes::validate_delegation_scopes(&input.granted_scopes) { 236 239 return ApiError::InvalidScopes(e).into_response(); 237 240 } 238 241 239 - match delegation::update_delegation_scopes( 240 - &state.db, 241 - &auth.0.did, 242 - &input.controller_did, 243 - &input.granted_scopes, 244 - ) 245 - .await 242 + match state 243 + .delegation_repo 244 + .update_delegation_scopes(&auth.0.did, &input.controller_did, &input.granted_scopes) 245 + .await 246 246 { 247 247 Ok(true) => { 248 - let _ = delegation::log_delegation_action( 249 - &state.db, 250 - &auth.0.did, 251 - &auth.0.did, 252 - Some(&input.controller_did), 253 - DelegationActionType::ScopesModified, 254 - Some(serde_json::json!({ 255 - "new_scopes": input.granted_scopes 256 - })), 257 - None, 258 - None, 259 - ) 260 - .await; 248 + let _ = state 249 + .delegation_repo 250 + .log_delegation_action( 251 + &auth.0.did, 252 + &auth.0.did, 253 + Some(&input.controller_did), 254 + DelegationActionType::ScopesModified, 255 + Some(serde_json::json!({ 256 + "new_scopes": input.granted_scopes 257 + })), 258 + None, 259 + None, 260 + ) 261 + .await; 261 262 262 263 ( 263 264 StatusCode::OK, ··· 291 292 } 292 293 293 294 pub async fn list_controlled_accounts(State(state): State<AppState>, auth: BearerAuth) -> Response { 294 - let accounts = match delegation::get_accounts_controlled_by(&state.db, &auth.0.did).await { 295 + let accounts = match state 296 + .delegation_repo 297 + .get_accounts_controlled_by(&auth.0.did) 298 + .await 299 + { 295 300 Ok(a) => a, 296 301 Err(e) => { 297 302 tracing::error!("Failed to list controlled accounts: {:?}", e); ··· 305 310 .into_iter() 306 311 .map(|a| DelegatedAccountInfo { 307 312 did: a.did.into(), 308 - handle: a.handle, 313 + handle: a.handle.into(), 309 314 granted_scopes: a.granted_scopes, 310 315 granted_at: a.granted_at, 311 316 }) ··· 352 357 let limit = params.limit.clamp(1, 100); 353 358 let offset = params.offset.max(0); 354 359 355 - let entries = 356 - match delegation::audit::get_audit_log_for_account(&state.db, &auth.0.did, limit, offset) 357 - .await 358 - { 359 - Ok(e) => e, 360 - Err(e) => { 361 - tracing::error!("Failed to get audit log: {:?}", e); 362 - return ApiError::InternalError(Some("Failed to get audit log".into())) 363 - .into_response(); 364 - } 365 - }; 360 + let entries = match state 361 + .delegation_repo 362 + .get_audit_log_for_account(&auth.0.did, limit, offset) 363 + .await 364 + { 365 + Ok(e) => e, 366 + Err(e) => { 367 + tracing::error!("Failed to get audit log: {:?}", e); 368 + return ApiError::InternalError(Some("Failed to get audit log".into())).into_response(); 369 + } 370 + }; 366 371 367 - let total = delegation::audit::count_audit_log_entries(&state.db, &auth.0.did) 372 + let total = state 373 + .delegation_repo 374 + .count_audit_log_entries(&auth.0.did) 368 375 .await 369 376 .unwrap_or_default(); 370 377 ··· 401 408 402 409 pub async fn get_scope_presets() -> Response { 403 410 Json(GetScopePresetsResponse { 404 - presets: delegation::SCOPE_PRESETS 411 + presets: SCOPE_PRESETS 405 412 .iter() 406 413 .map(|p| ScopePresetInfo { 407 414 name: p.name, ··· 448 455 .into_response(); 449 456 } 450 457 451 - if let Err(e) = delegation::scopes::validate_delegation_scopes(&input.controller_scopes) { 458 + if let Err(e) = scopes::validate_delegation_scopes(&input.controller_scopes) { 452 459 return ApiError::InvalidScopes(e).into_response(); 453 460 } 454 461 455 - match delegation::has_any_controllers(&state.db, &auth.0.did).await { 462 + match state.delegation_repo.has_any_controllers(&auth.0.did).await { 456 463 Ok(true) => { 457 464 return ApiError::InvalidDelegation( 458 465 "Cannot create delegated accounts from a controlled account".into(), ··· 468 475 } 469 476 470 477 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 471 - let pds_suffix = format!(".{}", hostname); 478 + let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); 479 + let pds_suffix = format!(".{}", hostname_for_handles); 472 480 473 481 let handle = if !input.handle.contains('.') || input.handle.ends_with(&pds_suffix) { 474 482 let handle_to_validate = if input.handle.ends_with(&pds_suffix) { ··· 480 488 &input.handle 481 489 }; 482 490 match crate::api::validation::validate_short_handle(handle_to_validate) { 483 - Ok(h) => format!("{}.{}", h, hostname), 491 + Ok(h) => format!("{}.{}", h, hostname_for_handles), 484 492 Err(e) => { 485 493 return ApiError::InvalidRequest(e.to_string()).into_response(); 486 494 } ··· 501 509 } 502 510 503 511 if let Some(ref code) = input.invite_code { 504 - let valid = sqlx::query_scalar!( 505 - "SELECT available_uses > 0 AND NOT disabled FROM invite_codes WHERE code = $1", 506 - code 507 - ) 508 - .fetch_optional(&state.db) 509 - .await 510 - .ok() 511 - .flatten() 512 - .unwrap_or(Some(false)); 512 + let valid = state.infra_repo.is_invite_code_valid(code).await.unwrap_or(false); 513 513 514 - if valid != Some(true) { 514 + if !valid { 515 515 return ApiError::InvalidInviteCode.into_response(); 516 516 } 517 517 } else { ··· 572 572 let handle = Handle::new_unchecked(&handle); 573 573 info!(did = %did, handle = %handle, controller = %&auth.0.did, "Created DID for delegated account"); 574 574 575 - let mut tx = match state.db.begin().await { 576 - Ok(tx) => tx, 577 - Err(e) => { 578 - error!("Error starting transaction: {:?}", e); 579 - return ApiError::InternalError(None).into_response(); 580 - } 581 - }; 582 - 583 - let user_insert: Result<(uuid::Uuid,), _> = sqlx::query_as( 584 - r#"INSERT INTO users ( 585 - handle, email, did, password_hash, password_required, 586 - account_type, preferred_comms_channel 587 - ) VALUES ($1, $2, $3, NULL, FALSE, 'delegated'::account_type, 'email'::comms_channel) RETURNING id"#, 588 - ) 589 - .bind(handle.as_str()) 590 - .bind(&email) 591 - .bind(did.as_str()) 592 - .fetch_one(&mut *tx) 593 - .await; 594 - 595 - let user_id = match user_insert { 596 - Ok((id,)) => id, 597 - Err(e) => { 598 - if let Some(db_err) = e.as_database_error() 599 - && db_err.code().as_deref() == Some("23505") 600 - { 601 - let constraint = db_err.constraint().unwrap_or(""); 602 - if constraint.contains("handle") { 603 - return ApiError::HandleNotAvailable(None).into_response(); 604 - } else if constraint.contains("email") { 605 - return ApiError::EmailTaken.into_response(); 606 - } 607 - } 608 - error!("Error inserting user: {:?}", e); 609 - return ApiError::InternalError(None).into_response(); 610 - } 611 - }; 612 - 613 575 let encrypted_key_bytes = match crate::config::encrypt_key(&secret_key_bytes) { 614 576 Ok(bytes) => bytes, 615 577 Err(e) => { ··· 618 580 } 619 581 }; 620 582 621 - if let Err(e) = sqlx::query!( 622 - "INSERT INTO user_keys (user_id, key_bytes, encryption_version, encrypted_at) VALUES ($1, $2, $3, NOW())", 623 - user_id, 624 - &encrypted_key_bytes[..], 625 - crate::config::ENCRYPTION_VERSION 626 - ) 627 - .execute(&mut *tx) 628 - .await 629 - { 630 - error!("Error inserting user key: {:?}", e); 631 - return ApiError::InternalError(None).into_response(); 632 - } 633 - 634 - if let Err(e) = sqlx::query!( 635 - r#"INSERT INTO account_delegations (delegated_did, controller_did, granted_scopes, granted_by) 636 - VALUES ($1, $2, $3, $4)"#, 637 - did.as_str(), 638 - auth.0.did.as_str(), 639 - input.controller_scopes, 640 - auth.0.did.as_str() 641 - ) 642 - .execute(&mut *tx) 643 - .await 644 - { 645 - error!("Error creating initial delegation: {:?}", e); 646 - return ApiError::InternalError(None).into_response(); 647 - } 648 - 649 583 let mst = Mst::new(Arc::new(state.block_store.clone())); 650 584 let mst_root = match mst.persist().await { 651 585 Ok(c) => c, ··· 670 604 return ApiError::InternalError(None).into_response(); 671 605 } 672 606 }; 673 - let commit_cid_str = commit_cid.to_string(); 674 - let rev_str = rev.as_ref().to_string(); 675 - if let Err(e) = sqlx::query!( 676 - "INSERT INTO repos (user_id, repo_root_cid, repo_rev) VALUES ($1, $2, $3)", 677 - user_id, 678 - commit_cid_str, 679 - rev_str 680 - ) 681 - .execute(&mut *tx) 682 - .await 683 - { 684 - error!("Error inserting repo: {:?}", e); 685 - return ApiError::InternalError(None).into_response(); 686 - } 687 607 let genesis_block_cids = vec![mst_root.to_bytes(), commit_cid.to_bytes()]; 688 - if let Err(e) = sqlx::query!( 689 - r#" 690 - INSERT INTO user_blocks (user_id, block_cid) 691 - SELECT $1, block_cid FROM UNNEST($2::bytea[]) AS t(block_cid) 692 - ON CONFLICT (user_id, block_cid) DO NOTHING 693 - "#, 694 - user_id, 695 - &genesis_block_cids 696 - ) 697 - .execute(&mut *tx) 698 - .await 699 - { 700 - error!("Error inserting user_blocks: {:?}", e); 701 - return ApiError::InternalError(None).into_response(); 702 - } 703 608 704 - if let Some(ref code) = input.invite_code { 705 - let _ = sqlx::query!( 706 - "UPDATE invite_codes SET available_uses = available_uses - 1 WHERE code = $1", 707 - code 708 - ) 709 - .execute(&mut *tx) 710 - .await; 711 - 712 - let _ = sqlx::query!( 713 - "INSERT INTO invite_code_uses (code, used_by_user) VALUES ($1, $2)", 714 - code, 715 - user_id 716 - ) 717 - .execute(&mut *tx) 718 - .await; 719 - } 609 + let create_input = tranquil_db_traits::CreateDelegatedAccountInput { 610 + handle: handle.clone(), 611 + email: email.clone(), 612 + did: did.clone(), 613 + controller_did: auth.0.did.clone(), 614 + controller_scopes: input.controller_scopes.clone(), 615 + encrypted_key_bytes, 616 + encryption_version: crate::config::ENCRYPTION_VERSION, 617 + commit_cid: commit_cid.to_string(), 618 + repo_rev: rev.as_ref().to_string(), 619 + genesis_block_cids, 620 + invite_code: input.invite_code.clone(), 621 + }; 720 622 721 - if let Err(e) = tx.commit().await { 722 - error!("Error committing transaction: {:?}", e); 723 - return ApiError::InternalError(None).into_response(); 724 - } 623 + let _user_id = match state.user_repo.create_delegated_account(&create_input).await { 624 + Ok(id) => id, 625 + Err(tranquil_db_traits::CreateAccountError::HandleTaken) => { 626 + return ApiError::HandleNotAvailable(None).into_response(); 627 + } 628 + Err(tranquil_db_traits::CreateAccountError::EmailTaken) => { 629 + return ApiError::EmailTaken.into_response(); 630 + } 631 + Err(e) => { 632 + error!("Error creating delegated account: {:?}", e); 633 + return ApiError::InternalError(None).into_response(); 634 + } 635 + }; 725 636 726 637 if let Err(e) = 727 638 crate::api::repo::record::sequence_identity_event(&state, &did, Some(&handle)).await ··· 751 662 warn!("Failed to create default profile for {}: {}", did, e); 752 663 } 753 664 754 - let _ = delegation::log_delegation_action( 755 - &state.db, 756 - &did, 757 - &auth.0.did, 758 - Some(&auth.0.did), 759 - DelegationActionType::GrantCreated, 760 - Some(json!({ 761 - "account_created": true, 762 - "granted_scopes": input.controller_scopes 763 - })), 764 - None, 765 - None, 766 - ) 767 - .await; 665 + let _ = state 666 + .delegation_repo 667 + .log_delegation_action( 668 + &did, 669 + &auth.0.did, 670 + Some(&auth.0.did), 671 + DelegationActionType::GrantCreated, 672 + Some(json!({ 673 + "account_created": true, 674 + "granted_scopes": input.controller_scopes 675 + })), 676 + None, 677 + None, 678 + ) 679 + .await; 768 680 769 681 info!(did = %did, handle = %handle, controller = %&auth.0.did, "Delegated account created"); 770 682
+14 -12
crates/tranquil-pds/src/api/error.rs
··· 158 158 Self::RepoTakendown | Self::RepoDeactivated | Self::RepoNotFound(_) => { 159 159 StatusCode::BAD_REQUEST 160 160 } 161 - Self::InvalidSwap(_) | Self::TotpAlreadyEnabled => StatusCode::CONFLICT, 161 + Self::TotpAlreadyEnabled => StatusCode::CONFLICT, 162 + Self::InvalidSwap(_) => StatusCode::BAD_REQUEST, 162 163 Self::InvalidRequest(_) 163 164 | Self::InvalidHandle(_) 164 165 | Self::HandleNotAvailable(_) ··· 474 475 crate::auth::TokenValidationError::OAuthTokenExpired => { 475 476 Self::OAuthExpiredToken(Some("Token has expired".to_string())) 476 477 } 477 - } 478 - } 479 - } 480 - 481 - impl From<crate::util::DbLookupError> for ApiError { 482 - fn from(e: crate::util::DbLookupError) -> Self { 483 - match e { 484 - crate::util::DbLookupError::NotFound => Self::AccountNotFound, 485 - crate::util::DbLookupError::DatabaseError(db_err) => { 486 - tracing::error!("Database error: {:?}", db_err); 487 - Self::DatabaseError 478 + crate::auth::TokenValidationError::InvalidToken => { 479 + Self::AuthenticationFailed(Some("Invalid token format".to_string())) 488 480 } 489 481 } 490 482 } 491 483 } 484 + 492 485 493 486 impl From<crate::auth::extractor::AuthError> for ApiError { 494 487 fn from(e: crate::auth::extractor::AuthError) -> Self { ··· 642 635 tracing::error!("Storage error: {:?}", e); 643 636 Self::InternalError(Some("Storage operation failed".into())) 644 637 } 638 + } 639 + 640 + pub fn parse_did(s: &str) -> Result<tranquil_types::Did, Response> { 641 + s.parse() 642 + .map_err(|_| ApiError::InvalidDid("Invalid DID format".into()).into_response()) 643 + } 644 + 645 + pub fn parse_did_option(s: Option<&str>) -> Result<Option<tranquil_types::Did>, Response> { 646 + s.map(parse_did).transpose() 645 647 } 646 648 647 649 pub struct AtpJson<T>(pub T);
+132 -278
crates/tranquil-pds/src/api/identity/account.rs
··· 243 243 }) 244 244 }; 245 245 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 246 + let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); 246 247 let pds_endpoint = format!("https://{}", hostname); 247 - let suffix = format!(".{}", hostname); 248 + let suffix = format!(".{}", hostname_for_handles); 248 249 let handle = if input.handle.ends_with(&suffix) { 249 - format!("{}.{}", validated_short_handle, hostname) 250 + format!("{}.{}", validated_short_handle, hostname_for_handles) 250 251 } else if input.handle.contains('.') { 251 252 validated_short_handle.clone() 252 253 } else { 253 - format!("{}.{}", validated_short_handle, hostname) 254 + format!("{}.{}", validated_short_handle, hostname_for_handles) 254 255 }; 255 256 let (secret_key_bytes, reserved_key_id): (Vec<u8>, Option<uuid::Uuid>) = 256 257 if let Some(signing_key_did) = &input.signing_key { 257 - let reserved = sqlx::query!( 258 - r#" 259 - SELECT id, private_key_bytes 260 - FROM reserved_signing_keys 261 - WHERE public_key_did_key = $1 262 - AND used_at IS NULL 263 - AND expires_at > NOW() 264 - FOR UPDATE 265 - "#, 266 - signing_key_did 267 - ) 268 - .fetch_optional(&state.db) 269 - .await; 270 - match reserved { 271 - Ok(Some(row)) => (row.private_key_bytes, Some(row.id)), 258 + match state.infra_repo.get_reserved_signing_key(signing_key_did).await { 259 + Ok(Some(key)) => (key.private_key_bytes, Some(key.id)), 272 260 Ok(None) => { 273 261 return ApiError::InvalidSigningKey.into_response(); 274 262 } ··· 294 282 if !crate::api::server::meta::is_self_hosted_did_web_enabled() { 295 283 return ApiError::SelfHostedDidWebDisabled.into_response(); 296 284 } 297 - let subdomain_host = format!("{}.{}", input.handle, hostname); 285 + let subdomain_host = format!("{}.{}", input.handle, hostname_for_handles); 298 286 let encoded_subdomain = subdomain_host.replace(':', "%3A"); 299 287 let self_hosted_did = format!("did:web:{}", encoded_subdomain); 300 288 info!(did = %self_hosted_did, "Creating self-hosted did:web account (subdomain)"); ··· 414 402 } 415 403 } 416 404 }; 417 - let mut tx = match state.db.begin().await { 418 - Ok(tx) => tx, 419 - Err(e) => { 420 - error!("Error starting transaction: {:?}", e); 421 - return ApiError::InternalError(None).into_response(); 422 - } 423 - }; 424 405 if is_migration { 425 - let existing_account: Option<(uuid::Uuid, String, Option<chrono::DateTime<chrono::Utc>>)> = 426 - sqlx::query_as("SELECT id, handle, deactivated_at FROM users WHERE did = $1 FOR UPDATE") 427 - .bind(&did) 428 - .fetch_optional(&mut *tx) 429 - .await 430 - .unwrap_or(None); 431 - if let Some((account_id, old_handle, deactivated_at)) = existing_account { 432 - if deactivated_at.is_some() { 433 - info!(did = %did, old_handle = %old_handle, new_handle = %handle, "Preparing existing account for inbound migration"); 434 - let update_result: Result<_, sqlx::Error> = 435 - sqlx::query("UPDATE users SET handle = $1 WHERE id = $2") 436 - .bind(&handle) 437 - .bind(account_id) 438 - .execute(&mut *tx) 439 - .await; 440 - if let Err(e) = update_result { 441 - if let Some(db_err) = e.as_database_error() 442 - && db_err 443 - .constraint() 444 - .map(|c| c.contains("handle")) 445 - .unwrap_or(false) 446 - { 447 - return ApiError::HandleTaken.into_response(); 448 - } 449 - error!("Error reactivating account: {:?}", e); 450 - return ApiError::InternalError(None).into_response(); 451 - } 452 - if let Err(e) = tx.commit().await { 453 - error!("Error committing reactivation: {:?}", e); 454 - return ApiError::InternalError(None).into_response(); 455 - } 456 - let key_row: Option<(Vec<u8>, i32)> = sqlx::query_as( 457 - "SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", 458 - ) 459 - .bind(account_id) 460 - .fetch_optional(&state.db) 461 - .await 462 - .unwrap_or(None); 463 - let secret_key_bytes = match key_row { 464 - Some((key_bytes, encryption_version)) => { 465 - match crate::config::decrypt_key(&key_bytes, Some(encryption_version)) { 406 + let reactivate_input = tranquil_db_traits::MigrationReactivationInput { 407 + did: Did::new_unchecked(&did), 408 + new_handle: Handle::new_unchecked(&handle), 409 + }; 410 + match state.user_repo.reactivate_migration_account(&reactivate_input).await { 411 + Ok(reactivated) => { 412 + info!(did = %did, old_handle = %reactivated.old_handle, new_handle = %handle, "Preparing existing account for inbound migration"); 413 + let secret_key_bytes = match state.user_repo.get_user_key_by_id(reactivated.user_id).await { 414 + Ok(Some(key_info)) => { 415 + match crate::config::decrypt_key(&key_info.key_bytes, key_info.encryption_version) { 466 416 Ok(k) => k, 467 417 Err(e) => { 468 418 error!("Error decrypting key for reactivated account: {:?}", e); ··· 470 420 } 471 421 } 472 422 } 473 - None => { 423 + _ => { 474 424 error!("No signing key found for reactivated account"); 475 425 return ApiError::InternalError(Some( 476 426 "Account signing key not found".into(), ··· 496 446 return ApiError::InternalError(None).into_response(); 497 447 } 498 448 }; 499 - let session_result: Result<_, sqlx::Error> = sqlx::query( 500 - "INSERT INTO session_tokens (did, access_jti, refresh_jti, access_expires_at, refresh_expires_at) VALUES ($1, $2, $3, $4, $5)", 501 - ) 502 - .bind(&did) 503 - .bind(&access_meta.jti) 504 - .bind(&refresh_meta.jti) 505 - .bind(access_meta.expires_at) 506 - .bind(refresh_meta.expires_at) 507 - .execute(&state.db) 508 - .await; 509 - if let Err(e) = session_result { 449 + let session_data = tranquil_db_traits::SessionTokenCreate { 450 + did: Did::new_unchecked(&did), 451 + access_jti: access_meta.jti.clone(), 452 + refresh_jti: refresh_meta.jti.clone(), 453 + access_expires_at: access_meta.expires_at, 454 + refresh_expires_at: refresh_meta.expires_at, 455 + legacy_login: false, 456 + mfa_verified: false, 457 + scope: None, 458 + controller_did: None, 459 + app_password_name: None, 460 + }; 461 + if let Err(e) = state.session_repo.create_session(&session_data).await { 510 462 error!("Error creating session: {:?}", e); 511 463 return ApiError::InternalError(None).into_response(); 512 464 } ··· 514 466 axum::http::StatusCode::OK, 515 467 Json(CreateAccountOutput { 516 468 handle: handle.clone().into(), 517 - did: did.clone().into(), 469 + did: Did::new_unchecked(&did), 518 470 did_doc: state.did_resolver.resolve_did_document(&did).await, 519 471 access_jwt: access_meta.token, 520 472 refresh_jwt: refresh_meta.token, ··· 523 475 }), 524 476 ) 525 477 .into_response(); 526 - } else { 478 + } 479 + Err(tranquil_db_traits::MigrationReactivationError::NotFound) => { 480 + } 481 + Err(tranquil_db_traits::MigrationReactivationError::NotDeactivated) => { 527 482 return ApiError::AccountAlreadyExists.into_response(); 528 483 } 484 + Err(tranquil_db_traits::MigrationReactivationError::HandleTaken) => { 485 + return ApiError::HandleTaken.into_response(); 486 + } 487 + Err(e) => { 488 + error!("Error reactivating migration account: {:?}", e); 489 + return ApiError::InternalError(None).into_response(); 490 + } 529 491 } 530 492 } 531 - let exists_result: Option<(i32,)> = 532 - sqlx::query_as("SELECT 1 FROM users WHERE handle = $1 AND deactivated_at IS NULL") 533 - .bind(&handle) 534 - .fetch_optional(&mut *tx) 535 - .await 536 - .unwrap_or(None); 537 - if exists_result.is_some() { 493 + 494 + let handle_typed = Handle::new_unchecked(&handle); 495 + let handle_available = match state.user_repo.check_handle_available_for_new_account(&handle_typed).await { 496 + Ok(available) => available, 497 + Err(e) => { 498 + error!("Error checking handle availability: {:?}", e); 499 + return ApiError::InternalError(None).into_response(); 500 + } 501 + }; 502 + if !handle_available { 538 503 return ApiError::HandleTaken.into_response(); 539 504 } 505 + 540 506 let invite_code_required = std::env::var("INVITE_CODE_REQUIRED") 541 507 .map(|v| v == "true" || v == "1") 542 508 .unwrap_or(false); ··· 552 518 if let Some(code) = &input.invite_code 553 519 && !code.trim().is_empty() 554 520 { 555 - let invite_query = sqlx::query!( 556 - "SELECT available_uses FROM invite_codes WHERE code = $1 FOR UPDATE", 557 - code 558 - ) 559 - .fetch_optional(&mut *tx) 560 - .await; 561 - match invite_query { 562 - Ok(Some(row)) => { 563 - if row.available_uses <= 0 { 564 - return ApiError::InvalidInviteCode.into_response(); 565 - } 566 - let update_invite = sqlx::query!( 567 - "UPDATE invite_codes SET available_uses = available_uses - 1 WHERE code = $1", 568 - code 569 - ) 570 - .execute(&mut *tx) 571 - .await; 572 - if let Err(e) = update_invite { 573 - error!("Error updating invite code: {:?}", e); 574 - return ApiError::InternalError(None).into_response(); 575 - } 576 - } 577 - Ok(None) => { 578 - return ApiError::InvalidInviteCode.into_response(); 579 - } 521 + let valid = match state.user_repo.check_and_consume_invite_code(code).await { 522 + Ok(v) => v, 580 523 Err(e) => { 581 524 error!("Error checking invite code: {:?}", e); 582 525 return ApiError::InternalError(None).into_response(); 583 526 } 527 + }; 528 + if !valid { 529 + return ApiError::InvalidInviteCode.into_response(); 584 530 } 585 531 } 532 + 586 533 if let Err(e) = validate_password(&input.password) { 587 534 return ApiError::InvalidRequest(e.to_string()).into_response(); 588 535 } ··· 600 547 return ApiError::InternalError(None).into_response(); 601 548 } 602 549 }; 603 - let is_first_user = sqlx::query_scalar!("SELECT COUNT(*) as count FROM users") 604 - .fetch_one(&mut *tx) 605 - .await 606 - .map(|c| c.unwrap_or(0) == 0) 607 - .unwrap_or(false); 550 + 608 551 let deactivated_at: Option<chrono::DateTime<chrono::Utc>> = if is_migration || is_did_web_byod { 609 552 Some(chrono::Utc::now()) 610 553 } else { 611 554 None 612 555 }; 613 - let user_insert: Result<(uuid::Uuid,), _> = sqlx::query_as( 614 - r#"INSERT INTO users ( 615 - handle, email, did, password_hash, 616 - preferred_comms_channel, 617 - discord_id, telegram_username, signal_number, 618 - is_admin, deactivated_at, email_verified 619 - ) VALUES ($1, $2, $3, $4, $5::comms_channel, $6, $7, $8, $9, $10, $11) RETURNING id"#, 620 - ) 621 - .bind(&handle) 622 - .bind(&email) 623 - .bind(&did) 624 - .bind(&password_hash) 625 - .bind(verification_channel) 626 - .bind( 627 - input 628 - .discord_id 629 - .as_deref() 630 - .map(|s| s.trim()) 631 - .filter(|s| !s.is_empty()), 632 - ) 633 - .bind( 634 - input 635 - .telegram_username 636 - .as_deref() 637 - .map(|s| s.trim()) 638 - .filter(|s| !s.is_empty()), 639 - ) 640 - .bind( 641 - input 642 - .signal_number 643 - .as_deref() 644 - .map(|s| s.trim()) 645 - .filter(|s| !s.is_empty()), 646 - ) 647 - .bind(is_first_user) 648 - .bind(deactivated_at) 649 - .bind(false) 650 - .fetch_one(&mut *tx) 651 - .await; 652 - let user_id = match user_insert { 653 - Ok((id,)) => id, 654 - Err(e) => { 655 - if let Some(db_err) = e.as_database_error() 656 - && db_err.code().as_deref() == Some("23505") 657 - { 658 - let constraint = db_err.constraint().unwrap_or(""); 659 - if constraint.contains("handle") || constraint.contains("users_handle") { 660 - return ApiError::HandleNotAvailable(None).into_response(); 661 - } else if constraint.contains("email") || constraint.contains("users_email") { 662 - return ApiError::EmailTaken.into_response(); 663 - } else if constraint.contains("did") || constraint.contains("users_did") { 664 - return ApiError::AccountAlreadyExists.into_response(); 665 - } 666 - } 667 - error!("Error inserting user: {:?}", e); 668 - return ApiError::InternalError(None).into_response(); 669 - } 670 - }; 671 556 672 557 let encrypted_key_bytes = match crate::config::encrypt_key(&secret_key_bytes) { 673 558 Ok(enc) => enc, ··· 676 561 return ApiError::InternalError(None).into_response(); 677 562 } 678 563 }; 679 - let key_insert = sqlx::query!( 680 - "INSERT INTO user_keys (user_id, key_bytes, encryption_version, encrypted_at) VALUES ($1, $2, $3, NOW())", 681 - user_id, 682 - &encrypted_key_bytes[..], 683 - crate::config::ENCRYPTION_VERSION 684 - ) 685 - .execute(&mut *tx) 686 - .await; 687 - if let Err(e) = key_insert { 688 - error!("Error inserting user key: {:?}", e); 689 - return ApiError::InternalError(None).into_response(); 690 - } 691 - if let Some(key_id) = reserved_key_id { 692 - let mark_used = sqlx::query!( 693 - "UPDATE reserved_signing_keys SET used_at = NOW() WHERE id = $1", 694 - key_id 695 - ) 696 - .execute(&mut *tx) 697 - .await; 698 - if let Err(e) = mark_used { 699 - error!("Error marking reserved key as used: {:?}", e); 700 - return ApiError::InternalError(None).into_response(); 701 - } 702 - } 564 + 703 565 let mst = Mst::new(Arc::new(state.block_store.clone())); 704 566 let mst_root = match mst.persist().await { 705 567 Ok(c) => c, ··· 727 589 }; 728 590 let commit_cid_str = commit_cid.to_string(); 729 591 let rev_str = rev.as_ref().to_string(); 730 - let repo_insert = sqlx::query!( 731 - "INSERT INTO repos (user_id, repo_root_cid, repo_rev) VALUES ($1, $2, $3)", 732 - user_id, 733 - commit_cid_str, 734 - rev_str 735 - ) 736 - .execute(&mut *tx) 737 - .await; 738 - if let Err(e) = repo_insert { 739 - error!("Error initializing repo: {:?}", e); 740 - return ApiError::InternalError(None).into_response(); 741 - } 742 592 let genesis_block_cids = vec![mst_root.to_bytes(), commit_cid.to_bytes()]; 743 - if let Err(e) = sqlx::query!( 744 - r#" 745 - INSERT INTO user_blocks (user_id, block_cid) 746 - SELECT $1, block_cid FROM UNNEST($2::bytea[]) AS t(block_cid) 747 - ON CONFLICT (user_id, block_cid) DO NOTHING 748 - "#, 749 - user_id, 750 - &genesis_block_cids 751 - ) 752 - .execute(&mut *tx) 753 - .await 754 - { 755 - error!("Error inserting user_blocks: {:?}", e); 756 - return ApiError::InternalError(None).into_response(); 757 - } 758 - if let Some(code) = &input.invite_code 759 - && !code.trim().is_empty() 760 - { 761 - let use_insert = sqlx::query!( 762 - "INSERT INTO invite_code_uses (code, used_by_user) VALUES ($1, $2)", 763 - code, 764 - user_id 765 - ) 766 - .execute(&mut *tx) 767 - .await; 768 - if let Err(e) = use_insert { 769 - error!("Error recording invite usage: {:?}", e); 770 - return ApiError::InternalError(None).into_response(); 771 - } 772 - } 773 - if std::env::var("PDS_AGE_ASSURANCE_OVERRIDE").is_ok() { 774 - let birthdate_pref = json!({ 593 + 594 + let birthdate_pref = std::env::var("PDS_AGE_ASSURANCE_OVERRIDE").ok().map(|_| { 595 + json!({ 775 596 "$type": "app.bsky.actor.defs#personalDetailsPref", 776 597 "birthDate": "1998-05-06T00:00:00.000Z" 777 - }); 778 - if let Err(e) = sqlx::query!( 779 - "INSERT INTO account_preferences (user_id, name, value_json) VALUES ($1, $2, $3) 780 - ON CONFLICT (user_id, name) DO NOTHING", 781 - user_id, 782 - "app.bsky.actor.defs#personalDetailsPref", 783 - birthdate_pref 784 - ) 785 - .execute(&mut *tx) 786 - .await 787 - { 788 - warn!("Failed to set default birthdate preference: {:?}", e); 598 + }) 599 + }); 600 + 601 + let preferred_comms_channel = match verification_channel { 602 + "email" => tranquil_db_traits::CommsChannel::Email, 603 + "discord" => tranquil_db_traits::CommsChannel::Discord, 604 + "telegram" => tranquil_db_traits::CommsChannel::Telegram, 605 + "signal" => tranquil_db_traits::CommsChannel::Signal, 606 + _ => tranquil_db_traits::CommsChannel::Email, 607 + }; 608 + 609 + let create_input = tranquil_db_traits::CreatePasswordAccountInput { 610 + handle: Handle::new_unchecked(&handle), 611 + email: email.clone(), 612 + did: Did::new_unchecked(&did), 613 + password_hash, 614 + preferred_comms_channel, 615 + discord_id: input.discord_id.as_deref().map(|s| s.trim()).filter(|s| !s.is_empty()).map(String::from), 616 + telegram_username: input.telegram_username.as_deref().map(|s| s.trim()).filter(|s| !s.is_empty()).map(String::from), 617 + signal_number: input.signal_number.as_deref().map(|s| s.trim()).filter(|s| !s.is_empty()).map(String::from), 618 + deactivated_at, 619 + encrypted_key_bytes, 620 + encryption_version: crate::config::ENCRYPTION_VERSION, 621 + reserved_key_id, 622 + commit_cid: commit_cid_str.clone(), 623 + repo_rev: rev_str.clone(), 624 + genesis_block_cids, 625 + invite_code: input.invite_code.clone(), 626 + birthdate_pref, 627 + }; 628 + 629 + let create_result = match state.user_repo.create_password_account(&create_input).await { 630 + Ok(r) => r, 631 + Err(tranquil_db_traits::CreateAccountError::HandleTaken) => { 632 + return ApiError::HandleNotAvailable(None).into_response(); 789 633 } 790 - } 791 - if let Err(e) = tx.commit().await { 792 - error!("Error committing transaction: {:?}", e); 793 - return ApiError::InternalError(None).into_response(); 794 - } 634 + Err(tranquil_db_traits::CreateAccountError::EmailTaken) => { 635 + return ApiError::EmailTaken.into_response(); 636 + } 637 + Err(tranquil_db_traits::CreateAccountError::DidExists) => { 638 + return ApiError::AccountAlreadyExists.into_response(); 639 + } 640 + Err(e) => { 641 + error!("Error creating password account: {:?}", e); 642 + return ApiError::InternalError(None).into_response(); 643 + } 644 + }; 645 + let user_id = create_result.user_id; 795 646 if !is_migration && !is_did_web_byod { 796 647 let did_typed = Did::new_unchecked(&did); 797 648 let handle_typed = Handle::new_unchecked(&handle); ··· 858 709 ); 859 710 let formatted_token = 860 711 crate::auth::verification_token::format_token_for_display(&verification_token); 861 - if let Err(e) = crate::comms::enqueue_signup_verification( 862 - &state.db, 712 + if let Err(e) = crate::comms::comms_repo::enqueue_signup_verification( 713 + state.infra_repo.as_ref(), 863 714 user_id, 864 715 verification_channel, 865 716 recipient, 866 717 &formatted_token, 867 - None, 718 + &hostname, 868 719 ) 869 720 .await 870 721 { ··· 877 728 } else if let Some(ref user_email) = email { 878 729 let token = crate::auth::verification_token::generate_migration_token(&did, user_email); 879 730 let formatted_token = crate::auth::verification_token::format_token_for_display(&token); 880 - if let Err(e) = crate::comms::enqueue_migration_verification( 881 - &state.db, 731 + if let Err(e) = crate::comms::comms_repo::enqueue_migration_verification( 732 + state.user_repo.as_ref(), 733 + state.infra_repo.as_ref(), 882 734 user_id, 883 735 user_email, 884 736 &formatted_token, ··· 906 758 return ApiError::InternalError(None).into_response(); 907 759 } 908 760 }; 909 - if let Err(e) = sqlx::query!( 910 - "INSERT INTO session_tokens (did, access_jti, refresh_jti, access_expires_at, refresh_expires_at) VALUES ($1, $2, $3, $4, $5)", 911 - did, 912 - access_meta.jti, 913 - refresh_meta.jti, 914 - access_meta.expires_at, 915 - refresh_meta.expires_at 916 - ) 917 - .execute(&state.db) 918 - .await 919 - { 761 + let session_data = tranquil_db_traits::SessionTokenCreate { 762 + did: Did::new_unchecked(&did), 763 + access_jti: access_meta.jti.clone(), 764 + refresh_jti: refresh_meta.jti.clone(), 765 + access_expires_at: access_meta.expires_at, 766 + refresh_expires_at: refresh_meta.expires_at, 767 + legacy_login: false, 768 + mfa_verified: false, 769 + scope: None, 770 + controller_did: None, 771 + app_password_name: None, 772 + }; 773 + if let Err(e) = state.session_repo.create_session(&session_data).await { 920 774 error!("createAccount: Error creating session: {:?}", e); 921 775 return ApiError::InternalError(None).into_response(); 922 776 } ··· 934 788 StatusCode::OK, 935 789 Json(CreateAccountOutput { 936 790 handle: handle.clone().into(), 937 - did: did.into(), 791 + did: Did::new_unchecked(&did), 938 792 did_doc, 939 793 access_jwt: access_meta.token, 940 794 refresh_jwt: refresh_meta.token,
+93 -125
crates/tranquil-pds/src/api/identity/did.rs
··· 34 34 State(state): State<AppState>, 35 35 Query(params): Query<ResolveHandleParams>, 36 36 ) -> Response { 37 - let handle = params.handle.trim(); 38 - if handle.is_empty() { 37 + let handle_str = params.handle.trim(); 38 + if handle_str.is_empty() { 39 39 return ApiError::InvalidRequest("handle is required".into()).into_response(); 40 40 } 41 - let cache_key = format!("handle:{}", handle); 41 + let cache_key = format!("handle:{}", handle_str); 42 42 if let Some(did) = state.cache.get(&cache_key).await { 43 43 return DidResponse::response(did).into_response(); 44 44 } 45 - let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", handle) 46 - .fetch_optional(&state.db) 47 - .await; 45 + let handle: Handle = match handle_str.parse() { 46 + Ok(h) => h, 47 + Err(_) => return ApiError::InvalidHandle(Some("Invalid handle format".into())).into_response(), 48 + }; 49 + let user = state.user_repo.get_by_handle(&handle).await; 48 50 match user { 49 51 Ok(Some(row)) => { 50 52 let _ = state ··· 53 55 .await; 54 56 DidResponse::response(row.did).into_response() 55 57 } 56 - Ok(None) => match crate::handle::resolve_handle(handle).await { 58 + Ok(None) => match crate::handle::resolve_handle(handle.as_str()).await { 57 59 Ok(did) => { 58 60 let _ = state 59 61 .cache ··· 130 132 } 131 133 132 134 async fn serve_subdomain_did_doc(state: &AppState, handle: &str, hostname: &str) -> Response { 133 - let full_handle = format!("{}.{}", handle, hostname); 134 - let user = sqlx::query!( 135 - "SELECT id, did, migrated_to_pds FROM users WHERE handle = $1", 136 - full_handle 137 - ) 138 - .fetch_optional(&state.db) 139 - .await; 140 - let (user_id, did, migrated_to_pds) = match user { 141 - Ok(Some(row)) => (row.id, row.did, row.migrated_to_pds), 135 + let hostname_for_handles = hostname.split(':').next().unwrap_or(hostname); 136 + let full_handle = format!("{}.{}", handle, hostname_for_handles); 137 + let full_handle_typed: Handle = match full_handle.parse() { 138 + Ok(h) => h, 139 + Err(_) => return ApiError::InvalidHandle(Some("Invalid handle format".into())).into_response(), 140 + }; 141 + let user = match state.user_repo.get_did_web_info_by_handle(&full_handle_typed).await { 142 + Ok(Some(u)) => u, 142 143 Ok(None) => { 143 144 return ApiError::NotFoundMsg("User not found".into()).into_response(); 144 145 } ··· 147 148 return ApiError::InternalError(None).into_response(); 148 149 } 149 150 }; 151 + let (user_id, did, migrated_to_pds) = (user.id, user.did, user.migrated_to_pds); 150 152 if !did.starts_with("did:web:") { 151 153 return ApiError::NotFoundMsg("User is not did:web".into()).into_response(); 152 154 } 153 - let subdomain_host = format!("{}.{}", handle, hostname); 155 + let subdomain_host = format!("{}.{}", handle, hostname_for_handles); 154 156 let encoded_subdomain = subdomain_host.replace(':', "%3A"); 155 157 let expected_self_hosted = format!("did:web:{}", encoded_subdomain); 156 158 if did != expected_self_hosted { ··· 158 160 .into_response(); 159 161 } 160 162 161 - let overrides = sqlx::query!( 162 - "SELECT verification_methods, also_known_as FROM did_web_overrides WHERE user_id = $1", 163 - user_id 164 - ) 165 - .fetch_optional(&state.db) 166 - .await 167 - .ok() 168 - .flatten(); 163 + let overrides = state 164 + .user_repo 165 + .get_did_web_overrides(user_id) 166 + .await 167 + .ok() 168 + .flatten(); 169 169 170 170 let service_endpoint = migrated_to_pds.unwrap_or_else(|| format!("https://{}", hostname)); 171 171 ··· 204 204 .into_response(); 205 205 } 206 206 207 - let key_row = sqlx::query!( 208 - "SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", 209 - user_id 210 - ) 211 - .fetch_optional(&state.db) 212 - .await; 213 - let key_bytes: Vec<u8> = match key_row { 214 - Ok(Some(row)) => match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) { 215 - Ok(k) => k, 216 - Err(_) => { 217 - return ApiError::InternalError(None).into_response(); 218 - } 219 - }, 220 - _ => { 207 + let key_info = match state.user_repo.get_user_key_by_id(user_id).await { 208 + Ok(Some(k)) => k, 209 + Ok(None) => return ApiError::InternalError(None).into_response(), 210 + Err(_) => return ApiError::InternalError(None).into_response(), 211 + }; 212 + let key_bytes: Vec<u8> = match crate::config::decrypt_key(&key_info.key_bytes, key_info.encryption_version) { 213 + Ok(k) => k, 214 + Err(_) => { 221 215 return ApiError::InternalError(None).into_response(); 222 216 } 223 217 }; ··· 264 258 265 259 pub async fn user_did_doc(State(state): State<AppState>, Path(handle): Path<String>) -> Response { 266 260 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 267 - let full_handle = format!("{}.{}", handle, hostname); 268 - let user = sqlx::query!( 269 - "SELECT id, did, migrated_to_pds FROM users WHERE handle = $1", 270 - full_handle 271 - ) 272 - .fetch_optional(&state.db) 273 - .await; 274 - let (user_id, did, migrated_to_pds) = match user { 275 - Ok(Some(row)) => (row.id, row.did, row.migrated_to_pds), 261 + let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); 262 + let full_handle = format!("{}.{}", handle, hostname_for_handles); 263 + let full_handle_typed: Handle = match full_handle.parse() { 264 + Ok(h) => h, 265 + Err(_) => return ApiError::InvalidHandle(Some("Invalid handle format".into())).into_response(), 266 + }; 267 + let user = match state.user_repo.get_did_web_info_by_handle(&full_handle_typed).await { 268 + Ok(Some(u)) => u, 276 269 Ok(None) => { 277 270 return ApiError::NotFoundMsg("User not found".into()).into_response(); 278 271 } ··· 281 274 return ApiError::InternalError(None).into_response(); 282 275 } 283 276 }; 277 + let (user_id, did, migrated_to_pds) = (user.id, user.did, user.migrated_to_pds); 284 278 if !did.starts_with("did:web:") { 285 279 return ApiError::NotFoundMsg("User is not did:web".into()).into_response(); 286 280 } 287 281 let encoded_hostname = hostname.replace(':', "%3A"); 288 282 let old_path_format = format!("did:web:{}:u:{}", encoded_hostname, handle); 289 - let subdomain_host = format!("{}.{}", handle, hostname); 283 + let subdomain_host = format!("{}.{}", handle, hostname_for_handles); 290 284 let encoded_subdomain = subdomain_host.replace(':', "%3A"); 291 285 let new_subdomain_format = format!("did:web:{}", encoded_subdomain); 292 286 if did != old_path_format && did != new_subdomain_format { ··· 294 288 .into_response(); 295 289 } 296 290 297 - let overrides = sqlx::query!( 298 - "SELECT verification_methods, also_known_as FROM did_web_overrides WHERE user_id = $1", 299 - user_id 300 - ) 301 - .fetch_optional(&state.db) 302 - .await 303 - .ok() 304 - .flatten(); 291 + let overrides = state 292 + .user_repo 293 + .get_did_web_overrides(user_id) 294 + .await 295 + .ok() 296 + .flatten(); 305 297 306 298 let service_endpoint = migrated_to_pds.unwrap_or_else(|| format!("https://{}", hostname)); 307 299 ··· 340 332 .into_response(); 341 333 } 342 334 343 - let key_row = sqlx::query!( 344 - "SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", 345 - user_id 346 - ) 347 - .fetch_optional(&state.db) 348 - .await; 349 - let key_bytes: Vec<u8> = match key_row { 350 - Ok(Some(row)) => match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) { 351 - Ok(k) => k, 352 - Err(_) => { 353 - return ApiError::InternalError(None).into_response(); 354 - } 355 - }, 356 - _ => { 335 + let key_info = match state.user_repo.get_user_key_by_id(user_id).await { 336 + Ok(Some(k)) => k, 337 + Ok(None) => return ApiError::InternalError(None).into_response(), 338 + Err(_) => return ApiError::InternalError(None).into_response(), 339 + }; 340 + let key_bytes: Vec<u8> = match crate::config::decrypt_key(&key_info.key_bytes, key_info.encryption_version) { 341 + Ok(k) => k, 342 + Err(_) => { 357 343 return ApiError::InternalError(None).into_response(); 358 344 } 359 345 }; ··· 404 390 handle: &str, 405 391 expected_signing_key: Option<&str>, 406 392 ) -> Result<(), String> { 407 - let subdomain_host = format!("{}.{}", handle, hostname); 393 + let hostname_for_handles = hostname.split(':').next().unwrap_or(hostname); 394 + let subdomain_host = format!("{}.{}", handle, hostname_for_handles); 408 395 let encoded_subdomain = subdomain_host.replace(':', "%3A"); 409 396 let expected_subdomain_did = format!("did:web:{}", encoded_subdomain); 410 397 if did == expected_subdomain_did { ··· 527 514 auth: BearerAuthAllowDeactivated, 528 515 ) -> Response { 529 516 let auth_user = auth.0; 530 - let user = match sqlx::query!( 531 - "SELECT handle FROM users u JOIN user_keys k ON u.id = k.user_id WHERE u.did = $1", 532 - &auth_user.did 533 - ) 534 - .fetch_optional(&state.db) 535 - .await 536 - { 537 - Ok(Some(row)) => row, 538 - _ => return ApiError::InternalError(None).into_response(), 517 + let handle = match state.user_repo.get_handle_by_did(&auth_user.did).await { 518 + Ok(Some(h)) => h, 519 + Ok(None) => return ApiError::InternalError(None).into_response(), 520 + Err(_) => return ApiError::InternalError(None).into_response(), 539 521 }; 540 522 let key_bytes = match auth_user.key_bytes { 541 523 Some(kb) => kb, ··· 571 553 StatusCode::OK, 572 554 Json(GetRecommendedDidCredentialsOutput { 573 555 rotation_keys, 574 - also_known_as: vec![format!("at://{}", user.handle)], 556 + also_known_as: vec![format!("at://{}", handle)], 575 557 verification_methods: VerificationMethods { atproto: did_key }, 576 558 services: Services { 577 559 atproto_pds: AtprotoPds { ··· 619 601 return ApiError::RateLimitExceeded(Some("Daily handle update limit exceeded.".into())) 620 602 .into_response(); 621 603 } 622 - let user_row = match sqlx::query!("SELECT id, handle FROM users WHERE did = $1", did.as_str()) 623 - .fetch_optional(&state.db) 624 - .await 625 - { 604 + let user_row = match state.user_repo.get_id_and_handle_by_did(&did).await { 626 605 Ok(Some(row)) => row, 627 - _ => return ApiError::InternalError(None).into_response(), 606 + Ok(None) => return ApiError::InternalError(None).into_response(), 607 + Err(_) => return ApiError::InternalError(None).into_response(), 628 608 }; 629 609 let user_id = user_row.id; 630 610 let current_handle = user_row.handle; ··· 657 637 .into_response(); 658 638 } 659 639 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 660 - let suffix = format!(".{}", hostname); 661 - let is_service_domain = crate::handle::is_service_domain_handle(&new_handle, &hostname); 640 + let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); 641 + let suffix = format!(".{}", hostname_for_handles); 642 + let is_service_domain = crate::handle::is_service_domain_handle(&new_handle, hostname_for_handles); 662 643 let handle = if is_service_domain { 663 644 let short_part = if new_handle.ends_with(&suffix) { 664 645 new_handle.strip_suffix(&suffix).unwrap_or(&new_handle) ··· 668 649 let full_handle = if new_handle.ends_with(&suffix) { 669 650 new_handle.clone() 670 651 } else { 671 - format!("{}.{}", new_handle, hostname) 652 + format!("{}.{}", new_handle, hostname_for_handles) 672 653 }; 673 654 if full_handle == current_handle { 674 655 let handle_typed = Handle::new_unchecked(&full_handle); ··· 727 708 } 728 709 new_handle.clone() 729 710 }; 730 - let existing = sqlx::query!( 731 - "SELECT id FROM users WHERE handle = $1 AND id != $2", 732 - handle, 733 - user_id 734 - ) 735 - .fetch_optional(&state.db) 736 - .await; 737 - if let Ok(Some(_)) = existing { 711 + let handle_typed: Handle = match handle.parse() { 712 + Ok(h) => h, 713 + Err(_) => return ApiError::InvalidHandle(Some("Invalid handle format".into())).into_response(), 714 + }; 715 + let handle_exists = match state.user_repo.check_handle_exists(&handle_typed, user_id).await { 716 + Ok(exists) => exists, 717 + Err(_) => return ApiError::InternalError(None).into_response(), 718 + }; 719 + if handle_exists { 738 720 return ApiError::HandleTaken.into_response(); 739 721 } 740 - let result = sqlx::query!( 741 - "UPDATE users SET handle = $1 WHERE id = $2", 742 - handle, 743 - user_id 744 - ) 745 - .execute(&state.db) 746 - .await; 722 + let result = state.user_repo.update_handle(user_id, &handle_typed).await; 747 723 match result { 748 724 Ok(_) => { 749 725 if !current_handle.is_empty() { ··· 753 729 .await; 754 730 } 755 731 let _ = state.cache.delete(&format!("handle:{}", handle)).await; 756 - let handle_typed = Handle::new_unchecked(&handle); 757 732 if let Err(e) = 758 733 crate::api::repo::record::sequence_identity_event(&state, &did, Some(&handle_typed)) 759 734 .await 760 735 { 761 736 warn!("Failed to sequence identity event for handle update: {}", e); 762 737 } 763 - if let Err(e) = update_plc_handle(&state, &did, &handle).await { 738 + if let Err(e) = update_plc_handle(&state, &did, &handle_typed).await { 764 739 warn!("Failed to update PLC handle: {}", e); 765 740 } 766 741 EmptyResponse::ok().into_response() ··· 774 749 775 750 pub async fn update_plc_handle( 776 751 state: &AppState, 777 - did: &str, 778 - new_handle: &str, 752 + did: &crate::types::Did, 753 + new_handle: &Handle, 779 754 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { 780 - if !did.starts_with("did:plc:") { 755 + if !did.as_str().starts_with("did:plc:") { 781 756 return Ok(()); 782 757 } 783 - let user_row = sqlx::query!( 784 - r#"SELECT u.id, uk.key_bytes, uk.encryption_version 785 - FROM users u 786 - JOIN user_keys uk ON u.id = uk.user_id 787 - WHERE u.did = $1"#, 788 - did 789 - ) 790 - .fetch_optional(&state.db) 791 - .await?; 792 - let user_row = match user_row { 758 + let user_row = match state.user_repo.get_user_with_key_by_did(did).await? { 793 759 Some(r) => r, 794 760 None => return Ok(()), 795 761 }; ··· 810 776 Some(h) => h, 811 777 None => return (StatusCode::BAD_REQUEST, "Missing host header").into_response(), 812 778 }; 813 - let handle = host.split(':').next().unwrap_or(host); 814 - let user = sqlx::query!("SELECT did FROM users WHERE handle = $1", handle) 815 - .fetch_optional(&state.db) 816 - .await; 779 + let handle_str = host.split(':').next().unwrap_or(host); 780 + let handle: Handle = match handle_str.parse() { 781 + Ok(h) => h, 782 + Err(_) => return (StatusCode::BAD_REQUEST, "Invalid handle format").into_response(), 783 + }; 784 + let user = state.user_repo.get_by_handle(&handle).await; 817 785 match user { 818 - Ok(Some(row)) => row.did.into_response(), 786 + Ok(Some(row)) => row.did.to_string().into_response(), 819 787 Ok(None) => (StatusCode::NOT_FOUND, "Handle not found").into_response(), 820 788 Err(e) => { 821 789 error!("DB error in well-known atproto-did: {:?}", e);
+15 -24
crates/tranquil-pds/src/api/identity/plc/request.rs
··· 25 25 ) { 26 26 return e; 27 27 } 28 - let user = match sqlx::query!("SELECT id FROM users WHERE did = $1", &auth_user.did) 29 - .fetch_optional(&state.db) 30 - .await 31 - { 32 - Ok(Some(row)) => row, 28 + let user_id = match state.user_repo.get_id_by_did(&auth_user.did).await { 29 + Ok(Some(id)) => id, 33 30 Ok(None) => return ApiError::AccountNotFound.into_response(), 34 31 Err(e) => { 35 32 error!("DB error: {:?}", e); 36 33 return ApiError::InternalError(None).into_response(); 37 34 } 38 35 }; 39 - let _ = sqlx::query!( 40 - "DELETE FROM plc_operation_tokens WHERE user_id = $1 OR expires_at < NOW()", 41 - user.id 42 - ) 43 - .execute(&state.db) 44 - .await; 36 + let _ = state.infra_repo.delete_plc_tokens_for_user(user_id).await; 45 37 let plc_token = generate_plc_token(); 46 38 let expires_at = Utc::now() + Duration::minutes(10); 47 - if let Err(e) = sqlx::query!( 48 - r#" 49 - INSERT INTO plc_operation_tokens (user_id, token, expires_at) 50 - VALUES ($1, $2, $3) 51 - "#, 52 - user.id, 53 - plc_token, 54 - expires_at 55 - ) 56 - .execute(&state.db) 57 - .await 39 + if let Err(e) = state 40 + .infra_repo 41 + .insert_plc_token(user_id, &plc_token, expires_at) 42 + .await 58 43 { 59 44 error!("Failed to create PLC token: {:?}", e); 60 45 return ApiError::InternalError(None).into_response(); 61 46 } 62 47 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 63 - if let Err(e) = 64 - crate::comms::enqueue_plc_operation(&state.db, user.id, &plc_token, &hostname).await 48 + if let Err(e) = crate::comms::comms_repo::enqueue_plc_operation( 49 + state.user_repo.as_ref(), 50 + state.infra_repo.as_ref(), 51 + user_id, 52 + &plc_token, 53 + &hostname, 54 + ) 55 + .await 65 56 { 66 57 warn!("Failed to enqueue PLC operation notification: {:?}", e); 67 58 }
+17 -37
crates/tranquil-pds/src/api/identity/plc/sign.rs
··· 67 67 .into_response(); 68 68 } 69 69 }; 70 - let user = match sqlx::query!("SELECT id FROM users WHERE did = $1", did) 71 - .fetch_optional(&state.db) 72 - .await 73 - { 74 - Ok(Some(row)) => row, 75 - _ => { 76 - return ApiError::AccountNotFound.into_response(); 70 + let user_id = match state.user_repo.get_id_by_did(did).await { 71 + Ok(Some(id)) => id, 72 + Ok(None) => return ApiError::AccountNotFound.into_response(), 73 + Err(e) => { 74 + error!("DB error: {:?}", e); 75 + return ApiError::InternalError(None).into_response(); 77 76 } 78 77 }; 79 - let token_row = match sqlx::query!( 80 - "SELECT id, expires_at FROM plc_operation_tokens WHERE user_id = $1 AND token = $2", 81 - user.id, 82 - token 83 - ) 84 - .fetch_optional(&state.db) 85 - .await 86 - { 87 - Ok(Some(row)) => row, 78 + let token_expiry = match state.infra_repo.get_plc_token_expiry(user_id, token).await { 79 + Ok(Some(expiry)) => expiry, 88 80 Ok(None) => { 89 81 return ApiError::InvalidToken(Some("Invalid or expired token".into())).into_response(); 90 82 } ··· 93 85 return ApiError::InternalError(None).into_response(); 94 86 } 95 87 }; 96 - if Utc::now() > token_row.expires_at { 97 - let _ = sqlx::query!( 98 - "DELETE FROM plc_operation_tokens WHERE id = $1", 99 - token_row.id 100 - ) 101 - .execute(&state.db) 102 - .await; 88 + if Utc::now() > token_expiry { 89 + let _ = state.infra_repo.delete_plc_token(user_id, token).await; 103 90 return ApiError::ExpiredToken(Some("Token has expired".into())).into_response(); 104 91 } 105 - let key_row = match sqlx::query!( 106 - "SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", 107 - user.id 108 - ) 109 - .fetch_optional(&state.db) 110 - .await 111 - { 92 + let key_row = match state.user_repo.get_user_key_by_id(user_id).await { 112 93 Ok(Some(row)) => row, 113 - _ => { 94 + Ok(None) => { 114 95 return ApiError::InternalError(Some("User signing key not found".into())) 115 96 .into_response(); 97 + } 98 + Err(e) => { 99 + error!("DB error: {:?}", e); 100 + return ApiError::InternalError(None).into_response(); 116 101 } 117 102 }; 118 103 let key_bytes = match crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version) ··· 179 164 return ApiError::InternalError(None).into_response(); 180 165 } 181 166 }; 182 - let _ = sqlx::query!( 183 - "DELETE FROM plc_operation_tokens WHERE id = $1", 184 - token_row.id 185 - ) 186 - .execute(&state.db) 187 - .await; 167 + let _ = state.infra_repo.delete_plc_token(user_id, token).await; 188 168 info!("Signed PLC operation for user {}", did); 189 169 ( 190 170 StatusCode::OK,
+18 -27
crates/tranquil-pds/src/api/identity/plc/submit.rs
··· 44 44 let op = &input.operation; 45 45 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 46 46 let public_url = format!("https://{}", hostname); 47 - let user = match sqlx::query!("SELECT id, handle FROM users WHERE did = $1", did) 48 - .fetch_optional(&state.db) 49 - .await 50 - { 51 - Ok(Some(row)) => row, 52 - _ => { 53 - return ApiError::AccountNotFound.into_response(); 47 + let user = match state.user_repo.get_id_and_handle_by_did(did).await { 48 + Ok(Some(u)) => u, 49 + Ok(None) => return ApiError::AccountNotFound.into_response(), 50 + Err(e) => { 51 + error!("DB error: {:?}", e); 52 + return ApiError::InternalError(None).into_response(); 54 53 } 55 54 }; 56 - let key_row = match sqlx::query!( 57 - "SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", 58 - user.id 59 - ) 60 - .fetch_optional(&state.db) 61 - .await 62 - { 55 + let key_row = match state.user_repo.get_user_key_by_id(user.id).await { 63 56 Ok(Some(row)) => row, 64 - _ => { 57 + Ok(None) => { 65 58 return ApiError::InternalError(Some("User signing key not found".into())) 66 59 .into_response(); 60 + } 61 + Err(e) => { 62 + error!("DB error: {:?}", e); 63 + return ApiError::InternalError(None).into_response(); 67 64 } 68 65 }; 69 66 let key_bytes = match crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version) ··· 139 136 { 140 137 return ApiError::from(e).into_response(); 141 138 } 142 - match sqlx::query!( 143 - "INSERT INTO repo_seq (did, event_type, handle) VALUES ($1, 'identity', $2) RETURNING seq", 144 - did, 145 - user.handle 146 - ) 147 - .fetch_one(&state.db) 148 - .await 139 + match state 140 + .repo_repo 141 + .insert_identity_event(did, Some(&user.handle)) 142 + .await 149 143 { 150 - Ok(row) => { 151 - if let Err(e) = sqlx::query(&format!("NOTIFY repo_updates, '{}'", row.seq)) 152 - .execute(&state.db) 153 - .await 154 - { 144 + Ok(seq) => { 145 + if let Err(e) = state.repo_repo.notify_update(seq).await { 155 146 warn!("Failed to notify identity event: {:?}", e); 156 147 } 157 148 }
+9 -23
crates/tranquil-pds/src/api/moderation/mod.rs
··· 72 72 let key_bytes = match &auth_user.key_bytes { 73 73 Some(kb) => kb.clone(), 74 74 None => { 75 - match sqlx::query_as::<_, (Vec<u8>, Option<i32>)>( 76 - "SELECT k.key_bytes, k.encryption_version 77 - FROM users u 78 - JOIN user_keys k ON u.id = k.user_id 79 - WHERE u.did = $1", 80 - ) 81 - .bind(&auth_user.did) 82 - .fetch_optional(&state.db) 83 - .await 84 - { 85 - Ok(Some((key_bytes_enc, encryption_version))) => { 86 - match crate::config::decrypt_key(&key_bytes_enc, encryption_version) { 75 + match state.user_repo.get_with_key_by_did(&auth_user.did).await { 76 + Ok(Some(user_with_key)) => { 77 + match crate::config::decrypt_key(&user_with_key.key_bytes, user_with_key.encryption_version) { 87 78 Ok(key) => key, 88 79 Err(e) => { 89 80 error!(error = ?e, "Failed to decrypt user key for report service auth"); ··· 185 176 186 177 async fn create_report_locally( 187 178 state: &AppState, 188 - did: &str, 179 + did: &crate::types::Did, 189 180 is_takendown: bool, 190 181 input: CreateReportInput, 191 182 ) -> Response { ··· 214 205 let report_id = (uuid::Uuid::now_v7().as_u128() & 0x7FFF_FFFF_FFFF_FFFF) as i64; 215 206 let subject_json = json!(input.subject); 216 207 217 - let insert = sqlx::query!( 218 - "INSERT INTO reports (id, reason_type, reason, subject_json, reported_by_did, created_at) VALUES ($1, $2, $3, $4, $5, $6)", 208 + if let Err(e) = state.infra_repo.insert_report( 219 209 report_id, 220 - input.reason_type, 221 - input.reason, 210 + &input.reason_type, 211 + input.reason.as_deref(), 222 212 subject_json, 223 213 did, 224 - created_at 225 - ) 226 - .execute(&state.db) 227 - .await; 228 - 229 - if let Err(e) = insert { 214 + created_at, 215 + ).await { 230 216 error!("Failed to insert report: {:?}", e); 231 217 return ApiError::InternalError(None).into_response(); 232 218 }
+73 -139
crates/tranquil-pds/src/api/notification_prefs.rs
··· 8 8 }; 9 9 use serde::{Deserialize, Serialize}; 10 10 use serde_json::json; 11 - use sqlx::Row; 12 11 use tracing::info; 13 12 14 13 #[derive(Serialize)] ··· 26 25 27 26 pub async fn get_notification_prefs(State(state): State<AppState>, auth: BearerAuth) -> Response { 28 27 let user = auth.0; 29 - let row = match sqlx::query( 30 - r#" 31 - SELECT 32 - email, 33 - preferred_comms_channel::text as channel, 34 - discord_id, 35 - discord_verified, 36 - telegram_username, 37 - telegram_verified, 38 - signal_number, 39 - signal_verified 40 - FROM users 41 - WHERE did = $1 42 - "#, 43 - ) 44 - .bind(&user.did) 45 - .fetch_one(&state.db) 46 - .await 47 - { 48 - Ok(r) => r, 28 + let prefs = match state.user_repo.get_notification_prefs(&user.did).await { 29 + Ok(Some(p)) => p, 30 + Ok(None) => return ApiError::AccountNotFound.into_response(), 49 31 Err(e) => { 50 32 return ApiError::InternalError(Some(format!("Database error: {}", e))).into_response(); 51 33 } 52 34 }; 53 - let email: String = row.get("email"); 54 - let channel: String = row.get("channel"); 55 - let discord_id: Option<String> = row.get("discord_id"); 56 - let discord_verified: bool = row.get("discord_verified"); 57 - let telegram_username: Option<String> = row.get("telegram_username"); 58 - let telegram_verified: bool = row.get("telegram_verified"); 59 - let signal_number: Option<String> = row.get("signal_number"); 60 - let signal_verified: bool = row.get("signal_verified"); 61 35 Json(NotificationPrefsResponse { 62 - preferred_channel: channel, 63 - email, 64 - discord_id, 65 - discord_verified, 66 - telegram_username, 67 - telegram_verified, 68 - signal_number, 69 - signal_verified, 36 + preferred_channel: prefs.preferred_channel, 37 + email: prefs.email, 38 + discord_id: prefs.discord_id, 39 + discord_verified: prefs.discord_verified, 40 + telegram_username: prefs.telegram_username, 41 + telegram_verified: prefs.telegram_verified, 42 + signal_number: prefs.signal_number, 43 + signal_verified: prefs.signal_verified, 70 44 }) 71 45 .into_response() 72 46 } ··· 91 65 pub async fn get_notification_history(State(state): State<AppState>, auth: BearerAuth) -> Response { 92 66 let user = auth.0; 93 67 94 - let user_id: uuid::Uuid = 95 - match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", &user.did) 96 - .fetch_one(&state.db) 97 - .await 98 - { 99 - Ok(id) => id, 100 - Err(e) => { 101 - return ApiError::InternalError(Some(format!("Database error: {}", e))) 102 - .into_response(); 103 - } 104 - }; 68 + let user_id: uuid::Uuid = match state.user_repo.get_id_by_did(&user.did).await { 69 + Ok(Some(id)) => id, 70 + Ok(None) => return ApiError::AccountNotFound.into_response(), 71 + Err(e) => { 72 + return ApiError::InternalError(Some(format!("Database error: {}", e))).into_response(); 73 + } 74 + }; 105 75 106 - let rows = match sqlx::query!( 107 - r#" 108 - SELECT 109 - created_at, 110 - channel as "channel: String", 111 - comms_type as "comms_type: String", 112 - status as "status: String", 113 - subject, 114 - body 115 - FROM comms_queue 116 - WHERE user_id = $1 117 - ORDER BY created_at DESC 118 - LIMIT 50 119 - "#, 120 - user_id 121 - ) 122 - .fetch_all(&state.db) 123 - .await 124 - { 76 + let rows = match state.infra_repo.get_notification_history(user_id, 50).await { 125 77 Ok(r) => r, 126 78 Err(e) => { 127 79 return ApiError::InternalError(Some(format!("Database error: {}", e))).into_response(); ··· 181 133 } 182 134 183 135 pub async fn request_channel_verification( 184 - db: &sqlx::PgPool, 136 + state: &AppState, 185 137 user_id: uuid::Uuid, 186 138 did: &str, 187 139 channel: &str, ··· 195 147 if channel == "email" { 196 148 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 197 149 let handle_str = handle.unwrap_or("user"); 198 - crate::comms::enqueue_email_update( 199 - db, 150 + crate::comms::comms_repo::enqueue_email_update( 151 + state.infra_repo.as_ref(), 200 152 user_id, 201 153 identifier, 202 154 handle_str, ··· 206 158 .await 207 159 .map_err(|e| format!("Failed to enqueue email notification: {}", e))?; 208 160 } else { 209 - sqlx::query!( 210 - r#" 211 - INSERT INTO comms_queue (user_id, channel, comms_type, recipient, subject, body, metadata) 212 - VALUES ($1, $2::comms_channel, 'channel_verification', $3, 'Verify your channel', $4, $5) 213 - "#, 214 - user_id, 215 - channel as _, 216 - identifier, 217 - format!("Your verification code is: {}", formatted_token), 218 - json!({"code": formatted_token}) 219 - ) 220 - .execute(db) 221 - .await 222 - .map_err(|e| format!("Failed to enqueue notification: {}", e))?; 161 + let comms_channel = match channel { 162 + "discord" => tranquil_db_traits::CommsChannel::Discord, 163 + "telegram" => tranquil_db_traits::CommsChannel::Telegram, 164 + "signal" => tranquil_db_traits::CommsChannel::Signal, 165 + _ => return Err("Invalid channel".to_string()), 166 + }; 167 + state 168 + .infra_repo 169 + .enqueue_comms( 170 + Some(user_id), 171 + comms_channel, 172 + tranquil_db_traits::CommsType::ChannelVerification, 173 + identifier, 174 + Some("Verify your channel"), 175 + &format!("Your verification code is: {}", formatted_token), 176 + Some(json!({"code": formatted_token})), 177 + ) 178 + .await 179 + .map_err(|e| format!("Failed to enqueue notification: {}", e))?; 223 180 } 224 181 225 182 Ok(token) ··· 232 189 ) -> Response { 233 190 let user = auth.0; 234 191 235 - let user_row = match sqlx::query!( 236 - "SELECT id, handle, email FROM users WHERE did = $1", 237 - &user.did 238 - ) 239 - .fetch_one(&state.db) 240 - .await 192 + let user_row = match state 193 + .user_repo 194 + .get_id_handle_email_by_did(&user.did) 195 + .await 241 196 { 242 - Ok(row) => row, 197 + Ok(Some(row)) => row, 198 + Ok(None) => return ApiError::AccountNotFound.into_response(), 243 199 Err(e) => { 244 200 return ApiError::InternalError(Some(format!("Database error: {}", e))).into_response(); 245 201 } ··· 259 215 ) 260 216 .into_response(); 261 217 } 262 - if let Err(e) = sqlx::query( 263 - r#"UPDATE users SET preferred_comms_channel = $1::comms_channel, updated_at = NOW() WHERE did = $2"# 264 - ) 265 - .bind(channel) 266 - .bind(&user.did) 267 - .execute(&state.db) 268 - .await 218 + if let Err(e) = state 219 + .user_repo 220 + .update_preferred_comms_channel(&user.did, channel) 221 + .await 269 222 { 270 223 return ApiError::InternalError(Some(format!("Database error: {}", e))).into_response(); 271 224 } ··· 285 238 if current_email.as_ref().map(|e| e.to_lowercase()) == Some(email_clean.clone()) { 286 239 info!(did = %user.did, "Email unchanged, skipping"); 287 240 } else { 288 - let exists = sqlx::query!( 289 - "SELECT 1 as one FROM users WHERE LOWER(email) = $1 AND id != $2", 290 - email_clean, 291 - user_id 292 - ) 293 - .fetch_optional(&state.db) 294 - .await; 295 - 296 - if let Ok(Some(_)) = exists { 297 - return ApiError::EmailTaken.into_response(); 241 + match state.user_repo.check_email_exists(&email_clean, user_id).await { 242 + Ok(true) => return ApiError::EmailTaken.into_response(), 243 + Err(e) => { 244 + return ApiError::InternalError(Some(format!("Database error: {}", e))) 245 + .into_response(); 246 + } 247 + Ok(false) => {} 298 248 } 299 249 300 250 if let Err(e) = request_channel_verification( 301 - &state.db, 251 + &state, 302 252 user_id, 303 253 &user.did, 304 254 "email", ··· 316 266 317 267 if let Some(ref discord_id) = input.discord_id { 318 268 if discord_id.is_empty() { 319 - if let Err(e) = sqlx::query!( 320 - "UPDATE users SET discord_id = NULL, discord_verified = FALSE, updated_at = NOW() WHERE id = $1", 321 - user_id 322 - ) 323 - .execute(&state.db) 324 - .await 325 - { 326 - return ApiError::InternalError(Some(format!("Database error: {}", e))).into_response(); 269 + if let Err(e) = state.user_repo.clear_discord(user_id).await { 270 + return ApiError::InternalError(Some(format!("Database error: {}", e))) 271 + .into_response(); 327 272 } 328 273 info!(did = %user.did, "Cleared Discord ID"); 329 274 } else { 330 - if let Err(e) = request_channel_verification( 331 - &state.db, user_id, &user.did, "discord", discord_id, None, 332 - ) 333 - .await 275 + if let Err(e) = 276 + request_channel_verification(&state, user_id, &user.did, "discord", discord_id, None) 277 + .await 334 278 { 335 279 return ApiError::InternalError(Some(e)).into_response(); 336 280 } ··· 342 286 if let Some(ref telegram) = input.telegram_username { 343 287 let telegram_clean = telegram.trim_start_matches('@'); 344 288 if telegram_clean.is_empty() { 345 - if let Err(e) = sqlx::query!( 346 - "UPDATE users SET telegram_username = NULL, telegram_verified = FALSE, updated_at = NOW() WHERE id = $1", 347 - user_id 348 - ) 349 - .execute(&state.db) 350 - .await 351 - { 352 - return ApiError::InternalError(Some(format!("Database error: {}", e))).into_response(); 289 + if let Err(e) = state.user_repo.clear_telegram(user_id).await { 290 + return ApiError::InternalError(Some(format!("Database error: {}", e))) 291 + .into_response(); 353 292 } 354 293 info!(did = %user.did, "Cleared Telegram username"); 355 294 } else { 356 295 if let Err(e) = request_channel_verification( 357 - &state.db, 296 + &state, 358 297 user_id, 359 298 &user.did, 360 299 "telegram", ··· 372 311 373 312 if let Some(ref signal) = input.signal_number { 374 313 if signal.is_empty() { 375 - if let Err(e) = sqlx::query!( 376 - "UPDATE users SET signal_number = NULL, signal_verified = FALSE, updated_at = NOW() WHERE id = $1", 377 - user_id 378 - ) 379 - .execute(&state.db) 380 - .await 381 - { 382 - return ApiError::InternalError(Some(format!("Database error: {}", e))).into_response(); 314 + if let Err(e) = state.user_repo.clear_signal(user_id).await { 315 + return ApiError::InternalError(Some(format!("Database error: {}", e))) 316 + .into_response(); 383 317 } 384 318 info!(did = %user.did, "Cleared Signal number"); 385 319 } else { 386 320 if let Err(e) = 387 - request_channel_verification(&state.db, user_id, &user.did, "signal", signal, None) 321 + request_channel_verification(&state, user_id, &user.did, "signal", signal, None) 388 322 .await 389 323 { 390 324 return ApiError::InternalError(Some(e)).into_response();
+2 -1
crates/tranquil-pds/src/api/proxy.rs
··· 225 225 let http_uri = crate::util::build_full_url(&uri.to_string()); 226 226 227 227 match crate::auth::validate_token_with_dpop( 228 - &state.db, 228 + state.user_repo.as_ref(), 229 + state.oauth_repo.as_ref(), 229 230 &token, 230 231 extracted.is_dpop, 231 232 dpop_proof,
+70 -102
crates/tranquil-pds/src/api/repo/blob.rs
··· 1 1 use crate::api::error::ApiError; 2 2 use crate::auth::{BearerAuthAllowDeactivated, ServiceTokenVerifier, is_service_token}; 3 - use crate::delegation::{self, DelegationActionType}; 3 + use crate::delegation::DelegationActionType; 4 4 use crate::state::AppState; 5 + use crate::types::{CidLink, Did}; 5 6 use crate::util::get_max_blob_size; 6 7 use axum::body::Body; 7 8 use axum::{ ··· 55 56 56 57 let is_service_auth = is_service_token(&token); 57 58 58 - let (did, _is_migration, controller_did) = if is_service_auth { 59 + let (did, _is_migration, controller_did): (Did, bool, Option<Did>) = if is_service_auth { 59 60 debug!("Verifying service token for blob upload"); 60 61 let verifier = ServiceTokenVerifier::new(); 61 62 match verifier ··· 64 65 { 65 66 Ok(claims) => { 66 67 debug!("Service token verified for DID: {}", claims.iss); 67 - (claims.iss, false, None) 68 + let did: Did = match claims.iss.parse() { 69 + Ok(d) => d, 70 + Err(_) => return ApiError::InvalidDid("Invalid DID format".into()).into_response(), 71 + }; 72 + (did, false, None) 68 73 } 69 74 Err(e) => { 70 75 error!("Service token verification failed: {:?}", e); ··· 82 87 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 83 88 ); 84 89 match crate::auth::validate_token_with_dpop( 85 - &state.db, 90 + state.user_repo.as_ref(), 91 + state.oauth_repo.as_ref(), 86 92 &token, 87 93 extracted.is_dpop, 88 94 dpop_proof, ··· 105 111 ) { 106 112 return e; 107 113 } 108 - let deactivated = sqlx::query_scalar!( 109 - "SELECT deactivated_at FROM users WHERE did = $1", 110 - &user.did 111 - ) 112 - .fetch_optional(&state.db) 113 - .await 114 - .ok() 115 - .flatten() 116 - .flatten(); 117 - let ctrl_did = user.controller_did.map(|d| d.to_string()); 118 - (user.did.to_string(), deactivated.is_some(), ctrl_did) 114 + let deactivated = state 115 + .user_repo 116 + .get_status_by_did(&user.did) 117 + .await 118 + .ok() 119 + .flatten() 120 + .and_then(|s| s.deactivated_at); 121 + let ctrl_did = user.controller_did.clone(); 122 + (user.did, deactivated.is_some(), ctrl_did) 119 123 } 120 124 Err(_) => { 121 125 return ApiError::AuthenticationFailed(None).into_response(); ··· 123 127 } 124 128 }; 125 129 126 - if crate::util::is_account_migrated(&state.db, &did) 130 + if state 131 + .user_repo 132 + .is_account_migrated(&did) 127 133 .await 128 134 .unwrap_or(false) 129 135 { ··· 135 141 .and_then(|h| h.to_str().ok()) 136 142 .unwrap_or("application/octet-stream"); 137 143 138 - let user_query = sqlx::query!("SELECT id FROM users WHERE did = $1", did) 139 - .fetch_optional(&state.db) 140 - .await; 141 - let user_id = match user_query { 142 - Ok(Some(row)) => row.id, 144 + let user_id = match state.user_repo.get_id_by_did(&did).await { 145 + Ok(Some(id)) => id, 143 146 _ => { 144 147 return ApiError::InternalError(None).into_response(); 145 148 } ··· 192 195 }; 193 196 let cid = Cid::new_v1(0x55, multihash); 194 197 let cid_str = cid.to_string(); 198 + let cid_link: CidLink = CidLink::new_unchecked(&cid_str); 195 199 let storage_key = format!("blobs/{}", cid_str); 196 200 197 201 info!( ··· 199 203 size, cid_str 200 204 ); 201 205 202 - let mut tx = match state.db.begin().await { 203 - Ok(tx) => tx, 204 - Err(e) => { 205 - let _ = state.blob_store.delete(&temp_key).await; 206 - error!("Failed to begin transaction: {:?}", e); 207 - return ApiError::InternalError(None).into_response(); 208 - } 209 - }; 210 - 211 - let insert = sqlx::query!( 212 - "INSERT INTO blobs (cid, mime_type, size_bytes, created_by_user, storage_key) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (cid) DO NOTHING RETURNING cid", 213 - cid_str, 214 - mime_type, 215 - size as i64, 216 - user_id, 217 - storage_key 218 - ) 219 - .fetch_optional(&mut *tx) 220 - .await; 221 - 222 - let was_inserted = match insert { 206 + let was_inserted = match state 207 + .blob_repo 208 + .insert_blob(&cid_link, &mime_type, size as i64, user_id, &storage_key) 209 + .await 210 + { 223 211 Ok(Some(_)) => true, 224 212 Ok(None) => false, 225 213 Err(e) => { ··· 229 217 } 230 218 }; 231 219 232 - if was_inserted && let Err(e) = state.blob_store.copy(&temp_key, &storage_key).await { 233 - let _ = state.blob_store.delete(&temp_key).await; 234 - error!("Failed to copy blob to final location: {:?}", e); 235 - return ApiError::InternalError(Some("Failed to store blob".into())).into_response(); 220 + if was_inserted { 221 + if let Err(e) = state.blob_store.copy(&temp_key, &storage_key).await { 222 + let _ = state.blob_store.delete(&temp_key).await; 223 + error!("Failed to copy blob to final location: {:?}", e); 224 + return ApiError::InternalError(Some("Failed to store blob".into())).into_response(); 225 + } 236 226 } 237 227 238 228 let _ = state.blob_store.delete(&temp_key).await; 239 229 240 - if let Err(e) = tx.commit().await { 241 - error!("Failed to commit blob transaction: {:?}", e); 242 - if was_inserted && let Err(cleanup_err) = state.blob_store.delete(&storage_key).await { 243 - error!( 244 - "Failed to cleanup orphaned blob {}: {:?}", 245 - storage_key, cleanup_err 246 - ); 247 - } 248 - return ApiError::InternalError(None).into_response(); 249 - } 250 - 251 230 if let Some(ref controller) = controller_did { 252 - let _ = delegation::log_delegation_action( 253 - &state.db, 254 - &did, 255 - controller, 256 - Some(controller), 257 - DelegationActionType::BlobUpload, 258 - Some(json!({ 259 - "cid": cid_str, 260 - "mime_type": mime_type, 261 - "size": size 262 - })), 263 - None, 264 - None, 265 - ) 266 - .await; 231 + let _ = state 232 + .delegation_repo 233 + .log_delegation_action( 234 + &did, 235 + controller, 236 + Some(controller), 237 + DelegationActionType::BlobUpload, 238 + Some(json!({ 239 + "cid": cid_str, 240 + "mime_type": mime_type, 241 + "size": size 242 + })), 243 + None, 244 + None, 245 + ) 246 + .await; 267 247 } 268 248 269 249 Json(json!({ ··· 305 285 Query(params): Query<ListMissingBlobsParams>, 306 286 ) -> Response { 307 287 let auth_user = auth.0; 308 - let did = auth_user.did; 309 - let user_query = sqlx::query!("SELECT id FROM users WHERE did = $1", did.as_str()) 310 - .fetch_optional(&state.db) 311 - .await; 312 - let user_id = match user_query { 313 - Ok(Some(row)) => row.id, 314 - _ => { 288 + let did = &auth_user.did; 289 + let user = match state.user_repo.get_by_did(did).await { 290 + Ok(Some(u)) => u, 291 + Ok(None) => return ApiError::InternalError(None).into_response(), 292 + Err(e) => { 293 + error!("DB error fetching user: {:?}", e); 315 294 return ApiError::InternalError(None).into_response(); 316 295 } 317 296 }; 318 297 let limit = params.limit.unwrap_or(500).clamp(1, 1000); 319 - let cursor_cid = params.cursor.as_deref().unwrap_or(""); 320 - let missing_query = sqlx::query!( 321 - r#" 322 - SELECT rb.blob_cid, rb.record_uri 323 - FROM record_blobs rb 324 - LEFT JOIN blobs b ON rb.blob_cid = b.cid 325 - WHERE rb.repo_id = $1 AND b.cid IS NULL AND rb.blob_cid > $2 326 - ORDER BY rb.blob_cid 327 - LIMIT $3 328 - "#, 329 - user_id, 330 - cursor_cid, 331 - limit + 1 332 - ) 333 - .fetch_all(&state.db) 334 - .await; 335 - let rows = match missing_query { 336 - Ok(r) => r, 298 + let cursor = params.cursor.as_deref(); 299 + let missing = match state 300 + .blob_repo 301 + .list_missing_blobs(user.id, cursor, limit + 1) 302 + .await 303 + { 304 + Ok(m) => m, 337 305 Err(e) => { 338 306 error!("DB error fetching missing blobs: {:?}", e); 339 307 return ApiError::InternalError(None).into_response(); 340 308 } 341 309 }; 342 - let has_more = rows.len() > limit as usize; 343 - let blobs: Vec<RecordBlob> = rows 310 + let has_more = missing.len() > limit as usize; 311 + let blobs: Vec<RecordBlob> = missing 344 312 .into_iter() 345 313 .take(limit as usize) 346 - .map(|row| RecordBlob { 347 - cid: row.blob_cid, 348 - record_uri: row.record_uri, 314 + .map(|m| RecordBlob { 315 + cid: m.blob_cid.to_string(), 316 + record_uri: m.record_uri.to_string(), 349 317 }) 350 318 .collect(); 351 319 let next_cursor = if has_more {
+46 -96
crates/tranquil-pds/src/api/repo/import.rs
··· 6 6 use crate::sync::import::{ImportError, apply_import, parse_car}; 7 7 use crate::sync::verify::CarVerifier; 8 8 use crate::types::Did; 9 + use tranquil_types::{AtUri, CidLink}; 9 10 use axum::{ 10 11 body::Bytes, 11 12 extract::State, ··· 45 46 } 46 47 let auth_user = auth.0; 47 48 let did = &auth_user.did; 48 - let user = match sqlx::query!( 49 - "SELECT id, handle, deactivated_at, takedown_ref FROM users WHERE did = $1", 50 - did 51 - ) 52 - .fetch_optional(&state.db) 53 - .await 54 - { 49 + let user = match state.user_repo.get_by_did(did).await { 55 50 Ok(Some(row)) => row, 56 51 Ok(None) => { 57 52 return ApiError::AccountNotFound.into_response(); ··· 190 185 .ok() 191 186 .and_then(|s| s.parse().ok()) 192 187 .unwrap_or(DEFAULT_MAX_BLOCKS); 193 - match apply_import(&state.db, user_id, root, blocks.clone(), max_blocks).await { 188 + match apply_import(&state.repo_repo, user_id, root, blocks.clone(), max_blocks).await { 194 189 Ok(import_result) => { 195 190 info!( 196 191 "Successfully imported {} records for user {}", 197 192 import_result.records.len(), 198 193 did 199 194 ); 200 - let blob_refs: Vec<(String, String)> = import_result 195 + let blob_refs: Vec<(AtUri, CidLink)> = import_result 201 196 .records 202 197 .iter() 203 198 .flat_map(|record| { 204 - let record_uri = format!("at://{}/{}/{}", did, record.collection, record.rkey); 199 + let record_uri = AtUri::from_parts(did.as_str(), &record.collection, &record.rkey); 205 200 record 206 201 .blob_refs 207 202 .iter() 208 - .map(move |blob_ref| (record_uri.clone(), blob_ref.cid.clone())) 203 + .map(move |blob_ref| (record_uri.clone(), CidLink::new_unchecked(blob_ref.cid.clone()))) 209 204 }) 210 205 .collect(); 211 206 212 207 if !blob_refs.is_empty() { 213 - let (record_uris, blob_cids): (Vec<String>, Vec<String>) = 208 + let (record_uris, blob_cids): (Vec<AtUri>, Vec<CidLink>) = 214 209 blob_refs.into_iter().unzip(); 215 210 216 - match sqlx::query!( 217 - r#" 218 - INSERT INTO record_blobs (repo_id, record_uri, blob_cid) 219 - SELECT $1, * FROM UNNEST($2::text[], $3::text[]) 220 - ON CONFLICT (repo_id, record_uri, blob_cid) DO NOTHING 221 - "#, 222 - user_id, 223 - &record_uris, 224 - &blob_cids 225 - ) 226 - .execute(&state.db) 227 - .await 211 + match state 212 + .blob_repo 213 + .insert_record_blobs(user_id, &record_uris, &blob_cids) 214 + .await 228 215 { 229 - Ok(result) => { 216 + Ok(()) => { 230 217 info!( 231 218 "Recorded {} blob references for imported repo", 232 - result.rows_affected() 219 + blob_cids.len() 233 220 ); 234 221 } 235 222 Err(e) => { ··· 237 224 } 238 225 } 239 226 } 240 - let key_row = match sqlx::query!( 241 - r#"SELECT uk.key_bytes, uk.encryption_version 242 - FROM user_keys uk 243 - JOIN users u ON uk.user_id = u.id 244 - WHERE u.did = $1"#, 245 - did 246 - ) 247 - .fetch_optional(&state.db) 248 - .await 249 - { 227 + let key_row = match state.user_repo.get_user_with_key_by_did(did).await { 250 228 Ok(Some(row)) => row, 251 229 Ok(None) => { 252 230 error!("No signing key found for user {}", did); ··· 295 273 return ApiError::InternalError(None).into_response(); 296 274 } 297 275 }; 298 - let new_root_str = new_root_cid.to_string(); 299 - if let Err(e) = sqlx::query!( 300 - "UPDATE repos SET repo_root_cid = $1, repo_rev = $2, updated_at = NOW() WHERE user_id = $3", 301 - new_root_str, 302 - &new_rev_str, 303 - user_id 304 - ) 305 - .execute(&state.db) 306 - .await 276 + let new_root_cid_link = CidLink::new_unchecked(&new_root_cid.to_string()); 277 + if let Err(e) = state 278 + .repo_repo 279 + .update_repo_root(user_id, &new_root_cid_link, &new_rev_str) 280 + .await 307 281 { 308 282 error!("Failed to update repo root: {:?}", e); 309 283 return ApiError::InternalError(None).into_response(); 310 284 } 311 285 let mut all_block_cids: Vec<Vec<u8>> = blocks.keys().map(|c| c.to_bytes()).collect(); 312 286 all_block_cids.push(new_root_cid.to_bytes()); 313 - if let Err(e) = sqlx::query!( 314 - r#" 315 - INSERT INTO user_blocks (user_id, block_cid) 316 - SELECT $1, block_cid FROM UNNEST($2::bytea[]) AS t(block_cid) 317 - ON CONFLICT (user_id, block_cid) DO NOTHING 318 - "#, 319 - user_id, 320 - &all_block_cids 321 - ) 322 - .execute(&state.db) 323 - .await 287 + if let Err(e) = state 288 + .repo_repo 289 + .insert_user_blocks(user_id, &all_block_cids, &new_rev_str) 290 + .await 324 291 { 325 292 error!("Failed to insert user_blocks: {:?}", e); 326 293 return ApiError::InternalError(None).into_response(); 327 294 } 295 + let new_root_str = new_root_cid.to_string(); 328 296 info!( 329 297 "Created new commit for imported repo: cid={}, rev={}", 330 298 new_root_str, new_rev_str ··· 338 306 "$type": "app.bsky.actor.defs#personalDetailsPref", 339 307 "birthDate": "1998-05-06T00:00:00.000Z" 340 308 }); 341 - if let Err(e) = sqlx::query!( 342 - "INSERT INTO account_preferences (user_id, name, value_json) VALUES ($1, $2, $3) 343 - ON CONFLICT (user_id, name) DO NOTHING", 344 - user_id, 345 - "app.bsky.actor.defs#personalDetailsPref", 346 - birthdate_pref 347 - ) 348 - .execute(&state.db) 349 - .await 309 + if let Err(e) = state 310 + .infra_repo 311 + .insert_account_preference_if_not_exists( 312 + user_id, 313 + "app.bsky.actor.defs#personalDetailsPref", 314 + birthdate_pref, 315 + ) 316 + .await 350 317 { 351 318 warn!( 352 319 "Failed to set default birthdate preference for migrated user: {:?}", ··· 397 364 state: &AppState, 398 365 did: &Did, 399 366 commit_cid: &str, 400 - ) -> Result<(), sqlx::Error> { 401 - let prev_cid: Option<String> = None; 402 - let prev_data_cid: Option<String> = None; 403 - let ops = serde_json::json!([]); 404 - let blobs: Vec<String> = vec![]; 405 - let blocks_cids: Vec<String> = vec![]; 406 - let did_str = did.as_str(); 407 - 408 - let mut tx = state.db.begin().await?; 367 + ) -> Result<(), tranquil_db::DbError> { 368 + let data = tranquil_db::CommitEventData { 369 + did: did.clone(), 370 + event_type: "commit".to_string(), 371 + commit_cid: Some(CidLink::new_unchecked(commit_cid)), 372 + prev_cid: None, 373 + ops: Some(serde_json::json!([])), 374 + blobs: Some(vec![]), 375 + blocks_cids: Some(vec![]), 376 + prev_data_cid: None, 377 + rev: None, 378 + }; 409 379 410 - let seq_row = sqlx::query!( 411 - r#" 412 - INSERT INTO repo_seq (did, event_type, commit_cid, prev_cid, prev_data_cid, ops, blobs, blocks_cids) 413 - VALUES ($1, 'commit', $2, $3, $4, $5, $6, $7) 414 - RETURNING seq 415 - "#, 416 - did_str, 417 - commit_cid, 418 - prev_cid, 419 - prev_data_cid, 420 - ops, 421 - &blobs, 422 - &blocks_cids 423 - ) 424 - .fetch_one(&mut *tx) 425 - .await?; 426 - 427 - sqlx::query(&format!("NOTIFY repo_updates, '{}'", seq_row.seq)) 428 - .execute(&mut *tx) 429 - .await?; 430 - 431 - tx.commit().await?; 380 + let seq = state.repo_repo.insert_commit_event(&data).await?; 381 + state.repo_repo.notify_update(seq).await?; 432 382 Ok(()) 433 383 }
+26 -26
crates/tranquil-pds/src/api/repo/meta.rs
··· 19 19 Query(input): Query<DescribeRepoInput>, 20 20 ) -> Response { 21 21 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 22 + let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); 22 23 let user_row = if input.repo.is_did() { 23 - sqlx::query!( 24 - "SELECT id, handle, did FROM users WHERE did = $1", 25 - input.repo.as_str() 26 - ) 27 - .fetch_optional(&state.db) 28 - .await 29 - .map(|opt| opt.map(|r| (r.id, r.handle, r.did))) 24 + let did: crate::types::Did = match input.repo.as_str().parse() { 25 + Ok(d) => d, 26 + Err(_) => return ApiError::InvalidRequest("Invalid DID format".into()).into_response(), 27 + }; 28 + state 29 + .user_repo 30 + .get_by_did(&did) 31 + .await 32 + .map(|opt| opt.map(|r| (r.id, r.handle, r.did))) 30 33 } else { 31 34 let repo_str = input.repo.as_str(); 32 - let handle = if !repo_str.contains('.') { 33 - format!("{}.{}", repo_str, hostname) 35 + let handle_str = if !repo_str.contains('.') { 36 + format!("{}.{}", repo_str, hostname_for_handles) 34 37 } else { 35 38 repo_str.to_string() 36 39 }; 37 - sqlx::query!( 38 - "SELECT id, handle, did FROM users WHERE handle = $1", 39 - handle 40 - ) 41 - .fetch_optional(&state.db) 42 - .await 43 - .map(|opt| opt.map(|r| (r.id, r.handle, r.did))) 40 + let handle: crate::types::Handle = match handle_str.parse() { 41 + Ok(h) => h, 42 + Err(_) => return ApiError::InvalidRequest("Invalid handle format".into()).into_response(), 43 + }; 44 + state 45 + .user_repo 46 + .get_by_handle(&handle) 47 + .await 48 + .map(|opt| opt.map(|r| (r.id, r.handle, r.did))) 44 49 }; 45 50 let (user_id, handle, did) = match user_row { 46 51 Ok(Some((id, handle, did))) => (id, handle, did), ··· 51 56 return ApiError::InternalError(None).into_response(); 52 57 } 53 58 }; 54 - let collections_query = sqlx::query!( 55 - "SELECT DISTINCT collection FROM records WHERE repo_id = $1", 56 - user_id 57 - ) 58 - .fetch_all(&state.db) 59 - .await; 60 - let collections: Vec<String> = match collections_query { 61 - Ok(rows) => rows.iter().map(|r| r.collection.clone()).collect(), 62 - Err(_) => Vec::new(), 63 - }; 59 + let collections = state 60 + .repo_repo 61 + .list_collections(user_id) 62 + .await 63 + .unwrap_or_default(); 64 64 let did_doc = json!({ 65 65 "id": did, 66 66 "alsoKnownAs": [format!("at://{}", handle)]
+38 -34
crates/tranquil-pds/src/api/repo/record/batch.rs
··· 1 1 use super::validation::validate_record_with_status; 2 - use super::write::has_verified_comms_channel; 3 2 use crate::api::error::ApiError; 4 3 use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log, extract_blob_cids}; 5 4 use crate::auth::BearerAuth; 6 - use crate::delegation::{self, DelegationActionType}; 5 + use crate::delegation::DelegationActionType; 7 6 use crate::repo::tracking::TrackingBlockStore; 8 7 use crate::state::AppState; 9 8 use crate::types::{AtIdentifier, AtUri, Did, Nsid, Rkey}; ··· 280 279 return ApiError::InvalidRepo("Repo does not match authenticated user".into()) 281 280 .into_response(); 282 281 } 283 - if crate::util::is_account_migrated(&state.db, &did) 282 + if state 283 + .user_repo 284 + .is_account_migrated(&did) 284 285 .await 285 286 .unwrap_or(false) 286 287 { 287 288 return ApiError::AccountMigrated.into_response(); 288 289 } 289 - let is_verified = has_verified_comms_channel(&state.db, &did) 290 + let is_verified = state 291 + .user_repo 292 + .has_verified_comms_channel(&did) 290 293 .await 291 294 .unwrap_or(false); 292 - let is_delegated = crate::delegation::is_delegated_account(&state.db, &did) 295 + let is_delegated = state 296 + .delegation_repo 297 + .is_delegated_account(&did) 293 298 .await 294 299 .unwrap_or(false); 295 300 if !is_verified && !is_delegated { ··· 373 378 } 374 379 } 375 380 376 - let user_id: uuid::Uuid = 377 - match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did.as_str()) 378 - .fetch_optional(&state.db) 379 - .await 380 - { 381 - Ok(Some(id)) => id, 382 - _ => return ApiError::InternalError(Some("User not found".into())).into_response(), 383 - }; 384 - let root_cid_str: String = match sqlx::query_scalar!( 385 - "SELECT repo_root_cid FROM repos WHERE user_id = $1", 386 - user_id 387 - ) 388 - .fetch_optional(&state.db) 389 - .await 381 + let user_id: uuid::Uuid = match state 382 + .user_repo 383 + .get_id_by_did(&did) 384 + .await 385 + { 386 + Ok(Some(id)) => id, 387 + _ => return ApiError::InternalError(Some("User not found".into())).into_response(), 388 + }; 389 + let root_cid_str = match state 390 + .repo_repo 391 + .get_repo_root_cid_by_user_id(user_id) 392 + .await 390 393 { 391 394 Ok(Some(cid_str)) => cid_str, 392 395 _ => return ApiError::InternalError(Some("Repo root not found".into())).into_response(), ··· 544 547 }) 545 548 .collect(); 546 549 547 - let _ = delegation::log_delegation_action( 548 - &state.db, 549 - &did, 550 - controller, 551 - Some(controller), 552 - DelegationActionType::RepoWrite, 553 - Some(json!({ 554 - "action": "apply_writes", 555 - "count": input.writes.len(), 556 - "writes": write_summary 557 - })), 558 - None, 559 - None, 560 - ) 561 - .await; 550 + let _ = state 551 + .delegation_repo 552 + .log_delegation_action( 553 + &did, 554 + controller, 555 + Some(controller), 556 + DelegationActionType::RepoWrite, 557 + Some(json!({ 558 + "action": "apply_writes", 559 + "count": input.writes.len(), 560 + "writes": write_summary 561 + })), 562 + None, 563 + None, 564 + ) 565 + .await; 562 566 } 563 567 564 568 (
+135 -17
crates/tranquil-pds/src/api/repo/record/delete.rs
··· 1 1 use crate::api::error::ApiError; 2 2 use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log}; 3 3 use crate::api::repo::record::write::{CommitInfo, prepare_repo_write}; 4 - use crate::delegation::{self, DelegationActionType}; 4 + use crate::delegation::DelegationActionType; 5 5 use crate::repo::tracking::TrackingBlockStore; 6 6 use crate::state::AppState; 7 - use crate::types::{AtIdentifier, Nsid, Rkey}; 7 + use crate::types::{AtIdentifier, AtUri, Nsid, Rkey}; 8 8 use axum::{ 9 9 Json, 10 10 extract::State, ··· 183 183 }; 184 184 185 185 if let Some(ref controller) = controller_did { 186 - let _ = delegation::log_delegation_action( 187 - &state.db, 188 - &did, 189 - controller, 190 - Some(controller), 191 - DelegationActionType::RepoWrite, 192 - Some(json!({ 193 - "action": "delete", 194 - "collection": collection_for_audit, 195 - "rkey": rkey_for_audit 196 - })), 197 - None, 198 - None, 199 - ) 200 - .await; 186 + let _ = state 187 + .delegation_repo 188 + .log_delegation_action( 189 + &did, 190 + controller, 191 + Some(controller), 192 + DelegationActionType::RepoWrite, 193 + Some(json!({ 194 + "action": "delete", 195 + "collection": collection_for_audit, 196 + "rkey": rkey_for_audit 197 + })), 198 + None, 199 + None, 200 + ) 201 + .await; 202 + } 203 + 204 + let deleted_uri = AtUri::from_parts(&did, &input.collection, &input.rkey); 205 + if let Err(e) = state.backlink_repo.remove_backlinks_by_uri(&deleted_uri).await { 206 + error!("Failed to remove backlinks for {}: {}", deleted_uri, e); 201 207 } 202 208 203 209 ( ··· 211 217 ) 212 218 .into_response() 213 219 } 220 + 221 + use crate::types::Did; 222 + use uuid::Uuid; 223 + 224 + pub async fn delete_record_internal( 225 + state: &AppState, 226 + did: &Did, 227 + user_id: Uuid, 228 + collection: &Nsid, 229 + rkey: &Rkey, 230 + ) -> Result<(), String> { 231 + let root_cid_str = state 232 + .repo_repo 233 + .get_repo_root_cid_by_user_id(user_id) 234 + .await 235 + .map_err(|e| format!("DB error: {}", e))? 236 + .ok_or_else(|| "Repo root not found".to_string())?; 237 + 238 + let current_root_cid = 239 + Cid::from_str(root_cid_str.as_str()).map_err(|_| "Invalid repo root CID".to_string())?; 240 + 241 + let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 242 + let commit_bytes = tracking_store 243 + .get(&current_root_cid) 244 + .await 245 + .map_err(|e| format!("Failed to fetch commit: {:?}", e))? 246 + .ok_or_else(|| "Commit block not found".to_string())?; 247 + 248 + let commit = Commit::from_cbor(&commit_bytes) 249 + .map_err(|e| format!("Failed to parse commit: {:?}", e))?; 250 + 251 + let mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 252 + let key = format!("{}/{}", collection, rkey); 253 + 254 + let prev_record_cid = mst 255 + .get(&key) 256 + .await 257 + .map_err(|e| format!("MST get error: {:?}", e))?; 258 + 259 + let Some(prev_cid) = prev_record_cid else { 260 + return Ok(()); 261 + }; 262 + 263 + let new_mst = mst 264 + .delete(&key) 265 + .await 266 + .map_err(|e| format!("Failed to delete from MST: {:?}", e))?; 267 + 268 + let new_mst_root = new_mst 269 + .persist() 270 + .await 271 + .map_err(|e| format!("Failed to persist MST: {:?}", e))?; 272 + 273 + let op = RecordOp::Delete { 274 + collection: collection.clone(), 275 + rkey: rkey.clone(), 276 + prev: Some(prev_cid), 277 + }; 278 + 279 + let mut new_mst_blocks = std::collections::BTreeMap::new(); 280 + let mut old_mst_blocks = std::collections::BTreeMap::new(); 281 + 282 + new_mst 283 + .blocks_for_path(&key, &mut new_mst_blocks) 284 + .await 285 + .map_err(|e| format!("Failed to get new MST blocks: {:?}", e))?; 286 + 287 + mst.blocks_for_path(&key, &mut old_mst_blocks) 288 + .await 289 + .map_err(|e| format!("Failed to get old MST blocks: {:?}", e))?; 290 + 291 + let mut relevant_blocks = new_mst_blocks.clone(); 292 + relevant_blocks.extend(old_mst_blocks.iter().map(|(k, v)| (*k, v.clone()))); 293 + 294 + let written_cids: Vec<Cid> = tracking_store 295 + .get_all_relevant_cids() 296 + .into_iter() 297 + .chain(relevant_blocks.keys().copied()) 298 + .collect::<std::collections::HashSet<_>>() 299 + .into_iter() 300 + .collect(); 301 + 302 + let written_cids_str: Vec<String> = written_cids.iter().map(|c| c.to_string()).collect(); 303 + 304 + let obsolete_cids: Vec<Cid> = std::iter::once(current_root_cid) 305 + .chain( 306 + old_mst_blocks 307 + .keys() 308 + .filter(|cid| !new_mst_blocks.contains_key(*cid)) 309 + .copied(), 310 + ) 311 + .chain(std::iter::once(prev_cid)) 312 + .collect(); 313 + 314 + commit_and_log( 315 + state, 316 + CommitParams { 317 + did, 318 + user_id, 319 + current_root_cid: Some(current_root_cid), 320 + prev_data_cid: Some(commit.data), 321 + new_mst_root, 322 + ops: vec![op], 323 + blocks_cids: &written_cids_str, 324 + blobs: &[], 325 + obsolete_cids, 326 + }, 327 + ) 328 + .await?; 329 + 330 + Ok(()) 331 + }
+1 -1
crates/tranquil-pds/src/api/repo/record/mod.rs
··· 6 6 pub mod write; 7 7 8 8 pub use batch::apply_writes; 9 - pub use delete::{DeleteRecordInput, delete_record}; 9 + pub use delete::{DeleteRecordInput, delete_record, delete_record_internal}; 10 10 pub use read::{GetRecordInput, ListRecordsInput, ListRecordsOutput, get_record, list_records}; 11 11 pub use utils::*; 12 12 pub use write::{
+63 -78
crates/tranquil-pds/src/api/repo/record/read.rs
··· 59 59 Query(input): Query<GetRecordInput>, 60 60 ) -> Response { 61 61 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 62 + let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); 62 63 let user_id_opt = if input.repo.is_did() { 63 - sqlx::query!("SELECT id FROM users WHERE did = $1", input.repo.as_str()) 64 - .fetch_optional(&state.db) 64 + let did: crate::types::Did = match input.repo.as_str().parse() { 65 + Ok(d) => d, 66 + Err(_) => return ApiError::InvalidRequest("Invalid DID format".into()).into_response(), 67 + }; 68 + state 69 + .user_repo 70 + .get_id_by_did(&did) 65 71 .await 66 - .map(|opt| opt.map(|r| r.id)) 72 + .map_err(|_| ()) 67 73 } else { 68 74 let repo_str = input.repo.as_str(); 69 - let handle = if !repo_str.contains('.') { 70 - format!("{}.{}", repo_str, hostname) 75 + let handle_str = if !repo_str.contains('.') { 76 + format!("{}.{}", repo_str, hostname_for_handles) 71 77 } else { 72 78 repo_str.to_string() 73 79 }; 74 - sqlx::query!("SELECT id FROM users WHERE handle = $1", handle) 75 - .fetch_optional(&state.db) 80 + let handle: crate::types::Handle = match handle_str.parse() { 81 + Ok(h) => h, 82 + Err(_) => return ApiError::InvalidRequest("Invalid handle format".into()).into_response(), 83 + }; 84 + state 85 + .user_repo 86 + .get_id_by_handle(&handle) 76 87 .await 77 - .map(|opt| opt.map(|r| r.id)) 88 + .map_err(|_| ()) 78 89 }; 79 90 let user_id: uuid::Uuid = match user_id_opt { 80 91 Ok(Some(id)) => id, ··· 85 96 return ApiError::InternalError(None).into_response(); 86 97 } 87 98 }; 88 - let record_row = sqlx::query!( 89 - "SELECT record_cid FROM records WHERE repo_id = $1 AND collection = $2 AND rkey = $3", 90 - user_id, 91 - input.collection.as_str(), 92 - input.rkey.as_str() 93 - ) 94 - .fetch_optional(&state.db) 95 - .await; 96 - let record_cid_str: String = match record_row { 97 - Ok(Some(row)) => row.record_cid, 99 + let record_row = state 100 + .repo_repo 101 + .get_record_cid(user_id, &input.collection, &input.rkey) 102 + .await; 103 + let record_cid_link = match record_row { 104 + Ok(Some(cid)) => cid, 98 105 _ => { 99 106 return ApiError::RecordNotFound.into_response(); 100 107 } 101 108 }; 109 + let record_cid_str = record_cid_link.to_string(); 102 110 if let Some(expected_cid) = &input.cid 103 111 && &record_cid_str != expected_cid 104 112 { ··· 152 160 Query(input): Query<ListRecordsInput>, 153 161 ) -> Response { 154 162 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 163 + let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); 155 164 let user_id_opt = if input.repo.is_did() { 156 - sqlx::query!("SELECT id FROM users WHERE did = $1", input.repo.as_str()) 157 - .fetch_optional(&state.db) 165 + let did: crate::types::Did = match input.repo.as_str().parse() { 166 + Ok(d) => d, 167 + Err(_) => return ApiError::InvalidRequest("Invalid DID format".into()).into_response(), 168 + }; 169 + state 170 + .user_repo 171 + .get_id_by_did(&did) 158 172 .await 159 - .map(|opt| opt.map(|r| r.id)) 173 + .map_err(|_| ()) 160 174 } else { 161 175 let repo_str = input.repo.as_str(); 162 - let handle = if !repo_str.contains('.') { 163 - format!("{}.{}", repo_str, hostname) 176 + let handle_str = if !repo_str.contains('.') { 177 + format!("{}.{}", repo_str, hostname_for_handles) 164 178 } else { 165 179 repo_str.to_string() 166 180 }; 167 - sqlx::query!("SELECT id FROM users WHERE handle = $1", handle) 168 - .fetch_optional(&state.db) 181 + let handle: crate::types::Handle = match handle_str.parse() { 182 + Ok(h) => h, 183 + Err(_) => return ApiError::InvalidRequest("Invalid handle format".into()).into_response(), 184 + }; 185 + state 186 + .user_repo 187 + .get_id_by_handle(&handle) 169 188 .await 170 - .map(|opt| opt.map(|r| r.id)) 189 + .map_err(|_| ()) 171 190 }; 172 191 let user_id: uuid::Uuid = match user_id_opt { 173 192 Ok(Some(id)) => id, ··· 181 200 let limit = input.limit.unwrap_or(50).clamp(1, 100); 182 201 let reverse = input.reverse.unwrap_or(false); 183 202 let limit_i64 = limit as i64; 184 - let order = if reverse { "ASC" } else { "DESC" }; 185 - let rows_res: Result<Vec<(String, String)>, sqlx::Error> = if let Some(cursor) = &input.cursor { 186 - let comparator = if reverse { ">" } else { "<" }; 187 - let query = format!( 188 - "SELECT rkey, record_cid FROM records WHERE repo_id = $1 AND collection = $2 AND rkey {} $3 ORDER BY rkey {} LIMIT $4", 189 - comparator, order 190 - ); 191 - sqlx::query_as(&query) 192 - .bind(user_id) 193 - .bind(input.collection.as_str()) 194 - .bind(cursor) 195 - .bind(limit_i64) 196 - .fetch_all(&state.db) 197 - .await 198 - } else { 199 - let mut conditions = vec!["repo_id = $1", "collection = $2"]; 200 - let mut param_idx = 3; 201 - if input.rkey_start.is_some() { 202 - conditions.push("rkey > $3"); 203 - param_idx += 1; 204 - } 205 - if input.rkey_end.is_some() { 206 - conditions.push(if param_idx == 3 { 207 - "rkey < $3" 208 - } else { 209 - "rkey < $4" 210 - }); 211 - param_idx += 1; 212 - } 213 - let limit_idx = param_idx; 214 - let query = format!( 215 - "SELECT rkey, record_cid FROM records WHERE {} ORDER BY rkey {} LIMIT ${}", 216 - conditions.join(" AND "), 217 - order, 218 - limit_idx 219 - ); 220 - let mut query_builder = sqlx::query_as::<_, (String, String)>(&query) 221 - .bind(user_id) 222 - .bind(input.collection.as_str()); 223 - if let Some(start) = &input.rkey_start { 224 - query_builder = query_builder.bind(start.as_str()); 225 - } 226 - if let Some(end) = &input.rkey_end { 227 - query_builder = query_builder.bind(end.as_str()); 228 - } 229 - query_builder.bind(limit_i64).fetch_all(&state.db).await 230 - }; 231 - let rows = match rows_res { 203 + let cursor_rkey = input.cursor.as_ref().and_then(|c| c.parse::<crate::types::Rkey>().ok()); 204 + let rows = match state 205 + .repo_repo 206 + .list_records( 207 + user_id, 208 + &input.collection, 209 + cursor_rkey.as_ref(), 210 + limit_i64, 211 + reverse, 212 + input.rkey_start.as_ref(), 213 + input.rkey_end.as_ref(), 214 + ) 215 + .await 216 + { 232 217 Ok(r) => r, 233 218 Err(e) => { 234 219 error!("Error listing records: {:?}", e); 235 220 return ApiError::InternalError(None).into_response(); 236 221 } 237 222 }; 238 - let last_rkey = rows.last().map(|(rkey, _)| rkey.clone()); 223 + let last_rkey = rows.last().map(|r| r.rkey.to_string()); 239 224 let parsed_rows: Vec<(Cid, String, String)> = rows 240 225 .iter() 241 - .filter_map(|(rkey, cid_str)| { 242 - Cid::from_str(cid_str) 226 + .filter_map(|row| { 227 + Cid::from_str(row.record_cid.as_str()) 243 228 .ok() 244 - .map(|cid| (cid, rkey.clone(), cid_str.clone())) 229 + .map(|cid| (cid, row.rkey.to_string(), row.record_cid.to_string())) 245 230 }) 246 231 .collect(); 247 232 let cids: Vec<Cid> = parsed_rows.iter().map(|(cid, _, _)| *cid).collect();
+142 -327
crates/tranquil-pds/src/api/repo/record/utils.rs
··· 36 36 } 37 37 } 38 38 39 + use tranquil_db_traits::Backlink; 40 + use crate::types::AtUri; 41 + 42 + pub fn extract_backlinks(uri: &AtUri, record: &Value) -> Vec<Backlink> { 43 + let record_type = record 44 + .get("$type") 45 + .and_then(|v| v.as_str()) 46 + .unwrap_or_default(); 47 + 48 + match record_type { 49 + "app.bsky.graph.follow" | "app.bsky.graph.block" => record 50 + .get("subject") 51 + .and_then(|v| v.as_str()) 52 + .filter(|s| s.starts_with("did:")) 53 + .map(|subject| vec![Backlink { 54 + uri: uri.clone(), 55 + path: "subject".to_string(), 56 + link_to: subject.to_string(), 57 + }]) 58 + .unwrap_or_default(), 59 + "app.bsky.feed.like" | "app.bsky.feed.repost" => record 60 + .get("subject") 61 + .and_then(|v| v.get("uri")) 62 + .and_then(|v| v.as_str()) 63 + .filter(|s| s.starts_with("at://")) 64 + .map(|subject_uri| vec![Backlink { 65 + uri: uri.clone(), 66 + path: "subject.uri".to_string(), 67 + link_to: subject_uri.to_string(), 68 + }]) 69 + .unwrap_or_default(), 70 + _ => Vec::new(), 71 + } 72 + } 73 + 39 74 pub fn create_signed_commit( 40 75 did: &Did, 41 76 data: Cid, ··· 98 133 state: &AppState, 99 134 params: CommitParams<'_>, 100 135 ) -> Result<CommitResult, String> { 136 + use tranquil_db_traits::{ApplyCommitError, ApplyCommitInput, CommitEventData, RecordDelete, RecordUpsert}; 137 + 101 138 let CommitParams { 102 139 did, 103 140 user_id, ··· 109 146 blobs, 110 147 obsolete_cids, 111 148 } = params; 112 - let key_row = sqlx::query!( 113 - "SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", 114 - user_id 115 - ) 116 - .fetch_one(&state.db) 117 - .await 118 - .map_err(|e| format!("Failed to fetch signing key: {}", e))?; 149 + let key_row = state 150 + .user_repo 151 + .get_user_key_by_id(user_id) 152 + .await 153 + .map_err(|e| format!("Failed to fetch signing key: {}", e))? 154 + .ok_or_else(|| "Signing key not found".to_string())?; 119 155 let key_bytes = crate::config::decrypt_key(&key_row.key_bytes, key_row.encryption_version) 120 156 .map_err(|e| format!("Failed to decrypt signing key: {}", e))?; 121 157 let signing_key = ··· 129 165 .put(&new_commit_bytes) 130 166 .await 131 167 .map_err(|e| format!("Failed to save commit block: {:?}", e))?; 132 - let mut tx = state 133 - .db 134 - .begin() 135 - .await 136 - .map_err(|e| format!("Failed to begin transaction: {}", e))?; 137 - let lock_result = sqlx::query!( 138 - "SELECT repo_root_cid FROM repos WHERE user_id = $1 FOR UPDATE NOWAIT", 139 - user_id 140 - ) 141 - .fetch_optional(&mut *tx) 142 - .await; 143 - match lock_result { 144 - Err(e) => { 145 - if let Some(db_err) = e.as_database_error() 146 - && db_err.code().as_deref() == Some("55P03") 147 - { 148 - return Err( 149 - "ConcurrentModification: Another request is modifying this repo".to_string(), 150 - ); 151 - } 152 - return Err(format!("Failed to acquire repo lock: {}", e)); 153 - } 154 - Ok(Some(row)) => { 155 - if let Some(expected_root) = &current_root_cid 156 - && row.repo_root_cid != expected_root.to_string() 157 - { 158 - return Err( 159 - "ConcurrentModification: Repo has been modified since last read".to_string(), 160 - ); 161 - } 162 - } 163 - Ok(None) => { 164 - return Err("Repo not found".to_string()); 165 - } 166 - } 167 - let is_account_active = sqlx::query_scalar!( 168 - "SELECT deactivated_at IS NULL FROM users WHERE id = $1", 169 - user_id 170 - ) 171 - .fetch_optional(&mut *tx) 172 - .await 173 - .map_err(|e| format!("Failed to check account status: {}", e))? 174 - .flatten() 175 - .unwrap_or(false); 176 - sqlx::query!( 177 - "UPDATE repos SET repo_root_cid = $1, repo_rev = $2 WHERE user_id = $3", 178 - new_root_cid.to_string(), 179 - &rev_str, 180 - user_id 181 - ) 182 - .execute(&mut *tx) 183 - .await 184 - .map_err(|e| format!("DB Error (repos): {}", e))?; 168 + 185 169 let mut all_block_cids: Vec<Vec<u8>> = blocks_cids 186 170 .iter() 187 171 .filter_map(|s| Cid::from_str(s).ok()) 188 172 .map(|c| c.to_bytes()) 189 173 .collect(); 190 174 all_block_cids.push(new_root_cid.to_bytes()); 191 - if !all_block_cids.is_empty() { 192 - sqlx::query!( 193 - r#" 194 - INSERT INTO user_blocks (user_id, block_cid) 195 - SELECT $1, block_cid FROM UNNEST($2::bytea[]) AS t(block_cid) 196 - ON CONFLICT (user_id, block_cid) DO NOTHING 197 - "#, 198 - user_id, 199 - &all_block_cids 200 - ) 201 - .execute(&mut *tx) 202 - .await 203 - .map_err(|e| format!("DB Error (user_blocks): {}", e))?; 204 - } 205 - if !obsolete_cids.is_empty() { 206 - let obsolete_bytes: Vec<Vec<u8>> = obsolete_cids.iter().map(|c| c.to_bytes()).collect(); 207 - sqlx::query!( 208 - r#" 209 - DELETE FROM user_blocks 210 - WHERE user_id = $1 211 - AND block_cid = ANY($2) 212 - "#, 213 - user_id, 214 - &obsolete_bytes as &[Vec<u8>] 215 - ) 216 - .execute(&mut *tx) 217 - .await 218 - .map_err(|e| format!("DB Error (user_blocks delete obsolete): {}", e))?; 219 - } 220 - let (upserts, deletes): (Vec<_>, Vec<_>) = ops 221 - .iter() 222 - .partition(|op| matches!(op, RecordOp::Create { .. } | RecordOp::Update { .. })); 223 - let (upsert_collections, upsert_rkeys, upsert_cids): (Vec<String>, Vec<String>, Vec<String>) = 224 - upserts 225 - .into_iter() 226 - .filter_map(|op| match op { 227 - RecordOp::Create { 228 - collection, 229 - rkey, 230 - cid, 175 + 176 + let obsolete_bytes: Vec<Vec<u8>> = obsolete_cids.iter().map(|c| c.to_bytes()).collect(); 177 + 178 + let (record_upserts, record_deletes): (Vec<RecordUpsert>, Vec<RecordDelete>) = ops.iter().fold( 179 + (Vec::new(), Vec::new()), 180 + |(mut upserts, mut deletes), op| { 181 + match op { 182 + RecordOp::Create { collection, rkey, cid } 183 + | RecordOp::Update { collection, rkey, cid, .. } => { 184 + upserts.push(RecordUpsert { 185 + collection: collection.clone(), 186 + rkey: rkey.clone(), 187 + cid: crate::types::CidLink::new_unchecked(&cid.to_string()), 188 + }); 189 + } 190 + RecordOp::Delete { collection, rkey, .. } => { 191 + deletes.push(RecordDelete { 192 + collection: collection.clone(), 193 + rkey: rkey.clone(), 194 + }); 231 195 } 232 - | RecordOp::Update { 233 - collection, 234 - rkey, 235 - cid, 236 - .. 237 - } => Some((collection.to_string(), rkey.to_string(), cid.to_string())), 238 - _ => None, 239 - }) 240 - .fold( 241 - (Vec::new(), Vec::new(), Vec::new()), 242 - |(mut cols, mut rkeys, mut cids), (c, r, ci)| { 243 - cols.push(c); 244 - rkeys.push(r); 245 - cids.push(ci); 246 - (cols, rkeys, cids) 247 - }, 248 - ); 249 - let (delete_collections, delete_rkeys): (Vec<String>, Vec<String>) = deletes 250 - .into_iter() 251 - .filter_map(|op| match op { 252 - RecordOp::Delete { 253 - collection, rkey, .. 254 - } => Some((collection.to_string(), rkey.to_string())), 255 - _ => None, 256 - }) 257 - .unzip(); 258 - if !upsert_collections.is_empty() { 259 - sqlx::query!( 260 - r#" 261 - INSERT INTO records (repo_id, collection, rkey, record_cid, repo_rev) 262 - SELECT $1, collection, rkey, record_cid, $5 263 - FROM UNNEST($2::text[], $3::text[], $4::text[]) AS t(collection, rkey, record_cid) 264 - ON CONFLICT (repo_id, collection, rkey) DO UPDATE 265 - SET record_cid = EXCLUDED.record_cid, repo_rev = EXCLUDED.repo_rev, created_at = NOW() 266 - "#, 267 - user_id, 268 - &upsert_collections, 269 - &upsert_rkeys, 270 - &upsert_cids, 271 - rev_str 272 - ) 273 - .execute(&mut *tx) 274 - .await 275 - .map_err(|e| format!("DB Error (records batch upsert): {}", e))?; 276 - } 277 - if !delete_collections.is_empty() { 278 - sqlx::query!( 279 - r#" 280 - DELETE FROM records 281 - WHERE repo_id = $1 282 - AND (collection, rkey) IN (SELECT * FROM UNNEST($2::text[], $3::text[])) 283 - "#, 284 - user_id, 285 - &delete_collections, 286 - &delete_rkeys 287 - ) 288 - .execute(&mut *tx) 289 - .await 290 - .map_err(|e| format!("DB Error (records batch delete): {}", e))?; 291 - } 292 - let ops_json = ops 196 + } 197 + (upserts, deletes) 198 + }, 199 + ); 200 + 201 + let ops_json: Vec<serde_json::Value> = ops 293 202 .iter() 294 203 .map(|op| match op { 295 - RecordOp::Create { 296 - collection, 297 - rkey, 298 - cid, 299 - } => json!({ 204 + RecordOp::Create { collection, rkey, cid } => json!({ 300 205 "action": "create", 301 206 "path": format!("{}/{}", collection, rkey), 302 207 "cid": cid.to_string() 303 208 }), 304 - RecordOp::Update { 305 - collection, 306 - rkey, 307 - cid, 308 - prev, 309 - } => { 209 + RecordOp::Update { collection, rkey, cid, prev } => { 310 210 let mut obj = json!({ 311 211 "action": "update", 312 212 "path": format!("{}/{}", collection, rkey), ··· 317 217 } 318 218 obj 319 219 } 320 - RecordOp::Delete { 321 - collection, 322 - rkey, 323 - prev, 324 - } => { 220 + RecordOp::Delete { collection, rkey, prev } => { 325 221 let mut obj = json!({ 326 222 "action": "delete", 327 223 "path": format!("{}/{}", collection, rkey), ··· 333 229 obj 334 230 } 335 231 }) 336 - .collect::<Vec<_>>(); 337 - if is_account_active { 338 - let event_type = "commit"; 339 - let prev_cid_str = current_root_cid.map(|c| c.to_string()); 340 - let prev_data_cid_str = prev_data_cid.map(|c| c.to_string()); 341 - let seq_row = sqlx::query!( 342 - r#" 343 - INSERT INTO repo_seq (did, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids, prev_data_cid) 344 - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) 345 - RETURNING seq 346 - "#, 347 - did.as_str(), 348 - event_type, 349 - new_root_cid.to_string(), 350 - prev_cid_str, 351 - json!(ops_json), 352 - blobs, 353 - blocks_cids, 354 - prev_data_cid_str, 355 - ) 356 - .fetch_one(&mut *tx) 232 + .collect(); 233 + 234 + let commit_event = CommitEventData { 235 + did: did.clone(), 236 + event_type: "commit".to_string(), 237 + commit_cid: Some(crate::types::CidLink::new_unchecked(&new_root_cid.to_string())), 238 + prev_cid: current_root_cid.map(|c| crate::types::CidLink::new_unchecked(&c.to_string())), 239 + ops: Some(json!(ops_json)), 240 + blobs: Some(blobs.to_vec()), 241 + blocks_cids: Some(blocks_cids.to_vec()), 242 + prev_data_cid: prev_data_cid.map(|c| crate::types::CidLink::new_unchecked(&c.to_string())), 243 + rev: Some(rev_str.clone()), 244 + }; 245 + 246 + let input = ApplyCommitInput { 247 + user_id, 248 + did: did.clone(), 249 + expected_root_cid: current_root_cid.map(|c| crate::types::CidLink::new_unchecked(&c.to_string())), 250 + new_root_cid: crate::types::CidLink::new_unchecked(&new_root_cid.to_string()), 251 + new_rev: rev_str.clone(), 252 + new_block_cids: all_block_cids, 253 + obsolete_block_cids: obsolete_bytes, 254 + record_upserts, 255 + record_deletes, 256 + commit_event, 257 + }; 258 + 259 + let result = state 260 + .repo_repo 261 + .apply_commit(input) 357 262 .await 358 - .map_err(|e| format!("DB Error (repo_seq): {}", e))?; 359 - sqlx::query(&format!("NOTIFY repo_updates, '{}'", seq_row.seq)) 360 - .execute(&mut *tx) 361 - .await 362 - .map_err(|e| format!("DB Error (notify): {}", e))?; 363 - } 364 - tx.commit() 365 - .await 366 - .map_err(|e| format!("Failed to commit transaction: {}", e))?; 367 - if is_account_active { 263 + .map_err(|e| match e { 264 + ApplyCommitError::RepoNotFound => "Repo not found".to_string(), 265 + ApplyCommitError::ConcurrentModification => { 266 + "ConcurrentModification: Repo has been modified since last read".to_string() 267 + } 268 + ApplyCommitError::Database(msg) => format!("DB Error: {}", msg), 269 + })?; 270 + 271 + if result.is_account_active { 368 272 let _ = sequence_sync_event(state, did, &new_root_cid.to_string(), Some(&rev_str)).await; 369 273 } 274 + 370 275 Ok(CommitResult { 371 276 commit_cid: new_root_cid, 372 277 rev: rev_str, ··· 382 287 use crate::repo::tracking::TrackingBlockStore; 383 288 use jacquard_repo::mst::Mst; 384 289 use std::sync::Arc; 385 - let user_id: Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did.as_str()) 386 - .fetch_optional(&state.db) 290 + let user_id: Uuid = state 291 + .user_repo 292 + .get_id_by_did(did) 387 293 .await 388 294 .map_err(|e| format!("DB error: {}", e))? 389 295 .ok_or_else(|| "User not found".to_string())?; 390 - let root_cid_str: String = sqlx::query_scalar!( 391 - "SELECT repo_root_cid FROM repos WHERE user_id = $1", 392 - user_id 393 - ) 394 - .fetch_optional(&state.db) 395 - .await 396 - .map_err(|e| format!("DB error: {}", e))? 397 - .ok_or_else(|| "Repo not found".to_string())?; 296 + let root_cid_link = state 297 + .repo_repo 298 + .get_repo_root_cid_by_user_id(user_id) 299 + .await 300 + .map_err(|e| format!("DB error: {}", e))? 301 + .ok_or_else(|| "Repo not found".to_string())?; 398 302 let current_root_cid = 399 - Cid::from_str(&root_cid_str).map_err(|_| "Invalid repo root CID".to_string())?; 303 + Cid::from_str(root_cid_link.as_str()).map_err(|_| "Invalid repo root CID".to_string())?; 400 304 let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 401 305 let commit_bytes = tracking_store 402 306 .get(&current_root_cid) ··· 481 385 did: &Did, 482 386 handle: Option<&Handle>, 483 387 ) -> Result<i64, String> { 484 - let mut tx = state 485 - .db 486 - .begin() 487 - .await 488 - .map_err(|e| format!("Failed to begin transaction: {}", e))?; 489 - let seq_row = sqlx::query!( 490 - r#" 491 - INSERT INTO repo_seq (did, event_type, handle) 492 - VALUES ($1, 'identity', $2) 493 - RETURNING seq 494 - "#, 495 - did.as_str(), 496 - handle.map(|h| h.as_str()), 497 - ) 498 - .fetch_one(&mut *tx) 499 - .await 500 - .map_err(|e| format!("DB Error (repo_seq identity): {}", e))?; 501 - sqlx::query(&format!("NOTIFY repo_updates, '{}'", seq_row.seq)) 502 - .execute(&mut *tx) 503 - .await 504 - .map_err(|e| format!("DB Error (notify): {}", e))?; 505 - tx.commit() 388 + state 389 + .repo_repo 390 + .insert_identity_event(did, handle) 506 391 .await 507 - .map_err(|e| format!("Failed to commit transaction: {}", e))?; 508 - Ok(seq_row.seq) 392 + .map_err(|e| format!("DB Error (identity event): {}", e)) 509 393 } 510 394 pub async fn sequence_account_event( 511 395 state: &AppState, ··· 513 397 active: bool, 514 398 status: Option<&str>, 515 399 ) -> Result<i64, String> { 516 - let mut tx = state 517 - .db 518 - .begin() 400 + state 401 + .repo_repo 402 + .insert_account_event(did, active, status) 519 403 .await 520 - .map_err(|e| format!("Failed to begin transaction: {}", e))?; 521 - let seq_row = sqlx::query!( 522 - r#" 523 - INSERT INTO repo_seq (did, event_type, active, status) 524 - VALUES ($1, 'account', $2, $3) 525 - RETURNING seq 526 - "#, 527 - did.as_str(), 528 - active, 529 - status, 530 - ) 531 - .fetch_one(&mut *tx) 532 - .await 533 - .map_err(|e| format!("DB Error (repo_seq account): {}", e))?; 534 - sqlx::query(&format!("NOTIFY repo_updates, '{}'", seq_row.seq)) 535 - .execute(&mut *tx) 536 - .await 537 - .map_err(|e| format!("DB Error (notify): {}", e))?; 538 - tx.commit() 539 - .await 540 - .map_err(|e| format!("Failed to commit transaction: {}", e))?; 541 - Ok(seq_row.seq) 404 + .map_err(|e| format!("DB Error (account event): {}", e)) 542 405 } 543 406 pub async fn sequence_sync_event( 544 407 state: &AppState, ··· 546 409 commit_cid: &str, 547 410 rev: Option<&str>, 548 411 ) -> Result<i64, String> { 549 - let mut tx = state 550 - .db 551 - .begin() 412 + let cid_link = crate::types::CidLink::new_unchecked(commit_cid); 413 + state 414 + .repo_repo 415 + .insert_sync_event(did, &cid_link, rev) 552 416 .await 553 - .map_err(|e| format!("Failed to begin transaction: {}", e))?; 554 - let seq_row = sqlx::query!( 555 - r#" 556 - INSERT INTO repo_seq (did, event_type, commit_cid, rev) 557 - VALUES ($1, 'sync', $2, $3) 558 - RETURNING seq 559 - "#, 560 - did.as_str(), 561 - commit_cid, 562 - rev, 563 - ) 564 - .fetch_one(&mut *tx) 565 - .await 566 - .map_err(|e| format!("DB Error (repo_seq sync): {}", e))?; 567 - sqlx::query(&format!("NOTIFY repo_updates, '{}'", seq_row.seq)) 568 - .execute(&mut *tx) 569 - .await 570 - .map_err(|e| format!("DB Error (notify): {}", e))?; 571 - tx.commit() 572 - .await 573 - .map_err(|e| format!("Failed to commit transaction: {}", e))?; 574 - Ok(seq_row.seq) 417 + .map_err(|e| format!("DB Error (sync event): {}", e)) 575 418 } 576 419 577 420 pub async fn sequence_genesis_commit( ··· 581 424 mst_root_cid: &Cid, 582 425 rev: &str, 583 426 ) -> Result<i64, String> { 584 - let ops = serde_json::json!([]); 585 - let blobs: Vec<String> = vec![]; 586 - let blocks_cids: Vec<String> = vec![mst_root_cid.to_string(), commit_cid.to_string()]; 587 - let prev_cid: Option<&str> = None; 588 - let commit_cid_str = commit_cid.to_string(); 589 - let mut tx = state 590 - .db 591 - .begin() 592 - .await 593 - .map_err(|e| format!("Failed to begin transaction: {}", e))?; 594 - let seq_row = sqlx::query!( 595 - r#" 596 - INSERT INTO repo_seq (did, event_type, commit_cid, prev_cid, ops, blobs, blocks_cids, rev) 597 - VALUES ($1, 'commit', $2, $3::TEXT, $4, $5, $6, $7) 598 - RETURNING seq 599 - "#, 600 - did.as_str(), 601 - commit_cid_str, 602 - prev_cid, 603 - ops, 604 - &blobs, 605 - &blocks_cids, 606 - rev 607 - ) 608 - .fetch_one(&mut *tx) 609 - .await 610 - .map_err(|e| format!("DB Error (repo_seq genesis commit): {}", e))?; 611 - sqlx::query(&format!("NOTIFY repo_updates, '{}'", seq_row.seq)) 612 - .execute(&mut *tx) 427 + let commit_cid_link = crate::types::CidLink::new_unchecked(&commit_cid.to_string()); 428 + let mst_root_cid_link = crate::types::CidLink::new_unchecked(&mst_root_cid.to_string()); 429 + state 430 + .repo_repo 431 + .insert_genesis_commit_event(did, &commit_cid_link, &mst_root_cid_link, rev) 613 432 .await 614 - .map_err(|e| format!("DB Error (notify): {}", e))?; 615 - tx.commit() 616 - .await 617 - .map_err(|e| format!("Failed to commit transaction: {}", e))?; 618 - Ok(seq_row.seq) 433 + .map_err(|e| format!("DB Error (genesis commit event): {}", e)) 619 434 }
+166 -104
crates/tranquil-pds/src/api/repo/record/write.rs
··· 1 1 use super::validation::validate_record_with_status; 2 2 use crate::api::error::ApiError; 3 - use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log, extract_blob_cids}; 4 - use crate::delegation::{self, DelegationActionType}; 3 + use crate::api::repo::record::utils::{CommitParams, RecordOp, commit_and_log, extract_backlinks, extract_blob_cids}; 4 + use crate::delegation::DelegationActionType; 5 5 use crate::repo::tracking::TrackingBlockStore; 6 6 use crate::state::AppState; 7 7 use crate::types::{AtIdentifier, AtUri, Did, Nsid, Rkey}; ··· 15 15 use jacquard_repo::{commit::Commit, mst::Mst, storage::BlockStore}; 16 16 use serde::{Deserialize, Serialize}; 17 17 use serde_json::json; 18 - use sqlx::{PgPool, Row}; 19 18 use std::str::FromStr; 20 19 use std::sync::Arc; 21 20 use tracing::error; 22 21 use uuid::Uuid; 23 22 24 - pub async fn has_verified_comms_channel(db: &PgPool, did: &Did) -> Result<bool, sqlx::Error> { 25 - let row = sqlx::query( 26 - r#" 27 - SELECT 28 - email_verified, 29 - discord_verified, 30 - telegram_verified, 31 - signal_verified 32 - FROM users 33 - WHERE did = $1 34 - "#, 35 - ) 36 - .bind(did.as_str()) 37 - .fetch_optional(db) 38 - .await?; 39 - match row { 40 - Some(r) => { 41 - let email_verified: bool = r.get("email_verified"); 42 - let discord_verified: bool = r.get("discord_verified"); 43 - let telegram_verified: bool = r.get("telegram_verified"); 44 - let signal_verified: bool = r.get("signal_verified"); 45 - Ok(email_verified || discord_verified || telegram_verified || signal_verified) 46 - } 47 - None => Ok(false), 48 - } 49 - } 50 - 51 23 pub struct RepoWriteAuth { 52 24 pub did: Did, 53 25 pub user_id: Uuid, ··· 70 42 .ok_or_else(|| ApiError::AuthenticationRequired.into_response())?; 71 43 let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 72 44 let auth_user = crate::auth::validate_token_with_dpop( 73 - &state.db, 45 + state.user_repo.as_ref(), 46 + state.oauth_repo.as_ref(), 74 47 &extracted.token, 75 48 extracted.is_dpop, 76 49 dpop_proof, ··· 89 62 ApiError::InvalidRepo("Repo does not match authenticated user".into()).into_response(), 90 63 ); 91 64 } 92 - if crate::util::is_account_migrated(&state.db, &auth_user.did) 65 + if state 66 + .user_repo 67 + .is_account_migrated(&auth_user.did) 93 68 .await 94 69 .unwrap_or(false) 95 70 { 96 71 return Err(ApiError::AccountMigrated.into_response()); 97 72 } 98 - let is_verified = has_verified_comms_channel(&state.db, &auth_user.did) 73 + let is_verified = state 74 + .user_repo 75 + .has_verified_comms_channel(&auth_user.did) 99 76 .await 100 77 .unwrap_or(false); 101 - let is_delegated = crate::delegation::is_delegated_account(&state.db, &auth_user.did) 78 + let is_delegated = state 79 + .delegation_repo 80 + .is_delegated_account(&auth_user.did) 102 81 .await 103 82 .unwrap_or(false); 104 83 if !is_verified && !is_delegated { 105 84 return Err(ApiError::AccountNotVerified.into_response()); 106 85 } 107 - let user_id = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", &auth_user.did) 108 - .fetch_optional(&state.db) 86 + let user_id = state 87 + .user_repo 88 + .get_id_by_did(&auth_user.did) 109 89 .await 110 90 .map_err(|e| { 111 91 error!("DB error fetching user: {}", e); 112 92 ApiError::InternalError(None).into_response() 113 93 })? 114 94 .ok_or_else(|| ApiError::InternalError(Some("User not found".into())).into_response())?; 115 - let root_cid_str: String = sqlx::query_scalar!( 116 - "SELECT repo_root_cid FROM repos WHERE user_id = $1", 117 - user_id 118 - ) 119 - .fetch_optional(&state.db) 120 - .await 121 - .map_err(|e| { 122 - error!("DB error fetching repo root: {}", e); 123 - ApiError::InternalError(None).into_response() 124 - })? 125 - .ok_or_else(|| ApiError::InternalError(Some("Repo root not found".into())).into_response())?; 95 + let root_cid_str = state 96 + .repo_repo 97 + .get_repo_root_cid_by_user_id(user_id) 98 + .await 99 + .map_err(|e| { 100 + error!("DB error fetching repo root: {}", e); 101 + ApiError::InternalError(None).into_response() 102 + })? 103 + .ok_or_else(|| ApiError::InternalError(Some("Repo root not found".into())).into_response())?; 126 104 let current_root_cid = Cid::from_str(&root_cid_str).map_err(|_| { 127 105 ApiError::InternalError(Some("Invalid repo root CID".into())).into_response() 128 106 })?; ··· 200 178 { 201 179 return ApiError::InvalidSwap(Some("Repo has been modified".into())).into_response(); 202 180 } 203 - let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 204 - let commit_bytes = match tracking_store.get(&current_root_cid).await { 205 - Ok(Some(b)) => b, 206 - _ => return ApiError::InternalError(Some("Commit block not found".into())).into_response(), 207 - }; 208 - let commit = match Commit::from_cbor(&commit_bytes) { 209 - Ok(c) => c, 210 - _ => return ApiError::InternalError(Some("Failed to parse commit".into())).into_response(), 211 - }; 212 - let mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 181 + 213 182 let validation_status = if input.validate == Some(false) { 214 183 None 215 184 } else { ··· 225 194 } 226 195 }; 227 196 let rkey = input.rkey.unwrap_or_else(Rkey::generate); 197 + 198 + let tracking_store = TrackingBlockStore::new(state.block_store.clone()); 199 + let commit_bytes = match tracking_store.get(&current_root_cid).await { 200 + Ok(Some(b)) => b, 201 + _ => return ApiError::InternalError(Some("Commit block not found".into())).into_response(), 202 + }; 203 + let commit = match Commit::from_cbor(&commit_bytes) { 204 + Ok(c) => c, 205 + _ => return ApiError::InternalError(Some("Failed to parse commit".into())).into_response(), 206 + }; 207 + let mut mst = Mst::load(Arc::new(tracking_store.clone()), commit.data, None); 208 + let initial_mst_root = commit.data; 209 + 210 + let mut ops: Vec<RecordOp> = Vec::new(); 211 + let mut conflict_uris_to_cleanup: Vec<AtUri> = Vec::new(); 212 + let mut all_old_mst_blocks = std::collections::BTreeMap::new(); 213 + 214 + if input.validate != Some(false) { 215 + let record_uri = AtUri::from_parts(&did, &input.collection, &rkey); 216 + let backlinks = extract_backlinks(&record_uri, &input.record); 217 + 218 + if !backlinks.is_empty() { 219 + let conflicts = match state 220 + .backlink_repo 221 + .get_backlink_conflicts(user_id, &input.collection, &backlinks) 222 + .await 223 + { 224 + Ok(c) => c, 225 + Err(e) => { 226 + error!("Failed to check backlink conflicts: {}", e); 227 + return ApiError::InternalError(None).into_response(); 228 + } 229 + }; 230 + 231 + for conflict_uri in conflicts { 232 + let conflict_rkey = match conflict_uri.rkey() { 233 + Some(r) => Rkey::from(r.to_string()), 234 + None => continue, 235 + }; 236 + let conflict_collection = match conflict_uri.collection() { 237 + Some(c) => Nsid::from(c.to_string()), 238 + None => continue, 239 + }; 240 + let conflict_key = format!("{}/{}", conflict_collection, conflict_rkey); 241 + 242 + let prev_cid = match mst.get(&conflict_key).await { 243 + Ok(Some(cid)) => cid, 244 + Ok(None) => continue, 245 + Err(_) => continue, 246 + }; 247 + 248 + if mst.blocks_for_path(&conflict_key, &mut all_old_mst_blocks).await.is_err() { 249 + error!("Failed to get old MST blocks for conflict {}", conflict_uri); 250 + } 251 + 252 + mst = match mst.delete(&conflict_key).await { 253 + Ok(m) => m, 254 + Err(e) => { 255 + error!("Failed to delete conflict from MST {}: {:?}", conflict_uri, e); 256 + continue; 257 + } 258 + }; 259 + 260 + ops.push(RecordOp::Delete { 261 + collection: conflict_collection, 262 + rkey: conflict_rkey, 263 + prev: Some(prev_cid), 264 + }); 265 + conflict_uris_to_cleanup.push(conflict_uri); 266 + } 267 + } 268 + } 269 + 228 270 let record_ipld = crate::util::json_to_ipld(&input.record); 229 271 let mut record_bytes = Vec::new(); 230 272 if serde_ipld_dagcbor::to_writer(&mut record_bytes, &record_ipld).is_err() { ··· 238 280 } 239 281 }; 240 282 let key = format!("{}/{}", input.collection, rkey); 283 + 284 + if mst.blocks_for_path(&key, &mut all_old_mst_blocks).await.is_err() { 285 + error!("Failed to get old MST blocks for new record path"); 286 + } 287 + 241 288 let new_mst = match mst.add(&key, record_cid).await { 242 289 Ok(m) => m, 243 290 _ => return ApiError::InternalError(Some("Failed to add to MST".into())).into_response(), ··· 246 293 Ok(c) => c, 247 294 _ => return ApiError::InternalError(Some("Failed to persist MST".into())).into_response(), 248 295 }; 249 - let op = RecordOp::Create { 296 + 297 + ops.push(RecordOp::Create { 250 298 collection: input.collection.clone(), 251 299 rkey: rkey.clone(), 252 300 cid: record_cid, 253 - }; 301 + }); 302 + 254 303 let mut new_mst_blocks = std::collections::BTreeMap::new(); 255 - let mut old_mst_blocks = std::collections::BTreeMap::new(); 256 304 if new_mst 257 305 .blocks_for_path(&key, &mut new_mst_blocks) 258 306 .await ··· 261 309 return ApiError::InternalError(Some("Failed to get new MST blocks for path".into())) 262 310 .into_response(); 263 311 } 264 - if mst 265 - .blocks_for_path(&key, &mut old_mst_blocks) 266 - .await 267 - .is_err() 268 - { 269 - return ApiError::InternalError(Some("Failed to get old MST blocks for path".into())) 270 - .into_response(); 271 - } 312 + 272 313 let mut relevant_blocks = new_mst_blocks.clone(); 273 - relevant_blocks.extend(old_mst_blocks.iter().map(|(k, v)| (*k, v.clone()))); 274 - relevant_blocks.insert(record_cid, bytes::Bytes::from(record_bytes)); 314 + relevant_blocks.extend(all_old_mst_blocks.iter().map(|(k, v)| (*k, v.clone()))); 315 + relevant_blocks.insert(record_cid, bytes::Bytes::new()); 275 316 let written_cids: Vec<Cid> = tracking_store 276 317 .get_all_relevant_cids() 277 318 .into_iter() ··· 283 324 let blob_cids = extract_blob_cids(&input.record); 284 325 let obsolete_cids: Vec<Cid> = std::iter::once(current_root_cid) 285 326 .chain( 286 - old_mst_blocks 327 + all_old_mst_blocks 287 328 .keys() 288 329 .filter(|cid| !new_mst_blocks.contains_key(*cid)) 289 330 .copied(), 290 331 ) 291 332 .collect(); 333 + 292 334 let commit_result = match commit_and_log( 293 335 &state, 294 336 CommitParams { 295 337 did: &did, 296 338 user_id, 297 339 current_root_cid: Some(current_root_cid), 298 - prev_data_cid: Some(commit.data), 340 + prev_data_cid: Some(initial_mst_root), 299 341 new_mst_root, 300 - ops: vec![op], 342 + ops, 301 343 blocks_cids: &written_cids_str, 302 344 blobs: &blob_cids, 303 345 obsolete_cids, ··· 312 354 Err(e) => return ApiError::InternalError(Some(e)).into_response(), 313 355 }; 314 356 357 + for conflict_uri in conflict_uris_to_cleanup { 358 + if let Err(e) = state.backlink_repo.remove_backlinks_by_uri(&conflict_uri).await { 359 + error!("Failed to remove backlinks for {}: {}", conflict_uri, e); 360 + } 361 + } 362 + 315 363 if let Some(ref controller) = controller_did { 316 - let _ = delegation::log_delegation_action( 317 - &state.db, 318 - &did, 319 - controller, 320 - Some(controller), 321 - DelegationActionType::RepoWrite, 322 - Some(json!({ 323 - "action": "create", 324 - "collection": input.collection, 325 - "rkey": rkey 326 - })), 327 - None, 328 - None, 329 - ) 330 - .await; 364 + let _ = state 365 + .delegation_repo 366 + .log_delegation_action( 367 + &did, 368 + controller, 369 + Some(controller), 370 + DelegationActionType::RepoWrite, 371 + Some(json!({ 372 + "action": "create", 373 + "collection": input.collection, 374 + "rkey": rkey 375 + })), 376 + None, 377 + None, 378 + ) 379 + .await; 380 + } 381 + 382 + let created_uri = AtUri::from_parts(&did, &input.collection, &rkey); 383 + let backlinks = extract_backlinks(&created_uri, &input.record); 384 + if !backlinks.is_empty() { 385 + if let Err(e) = state 386 + .backlink_repo 387 + .add_backlinks(user_id, &backlinks) 388 + .await 389 + { 390 + error!("Failed to add backlinks for {}: {}", created_uri, e); 391 + } 331 392 } 332 393 333 394 ( 334 395 StatusCode::OK, 335 396 Json(CreateRecordOutput { 336 - uri: AtUri::from_parts(&did, &input.collection, &rkey), 397 + uri: created_uri, 337 398 cid: record_cid.to_string(), 338 399 commit: CommitInfo { 339 400 cid: commit_result.commit_cid.to_string(), ··· 574 635 }; 575 636 576 637 if let Some(ref controller) = controller_did { 577 - let _ = delegation::log_delegation_action( 578 - &state.db, 579 - &did, 580 - controller, 581 - Some(controller), 582 - DelegationActionType::RepoWrite, 583 - Some(json!({ 584 - "action": if is_update { "update" } else { "create" }, 585 - "collection": input.collection, 586 - "rkey": input.rkey 587 - })), 588 - None, 589 - None, 590 - ) 591 - .await; 638 + let _ = state 639 + .delegation_repo 640 + .log_delegation_action( 641 + &did, 642 + controller, 643 + Some(controller), 644 + DelegationActionType::RepoWrite, 645 + Some(json!({ 646 + "action": if is_update { "update" } else { "create" }, 647 + "collection": input.collection, 648 + "rkey": input.rkey 649 + })), 650 + None, 651 + None, 652 + ) 653 + .await; 592 654 } 593 655 594 656 (
+141 -237
crates/tranquil-pds/src/api/server/account_status.rs
··· 3 3 use crate::cache::Cache; 4 4 use crate::plc::PlcClient; 5 5 use crate::state::AppState; 6 - use crate::types::{Handle, PlainPassword}; 6 + use crate::types::PlainPassword; 7 7 use axum::{ 8 8 Json, 9 9 extract::State, ··· 54 54 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 55 55 ); 56 56 let did = match crate::auth::validate_token_with_dpop( 57 - &state.db, 57 + state.user_repo.as_ref(), 58 + state.oauth_repo.as_ref(), 58 59 &extracted.token, 59 60 extracted.is_dpop, 60 61 dpop_proof, ··· 68 69 Ok(user) => user.did, 69 70 Err(e) => return ApiError::from(e).into_response(), 70 71 }; 71 - let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did.as_str()) 72 - .fetch_optional(&state.db) 73 - .await 74 - { 72 + let user_id = match state.user_repo.get_id_by_did(&did).await { 75 73 Ok(Some(id)) => id, 76 74 _ => { 77 75 return ApiError::InternalError(None).into_response(); 78 76 } 79 77 }; 80 - let user_status = sqlx::query!( 81 - "SELECT deactivated_at FROM users WHERE did = $1", 82 - did.as_str() 83 - ) 84 - .fetch_optional(&state.db) 85 - .await; 86 - let deactivated_at = match user_status { 87 - Ok(Some(row)) => row.deactivated_at, 88 - _ => None, 89 - }; 90 - let repo_result = sqlx::query!( 91 - "SELECT repo_root_cid, repo_rev FROM repos WHERE user_id = $1", 92 - user_id 93 - ) 94 - .fetch_optional(&state.db) 95 - .await; 96 - let (repo_commit, repo_rev_from_db) = match repo_result { 97 - Ok(Some(row)) => (row.repo_root_cid, row.repo_rev), 98 - _ => (String::new(), None), 99 - }; 100 - let block_count: i64 = sqlx::query_scalar!( 101 - "SELECT COUNT(*) FROM user_blocks WHERE user_id = $1", 102 - user_id 103 - ) 104 - .fetch_one(&state.db) 105 - .await 106 - .unwrap_or(Some(0)) 107 - .unwrap_or(0); 78 + let is_active = state 79 + .user_repo 80 + .is_account_active_by_did(&did) 81 + .await 82 + .ok() 83 + .flatten() 84 + .unwrap_or(false); 85 + let repo_info = state.repo_repo.get_repo(user_id).await.ok().flatten(); 86 + let (repo_commit, repo_rev_from_db) = repo_info 87 + .map(|r| (r.repo_root_cid.to_string(), r.repo_rev)) 88 + .unwrap_or_else(|| (String::new(), None)); 89 + let block_count: i64 = state 90 + .repo_repo 91 + .count_user_blocks(user_id) 92 + .await 93 + .unwrap_or(0); 108 94 let repo_rev = if let Some(rev) = repo_rev_from_db { 109 95 rev 110 96 } else if !repo_commit.is_empty() { ··· 123 109 } else { 124 110 String::new() 125 111 }; 126 - let record_count: i64 = 127 - sqlx::query_scalar!("SELECT COUNT(*) FROM records WHERE repo_id = $1", user_id) 128 - .fetch_one(&state.db) 129 - .await 130 - .unwrap_or(Some(0)) 131 - .unwrap_or(0); 132 - let imported_blobs: i64 = sqlx::query_scalar!( 133 - "SELECT COUNT(*) FROM blobs WHERE created_by_user = $1", 134 - user_id 135 - ) 136 - .fetch_one(&state.db) 137 - .await 138 - .unwrap_or(Some(0)) 139 - .unwrap_or(0); 140 - let expected_blobs: i64 = sqlx::query_scalar!( 141 - "SELECT COUNT(DISTINCT blob_cid) FROM record_blobs WHERE repo_id = $1", 142 - user_id 143 - ) 144 - .fetch_one(&state.db) 145 - .await 146 - .unwrap_or(Some(0)) 147 - .unwrap_or(0); 148 - let valid_did = is_valid_did_for_service(&state.db, state.cache.clone(), did.as_str()).await; 112 + let record_count: i64 = state.repo_repo.count_records(user_id).await.unwrap_or(0); 113 + let imported_blobs: i64 = state.blob_repo.count_blobs_by_user(user_id).await.unwrap_or(0); 114 + let expected_blobs: i64 = state 115 + .blob_repo 116 + .count_distinct_record_blobs(user_id) 117 + .await 118 + .unwrap_or(0); 119 + let valid_did = is_valid_did_for_service(state.user_repo.as_ref(), state.cache.clone(), &did).await; 149 120 ( 150 121 StatusCode::OK, 151 122 Json(CheckAccountStatusOutput { 152 - activated: deactivated_at.is_none(), 123 + activated: is_active, 153 124 valid_did, 154 125 repo_commit: repo_commit.clone(), 155 126 repo_rev, 156 - repo_blocks: block_count as i64, 127 + repo_blocks: block_count, 157 128 indexed_records: record_count, 158 129 private_state_values: 0, 159 130 expected_blobs, ··· 163 134 .into_response() 164 135 } 165 136 166 - async fn is_valid_did_for_service(db: &sqlx::PgPool, cache: Arc<dyn Cache>, did: &str) -> bool { 167 - assert_valid_did_document_for_service(db, cache, did, false) 137 + async fn is_valid_did_for_service(user_repo: &dyn tranquil_db_traits::UserRepository, cache: Arc<dyn Cache>, did: &crate::types::Did) -> bool { 138 + assert_valid_did_document_for_service(user_repo, cache, did, false) 168 139 .await 169 140 .is_ok() 170 141 } 171 142 172 143 async fn assert_valid_did_document_for_service( 173 - db: &sqlx::PgPool, 144 + user_repo: &dyn tranquil_db_traits::UserRepository, 174 145 cache: Arc<dyn Cache>, 175 - did: &str, 146 + did: &crate::types::Did, 176 147 with_retry: bool, 177 148 ) -> Result<(), ApiError> { 178 149 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 179 150 let expected_endpoint = format!("https://{}", hostname); 180 151 181 - if did.starts_with("did:plc:") { 152 + if did.as_str().starts_with("did:plc:") { 182 153 let max_attempts = if with_retry { 5 } else { 1 }; 183 154 let cache_for_retry = cache.clone(); 184 - let did_owned = did.to_string(); 155 + let did_owned = did.as_str().to_string(); 185 156 let expected_owned = expected_endpoint.clone(); 186 157 let attempt_counter = Arc::new(AtomicUsize::new(0)); 187 158 ··· 264 235 .and_then(|v| v.get("atproto")) 265 236 .and_then(|k| k.as_str()); 266 237 267 - let user_row = sqlx::query!( 268 - "SELECT uk.key_bytes, uk.encryption_version FROM user_keys uk JOIN users u ON uk.user_id = u.id WHERE u.did = $1", 269 - did 270 - ) 271 - .fetch_optional(db) 272 - .await 273 - .map_err(|e| { 274 - error!("Failed to fetch user key: {:?}", e); 275 - ApiError::InternalError(None) 276 - })?; 238 + let user_key = user_repo 239 + .get_user_key_by_did(&did) 240 + .await 241 + .map_err(|e| { 242 + error!("Failed to fetch user key: {:?}", e); 243 + ApiError::InternalError(None) 244 + })?; 277 245 278 - if let Some(row) = user_row { 279 - let key_bytes = crate::config::decrypt_key(&row.key_bytes, row.encryption_version) 246 + if let Some(key_info) = user_key { 247 + let key_bytes = crate::config::decrypt_key(&key_info.key_bytes, key_info.encryption_version) 280 248 .map_err(|e| { 281 249 error!("Failed to decrypt user key: {}", e); 282 250 ApiError::InternalError(None) ··· 297 265 )); 298 266 } 299 267 } 300 - } else if let Some(host_and_path) = did.strip_prefix("did:web:") { 268 + } else if let Some(host_and_path) = did.as_str().strip_prefix("did:web:") { 301 269 let client = crate::api::proxy_client::did_resolution_client(); 302 270 let decoded = host_and_path.replace("%3A", ":"); 303 271 let parts: Vec<&str> = decoded.split(':').collect(); ··· 374 342 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 375 343 ); 376 344 let auth_user = match crate::auth::validate_token_with_dpop( 377 - &state.db, 345 + state.user_repo.as_ref(), 346 + state.oauth_repo.as_ref(), 378 347 &extracted.token, 379 348 extracted.is_dpop, 380 349 dpop_proof, ··· 414 383 ); 415 384 let did_validation_start = std::time::Instant::now(); 416 385 if let Err(e) = 417 - assert_valid_did_document_for_service(&state.db, state.cache.clone(), did.as_str(), true) 386 + assert_valid_did_document_for_service(state.user_repo.as_ref(), state.cache.clone(), &did, true) 418 387 .await 419 388 { 420 389 info!( ··· 430 399 did_validation_start.elapsed() 431 400 ); 432 401 433 - let handle = sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", did.as_str()) 434 - .fetch_optional(&state.db) 402 + let handle = state 403 + .user_repo 404 + .get_handle_by_did(&did) 435 405 .await 436 406 .ok() 437 407 .flatten(); ··· 439 409 "[MIGRATION] activateAccount: Activating account did={} handle={:?}", 440 410 did, handle 441 411 ); 442 - let result = sqlx::query!( 443 - "UPDATE users SET deactivated_at = NULL WHERE did = $1", 444 - did.as_str() 445 - ) 446 - .execute(&state.db) 447 - .await; 412 + let result = state.user_repo.activate_account(&did).await; 448 413 match result { 449 414 Ok(_) => { 450 415 info!( ··· 472 437 "[MIGRATION] activateAccount: Sequencing identity event for did={} handle={:?}", 473 438 did, handle 474 439 ); 475 - let handle_typed = handle.as_ref().map(Handle::new_unchecked); 440 + let handle_typed = handle.clone(); 476 441 if let Err(e) = crate::api::repo::record::sequence_identity_event( 477 442 &state, 478 443 &did, ··· 487 452 } else { 488 453 info!("[MIGRATION] activateAccount: Identity event sequenced successfully"); 489 454 } 490 - let repo_root = sqlx::query_scalar!( 491 - "SELECT r.repo_root_cid FROM repos r JOIN users u ON r.user_id = u.id WHERE u.did = $1", 492 - did.as_str() 493 - ) 494 - .fetch_optional(&state.db) 495 - .await 496 - .ok() 497 - .flatten(); 498 - if let Some(root_cid) = repo_root { 455 + let repo_root = state 456 + .repo_repo 457 + .get_repo_root_by_did(&did) 458 + .await 459 + .ok() 460 + .flatten(); 461 + if let Some(root_cid_link) = repo_root { 499 462 info!( 500 463 "[MIGRATION] activateAccount: Sequencing sync event for did={} root_cid={}", 501 - did, root_cid 464 + did, root_cid_link 502 465 ); 503 - let rev = if let Ok(cid) = Cid::from_str(&root_cid) { 466 + let rev = if let Ok(cid) = Cid::from_str(root_cid_link.as_str()) { 504 467 if let Ok(Some(block)) = state.block_store.get(&cid).await { 505 468 Commit::from_cbor(&block).ok().map(|c| c.rev().to_string()) 506 469 } else { ··· 512 475 if let Err(e) = crate::api::repo::record::sequence_sync_event( 513 476 &state, 514 477 &did, 515 - &root_cid, 478 + root_cid_link.as_str(), 516 479 rev.as_deref(), 517 480 ) 518 481 .await ··· 566 529 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 567 530 ); 568 531 let auth_user = match crate::auth::validate_token_with_dpop( 569 - &state.db, 532 + state.user_repo.as_ref(), 533 + state.oauth_repo.as_ref(), 570 534 &extracted.token, 571 535 extracted.is_dpop, 572 536 dpop_proof, ··· 598 562 599 563 let did = auth_user.did; 600 564 601 - let handle = sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", did.as_str()) 602 - .fetch_optional(&state.db) 565 + let handle = state 566 + .user_repo 567 + .get_handle_by_did(&did) 603 568 .await 604 569 .ok() 605 570 .flatten(); 606 571 607 - let result = sqlx::query!( 608 - "UPDATE users SET deactivated_at = NOW(), delete_after = $2 WHERE did = $1", 609 - did.as_str(), 610 - delete_after 611 - ) 612 - .execute(&state.db) 613 - .await; 572 + let result = state 573 + .user_repo 574 + .deactivate_account(&did, delete_after) 575 + .await; 614 576 615 577 match result { 616 - Ok(_) => { 578 + Ok(true) => { 617 579 if let Some(ref h) = handle { 618 580 let _ = state.cache.delete(&format!("handle:{}", h)).await; 619 581 } ··· 627 589 { 628 590 warn!("Failed to sequence account deactivated event: {}", e); 629 591 } 592 + EmptyResponse::ok().into_response() 593 + } 594 + Ok(false) => { 630 595 EmptyResponse::ok().into_response() 631 596 } 632 597 Err(e) => { ··· 652 617 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 653 618 ); 654 619 let validated = match crate::auth::validate_token_with_dpop( 655 - &state.db, 620 + state.user_repo.as_ref(), 621 + state.oauth_repo.as_ref(), 656 622 &extracted.token, 657 623 extracted.is_dpop, 658 624 dpop_proof, ··· 668 634 }; 669 635 let did = validated.did.clone(); 670 636 671 - if !crate::api::server::reauth::check_legacy_session_mfa(&state.db, did.as_str()).await { 672 - return crate::api::server::reauth::legacy_mfa_required_response(&state.db, did.as_str()) 637 + if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &did).await { 638 + return crate::api::server::reauth::legacy_mfa_required_response(&*state.user_repo, &*state.session_repo, &did) 673 639 .await; 674 640 } 675 641 676 - let user_id = match sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did.as_str()) 677 - .fetch_optional(&state.db) 678 - .await 679 - { 642 + let user_id = match state.user_repo.get_id_by_did(&did).await { 680 643 Ok(Some(id)) => id, 681 644 _ => { 682 645 return ApiError::InternalError(None).into_response(); ··· 684 647 }; 685 648 let confirmation_token = Uuid::new_v4().to_string(); 686 649 let expires_at = Utc::now() + Duration::minutes(15); 687 - let insert = sqlx::query!( 688 - "INSERT INTO account_deletion_requests (token, did, expires_at) VALUES ($1, $2, $3)", 689 - confirmation_token, 690 - did.as_str(), 691 - expires_at 692 - ) 693 - .execute(&state.db) 694 - .await; 695 - if let Err(e) = insert { 650 + if let Err(e) = state 651 + .infra_repo 652 + .create_deletion_request(&confirmation_token, &did, expires_at) 653 + .await 654 + { 696 655 error!("DB error creating deletion token: {:?}", e); 697 656 return ApiError::InternalError(None).into_response(); 698 657 } 699 658 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 700 - if let Err(e) = 701 - crate::comms::enqueue_account_deletion(&state.db, user_id, &confirmation_token, &hostname) 702 - .await 659 + if let Err(e) = crate::comms::comms_repo::enqueue_account_deletion( 660 + state.user_repo.as_ref(), 661 + state.infra_repo.as_ref(), 662 + user_id, 663 + &confirmation_token, 664 + &hostname, 665 + ) 666 + .await 703 667 { 704 668 warn!("Failed to enqueue account deletion notification: {:?}", e); 705 669 } ··· 731 695 if token.is_empty() { 732 696 return ApiError::InvalidToken(Some("token is required".into())).into_response(); 733 697 } 734 - let user = sqlx::query!( 735 - "SELECT id, password_hash, handle FROM users WHERE did = $1", 736 - did.as_str() 737 - ) 738 - .fetch_optional(&state.db) 739 - .await; 740 - let (user_id, password_hash, handle) = match user { 741 - Ok(Some(row)) => (row.id, row.password_hash, row.handle), 698 + let user = match state.user_repo.get_user_for_deletion(did).await { 699 + Ok(Some(u)) => u, 742 700 Ok(None) => { 743 701 return ApiError::InvalidRequest("account not found".into()).into_response(); 744 702 } ··· 747 705 return ApiError::InternalError(None).into_response(); 748 706 } 749 707 }; 708 + let (user_id, password_hash, handle) = (user.id, user.password_hash, user.handle); 750 709 let password_valid = if password_hash 751 710 .as_ref() 752 711 .map(|h| verify(password, h).unwrap_or(false)) ··· 754 713 { 755 714 true 756 715 } else { 757 - let app_pass_rows = sqlx::query!( 758 - "SELECT password_hash FROM app_passwords WHERE user_id = $1", 759 - user_id 760 - ) 761 - .fetch_all(&state.db) 762 - .await 763 - .unwrap_or_default(); 764 - app_pass_rows 716 + let app_pass_hashes = state 717 + .session_repo 718 + .get_app_password_hashes_by_did(did) 719 + .await 720 + .unwrap_or_default(); 721 + app_pass_hashes 765 722 .iter() 766 - .any(|row| verify(password, &row.password_hash).unwrap_or(false)) 723 + .any(|h| verify(password, h).unwrap_or(false)) 767 724 }; 768 725 if !password_valid { 769 726 return ApiError::AuthenticationFailed(Some("Invalid password".into())).into_response(); 770 727 } 771 - let deletion_request = sqlx::query!( 772 - "SELECT did, expires_at FROM account_deletion_requests WHERE token = $1", 773 - token 774 - ) 775 - .fetch_optional(&state.db) 776 - .await; 777 - let (token_did, expires_at) = match deletion_request { 778 - Ok(Some(row)) => (row.did, row.expires_at), 728 + let deletion_request = match state.infra_repo.get_deletion_request(token).await { 729 + Ok(Some(req)) => req, 779 730 Ok(None) => { 780 731 return ApiError::InvalidToken(Some("Invalid or expired token".into())).into_response(); 781 732 } ··· 784 735 return ApiError::InternalError(None).into_response(); 785 736 } 786 737 }; 787 - if token_did != did.as_str() { 738 + if &deletion_request.did != did { 788 739 return ApiError::InvalidToken(Some("Token does not match account".into())).into_response(); 789 740 } 790 - if Utc::now() > expires_at { 791 - let _ = sqlx::query!( 792 - "DELETE FROM account_deletion_requests WHERE token = $1", 793 - token 794 - ) 795 - .execute(&state.db) 796 - .await; 741 + if Utc::now() > deletion_request.expires_at { 742 + let _ = state.infra_repo.delete_deletion_request(token).await; 797 743 return ApiError::ExpiredToken(None).into_response(); 798 744 } 799 - let mut tx = match state.db.begin().await { 800 - Ok(tx) => tx, 801 - Err(e) => { 802 - error!("Failed to begin transaction: {:?}", e); 803 - return ApiError::InternalError(None).into_response(); 804 - } 805 - }; 806 - let deletion_result: Result<(), sqlx::Error> = async { 807 - sqlx::query!("DELETE FROM session_tokens WHERE did = $1", did) 808 - .execute(&mut *tx) 809 - .await?; 810 - sqlx::query!("DELETE FROM records WHERE repo_id = $1", user_id) 811 - .execute(&mut *tx) 812 - .await?; 813 - sqlx::query!("DELETE FROM repos WHERE user_id = $1", user_id) 814 - .execute(&mut *tx) 815 - .await?; 816 - sqlx::query!("DELETE FROM blobs WHERE created_by_user = $1", user_id) 817 - .execute(&mut *tx) 818 - .await?; 819 - sqlx::query!("DELETE FROM user_keys WHERE user_id = $1", user_id) 820 - .execute(&mut *tx) 821 - .await?; 822 - sqlx::query!("DELETE FROM app_passwords WHERE user_id = $1", user_id) 823 - .execute(&mut *tx) 824 - .await?; 825 - sqlx::query!("DELETE FROM account_deletion_requests WHERE did = $1", did) 826 - .execute(&mut *tx) 827 - .await?; 828 - sqlx::query!("DELETE FROM users WHERE id = $1", user_id) 829 - .execute(&mut *tx) 830 - .await?; 831 - Ok(()) 745 + if let Err(e) = state 746 + .user_repo 747 + .delete_account_complete(user_id, did) 748 + .await 749 + { 750 + error!("DB error deleting account: {:?}", e); 751 + return ApiError::InternalError(None).into_response(); 832 752 } 753 + let account_seq = crate::api::repo::record::sequence_account_event( 754 + &state, 755 + did, 756 + false, 757 + Some("deleted"), 758 + ) 833 759 .await; 834 - match deletion_result { 835 - Ok(()) => { 836 - if let Err(e) = tx.commit().await { 837 - error!("Failed to commit account deletion transaction: {:?}", e); 838 - return ApiError::InternalError(None).into_response(); 839 - } 840 - let account_seq = crate::api::repo::record::sequence_account_event( 841 - &state, 842 - did, 843 - false, 844 - Some("deleted"), 845 - ) 846 - .await; 847 - match account_seq { 848 - Ok(seq) => { 849 - if let Err(e) = sqlx::query!( 850 - "DELETE FROM repo_seq WHERE did = $1 AND seq != $2", 851 - did, 852 - seq 853 - ) 854 - .execute(&state.db) 855 - .await 856 - { 857 - warn!( 858 - "Failed to cleanup sequences for deleted account {}: {}", 859 - did, e 860 - ); 861 - } 862 - } 863 - Err(e) => { 864 - warn!( 865 - "Failed to sequence account deletion event for {}: {}", 866 - did, e 867 - ); 868 - } 760 + match account_seq { 761 + Ok(seq) => { 762 + if let Err(e) = state 763 + .repo_repo 764 + .delete_sequences_except(did, seq) 765 + .await 766 + { 767 + warn!( 768 + "Failed to cleanup sequences for deleted account {}: {}", 769 + did, e 770 + ); 869 771 } 870 - let _ = state.cache.delete(&format!("handle:{}", handle)).await; 871 - info!("Account {} deleted successfully", did); 872 - EmptyResponse::ok().into_response() 873 772 } 874 773 Err(e) => { 875 - error!("DB error deleting account, rolling back: {:?}", e); 876 - ApiError::InternalError(None).into_response() 774 + warn!( 775 + "Failed to sequence account deletion event for {}: {}", 776 + did, e 777 + ); 877 778 } 878 779 } 780 + let _ = state.cache.delete(&format!("handle:{}", handle)).await; 781 + info!("Account {} deleted successfully", did); 782 + EmptyResponse::ok().into_response() 879 783 }
+83 -79
crates/tranquil-pds/src/api/server/app_password.rs
··· 1 1 use crate::api::EmptyResponse; 2 2 use crate::api::error::ApiError; 3 3 use crate::auth::BearerAuth; 4 - use crate::delegation::{self, DelegationActionType}; 4 + use crate::delegation::{DelegationActionType, intersect_scopes}; 5 5 use crate::state::{AppState, RateLimitKind}; 6 - use crate::util::get_user_id_by_did; 7 6 use axum::{ 8 7 Json, 9 8 extract::State, ··· 12 11 }; 13 12 use serde::{Deserialize, Serialize}; 14 13 use serde_json::json; 14 + use tranquil_db_traits::AppPasswordCreate; 15 15 use tracing::{error, warn}; 16 16 17 17 #[derive(Serialize)] ··· 35 35 State(state): State<AppState>, 36 36 BearerAuth(auth_user): BearerAuth, 37 37 ) -> Response { 38 - let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await { 39 - Ok(id) => id, 40 - Err(e) => return ApiError::from(e).into_response(), 38 + let user = match state.user_repo.get_by_did(&auth_user.did).await { 39 + Ok(Some(u)) => u, 40 + Ok(None) => return ApiError::AccountNotFound.into_response(), 41 + Err(e) => { 42 + error!("DB error getting user: {:?}", e); 43 + return ApiError::InternalError(None).into_response(); 44 + } 41 45 }; 42 - match sqlx::query!( 43 - "SELECT name, created_at, privileged, scopes, created_by_controller_did FROM app_passwords WHERE user_id = $1 ORDER BY created_at DESC", 44 - user_id 45 - ) 46 - .fetch_all(&state.db) 47 - .await 48 - { 46 + 47 + match state.session_repo.list_app_passwords(user.id).await { 49 48 Ok(rows) => { 50 49 let passwords: Vec<AppPassword> = rows 51 50 .iter() ··· 54 53 created_at: row.created_at.to_rfc3339(), 55 54 privileged: row.privileged, 56 55 scopes: row.scopes.clone(), 57 - created_by_controller: row.created_by_controller_did.clone(), 56 + created_by_controller: row.created_by_controller_did.as_ref().map(|d| d.to_string()), 58 57 }) 59 58 .collect(); 60 59 Json(ListAppPasswordsOutput { passwords }).into_response() ··· 98 97 warn!(ip = %client_ip, "App password creation rate limit exceeded"); 99 98 return ApiError::RateLimitExceeded(None).into_response(); 100 99 } 101 - let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await { 102 - Ok(id) => id, 103 - Err(e) => return ApiError::from(e).into_response(), 100 + 101 + let user = match state.user_repo.get_by_did(&auth_user.did).await { 102 + Ok(Some(u)) => u, 103 + Ok(None) => return ApiError::AccountNotFound.into_response(), 104 + Err(e) => { 105 + error!("DB error getting user: {:?}", e); 106 + return ApiError::InternalError(None).into_response(); 107 + } 104 108 }; 109 + 105 110 let name = input.name.trim(); 106 111 if name.is_empty() { 107 112 return ApiError::InvalidRequest("name is required".into()).into_response(); 108 113 } 109 - let existing = sqlx::query!( 110 - "SELECT id FROM app_passwords WHERE user_id = $1 AND name = $2", 111 - user_id, 112 - name 113 - ) 114 - .fetch_optional(&state.db) 115 - .await; 116 - if let Ok(Some(_)) = existing { 117 - return ApiError::DuplicateAppPassword.into_response(); 114 + 115 + match state.session_repo.get_app_password_by_name(user.id, name).await { 116 + Ok(Some(_)) => return ApiError::DuplicateAppPassword.into_response(), 117 + Err(e) => { 118 + error!("DB error checking app password: {:?}", e); 119 + return ApiError::InternalError(None).into_response(); 120 + } 121 + Ok(None) => {} 118 122 } 119 123 120 124 let (final_scopes, controller_did) = if let Some(ref controller) = auth_user.controller_did { 121 - let grant = delegation::get_delegation(&state.db, &auth_user.did, controller) 125 + let grant = state 126 + .delegation_repo 127 + .get_delegation(&auth_user.did, controller) 122 128 .await 123 129 .ok() 124 130 .flatten(); 125 131 let granted_scopes = grant.map(|g| g.granted_scopes).unwrap_or_default(); 126 132 127 133 let requested = input.scopes.as_deref().unwrap_or("atproto"); 128 - let intersected = delegation::intersect_scopes(requested, &granted_scopes); 134 + let intersected = intersect_scopes(requested, &granted_scopes); 129 135 130 136 if intersected.is_empty() && !granted_scopes.is_empty() { 131 137 return ApiError::InsufficientScope(None).into_response(); ··· 152 158 }) 153 159 .collect::<Vec<String>>() 154 160 .join("-"); 161 + 155 162 let password_clone = password.clone(); 156 163 let password_hash = match tokio::task::spawn_blocking(move || { 157 164 bcrypt::hash(&password_clone, bcrypt::DEFAULT_COST) ··· 168 175 return ApiError::InternalError(None).into_response(); 169 176 } 170 177 }; 178 + 171 179 let privileged = input.privileged.unwrap_or(false); 172 180 let created_at = chrono::Utc::now(); 173 - match sqlx::query!( 174 - "INSERT INTO app_passwords (user_id, name, password_hash, created_at, privileged, scopes, created_by_controller_did) VALUES ($1, $2, $3, $4, $5, $6, $7)", 175 - user_id, 176 - name, 181 + 182 + let create_data = AppPasswordCreate { 183 + user_id: user.id, 184 + name: name.to_string(), 177 185 password_hash, 178 - created_at, 179 186 privileged, 180 - final_scopes, 181 - controller_did.as_deref() 182 - ) 183 - .execute(&state.db) 184 - .await 185 - { 187 + scopes: final_scopes.clone(), 188 + created_by_controller_did: controller_did.clone(), 189 + }; 190 + 191 + match state.session_repo.create_app_password(&create_data).await { 186 192 Ok(_) => { 187 193 if let Some(ref controller) = controller_did { 188 - let _ = delegation::log_delegation_action( 189 - &state.db, 190 - &auth_user.did, 191 - controller, 192 - Some(controller), 193 - DelegationActionType::AccountAction, 194 - Some(json!({ 195 - "action": "create_app_password", 196 - "name": name, 197 - "scopes": final_scopes 198 - })), 199 - None, 200 - None, 201 - ) 202 - .await; 194 + let _ = state 195 + .delegation_repo 196 + .log_delegation_action( 197 + &auth_user.did, 198 + controller, 199 + Some(controller), 200 + DelegationActionType::AccountAction, 201 + Some(json!({ 202 + "action": "create_app_password", 203 + "name": name, 204 + "scopes": final_scopes 205 + })), 206 + None, 207 + None, 208 + ) 209 + .await; 203 210 } 204 211 Json(CreateAppPasswordOutput { 205 212 name: name.to_string(), ··· 227 234 BearerAuth(auth_user): BearerAuth, 228 235 Json(input): Json<RevokeAppPasswordInput>, 229 236 ) -> Response { 230 - let user_id = match get_user_id_by_did(&state.db, &auth_user.did).await { 231 - Ok(id) => id, 232 - Err(e) => return ApiError::from(e).into_response(), 237 + let user = match state.user_repo.get_by_did(&auth_user.did).await { 238 + Ok(Some(u)) => u, 239 + Ok(None) => return ApiError::AccountNotFound.into_response(), 240 + Err(e) => { 241 + error!("DB error getting user: {:?}", e); 242 + return ApiError::InternalError(None).into_response(); 243 + } 233 244 }; 245 + 234 246 let name = input.name.trim(); 235 247 if name.is_empty() { 236 248 return ApiError::InvalidRequest("name is required".into()).into_response(); 237 249 } 238 - let sessions_to_invalidate = sqlx::query_scalar!( 239 - "SELECT access_jti FROM session_tokens WHERE did = $1 AND app_password_name = $2", 240 - &auth_user.did, 241 - name 242 - ) 243 - .fetch_all(&state.db) 244 - .await 245 - .unwrap_or_default(); 246 - if let Err(e) = sqlx::query!( 247 - "DELETE FROM session_tokens WHERE did = $1 AND app_password_name = $2", 248 - &auth_user.did, 249 - name 250 - ) 251 - .execute(&state.db) 252 - .await 250 + 251 + let sessions_to_invalidate = state 252 + .session_repo 253 + .get_session_jtis_by_app_password(&auth_user.did, name) 254 + .await 255 + .unwrap_or_default(); 256 + 257 + if let Err(e) = state 258 + .session_repo 259 + .delete_sessions_by_app_password(&auth_user.did, name) 260 + .await 253 261 { 254 262 error!("DB error revoking sessions for app password: {:?}", e); 255 263 return ApiError::InternalError(None).into_response(); 256 264 } 265 + 257 266 futures::future::join_all(sessions_to_invalidate.iter().map(|jti| { 258 267 let cache_key = format!("auth:session:{}:{}", &auth_user.did, jti); 259 268 let cache = state.cache.clone(); ··· 262 271 } 263 272 })) 264 273 .await; 265 - if let Err(e) = sqlx::query!( 266 - "DELETE FROM app_passwords WHERE user_id = $1 AND name = $2", 267 - user_id, 268 - name 269 - ) 270 - .execute(&state.db) 271 - .await 272 - { 274 + 275 + if let Err(e) = state.session_repo.delete_app_password(user.id, name).await { 273 276 error!("DB error revoking app password: {:?}", e); 274 277 return ApiError::InternalError(None).into_response(); 275 278 } 279 + 276 280 EmptyResponse::ok().into_response() 277 281 }
+38 -83
crates/tranquil-pds/src/api/server/email.rs
··· 34 34 return e; 35 35 } 36 36 37 - let did = auth.0.did.to_string(); 38 - let user = match sqlx::query!( 39 - "SELECT id, handle, email, email_verified FROM users WHERE did = $1", 40 - did 41 - ) 42 - .fetch_optional(&state.db) 43 - .await 44 - { 37 + let user = match state.user_repo.get_email_info_by_did(&auth.0.did).await { 45 38 Ok(Some(row)) => row, 46 39 Ok(None) => { 47 40 return ApiError::AccountNotFound.into_response(); ··· 61 54 62 55 if token_required { 63 56 let code = crate::auth::verification_token::generate_channel_update_token( 64 - &did, 57 + &auth.0.did, 65 58 "email_update", 66 59 &current_email.to_lowercase(), 67 60 ); 68 61 let formatted_code = crate::auth::verification_token::format_token_for_display(&code); 69 62 70 63 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 71 - if let Err(e) = 72 - crate::comms::enqueue_email_update_token(&state.db, user.id, &formatted_code, &hostname) 73 - .await 64 + if let Err(e) = crate::comms::comms_repo::enqueue_email_update_token( 65 + state.user_repo.as_ref(), 66 + state.infra_repo.as_ref(), 67 + user.id, 68 + &formatted_code, 69 + &hostname, 70 + ) 71 + .await 74 72 { 75 73 warn!("Failed to enqueue email update notification: {:?}", e); 76 74 } ··· 111 109 return e; 112 110 } 113 111 114 - let did = auth.0.did.to_string(); 115 - let user = match sqlx::query!( 116 - "SELECT id, email, email_verified FROM users WHERE did = $1", 117 - did 118 - ) 119 - .fetch_optional(&state.db) 120 - .await 121 - { 112 + let did = &auth.0.did; 113 + let user = match state.user_repo.get_email_info_by_did(did).await { 122 114 Ok(Some(row)) => row, 123 115 Ok(None) => { 124 116 return ApiError::AccountNotFound.into_response(); ··· 154 146 155 147 match verified { 156 148 Ok(token_data) => { 157 - if token_data.did != did { 149 + if token_data.did != did.as_str() { 158 150 return ApiError::InvalidToken(None).into_response(); 159 151 } 160 152 } ··· 166 158 } 167 159 } 168 160 169 - let update = sqlx::query!( 170 - "UPDATE users SET email_verified = TRUE, updated_at = NOW() WHERE id = $1", 171 - user.id 172 - ) 173 - .execute(&state.db) 174 - .await; 175 - 176 - if let Err(e) = update { 161 + if let Err(e) = state.user_repo.set_email_verified(user.id, true).await { 177 162 error!("DB error confirming email: {:?}", e); 178 163 return ApiError::InternalError(None).into_response(); 179 164 } ··· 207 192 return e; 208 193 } 209 194 210 - let did = auth_user.did.to_string(); 211 - let user = match sqlx::query!( 212 - "SELECT id, email, email_verified FROM users WHERE did = $1", 213 - did 214 - ) 215 - .fetch_optional(&state.db) 216 - .await 217 - { 195 + let did = &auth_user.did; 196 + let user = match state.user_repo.get_email_info_by_did(did).await { 218 197 Ok(Some(row)) => row, 219 198 Ok(None) => { 220 199 return ApiError::AccountNotFound.into_response(); ··· 262 241 263 242 match verified { 264 243 Ok(token_data) => { 265 - if token_data.did != did { 244 + if token_data.did != did.as_str() { 266 245 return ApiError::InvalidToken(None).into_response(); 267 246 } 268 247 } ··· 275 254 } 276 255 } 277 256 278 - let exists = sqlx::query!( 279 - "SELECT 1 as one FROM users WHERE LOWER(email) = $1 AND id != $2", 280 - new_email, 281 - user_id 282 - ) 283 - .fetch_optional(&state.db) 284 - .await; 285 - 286 - if let Ok(Some(_)) = exists { 257 + if let Ok(true) = state.user_repo.check_email_exists(&new_email, user_id).await { 287 258 return ApiError::InvalidRequest("Email is already in use".into()).into_response(); 288 259 } 289 260 290 - let update: Result<sqlx::postgres::PgQueryResult, sqlx::Error> = sqlx::query!( 291 - "UPDATE users SET email = $1, email_verified = FALSE, updated_at = NOW() WHERE id = $2", 292 - new_email, 293 - user_id 294 - ) 295 - .execute(&state.db) 296 - .await; 297 - 298 - if let Err(e) = update { 261 + if let Err(e) = state.user_repo.update_email(user_id, &new_email).await { 299 262 error!("DB error updating email: {:?}", e); 300 - if e.as_database_error() 301 - .map(|db_err: &dyn sqlx::error::DatabaseError| db_err.is_unique_violation()) 302 - .unwrap_or(false) 303 - { 304 - return ApiError::EmailTaken.into_response(); 305 - } 306 263 return ApiError::InternalError(None).into_response(); 307 264 } 308 265 ··· 310 267 crate::auth::verification_token::generate_signup_token(&did, "email", &new_email); 311 268 let formatted_token = 312 269 crate::auth::verification_token::format_token_for_display(&verification_token); 313 - if let Err(e) = crate::comms::enqueue_signup_verification( 314 - &state.db, 270 + let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 271 + if let Err(e) = crate::comms::comms_repo::enqueue_signup_verification( 272 + state.infra_repo.as_ref(), 315 273 user_id, 316 274 "email", 317 275 &new_email, 318 276 &formatted_token, 319 - None, 277 + &hostname, 320 278 ) 321 279 .await 322 280 { 323 281 warn!("Failed to send verification email to new address: {:?}", e); 324 282 } 325 283 326 - match sqlx::query!( 327 - "INSERT INTO account_preferences (user_id, name, value_json) VALUES ($1, 'email_auth_factor', $2) ON CONFLICT (user_id, name) DO UPDATE SET value_json = $2", 328 - user_id, 329 - json!(input.email_auth_factor.unwrap_or(false)) 330 - ) 331 - .execute(&state.db) 332 - .await 284 + if let Err(e) = state 285 + .infra_repo 286 + .upsert_account_preference( 287 + user_id, 288 + "email_auth_factor", 289 + json!(input.email_auth_factor.unwrap_or(false)), 290 + ) 291 + .await 333 292 { 334 - Ok(_) => {} 335 - Err(e) => warn!("Failed to update email_auth_factor preference: {}", e), 293 + warn!("Failed to update email_auth_factor preference: {}", e); 336 294 } 337 295 338 296 info!("Email updated for user {}", user_id); ··· 357 315 return ApiError::RateLimitExceeded(None).into_response(); 358 316 } 359 317 360 - let user = sqlx::query!( 361 - "SELECT email_verified FROM users WHERE email = $1 OR handle = $1", 362 - input.identifier 363 - ) 364 - .fetch_optional(&state.db) 365 - .await; 366 - 367 - match user { 368 - Ok(Some(row)) => VerifiedResponse::response(row.email_verified).into_response(), 318 + match state 319 + .user_repo 320 + .check_email_verified_by_identifier(&input.identifier) 321 + .await 322 + { 323 + Ok(Some(verified)) => VerifiedResponse::response(verified).into_response(), 369 324 Ok(None) => ApiError::AccountNotFound.into_response(), 370 325 Err(e) => { 371 326 error!("DB error checking email verified: {:?}", e);
+78 -106
crates/tranquil-pds/src/api/server/invite.rs
··· 2 2 use crate::auth::BearerAuth; 3 3 use crate::auth::extractor::BearerAuthAdmin; 4 4 use crate::state::AppState; 5 + use crate::types::Did; 5 6 use axum::{ 6 7 Json, 7 8 extract::State, ··· 50 51 return ApiError::InvalidRequest("useCount must be at least 1".into()).into_response(); 51 52 } 52 53 53 - let for_account = input 54 - .for_account 55 - .unwrap_or_else(|| auth_user.did.to_string()); 54 + let for_account: Did = match &input.for_account { 55 + Some(acct) => match acct.parse() { 56 + Ok(d) => d, 57 + Err(_) => return ApiError::InvalidDid("Invalid DID format".into()).into_response(), 58 + }, 59 + None => auth_user.did.clone(), 60 + }; 56 61 let code = gen_invite_code(); 57 62 58 - match sqlx::query!( 59 - "INSERT INTO invite_codes (code, available_uses, created_by_user, for_account) 60 - SELECT $1, $2, id, $3 FROM users WHERE is_admin = true LIMIT 1", 61 - code, 62 - input.use_count, 63 - for_account 64 - ) 65 - .execute(&state.db) 66 - .await 63 + match state 64 + .infra_repo 65 + .create_invite_code(&code, input.use_count, Some(&for_account)) 66 + .await 67 67 { 68 - Ok(result) => { 69 - if result.rows_affected() == 0 { 70 - error!("No admin user found to create invite code"); 71 - return ApiError::InternalError(None).into_response(); 72 - } 73 - Json(CreateInviteCodeOutput { code }).into_response() 68 + Ok(true) => Json(CreateInviteCodeOutput { code }).into_response(), 69 + Ok(false) => { 70 + error!("No admin user found to create invite code"); 71 + ApiError::InternalError(None).into_response() 74 72 } 75 73 Err(e) => { 76 74 error!("DB error creating invite code: {:?}", e); ··· 108 106 } 109 107 110 108 let code_count = input.code_count.unwrap_or(1).max(1); 111 - let for_accounts = input 112 - .for_accounts 113 - .filter(|v| !v.is_empty()) 114 - .unwrap_or_else(|| vec![auth_user.did.to_string()]); 115 - 116 - let admin_user_id = 117 - match sqlx::query_scalar!("SELECT id FROM users WHERE is_admin = true LIMIT 1") 118 - .fetch_optional(&state.db) 119 - .await 120 - { 121 - Ok(Some(id)) => id, 122 - Ok(None) => { 123 - error!("No admin user found to create invite codes"); 124 - return ApiError::InternalError(None).into_response(); 109 + let for_accounts: Vec<Did> = match &input.for_accounts { 110 + Some(accounts) if !accounts.is_empty() => { 111 + let parsed: Result<Vec<Did>, _> = accounts.iter().map(|a| a.parse()).collect(); 112 + match parsed { 113 + Ok(dids) => dids, 114 + Err(_) => return ApiError::InvalidDid("Invalid DID format".into()).into_response(), 125 115 } 126 - Err(e) => { 127 - error!("DB error looking up admin user: {:?}", e); 128 - return ApiError::InternalError(None).into_response(); 129 - } 130 - }; 116 + } 117 + _ => vec![auth_user.did.clone()], 118 + }; 119 + 120 + let admin_user_id = match state.user_repo.get_any_admin_user_id().await { 121 + Ok(Some(id)) => id, 122 + Ok(None) => { 123 + error!("No admin user found to create invite codes"); 124 + return ApiError::InternalError(None).into_response(); 125 + } 126 + Err(e) => { 127 + error!("DB error looking up admin user: {:?}", e); 128 + return ApiError::InternalError(None).into_response(); 129 + } 130 + }; 131 131 132 132 let result = futures::future::try_join_all(for_accounts.into_iter().map(|account| { 133 - let db = state.db.clone(); 133 + let infra_repo = state.infra_repo.clone(); 134 134 let use_count = input.use_count; 135 135 async move { 136 136 let codes: Vec<String> = (0..code_count).map(|_| gen_invite_code()).collect(); 137 - sqlx::query!( 138 - r#" 139 - INSERT INTO invite_codes (code, available_uses, created_by_user, for_account) 140 - SELECT code, $2, $3, $4 FROM UNNEST($1::text[]) AS t(code) 141 - "#, 142 - &codes[..], 143 - use_count, 144 - admin_user_id, 145 - account 146 - ) 147 - .execute(&db) 148 - .await 149 - .map(|_| AccountCodes { account, codes }) 137 + infra_repo 138 + .create_invite_codes_batch(&codes, use_count, admin_user_id, Some(&account)) 139 + .await 140 + .map(|_| AccountCodes { account: account.to_string(), codes }) 150 141 } 151 142 })) 152 143 .await; ··· 203 194 ) -> Response { 204 195 let include_used = params.include_used.unwrap_or(true); 205 196 206 - let codes_rows = match sqlx::query!( 207 - r#" 208 - SELECT 209 - ic.code, 210 - ic.available_uses, 211 - ic.created_at, 212 - ic.disabled, 213 - ic.for_account, 214 - (SELECT COUNT(*) FROM invite_code_uses icu WHERE icu.code = ic.code)::int as "use_count!" 215 - FROM invite_codes ic 216 - WHERE ic.for_account = $1 217 - ORDER BY ic.created_at DESC 218 - "#, 219 - &auth_user.did 220 - ) 221 - .fetch_all(&state.db) 222 - .await 197 + let codes_info = match state 198 + .infra_repo 199 + .get_invite_codes_for_account(&auth_user.did) 200 + .await 223 201 { 224 - Ok(rows) => rows, 202 + Ok(info) => info, 225 203 Err(e) => { 226 204 error!("DB error fetching invite codes: {:?}", e); 227 205 return ApiError::InternalError(None).into_response(); 228 206 } 229 207 }; 230 208 231 - let filtered_rows: Vec<_> = codes_rows 209 + let filtered_codes: Vec<_> = codes_info 232 210 .into_iter() 233 - .filter(|row| { 234 - let disabled = row.disabled.unwrap_or(false); 235 - !disabled && (include_used || row.use_count < row.available_uses) 236 - }) 211 + .filter(|info| !info.disabled) 237 212 .collect(); 238 213 239 - let codes = futures::future::join_all(filtered_rows.into_iter().map(|row| { 240 - let db = state.db.clone(); 214 + let codes = futures::future::join_all(filtered_codes.into_iter().map(|info| { 215 + let infra_repo = state.infra_repo.clone(); 241 216 async move { 242 - let uses = sqlx::query!( 243 - r#" 244 - SELECT u.did, u.handle, icu.used_at 245 - FROM invite_code_uses icu 246 - JOIN users u ON icu.used_by_user = u.id 247 - WHERE icu.code = $1 248 - ORDER BY icu.used_at DESC 249 - "#, 250 - row.code 251 - ) 252 - .fetch_all(&db) 253 - .await 254 - .map(|use_rows| { 255 - use_rows 256 - .iter() 257 - .map(|u| InviteCodeUse { 258 - used_by: u.did.clone(), 259 - used_by_handle: Some(u.handle.clone()), 260 - used_at: u.used_at.to_rfc3339(), 261 - }) 262 - .collect() 263 - }) 264 - .unwrap_or_default(); 217 + let uses = infra_repo 218 + .get_invite_code_uses(&info.code) 219 + .await 220 + .map(|use_rows| { 221 + use_rows 222 + .into_iter() 223 + .map(|u| InviteCodeUse { 224 + used_by: u.used_by_did.to_string(), 225 + used_by_handle: u.used_by_handle.map(|h| h.to_string()), 226 + used_at: u.used_at.to_rfc3339(), 227 + }) 228 + .collect::<Vec<_>>() 229 + }) 230 + .unwrap_or_default(); 265 231 266 - InviteCode { 267 - code: row.code, 268 - available: row.available_uses, 232 + let use_count = uses.len() as i32; 233 + if !include_used && use_count >= info.available_uses { 234 + return None; 235 + } 236 + 237 + Some(InviteCode { 238 + code: info.code, 239 + available: info.available_uses, 269 240 disabled: false, 270 - for_account: row.for_account, 271 - created_by: "admin".to_string(), 272 - created_at: row.created_at.to_rfc3339(), 241 + for_account: info.for_account.map(|d| d.to_string()).unwrap_or_default(), 242 + created_by: info.created_by.map(|d| d.to_string()).unwrap_or_else(|| "admin".to_string()), 243 + created_at: info.created_at.to_rfc3339(), 273 244 uses, 274 - } 245 + }) 275 246 } 276 247 })) 277 248 .await; 278 249 250 + let codes: Vec<InviteCode> = codes.into_iter().flatten().collect(); 279 251 Json(GetAccountInviteCodesOutput { codes }).into_response() 280 252 }
+13 -22
crates/tranquil-pds/src/api/server/logo.rs
··· 9 9 use tracing::error; 10 10 11 11 pub async fn get_logo(State(state): State<AppState>) -> Response { 12 - let logo_cid: Option<String> = 13 - match sqlx::query_scalar("SELECT value FROM server_config WHERE key = 'logo_cid'") 14 - .fetch_optional(&state.db) 15 - .await 16 - { 17 - Ok(cid) => cid, 18 - Err(e) => { 19 - error!("DB error fetching logo_cid: {:?}", e); 20 - return StatusCode::INTERNAL_SERVER_ERROR.into_response(); 21 - } 22 - }; 12 + let logo_cid = match state.infra_repo.get_server_config("logo_cid").await { 13 + Ok(cid) => cid, 14 + Err(e) => { 15 + error!("DB error fetching logo_cid: {:?}", e); 16 + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); 17 + } 18 + }; 23 19 24 - let cid = match logo_cid { 20 + let cid_str = match logo_cid { 25 21 Some(c) if !c.is_empty() => c, 26 22 _ => return StatusCode::NOT_FOUND.into_response(), 27 23 }; 24 + let cid = crate::types::CidLink::new_unchecked(&cid_str); 28 25 29 - let blob = match sqlx::query!( 30 - "SELECT storage_key, mime_type FROM blobs WHERE cid = $1", 31 - cid 32 - ) 33 - .fetch_optional(&state.db) 34 - .await 35 - { 36 - Ok(Some(row)) => row, 26 + let metadata = match state.blob_repo.get_blob_metadata(&cid).await { 27 + Ok(Some(m)) => m, 37 28 Ok(None) => return StatusCode::NOT_FOUND.into_response(), 38 29 Err(e) => { 39 30 error!("DB error fetching blob: {:?}", e); ··· 41 32 } 42 33 }; 43 34 44 - match state.blob_store.get(&blob.storage_key).await { 35 + match state.blob_store.get(&metadata.storage_key).await { 45 36 Ok(data) => Response::builder() 46 37 .status(StatusCode::OK) 47 - .header(header::CONTENT_TYPE, &blob.mime_type) 38 + .header(header::CONTENT_TYPE, &metadata.mime_type) 48 39 .header(header::CACHE_CONTROL, "public, max-age=3600") 49 40 .body(Body::from(data)) 50 41 .unwrap(),
+3 -7
crates/tranquil-pds/src/api/server/meta.rs
··· 1 1 use crate::state::AppState; 2 2 use axum::{Json, extract::State, http::StatusCode, response::IntoResponse}; 3 3 use serde_json::json; 4 - use tracing::error; 5 4 6 5 fn get_available_comms_channels() -> Vec<&'static str> { 7 6 let mut channels = vec!["email"]; ··· 64 63 })) 65 64 } 66 65 pub async fn health(State(state): State<AppState>) -> impl IntoResponse { 67 - match sqlx::query!("SELECT 1 as one").fetch_one(&state.db).await { 68 - Ok(_) => (StatusCode::OK, "OK"), 69 - Err(e) => { 70 - error!("Health check failed: {:?}", e); 71 - (StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable") 72 - } 66 + match state.infra_repo.health_check().await { 67 + Ok(true) => (StatusCode::OK, "OK"), 68 + _ => (StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable"), 73 69 } 74 70 }
+30 -68
crates/tranquil-pds/src/api/server/migration.rs
··· 6 6 http::StatusCode, 7 7 response::{IntoResponse, Response}, 8 8 }; 9 - use chrono::Utc; 10 9 use serde::{Deserialize, Serialize}; 11 10 use serde_json::json; 12 11 ··· 51 50 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 52 51 ); 53 52 let auth_user = match crate::auth::validate_token_with_dpop( 54 - &state.db, 53 + state.user_repo.as_ref(), 54 + state.oauth_repo.as_ref(), 55 55 &extracted.token, 56 56 extracted.is_dpop, 57 57 dpop_proof, ··· 73 73 .into_response(); 74 74 } 75 75 76 - let user = match sqlx::query!( 77 - "SELECT id, handle, deactivated_at FROM users WHERE did = $1", 78 - &auth_user.did 79 - ) 80 - .fetch_optional(&state.db) 81 - .await 82 - { 83 - Ok(Some(row)) => row, 76 + let user = match state.user_repo.get_user_for_did_doc(&auth_user.did).await { 77 + Ok(Some(u)) => u, 84 78 Ok(None) => return ApiError::AccountNotFound.into_response(), 85 79 Err(e) => { 86 80 tracing::error!("DB error getting user: {:?}", e); ··· 137 131 138 132 let also_known_as: Option<Vec<String>> = input.also_known_as.clone(); 139 133 140 - let now = Utc::now(); 141 - 142 - let upsert_result = sqlx::query!( 143 - r#" 144 - INSERT INTO did_web_overrides (user_id, verification_methods, also_known_as, updated_at) 145 - VALUES ($1, COALESCE($2, '[]'::jsonb), COALESCE($3, '{}'::text[]), $4) 146 - ON CONFLICT (user_id) DO UPDATE SET 147 - verification_methods = CASE WHEN $2 IS NOT NULL THEN $2 ELSE did_web_overrides.verification_methods END, 148 - also_known_as = CASE WHEN $3 IS NOT NULL THEN $3 ELSE did_web_overrides.also_known_as END, 149 - updated_at = $4 150 - "#, 151 - user.id, 152 - verification_methods_json, 153 - also_known_as.as_deref(), 154 - now 155 - ) 156 - .execute(&state.db) 157 - .await; 158 - 159 - if let Err(e) = upsert_result { 134 + if let Err(e) = state 135 + .user_repo 136 + .upsert_did_web_overrides(user.id, verification_methods_json, also_known_as) 137 + .await 138 + { 160 139 tracing::error!("DB error upserting did_web_overrides: {:?}", e); 161 140 return ApiError::InternalError(None).into_response(); 162 141 } 163 142 164 143 if let Some(ref endpoint) = input.service_endpoint { 165 144 let endpoint_clean = endpoint.trim().trim_end_matches('/'); 166 - let update_result = sqlx::query!( 167 - "UPDATE users SET migrated_to_pds = $1, migrated_at = $2 WHERE did = $3", 168 - endpoint_clean, 169 - now, 170 - &auth_user.did 171 - ) 172 - .execute(&state.db) 173 - .await; 174 - 175 - if let Err(e) = update_result { 145 + if let Err(e) = state 146 + .user_repo 147 + .update_migrated_to_pds(&auth_user.did, endpoint_clean) 148 + .await 149 + { 176 150 tracing::error!("DB error updating service endpoint: {:?}", e); 177 151 return ApiError::InternalError(None).into_response(); 178 152 } 179 153 } 180 154 181 - let did_doc = build_did_document(&state.db, &auth_user.did).await; 155 + let did_doc = build_did_document(&state, &auth_user.did).await; 182 156 183 157 tracing::info!("Updated DID document for {}", &auth_user.did); 184 158 ··· 208 182 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()) 209 183 ); 210 184 let auth_user = match crate::auth::validate_token_with_dpop( 211 - &state.db, 185 + state.user_repo.as_ref(), 186 + state.oauth_repo.as_ref(), 212 187 &extracted.token, 213 188 extracted.is_dpop, 214 189 dpop_proof, ··· 230 205 .into_response(); 231 206 } 232 207 233 - let did_doc = build_did_document(&state.db, &auth_user.did).await; 208 + let did_doc = build_did_document(&state, &auth_user.did).await; 234 209 235 210 (StatusCode::OK, Json(json!({ "didDocument": did_doc }))).into_response() 236 211 } 237 212 238 - async fn build_did_document(db: &sqlx::PgPool, did: &str) -> serde_json::Value { 213 + async fn build_did_document(state: &AppState, did: &crate::types::Did) -> serde_json::Value { 239 214 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 240 215 241 - let user = match sqlx::query!( 242 - "SELECT id, handle, migrated_to_pds FROM users WHERE did = $1", 243 - did 244 - ) 245 - .fetch_optional(db) 246 - .await 247 - { 216 + let user = match state.user_repo.get_user_for_did_doc_build(did).await { 248 217 Ok(Some(row)) => row, 249 218 _ => { 250 219 return json!({ ··· 253 222 } 254 223 }; 255 224 256 - let overrides = sqlx::query!( 257 - "SELECT verification_methods, also_known_as FROM did_web_overrides WHERE user_id = $1", 258 - user.id 259 - ) 260 - .fetch_optional(db) 261 - .await 262 - .ok() 263 - .flatten(); 225 + let overrides = state 226 + .user_repo 227 + .get_did_web_overrides(user.id) 228 + .await 229 + .ok() 230 + .flatten(); 264 231 265 232 let service_endpoint = user 266 233 .migrated_to_pds ··· 299 266 }); 300 267 } 301 268 302 - let key_row = sqlx::query!( 303 - "SELECT key_bytes, encryption_version FROM user_keys WHERE user_id = $1", 304 - user.id 305 - ) 306 - .fetch_optional(db) 307 - .await; 269 + let key_info = state.user_repo.get_user_key_by_id(user.id).await.ok().flatten(); 308 270 309 - let public_key_multibase = match key_row { 310 - Ok(Some(row)) => match crate::config::decrypt_key(&row.key_bytes, row.encryption_version) { 271 + let public_key_multibase = match key_info { 272 + Some(info) => match crate::config::decrypt_key(&info.key_bytes, info.encryption_version) { 311 273 Ok(key_bytes) => crate::api::identity::did::get_public_key_multibase(&key_bytes) 312 274 .unwrap_or_else(|_| "error".to_string()), 313 275 Err(_) => "error".to_string(), 314 276 }, 315 - _ => "error".to_string(), 277 + None => "error".to_string(), 316 278 }; 317 279 318 280 let also_known_as = if let Some(ref ovr) = overrides {
+3 -3
crates/tranquil-pds/src/api/server/mod.rs
··· 32 32 request_passkey_recovery, start_passkey_registration_for_setup, 33 33 }; 34 34 pub use passkeys::{ 35 - delete_passkey, finish_passkey_registration, has_passkeys_for_user, has_passkeys_for_user_db, 36 - list_passkeys, start_passkey_registration, update_passkey, 35 + delete_passkey, finish_passkey_registration, has_passkeys_for_user, list_passkeys, 36 + start_passkey_registration, update_passkey, 37 37 }; 38 38 pub use password::{ 39 39 change_password, get_password_status, remove_password, request_password_reset, reset_password, ··· 53 53 pub use signing_key::reserve_signing_key; 54 54 pub use totp::{ 55 55 create_totp_secret, disable_totp, enable_totp, get_totp_status, has_totp_enabled, 56 - has_totp_enabled_db, regenerate_backup_codes, verify_totp_or_backup_for_user, 56 + regenerate_backup_codes, verify_totp_or_backup_for_user, 57 57 }; 58 58 pub use trusted_devices::{ 59 59 extend_device_trust, is_device_trusted, list_trusted_devices, revoke_trusted_device,
+153 -341
crates/tranquil-pds/src/api/server/passkey_account.rs
··· 149 149 .unwrap_or(false); 150 150 151 151 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 152 - let pds_suffix = format!(".{}", hostname); 152 + let hostname_for_handles = hostname.split(':').next().unwrap_or(&hostname); 153 + let pds_suffix = format!(".{}", hostname_for_handles); 153 154 154 155 let handle = if !input.handle.contains('.') || input.handle.ends_with(&pds_suffix) { 155 156 let handle_to_validate = if input.handle.ends_with(&pds_suffix) { ··· 161 162 &input.handle 162 163 }; 163 164 match crate::api::validation::validate_short_handle(handle_to_validate) { 164 - Ok(h) => format!("{}.{}", h, hostname), 165 + Ok(h) => format!("{}.{}", h, hostname_for_handles), 165 166 Err(_) => { 166 167 return ApiError::InvalidHandle(None).into_response(); 167 168 } ··· 182 183 } 183 184 184 185 if let Some(ref code) = input.invite_code { 185 - let valid = sqlx::query_scalar!( 186 - "SELECT available_uses > 0 AND NOT disabled FROM invite_codes WHERE code = $1", 187 - code 188 - ) 189 - .fetch_optional(&state.db) 190 - .await 191 - .ok() 192 - .flatten() 193 - .unwrap_or(Some(false)); 186 + let valid = state.infra_repo.is_invite_code_valid(code).await.unwrap_or(false); 194 187 195 - if valid != Some(true) { 188 + if !valid { 196 189 return ApiError::InvalidInviteCode.into_response(); 197 190 } 198 191 } else { ··· 233 226 234 227 let (secret_key_bytes, reserved_key_id): (Vec<u8>, Option<Uuid>) = 235 228 if let Some(signing_key_did) = &input.signing_key { 236 - let reserved = sqlx::query!( 237 - r#" 238 - SELECT id, private_key_bytes 239 - FROM reserved_signing_keys 240 - WHERE public_key_did_key = $1 241 - AND used_at IS NULL 242 - AND expires_at > NOW() 243 - FOR UPDATE 244 - "#, 245 - signing_key_did 246 - ) 247 - .fetch_optional(&state.db) 248 - .await; 249 - match reserved { 250 - Ok(Some(row)) => (row.private_key_bytes, Some(row.id)), 229 + match state.infra_repo.get_reserved_signing_key(signing_key_did).await { 230 + Ok(Some(reserved)) => (reserved.private_key_bytes, Some(reserved.id)), 251 231 Ok(None) => { 252 232 return ApiError::InvalidSigningKey.into_response(); 253 233 } ··· 271 251 272 252 let did = match did_type { 273 253 "web" => { 274 - let subdomain_host = format!("{}.{}", input.handle, hostname); 254 + let subdomain_host = format!("{}.{}", input.handle, hostname_for_handles); 275 255 let encoded_subdomain = subdomain_host.replace(':', "%3A"); 276 256 let self_hosted_did = format!("did:web:{}", encoded_subdomain); 277 257 info!(did = %self_hosted_did, "Creating self-hosted did:web passkey account"); ··· 391 371 }; 392 372 let setup_expires_at = Utc::now() + Duration::hours(1); 393 373 394 - let mut tx = match state.db.begin().await { 395 - Ok(tx) => tx, 396 - Err(e) => { 397 - error!("Error starting transaction: {:?}", e); 398 - return ApiError::InternalError(None).into_response(); 399 - } 400 - }; 401 - 402 - let is_first_user = sqlx::query_scalar!("SELECT COUNT(*) as count FROM users") 403 - .fetch_one(&mut *tx) 404 - .await 405 - .map(|c| c.unwrap_or(0) == 0) 406 - .unwrap_or(false); 407 - 408 374 let deactivated_at: Option<chrono::DateTime<Utc>> = if is_byod_did_web { 409 375 Some(Utc::now()) 410 376 } else { 411 377 None 412 378 }; 413 379 414 - let user_insert: Result<(Uuid,), _> = sqlx::query_as( 415 - r#"INSERT INTO users ( 416 - handle, email, did, password_hash, password_required, 417 - preferred_comms_channel, 418 - discord_id, telegram_username, signal_number, 419 - recovery_token, recovery_token_expires_at, 420 - is_admin, deactivated_at 421 - ) VALUES ($1, $2, $3, NULL, FALSE, $4::comms_channel, $5, $6, $7, $8, $9, $10, $11) RETURNING id"#, 422 - ) 423 - .bind(&handle) 424 - .bind(&email) 425 - .bind(&did) 426 - .bind(verification_channel) 427 - .bind( 428 - input 429 - .discord_id 430 - .as_deref() 431 - .map(|s| s.trim()) 432 - .filter(|s| !s.is_empty()), 433 - ) 434 - .bind( 435 - input 436 - .telegram_username 437 - .as_deref() 438 - .map(|s| s.trim()) 439 - .filter(|s| !s.is_empty()), 440 - ) 441 - .bind( 442 - input 443 - .signal_number 444 - .as_deref() 445 - .map(|s| s.trim()) 446 - .filter(|s| !s.is_empty()), 447 - ) 448 - .bind(&setup_token_hash) 449 - .bind(setup_expires_at) 450 - .bind(is_first_user) 451 - .bind(deactivated_at) 452 - .fetch_one(&mut *tx) 453 - .await; 454 - 455 - let user_id = match user_insert { 456 - Ok((id,)) => id, 457 - Err(e) => { 458 - if let Some(db_err) = e.as_database_error() 459 - && db_err.code().as_deref() == Some("23505") 460 - { 461 - let constraint = db_err.constraint().unwrap_or(""); 462 - if constraint.contains("handle") { 463 - return ApiError::HandleNotAvailable(None).into_response(); 464 - } else if constraint.contains("email") { 465 - return ApiError::EmailTaken.into_response(); 466 - } 467 - } 468 - error!("Error inserting user: {:?}", e); 469 - return ApiError::InternalError(None).into_response(); 470 - } 471 - }; 472 - 473 380 let encrypted_key_bytes = match crate::config::encrypt_key(&secret_key_bytes) { 474 381 Ok(bytes) => bytes, 475 382 Err(e) => { ··· 478 385 } 479 386 }; 480 387 481 - if let Err(e) = sqlx::query!( 482 - "INSERT INTO user_keys (user_id, key_bytes, encryption_version, encrypted_at) VALUES ($1, $2, $3, NOW())", 483 - user_id, 484 - &encrypted_key_bytes[..], 485 - crate::config::ENCRYPTION_VERSION 486 - ) 487 - .execute(&mut *tx) 488 - .await 489 - { 490 - error!("Error inserting user key: {:?}", e); 491 - return ApiError::InternalError(None).into_response(); 492 - } 493 - 494 - if let Some(key_id) = reserved_key_id 495 - && let Err(e) = sqlx::query!( 496 - "UPDATE reserved_signing_keys SET used_at = NOW() WHERE id = $1", 497 - key_id 498 - ) 499 - .execute(&mut *tx) 500 - .await 501 - { 502 - error!("Error marking reserved key as used: {:?}", e); 503 - return ApiError::InternalError(None).into_response(); 504 - } 505 - 506 388 let mst = Mst::new(Arc::new(state.block_store.clone())); 507 389 let mst_root = match mst.persist().await { 508 390 Ok(c) => c, ··· 528 410 return ApiError::InternalError(None).into_response(); 529 411 } 530 412 }; 531 - let commit_cid_str = commit_cid.to_string(); 532 - let rev_str = rev.as_ref().to_string(); 533 - if let Err(e) = sqlx::query!( 534 - "INSERT INTO repos (user_id, repo_root_cid, repo_rev) VALUES ($1, $2, $3)", 535 - user_id, 536 - commit_cid_str, 537 - rev_str 538 - ) 539 - .execute(&mut *tx) 540 - .await 541 - { 542 - error!("Error inserting repo: {:?}", e); 543 - return ApiError::InternalError(None).into_response(); 544 - } 545 413 let genesis_block_cids = vec![mst_root.to_bytes(), commit_cid.to_bytes()]; 546 - if let Err(e) = sqlx::query!( 547 - r#" 548 - INSERT INTO user_blocks (user_id, block_cid) 549 - SELECT $1, block_cid FROM UNNEST($2::bytea[]) AS t(block_cid) 550 - ON CONFLICT (user_id, block_cid) DO NOTHING 551 - "#, 552 - user_id, 553 - &genesis_block_cids 554 - ) 555 - .execute(&mut *tx) 556 - .await 557 - { 558 - error!("Error inserting user_blocks: {:?}", e); 559 - return ApiError::InternalError(None).into_response(); 560 - } 561 414 562 - if let Some(ref code) = input.invite_code { 563 - let _ = sqlx::query!( 564 - "UPDATE invite_codes SET available_uses = available_uses - 1 WHERE code = $1", 565 - code 566 - ) 567 - .execute(&mut *tx) 568 - .await; 569 - 570 - let _ = sqlx::query!( 571 - "INSERT INTO invite_code_uses (code, used_by_user) VALUES ($1, $2)", 572 - code, 573 - user_id 574 - ) 575 - .execute(&mut *tx) 576 - .await; 577 - } 578 - 579 - if std::env::var("PDS_AGE_ASSURANCE_OVERRIDE").is_ok() { 580 - let birthdate_pref = json!({ 415 + let birthdate_pref = std::env::var("PDS_AGE_ASSURANCE_OVERRIDE").ok().map(|_| { 416 + json!({ 581 417 "$type": "app.bsky.actor.defs#personalDetailsPref", 582 418 "birthDate": "1998-05-06T00:00:00.000Z" 583 - }); 584 - if let Err(e) = sqlx::query!( 585 - "INSERT INTO account_preferences (user_id, name, value_json) VALUES ($1, $2, $3) 586 - ON CONFLICT (user_id, name) DO NOTHING", 587 - user_id, 588 - "app.bsky.actor.defs#personalDetailsPref", 589 - birthdate_pref 590 - ) 591 - .execute(&mut *tx) 592 - .await 593 - { 594 - warn!("Failed to set default birthdate preference: {:?}", e); 595 - } 596 - } 419 + }) 420 + }); 421 + 422 + let preferred_comms_channel = match verification_channel { 423 + "email" => tranquil_db_traits::CommsChannel::Email, 424 + "discord" => tranquil_db_traits::CommsChannel::Discord, 425 + "telegram" => tranquil_db_traits::CommsChannel::Telegram, 426 + "signal" => tranquil_db_traits::CommsChannel::Signal, 427 + _ => tranquil_db_traits::CommsChannel::Email, 428 + }; 597 429 598 - if let Err(e) = tx.commit().await { 599 - error!("Error committing transaction: {:?}", e); 600 - return ApiError::InternalError(None).into_response(); 601 - } 430 + let handle_typed = Handle::new_unchecked(&handle); 431 + let create_input = tranquil_db_traits::CreatePasskeyAccountInput { 432 + handle: handle_typed.clone(), 433 + email: email.clone().unwrap_or_default(), 434 + did: did_typed.clone(), 435 + preferred_comms_channel, 436 + discord_id: input.discord_id.as_deref().map(|s| s.trim()).filter(|s| !s.is_empty()).map(String::from), 437 + telegram_username: input.telegram_username.as_deref().map(|s| s.trim()).filter(|s| !s.is_empty()).map(String::from), 438 + signal_number: input.signal_number.as_deref().map(|s| s.trim()).filter(|s| !s.is_empty()).map(String::from), 439 + setup_token_hash, 440 + setup_expires_at, 441 + deactivated_at, 442 + encrypted_key_bytes, 443 + encryption_version: crate::config::ENCRYPTION_VERSION, 444 + reserved_key_id, 445 + commit_cid: commit_cid.to_string(), 446 + repo_rev: rev.as_ref().to_string(), 447 + genesis_block_cids, 448 + invite_code: input.invite_code.clone(), 449 + birthdate_pref, 450 + }; 451 + 452 + let create_result = match state.user_repo.create_passkey_account(&create_input).await { 453 + Ok(r) => r, 454 + Err(tranquil_db_traits::CreateAccountError::HandleTaken) => { 455 + return ApiError::HandleNotAvailable(None).into_response(); 456 + } 457 + Err(tranquil_db_traits::CreateAccountError::EmailTaken) => { 458 + return ApiError::EmailTaken.into_response(); 459 + } 460 + Err(e) => { 461 + error!("Error creating passkey account: {:?}", e); 462 + return ApiError::InternalError(None).into_response(); 463 + } 464 + }; 465 + let user_id = create_result.user_id; 602 466 603 467 if !is_byod_did_web { 604 - let handle_typed = Handle::new_unchecked(&handle); 605 468 if let Err(e) = crate::api::repo::record::sequence_identity_event( 606 469 &state, 607 470 &did_typed, ··· 642 505 ); 643 506 let formatted_token = 644 507 crate::auth::verification_token::format_token_for_display(&verification_token); 645 - if let Err(e) = crate::comms::enqueue_signup_verification( 646 - &state.db, 508 + if let Err(e) = crate::comms::comms_repo::enqueue_signup_verification( 509 + state.infra_repo.as_ref(), 647 510 user_id, 648 511 verification_channel, 649 512 &verification_recipient, 650 513 &formatted_token, 651 - None, 514 + &hostname, 652 515 ) 653 516 .await 654 517 { ··· 662 525 Ok(token_meta) => { 663 526 let refresh_jti = uuid::Uuid::new_v4().to_string(); 664 527 let refresh_expires = chrono::Utc::now() + chrono::Duration::hours(24); 665 - let no_scope: Option<String> = None; 666 - if let Err(e) = sqlx::query!( 667 - "INSERT INTO session_tokens (did, access_jti, refresh_jti, access_expires_at, refresh_expires_at, legacy_login, mfa_verified, scope) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)", 668 - did, 669 - token_meta.jti, 528 + let session_data = tranquil_db::SessionTokenCreate { 529 + did: did_typed.clone(), 530 + access_jti: token_meta.jti.clone(), 670 531 refresh_jti, 671 - token_meta.expires_at, 672 - refresh_expires, 673 - false, 674 - false, 675 - no_scope 676 - ) 677 - .execute(&state.db) 678 - .await 679 - { 532 + access_expires_at: token_meta.expires_at, 533 + refresh_expires_at: refresh_expires, 534 + legacy_login: false, 535 + mfa_verified: false, 536 + scope: None, 537 + controller_did: None, 538 + app_password_name: None, 539 + }; 540 + if let Err(e) = state.session_repo.create_session(&session_data).await { 680 541 warn!(did = %did, "Failed to insert migration session: {:?}", e); 681 542 } 682 543 info!(did = %did, "Generated migration access token for BYOD passkey account"); ··· 723 584 State(state): State<AppState>, 724 585 Json(input): Json<CompletePasskeySetupInput>, 725 586 ) -> Response { 726 - let user = sqlx::query!( 727 - r#"SELECT id, handle, recovery_token, recovery_token_expires_at, password_required 728 - FROM users WHERE did = $1"#, 729 - input.did.as_str() 730 - ) 731 - .fetch_optional(&state.db) 732 - .await; 733 - 734 - let user = match user { 587 + let user = match state.user_repo.get_user_for_passkey_setup(&input.did).await { 735 588 Ok(Some(u)) => u, 736 589 Ok(None) => { 737 590 return ApiError::AccountNotFound.into_response(); ··· 772 625 } 773 626 }; 774 627 775 - let reg_state = 776 - match crate::auth::webauthn::load_registration_state(&state.db, &input.did).await { 777 - Ok(Some(s)) => s, 778 - Ok(None) => { 779 - return ApiError::NoChallengeInProgress.into_response(); 780 - } 628 + let reg_state = match state 629 + .user_repo 630 + .load_webauthn_challenge(&input.did, "registration") 631 + .await 632 + { 633 + Ok(Some(json)) => match serde_json::from_str(&json) { 634 + Ok(s) => s, 781 635 Err(e) => { 782 - error!("Error loading registration state: {:?}", e); 636 + error!("Error deserializing registration state: {:?}", e); 783 637 return ApiError::InternalError(None).into_response(); 784 638 } 785 - }; 639 + }, 640 + Ok(None) => { 641 + return ApiError::NoChallengeInProgress.into_response(); 642 + } 643 + Err(e) => { 644 + error!("Error loading registration state: {:?}", e); 645 + return ApiError::InternalError(None).into_response(); 646 + } 647 + }; 786 648 787 649 let credential: webauthn_rs::prelude::RegisterPublicKeyCredential = 788 650 match serde_json::from_value(input.passkey_credential) { ··· 801 663 } 802 664 }; 803 665 804 - if let Err(e) = crate::auth::webauthn::save_passkey( 805 - &state.db, 806 - &input.did, 807 - &security_key, 808 - input.passkey_friendly_name.as_deref(), 809 - ) 810 - .await 666 + let credential_id = security_key.cred_id().to_vec(); 667 + let public_key = match serde_json::to_vec(&security_key) { 668 + Ok(pk) => pk, 669 + Err(e) => { 670 + error!("Error serializing security key: {:?}", e); 671 + return ApiError::InternalError(None).into_response(); 672 + } 673 + }; 674 + if let Err(e) = state 675 + .user_repo 676 + .save_passkey( 677 + &input.did, 678 + &credential_id, 679 + &public_key, 680 + input.passkey_friendly_name.as_deref(), 681 + ) 682 + .await 811 683 { 812 684 error!("Error saving passkey: {:?}", e); 813 685 return ApiError::InternalError(None).into_response(); ··· 823 695 } 824 696 }; 825 697 826 - let mut tx = match state.db.begin().await { 827 - Ok(tx) => tx, 828 - Err(e) => { 829 - error!("Failed to begin transaction: {:?}", e); 830 - return ApiError::InternalError(None).into_response(); 831 - } 698 + let setup_input = tranquil_db_traits::CompletePasskeySetupInput { 699 + user_id: user.id, 700 + did: input.did.clone(), 701 + app_password_name: app_password_name.clone(), 702 + app_password_hash: password_hash, 832 703 }; 833 - 834 - if let Err(e) = sqlx::query!( 835 - "INSERT INTO app_passwords (user_id, name, password_hash, privileged) VALUES ($1, $2, $3, FALSE)", 836 - user.id, 837 - app_password_name, 838 - password_hash 839 - ) 840 - .execute(&mut *tx) 841 - .await 842 - { 843 - error!("Error creating app password: {:?}", e); 704 + if let Err(e) = state.user_repo.complete_passkey_setup(&setup_input).await { 705 + error!("Error completing passkey setup: {:?}", e); 844 706 return ApiError::InternalError(None).into_response(); 845 707 } 846 708 847 - if let Err(e) = sqlx::query!( 848 - "UPDATE users SET recovery_token = NULL, recovery_token_expires_at = NULL WHERE did = $1", 849 - input.did.as_str() 850 - ) 851 - .execute(&mut *tx) 852 - .await 853 - { 854 - error!("Error clearing setup token: {:?}", e); 855 - return ApiError::InternalError(None).into_response(); 856 - } 857 - 858 - if let Err(e) = tx.commit().await { 859 - error!("Failed to commit setup transaction: {:?}", e); 860 - return ApiError::InternalError(None).into_response(); 861 - } 862 - 863 - let _ = crate::auth::webauthn::delete_registration_state(&state.db, &input.did).await; 709 + let _ = state 710 + .user_repo 711 + .delete_webauthn_challenge(&input.did, "registration") 712 + .await; 864 713 865 714 info!(did = %input.did, "Passkey-only account setup completed"); 866 715 ··· 877 726 State(state): State<AppState>, 878 727 Json(input): Json<StartPasskeyRegistrationInput>, 879 728 ) -> Response { 880 - let user = sqlx::query!( 881 - r#"SELECT handle, recovery_token, recovery_token_expires_at, password_required 882 - FROM users WHERE did = $1"#, 883 - input.did.as_str() 884 - ) 885 - .fetch_optional(&state.db) 886 - .await; 887 - 888 - let user = match user { 729 + let user = match state.user_repo.get_user_for_passkey_setup(&input.did).await { 889 730 Ok(Some(u)) => u, 890 731 Ok(None) => { 891 732 return ApiError::AccountNotFound.into_response(); ··· 926 767 } 927 768 }; 928 769 929 - let existing_passkeys = crate::auth::webauthn::get_passkeys_for_user(&state.db, &input.did) 770 + let existing_passkeys = state 771 + .user_repo 772 + .get_passkeys_for_user(&input.did) 930 773 .await 931 774 .unwrap_or_default(); 932 775 ··· 950 793 } 951 794 }; 952 795 953 - if let Err(e) = 954 - crate::auth::webauthn::save_registration_state(&state.db, &input.did, &reg_state).await 796 + let state_json = match serde_json::to_string(&reg_state) { 797 + Ok(json) => json, 798 + Err(e) => { 799 + error!("Failed to serialize registration state: {:?}", e); 800 + return ApiError::InternalError(None).into_response(); 801 + } 802 + }; 803 + if let Err(e) = state 804 + .user_repo 805 + .save_webauthn_challenge(&input.did, "registration", &state_json) 806 + .await 955 807 { 956 808 error!("Failed to save registration state: {:?}", e); 957 809 return ApiError::InternalError(None).into_response(); ··· 990 842 } 991 843 992 844 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 845 + let hostname_for_handles = pds_hostname.split(':').next().unwrap_or(&pds_hostname); 993 846 let identifier = input.email.trim().to_lowercase(); 994 847 let identifier = identifier.strip_prefix('@').unwrap_or(&identifier); 995 848 let normalized_handle = if identifier.contains('@') || identifier.contains('.') { 996 849 identifier.to_string() 997 850 } else { 998 - format!("{}.{}", identifier, pds_hostname) 851 + format!("{}.{}", identifier, hostname_for_handles) 999 852 }; 1000 853 1001 - let user = sqlx::query!( 1002 - "SELECT id, did, handle, password_required FROM users WHERE LOWER(email) = $1 OR handle = $2", 1003 - identifier, 1004 - normalized_handle 1005 - ) 1006 - .fetch_optional(&state.db) 1007 - .await; 1008 - 1009 - let user = match user { 854 + let user = match state.user_repo.get_user_for_passkey_recovery(identifier, &normalized_handle).await { 1010 855 Ok(Some(u)) if !u.password_required => u, 1011 856 _ => { 1012 857 return SuccessResponse::ok().into_response(); ··· 1022 867 }; 1023 868 let expires_at = Utc::now() + Duration::hours(1); 1024 869 1025 - if let Err(e) = sqlx::query!( 1026 - "UPDATE users SET recovery_token = $1, recovery_token_expires_at = $2 WHERE did = $3", 1027 - recovery_token_hash, 1028 - expires_at, 1029 - &user.did 1030 - ) 1031 - .execute(&state.db) 1032 - .await 1033 - { 870 + if let Err(e) = state.user_repo.set_recovery_token(&user.did, &recovery_token_hash, expires_at).await { 1034 871 error!("Error updating recovery token: {:?}", e); 1035 872 return ApiError::InternalError(None).into_response(); 1036 873 } ··· 1043 880 urlencoding::encode(&recovery_token) 1044 881 ); 1045 882 1046 - let _ = 1047 - crate::comms::enqueue_passkey_recovery(&state.db, user.id, &recovery_url, &hostname).await; 883 + let _ = crate::comms::comms_repo::enqueue_passkey_recovery( 884 + state.user_repo.as_ref(), 885 + state.infra_repo.as_ref(), 886 + user.id, 887 + &recovery_url, 888 + &hostname, 889 + ) 890 + .await; 1048 891 1049 892 info!(did = %user.did, "Passkey recovery requested"); 1050 893 SuccessResponse::ok().into_response() ··· 1066 909 return ApiError::InvalidRequest(e.to_string()).into_response(); 1067 910 } 1068 911 1069 - let user = sqlx::query!( 1070 - "SELECT id, did, recovery_token, recovery_token_expires_at FROM users WHERE did = $1", 1071 - input.did.as_str() 1072 - ) 1073 - .fetch_optional(&state.db) 1074 - .await; 1075 - 1076 - let user = match user { 912 + let user = match state.user_repo.get_user_for_recovery(&input.did).await { 1077 913 Ok(Some(u)) => u, 1078 914 _ => { 1079 915 return ApiError::InvalidRecoveryLink.into_response(); ··· 1104 940 } 1105 941 }; 1106 942 1107 - let mut tx = match state.db.begin().await { 1108 - Ok(tx) => tx, 1109 - Err(e) => { 1110 - error!("Failed to begin transaction: {:?}", e); 1111 - return ApiError::InternalError(None).into_response(); 1112 - } 943 + let recover_input = tranquil_db_traits::RecoverPasskeyAccountInput { 944 + did: input.did.clone(), 945 + password_hash, 1113 946 }; 1114 - 1115 - if let Err(e) = sqlx::query!( 1116 - "UPDATE users SET password_hash = $1, password_required = TRUE, recovery_token = NULL, recovery_token_expires_at = NULL WHERE did = $2", 1117 - password_hash, 1118 - input.did.as_str() 1119 - ) 1120 - .execute(&mut *tx) 1121 - .await 1122 - { 1123 - error!("Error updating password: {:?}", e); 1124 - return ApiError::InternalError(None).into_response(); 1125 - } 1126 - 1127 - let deleted = sqlx::query!("DELETE FROM passkeys WHERE did = $1", input.did.as_str()) 1128 - .execute(&mut *tx) 1129 - .await; 1130 - let passkeys_deleted = match deleted { 1131 - Ok(result) => result.rows_affected(), 947 + let result = match state.user_repo.recover_passkey_account(&recover_input).await { 948 + Ok(r) => r, 1132 949 Err(e) => { 1133 - error!(did = %input.did, "Failed to delete passkeys during recovery: {:?}", e); 950 + error!("Error recovering passkey account: {:?}", e); 1134 951 return ApiError::InternalError(None).into_response(); 1135 952 } 1136 953 }; 1137 954 1138 - if let Err(e) = tx.commit().await { 1139 - error!("Failed to commit recovery transaction: {:?}", e); 1140 - return ApiError::InternalError(None).into_response(); 1141 - } 1142 - 1143 - if passkeys_deleted > 0 { 1144 - info!(did = %input.did, count = passkeys_deleted, "Deleted lost passkeys during account recovery"); 955 + if result.passkeys_deleted > 0 { 956 + info!(did = %input.did, count = result.passkeys_deleted, "Deleted lost passkeys during account recovery"); 1145 957 } 1146 958 info!(did = %input.did, "Passkey-only account recovered with temporary password"); 1147 959 SuccessResponse::ok().into_response()
+66 -36
crates/tranquil-pds/src/api/server/passkeys.rs
··· 1 1 use crate::api::EmptyResponse; 2 2 use crate::api::error::ApiError; 3 3 use crate::auth::BearerAuth; 4 - use crate::auth::webauthn::{ 5 - self, WebAuthnConfig, delete_passkey as db_delete_passkey, delete_registration_state, 6 - get_passkeys_for_user, load_registration_state, save_passkey, save_registration_state, 7 - update_passkey_name as db_update_passkey_name, 8 - }; 4 + use crate::auth::webauthn::WebAuthnConfig; 9 5 use crate::state::AppState; 10 6 use axum::{ 11 7 Json, ··· 46 42 Err(e) => return e.into_response(), 47 43 }; 48 44 49 - let user = sqlx::query!("SELECT handle FROM users WHERE did = $1", &*auth.0.did) 50 - .fetch_optional(&state.db) 51 - .await; 52 - 53 - let handle = match user { 54 - Ok(Some(row)) => row.handle, 45 + let handle = match state.user_repo.get_handle_by_did(&auth.0.did).await { 46 + Ok(Some(h)) => h, 55 47 Ok(None) => { 56 48 return ApiError::AccountNotFound.into_response(); 57 49 } ··· 61 53 } 62 54 }; 63 55 64 - let existing_passkeys = match get_passkeys_for_user(&state.db, &auth.0.did).await { 56 + let existing_passkeys = match state.user_repo.get_passkeys_for_user(&auth.0.did).await { 65 57 Ok(passkeys) => passkeys, 66 58 Err(e) => { 67 59 error!("DB error fetching existing passkeys: {:?}", e); ··· 90 82 } 91 83 }; 92 84 93 - if let Err(e) = save_registration_state(&state.db, &auth.0.did, &reg_state).await { 85 + let state_json = match serde_json::to_string(&reg_state) { 86 + Ok(s) => s, 87 + Err(e) => { 88 + error!("Failed to serialize registration state: {:?}", e); 89 + return ApiError::InternalError(None).into_response(); 90 + } 91 + }; 92 + 93 + if let Err(e) = state 94 + .user_repo 95 + .save_webauthn_challenge(&auth.0.did, "registration", &state_json) 96 + .await 97 + { 94 98 error!("Failed to save registration state: {:?}", e); 95 99 return ApiError::InternalError(None).into_response(); 96 100 } ··· 126 130 Err(e) => return e.into_response(), 127 131 }; 128 132 129 - let reg_state = match load_registration_state(&state.db, &auth.0.did).await { 130 - Ok(Some(state)) => state, 133 + let reg_state_json = match state 134 + .user_repo 135 + .load_webauthn_challenge(&auth.0.did, "registration") 136 + .await 137 + { 138 + Ok(Some(json)) => json, 131 139 Ok(None) => { 132 140 return ApiError::NoRegistrationInProgress.into_response(); 133 141 } ··· 137 145 } 138 146 }; 139 147 148 + let reg_state: SecurityKeyRegistration = match serde_json::from_str(&reg_state_json) { 149 + Ok(s) => s, 150 + Err(e) => { 151 + error!("Failed to deserialize registration state: {:?}", e); 152 + return ApiError::InternalError(None).into_response(); 153 + } 154 + }; 155 + 140 156 let credential: RegisterPublicKeyCredential = match serde_json::from_value(input.credential) { 141 157 Ok(c) => c, 142 158 Err(e) => { ··· 153 169 } 154 170 }; 155 171 156 - let passkey_id = match save_passkey( 157 - &state.db, 158 - &auth.0.did, 159 - &passkey, 160 - input.friendly_name.as_deref(), 161 - ) 162 - .await 172 + let public_key = match serde_json::to_vec(&passkey) { 173 + Ok(pk) => pk, 174 + Err(e) => { 175 + error!("Failed to serialize passkey: {:?}", e); 176 + return ApiError::InternalError(None).into_response(); 177 + } 178 + }; 179 + 180 + let passkey_id = match state 181 + .user_repo 182 + .save_passkey( 183 + &auth.0.did, 184 + passkey.cred_id(), 185 + &public_key, 186 + input.friendly_name.as_deref(), 187 + ) 188 + .await 163 189 { 164 190 Ok(id) => id, 165 191 Err(e) => { ··· 168 194 } 169 195 }; 170 196 171 - if let Err(e) = delete_registration_state(&state.db, &auth.0.did).await { 197 + if let Err(e) = state 198 + .user_repo 199 + .delete_webauthn_challenge(&auth.0.did, "registration") 200 + .await 201 + { 172 202 warn!("Failed to delete registration state: {:?}", e); 173 203 } 174 204 ··· 203 233 } 204 234 205 235 pub async fn list_passkeys(State(state): State<AppState>, auth: BearerAuth) -> Response { 206 - let passkeys = match get_passkeys_for_user(&state.db, &auth.0.did).await { 236 + let passkeys = match state.user_repo.get_passkeys_for_user(&auth.0.did).await { 207 237 Ok(pks) => pks, 208 238 Err(e) => { 209 239 error!("DB error fetching passkeys: {:?}", e); ··· 239 269 auth: BearerAuth, 240 270 Json(input): Json<DeletePasskeyInput>, 241 271 ) -> Response { 242 - if !crate::api::server::reauth::check_legacy_session_mfa(&state.db, &auth.0.did).await { 243 - return crate::api::server::reauth::legacy_mfa_required_response(&state.db, &auth.0.did) 272 + if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.0.did).await { 273 + return crate::api::server::reauth::legacy_mfa_required_response(&*state.user_repo, &*state.session_repo, &auth.0.did) 244 274 .await; 245 275 } 246 276 247 - if crate::api::server::reauth::check_reauth_required(&state.db, &auth.0.did).await { 248 - return crate::api::server::reauth::reauth_required_response(&state.db, &auth.0.did).await; 277 + if crate::api::server::reauth::check_reauth_required(&*state.session_repo, &auth.0.did).await { 278 + return crate::api::server::reauth::reauth_required_response(&*state.user_repo, &*state.session_repo, &auth.0.did).await; 249 279 } 250 280 251 281 let id: uuid::Uuid = match input.id.parse() { ··· 255 285 } 256 286 }; 257 287 258 - match db_delete_passkey(&state.db, id, &auth.0.did).await { 288 + match state.user_repo.delete_passkey(id, &auth.0.did).await { 259 289 Ok(true) => { 260 290 info!(did = %auth.0.did, passkey_id = %id, "Passkey deleted"); 261 291 EmptyResponse::ok().into_response() ··· 287 317 } 288 318 }; 289 319 290 - match db_update_passkey_name(&state.db, id, &auth.0.did, &input.friendly_name).await { 320 + match state 321 + .user_repo 322 + .update_passkey_name(id, &auth.0.did, &input.friendly_name) 323 + .await 324 + { 291 325 Ok(true) => { 292 326 info!(did = %auth.0.did, passkey_id = %id, "Passkey renamed"); 293 327 EmptyResponse::ok().into_response() ··· 300 334 } 301 335 } 302 336 303 - pub async fn has_passkeys_for_user(state: &AppState, did: &str) -> bool { 304 - has_passkeys_for_user_db(&state.db, did).await 305 - } 306 - 307 - pub async fn has_passkeys_for_user_db(db: &sqlx::PgPool, did: &str) -> bool { 308 - webauthn::has_passkeys(db, did).await.unwrap_or(false) 337 + pub async fn has_passkeys_for_user(state: &AppState, did: &crate::types::Did) -> bool { 338 + state.user_repo.has_passkeys(did).await.unwrap_or(false) 309 339 }
+57 -162
crates/tranquil-pds/src/api/server/password.rs
··· 14 14 use chrono::{Duration, Utc}; 15 15 use serde::Deserialize; 16 16 use tracing::{error, info, warn}; 17 - use uuid::Uuid; 18 17 19 18 fn generate_reset_code() -> String { 20 19 crate::util::generate_token_code() ··· 58 57 return ApiError::InvalidRequest("email or handle is required".into()).into_response(); 59 58 } 60 59 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 60 + let hostname_for_handles = pds_hostname.split(':').next().unwrap_or(&pds_hostname); 61 61 let normalized = identifier.to_lowercase(); 62 62 let normalized = normalized.strip_prefix('@').unwrap_or(&normalized); 63 63 let normalized_handle = if normalized.contains('@') || normalized.contains('.') { 64 64 normalized.to_string() 65 65 } else { 66 - format!("{}.{}", normalized, pds_hostname) 66 + format!("{}.{}", normalized, hostname_for_handles) 67 67 }; 68 - let user = sqlx::query!( 69 - "SELECT id FROM users WHERE LOWER(email) = $1 OR handle = $2", 70 - normalized, 71 - normalized_handle 72 - ) 73 - .fetch_optional(&state.db) 74 - .await; 75 - let user_id = match user { 76 - Ok(Some(row)) => row.id, 68 + let user_id = match state 69 + .user_repo 70 + .get_id_by_email_or_handle(&normalized, &normalized_handle) 71 + .await 72 + { 73 + Ok(Some(id)) => id, 77 74 Ok(None) => { 78 75 info!("Password reset requested for unknown identifier"); 79 76 return EmptyResponse::ok().into_response(); ··· 85 82 }; 86 83 let code = generate_reset_code(); 87 84 let expires_at = Utc::now() + Duration::minutes(10); 88 - let update = sqlx::query!( 89 - "UPDATE users SET password_reset_code = $1, password_reset_code_expires_at = $2 WHERE id = $3", 90 - code, 91 - expires_at, 92 - user_id 93 - ) 94 - .execute(&state.db) 95 - .await; 96 - if let Err(e) = update { 85 + if let Err(e) = state 86 + .user_repo 87 + .set_password_reset_code(user_id, &code, expires_at) 88 + .await 89 + { 97 90 error!("DB error setting reset code: {:?}", e); 98 91 return ApiError::InternalError(None).into_response(); 99 92 } 100 93 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 101 - if let Err(e) = crate::comms::enqueue_password_reset(&state.db, user_id, &code, &hostname).await 94 + if let Err(e) = crate::comms::comms_repo::enqueue_password_reset( 95 + state.user_repo.as_ref(), 96 + state.infra_repo.as_ref(), 97 + user_id, 98 + &code, 99 + &hostname, 100 + ) 101 + .await 102 102 { 103 103 warn!("Failed to enqueue password reset notification: {:?}", e); 104 104 } ··· 136 136 if let Err(e) = validate_password(password) { 137 137 return ApiError::InvalidRequest(e.to_string()).into_response(); 138 138 } 139 - let user = sqlx::query!( 140 - "SELECT id, password_reset_code, password_reset_code_expires_at FROM users WHERE password_reset_code = $1", 141 - token 142 - ) 143 - .fetch_optional(&state.db) 144 - .await; 145 - let (user_id, expires_at) = match user { 146 - Ok(Some(row)) => { 147 - let expires = row.password_reset_code_expires_at; 148 - (row.id, expires) 149 - } 139 + let user = match state.user_repo.get_user_by_reset_code(token).await { 140 + Ok(Some(u)) => u, 150 141 Ok(None) => { 151 142 return ApiError::InvalidToken(None).into_response(); 152 143 } ··· 155 146 return ApiError::InternalError(None).into_response(); 156 147 } 157 148 }; 158 - if let Some(exp) = expires_at { 159 - if Utc::now() > exp { 160 - if let Err(e) = sqlx::query!( 161 - "UPDATE users SET password_reset_code = NULL, password_reset_code_expires_at = NULL WHERE id = $1", 162 - user_id 163 - ) 164 - .execute(&state.db) 165 - .await 166 - { 167 - error!("Failed to clear expired reset code: {:?}", e); 168 - } 169 - return ApiError::ExpiredToken(None).into_response(); 170 - } 171 - } else { 149 + let user_id = user.id; 150 + let Some(exp) = user.expires_at else { 172 151 return ApiError::InvalidToken(None).into_response(); 152 + }; 153 + if Utc::now() > exp { 154 + if let Err(e) = state.user_repo.clear_password_reset_code(user_id).await { 155 + error!("Failed to clear expired reset code: {:?}", e); 156 + } 157 + return ApiError::ExpiredToken(None).into_response(); 173 158 } 174 159 let password_clone = password.to_string(); 175 160 let password_hash = ··· 184 169 return ApiError::InternalError(None).into_response(); 185 170 } 186 171 }; 187 - let mut tx = match state.db.begin().await { 188 - Ok(tx) => tx, 189 - Err(e) => { 190 - error!("Failed to begin transaction: {:?}", e); 191 - return ApiError::InternalError(None).into_response(); 192 - } 193 - }; 194 - if let Err(e) = sqlx::query!( 195 - "UPDATE users SET password_hash = $1, password_reset_code = NULL, password_reset_code_expires_at = NULL, password_required = TRUE WHERE id = $2", 196 - password_hash, 197 - user_id 198 - ) 199 - .execute(&mut *tx) 200 - .await 201 - { 202 - error!("DB error updating password: {:?}", e); 203 - return ApiError::InternalError(None).into_response(); 204 - } 205 - let user_did = match sqlx::query_scalar!("SELECT did FROM users WHERE id = $1", user_id) 206 - .fetch_one(&mut *tx) 172 + let result = match state 173 + .user_repo 174 + .reset_password_with_sessions(user_id, &password_hash) 207 175 .await 208 176 { 209 - Ok(did) => did, 177 + Ok(r) => r, 210 178 Err(e) => { 211 - error!("Failed to get DID for user {}: {:?}", user_id, e); 179 + error!("Failed to reset password: {:?}", e); 212 180 return ApiError::InternalError(None).into_response(); 213 181 } 214 182 }; 215 - let session_jtis: Vec<String> = match sqlx::query_scalar!( 216 - "SELECT access_jti FROM session_tokens WHERE did = $1", 217 - user_did 218 - ) 219 - .fetch_all(&mut *tx) 220 - .await 221 - { 222 - Ok(jtis) => jtis, 223 - Err(e) => { 224 - error!("Failed to fetch session JTIs: {:?}", e); 225 - vec![] 226 - } 227 - }; 228 - if let Err(e) = sqlx::query!("DELETE FROM session_tokens WHERE did = $1", user_did) 229 - .execute(&mut *tx) 230 - .await 231 - { 232 - error!( 233 - "Failed to invalidate sessions after password reset: {:?}", 234 - e 235 - ); 236 - return ApiError::InternalError(None).into_response(); 237 - } 238 - if let Err(e) = tx.commit().await { 239 - error!("Failed to commit password reset transaction: {:?}", e); 240 - return ApiError::InternalError(None).into_response(); 241 - } 242 - futures::future::join_all(session_jtis.into_iter().map(|jti| { 243 - let cache_key = format!("auth:session:{}:{}", user_did, jti); 183 + futures::future::join_all(result.session_jtis.iter().map(|jti| { 184 + let cache_key = format!("auth:session:{}:{}", result.did, jti); 244 185 let cache = state.cache.clone(); 245 186 async move { 246 187 if let Err(e) = cache.delete(&cache_key).await { ··· 268 209 auth: BearerAuth, 269 210 Json(input): Json<ChangePasswordInput>, 270 211 ) -> Response { 271 - if !crate::api::server::reauth::check_legacy_session_mfa(&state.db, &auth.0.did).await { 272 - return crate::api::server::reauth::legacy_mfa_required_response(&state.db, &auth.0.did) 212 + if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.0.did).await { 213 + return crate::api::server::reauth::legacy_mfa_required_response(&*state.user_repo, &*state.session_repo, &auth.0.did) 273 214 .await; 274 215 } 275 216 ··· 284 225 if let Err(e) = validate_password(new_password) { 285 226 return ApiError::InvalidRequest(e.to_string()).into_response(); 286 227 } 287 - let user = 288 - sqlx::query_as::<_, (Uuid, String)>("SELECT id, password_hash FROM users WHERE did = $1") 289 - .bind(&auth.0.did) 290 - .fetch_optional(&state.db) 291 - .await; 292 - let (user_id, password_hash) = match user { 293 - Ok(Some(row)) => row, 228 + let user = match state.user_repo.get_id_and_password_hash_by_did(&auth.0.did).await { 229 + Ok(Some(u)) => u, 294 230 Ok(None) => { 295 231 return ApiError::AccountNotFound.into_response(); 296 232 } ··· 299 235 return ApiError::InternalError(None).into_response(); 300 236 } 301 237 }; 238 + let (user_id, password_hash) = (user.id, user.password_hash); 302 239 let valid = match verify(current_password, &password_hash) { 303 240 Ok(v) => v, 304 241 Err(e) => { ··· 322 259 return ApiError::InternalError(None).into_response(); 323 260 } 324 261 }; 325 - if let Err(e) = sqlx::query("UPDATE users SET password_hash = $1 WHERE id = $2") 326 - .bind(&new_hash) 327 - .bind(user_id) 328 - .execute(&state.db) 329 - .await 330 - { 262 + if let Err(e) = state.user_repo.update_password_hash(user_id, &new_hash).await { 331 263 error!("DB error updating password: {:?}", e); 332 264 return ApiError::InternalError(None).into_response(); 333 265 } ··· 336 268 } 337 269 338 270 pub async fn get_password_status(State(state): State<AppState>, auth: BearerAuth) -> Response { 339 - let user = sqlx::query!( 340 - "SELECT password_hash IS NOT NULL as has_password FROM users WHERE did = $1", 341 - &auth.0.did 342 - ) 343 - .fetch_optional(&state.db) 344 - .await; 345 - 346 - match user { 347 - Ok(Some(row)) => { 348 - HasPasswordResponse::response(row.has_password.unwrap_or(false)).into_response() 349 - } 271 + match state.user_repo.has_password_by_did(&auth.0.did).await { 272 + Ok(Some(has)) => HasPasswordResponse::response(has).into_response(), 350 273 Ok(None) => ApiError::AccountNotFound.into_response(), 351 274 Err(e) => { 352 275 error!("DB error: {:?}", e); ··· 356 279 } 357 280 358 281 pub async fn remove_password(State(state): State<AppState>, auth: BearerAuth) -> Response { 359 - if !crate::api::server::reauth::check_legacy_session_mfa(&state.db, &auth.0.did).await { 360 - return crate::api::server::reauth::legacy_mfa_required_response(&state.db, &auth.0.did) 282 + if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.0.did).await { 283 + return crate::api::server::reauth::legacy_mfa_required_response(&*state.user_repo, &*state.session_repo, &auth.0.did) 361 284 .await; 362 285 } 363 286 364 287 if crate::api::server::reauth::check_reauth_required_cached( 365 - &state.db, 288 + &*state.session_repo, 366 289 &state.cache, 367 290 &auth.0.did, 368 291 ) 369 292 .await 370 293 { 371 - return crate::api::server::reauth::reauth_required_response(&state.db, &auth.0.did).await; 294 + return crate::api::server::reauth::reauth_required_response(&*state.user_repo, &*state.session_repo, &auth.0.did).await; 372 295 } 373 296 374 - let has_passkeys = 375 - crate::api::server::passkeys::has_passkeys_for_user_db(&state.db, &auth.0.did).await; 297 + let has_passkeys = state.user_repo.has_passkeys(&auth.0.did).await.unwrap_or(false); 376 298 if !has_passkeys { 377 299 return ApiError::InvalidRequest( 378 300 "You must have at least one passkey registered before removing your password".into(), ··· 380 302 .into_response(); 381 303 } 382 304 383 - let user = sqlx::query!( 384 - "SELECT id, password_hash FROM users WHERE did = $1", 385 - &auth.0.did 386 - ) 387 - .fetch_optional(&state.db) 388 - .await; 389 - 390 - let user = match user { 305 + let user = match state.user_repo.get_password_info_by_did(&auth.0.did).await { 391 306 Ok(Some(u)) => u, 392 307 Ok(None) => { 393 308 return ApiError::AccountNotFound.into_response(); ··· 402 317 return ApiError::InvalidRequest("Account already has no password".into()).into_response(); 403 318 } 404 319 405 - if let Err(e) = sqlx::query!( 406 - "UPDATE users SET password_hash = NULL, password_required = FALSE WHERE id = $1", 407 - user.id 408 - ) 409 - .execute(&state.db) 410 - .await 411 - { 320 + if let Err(e) = state.user_repo.remove_user_password(user.id).await { 412 321 error!("DB error removing password: {:?}", e); 413 322 return ApiError::InternalError(None).into_response(); 414 323 } ··· 429 338 Json(input): Json<SetPasswordInput>, 430 339 ) -> Response { 431 340 if crate::api::server::reauth::check_reauth_required_cached( 432 - &state.db, 341 + &*state.session_repo, 433 342 &state.cache, 434 343 &auth.0.did, 435 344 ) 436 345 .await 437 346 { 438 - return crate::api::server::reauth::reauth_required_response(&state.db, &auth.0.did).await; 347 + return crate::api::server::reauth::reauth_required_response(&*state.user_repo, &*state.session_repo, &auth.0.did).await; 439 348 } 440 349 441 350 let new_password = &input.new_password; ··· 446 355 return ApiError::InvalidRequest(e.to_string()).into_response(); 447 356 } 448 357 449 - let user = sqlx::query!( 450 - "SELECT id, password_hash FROM users WHERE did = $1", 451 - &auth.0.did 452 - ) 453 - .fetch_optional(&state.db) 454 - .await; 455 - 456 - let user = match user { 358 + let user = match state.user_repo.get_password_info_by_did(&auth.0.did).await { 457 359 Ok(Some(u)) => u, 458 360 Ok(None) => { 459 361 return ApiError::AccountNotFound.into_response(); ··· 485 387 } 486 388 }; 487 389 488 - if let Err(e) = sqlx::query!( 489 - "UPDATE users SET password_hash = $1, password_required = TRUE WHERE id = $2", 490 - new_hash, 491 - user.id 492 - ) 493 - .execute(&state.db) 494 - .await 495 - { 390 + if let Err(e) = state.user_repo.set_new_user_password(user.id, &new_hash).await { 496 391 error!("DB error setting password: {:?}", e); 497 392 return ApiError::InternalError(None).into_response(); 498 393 }
+117 -136
crates/tranquil-pds/src/api/server/reauth.rs
··· 7 7 }; 8 8 use chrono::{DateTime, Utc}; 9 9 use serde::{Deserialize, Serialize}; 10 - use sqlx::PgPool; 11 10 use tracing::{error, info, warn}; 11 + use tranquil_db_traits::{SessionRepository, UserRepository}; 12 12 13 13 use crate::auth::BearerAuth; 14 14 use crate::state::{AppState, RateLimitKind}; ··· 25 25 } 26 26 27 27 pub async fn get_reauth_status(State(state): State<AppState>, auth: BearerAuth) -> Response { 28 - let session = sqlx::query!( 29 - "SELECT last_reauth_at FROM session_tokens WHERE did = $1 ORDER BY created_at DESC LIMIT 1", 30 - &auth.0.did 31 - ) 32 - .fetch_optional(&state.db) 33 - .await; 34 - 35 - let last_reauth_at = match session { 36 - Ok(Some(row)) => row.last_reauth_at, 37 - Ok(None) => None, 28 + let last_reauth_at = match state.session_repo.get_last_reauth_at(&auth.0.did).await { 29 + Ok(t) => t, 38 30 Err(e) => { 39 31 error!("DB error: {:?}", e); 40 32 return ApiError::InternalError(None).into_response(); ··· 42 34 }; 43 35 44 36 let reauth_required = is_reauth_required(last_reauth_at); 45 - let available_methods = get_available_reauth_methods(&state.db, &auth.0.did).await; 37 + let available_methods = 38 + get_available_reauth_methods(&*state.user_repo, &*state.session_repo, &auth.0.did).await; 46 39 47 40 Json(ReauthStatusResponse { 48 41 last_reauth_at, ··· 69 62 auth: BearerAuth, 70 63 Json(input): Json<PasswordReauthInput>, 71 64 ) -> Response { 72 - let user = sqlx::query!( 73 - "SELECT password_hash FROM users WHERE did = $1", 74 - &*&auth.0.did 75 - ) 76 - .fetch_optional(&state.db) 77 - .await; 78 - 79 - let password_hash = match user { 80 - Ok(Some(row)) => row.password_hash, 65 + let password_hash = match state.user_repo.get_password_hash_by_did(&auth.0.did).await { 66 + Ok(Some(hash)) => hash, 81 67 Ok(None) => { 82 68 return ApiError::AccountNotFound.into_response(); 83 69 } ··· 87 73 } 88 74 }; 89 75 90 - let password_valid = password_hash 91 - .as_ref() 92 - .map(|h| bcrypt::verify(&input.password, h).unwrap_or(false)) 93 - .unwrap_or(false); 76 + let password_valid = bcrypt::verify(&input.password, &password_hash).unwrap_or(false); 94 77 95 78 if !password_valid { 96 - let app_passwords = sqlx::query!( 97 - "SELECT ap.password_hash FROM app_passwords ap 98 - JOIN users u ON ap.user_id = u.id 99 - WHERE u.did = $1", 100 - &auth.0.did 101 - ) 102 - .fetch_all(&state.db) 103 - .await 104 - .unwrap_or_default(); 79 + let app_password_hashes = state 80 + .session_repo 81 + .get_app_password_hashes_by_did(&auth.0.did) 82 + .await 83 + .unwrap_or_default(); 105 84 106 - let app_password_valid = app_passwords 85 + let app_password_valid = app_password_hashes 107 86 .iter() 108 - .any(|ap| bcrypt::verify(&input.password, &ap.password_hash).unwrap_or(false)); 87 + .any(|h| bcrypt::verify(&input.password, h).unwrap_or(false)); 109 88 110 89 if !app_password_valid { 111 90 warn!(did = %&auth.0.did, "Re-auth failed: invalid password"); ··· 113 92 } 114 93 } 115 94 116 - match update_last_reauth_cached(&state.db, &state.cache, &auth.0.did).await { 95 + match update_last_reauth_cached(&*state.session_repo, &state.cache, &auth.0.did).await { 117 96 Ok(reauthed_at) => { 118 97 info!(did = %&auth.0.did, "Re-auth successful via password"); 119 98 Json(ReauthResponse { reauthed_at }).into_response() ··· 156 135 return ApiError::InvalidCode(Some("Invalid TOTP or backup code".into())).into_response(); 157 136 } 158 137 159 - match update_last_reauth_cached(&state.db, &state.cache, &auth.0.did).await { 138 + match update_last_reauth_cached(&*state.session_repo, &state.cache, &auth.0.did).await { 160 139 Ok(reauthed_at) => { 161 140 info!(did = %&auth.0.did, "Re-auth successful via TOTP"); 162 141 Json(ReauthResponse { reauthed_at }).into_response() ··· 177 156 pub async fn reauth_passkey_start(State(state): State<AppState>, auth: BearerAuth) -> Response { 178 157 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 179 158 180 - let stored_passkeys = 181 - match crate::auth::webauthn::get_passkeys_for_user(&state.db, &auth.0.did).await { 182 - Ok(pks) => pks, 183 - Err(e) => { 184 - error!("Failed to get passkeys: {:?}", e); 185 - return ApiError::InternalError(None).into_response(); 186 - } 187 - }; 159 + let stored_passkeys = match state.user_repo.get_passkeys_for_user(&auth.0.did).await { 160 + Ok(pks) => pks, 161 + Err(e) => { 162 + error!("Failed to get passkeys: {:?}", e); 163 + return ApiError::InternalError(None).into_response(); 164 + } 165 + }; 188 166 189 167 if stored_passkeys.is_empty() { 190 168 return ApiError::NoPasskeys.into_response(); ··· 192 170 193 171 let passkeys: Vec<webauthn_rs::prelude::SecurityKey> = stored_passkeys 194 172 .iter() 195 - .filter_map(|sp| sp.to_security_key().ok()) 173 + .filter_map(|sp| serde_json::from_slice(&sp.public_key).ok()) 196 174 .collect(); 197 175 198 176 if passkeys.is_empty() { ··· 215 193 } 216 194 }; 217 195 218 - if let Err(e) = 219 - crate::auth::webauthn::save_authentication_state(&state.db, &auth.0.did, &auth_state).await 196 + let state_json = match serde_json::to_string(&auth_state) { 197 + Ok(s) => s, 198 + Err(e) => { 199 + error!("Failed to serialize authentication state: {:?}", e); 200 + return ApiError::InternalError(None).into_response(); 201 + } 202 + }; 203 + 204 + if let Err(e) = state 205 + .user_repo 206 + .save_webauthn_challenge(&auth.0.did, "authentication", &state_json) 207 + .await 220 208 { 221 209 error!("Failed to save authentication state: {:?}", e); 222 210 return ApiError::InternalError(None).into_response(); ··· 239 227 ) -> Response { 240 228 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 241 229 242 - let auth_state = 243 - match crate::auth::webauthn::load_authentication_state(&state.db, &auth.0.did).await { 244 - Ok(Some(s)) => s, 245 - Ok(None) => { 246 - return ApiError::NoChallengeInProgress.into_response(); 247 - } 230 + let auth_state_json = match state 231 + .user_repo 232 + .load_webauthn_challenge(&auth.0.did, "authentication") 233 + .await 234 + { 235 + Ok(Some(json)) => json, 236 + Ok(None) => { 237 + return ApiError::NoChallengeInProgress.into_response(); 238 + } 239 + Err(e) => { 240 + error!("Failed to load authentication state: {:?}", e); 241 + return ApiError::InternalError(None).into_response(); 242 + } 243 + }; 244 + 245 + let auth_state: webauthn_rs::prelude::SecurityKeyAuthentication = 246 + match serde_json::from_str(&auth_state_json) { 247 + Ok(s) => s, 248 248 Err(e) => { 249 - error!("Failed to load authentication state: {:?}", e); 249 + error!("Failed to deserialize authentication state: {:?}", e); 250 250 return ApiError::InternalError(None).into_response(); 251 251 } 252 252 }; ··· 278 278 }; 279 279 280 280 let cred_id_bytes = auth_result.cred_id().as_ref(); 281 - match crate::auth::webauthn::update_passkey_counter( 282 - &state.db, 283 - cred_id_bytes, 284 - auth_result.counter(), 285 - ) 286 - .await 281 + match state 282 + .user_repo 283 + .update_passkey_counter(cred_id_bytes, auth_result.counter() as i32) 284 + .await 287 285 { 288 286 Ok(false) => { 289 287 warn!(did = %&auth.0.did, "Passkey counter anomaly detected - possible cloned key"); 290 - let _ = 291 - crate::auth::webauthn::delete_authentication_state(&state.db, &auth.0.did).await; 288 + let _ = state 289 + .user_repo 290 + .delete_webauthn_challenge(&auth.0.did, "authentication") 291 + .await; 292 292 return ApiError::PasskeyCounterAnomaly.into_response(); 293 293 } 294 294 Err(e) => { ··· 297 297 Ok(true) => {} 298 298 } 299 299 300 - let _ = crate::auth::webauthn::delete_authentication_state(&state.db, &auth.0.did).await; 300 + let _ = state 301 + .user_repo 302 + .delete_webauthn_challenge(&auth.0.did, "authentication") 303 + .await; 301 304 302 - match update_last_reauth_cached(&state.db, &state.cache, &auth.0.did).await { 305 + match update_last_reauth_cached(&*state.session_repo, &state.cache, &auth.0.did).await { 303 306 Ok(reauthed_at) => { 304 307 info!(did = %&auth.0.did, "Re-auth successful via passkey"); 305 308 Json(ReauthResponse { reauthed_at }).into_response() ··· 312 315 } 313 316 314 317 pub async fn update_last_reauth_cached( 315 - db: &PgPool, 318 + session_repo: &dyn SessionRepository, 316 319 cache: &std::sync::Arc<dyn crate::cache::Cache>, 317 - did: &str, 318 - ) -> Result<DateTime<Utc>, sqlx::Error> { 319 - let now = Utc::now(); 320 - sqlx::query!( 321 - "UPDATE session_tokens SET last_reauth_at = $1, mfa_verified = TRUE WHERE did = $2", 322 - now, 323 - did 324 - ) 325 - .execute(db) 326 - .await?; 320 + did: &crate::types::Did, 321 + ) -> Result<DateTime<Utc>, tranquil_db_traits::DbError> { 322 + let now = session_repo.update_last_reauth(did).await?; 327 323 let cache_key = format!("reauth:{}", did); 328 324 let _ = cache 329 325 .set( ··· 345 341 } 346 342 } 347 343 348 - async fn get_available_reauth_methods(db: &PgPool, did: &str) -> Vec<String> { 344 + async fn get_available_reauth_methods( 345 + user_repo: &dyn UserRepository, 346 + _session_repo: &dyn SessionRepository, 347 + did: &crate::types::Did, 348 + ) -> Vec<String> { 349 349 let mut methods = Vec::new(); 350 350 351 - let has_password = sqlx::query_scalar!( 352 - "SELECT password_hash IS NOT NULL as has_pw FROM users WHERE did = $1", 353 - did 354 - ) 355 - .fetch_optional(db) 356 - .await 357 - .ok() 358 - .flatten() 359 - .unwrap_or(Some(false)); 351 + let has_password = user_repo 352 + .get_password_hash_by_did(did) 353 + .await 354 + .ok() 355 + .flatten() 356 + .is_some(); 360 357 361 - if has_password == Some(true) { 358 + if has_password { 362 359 methods.push("password".to_string()); 363 360 } 364 361 365 - let has_totp = crate::api::server::totp::has_totp_enabled_db(db, did).await; 362 + let has_totp = user_repo.has_totp_enabled(did).await.unwrap_or(false); 366 363 if has_totp { 367 364 methods.push("totp".to_string()); 368 365 } 369 366 370 - let has_passkeys = crate::api::server::passkeys::has_passkeys_for_user_db(db, did).await; 367 + let has_passkeys = user_repo.has_passkeys(did).await.unwrap_or(false); 371 368 if has_passkeys { 372 369 methods.push("passkey".to_string()); 373 370 } ··· 375 372 methods 376 373 } 377 374 378 - pub async fn check_reauth_required(db: &PgPool, did: &str) -> bool { 379 - let session = sqlx::query!( 380 - "SELECT last_reauth_at FROM session_tokens WHERE did = $1 ORDER BY created_at DESC LIMIT 1", 381 - did 382 - ) 383 - .fetch_optional(db) 384 - .await; 385 - 386 - match session { 387 - Ok(Some(row)) => is_reauth_required(row.last_reauth_at), 375 + pub async fn check_reauth_required(session_repo: &dyn SessionRepository, did: &crate::types::Did) -> bool { 376 + match session_repo.get_last_reauth_at(did).await { 377 + Ok(last_reauth_at) => is_reauth_required(last_reauth_at), 388 378 _ => true, 389 379 } 390 380 } 391 381 392 382 pub async fn check_reauth_required_cached( 393 - db: &PgPool, 383 + session_repo: &dyn SessionRepository, 394 384 cache: &std::sync::Arc<dyn crate::cache::Cache>, 395 - did: &str, 385 + did: &crate::types::Did, 396 386 ) -> bool { 397 387 let cache_key = format!("reauth:{}", did); 398 388 if let Some(timestamp_str) = cache.get(&cache_key).await ··· 406 396 } 407 397 } 408 398 } 409 - let session = sqlx::query!( 410 - "SELECT last_reauth_at FROM session_tokens WHERE did = $1 ORDER BY created_at DESC LIMIT 1", 411 - did 412 - ) 413 - .fetch_optional(db) 414 - .await; 415 - 416 - match session { 417 - Ok(Some(row)) => is_reauth_required(row.last_reauth_at), 399 + match session_repo.get_last_reauth_at(did).await { 400 + Ok(last_reauth_at) => is_reauth_required(last_reauth_at), 418 401 _ => true, 419 402 } 420 403 } ··· 427 410 pub reauth_methods: Vec<String>, 428 411 } 429 412 430 - pub async fn reauth_required_response(db: &PgPool, did: &str) -> Response { 431 - let methods = get_available_reauth_methods(db, did).await; 413 + pub async fn reauth_required_response( 414 + user_repo: &dyn UserRepository, 415 + session_repo: &dyn SessionRepository, 416 + did: &crate::types::Did, 417 + ) -> Response { 418 + let methods = get_available_reauth_methods(user_repo, session_repo, did).await; 432 419 ( 433 420 StatusCode::UNAUTHORIZED, 434 421 Json(ReauthRequiredError { ··· 440 427 .into_response() 441 428 } 442 429 443 - pub async fn check_legacy_session_mfa(db: &PgPool, did: &str) -> bool { 444 - let session = sqlx::query!( 445 - "SELECT legacy_login, mfa_verified, last_reauth_at FROM session_tokens WHERE did = $1 ORDER BY created_at DESC LIMIT 1", 446 - did 447 - ) 448 - .fetch_optional(db) 449 - .await; 450 - 451 - match session { 452 - Ok(Some(row)) => { 453 - if !row.legacy_login { 430 + pub async fn check_legacy_session_mfa(session_repo: &dyn SessionRepository, did: &crate::types::Did) -> bool { 431 + match session_repo.get_session_mfa_status(did).await { 432 + Ok(Some(status)) => { 433 + if !status.legacy_login { 454 434 return true; 455 435 } 456 - if row.mfa_verified { 436 + if status.mfa_verified { 457 437 return true; 458 438 } 459 - if let Some(last_reauth) = row.last_reauth_at { 439 + if let Some(last_reauth) = status.last_reauth_at { 460 440 let elapsed = chrono::Utc::now().signed_duration_since(last_reauth); 461 441 if elapsed.num_seconds() <= REAUTH_WINDOW_SECONDS { 462 442 return true; ··· 468 448 } 469 449 } 470 450 471 - pub async fn update_mfa_verified(db: &PgPool, did: &str) -> Result<(), sqlx::Error> { 472 - sqlx::query!( 473 - "UPDATE session_tokens SET mfa_verified = TRUE, last_reauth_at = NOW() WHERE did = $1", 474 - did 475 - ) 476 - .execute(db) 477 - .await?; 478 - Ok(()) 451 + pub async fn update_mfa_verified( 452 + session_repo: &dyn SessionRepository, 453 + did: &crate::types::Did, 454 + ) -> Result<(), tranquil_db_traits::DbError> { 455 + session_repo.update_mfa_verified(did).await 479 456 } 480 457 481 - pub async fn legacy_mfa_required_response(db: &PgPool, did: &str) -> Response { 482 - let methods = get_available_reauth_methods(db, did).await; 458 + pub async fn legacy_mfa_required_response( 459 + user_repo: &dyn UserRepository, 460 + session_repo: &dyn SessionRepository, 461 + did: &crate::types::Did, 462 + ) -> Response { 463 + let methods = get_available_reauth_methods(user_repo, session_repo, did).await; 483 464 ( 484 465 StatusCode::FORBIDDEN, 485 466 Json(MfaVerificationRequiredError {
+28 -33
crates/tranquil-pds/src/api/server/service_auth.rs
··· 81 81 82 82 let auth_user = if is_dpop { 83 83 match crate::oauth::verify::verify_oauth_access_token( 84 - &state.db, 84 + state.oauth_repo.as_ref(), 85 85 &token, 86 86 dpop_proof, 87 87 "GET", ··· 119 119 } 120 120 } 121 121 } else { 122 - match crate::auth::validate_bearer_token_for_service_auth(&state.db, &token).await { 122 + match crate::auth::validate_bearer_token_for_service_auth(state.user_repo.as_ref(), &token).await { 123 123 Ok(user) => user, 124 124 Err(e) => { 125 125 warn!(error = ?e, "getServiceAuth auth validation failed"); ··· 137 137 Some(kb) => kb.clone(), 138 138 None => { 139 139 warn!(did = %&auth_user.did, "getServiceAuth: OAuth token has no key_bytes, fetching from DB"); 140 - match sqlx::query_as::<_, (Vec<u8>, Option<i32>)>( 141 - "SELECT k.key_bytes, k.encryption_version 142 - FROM users u 143 - JOIN user_keys k ON u.id = k.user_id 144 - WHERE u.did = $1", 145 - ) 146 - .bind(&auth_user.did) 147 - .fetch_optional(&state.db) 148 - .await 149 - { 150 - Ok(Some((key_bytes_enc, encryption_version))) => { 151 - match crate::config::decrypt_key(&key_bytes_enc, encryption_version) { 152 - Ok(key) => key, 153 - Err(e) => { 154 - error!(error = ?e, "Failed to decrypt user key for service auth"); 155 - return ApiError::AuthenticationFailed(Some( 156 - "Failed to get signing key".into(), 157 - )) 158 - .into_response(); 140 + match state.user_repo.get_user_info_by_did(&auth_user.did).await { 141 + Ok(Some(info)) => match info.key_bytes { 142 + Some(key_bytes_enc) => { 143 + match crate::config::decrypt_key(&key_bytes_enc, info.encryption_version) { 144 + Ok(key) => key, 145 + Err(e) => { 146 + error!(error = ?e, "Failed to decrypt user key for service auth"); 147 + return ApiError::AuthenticationFailed(Some( 148 + "Failed to get signing key".into(), 149 + )) 150 + .into_response(); 151 + } 159 152 } 160 153 } 161 - } 154 + None => { 155 + return ApiError::AuthenticationFailed(Some( 156 + "User has no signing key".into(), 157 + )) 158 + .into_response(); 159 + } 160 + }, 162 161 Ok(None) => { 163 162 return ApiError::AuthenticationFailed(Some("User has no signing key".into())) 164 163 .into_response(); ··· 196 195 } 197 196 } 198 197 199 - let user_status = sqlx::query!( 200 - "SELECT takedown_ref FROM users WHERE did = $1", 201 - &auth_user.did 202 - ) 203 - .fetch_optional(&state.db) 204 - .await; 205 - 206 - let is_takendown = match user_status { 207 - Ok(Some(row)) => row.takedown_ref.is_some(), 208 - _ => false, 209 - }; 198 + let is_takendown = state 199 + .user_repo 200 + .get_status_by_did(&auth_user.did) 201 + .await 202 + .ok() 203 + .flatten() 204 + .is_some_and(|s| s.takedown_ref.is_some()); 210 205 211 206 if is_takendown && lxm != Some("com.atproto.server.createAccount") { 212 207 return ApiError::InvalidToken(Some("Bad token scope".into())).into_response();
+200 -386
crates/tranquil-pds/src/api/server/session.rs
··· 13 13 use serde::{Deserialize, Serialize}; 14 14 use serde_json::json; 15 15 use tracing::{error, info, warn}; 16 + use tranquil_types::TokenId; 16 17 17 18 fn extract_client_ip(headers: &HeaderMap) -> String { 18 19 if let Some(forwarded) = headers.get("x-forwarded-for") ··· 90 91 return ApiError::RateLimitExceeded(None).into_response(); 91 92 } 92 93 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 93 - let normalized_identifier = normalize_handle(&input.identifier, &pds_hostname); 94 + let hostname_for_handles = pds_hostname.split(':').next().unwrap_or(&pds_hostname); 95 + let normalized_identifier = normalize_handle(&input.identifier, hostname_for_handles); 94 96 info!( 95 97 "Normalized identifier: {} -> {}", 96 98 input.identifier, normalized_identifier 97 99 ); 98 - let row = match sqlx::query!( 99 - r#"SELECT 100 - u.id, u.did, u.handle, u.password_hash, u.email, u.deactivated_at, u.takedown_ref, 101 - u.email_verified, u.discord_verified, u.telegram_verified, u.signal_verified, 102 - u.allow_legacy_login, u.migrated_to_pds, 103 - u.preferred_comms_channel as "preferred_comms_channel: crate::comms::CommsChannel", 104 - k.key_bytes, k.encryption_version, 105 - (SELECT verified FROM user_totp WHERE did = u.did) as totp_enabled 106 - FROM users u 107 - JOIN user_keys k ON u.id = k.user_id 108 - WHERE u.handle = $1 OR u.email = $1 OR u.did = $1"#, 109 - normalized_identifier 110 - ) 111 - .fetch_optional(&state.db) 112 - .await 100 + let row = match state 101 + .user_repo 102 + .get_login_full_by_identifier(&normalized_identifier) 103 + .await 113 104 { 114 105 Ok(Some(row)) => row, 115 106 Ok(None) => { ··· 141 132 { 142 133 (true, None, None, None) 143 134 } else { 144 - let app_passwords = sqlx::query!( 145 - "SELECT name, password_hash, scopes, created_by_controller_did FROM app_passwords WHERE user_id = $1 ORDER BY created_at DESC LIMIT 20", 146 - row.id 147 - ) 148 - .fetch_all(&state.db) 149 - .await 150 - .unwrap_or_default(); 135 + let app_passwords = state 136 + .session_repo 137 + .get_app_passwords_for_login(row.id) 138 + .await 139 + .unwrap_or_default(); 151 140 let matched = app_passwords 152 141 .iter() 153 142 .find(|app| verify(&input.password, &app.password_hash).unwrap_or(false)); ··· 178 167 } 179 168 let is_verified = 180 169 row.email_verified || row.discord_verified || row.telegram_verified || row.signal_verified; 181 - let is_delegated = crate::delegation::is_delegated_account(&state.db, &row.did) 170 + let is_delegated = state 171 + .delegation_repo 172 + .is_delegated_account(&row.did) 182 173 .await 183 174 .unwrap_or(false); 184 175 if !is_verified && !is_delegated { ··· 193 184 ) 194 185 .into_response(); 195 186 } 196 - let has_totp = row.totp_enabled.unwrap_or(false); 187 + let has_totp = row.totp_enabled; 197 188 let is_legacy_login = has_totp; 198 189 if has_totp && !row.allow_legacy_login { 199 190 warn!("Legacy login blocked for TOTP-enabled account: {}", row.did); ··· 229 220 }; 230 221 let did_for_doc = row.did.clone(); 231 222 let did_resolver = state.did_resolver.clone(); 223 + let session_data = tranquil_db_traits::SessionTokenCreate { 224 + did: row.did.clone(), 225 + access_jti: access_meta.jti.clone(), 226 + refresh_jti: refresh_meta.jti.clone(), 227 + access_expires_at: access_meta.expires_at, 228 + refresh_expires_at: refresh_meta.expires_at, 229 + legacy_login: is_legacy_login, 230 + mfa_verified: false, 231 + scope: app_password_scopes.clone(), 232 + controller_did: app_password_controller.clone(), 233 + app_password_name: app_password_name.clone(), 234 + }; 232 235 let (insert_result, did_doc) = tokio::join!( 233 - sqlx::query!( 234 - "INSERT INTO session_tokens (did, access_jti, refresh_jti, access_expires_at, refresh_expires_at, legacy_login, mfa_verified, scope, controller_did, app_password_name) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)", 235 - row.did, 236 - access_meta.jti, 237 - refresh_meta.jti, 238 - access_meta.expires_at, 239 - refresh_meta.expires_at, 240 - is_legacy_login, 241 - false, 242 - app_password_scopes, 243 - app_password_controller, 244 - app_password_name 245 - ) 246 - .execute(&state.db), 236 + state.session_repo.create_session(&session_data), 247 237 did_resolver.resolve_did_document(&did_for_doc) 248 238 ); 249 239 if let Err(e) = insert_result { ··· 257 247 "Legacy login on TOTP-enabled account - sending notification" 258 248 ); 259 249 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 260 - if let Err(e) = crate::comms::queue_legacy_login_notification( 261 - &state.db, 250 + if let Err(e) = crate::comms::comms_repo::enqueue_legacy_login( 251 + state.user_repo.as_ref(), 252 + state.infra_repo.as_ref(), 262 253 row.id, 263 254 &hostname, 264 255 &client_ip, ··· 296 287 let did_for_doc = auth_user.did.clone(); 297 288 let did_resolver = state.did_resolver.clone(); 298 289 let (db_result, did_doc) = tokio::join!( 299 - sqlx::query!( 300 - r#"SELECT 301 - handle, email, email_verified, is_admin, deactivated_at, takedown_ref, preferred_locale, 302 - preferred_comms_channel as "preferred_channel: crate::comms::CommsChannel", 303 - discord_verified, telegram_verified, signal_verified, migrated_to_pds, migrated_at 304 - FROM users WHERE did = $1"#, 305 - &auth_user.did 306 - ) 307 - .fetch_optional(&state.db), 290 + state.user_repo.get_session_info_by_did(&auth_user.did), 308 291 did_resolver.resolve_did_document(&did_for_doc) 309 292 ); 310 293 match db_result { 311 294 Ok(Some(row)) => { 312 - let (preferred_channel, preferred_channel_verified) = match row.preferred_channel { 313 - crate::comms::CommsChannel::Email => ("email", row.email_verified), 314 - crate::comms::CommsChannel::Discord => ("discord", row.discord_verified), 315 - crate::comms::CommsChannel::Telegram => ("telegram", row.telegram_verified), 316 - crate::comms::CommsChannel::Signal => ("signal", row.signal_verified), 295 + let (preferred_channel, preferred_channel_verified) = match row.preferred_comms_channel { 296 + tranquil_db_traits::CommsChannel::Email => ("email", row.email_verified), 297 + tranquil_db_traits::CommsChannel::Discord => ("discord", row.discord_verified), 298 + tranquil_db_traits::CommsChannel::Telegram => ("telegram", row.telegram_verified), 299 + tranquil_db_traits::CommsChannel::Signal => ("signal", row.signal_verified), 317 300 }; 318 301 let pds_hostname = 319 302 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); ··· 379 362 Err(_) => return ApiError::AuthenticationFailed(None).into_response(), 380 363 }; 381 364 let did = crate::auth::get_did_from_token(&extracted.token).ok(); 382 - match sqlx::query!("DELETE FROM session_tokens WHERE access_jti = $1", jti) 383 - .execute(&state.db) 384 - .await 385 - { 386 - Ok(res) if res.rows_affected() > 0 => { 365 + match state.session_repo.delete_session_by_access_jti(&jti).await { 366 + Ok(rows) if rows > 0 => { 387 367 if let Some(did) = did { 388 368 let session_cache_key = format!("auth:session:{}:{}", did, jti); 389 369 let _ = state.cache.delete(&session_cache_key).await; ··· 391 371 EmptyResponse::ok().into_response() 392 372 } 393 373 Ok(_) => ApiError::AuthenticationFailed(None).into_response(), 394 - Err(e) => { 395 - error!("Database error in delete_session: {:?}", e); 396 - ApiError::AuthenticationFailed(None).into_response() 397 - } 374 + Err(_) => ApiError::AuthenticationFailed(None).into_response(), 398 375 } 399 376 } 400 377 ··· 424 401 .into_response(); 425 402 } 426 403 }; 427 - let mut tx = match state.db.begin().await { 428 - Ok(tx) => tx, 429 - Err(e) => { 430 - error!("Failed to begin transaction: {:?}", e); 431 - return ApiError::InternalError(None).into_response(); 432 - } 433 - }; 434 - if let Ok(Some(session_id)) = sqlx::query_scalar!( 435 - "SELECT session_id FROM used_refresh_tokens WHERE refresh_jti = $1 FOR UPDATE", 436 - refresh_jti 437 - ) 438 - .fetch_optional(&mut *tx) 439 - .await 404 + if let Ok(Some(_)) = state 405 + .session_repo 406 + .check_refresh_token_used(&refresh_jti) 407 + .await 440 408 { 441 - warn!( 442 - "Refresh token reuse detected! Revoking token family for session_id: {}", 443 - session_id 444 - ); 445 - let _ = sqlx::query!("DELETE FROM session_tokens WHERE id = $1", session_id) 446 - .execute(&mut *tx) 447 - .await; 448 - let _ = tx.commit().await; 409 + warn!("Refresh token reuse detected for jti: {}", refresh_jti); 449 410 return ApiError::AuthenticationFailed(Some( 450 411 "Refresh token has been revoked due to suspected compromise".into(), 451 412 )) 452 413 .into_response(); 453 414 } 454 - let session_row = match sqlx::query!( 455 - r#"SELECT st.id, st.did, st.scope, st.controller_did, k.key_bytes, k.encryption_version 456 - FROM session_tokens st 457 - JOIN users u ON st.did = u.did 458 - JOIN user_keys k ON u.id = k.user_id 459 - WHERE st.refresh_jti = $1 AND st.refresh_expires_at > NOW() 460 - FOR UPDATE OF st"#, 461 - refresh_jti 462 - ) 463 - .fetch_optional(&mut *tx) 464 - .await 415 + let session_row = match state 416 + .session_repo 417 + .get_session_for_refresh(&refresh_jti) 418 + .await 465 419 { 466 420 Ok(Some(row)) => row, 467 421 Ok(None) => { ··· 474 428 } 475 429 }; 476 430 let key_bytes = 477 - match crate::config::decrypt_key(&session_row.key_bytes, session_row.encryption_version) { 431 + match crate::config::decrypt_key(&session_row.key_bytes, Some(session_row.encryption_version)) { 478 432 Ok(k) => k, 479 433 Err(e) => { 480 434 error!("Failed to decrypt user key: {:?}", e); ··· 506 460 return ApiError::InternalError(None).into_response(); 507 461 } 508 462 }; 509 - match sqlx::query!( 510 - "INSERT INTO used_refresh_tokens (refresh_jti, session_id) VALUES ($1, $2) ON CONFLICT (refresh_jti) DO NOTHING", 511 - refresh_jti, 512 - session_row.id 513 - ) 514 - .execute(&mut *tx) 515 - .await 463 + let refresh_data = tranquil_db_traits::SessionRefreshData { 464 + old_refresh_jti: refresh_jti.clone(), 465 + session_id: session_row.id, 466 + new_access_jti: new_access_meta.jti.clone(), 467 + new_refresh_jti: new_refresh_meta.jti.clone(), 468 + new_access_expires_at: new_access_meta.expires_at, 469 + new_refresh_expires_at: new_refresh_meta.expires_at, 470 + }; 471 + match state 472 + .session_repo 473 + .refresh_session_atomic(&refresh_data) 474 + .await 516 475 { 517 - Ok(result) if result.rows_affected() == 0 => { 518 - warn!("Concurrent refresh token reuse detected for session_id: {}", session_row.id); 519 - let _ = sqlx::query!("DELETE FROM session_tokens WHERE id = $1", session_row.id) 520 - .execute(&mut *tx) 521 - .await; 522 - let _ = tx.commit().await; 523 - return ApiError::AuthenticationFailed(Some("Refresh token has been revoked due to suspected compromise".into())).into_response(); 476 + Ok(tranquil_db_traits::RefreshSessionResult::Success) => {} 477 + Ok(tranquil_db_traits::RefreshSessionResult::TokenAlreadyUsed) => { 478 + warn!("Refresh token reuse detected during atomic operation"); 479 + return ApiError::AuthenticationFailed(Some( 480 + "Refresh token has been revoked due to suspected compromise".into(), 481 + )) 482 + .into_response(); 483 + } 484 + Ok(tranquil_db_traits::RefreshSessionResult::ConcurrentRefresh) => { 485 + warn!("Concurrent refresh detected for session_id: {}", session_row.id); 486 + return ApiError::AuthenticationFailed(Some( 487 + "Refresh token has been revoked due to suspected compromise".into(), 488 + )) 489 + .into_response(); 524 490 } 525 491 Err(e) => { 526 - error!("Failed to record used refresh token: {:?}", e); 492 + error!("Database error during session refresh: {:?}", e); 527 493 return ApiError::InternalError(None).into_response(); 528 494 } 529 - Ok(_) => {} 530 - } 531 - if let Err(e) = sqlx::query!( 532 - "UPDATE session_tokens SET access_jti = $1, refresh_jti = $2, access_expires_at = $3, refresh_expires_at = $4, updated_at = NOW() WHERE id = $5", 533 - new_access_meta.jti, 534 - new_refresh_meta.jti, 535 - new_access_meta.expires_at, 536 - new_refresh_meta.expires_at, 537 - session_row.id 538 - ) 539 - .execute(&mut *tx) 540 - .await 541 - { 542 - error!("Database error updating session: {:?}", e); 543 - return ApiError::InternalError(None).into_response(); 544 - } 545 - if let Err(e) = tx.commit().await { 546 - error!("Failed to commit transaction: {:?}", e); 547 - return ApiError::InternalError(None).into_response(); 548 495 } 549 496 let did_for_doc = session_row.did.clone(); 550 497 let did_resolver = state.did_resolver.clone(); 551 498 let (db_result, did_doc) = tokio::join!( 552 - sqlx::query!( 553 - r#"SELECT 554 - handle, email, email_verified, is_admin, preferred_locale, deactivated_at, takedown_ref, 555 - preferred_comms_channel as "preferred_channel: crate::comms::CommsChannel", 556 - discord_verified, telegram_verified, signal_verified 557 - FROM users WHERE did = $1"#, 558 - session_row.did 559 - ) 560 - .fetch_optional(&state.db), 499 + state.user_repo.get_session_info_by_did(&session_row.did), 561 500 did_resolver.resolve_did_document(&did_for_doc) 562 501 ); 563 502 match db_result { 564 503 Ok(Some(u)) => { 565 - let (preferred_channel, preferred_channel_verified) = match u.preferred_channel { 566 - crate::comms::CommsChannel::Email => ("email", u.email_verified), 567 - crate::comms::CommsChannel::Discord => ("discord", u.discord_verified), 568 - crate::comms::CommsChannel::Telegram => ("telegram", u.telegram_verified), 569 - crate::comms::CommsChannel::Signal => ("signal", u.signal_verified), 504 + let (preferred_channel, preferred_channel_verified) = match u.preferred_comms_channel { 505 + tranquil_db_traits::CommsChannel::Email => ("email", u.email_verified), 506 + tranquil_db_traits::CommsChannel::Discord => ("discord", u.discord_verified), 507 + tranquil_db_traits::CommsChannel::Telegram => ("telegram", u.telegram_verified), 508 + tranquil_db_traits::CommsChannel::Signal => ("signal", u.signal_verified), 570 509 }; 571 510 let pds_hostname = 572 511 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); ··· 630 569 Json(input): Json<ConfirmSignupInput>, 631 570 ) -> Response { 632 571 info!("confirm_signup called for DID: {}", input.did); 633 - let row = match sqlx::query!( 634 - r#"SELECT 635 - u.id, u.did, u.handle, u.email, 636 - u.preferred_comms_channel as "channel: crate::comms::CommsChannel", 637 - u.discord_id, u.telegram_username, u.signal_number, 638 - k.key_bytes, k.encryption_version 639 - FROM users u 640 - JOIN user_keys k ON u.id = k.user_id 641 - WHERE u.did = $1"#, 642 - input.did.as_str() 643 - ) 644 - .fetch_optional(&state.db) 645 - .await 572 + let row = match state 573 + .user_repo 574 + .get_confirm_signup_by_did(&input.did) 575 + .await 646 576 { 647 577 Ok(Some(row)) => row, 648 578 Ok(None) => { ··· 657 587 }; 658 588 659 589 let (channel_str, identifier) = match row.channel { 660 - crate::comms::CommsChannel::Email => ("email", row.email.clone().unwrap_or_default()), 661 - crate::comms::CommsChannel::Discord => { 590 + tranquil_db_traits::CommsChannel::Email => ("email", row.email.clone().unwrap_or_default()), 591 + tranquil_db_traits::CommsChannel::Discord => { 662 592 ("discord", row.discord_id.clone().unwrap_or_default()) 663 593 } 664 - crate::comms::CommsChannel::Telegram => ( 594 + tranquil_db_traits::CommsChannel::Telegram => ( 665 595 "telegram", 666 596 row.telegram_username.clone().unwrap_or_default(), 667 597 ), 668 - crate::comms::CommsChannel::Signal => { 598 + tranquil_db_traits::CommsChannel::Signal => { 669 599 ("signal", row.signal_number.clone().unwrap_or_default()) 670 600 } 671 601 }; ··· 721 651 } 722 652 }; 723 653 724 - let mut tx = match state.db.begin().await { 725 - Ok(tx) => tx, 726 - Err(e) => { 727 - error!("Failed to begin transaction: {:?}", e); 728 - return ApiError::InternalError(None).into_response(); 729 - } 730 - }; 731 - 732 - let verified_column = match row.channel { 733 - crate::comms::CommsChannel::Email => "email_verified", 734 - crate::comms::CommsChannel::Discord => "discord_verified", 735 - crate::comms::CommsChannel::Telegram => "telegram_verified", 736 - crate::comms::CommsChannel::Signal => "signal_verified", 737 - }; 738 - let update_query = format!("UPDATE users SET {} = TRUE WHERE did = $1", verified_column); 739 - if let Err(e) = sqlx::query(&update_query) 740 - .bind(input.did.as_str()) 741 - .execute(&mut *tx) 654 + if let Err(e) = state 655 + .user_repo 656 + .set_channel_verified(&input.did, row.channel.clone()) 742 657 .await 743 658 { 744 659 error!("Failed to update verification status: {:?}", e); 745 660 return ApiError::InternalError(None).into_response(); 746 661 } 747 662 748 - let no_scope: Option<String> = None; 749 - if let Err(e) = sqlx::query!( 750 - "INSERT INTO session_tokens (did, access_jti, refresh_jti, access_expires_at, refresh_expires_at, legacy_login, mfa_verified, scope) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)", 751 - row.did, 752 - access_meta.jti, 753 - refresh_meta.jti, 754 - access_meta.expires_at, 755 - refresh_meta.expires_at, 756 - false, 757 - false, 758 - no_scope 759 - ) 760 - .execute(&mut *tx) 761 - .await 762 - { 663 + let session_data = tranquil_db_traits::SessionTokenCreate { 664 + did: row.did.clone(), 665 + access_jti: access_meta.jti.clone(), 666 + refresh_jti: refresh_meta.jti.clone(), 667 + access_expires_at: access_meta.expires_at, 668 + refresh_expires_at: refresh_meta.expires_at, 669 + legacy_login: false, 670 + mfa_verified: false, 671 + scope: None, 672 + controller_did: None, 673 + app_password_name: None, 674 + }; 675 + if let Err(e) = state.session_repo.create_session(&session_data).await { 763 676 error!("Failed to insert session: {:?}", e); 764 677 return ApiError::InternalError(None).into_response(); 765 678 } 766 679 767 - if let Err(e) = tx.commit().await { 768 - error!("Failed to commit transaction: {:?}", e); 769 - return ApiError::InternalError(None).into_response(); 770 - } 771 - 772 680 let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 773 - if let Err(e) = crate::comms::enqueue_welcome(&state.db, row.id, &hostname).await { 681 + if let Err(e) = crate::comms::comms_repo::enqueue_welcome( 682 + state.user_repo.as_ref(), 683 + state.infra_repo.as_ref(), 684 + row.id, 685 + &hostname, 686 + ) 687 + .await 688 + { 774 689 warn!("Failed to enqueue welcome notification: {:?}", e); 775 690 } 776 - let email_verified = matches!(row.channel, crate::comms::CommsChannel::Email); 691 + let email_verified = matches!(row.channel, tranquil_db_traits::CommsChannel::Email); 777 692 let preferred_channel = match row.channel { 778 - crate::comms::CommsChannel::Email => "email", 779 - crate::comms::CommsChannel::Discord => "discord", 780 - crate::comms::CommsChannel::Telegram => "telegram", 781 - crate::comms::CommsChannel::Signal => "signal", 693 + tranquil_db_traits::CommsChannel::Email => "email", 694 + tranquil_db_traits::CommsChannel::Discord => "discord", 695 + tranquil_db_traits::CommsChannel::Telegram => "telegram", 696 + tranquil_db_traits::CommsChannel::Signal => "signal", 782 697 }; 783 698 Json(ConfirmSignupOutput { 784 699 access_jwt: access_meta.token, ··· 804 719 Json(input): Json<ResendVerificationInput>, 805 720 ) -> Response { 806 721 info!("resend_verification called for DID: {}", input.did); 807 - let row = match sqlx::query!( 808 - r#"SELECT 809 - id, handle, email, 810 - preferred_comms_channel as "channel: crate::comms::CommsChannel", 811 - discord_id, telegram_username, signal_number, 812 - email_verified, discord_verified, telegram_verified, signal_verified 813 - FROM users 814 - WHERE did = $1"#, 815 - input.did.as_str() 816 - ) 817 - .fetch_optional(&state.db) 818 - .await 722 + let row = match state 723 + .user_repo 724 + .get_resend_verification_by_did(&input.did) 725 + .await 819 726 { 820 727 Ok(Some(row)) => row, 821 728 Ok(None) => { ··· 833 740 } 834 741 835 742 let (channel_str, recipient) = match row.channel { 836 - crate::comms::CommsChannel::Email => ("email", row.email.clone().unwrap_or_default()), 837 - crate::comms::CommsChannel::Discord => { 743 + tranquil_db_traits::CommsChannel::Email => ("email", row.email.clone().unwrap_or_default()), 744 + tranquil_db_traits::CommsChannel::Discord => { 838 745 ("discord", row.discord_id.clone().unwrap_or_default()) 839 746 } 840 - crate::comms::CommsChannel::Telegram => ( 747 + tranquil_db_traits::CommsChannel::Telegram => ( 841 748 "telegram", 842 749 row.telegram_username.clone().unwrap_or_default(), 843 750 ), 844 - crate::comms::CommsChannel::Signal => { 751 + tranquil_db_traits::CommsChannel::Signal => { 845 752 ("signal", row.signal_number.clone().unwrap_or_default()) 846 753 } 847 754 }; ··· 851 758 let formatted_token = 852 759 crate::auth::verification_token::format_token_for_display(&verification_token); 853 760 854 - if let Err(e) = crate::comms::enqueue_signup_verification( 855 - &state.db, 761 + let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 762 + if let Err(e) = crate::comms::comms_repo::enqueue_signup_verification( 763 + state.infra_repo.as_ref(), 856 764 row.id, 857 765 channel_str, 858 766 &recipient, 859 767 &formatted_token, 860 - None, 768 + &hostname, 861 769 ) 862 770 .await 863 771 { ··· 894 802 .and_then(|v| v.strip_prefix("Bearer ")) 895 803 .and_then(|token| crate::auth::get_jti_from_token(token).ok()); 896 804 897 - let jwt_rows = match sqlx::query_as::< 898 - _, 899 - ( 900 - i32, 901 - String, 902 - chrono::DateTime<chrono::Utc>, 903 - chrono::DateTime<chrono::Utc>, 904 - ), 905 - >( 906 - r#" 907 - SELECT id, access_jti, created_at, refresh_expires_at 908 - FROM session_tokens 909 - WHERE did = $1 AND refresh_expires_at > NOW() 910 - ORDER BY created_at DESC 911 - "#, 912 - ) 913 - .bind(&auth.0.did) 914 - .fetch_all(&state.db) 915 - .await 916 - { 805 + let jwt_rows = match state.session_repo.list_sessions_by_did(&auth.0.did).await { 917 806 Ok(rows) => rows, 918 807 Err(e) => { 919 808 error!("DB error fetching JWT sessions: {:?}", e); ··· 921 810 } 922 811 }; 923 812 924 - let oauth_rows = match sqlx::query_as::< 925 - _, 926 - ( 927 - i32, 928 - String, 929 - chrono::DateTime<chrono::Utc>, 930 - chrono::DateTime<chrono::Utc>, 931 - String, 932 - ), 933 - >( 934 - r#" 935 - SELECT id, token_id, created_at, expires_at, client_id 936 - FROM oauth_token 937 - WHERE did = $1 AND expires_at > NOW() 938 - ORDER BY created_at DESC 939 - "#, 940 - ) 941 - .bind(&auth.0.did) 942 - .fetch_all(&state.db) 943 - .await 944 - { 813 + let oauth_rows = match state.oauth_repo.list_sessions_by_did(&auth.0.did).await { 945 814 Ok(rows) => rows, 946 815 Err(e) => { 947 816 error!("DB error fetching OAuth sessions: {:?}", e); ··· 949 818 } 950 819 }; 951 820 952 - let jwt_sessions = jwt_rows 953 - .into_iter() 954 - .map(|(id, access_jti, created_at, expires_at)| SessionInfo { 955 - id: format!("jwt:{}", id), 956 - session_type: "legacy".to_string(), 957 - client_name: None, 958 - created_at: created_at.to_rfc3339(), 959 - expires_at: expires_at.to_rfc3339(), 960 - is_current: current_jti.as_ref() == Some(&access_jti), 961 - }); 821 + let jwt_sessions = jwt_rows.into_iter().map(|row| SessionInfo { 822 + id: format!("jwt:{}", row.id), 823 + session_type: "legacy".to_string(), 824 + client_name: None, 825 + created_at: row.created_at.to_rfc3339(), 826 + expires_at: row.refresh_expires_at.to_rfc3339(), 827 + is_current: current_jti.as_ref() == Some(&row.access_jti), 828 + }); 962 829 963 830 let is_oauth = auth.0.is_oauth; 964 - let oauth_sessions = 965 - oauth_rows 966 - .into_iter() 967 - .map(|(id, token_id, created_at, expires_at, client_id)| { 968 - let client_name = extract_client_name(&client_id); 969 - let is_current_oauth = is_oauth && current_jti.as_ref() == Some(&token_id); 970 - SessionInfo { 971 - id: format!("oauth:{}", id), 972 - session_type: "oauth".to_string(), 973 - client_name: Some(client_name), 974 - created_at: created_at.to_rfc3339(), 975 - expires_at: expires_at.to_rfc3339(), 976 - is_current: is_current_oauth, 977 - } 978 - }); 831 + let oauth_sessions = oauth_rows.into_iter().map(|row| { 832 + let client_name = extract_client_name(&row.client_id); 833 + let is_current_oauth = is_oauth && current_jti.as_ref().map(|s| s.as_str()) == Some(row.token_id.as_str()); 834 + SessionInfo { 835 + id: format!("oauth:{}", row.id), 836 + session_type: "oauth".to_string(), 837 + client_name: Some(client_name), 838 + created_at: row.created_at.to_rfc3339(), 839 + expires_at: row.expires_at.to_rfc3339(), 840 + is_current: is_current_oauth, 841 + } 842 + }); 979 843 980 844 let mut sessions: Vec<SessionInfo> = jwt_sessions.chain(oauth_sessions).collect(); 981 845 sessions.sort_by(|a, b| b.created_at.cmp(&a.created_at)); ··· 1008 872 let Ok(session_id) = jwt_id.parse::<i32>() else { 1009 873 return ApiError::InvalidRequest("Invalid session ID".into()).into_response(); 1010 874 }; 1011 - let session = sqlx::query_as::<_, (String,)>( 1012 - "SELECT access_jti FROM session_tokens WHERE id = $1 AND did = $2", 1013 - ) 1014 - .bind(session_id) 1015 - .bind(&auth.0.did) 1016 - .fetch_optional(&state.db) 1017 - .await; 1018 - let access_jti = match session { 1019 - Ok(Some((jti,))) => jti, 875 + let access_jti = match state 876 + .session_repo 877 + .get_session_access_jti_by_id(session_id, &auth.0.did) 878 + .await 879 + { 880 + Ok(Some(jti)) => jti, 1020 881 Ok(None) => { 1021 882 return ApiError::SessionNotFound.into_response(); 1022 883 } ··· 1025 886 return ApiError::InternalError(None).into_response(); 1026 887 } 1027 888 }; 1028 - if let Err(e) = sqlx::query("DELETE FROM session_tokens WHERE id = $1") 1029 - .bind(session_id) 1030 - .execute(&state.db) 1031 - .await 1032 - { 889 + if let Err(e) = state.session_repo.delete_session_by_id(session_id).await { 1033 890 error!("DB error deleting session: {:?}", e); 1034 891 return ApiError::InternalError(None).into_response(); 1035 892 } ··· 1042 899 let Ok(session_id) = oauth_id.parse::<i32>() else { 1043 900 return ApiError::InvalidRequest("Invalid session ID".into()).into_response(); 1044 901 }; 1045 - let result = sqlx::query("DELETE FROM oauth_token WHERE id = $1 AND did = $2") 1046 - .bind(session_id) 1047 - .bind(&auth.0.did) 1048 - .execute(&state.db) 1049 - .await; 1050 - match result { 1051 - Ok(r) if r.rows_affected() == 0 => { 902 + match state 903 + .oauth_repo 904 + .delete_session_by_id(session_id, &auth.0.did) 905 + .await 906 + { 907 + Ok(0) => { 1052 908 return ApiError::SessionNotFound.into_response(); 1053 909 } 1054 910 Err(e) => { ··· 1078 934 return ApiError::InvalidToken(None).into_response(); 1079 935 }; 1080 936 1081 - let mut tx = match state.db.begin().await { 1082 - Ok(tx) => tx, 1083 - Err(e) => { 1084 - error!("Failed to begin transaction: {:?}", e); 1085 - return ApiError::InternalError(None).into_response(); 1086 - } 1087 - }; 1088 - 1089 937 if auth.0.is_oauth { 1090 - if let Err(e) = sqlx::query("DELETE FROM session_tokens WHERE did = $1") 1091 - .bind(&auth.0.did) 1092 - .execute(&mut *tx) 1093 - .await 1094 - { 938 + if let Err(e) = state.session_repo.delete_sessions_by_did(&auth.0.did).await { 1095 939 error!("DB error revoking JWT sessions: {:?}", e); 1096 940 return ApiError::InternalError(None).into_response(); 1097 941 } 1098 - if let Err(e) = sqlx::query("DELETE FROM oauth_token WHERE did = $1 AND token_id != $2") 1099 - .bind(&auth.0.did) 1100 - .bind(jti) 1101 - .execute(&mut *tx) 942 + let jti_typed = TokenId::from(jti.clone()); 943 + if let Err(e) = state 944 + .oauth_repo 945 + .delete_sessions_by_did_except(&auth.0.did, &jti_typed) 1102 946 .await 1103 947 { 1104 948 error!("DB error revoking OAuth sessions: {:?}", e); 1105 949 return ApiError::InternalError(None).into_response(); 1106 950 } 1107 951 } else { 1108 - if let Err(e) = 1109 - sqlx::query("DELETE FROM session_tokens WHERE did = $1 AND access_jti != $2") 1110 - .bind(&auth.0.did) 1111 - .bind(jti) 1112 - .execute(&mut *tx) 1113 - .await 952 + if let Err(e) = state 953 + .session_repo 954 + .delete_sessions_by_did_except_jti(&auth.0.did, jti) 955 + .await 1114 956 { 1115 957 error!("DB error revoking JWT sessions: {:?}", e); 1116 958 return ApiError::InternalError(None).into_response(); 1117 959 } 1118 - if let Err(e) = sqlx::query("DELETE FROM oauth_token WHERE did = $1") 1119 - .bind(&auth.0.did) 1120 - .execute(&mut *tx) 1121 - .await 1122 - { 960 + if let Err(e) = state.oauth_repo.delete_sessions_by_did(&auth.0.did).await { 1123 961 error!("DB error revoking OAuth sessions: {:?}", e); 1124 962 return ApiError::InternalError(None).into_response(); 1125 963 } 1126 - } 1127 - 1128 - if let Err(e) = tx.commit().await { 1129 - error!("Failed to commit transaction: {:?}", e); 1130 - return ApiError::InternalError(None).into_response(); 1131 964 } 1132 965 1133 966 info!(did = %&auth.0.did, "All other sessions revoked"); ··· 1145 978 State(state): State<AppState>, 1146 979 auth: BearerAuth, 1147 980 ) -> Response { 1148 - let result = sqlx::query!( 1149 - r#"SELECT 1150 - u.allow_legacy_login, 1151 - (EXISTS(SELECT 1 FROM user_totp t WHERE t.did = u.did AND t.verified = TRUE) OR 1152 - EXISTS(SELECT 1 FROM passkeys p WHERE p.did = u.did)) as "has_mfa!" 1153 - FROM users u WHERE u.did = $1"#, 1154 - &auth.0.did 1155 - ) 1156 - .fetch_optional(&state.db) 1157 - .await; 1158 - 1159 - match result { 1160 - Ok(Some(row)) => Json(LegacyLoginPreferenceOutput { 1161 - allow_legacy_login: row.allow_legacy_login, 1162 - has_mfa: row.has_mfa, 981 + match state.user_repo.get_legacy_login_pref(&auth.0.did).await { 982 + Ok(Some(pref)) => Json(LegacyLoginPreferenceOutput { 983 + allow_legacy_login: pref.allow_legacy_login, 984 + has_mfa: pref.has_mfa, 1163 985 }) 1164 986 .into_response(), 1165 987 Ok(None) => ApiError::AccountNotFound.into_response(), ··· 1181 1003 auth: BearerAuth, 1182 1004 Json(input): Json<UpdateLegacyLoginInput>, 1183 1005 ) -> Response { 1184 - if !crate::api::server::reauth::check_legacy_session_mfa(&state.db, &auth.0.did).await { 1185 - return crate::api::server::reauth::legacy_mfa_required_response(&state.db, &auth.0.did) 1006 + if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.0.did).await { 1007 + return crate::api::server::reauth::legacy_mfa_required_response(&*state.user_repo, &*state.session_repo, &auth.0.did) 1186 1008 .await; 1187 1009 } 1188 1010 1189 - if crate::api::server::reauth::check_reauth_required(&state.db, &auth.0.did).await { 1190 - return crate::api::server::reauth::reauth_required_response(&state.db, &auth.0.did).await; 1011 + if crate::api::server::reauth::check_reauth_required(&*state.session_repo, &auth.0.did).await { 1012 + return crate::api::server::reauth::reauth_required_response(&*state.user_repo, &*state.session_repo, &auth.0.did).await; 1191 1013 } 1192 1014 1193 - let result = sqlx::query!( 1194 - "UPDATE users SET allow_legacy_login = $1 WHERE did = $2 RETURNING did", 1195 - input.allow_legacy_login, 1196 - &auth.0.did 1197 - ) 1198 - .fetch_optional(&state.db) 1199 - .await; 1200 - 1201 - match result { 1202 - Ok(Some(_)) => { 1015 + match state 1016 + .user_repo 1017 + .update_legacy_login(&auth.0.did, input.allow_legacy_login) 1018 + .await 1019 + { 1020 + Ok(true) => { 1203 1021 info!( 1204 1022 did = %&auth.0.did, 1205 1023 allow_legacy_login = input.allow_legacy_login, ··· 1210 1028 })) 1211 1029 .into_response() 1212 1030 } 1213 - Ok(None) => ApiError::AccountNotFound.into_response(), 1031 + Ok(false) => ApiError::AccountNotFound.into_response(), 1214 1032 Err(e) => { 1215 1033 error!("DB error: {:?}", e); 1216 1034 ApiError::InternalError(None).into_response() ··· 1239 1057 .into_response(); 1240 1058 } 1241 1059 1242 - let result = sqlx::query!( 1243 - "UPDATE users SET preferred_locale = $1 WHERE did = $2 RETURNING did", 1244 - input.preferred_locale, 1245 - &auth.0.did 1246 - ) 1247 - .fetch_optional(&state.db) 1248 - .await; 1249 - 1250 - match result { 1251 - Ok(Some(_)) => { 1060 + match state 1061 + .user_repo 1062 + .update_locale(&auth.0.did, &input.preferred_locale) 1063 + .await 1064 + { 1065 + Ok(true) => { 1252 1066 info!( 1253 1067 did = %&auth.0.did, 1254 1068 locale = %input.preferred_locale, ··· 1259 1073 })) 1260 1074 .into_response() 1261 1075 } 1262 - Ok(None) => ApiError::AccountNotFound.into_response(), 1076 + Ok(false) => ApiError::AccountNotFound.into_response(), 1263 1077 Err(e) => { 1264 1078 error!("DB error updating locale: {:?}", e); 1265 1079 ApiError::InternalError(None).into_response()
+19 -16
crates/tranquil-pds/src/api/server/signing_key.rs
··· 38 38 State(state): State<AppState>, 39 39 Json(input): Json<ReserveSigningKeyInput>, 40 40 ) -> Response { 41 + let did: Option<crate::types::Did> = match input.did { 42 + Some(ref d) => match d.parse() { 43 + Ok(parsed) => Some(parsed), 44 + Err(_) => return ApiError::InvalidDid("Invalid DID format".into()).into_response(), 45 + }, 46 + None => None, 47 + }; 41 48 let signing_key = SigningKey::random(&mut rand::thread_rng()); 42 49 let private_key_bytes = signing_key.to_bytes(); 43 50 let public_key_did_key = public_key_to_did_key(&signing_key); 44 51 let expires_at = Utc::now() + Duration::hours(24); 45 52 let private_bytes: &[u8] = &private_key_bytes; 46 - let result = sqlx::query!( 47 - r#" 48 - INSERT INTO reserved_signing_keys (did, public_key_did_key, private_key_bytes, expires_at) 49 - VALUES ($1, $2, $3, $4) 50 - RETURNING id 51 - "#, 52 - input.did, 53 - public_key_did_key, 54 - private_bytes, 55 - expires_at 56 - ) 57 - .fetch_one(&state.db) 58 - .await; 59 - match result { 60 - Ok(row) => { 61 - info!("Reserved signing key {} for did {:?}", row.id, input.did); 53 + match state 54 + .infra_repo 55 + .reserve_signing_key( 56 + did.as_ref(), 57 + &public_key_did_key, 58 + private_bytes, 59 + expires_at, 60 + ) 61 + .await 62 + { 63 + Ok(key_id) => { 64 + info!("Reserved signing key {} for did {:?}", key_id, input.did); 62 65 ( 63 66 StatusCode::OK, 64 67 Json(ReserveSigningKeyOutput {
+95 -269
crates/tranquil-pds/src/api/server/totp.rs
··· 13 13 extract::State, 14 14 response::{IntoResponse, Response}, 15 15 }; 16 - use chrono::Utc; 17 16 use serde::{Deserialize, Serialize}; 18 17 use tracing::{error, info, warn}; 19 18 ··· 28 27 } 29 28 30 29 pub async fn create_totp_secret(State(state): State<AppState>, auth: BearerAuth) -> Response { 31 - let existing = sqlx::query_scalar!( 32 - "SELECT verified FROM user_totp WHERE did = $1", 33 - &*&auth.0.did 34 - ) 35 - .fetch_optional(&state.db) 36 - .await; 37 - 38 - if let Ok(Some(true)) = existing { 39 - return ApiError::TotpAlreadyEnabled.into_response(); 30 + match state.user_repo.get_totp_record(&auth.0.did).await { 31 + Ok(Some(record)) if record.verified => return ApiError::TotpAlreadyEnabled.into_response(), 32 + Ok(_) => {} 33 + Err(e) => { 34 + error!("DB error checking TOTP: {:?}", e); 35 + return ApiError::InternalError(None).into_response(); 36 + } 40 37 } 41 38 42 39 let secret = generate_totp_secret(); 43 40 44 - let handle = sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", &*&auth.0.did) 45 - .fetch_optional(&state.db) 46 - .await; 47 - 48 - let handle = match handle { 41 + let handle = match state.user_repo.get_handle_by_did(&auth.0.did).await { 49 42 Ok(Some(h)) => h, 50 43 Ok(None) => return ApiError::AccountNotFound.into_response(), 51 44 Err(e) => { ··· 74 67 } 75 68 }; 76 69 77 - let result = sqlx::query!( 78 - r#" 79 - INSERT INTO user_totp (did, secret_encrypted, encryption_version, verified, created_at) 80 - VALUES ($1, $2, $3, false, NOW()) 81 - ON CONFLICT (did) DO UPDATE SET 82 - secret_encrypted = $2, 83 - encryption_version = $3, 84 - verified = false, 85 - created_at = NOW(), 86 - last_used = NULL 87 - "#, 88 - &auth.0.did, 89 - encrypted_secret, 90 - ENCRYPTION_VERSION 91 - ) 92 - .execute(&state.db) 93 - .await; 94 - 95 - if let Err(e) = result { 70 + if let Err(e) = state 71 + .user_repo 72 + .upsert_totp_secret(&auth.0.did, &encrypted_secret, ENCRYPTION_VERSION) 73 + .await 74 + { 96 75 error!("Failed to store TOTP secret: {:?}", e); 97 76 return ApiError::InternalError(None).into_response(); 98 77 } ··· 133 112 return ApiError::RateLimitExceeded(None).into_response(); 134 113 } 135 114 136 - let totp_row = sqlx::query!( 137 - "SELECT secret_encrypted, encryption_version, verified FROM user_totp WHERE did = $1", 138 - &auth.0.did 139 - ) 140 - .fetch_optional(&state.db) 141 - .await; 142 - 143 - let totp_row = match totp_row { 115 + let totp_record = match state.user_repo.get_totp_record(&auth.0.did).await { 144 116 Ok(Some(row)) => row, 145 117 Ok(None) => return ApiError::TotpNotEnabled.into_response(), 146 118 Err(e) => { ··· 149 121 } 150 122 }; 151 123 152 - if totp_row.verified { 124 + if totp_record.verified { 153 125 return ApiError::TotpAlreadyEnabled.into_response(); 154 126 } 155 127 156 - let secret = match decrypt_totp_secret(&totp_row.secret_encrypted, totp_row.encryption_version) 157 - { 158 - Ok(s) => s, 159 - Err(e) => { 160 - error!("Failed to decrypt TOTP secret: {:?}", e); 161 - return ApiError::InternalError(None).into_response(); 162 - } 163 - }; 128 + let secret = 129 + match decrypt_totp_secret(&totp_record.secret_encrypted, totp_record.encryption_version) { 130 + Ok(s) => s, 131 + Err(e) => { 132 + error!("Failed to decrypt TOTP secret: {:?}", e); 133 + return ApiError::InternalError(None).into_response(); 134 + } 135 + }; 164 136 165 137 let code = input.code.trim(); 166 138 if !verify_totp_code(&secret, code) { ··· 168 140 } 169 141 170 142 let backup_codes = generate_backup_codes(); 171 - let mut tx = match state.db.begin().await { 172 - Ok(tx) => tx, 173 - Err(e) => { 174 - error!("Failed to begin transaction: {:?}", e); 175 - return ApiError::InternalError(None).into_response(); 176 - } 177 - }; 178 - 179 - if let Err(e) = sqlx::query!( 180 - "UPDATE user_totp SET verified = true, last_used = NOW() WHERE did = $1", 181 - &auth.0.did 182 - ) 183 - .execute(&mut *tx) 184 - .await 185 - { 186 - error!("Failed to enable TOTP: {:?}", e); 187 - return ApiError::InternalError(None).into_response(); 188 - } 189 - 190 - if let Err(e) = sqlx::query!("DELETE FROM backup_codes WHERE did = $1", &*&auth.0.did) 191 - .execute(&mut *tx) 192 - .await 193 - { 194 - error!("Failed to clear old backup codes: {:?}", e); 195 - return ApiError::InternalError(None).into_response(); 196 - } 197 - 198 143 let backup_hashes: Result<Vec<_>, _> = 199 144 backup_codes.iter().map(|c| hash_backup_code(c)).collect(); 200 145 let backup_hashes = match backup_hashes { ··· 205 150 } 206 151 }; 207 152 208 - if let Err(e) = sqlx::query!( 209 - r#" 210 - INSERT INTO backup_codes (did, code_hash, created_at) 211 - SELECT $1, hash, NOW() FROM UNNEST($2::text[]) AS t(hash) 212 - "#, 213 - &auth.0.did, 214 - &backup_hashes[..] 215 - ) 216 - .execute(&mut *tx) 217 - .await 153 + if let Err(e) = state 154 + .user_repo 155 + .enable_totp_with_backup_codes(&auth.0.did, &backup_hashes) 156 + .await 218 157 { 219 - error!("Failed to store backup codes: {:?}", e); 220 - return ApiError::InternalError(None).into_response(); 221 - } 222 - 223 - if let Err(e) = tx.commit().await { 224 - error!("Failed to commit transaction: {:?}", e); 158 + error!("Failed to enable TOTP: {:?}", e); 225 159 return ApiError::InternalError(None).into_response(); 226 160 } 227 161 ··· 241 175 auth: BearerAuth, 242 176 Json(input): Json<DisableTotpInput>, 243 177 ) -> Response { 244 - if !crate::api::server::reauth::check_legacy_session_mfa(&state.db, &auth.0.did).await { 245 - return crate::api::server::reauth::legacy_mfa_required_response(&state.db, &auth.0.did) 246 - .await; 178 + if !crate::api::server::reauth::check_legacy_session_mfa(&*state.session_repo, &auth.0.did).await 179 + { 180 + return crate::api::server::reauth::legacy_mfa_required_response( 181 + &*state.user_repo, 182 + &*state.session_repo, 183 + &auth.0.did, 184 + ) 185 + .await; 247 186 } 248 187 249 188 if !state ··· 254 193 return ApiError::RateLimitExceeded(None).into_response(); 255 194 } 256 195 257 - let user = sqlx::query!( 258 - "SELECT password_hash FROM users WHERE did = $1", 259 - &*&auth.0.did 260 - ) 261 - .fetch_optional(&state.db) 262 - .await; 263 - 264 - let password_hash = match user { 265 - Ok(Some(row)) => row.password_hash, 196 + let password_hash = match state.user_repo.get_password_hash_by_did(&auth.0.did).await { 197 + Ok(Some(hash)) => hash, 266 198 Ok(None) => return ApiError::AccountNotFound.into_response(), 267 199 Err(e) => { 268 200 error!("DB error fetching user: {:?}", e); ··· 270 202 } 271 203 }; 272 204 273 - let password_valid = password_hash 274 - .as_ref() 275 - .map(|h| bcrypt::verify(&input.password, h).unwrap_or(false)) 276 - .unwrap_or(false); 205 + let password_valid = bcrypt::verify(&input.password, &password_hash).unwrap_or(false); 277 206 if !password_valid { 278 207 return ApiError::InvalidPassword("Password is incorrect".into()).into_response(); 279 208 } 280 209 281 - let totp_row = sqlx::query!( 282 - "SELECT secret_encrypted, encryption_version, verified FROM user_totp WHERE did = $1", 283 - &auth.0.did 284 - ) 285 - .fetch_optional(&state.db) 286 - .await; 287 - 288 - let totp_row = match totp_row { 210 + let totp_record = match state.user_repo.get_totp_record(&auth.0.did).await { 289 211 Ok(Some(row)) if row.verified => row, 290 212 Ok(Some(_)) | Ok(None) => return ApiError::TotpNotEnabled.into_response(), 291 213 Err(e) => { ··· 298 220 let code_valid = if is_backup_code_format(code) { 299 221 verify_backup_code_for_user(&state, &auth.0.did, code).await 300 222 } else { 301 - let secret = 302 - match decrypt_totp_secret(&totp_row.secret_encrypted, totp_row.encryption_version) { 303 - Ok(s) => s, 304 - Err(e) => { 305 - error!("Failed to decrypt TOTP secret: {:?}", e); 306 - return ApiError::InternalError(None).into_response(); 307 - } 308 - }; 223 + let secret = match decrypt_totp_secret( 224 + &totp_record.secret_encrypted, 225 + totp_record.encryption_version, 226 + ) { 227 + Ok(s) => s, 228 + Err(e) => { 229 + error!("Failed to decrypt TOTP secret: {:?}", e); 230 + return ApiError::InternalError(None).into_response(); 231 + } 232 + }; 309 233 verify_totp_code(&secret, code) 310 234 }; 311 235 ··· 313 237 return ApiError::InvalidCode(Some("Invalid verification code".into())).into_response(); 314 238 } 315 239 316 - let mut tx = match state.db.begin().await { 317 - Ok(tx) => tx, 318 - Err(e) => { 319 - error!("Failed to begin transaction: {:?}", e); 320 - return ApiError::InternalError(None).into_response(); 321 - } 322 - }; 323 - 324 - if let Err(e) = sqlx::query!("DELETE FROM user_totp WHERE did = $1", &*&auth.0.did) 325 - .execute(&mut *tx) 240 + if let Err(e) = state 241 + .user_repo 242 + .delete_totp_and_backup_codes(&auth.0.did) 326 243 .await 327 244 { 328 245 error!("Failed to delete TOTP: {:?}", e); 329 246 return ApiError::InternalError(None).into_response(); 330 247 } 331 248 332 - if let Err(e) = sqlx::query!("DELETE FROM backup_codes WHERE did = $1", &*&auth.0.did) 333 - .execute(&mut *tx) 334 - .await 335 - { 336 - error!("Failed to delete backup codes: {:?}", e); 337 - return ApiError::InternalError(None).into_response(); 338 - } 339 - 340 - if let Err(e) = tx.commit().await { 341 - error!("Failed to commit transaction: {:?}", e); 342 - return ApiError::InternalError(None).into_response(); 343 - } 344 - 345 249 info!(did = %&auth.0.did, "TOTP disabled"); 346 250 347 251 EmptyResponse::ok().into_response() ··· 356 260 } 357 261 358 262 pub async fn get_totp_status(State(state): State<AppState>, auth: BearerAuth) -> Response { 359 - let totp_row = sqlx::query!( 360 - "SELECT verified FROM user_totp WHERE did = $1", 361 - &*&auth.0.did 362 - ) 363 - .fetch_optional(&state.db) 364 - .await; 365 - 366 - let enabled = match totp_row { 263 + let enabled = match state.user_repo.get_totp_record(&auth.0.did).await { 367 264 Ok(Some(row)) => row.verified, 368 265 Ok(None) => false, 369 266 Err(e) => { ··· 372 269 } 373 270 }; 374 271 375 - let backup_count_row = sqlx::query!( 376 - "SELECT COUNT(*) as count FROM backup_codes WHERE did = $1 AND used_at IS NULL", 377 - &auth.0.did 378 - ) 379 - .fetch_one(&state.db) 380 - .await; 381 - 382 - let backup_count = backup_count_row.map(|r| r.count.unwrap_or(0)).unwrap_or(0); 272 + let backup_count = match state.user_repo.count_unused_backup_codes(&auth.0.did).await { 273 + Ok(count) => count, 274 + Err(e) => { 275 + error!("DB error counting backup codes: {:?}", e); 276 + return ApiError::InternalError(None).into_response(); 277 + } 278 + }; 383 279 384 280 Json(GetTotpStatusResponse { 385 281 enabled, ··· 414 310 return ApiError::RateLimitExceeded(None).into_response(); 415 311 } 416 312 417 - let user = sqlx::query!( 418 - "SELECT password_hash FROM users WHERE did = $1", 419 - &*&auth.0.did 420 - ) 421 - .fetch_optional(&state.db) 422 - .await; 423 - 424 - let password_hash = match user { 425 - Ok(Some(row)) => row.password_hash, 313 + let password_hash = match state.user_repo.get_password_hash_by_did(&auth.0.did).await { 314 + Ok(Some(hash)) => hash, 426 315 Ok(None) => return ApiError::AccountNotFound.into_response(), 427 316 Err(e) => { 428 317 error!("DB error fetching user: {:?}", e); ··· 430 319 } 431 320 }; 432 321 433 - let password_valid = password_hash 434 - .as_ref() 435 - .map(|h| bcrypt::verify(&input.password, h).unwrap_or(false)) 436 - .unwrap_or(false); 322 + let password_valid = bcrypt::verify(&input.password, &password_hash).unwrap_or(false); 437 323 if !password_valid { 438 324 return ApiError::InvalidPassword("Password is incorrect".into()).into_response(); 439 325 } 440 326 441 - let totp_row = sqlx::query!( 442 - "SELECT secret_encrypted, encryption_version, verified FROM user_totp WHERE did = $1", 443 - &auth.0.did 444 - ) 445 - .fetch_optional(&state.db) 446 - .await; 447 - 448 - let totp_row = match totp_row { 327 + let totp_record = match state.user_repo.get_totp_record(&auth.0.did).await { 449 328 Ok(Some(row)) if row.verified => row, 450 329 Ok(Some(_)) | Ok(None) => return ApiError::TotpNotEnabled.into_response(), 451 330 Err(e) => { ··· 454 333 } 455 334 }; 456 335 457 - let secret = match decrypt_totp_secret(&totp_row.secret_encrypted, totp_row.encryption_version) 458 - { 459 - Ok(s) => s, 460 - Err(e) => { 461 - error!("Failed to decrypt TOTP secret: {:?}", e); 462 - return ApiError::InternalError(None).into_response(); 463 - } 464 - }; 336 + let secret = 337 + match decrypt_totp_secret(&totp_record.secret_encrypted, totp_record.encryption_version) { 338 + Ok(s) => s, 339 + Err(e) => { 340 + error!("Failed to decrypt TOTP secret: {:?}", e); 341 + return ApiError::InternalError(None).into_response(); 342 + } 343 + }; 465 344 466 345 let code = input.code.trim(); 467 346 if !verify_totp_code(&secret, code) { ··· 469 348 } 470 349 471 350 let backup_codes = generate_backup_codes(); 472 - let mut tx = match state.db.begin().await { 473 - Ok(tx) => tx, 474 - Err(e) => { 475 - error!("Failed to begin transaction: {:?}", e); 476 - return ApiError::InternalError(None).into_response(); 477 - } 478 - }; 479 - 480 - if let Err(e) = sqlx::query!("DELETE FROM backup_codes WHERE did = $1", &*&auth.0.did) 481 - .execute(&mut *tx) 482 - .await 483 - { 484 - error!("Failed to clear old backup codes: {:?}", e); 485 - return ApiError::InternalError(None).into_response(); 486 - } 487 - 488 351 let backup_hashes: Result<Vec<_>, _> = 489 352 backup_codes.iter().map(|c| hash_backup_code(c)).collect(); 490 353 let backup_hashes = match backup_hashes { ··· 495 358 } 496 359 }; 497 360 498 - if let Err(e) = sqlx::query!( 499 - r#" 500 - INSERT INTO backup_codes (did, code_hash, created_at) 501 - SELECT $1, hash, NOW() FROM UNNEST($2::text[]) AS t(hash) 502 - "#, 503 - &auth.0.did, 504 - &backup_hashes[..] 505 - ) 506 - .execute(&mut *tx) 507 - .await 361 + if let Err(e) = state 362 + .user_repo 363 + .replace_backup_codes(&auth.0.did, &backup_hashes) 364 + .await 508 365 { 509 - error!("Failed to store backup codes: {:?}", e); 510 - return ApiError::InternalError(None).into_response(); 511 - } 512 - 513 - if let Err(e) = tx.commit().await { 514 - error!("Failed to commit transaction: {:?}", e); 366 + error!("Failed to regenerate backup codes: {:?}", e); 515 367 return ApiError::InternalError(None).into_response(); 516 368 } 517 369 ··· 520 372 Json(RegenerateBackupCodesResponse { backup_codes }).into_response() 521 373 } 522 374 523 - async fn verify_backup_code_for_user(state: &AppState, did: &str, code: &str) -> bool { 375 + async fn verify_backup_code_for_user(state: &AppState, did: &crate::types::Did, code: &str) -> bool { 524 376 let code = code.trim().to_uppercase(); 525 377 526 - let backup_codes = sqlx::query!( 527 - "SELECT id, code_hash FROM backup_codes WHERE did = $1 AND used_at IS NULL", 528 - did 529 - ) 530 - .fetch_all(&state.db) 531 - .await; 532 - 533 - let backup_codes = match backup_codes { 378 + let backup_codes = match state.user_repo.get_unused_backup_codes(did).await { 534 379 Ok(codes) => codes, 535 380 Err(e) => { 536 381 warn!("Failed to fetch backup codes: {:?}", e); ··· 544 389 545 390 match matched { 546 391 Some(row) => { 547 - let _ = sqlx::query!( 548 - "UPDATE backup_codes SET used_at = $1 WHERE id = $2", 549 - Utc::now(), 550 - row.id 551 - ) 552 - .execute(&state.db) 553 - .await; 392 + let _ = state.user_repo.mark_backup_code_used(row.id).await; 554 393 true 555 394 } 556 395 None => false, 557 396 } 558 397 } 559 398 560 - pub async fn verify_totp_or_backup_for_user(state: &AppState, did: &str, code: &str) -> bool { 399 + pub async fn verify_totp_or_backup_for_user(state: &AppState, did: &crate::types::Did, code: &str) -> bool { 561 400 let code = code.trim(); 562 401 563 402 if is_backup_code_format(code) { 564 403 return verify_backup_code_for_user(state, did, code).await; 565 404 } 566 405 567 - let totp_row = sqlx::query!( 568 - "SELECT secret_encrypted, encryption_version, verified FROM user_totp WHERE did = $1", 569 - did 570 - ) 571 - .fetch_optional(&state.db) 572 - .await; 573 - 574 - let totp_row = match totp_row { 406 + let totp_record = match state.user_repo.get_totp_record(did).await { 575 407 Ok(Some(row)) if row.verified => row, 576 408 _ => return false, 577 409 }; 578 410 579 - let secret = match decrypt_totp_secret(&totp_row.secret_encrypted, totp_row.encryption_version) 580 - { 581 - Ok(s) => s, 582 - Err(_) => return false, 583 - }; 411 + let secret = 412 + match decrypt_totp_secret(&totp_record.secret_encrypted, totp_record.encryption_version) { 413 + Ok(s) => s, 414 + Err(_) => return false, 415 + }; 584 416 585 417 if verify_totp_code(&secret, code) { 586 - let _ = sqlx::query!("UPDATE user_totp SET last_used = NOW() WHERE did = $1", did) 587 - .execute(&state.db) 588 - .await; 418 + let _ = state.user_repo.update_totp_last_used(did).await; 589 419 return true; 590 420 } 591 421 592 422 false 593 423 } 594 424 595 - pub async fn has_totp_enabled(state: &AppState, did: &str) -> bool { 596 - has_totp_enabled_db(&state.db, did).await 597 - } 598 - 599 - pub async fn has_totp_enabled_db(db: &sqlx::PgPool, did: &str) -> bool { 600 - let result = sqlx::query_scalar!("SELECT verified FROM user_totp WHERE did = $1", did) 601 - .fetch_optional(db) 602 - .await; 603 - 604 - matches!(result, Ok(Some(true))) 425 + pub async fn has_totp_enabled(state: &AppState, did: &crate::types::Did) -> bool { 426 + state 427 + .user_repo 428 + .has_totp_enabled(did) 429 + .await 430 + .unwrap_or(false) 605 431 }
+53 -96
crates/tranquil-pds/src/api/server/trusted_devices.rs
··· 7 7 }; 8 8 use chrono::{DateTime, Duration, Utc}; 9 9 use serde::{Deserialize, Serialize}; 10 - use sqlx::PgPool; 11 10 use tracing::{error, info}; 11 + use tranquil_db_traits::OAuthRepository; 12 + use tranquil_types::DeviceId; 12 13 13 14 use crate::auth::BearerAuth; 14 15 use crate::state::AppState; ··· 71 72 } 72 73 73 74 pub async fn list_trusted_devices(State(state): State<AppState>, auth: BearerAuth) -> Response { 74 - let devices = sqlx::query!( 75 - r#"SELECT od.id, od.user_agent, od.friendly_name, od.trusted_at, od.trusted_until, od.last_seen_at 76 - FROM oauth_device od 77 - JOIN oauth_account_device oad ON od.id = oad.device_id 78 - WHERE oad.did = $1 AND od.trusted_until IS NOT NULL AND od.trusted_until > NOW() 79 - ORDER BY od.last_seen_at DESC"#, 80 - &auth.0.did 81 - ) 82 - .fetch_all(&state.db) 83 - .await; 84 - 85 - match devices { 75 + match state.oauth_repo.list_trusted_devices(&auth.0.did).await { 86 76 Ok(rows) => { 87 77 let devices = rows 88 78 .into_iter() ··· 120 110 auth: BearerAuth, 121 111 Json(input): Json<RevokeTrustedDeviceInput>, 122 112 ) -> Response { 123 - let device_exists = sqlx::query_scalar!( 124 - r#"SELECT 1 as one FROM oauth_device od 125 - JOIN oauth_account_device oad ON od.id = oad.device_id 126 - WHERE oad.did = $1 AND od.id = $2"#, 127 - &auth.0.did, 128 - input.device_id 129 - ) 130 - .fetch_optional(&state.db) 131 - .await; 132 - 133 - match device_exists { 134 - Ok(Some(_)) => {} 135 - Ok(None) => { 113 + let device_id = DeviceId::from(input.device_id.clone()); 114 + match state 115 + .oauth_repo 116 + .device_belongs_to_user(&device_id, &auth.0.did) 117 + .await 118 + { 119 + Ok(true) => {} 120 + Ok(false) => { 136 121 return ApiError::DeviceNotFound.into_response(); 137 122 } 138 123 Err(e) => { ··· 141 126 } 142 127 } 143 128 144 - let result = sqlx::query!( 145 - "UPDATE oauth_device SET trusted_at = NULL, trusted_until = NULL WHERE id = $1", 146 - input.device_id 147 - ) 148 - .execute(&state.db) 149 - .await; 150 - 151 - match result { 152 - Ok(_) => { 129 + match state.oauth_repo.revoke_device_trust(&device_id).await { 130 + Ok(()) => { 153 131 info!(did = %&auth.0.did, device_id = %input.device_id, "Trusted device revoked"); 154 132 SuccessResponse::ok().into_response() 155 133 } ··· 172 150 auth: BearerAuth, 173 151 Json(input): Json<UpdateTrustedDeviceInput>, 174 152 ) -> Response { 175 - let device_exists = sqlx::query_scalar!( 176 - r#"SELECT 1 as one FROM oauth_device od 177 - JOIN oauth_account_device oad ON od.id = oad.device_id 178 - WHERE oad.did = $1 AND od.id = $2"#, 179 - &auth.0.did, 180 - input.device_id 181 - ) 182 - .fetch_optional(&state.db) 183 - .await; 184 - 185 - match device_exists { 186 - Ok(Some(_)) => {} 187 - Ok(None) => { 153 + let device_id = DeviceId::from(input.device_id.clone()); 154 + match state 155 + .oauth_repo 156 + .device_belongs_to_user(&device_id, &auth.0.did) 157 + .await 158 + { 159 + Ok(true) => {} 160 + Ok(false) => { 188 161 return ApiError::DeviceNotFound.into_response(); 189 162 } 190 163 Err(e) => { ··· 193 166 } 194 167 } 195 168 196 - let result = sqlx::query!( 197 - "UPDATE oauth_device SET friendly_name = $1 WHERE id = $2", 198 - input.friendly_name, 199 - input.device_id 200 - ) 201 - .execute(&state.db) 202 - .await; 203 - 204 - match result { 205 - Ok(_) => { 169 + match state 170 + .oauth_repo 171 + .update_device_friendly_name(&device_id, input.friendly_name.as_deref()) 172 + .await 173 + { 174 + Ok(()) => { 206 175 info!(did = %auth.0.did, device_id = %input.device_id, "Trusted device updated"); 207 176 SuccessResponse::ok().into_response() 208 177 } ··· 213 182 } 214 183 } 215 184 216 - pub async fn get_device_trust_state(db: &PgPool, device_id: &str, did: &str) -> DeviceTrustState { 217 - let result = sqlx::query!( 218 - r#"SELECT trusted_at, trusted_until FROM oauth_device od 219 - JOIN oauth_account_device oad ON od.id = oad.device_id 220 - WHERE od.id = $1 AND oad.did = $2"#, 221 - device_id, 222 - did 223 - ) 224 - .fetch_optional(db) 225 - .await; 226 - 227 - match result { 228 - Ok(Some(row)) => DeviceTrustState::from_timestamps(row.trusted_at, row.trusted_until), 185 + pub async fn get_device_trust_state( 186 + oauth_repo: &dyn OAuthRepository, 187 + device_id: &str, 188 + did: &tranquil_types::Did, 189 + ) -> DeviceTrustState { 190 + let device_id_typed = DeviceId::from(device_id.to_string()); 191 + match oauth_repo.get_device_trust_info(&device_id_typed, did).await { 192 + Ok(Some(info)) => DeviceTrustState::from_timestamps(info.trusted_at, info.trusted_until), 229 193 _ => DeviceTrustState::Untrusted, 230 194 } 231 195 } 232 196 233 - pub async fn is_device_trusted(db: &PgPool, device_id: &str, did: &str) -> bool { 234 - get_device_trust_state(db, device_id, did) 197 + pub async fn is_device_trusted( 198 + oauth_repo: &dyn OAuthRepository, 199 + device_id: &str, 200 + did: &tranquil_types::Did, 201 + ) -> bool { 202 + get_device_trust_state(oauth_repo, device_id, did) 235 203 .await 236 204 .is_trusted() 237 205 } 238 206 239 - pub async fn trust_device(db: &PgPool, device_id: &str) -> Result<(), sqlx::Error> { 207 + pub async fn trust_device( 208 + oauth_repo: &dyn OAuthRepository, 209 + device_id: &str, 210 + ) -> Result<(), tranquil_db_traits::DbError> { 240 211 let now = Utc::now(); 241 212 let trusted_until = now + Duration::days(TRUST_DURATION_DAYS); 242 - 243 - sqlx::query!( 244 - "UPDATE oauth_device SET trusted_at = $1, trusted_until = $2 WHERE id = $3", 245 - now, 246 - trusted_until, 247 - device_id 248 - ) 249 - .execute(db) 250 - .await?; 251 - 252 - Ok(()) 213 + let device_id_typed = DeviceId::from(device_id.to_string()); 214 + oauth_repo.trust_device(&device_id_typed, now, trusted_until).await 253 215 } 254 216 255 - pub async fn extend_device_trust(db: &PgPool, device_id: &str) -> Result<(), sqlx::Error> { 217 + pub async fn extend_device_trust( 218 + oauth_repo: &dyn OAuthRepository, 219 + device_id: &str, 220 + ) -> Result<(), tranquil_db_traits::DbError> { 256 221 let trusted_until = Utc::now() + Duration::days(TRUST_DURATION_DAYS); 257 - 258 - sqlx::query!( 259 - "UPDATE oauth_device SET trusted_until = $1 WHERE id = $2 AND trusted_until IS NOT NULL", 260 - trusted_until, 261 - device_id 262 - ) 263 - .execute(db) 264 - .await?; 265 - 266 - Ok(()) 222 + let device_id_typed = DeviceId::from(device_id.to_string()); 223 + oauth_repo.extend_device_trust(&device_id_typed, trusted_until).await 267 224 }
+11 -17
crates/tranquil-pds/src/api/server/verify_email.rs
··· 55 55 ) -> Result<Json<ResendMigrationVerificationOutput>, ApiError> { 56 56 let email = input.email.trim().to_lowercase(); 57 57 58 - let user = sqlx::query!( 59 - "SELECT id, did, email, email_verified, handle FROM users WHERE LOWER(email) = $1", 60 - email 61 - ) 62 - .fetch_optional(&state.db) 63 - .await 64 - .map_err(|e| { 65 - warn!(error = %e, "Database error during resend verification"); 66 - ApiError::InternalError(None) 67 - })?; 68 - 69 - let user = match user { 70 - Some(u) => u, 71 - None => { 58 + let user = match state.user_repo.get_by_email(&email).await { 59 + Ok(Some(u)) => u, 60 + Ok(None) => { 72 61 return Ok(Json(ResendMigrationVerificationOutput { sent: true })); 62 + } 63 + Err(e) => { 64 + warn!(error = ?e, "Database error during resend verification"); 65 + return Err(ApiError::InternalError(None)); 73 66 } 74 67 }; 75 68 ··· 81 74 let token = crate::auth::verification_token::generate_migration_token(&user.did, &email); 82 75 let formatted_token = crate::auth::verification_token::format_token_for_display(&token); 83 76 84 - if let Err(e) = crate::comms::enqueue_migration_verification( 85 - &state.db, 77 + if let Err(e) = crate::comms::comms_repo::enqueue_migration_verification( 78 + state.user_repo.as_ref(), 79 + state.infra_repo.as_ref(), 86 80 user.id, 87 81 &email, 88 82 &formatted_token, ··· 90 84 ) 91 85 .await 92 86 { 93 - warn!(error = %e, "Failed to enqueue migration verification email"); 87 + warn!(error = ?e, "Failed to enqueue migration verification email"); 94 88 } 95 89 96 90 info!(did = %user.did, "Resent migration verification email");
+111 -100
crates/tranquil-pds/src/api/server/verify_token.rs
··· 74 74 return Err(ApiError::InvalidChannel); 75 75 } 76 76 77 - let user = sqlx::query!( 78 - "SELECT id, email, email_verified FROM users WHERE did = $1", 79 - did 80 - ) 81 - .fetch_optional(&state.db) 82 - .await 83 - .map_err(|e| { 84 - warn!(error = %e, "Database error during migration verification"); 85 - ApiError::InternalError(None) 86 - })?; 87 - 88 - let user = user.ok_or(ApiError::AccountNotFound)?; 77 + let did_typed: Did = did.parse().map_err(|_| ApiError::InvalidDid("Invalid DID format".into()))?; 78 + let user = state 79 + .user_repo 80 + .get_verification_info(&did_typed) 81 + .await 82 + .map_err(|e| { 83 + warn!(error = ?e, "Database error during migration verification"); 84 + ApiError::InternalError(None) 85 + })? 86 + .ok_or(ApiError::AccountNotFound)?; 89 87 90 88 if user.email.as_ref().map(|e| e.to_lowercase()) != Some(identifier.to_string()) { 91 89 return Err(ApiError::IdentifierMismatch); 92 90 } 93 91 94 92 if !user.email_verified { 95 - sqlx::query!( 96 - "UPDATE users SET email_verified = true WHERE id = $1", 97 - user.id 98 - ) 99 - .execute(&state.db) 100 - .await 101 - .map_err(|e| { 102 - warn!(error = %e, "Failed to update email_verified status"); 103 - ApiError::InternalError(None) 104 - })?; 93 + state 94 + .user_repo 95 + .set_email_verified_flag(user.id) 96 + .await 97 + .map_err(|e| { 98 + warn!(error = ?e, "Failed to update email_verified status"); 99 + ApiError::InternalError(None) 100 + })?; 105 101 } 106 102 107 103 info!(did = %did, "Migration email verified successfully"); ··· 120 116 channel: &str, 121 117 identifier: &str, 122 118 ) -> Result<Json<VerifyTokenOutput>, ApiError> { 123 - let user_id = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 124 - .fetch_one(&state.db) 119 + let did_typed: Did = did.parse().map_err(|_| ApiError::InvalidDid("Invalid DID format".into()))?; 120 + let user_id = state 121 + .user_repo 122 + .get_id_by_did(&did_typed) 125 123 .await 126 - .map_err(|_| ApiError::InternalError(None))?; 124 + .map_err(|_| ApiError::InternalError(None))? 125 + .ok_or(ApiError::AccountNotFound)?; 127 126 128 - let update_result = match channel { 129 - "email" => sqlx::query!( 130 - "UPDATE users SET email = $1, email_verified = TRUE, updated_at = NOW() WHERE id = $2", 131 - identifier, 132 - user_id 133 - ).execute(&state.db).await, 134 - "discord" => sqlx::query!( 135 - "UPDATE users SET discord_id = $1, discord_verified = TRUE, updated_at = NOW() WHERE id = $2", 136 - identifier, 137 - user_id 138 - ).execute(&state.db).await, 139 - "telegram" => sqlx::query!( 140 - "UPDATE users SET telegram_username = $1, telegram_verified = TRUE, updated_at = NOW() WHERE id = $2", 141 - identifier, 142 - user_id 143 - ).execute(&state.db).await, 144 - "signal" => sqlx::query!( 145 - "UPDATE users SET signal_number = $1, signal_verified = TRUE, updated_at = NOW() WHERE id = $2", 146 - identifier, 147 - user_id 148 - ).execute(&state.db).await, 127 + match channel { 128 + "email" => { 129 + let success = state 130 + .user_repo 131 + .verify_email_channel(user_id, identifier) 132 + .await 133 + .map_err(|e| { 134 + error!("Failed to update email channel: {:?}", e); 135 + ApiError::InternalError(None) 136 + })?; 137 + if !success { 138 + return Err(ApiError::EmailTaken); 139 + } 140 + } 141 + "discord" => { 142 + state 143 + .user_repo 144 + .verify_discord_channel(user_id, identifier) 145 + .await 146 + .map_err(|e| { 147 + error!("Failed to update discord channel: {:?}", e); 148 + ApiError::InternalError(None) 149 + })?; 150 + } 151 + "telegram" => { 152 + state 153 + .user_repo 154 + .verify_telegram_channel(user_id, identifier) 155 + .await 156 + .map_err(|e| { 157 + error!("Failed to update telegram channel: {:?}", e); 158 + ApiError::InternalError(None) 159 + })?; 160 + } 161 + "signal" => { 162 + state 163 + .user_repo 164 + .verify_signal_channel(user_id, identifier) 165 + .await 166 + .map_err(|e| { 167 + error!("Failed to update signal channel: {:?}", e); 168 + ApiError::InternalError(None) 169 + })?; 170 + } 149 171 _ => { 150 172 return Err(ApiError::InvalidChannel); 151 173 } 152 174 }; 153 175 154 - if let Err(e) = update_result { 155 - error!("Failed to update user channel: {:?}", e); 156 - if channel == "email" 157 - && e.as_database_error() 158 - .map(|db| db.is_unique_violation()) 159 - .unwrap_or(false) 160 - { 161 - return Err(ApiError::EmailTaken); 162 - } 163 - return Err(ApiError::InternalError(None)); 164 - } 165 - 166 176 info!(did = %did, channel = %channel, "Channel verified successfully"); 167 177 168 178 Ok(Json(VerifyTokenOutput { ··· 179 189 channel: &str, 180 190 _identifier: &str, 181 191 ) -> Result<Json<VerifyTokenOutput>, ApiError> { 182 - let user = sqlx::query!( 183 - "SELECT id, handle, email, email_verified, discord_verified, telegram_verified, signal_verified FROM users WHERE did = $1", 184 - did 185 - ) 186 - .fetch_optional(&state.db) 187 - .await 188 - .map_err(|e| { 189 - warn!(error = %e, "Database error during signup verification"); 190 - ApiError::InternalError(None) 191 - })?; 192 - 193 - let user = user.ok_or(ApiError::AccountNotFound)?; 192 + let did_typed: Did = did.parse().map_err(|_| ApiError::InvalidDid("Invalid DID format".into()))?; 193 + let user = state 194 + .user_repo 195 + .get_verification_info(&did_typed) 196 + .await 197 + .map_err(|e| { 198 + warn!(error = ?e, "Database error during signup verification"); 199 + ApiError::InternalError(None) 200 + })? 201 + .ok_or(ApiError::AccountNotFound)?; 194 202 195 203 let is_verified = user.email_verified 196 204 || user.discord_verified ··· 206 214 })); 207 215 } 208 216 209 - let update_result = match channel { 217 + match channel { 210 218 "email" => { 211 - sqlx::query!( 212 - "UPDATE users SET email_verified = TRUE WHERE id = $1", 213 - user.id 214 - ) 215 - .execute(&state.db) 216 - .await 219 + state 220 + .user_repo 221 + .set_email_verified_flag(user.id) 222 + .await 223 + .map_err(|e| { 224 + warn!(error = ?e, "Failed to update email verified status"); 225 + ApiError::InternalError(None) 226 + })?; 217 227 } 218 228 "discord" => { 219 - sqlx::query!( 220 - "UPDATE users SET discord_verified = TRUE WHERE id = $1", 221 - user.id 222 - ) 223 - .execute(&state.db) 224 - .await 229 + state 230 + .user_repo 231 + .set_discord_verified_flag(user.id) 232 + .await 233 + .map_err(|e| { 234 + warn!(error = ?e, "Failed to update discord verified status"); 235 + ApiError::InternalError(None) 236 + })?; 225 237 } 226 238 "telegram" => { 227 - sqlx::query!( 228 - "UPDATE users SET telegram_verified = TRUE WHERE id = $1", 229 - user.id 230 - ) 231 - .execute(&state.db) 232 - .await 239 + state 240 + .user_repo 241 + .set_telegram_verified_flag(user.id) 242 + .await 243 + .map_err(|e| { 244 + warn!(error = ?e, "Failed to update telegram verified status"); 245 + ApiError::InternalError(None) 246 + })?; 233 247 } 234 248 "signal" => { 235 - sqlx::query!( 236 - "UPDATE users SET signal_verified = TRUE WHERE id = $1", 237 - user.id 238 - ) 239 - .execute(&state.db) 240 - .await 249 + state 250 + .user_repo 251 + .set_signal_verified_flag(user.id) 252 + .await 253 + .map_err(|e| { 254 + warn!(error = ?e, "Failed to update signal verified status"); 255 + ApiError::InternalError(None) 256 + })?; 241 257 } 242 258 _ => { 243 259 return Err(ApiError::InvalidChannel); 244 260 } 245 261 }; 246 - 247 - update_result.map_err(|e| { 248 - warn!(error = %e, "Failed to update channel verified status"); 249 - ApiError::InternalError(None) 250 - })?; 251 262 252 263 info!(did = %did, channel = %channel, "Signup verified successfully"); 253 264
+2 -1
crates/tranquil-pds/src/api/temp.rs
··· 28 28 { 29 29 let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 30 30 if let Ok(user) = validate_token_with_dpop( 31 - &state.db, 31 + state.user_repo.as_ref(), 32 + state.oauth_repo.as_ref(), 32 33 &extracted.token, 33 34 extracted.is_dpop, 34 35 dpop_proof,
+24 -10
crates/tranquil-pds/src/auth/extractor.rs
··· 130 130 let uri = build_full_url(&parts.uri.to_string()); 131 131 132 132 match validate_token_with_dpop( 133 - &state.db, 133 + state.user_repo.as_ref(), 134 + state.oauth_repo.as_ref(), 134 135 &extracted.token, 135 136 true, 136 137 dpop_proof, ··· 148 149 Err(_) => Err(AuthError::AuthenticationFailed), 149 150 } 150 151 } else { 151 - match validate_bearer_token_cached(&state.db, state.cache.as_ref(), &extracted.token) 152 - .await 152 + match validate_bearer_token_cached( 153 + state.user_repo.as_ref(), 154 + state.cache.as_ref(), 155 + &extracted.token, 156 + ) 157 + .await 153 158 { 154 159 Ok(user) => Ok(BearerAuth(user)), 155 160 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), ··· 186 191 let uri = build_full_url(&parts.uri.to_string()); 187 192 188 193 match validate_token_with_dpop( 189 - &state.db, 194 + state.user_repo.as_ref(), 195 + state.oauth_repo.as_ref(), 190 196 &extracted.token, 191 197 true, 192 198 dpop_proof, ··· 204 210 } 205 211 } else { 206 212 match validate_bearer_token_cached_allow_deactivated( 207 - &state.db, 213 + state.user_repo.as_ref(), 208 214 state.cache.as_ref(), 209 215 &extracted.token, 210 216 ) ··· 244 250 let uri = build_full_url(&parts.uri.to_string()); 245 251 246 252 match validate_token_with_dpop( 247 - &state.db, 253 + state.user_repo.as_ref(), 254 + state.oauth_repo.as_ref(), 248 255 &extracted.token, 249 256 true, 250 257 dpop_proof, ··· 261 268 Err(_) => Err(AuthError::AuthenticationFailed), 262 269 } 263 270 } else { 264 - match validate_bearer_token_allow_takendown(&state.db, &extracted.token).await { 271 + match validate_bearer_token_allow_takendown(state.user_repo.as_ref(), &extracted.token) 272 + .await 273 + { 265 274 Ok(user) => Ok(BearerAuthAllowTakendown(user)), 266 275 Err(TokenValidationError::AccountDeactivated) => Err(AuthError::AccountDeactivated), 267 276 Err(TokenValidationError::TokenExpired) => Err(AuthError::TokenExpired), ··· 296 305 let uri = build_full_url(&parts.uri.to_string()); 297 306 298 307 match validate_token_with_dpop( 299 - &state.db, 308 + state.user_repo.as_ref(), 309 + state.oauth_repo.as_ref(), 300 310 &extracted.token, 301 311 true, 302 312 dpop_proof, ··· 320 330 Err(_) => return Err(AuthError::AuthenticationFailed), 321 331 } 322 332 } else { 323 - match validate_bearer_token_cached(&state.db, state.cache.as_ref(), &extracted.token) 324 - .await 333 + match validate_bearer_token_cached( 334 + state.user_repo.as_ref(), 335 + state.cache.as_ref(), 336 + &extracted.token, 337 + ) 338 + .await 325 339 { 326 340 Ok(user) => user, 327 341 Err(TokenValidationError::AccountDeactivated) => {
+51 -79
crates/tranquil-pds/src/auth/mod.rs
··· 1 1 use serde::{Deserialize, Serialize}; 2 - use sqlx::PgPool; 3 2 use std::fmt; 4 3 use std::time::Duration; 5 4 ··· 7 6 use crate::cache::Cache; 8 7 use crate::oauth::scopes::ScopePermissions; 9 8 use crate::types::Did; 9 + use tranquil_db::UserRepository; 10 + use tranquil_db_traits::OAuthRepository; 10 11 11 12 pub mod extractor; 12 13 pub mod scope_check; ··· 62 63 AuthenticationFailed, 63 64 TokenExpired, 64 65 OAuthTokenExpired, 66 + InvalidToken, 65 67 } 66 68 67 69 impl fmt::Display for TokenValidationError { ··· 72 74 Self::KeyDecryptionFailed => write!(f, "KeyDecryptionFailed"), 73 75 Self::AuthenticationFailed => write!(f, "AuthenticationFailed"), 74 76 Self::TokenExpired | Self::OAuthTokenExpired => write!(f, "ExpiredToken"), 77 + Self::InvalidToken => write!(f, "InvalidToken"), 75 78 } 76 79 } 77 80 } ··· 105 108 } 106 109 107 110 pub async fn validate_bearer_token( 108 - db: &PgPool, 111 + user_repo: &dyn UserRepository, 109 112 token: &str, 110 113 ) -> Result<AuthenticatedUser, TokenValidationError> { 111 - validate_bearer_token_with_options_internal(db, None, token, false, false).await 114 + validate_bearer_token_with_options_internal(user_repo, None, token, false, false).await 112 115 } 113 116 114 117 pub async fn validate_bearer_token_allow_deactivated( 115 - db: &PgPool, 118 + user_repo: &dyn UserRepository, 116 119 token: &str, 117 120 ) -> Result<AuthenticatedUser, TokenValidationError> { 118 - validate_bearer_token_with_options_internal(db, None, token, true, false).await 121 + validate_bearer_token_with_options_internal(user_repo, None, token, true, false).await 119 122 } 120 123 121 124 pub async fn validate_bearer_token_cached( 122 - db: &PgPool, 125 + user_repo: &dyn UserRepository, 123 126 cache: &dyn Cache, 124 127 token: &str, 125 128 ) -> Result<AuthenticatedUser, TokenValidationError> { 126 - validate_bearer_token_with_options_internal(db, Some(cache), token, false, false).await 129 + validate_bearer_token_with_options_internal(user_repo, Some(cache), token, false, false).await 127 130 } 128 131 129 132 pub async fn validate_bearer_token_cached_allow_deactivated( 130 - db: &PgPool, 133 + user_repo: &dyn UserRepository, 131 134 cache: &dyn Cache, 132 135 token: &str, 133 136 ) -> Result<AuthenticatedUser, TokenValidationError> { 134 - validate_bearer_token_with_options_internal(db, Some(cache), token, true, false).await 137 + validate_bearer_token_with_options_internal(user_repo, Some(cache), token, true, false).await 135 138 } 136 139 137 140 pub async fn validate_bearer_token_for_service_auth( 138 - db: &PgPool, 141 + user_repo: &dyn UserRepository, 139 142 token: &str, 140 143 ) -> Result<AuthenticatedUser, TokenValidationError> { 141 - validate_bearer_token_with_options_internal(db, None, token, true, true).await 144 + validate_bearer_token_with_options_internal(user_repo, None, token, true, true).await 142 145 } 143 146 144 147 pub async fn validate_bearer_token_allow_takendown( 145 - db: &PgPool, 148 + user_repo: &dyn UserRepository, 146 149 token: &str, 147 150 ) -> Result<AuthenticatedUser, TokenValidationError> { 148 - validate_bearer_token_with_options_internal(db, None, token, false, true).await 151 + validate_bearer_token_with_options_internal(user_repo, None, token, false, true).await 149 152 } 150 153 151 154 async fn validate_bearer_token_with_options_internal( 152 - db: &PgPool, 155 + user_repo: &dyn UserRepository, 153 156 cache: Option<&dyn Cache>, 154 157 token: &str, 155 158 allow_deactivated: bool, ··· 157 160 ) -> Result<AuthenticatedUser, TokenValidationError> { 158 161 let did_from_token = get_did_from_token(token).ok(); 159 162 160 - if let Some(ref did) = did_from_token { 161 - let key_cache_key = format!("auth:key:{}", did); 163 + if let Some(ref did_str) = did_from_token { 164 + let did: tranquil_types::Did = match did_str.parse() { 165 + Ok(d) => d, 166 + Err(_) => return Err(TokenValidationError::InvalidToken), 167 + }; 168 + let key_cache_key = format!("auth:key:{}", did_str); 162 169 let mut cached_key: Option<Vec<u8>> = None; 163 170 164 171 if let Some(c) = cache { ··· 172 179 173 180 let (decrypted_key, deactivated_at, takedown_ref, is_admin) = if let Some(key) = cached_key 174 181 { 175 - let status_cache_key = format!("auth:status:{}", did); 182 + let status_cache_key = format!("auth:status:{}", did_str); 176 183 let cached_status: Option<CachedUserStatus> = if let Some(c) = cache { 177 184 c.get(&status_cache_key) 178 185 .await ··· 197 204 status.is_admin, 198 205 ) 199 206 } else { 200 - let user_status = sqlx::query!( 201 - "SELECT deactivated_at, takedown_ref, is_admin FROM users WHERE did = $1", 202 - did 203 - ) 204 - .fetch_optional(db) 205 - .await 206 - .ok() 207 - .flatten(); 207 + let user_status = user_repo.get_status_by_did(&did).await.ok().flatten(); 208 208 209 209 match user_status { 210 210 Some(status) => { ··· 234 234 None => (None, None, None, false), 235 235 } 236 236 } 237 - } else if let Some(user) = sqlx::query!( 238 - "SELECT k.key_bytes, k.encryption_version, u.deactivated_at, u.takedown_ref, u.is_admin 239 - FROM users u 240 - JOIN user_keys k ON u.id = k.user_id 241 - WHERE u.did = $1", 242 - did 243 - ) 244 - .fetch_optional(db) 245 - .await 246 - .ok() 247 - .flatten() 248 - { 237 + } else if let Some(user) = user_repo.get_with_key_by_did(&did).await.ok().flatten() { 249 238 let key = crate::config::decrypt_key(&user.key_bytes, user.encryption_version) 250 239 .map_err(|_| TokenValidationError::KeyDecryptionFailed)?; 251 240 ··· 310 299 } 311 300 312 301 if !session_valid { 313 - let session_row = sqlx::query!( 314 - "SELECT access_expires_at FROM session_tokens WHERE did = $1 AND access_jti = $2", 315 - did, 316 - jti 317 - ) 318 - .fetch_optional(db) 319 - .await 320 - .ok() 321 - .flatten(); 302 + let session_expiry = user_repo 303 + .get_session_access_expiry(&did, jti) 304 + .await 305 + .ok() 306 + .flatten(); 322 307 323 - if let Some(row) = session_row { 324 - if row.access_expires_at > chrono::Utc::now() { 308 + if let Some(expires_at) = session_expiry { 309 + if expires_at > chrono::Utc::now() { 325 310 session_valid = true; 326 311 if let Some(c) = cache { 327 312 let _ = c ··· 347 332 let status = 348 333 AccountStatus::from_db_fields(takedown_ref.as_deref(), deactivated_at); 349 334 return Ok(AuthenticatedUser { 350 - did: Did::new_unchecked(did.clone()), 335 + did: did.clone(), 351 336 key_bytes: Some(decrypted_key), 352 337 is_oauth: false, 353 338 is_admin, ··· 366 351 } 367 352 368 353 if let Ok(oauth_info) = crate::oauth::verify::extract_oauth_token_info(token) 369 - && let Some(oauth_token) = sqlx::query!( 370 - r#"SELECT t.did, t.expires_at, u.deactivated_at, u.takedown_ref, u.is_admin, 371 - k.key_bytes as "key_bytes?", k.encryption_version as "encryption_version?" 372 - FROM oauth_token t 373 - JOIN users u ON t.did = u.did 374 - LEFT JOIN user_keys k ON u.id = k.user_id 375 - WHERE t.token_id = $1"#, 376 - oauth_info.token_id 377 - ) 378 - .fetch_optional(db) 379 - .await 380 - .ok() 381 - .flatten() 354 + && let Some(oauth_token) = user_repo 355 + .get_oauth_token_with_user(&oauth_info.token_id) 356 + .await 357 + .ok() 358 + .flatten() 382 359 { 383 360 let status = AccountStatus::from_db_fields( 384 361 oauth_token.takedown_ref.as_deref(), ··· 428 405 429 406 #[allow(clippy::too_many_arguments)] 430 407 pub async fn validate_token_with_dpop( 431 - db: &PgPool, 408 + user_repo: &dyn UserRepository, 409 + oauth_repo: &dyn OAuthRepository, 432 410 token: &str, 433 411 is_dpop_token: bool, 434 412 dpop_proof: Option<&str>, ··· 439 417 ) -> Result<AuthenticatedUser, TokenValidationError> { 440 418 if !is_dpop_token { 441 419 if allow_takendown { 442 - return validate_bearer_token_allow_takendown(db, token).await; 420 + return validate_bearer_token_allow_takendown(user_repo, token).await; 443 421 } else if allow_deactivated { 444 - return validate_bearer_token_allow_deactivated(db, token).await; 422 + return validate_bearer_token_allow_deactivated(user_repo, token).await; 445 423 } else { 446 - return validate_bearer_token(db, token).await; 424 + return validate_bearer_token(user_repo, token).await; 447 425 } 448 426 } 449 427 match crate::oauth::verify::verify_oauth_access_token( 450 - db, 428 + oauth_repo, 451 429 token, 452 430 dpop_proof, 453 431 http_method, ··· 456 434 .await 457 435 { 458 436 Ok(result) => { 459 - let user_info = sqlx::query!( 460 - r#"SELECT u.deactivated_at, u.takedown_ref, u.is_admin, 461 - k.key_bytes as "key_bytes?", k.encryption_version as "encryption_version?" 462 - FROM users u 463 - LEFT JOIN user_keys k ON u.id = k.user_id 464 - WHERE u.did = $1"#, 465 - result.did 466 - ) 467 - .fetch_optional(db) 468 - .await 469 - .ok() 470 - .flatten(); 437 + let result_did: Did = result.did.parse().map_err(|_| TokenValidationError::InvalidToken)?; 438 + let user_info = user_repo 439 + .get_user_info_by_did(&result_did) 440 + .await 441 + .ok() 442 + .flatten(); 471 443 let Some(user_info) = user_info else { 472 444 return Err(TokenValidationError::AuthenticationFailed); 473 445 };
-322
crates/tranquil-pds/src/auth/webauthn.rs
··· 1 - use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD}; 2 - use chrono::{Duration, Utc}; 3 - use sqlx::{PgPool, Row}; 4 1 use uuid::Uuid; 5 2 use webauthn_rs::prelude::*; 6 3 ··· 80 77 .map_err(|e| format!("Failed to finish authentication: {}", e)) 81 78 } 82 79 } 83 - 84 - pub async fn save_registration_state( 85 - pool: &PgPool, 86 - did: &str, 87 - state: &SecurityKeyRegistration, 88 - ) -> Result<Uuid, sqlx::Error> { 89 - let id = Uuid::new_v4(); 90 - let state_json = serde_json::to_string(state) 91 - .map_err(|e| sqlx::Error::Protocol(format!("Failed to serialize state: {}", e)))?; 92 - let challenge = id.as_bytes().to_vec(); 93 - let expires_at = Utc::now() + Duration::minutes(5); 94 - 95 - sqlx::query!( 96 - r#" 97 - INSERT INTO webauthn_challenges (id, did, challenge, challenge_type, state_json, expires_at) 98 - VALUES ($1, $2, $3, 'registration', $4, $5) 99 - "#, 100 - id, 101 - did, 102 - challenge, 103 - state_json, 104 - expires_at, 105 - ) 106 - .execute(pool) 107 - .await?; 108 - 109 - Ok(id) 110 - } 111 - 112 - pub async fn load_registration_state( 113 - pool: &PgPool, 114 - did: &str, 115 - ) -> Result<Option<SecurityKeyRegistration>, sqlx::Error> { 116 - let row = sqlx::query!( 117 - r#" 118 - SELECT state_json FROM webauthn_challenges 119 - WHERE did = $1 AND challenge_type = 'registration' AND expires_at > NOW() 120 - ORDER BY created_at DESC 121 - LIMIT 1 122 - "#, 123 - did, 124 - ) 125 - .fetch_optional(pool) 126 - .await?; 127 - 128 - match row { 129 - Some(r) => { 130 - let state: SecurityKeyRegistration = 131 - serde_json::from_str(&r.state_json).map_err(|e| { 132 - sqlx::Error::Protocol(format!("Failed to deserialize state: {}", e)) 133 - })?; 134 - Ok(Some(state)) 135 - } 136 - None => Ok(None), 137 - } 138 - } 139 - 140 - pub async fn delete_registration_state(pool: &PgPool, did: &str) -> Result<(), sqlx::Error> { 141 - sqlx::query!( 142 - "DELETE FROM webauthn_challenges WHERE did = $1 AND challenge_type = 'registration'", 143 - did, 144 - ) 145 - .execute(pool) 146 - .await?; 147 - Ok(()) 148 - } 149 - 150 - pub async fn save_authentication_state( 151 - pool: &PgPool, 152 - did: &str, 153 - state: &SecurityKeyAuthentication, 154 - ) -> Result<Uuid, sqlx::Error> { 155 - let id = Uuid::new_v4(); 156 - let state_json = serde_json::to_string(state) 157 - .map_err(|e| sqlx::Error::Protocol(format!("Failed to serialize state: {}", e)))?; 158 - let challenge = id.as_bytes().to_vec(); 159 - let expires_at = Utc::now() + Duration::minutes(5); 160 - 161 - sqlx::query!( 162 - r#" 163 - INSERT INTO webauthn_challenges (id, did, challenge, challenge_type, state_json, expires_at) 164 - VALUES ($1, $2, $3, 'authentication', $4, $5) 165 - "#, 166 - id, 167 - did, 168 - challenge, 169 - state_json, 170 - expires_at, 171 - ) 172 - .execute(pool) 173 - .await?; 174 - 175 - Ok(id) 176 - } 177 - 178 - pub async fn load_authentication_state( 179 - pool: &PgPool, 180 - did: &str, 181 - ) -> Result<Option<SecurityKeyAuthentication>, sqlx::Error> { 182 - let row = sqlx::query!( 183 - r#" 184 - SELECT state_json FROM webauthn_challenges 185 - WHERE did = $1 AND challenge_type = 'authentication' AND expires_at > NOW() 186 - ORDER BY created_at DESC 187 - LIMIT 1 188 - "#, 189 - did, 190 - ) 191 - .fetch_optional(pool) 192 - .await?; 193 - 194 - match row { 195 - Some(r) => { 196 - let state: SecurityKeyAuthentication = 197 - serde_json::from_str(&r.state_json).map_err(|e| { 198 - sqlx::Error::Protocol(format!("Failed to deserialize state: {}", e)) 199 - })?; 200 - Ok(Some(state)) 201 - } 202 - None => Ok(None), 203 - } 204 - } 205 - 206 - pub async fn delete_authentication_state(pool: &PgPool, did: &str) -> Result<(), sqlx::Error> { 207 - sqlx::query!( 208 - "DELETE FROM webauthn_challenges WHERE did = $1 AND challenge_type = 'authentication'", 209 - did, 210 - ) 211 - .execute(pool) 212 - .await?; 213 - Ok(()) 214 - } 215 - 216 - pub async fn cleanup_expired_challenges(pool: &PgPool) -> Result<u64, sqlx::Error> { 217 - let result = sqlx::query!("DELETE FROM webauthn_challenges WHERE expires_at < NOW()") 218 - .execute(pool) 219 - .await?; 220 - Ok(result.rows_affected()) 221 - } 222 - 223 - #[derive(Debug, Clone)] 224 - pub struct StoredPasskey { 225 - pub id: Uuid, 226 - pub did: String, 227 - pub credential_id: Vec<u8>, 228 - pub public_key: Vec<u8>, 229 - pub sign_count: i32, 230 - pub created_at: chrono::DateTime<Utc>, 231 - pub last_used: Option<chrono::DateTime<Utc>>, 232 - pub friendly_name: Option<String>, 233 - pub aaguid: Option<Vec<u8>>, 234 - pub transports: Option<Vec<String>>, 235 - } 236 - 237 - impl StoredPasskey { 238 - pub fn to_security_key(&self) -> Result<SecurityKey, String> { 239 - serde_json::from_slice(&self.public_key) 240 - .map_err(|e| format!("Failed to deserialize security key: {}", e)) 241 - } 242 - 243 - pub fn credential_id_base64(&self) -> String { 244 - URL_SAFE_NO_PAD.encode(&self.credential_id) 245 - } 246 - } 247 - 248 - pub async fn save_passkey( 249 - pool: &PgPool, 250 - did: &str, 251 - security_key: &SecurityKey, 252 - friendly_name: Option<&str>, 253 - ) -> Result<Uuid, sqlx::Error> { 254 - let id = Uuid::new_v4(); 255 - let credential_id = security_key.cred_id().to_vec(); 256 - let public_key = serde_json::to_vec(security_key) 257 - .map_err(|e| sqlx::Error::Protocol(format!("Failed to serialize security key: {}", e)))?; 258 - let aaguid: Option<Vec<u8>> = None; 259 - 260 - sqlx::query!( 261 - r#" 262 - INSERT INTO passkeys (id, did, credential_id, public_key, sign_count, friendly_name, aaguid) 263 - VALUES ($1, $2, $3, $4, 0, $5, $6) 264 - "#, 265 - id, 266 - did, 267 - credential_id, 268 - public_key, 269 - friendly_name, 270 - aaguid, 271 - ) 272 - .execute(pool) 273 - .await?; 274 - 275 - Ok(id) 276 - } 277 - 278 - pub async fn get_passkeys_for_user( 279 - pool: &PgPool, 280 - did: &str, 281 - ) -> Result<Vec<StoredPasskey>, sqlx::Error> { 282 - let rows = sqlx::query!( 283 - r#" 284 - SELECT id, did, credential_id, public_key, sign_count, created_at, last_used, friendly_name, aaguid, transports 285 - FROM passkeys 286 - WHERE did = $1 287 - ORDER BY created_at DESC 288 - "#, 289 - did, 290 - ) 291 - .fetch_all(pool) 292 - .await?; 293 - 294 - Ok(rows 295 - .into_iter() 296 - .map(|r| StoredPasskey { 297 - id: r.id, 298 - did: r.did, 299 - credential_id: r.credential_id, 300 - public_key: r.public_key, 301 - sign_count: r.sign_count, 302 - created_at: r.created_at, 303 - last_used: r.last_used, 304 - friendly_name: r.friendly_name, 305 - aaguid: r.aaguid, 306 - transports: r.transports, 307 - }) 308 - .collect()) 309 - } 310 - 311 - pub async fn get_passkey_by_credential_id( 312 - pool: &PgPool, 313 - credential_id: &[u8], 314 - ) -> Result<Option<StoredPasskey>, sqlx::Error> { 315 - let row = sqlx::query!( 316 - r#" 317 - SELECT id, did, credential_id, public_key, sign_count, created_at, last_used, friendly_name, aaguid, transports 318 - FROM passkeys 319 - WHERE credential_id = $1 320 - "#, 321 - credential_id, 322 - ) 323 - .fetch_optional(pool) 324 - .await?; 325 - 326 - Ok(row.map(|r| StoredPasskey { 327 - id: r.id, 328 - did: r.did, 329 - credential_id: r.credential_id, 330 - public_key: r.public_key, 331 - sign_count: r.sign_count, 332 - created_at: r.created_at, 333 - last_used: r.last_used, 334 - friendly_name: r.friendly_name, 335 - aaguid: r.aaguid, 336 - transports: r.transports, 337 - })) 338 - } 339 - 340 - pub async fn update_passkey_counter( 341 - pool: &PgPool, 342 - credential_id: &[u8], 343 - new_counter: u32, 344 - ) -> Result<bool, sqlx::Error> { 345 - let stored = get_passkey_by_credential_id(pool, credential_id).await?; 346 - let Some(stored) = stored else { 347 - return Err(sqlx::Error::RowNotFound); 348 - }; 349 - 350 - if new_counter > 0 && new_counter <= stored.sign_count as u32 { 351 - tracing::warn!( 352 - credential_id = ?credential_id, 353 - stored_counter = stored.sign_count, 354 - new_counter = new_counter, 355 - "Passkey counter did not increment - possible cloned key!" 356 - ); 357 - return Ok(false); 358 - } 359 - 360 - sqlx::query!( 361 - "UPDATE passkeys SET sign_count = $1, last_used = NOW() WHERE credential_id = $2", 362 - new_counter as i32, 363 - credential_id, 364 - ) 365 - .execute(pool) 366 - .await?; 367 - Ok(true) 368 - } 369 - 370 - pub async fn delete_passkey(pool: &PgPool, id: Uuid, did: &str) -> Result<bool, sqlx::Error> { 371 - let result = sqlx::query("DELETE FROM passkeys WHERE id = $1 AND did = $2") 372 - .bind(id) 373 - .bind(did) 374 - .execute(pool) 375 - .await?; 376 - Ok(result.rows_affected() > 0) 377 - } 378 - 379 - pub async fn update_passkey_name( 380 - pool: &PgPool, 381 - id: Uuid, 382 - did: &str, 383 - name: &str, 384 - ) -> Result<bool, sqlx::Error> { 385 - let result = sqlx::query("UPDATE passkeys SET friendly_name = $1 WHERE id = $2 AND did = $3") 386 - .bind(name) 387 - .bind(id) 388 - .bind(did) 389 - .execute(pool) 390 - .await?; 391 - Ok(result.rows_affected() > 0) 392 - } 393 - 394 - pub async fn has_passkeys(pool: &PgPool, did: &str) -> Result<bool, sqlx::Error> { 395 - let row = sqlx::query("SELECT COUNT(*) as count FROM passkeys WHERE did = $1") 396 - .bind(did) 397 - .fetch_one(pool) 398 - .await?; 399 - let count: i64 = row.get("count"); 400 - Ok(count > 0) 401 - }
+1 -6
crates/tranquil-pds/src/comms/mod.rs
··· 7 7 sanitize_header_value, validate_locale, 8 8 }; 9 9 10 - pub use service::{ 11 - CommsService, channel_display_name, enqueue_2fa_code, enqueue_account_deletion, enqueue_comms, 12 - enqueue_email_update, enqueue_email_update_token, enqueue_migration_verification, 13 - enqueue_passkey_recovery, enqueue_password_reset, enqueue_plc_operation, 14 - enqueue_signup_verification, enqueue_welcome, queue_legacy_login_notification, 15 - }; 10 + pub use service::{CommsService, channel_display_name, repo as comms_repo};
+439 -499
crates/tranquil-pds/src/comms/service.rs
··· 3 3 use std::time::Duration; 4 4 5 5 use chrono::Utc; 6 - use sqlx::PgPool; 7 6 use tokio::sync::watch; 8 7 use tokio::time::interval; 9 8 use tracing::{debug, error, info, warn}; 10 9 use tranquil_comms::{ 11 - CommsChannel, CommsSender, CommsStatus, CommsType, NewComms, QueuedComms, SendError, 12 - format_message, get_strings, 10 + CommsChannel, CommsSender, CommsStatus, CommsType, NewComms, SendError, format_message, 11 + get_strings, 13 12 }; 13 + use tranquil_db_traits::{InfraRepository, QueuedComms, UserRepository}; 14 14 use uuid::Uuid; 15 15 16 16 pub struct CommsService { 17 - db: PgPool, 17 + infra_repo: Arc<dyn InfraRepository>, 18 18 senders: HashMap<CommsChannel, Arc<dyn CommsSender>>, 19 19 poll_interval: Duration, 20 20 batch_size: i64, 21 21 } 22 22 23 23 impl CommsService { 24 - pub fn new(db: PgPool) -> Self { 24 + pub fn new(infra_repo: Arc<dyn InfraRepository>) -> Self { 25 25 let poll_interval_ms: u64 = std::env::var("NOTIFICATION_POLL_INTERVAL_MS") 26 26 .ok() 27 27 .and_then(|v| v.parse().ok()) ··· 31 31 .and_then(|v| v.parse().ok()) 32 32 .unwrap_or(100); 33 33 Self { 34 - db, 34 + infra_repo, 35 35 senders: HashMap::new(), 36 36 poll_interval: Duration::from_millis(poll_interval_ms), 37 37 batch_size, ··· 53 53 self 54 54 } 55 55 56 - pub async fn enqueue(&self, item: NewComms) -> Result<Uuid, sqlx::Error> { 57 - let id = sqlx::query_scalar!( 58 - r#" 59 - INSERT INTO comms_queue 60 - (user_id, channel, comms_type, recipient, subject, body, metadata) 61 - VALUES ($1, $2, $3, $4, $5, $6, $7) 62 - RETURNING id 63 - "#, 64 - item.user_id, 65 - item.channel as CommsChannel, 66 - item.comms_type as CommsType, 67 - item.recipient, 68 - item.subject, 69 - item.body, 70 - item.metadata 71 - ) 72 - .fetch_one(&self.db) 73 - .await?; 56 + pub async fn enqueue(&self, item: NewComms) -> Result<Uuid, tranquil_db_traits::DbError> { 57 + let channel = match item.channel { 58 + CommsChannel::Email => tranquil_db_traits::CommsChannel::Email, 59 + CommsChannel::Discord => tranquil_db_traits::CommsChannel::Discord, 60 + CommsChannel::Telegram => tranquil_db_traits::CommsChannel::Telegram, 61 + CommsChannel::Signal => tranquil_db_traits::CommsChannel::Signal, 62 + }; 63 + let comms_type = match item.comms_type { 64 + CommsType::Welcome => tranquil_db_traits::CommsType::Welcome, 65 + CommsType::EmailVerification => tranquil_db_traits::CommsType::EmailVerification, 66 + CommsType::PasswordReset => tranquil_db_traits::CommsType::PasswordReset, 67 + CommsType::EmailUpdate => tranquil_db_traits::CommsType::EmailUpdate, 68 + CommsType::AccountDeletion => tranquil_db_traits::CommsType::AccountDeletion, 69 + CommsType::AdminEmail => tranquil_db_traits::CommsType::AdminEmail, 70 + CommsType::PlcOperation => tranquil_db_traits::CommsType::PlcOperation, 71 + CommsType::TwoFactorCode => tranquil_db_traits::CommsType::TwoFactorCode, 72 + CommsType::PasskeyRecovery => tranquil_db_traits::CommsType::PasskeyRecovery, 73 + CommsType::LegacyLoginAlert => tranquil_db_traits::CommsType::LegacyLoginAlert, 74 + CommsType::MigrationVerification => tranquil_db_traits::CommsType::MigrationVerification, 75 + CommsType::ChannelVerification => tranquil_db_traits::CommsType::ChannelVerification, 76 + }; 77 + let id = self 78 + .infra_repo 79 + .enqueue_comms( 80 + Some(item.user_id), 81 + channel, 82 + comms_type, 83 + &item.recipient, 84 + item.subject.as_deref(), 85 + &item.body, 86 + item.metadata, 87 + ) 88 + .await?; 74 89 debug!(comms_id = %id, "Comms enqueued"); 75 90 Ok(id) 76 91 } ··· 109 124 } 110 125 } 111 126 112 - async fn process_batch(&self) -> Result<(), sqlx::Error> { 127 + async fn process_batch(&self) -> Result<(), tranquil_db_traits::DbError> { 113 128 let items = self.fetch_pending().await?; 114 129 if items.is_empty() { 115 130 return Ok(()); ··· 119 134 Ok(()) 120 135 } 121 136 122 - async fn fetch_pending(&self) -> Result<Vec<QueuedComms>, sqlx::Error> { 137 + async fn fetch_pending(&self) -> Result<Vec<QueuedComms>, tranquil_db_traits::DbError> { 123 138 let now = Utc::now(); 124 - sqlx::query_as!( 125 - QueuedComms, 126 - r#" 127 - UPDATE comms_queue 128 - SET status = 'processing', updated_at = NOW() 129 - WHERE id IN ( 130 - SELECT id FROM comms_queue 131 - WHERE status = 'pending' 132 - AND scheduled_for <= $1 133 - AND attempts < max_attempts 134 - ORDER BY scheduled_for ASC 135 - LIMIT $2 136 - FOR UPDATE SKIP LOCKED 137 - ) 138 - RETURNING 139 - id, user_id, 140 - channel as "channel: CommsChannel", 141 - comms_type as "comms_type: CommsType", 142 - status as "status: CommsStatus", 143 - recipient, subject, body, metadata, 144 - attempts, max_attempts, last_error, 145 - created_at, updated_at, scheduled_for, processed_at 146 - "#, 147 - now, 148 - self.batch_size 149 - ) 150 - .fetch_all(&self.db) 151 - .await 139 + self.infra_repo.fetch_pending_comms(now, self.batch_size).await 152 140 } 153 141 154 142 async fn process_item(&self, item: QueuedComms) { 155 143 let comms_id = item.id; 156 - let channel = item.channel; 144 + let channel = match item.channel { 145 + tranquil_db_traits::CommsChannel::Email => CommsChannel::Email, 146 + tranquil_db_traits::CommsChannel::Discord => CommsChannel::Discord, 147 + tranquil_db_traits::CommsChannel::Telegram => CommsChannel::Telegram, 148 + tranquil_db_traits::CommsChannel::Signal => CommsChannel::Signal, 149 + }; 150 + let comms_item = tranquil_comms::QueuedComms { 151 + id: item.id, 152 + user_id: item.user_id, 153 + channel, 154 + comms_type: match item.comms_type { 155 + tranquil_db_traits::CommsType::Welcome => CommsType::Welcome, 156 + tranquil_db_traits::CommsType::EmailVerification => CommsType::EmailVerification, 157 + tranquil_db_traits::CommsType::PasswordReset => CommsType::PasswordReset, 158 + tranquil_db_traits::CommsType::EmailUpdate => CommsType::EmailUpdate, 159 + tranquil_db_traits::CommsType::AccountDeletion => CommsType::AccountDeletion, 160 + tranquil_db_traits::CommsType::AdminEmail => CommsType::AdminEmail, 161 + tranquil_db_traits::CommsType::PlcOperation => CommsType::PlcOperation, 162 + tranquil_db_traits::CommsType::TwoFactorCode => CommsType::TwoFactorCode, 163 + tranquil_db_traits::CommsType::PasskeyRecovery => CommsType::PasskeyRecovery, 164 + tranquil_db_traits::CommsType::LegacyLoginAlert => CommsType::LegacyLoginAlert, 165 + tranquil_db_traits::CommsType::MigrationVerification => CommsType::MigrationVerification, 166 + tranquil_db_traits::CommsType::ChannelVerification => CommsType::ChannelVerification, 167 + }, 168 + status: match item.status { 169 + tranquil_db_traits::CommsStatus::Pending => CommsStatus::Pending, 170 + tranquil_db_traits::CommsStatus::Processing => CommsStatus::Processing, 171 + tranquil_db_traits::CommsStatus::Sent => CommsStatus::Sent, 172 + tranquil_db_traits::CommsStatus::Failed => CommsStatus::Failed, 173 + }, 174 + recipient: item.recipient, 175 + subject: item.subject, 176 + body: item.body, 177 + metadata: item.metadata, 178 + attempts: item.attempts, 179 + max_attempts: item.max_attempts, 180 + last_error: item.last_error, 181 + created_at: item.created_at, 182 + updated_at: item.updated_at, 183 + scheduled_for: item.scheduled_for, 184 + processed_at: item.processed_at, 185 + }; 157 186 let result = match self.senders.get(&channel) { 158 - Some(sender) => sender.send(&item).await, 187 + Some(sender) => sender.send(&comms_item).await, 159 188 None => { 160 189 warn!( 161 190 comms_id = %comms_id, ··· 194 223 } 195 224 } 196 225 197 - async fn mark_sent(&self, id: Uuid) -> Result<(), sqlx::Error> { 198 - sqlx::query!( 199 - r#" 200 - UPDATE comms_queue 201 - SET status = 'sent', processed_at = NOW(), updated_at = NOW() 202 - WHERE id = $1 203 - "#, 204 - id 205 - ) 206 - .execute(&self.db) 207 - .await?; 208 - Ok(()) 226 + async fn mark_sent(&self, id: Uuid) -> Result<(), tranquil_db_traits::DbError> { 227 + self.infra_repo.mark_comms_sent(id).await 209 228 } 210 229 211 - async fn mark_failed(&self, id: Uuid, error: &str) -> Result<(), sqlx::Error> { 212 - sqlx::query!( 213 - r#" 214 - UPDATE comms_queue 215 - SET 216 - status = CASE 217 - WHEN attempts + 1 >= max_attempts THEN 'failed'::comms_status 218 - ELSE 'pending'::comms_status 219 - END, 220 - attempts = attempts + 1, 221 - last_error = $2, 222 - updated_at = NOW(), 223 - scheduled_for = NOW() + (INTERVAL '1 minute' * (attempts + 1)) 224 - WHERE id = $1 225 - "#, 226 - id, 227 - error 228 - ) 229 - .execute(&self.db) 230 - .await?; 231 - Ok(()) 230 + async fn mark_failed(&self, id: Uuid, error: &str) -> Result<(), tranquil_db_traits::DbError> { 231 + self.infra_repo.mark_comms_failed(id, error).await 232 232 } 233 233 } 234 234 235 - pub async fn enqueue_comms(db: &PgPool, item: NewComms) -> Result<Uuid, sqlx::Error> { 236 - sqlx::query_scalar!( 237 - r#" 238 - INSERT INTO comms_queue 239 - (user_id, channel, comms_type, recipient, subject, body, metadata) 240 - VALUES ($1, $2, $3, $4, $5, $6, $7) 241 - RETURNING id 242 - "#, 243 - item.user_id, 244 - item.channel as CommsChannel, 245 - item.comms_type as CommsType, 246 - item.recipient, 247 - item.subject, 248 - item.body, 249 - item.metadata 250 - ) 251 - .fetch_one(db) 252 - .await 235 + pub fn channel_display_name(channel: CommsChannel) -> &'static str { 236 + match channel { 237 + CommsChannel::Email => "email", 238 + CommsChannel::Discord => "Discord", 239 + CommsChannel::Telegram => "Telegram", 240 + CommsChannel::Signal => "Signal", 241 + } 253 242 } 254 243 255 - pub struct UserCommsPrefs { 256 - pub channel: CommsChannel, 257 - pub email: Option<String>, 258 - pub handle: crate::types::Handle, 259 - pub locale: String, 244 + fn channel_from_str(s: &str) -> tranquil_db_traits::CommsChannel { 245 + match s { 246 + "discord" => tranquil_db_traits::CommsChannel::Discord, 247 + "telegram" => tranquil_db_traits::CommsChannel::Telegram, 248 + "signal" => tranquil_db_traits::CommsChannel::Signal, 249 + _ => tranquil_db_traits::CommsChannel::Email, 250 + } 260 251 } 261 252 262 - pub async fn get_user_comms_prefs( 263 - db: &PgPool, 264 - user_id: Uuid, 265 - ) -> Result<UserCommsPrefs, sqlx::Error> { 266 - let row = sqlx::query!( 267 - r#" 268 - SELECT 269 - email, 270 - handle, 271 - preferred_comms_channel as "channel: CommsChannel", 272 - preferred_locale 273 - FROM users 274 - WHERE id = $1 275 - "#, 276 - user_id 277 - ) 278 - .fetch_one(db) 279 - .await?; 280 - Ok(UserCommsPrefs { 281 - channel: row.channel, 282 - email: row.email, 283 - handle: row.handle.into(), 284 - locale: row.preferred_locale.unwrap_or_else(|| "en".to_string()), 285 - }) 286 - } 253 + 254 + pub mod repo { 255 + use super::*; 256 + use tranquil_db_traits::DbError; 287 257 288 - pub async fn enqueue_welcome( 289 - db: &PgPool, 290 - user_id: Uuid, 291 - hostname: &str, 292 - ) -> Result<Uuid, sqlx::Error> { 293 - let prefs = get_user_comms_prefs(db, user_id).await?; 294 - let strings = get_strings(&prefs.locale); 295 - let body = format_message( 296 - strings.welcome_body, 297 - &[("hostname", hostname), ("handle", &prefs.handle)], 298 - ); 299 - let subject = format_message(strings.welcome_subject, &[("hostname", hostname)]); 300 - enqueue_comms( 301 - db, 302 - NewComms::new( 303 - user_id, 304 - prefs.channel, 258 + pub async fn enqueue_welcome( 259 + user_repo: &dyn UserRepository, 260 + infra_repo: &dyn InfraRepository, 261 + user_id: Uuid, 262 + hostname: &str, 263 + ) -> Result<Uuid, DbError> { 264 + let prefs = user_repo.get_comms_prefs(user_id).await?.ok_or(DbError::NotFound)?; 265 + let strings = get_strings(prefs.preferred_locale.as_deref().unwrap_or("en")); 266 + let body = format_message( 267 + strings.welcome_body, 268 + &[("hostname", hostname), ("handle", &prefs.handle)], 269 + ); 270 + let subject = format_message(strings.welcome_subject, &[("hostname", hostname)]); 271 + let channel = channel_from_str(&prefs.preferred_channel); 272 + infra_repo.enqueue_comms( 273 + Some(user_id), 274 + channel, 305 275 CommsType::Welcome, 306 - prefs.email.unwrap_or_default(), 307 - Some(subject), 308 - body, 309 - ), 310 - ) 311 - .await 312 - } 276 + &prefs.email.unwrap_or_default(), 277 + Some(&subject), 278 + &body, 279 + None, 280 + ).await 281 + } 313 282 314 - pub async fn enqueue_password_reset( 315 - db: &PgPool, 316 - user_id: Uuid, 317 - code: &str, 318 - hostname: &str, 319 - ) -> Result<Uuid, sqlx::Error> { 320 - let prefs = get_user_comms_prefs(db, user_id).await?; 321 - let strings = get_strings(&prefs.locale); 322 - let body = format_message( 323 - strings.password_reset_body, 324 - &[("handle", &prefs.handle), ("code", code)], 325 - ); 326 - let subject = format_message(strings.password_reset_subject, &[("hostname", hostname)]); 327 - enqueue_comms( 328 - db, 329 - NewComms::new( 330 - user_id, 331 - prefs.channel, 283 + pub async fn enqueue_password_reset( 284 + user_repo: &dyn UserRepository, 285 + infra_repo: &dyn InfraRepository, 286 + user_id: Uuid, 287 + code: &str, 288 + hostname: &str, 289 + ) -> Result<Uuid, DbError> { 290 + let prefs = user_repo.get_comms_prefs(user_id).await?.ok_or(DbError::NotFound)?; 291 + let strings = get_strings(prefs.preferred_locale.as_deref().unwrap_or("en")); 292 + let body = format_message( 293 + strings.password_reset_body, 294 + &[("handle", &prefs.handle), ("code", code)], 295 + ); 296 + let subject = format_message(strings.password_reset_subject, &[("hostname", hostname)]); 297 + let channel = channel_from_str(&prefs.preferred_channel); 298 + infra_repo.enqueue_comms( 299 + Some(user_id), 300 + channel, 332 301 CommsType::PasswordReset, 333 - prefs.email.unwrap_or_default(), 334 - Some(subject), 335 - body, 336 - ), 337 - ) 338 - .await 339 - } 302 + &prefs.email.unwrap_or_default(), 303 + Some(&subject), 304 + &body, 305 + None, 306 + ).await 307 + } 340 308 341 - pub async fn enqueue_email_update( 342 - db: &PgPool, 343 - user_id: Uuid, 344 - new_email: &str, 345 - handle: &str, 346 - code: &str, 347 - hostname: &str, 348 - ) -> Result<Uuid, sqlx::Error> { 349 - let prefs = get_user_comms_prefs(db, user_id).await?; 350 - let strings = get_strings(&prefs.locale); 351 - let encoded_email = urlencoding::encode(new_email); 352 - let encoded_token = urlencoding::encode(code); 353 - let verify_page = format!("https://{}/app/verify", hostname); 354 - let verify_link = format!( 355 - "https://{}/app/verify?token={}&identifier={}", 356 - hostname, encoded_token, encoded_email 357 - ); 358 - let body = format_message( 359 - strings.email_update_body, 360 - &[ 361 - ("handle", handle), 362 - ("code", code), 363 - ("verify_page", &verify_page), 364 - ("verify_link", &verify_link), 365 - ], 366 - ); 367 - let subject = format_message(strings.email_update_subject, &[("hostname", hostname)]); 368 - enqueue_comms( 369 - db, 370 - NewComms::email( 371 - user_id, 309 + pub async fn enqueue_email_update( 310 + infra_repo: &dyn InfraRepository, 311 + user_id: Uuid, 312 + new_email: &str, 313 + handle: &str, 314 + code: &str, 315 + hostname: &str, 316 + ) -> Result<Uuid, DbError> { 317 + let strings = get_strings("en"); 318 + let encoded_email = urlencoding::encode(new_email); 319 + let encoded_token = urlencoding::encode(code); 320 + let verify_page = format!("https://{}/app/verify", hostname); 321 + let verify_link = format!( 322 + "https://{}/app/verify?token={}&identifier={}", 323 + hostname, encoded_token, encoded_email 324 + ); 325 + let body = format_message( 326 + strings.email_update_body, 327 + &[ 328 + ("handle", handle), 329 + ("code", code), 330 + ("verify_page", &verify_page), 331 + ("verify_link", &verify_link), 332 + ], 333 + ); 334 + let subject = format_message(strings.email_update_subject, &[("hostname", hostname)]); 335 + infra_repo.enqueue_comms( 336 + Some(user_id), 337 + tranquil_db_traits::CommsChannel::Email, 372 338 CommsType::EmailUpdate, 373 - new_email.to_string(), 374 - subject, 375 - body, 376 - ), 377 - ) 378 - .await 379 - } 339 + new_email, 340 + Some(&subject), 341 + &body, 342 + None, 343 + ).await 344 + } 380 345 381 - pub async fn enqueue_email_update_token( 382 - db: &PgPool, 383 - user_id: Uuid, 384 - code: &str, 385 - hostname: &str, 386 - ) -> Result<Uuid, sqlx::Error> { 387 - let prefs = get_user_comms_prefs(db, user_id).await?; 388 - let strings = get_strings(&prefs.locale); 389 - let current_email = prefs.email.unwrap_or_default(); 390 - let verify_page = format!("https://{}/app/verify?type=email-update", hostname); 391 - let verify_link = format!( 392 - "https://{}/app/verify?type=email-update&token={}", 393 - hostname, 394 - urlencoding::encode(code) 395 - ); 396 - let body = format_message( 397 - strings.email_update_body, 398 - &[ 399 - ("handle", &prefs.handle), 400 - ("code", code), 401 - ("verify_page", &verify_page), 402 - ("verify_link", &verify_link), 403 - ], 404 - ); 405 - let subject = format_message(strings.email_update_subject, &[("hostname", hostname)]); 406 - enqueue_comms( 407 - db, 408 - NewComms::email( 409 - user_id, 346 + pub async fn enqueue_email_update_token( 347 + user_repo: &dyn UserRepository, 348 + infra_repo: &dyn InfraRepository, 349 + user_id: Uuid, 350 + code: &str, 351 + hostname: &str, 352 + ) -> Result<Uuid, DbError> { 353 + let prefs = user_repo.get_comms_prefs(user_id).await?.ok_or(DbError::NotFound)?; 354 + let strings = get_strings(prefs.preferred_locale.as_deref().unwrap_or("en")); 355 + let current_email = prefs.email.unwrap_or_default(); 356 + let verify_page = format!("https://{}/app/verify?type=email-update", hostname); 357 + let verify_link = format!( 358 + "https://{}/app/verify?type=email-update&token={}", 359 + hostname, 360 + urlencoding::encode(code) 361 + ); 362 + let body = format_message( 363 + strings.email_update_body, 364 + &[ 365 + ("handle", &prefs.handle), 366 + ("code", code), 367 + ("verify_page", &verify_page), 368 + ("verify_link", &verify_link), 369 + ], 370 + ); 371 + let subject = format_message(strings.email_update_subject, &[("hostname", hostname)]); 372 + infra_repo.enqueue_comms( 373 + Some(user_id), 374 + tranquil_db_traits::CommsChannel::Email, 410 375 CommsType::EmailUpdate, 411 - current_email, 412 - subject, 413 - body, 414 - ), 415 - ) 416 - .await 417 - } 376 + &current_email, 377 + Some(&subject), 378 + &body, 379 + None, 380 + ).await 381 + } 418 382 419 - pub async fn enqueue_account_deletion( 420 - db: &PgPool, 421 - user_id: Uuid, 422 - code: &str, 423 - hostname: &str, 424 - ) -> Result<Uuid, sqlx::Error> { 425 - let prefs = get_user_comms_prefs(db, user_id).await?; 426 - let strings = get_strings(&prefs.locale); 427 - let body = format_message( 428 - strings.account_deletion_body, 429 - &[("handle", &prefs.handle), ("code", code)], 430 - ); 431 - let subject = format_message(strings.account_deletion_subject, &[("hostname", hostname)]); 432 - enqueue_comms( 433 - db, 434 - NewComms::new( 435 - user_id, 436 - prefs.channel, 383 + pub async fn enqueue_account_deletion( 384 + user_repo: &dyn UserRepository, 385 + infra_repo: &dyn InfraRepository, 386 + user_id: Uuid, 387 + code: &str, 388 + hostname: &str, 389 + ) -> Result<Uuid, DbError> { 390 + let prefs = user_repo.get_comms_prefs(user_id).await?.ok_or(DbError::NotFound)?; 391 + let strings = get_strings(prefs.preferred_locale.as_deref().unwrap_or("en")); 392 + let body = format_message( 393 + strings.account_deletion_body, 394 + &[("handle", &prefs.handle), ("code", code)], 395 + ); 396 + let subject = format_message(strings.account_deletion_subject, &[("hostname", hostname)]); 397 + let channel = channel_from_str(&prefs.preferred_channel); 398 + infra_repo.enqueue_comms( 399 + Some(user_id), 400 + channel, 437 401 CommsType::AccountDeletion, 438 - prefs.email.unwrap_or_default(), 439 - Some(subject), 440 - body, 441 - ), 442 - ) 443 - .await 444 - } 402 + &prefs.email.unwrap_or_default(), 403 + Some(&subject), 404 + &body, 405 + None, 406 + ).await 407 + } 445 408 446 - pub async fn enqueue_plc_operation( 447 - db: &PgPool, 448 - user_id: Uuid, 449 - token: &str, 450 - hostname: &str, 451 - ) -> Result<Uuid, sqlx::Error> { 452 - let prefs = get_user_comms_prefs(db, user_id).await?; 453 - let strings = get_strings(&prefs.locale); 454 - let body = format_message( 455 - strings.plc_operation_body, 456 - &[("handle", &prefs.handle), ("token", token)], 457 - ); 458 - let subject = format_message(strings.plc_operation_subject, &[("hostname", hostname)]); 459 - enqueue_comms( 460 - db, 461 - NewComms::new( 462 - user_id, 463 - prefs.channel, 409 + pub async fn enqueue_plc_operation( 410 + user_repo: &dyn UserRepository, 411 + infra_repo: &dyn InfraRepository, 412 + user_id: Uuid, 413 + token: &str, 414 + hostname: &str, 415 + ) -> Result<Uuid, DbError> { 416 + let prefs = user_repo.get_comms_prefs(user_id).await?.ok_or(DbError::NotFound)?; 417 + let strings = get_strings(prefs.preferred_locale.as_deref().unwrap_or("en")); 418 + let body = format_message( 419 + strings.plc_operation_body, 420 + &[("handle", &prefs.handle), ("token", token)], 421 + ); 422 + let subject = format_message(strings.plc_operation_subject, &[("hostname", hostname)]); 423 + let channel = channel_from_str(&prefs.preferred_channel); 424 + infra_repo.enqueue_comms( 425 + Some(user_id), 426 + channel, 464 427 CommsType::PlcOperation, 465 - prefs.email.unwrap_or_default(), 466 - Some(subject), 467 - body, 468 - ), 469 - ) 470 - .await 471 - } 428 + &prefs.email.unwrap_or_default(), 429 + Some(&subject), 430 + &body, 431 + None, 432 + ).await 433 + } 472 434 473 - pub async fn enqueue_2fa_code( 474 - db: &PgPool, 475 - user_id: Uuid, 476 - code: &str, 477 - hostname: &str, 478 - ) -> Result<Uuid, sqlx::Error> { 479 - let prefs = get_user_comms_prefs(db, user_id).await?; 480 - let strings = get_strings(&prefs.locale); 481 - let body = format_message( 482 - strings.two_factor_code_body, 483 - &[("handle", &prefs.handle), ("code", code)], 484 - ); 485 - let subject = format_message(strings.two_factor_code_subject, &[("hostname", hostname)]); 486 - enqueue_comms( 487 - db, 488 - NewComms::new( 489 - user_id, 490 - prefs.channel, 491 - CommsType::TwoFactorCode, 492 - prefs.email.unwrap_or_default(), 493 - Some(subject), 494 - body, 495 - ), 496 - ) 497 - .await 498 - } 499 - 500 - pub async fn enqueue_passkey_recovery( 501 - db: &PgPool, 502 - user_id: Uuid, 503 - recovery_url: &str, 504 - hostname: &str, 505 - ) -> Result<Uuid, sqlx::Error> { 506 - let prefs = get_user_comms_prefs(db, user_id).await?; 507 - let strings = get_strings(&prefs.locale); 508 - let body = format_message( 509 - strings.passkey_recovery_body, 510 - &[("handle", &prefs.handle), ("url", recovery_url)], 511 - ); 512 - let subject = format_message(strings.passkey_recovery_subject, &[("hostname", hostname)]); 513 - enqueue_comms( 514 - db, 515 - NewComms::new( 516 - user_id, 517 - prefs.channel, 435 + pub async fn enqueue_passkey_recovery( 436 + user_repo: &dyn UserRepository, 437 + infra_repo: &dyn InfraRepository, 438 + user_id: Uuid, 439 + recovery_url: &str, 440 + hostname: &str, 441 + ) -> Result<Uuid, DbError> { 442 + let prefs = user_repo.get_comms_prefs(user_id).await?.ok_or(DbError::NotFound)?; 443 + let strings = get_strings(prefs.preferred_locale.as_deref().unwrap_or("en")); 444 + let body = format_message( 445 + strings.passkey_recovery_body, 446 + &[("handle", &prefs.handle), ("url", recovery_url)], 447 + ); 448 + let subject = format_message(strings.passkey_recovery_subject, &[("hostname", hostname)]); 449 + let channel = channel_from_str(&prefs.preferred_channel); 450 + infra_repo.enqueue_comms( 451 + Some(user_id), 452 + channel, 518 453 CommsType::PasskeyRecovery, 519 - prefs.email.unwrap_or_default(), 520 - Some(subject), 521 - body, 522 - ), 523 - ) 524 - .await 525 - } 454 + &prefs.email.unwrap_or_default(), 455 + Some(&subject), 456 + &body, 457 + None, 458 + ).await 459 + } 526 460 527 - pub fn channel_display_name(channel: CommsChannel) -> &'static str { 528 - match channel { 529 - CommsChannel::Email => "email", 530 - CommsChannel::Discord => "Discord", 531 - CommsChannel::Telegram => "Telegram", 532 - CommsChannel::Signal => "Signal", 461 + pub async fn enqueue_migration_verification( 462 + user_repo: &dyn UserRepository, 463 + infra_repo: &dyn InfraRepository, 464 + user_id: Uuid, 465 + email: &str, 466 + token: &str, 467 + hostname: &str, 468 + ) -> Result<Uuid, DbError> { 469 + let prefs = user_repo.get_comms_prefs(user_id).await?.ok_or(DbError::NotFound)?; 470 + let strings = get_strings(prefs.preferred_locale.as_deref().unwrap_or("en")); 471 + let encoded_email = urlencoding::encode(email); 472 + let encoded_token = urlencoding::encode(token); 473 + let verify_page = format!("https://{}/app/verify", hostname); 474 + let verify_link = format!( 475 + "https://{}/app/verify?token={}&identifier={}", 476 + hostname, encoded_token, encoded_email 477 + ); 478 + let body = format_message( 479 + strings.migration_verification_body, 480 + &[ 481 + ("code", token), 482 + ("hostname", hostname), 483 + ("verify_page", &verify_page), 484 + ("verify_link", &verify_link), 485 + ], 486 + ); 487 + let subject = format_message( 488 + strings.migration_verification_subject, 489 + &[("hostname", hostname)], 490 + ); 491 + infra_repo.enqueue_comms( 492 + Some(user_id), 493 + tranquil_db_traits::CommsChannel::Email, 494 + CommsType::MigrationVerification, 495 + email, 496 + Some(&subject), 497 + &body, 498 + None, 499 + ).await 533 500 } 534 - } 535 501 536 - pub async fn enqueue_signup_verification( 537 - db: &PgPool, 538 - user_id: Uuid, 539 - channel: &str, 540 - recipient: &str, 541 - code: &str, 542 - locale: Option<&str>, 543 - ) -> Result<Uuid, sqlx::Error> { 544 - let hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 545 - let comms_channel = match channel { 546 - "email" => CommsChannel::Email, 547 - "discord" => CommsChannel::Discord, 548 - "telegram" => CommsChannel::Telegram, 549 - "signal" => CommsChannel::Signal, 550 - _ => CommsChannel::Email, 551 - }; 552 - let strings = get_strings(locale.unwrap_or("en")); 553 - let (verify_page, verify_link) = if comms_channel == CommsChannel::Email { 554 - let encoded_email = urlencoding::encode(recipient); 555 - let encoded_token = urlencoding::encode(code); 556 - ( 557 - format!("https://{}/app/verify", hostname), 558 - format!( 559 - "https://{}/app/verify?token={}&identifier={}", 560 - hostname, encoded_token, encoded_email 561 - ), 562 - ) 563 - } else { 564 - (String::new(), String::new()) 565 - }; 566 - let body = format_message( 567 - strings.signup_verification_body, 568 - &[ 569 - ("code", code), 570 - ("hostname", &hostname), 571 - ("verify_page", &verify_page), 572 - ("verify_link", &verify_link), 573 - ], 574 - ); 575 - let subject = match comms_channel { 576 - CommsChannel::Email => Some(format_message( 577 - strings.signup_verification_subject, 578 - &[("hostname", &hostname)], 579 - )), 580 - _ => None, 581 - }; 582 - enqueue_comms( 583 - db, 584 - NewComms::new( 585 - user_id, 502 + pub async fn enqueue_signup_verification( 503 + infra_repo: &dyn InfraRepository, 504 + user_id: Uuid, 505 + channel: &str, 506 + recipient: &str, 507 + code: &str, 508 + hostname: &str, 509 + ) -> Result<Uuid, DbError> { 510 + let comms_channel = channel_from_str(channel); 511 + let strings = get_strings("en"); 512 + let (verify_page, verify_link) = match comms_channel { 513 + tranquil_db_traits::CommsChannel::Email => { 514 + let encoded_email = urlencoding::encode(recipient); 515 + let encoded_token = urlencoding::encode(code); 516 + ( 517 + format!("https://{}/app/verify", hostname), 518 + format!( 519 + "https://{}/app/verify?token={}&identifier={}", 520 + hostname, encoded_token, encoded_email 521 + ), 522 + ) 523 + } 524 + _ => (String::new(), String::new()), 525 + }; 526 + let body = format_message( 527 + strings.signup_verification_body, 528 + &[ 529 + ("code", code), 530 + ("hostname", hostname), 531 + ("verify_page", &verify_page), 532 + ("verify_link", &verify_link), 533 + ], 534 + ); 535 + let subject = match comms_channel { 536 + tranquil_db_traits::CommsChannel::Email => Some(format_message( 537 + strings.signup_verification_subject, 538 + &[("hostname", hostname)], 539 + )), 540 + _ => None, 541 + }; 542 + infra_repo.enqueue_comms( 543 + Some(user_id), 586 544 comms_channel, 587 545 CommsType::EmailVerification, 588 - recipient.to_string(), 589 - subject, 590 - body, 591 - ), 592 - ) 593 - .await 594 - } 546 + recipient, 547 + subject.as_deref(), 548 + &body, 549 + None, 550 + ).await 551 + } 595 552 596 - pub async fn enqueue_migration_verification( 597 - db: &PgPool, 598 - user_id: Uuid, 599 - email: &str, 600 - token: &str, 601 - hostname: &str, 602 - ) -> Result<Uuid, sqlx::Error> { 603 - let prefs = get_user_comms_prefs(db, user_id).await?; 604 - let strings = get_strings(&prefs.locale); 605 - let encoded_email = urlencoding::encode(email); 606 - let encoded_token = urlencoding::encode(token); 607 - let verify_page = format!("https://{}/app/verify", hostname); 608 - let verify_link = format!( 609 - "https://{}/app/verify?token={}&identifier={}", 610 - hostname, encoded_token, encoded_email 611 - ); 612 - let body = format_message( 613 - strings.migration_verification_body, 614 - &[ 615 - ("code", token), 616 - ("hostname", hostname), 617 - ("verify_page", &verify_page), 618 - ("verify_link", &verify_link), 619 - ], 620 - ); 621 - let subject = format_message( 622 - strings.migration_verification_subject, 623 - &[("hostname", hostname)], 624 - ); 625 - enqueue_comms( 626 - db, 627 - NewComms::email( 628 - user_id, 629 - CommsType::MigrationVerification, 630 - email.to_string(), 631 - subject, 632 - body, 633 - ), 634 - ) 635 - .await 636 - } 553 + pub async fn enqueue_2fa_code( 554 + user_repo: &dyn UserRepository, 555 + infra_repo: &dyn InfraRepository, 556 + user_id: Uuid, 557 + code: &str, 558 + hostname: &str, 559 + ) -> Result<Uuid, DbError> { 560 + let prefs = user_repo.get_comms_prefs(user_id).await?.ok_or(DbError::NotFound)?; 561 + let strings = get_strings(prefs.preferred_locale.as_deref().unwrap_or("en")); 562 + let body = format_message( 563 + strings.two_factor_code_body, 564 + &[("handle", &prefs.handle), ("code", code)], 565 + ); 566 + let subject = format_message(strings.two_factor_code_subject, &[("hostname", hostname)]); 567 + let channel = channel_from_str(&prefs.preferred_channel); 568 + infra_repo.enqueue_comms( 569 + Some(user_id), 570 + channel, 571 + CommsType::TwoFactorCode, 572 + &prefs.email.unwrap_or_default(), 573 + Some(&subject), 574 + &body, 575 + None, 576 + ).await 577 + } 637 578 638 - pub async fn queue_legacy_login_notification( 639 - db: &PgPool, 640 - user_id: Uuid, 641 - hostname: &str, 642 - client_ip: &str, 643 - channel: CommsChannel, 644 - ) -> Result<Uuid, sqlx::Error> { 645 - let prefs = get_user_comms_prefs(db, user_id).await?; 646 - let strings = get_strings(&prefs.locale); 647 - let timestamp = chrono::Utc::now() 648 - .format("%Y-%m-%d %H:%M:%S UTC") 649 - .to_string(); 650 - let body = format_message( 651 - strings.legacy_login_body, 652 - &[ 653 - ("handle", &prefs.handle), 654 - ("timestamp", &timestamp), 655 - ("ip", client_ip), 656 - ("hostname", hostname), 657 - ], 658 - ); 659 - let subject = format_message(strings.legacy_login_subject, &[("hostname", hostname)]); 660 - enqueue_comms( 661 - db, 662 - NewComms::new( 663 - user_id, 579 + pub async fn enqueue_legacy_login( 580 + user_repo: &dyn UserRepository, 581 + infra_repo: &dyn InfraRepository, 582 + user_id: Uuid, 583 + hostname: &str, 584 + client_ip: &str, 585 + channel: tranquil_db_traits::CommsChannel, 586 + ) -> Result<Uuid, DbError> { 587 + let prefs = user_repo.get_comms_prefs(user_id).await?.ok_or(DbError::NotFound)?; 588 + let strings = get_strings(prefs.preferred_locale.as_deref().unwrap_or("en")); 589 + let timestamp = chrono::Utc::now() 590 + .format("%Y-%m-%d %H:%M:%S UTC") 591 + .to_string(); 592 + let body = format_message( 593 + strings.legacy_login_body, 594 + &[ 595 + ("handle", &prefs.handle), 596 + ("timestamp", &timestamp), 597 + ("ip", client_ip), 598 + ("hostname", hostname), 599 + ], 600 + ); 601 + let subject = format_message(strings.legacy_login_subject, &[("hostname", hostname)]); 602 + infra_repo.enqueue_comms( 603 + Some(user_id), 664 604 channel, 665 605 CommsType::LegacyLoginAlert, 666 - prefs.email.unwrap_or_default(), 667 - Some(subject), 668 - body, 669 - ), 670 - ) 671 - .await 606 + &prefs.email.unwrap_or_default(), 607 + Some(&subject), 608 + &body, 609 + None, 610 + ).await 611 + } 672 612 }
-143
crates/tranquil-pds/src/delegation/audit.rs
··· 1 - use chrono::{DateTime, Utc}; 2 - use serde::{Deserialize, Serialize}; 3 - use sqlx::PgPool; 4 - use uuid::Uuid; 5 - 6 - #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, sqlx::Type)] 7 - #[sqlx(type_name = "delegation_action_type", rename_all = "snake_case")] 8 - pub enum DelegationActionType { 9 - GrantCreated, 10 - GrantRevoked, 11 - ScopesModified, 12 - TokenIssued, 13 - RepoWrite, 14 - BlobUpload, 15 - AccountAction, 16 - } 17 - 18 - #[derive(Debug, Clone, Serialize, Deserialize)] 19 - pub struct AuditLogEntry { 20 - pub id: Uuid, 21 - pub delegated_did: String, 22 - pub actor_did: String, 23 - pub controller_did: Option<String>, 24 - pub action_type: DelegationActionType, 25 - pub action_details: Option<serde_json::Value>, 26 - pub ip_address: Option<String>, 27 - pub user_agent: Option<String>, 28 - pub created_at: DateTime<Utc>, 29 - } 30 - 31 - #[allow(clippy::too_many_arguments)] 32 - pub async fn log_delegation_action( 33 - pool: &PgPool, 34 - delegated_did: &str, 35 - actor_did: &str, 36 - controller_did: Option<&str>, 37 - action_type: DelegationActionType, 38 - action_details: Option<serde_json::Value>, 39 - ip_address: Option<&str>, 40 - user_agent: Option<&str>, 41 - ) -> Result<Uuid, sqlx::Error> { 42 - let id = sqlx::query_scalar!( 43 - r#" 44 - INSERT INTO delegation_audit_log 45 - (delegated_did, actor_did, controller_did, action_type, action_details, ip_address, user_agent) 46 - VALUES ($1, $2, $3, $4, $5, $6, $7) 47 - RETURNING id 48 - "#, 49 - delegated_did, 50 - actor_did, 51 - controller_did, 52 - action_type as DelegationActionType, 53 - action_details, 54 - ip_address, 55 - user_agent 56 - ) 57 - .fetch_one(pool) 58 - .await?; 59 - 60 - Ok(id) 61 - } 62 - 63 - pub async fn get_audit_log_for_account( 64 - pool: &PgPool, 65 - delegated_did: &str, 66 - limit: i64, 67 - offset: i64, 68 - ) -> Result<Vec<AuditLogEntry>, sqlx::Error> { 69 - let entries = sqlx::query_as!( 70 - AuditLogEntry, 71 - r#" 72 - SELECT 73 - id, 74 - delegated_did, 75 - actor_did, 76 - controller_did, 77 - action_type as "action_type: DelegationActionType", 78 - action_details, 79 - ip_address, 80 - user_agent, 81 - created_at 82 - FROM delegation_audit_log 83 - WHERE delegated_did = $1 84 - ORDER BY created_at DESC 85 - LIMIT $2 OFFSET $3 86 - "#, 87 - delegated_did, 88 - limit, 89 - offset 90 - ) 91 - .fetch_all(pool) 92 - .await?; 93 - 94 - Ok(entries) 95 - } 96 - 97 - pub async fn get_audit_log_by_controller( 98 - pool: &PgPool, 99 - controller_did: &str, 100 - limit: i64, 101 - offset: i64, 102 - ) -> Result<Vec<AuditLogEntry>, sqlx::Error> { 103 - let entries = sqlx::query_as!( 104 - AuditLogEntry, 105 - r#" 106 - SELECT 107 - id, 108 - delegated_did, 109 - actor_did, 110 - controller_did, 111 - action_type as "action_type: DelegationActionType", 112 - action_details, 113 - ip_address, 114 - user_agent, 115 - created_at 116 - FROM delegation_audit_log 117 - WHERE controller_did = $1 118 - ORDER BY created_at DESC 119 - LIMIT $2 OFFSET $3 120 - "#, 121 - controller_did, 122 - limit, 123 - offset 124 - ) 125 - .fetch_all(pool) 126 - .await?; 127 - 128 - Ok(entries) 129 - } 130 - 131 - pub async fn count_audit_log_entries( 132 - pool: &PgPool, 133 - delegated_did: &str, 134 - ) -> Result<i64, sqlx::Error> { 135 - let count = sqlx::query_scalar!( 136 - r#"SELECT COUNT(*) as "count!" FROM delegation_audit_log WHERE delegated_did = $1"#, 137 - delegated_did 138 - ) 139 - .fetch_one(pool) 140 - .await?; 141 - 142 - Ok(count) 143 - }
-268
crates/tranquil-pds/src/delegation/db.rs
··· 1 - use crate::types::Handle; 2 - use chrono::{DateTime, Utc}; 3 - use serde::{Deserialize, Serialize}; 4 - use sqlx::PgPool; 5 - use uuid::Uuid; 6 - 7 - #[derive(Debug, Clone, Serialize, Deserialize)] 8 - pub struct DelegationGrant { 9 - pub id: Uuid, 10 - pub delegated_did: String, 11 - pub controller_did: String, 12 - pub granted_scopes: String, 13 - pub granted_at: DateTime<Utc>, 14 - pub granted_by: String, 15 - pub revoked_at: Option<DateTime<Utc>>, 16 - pub revoked_by: Option<String>, 17 - } 18 - 19 - #[derive(Debug, Clone, Serialize, Deserialize)] 20 - pub struct DelegatedAccountInfo { 21 - pub did: String, 22 - pub handle: Handle, 23 - pub granted_scopes: String, 24 - pub granted_at: DateTime<Utc>, 25 - } 26 - 27 - #[derive(Debug, Clone, Serialize, Deserialize)] 28 - pub struct ControllerInfo { 29 - pub did: String, 30 - pub handle: Handle, 31 - pub granted_scopes: String, 32 - pub granted_at: DateTime<Utc>, 33 - pub is_active: bool, 34 - } 35 - 36 - pub async fn is_delegated_account(pool: &PgPool, did: &str) -> Result<bool, sqlx::Error> { 37 - let result = sqlx::query_scalar!( 38 - r#"SELECT account_type::text = 'delegated' as "is_delegated!" FROM users WHERE did = $1"#, 39 - did 40 - ) 41 - .fetch_optional(pool) 42 - .await?; 43 - 44 - Ok(result.unwrap_or(false)) 45 - } 46 - 47 - pub async fn create_delegation( 48 - pool: &PgPool, 49 - delegated_did: &str, 50 - controller_did: &str, 51 - granted_scopes: &str, 52 - granted_by: &str, 53 - ) -> Result<Uuid, sqlx::Error> { 54 - let id = sqlx::query_scalar!( 55 - r#" 56 - INSERT INTO account_delegations (delegated_did, controller_did, granted_scopes, granted_by) 57 - VALUES ($1, $2, $3, $4) 58 - RETURNING id 59 - "#, 60 - delegated_did, 61 - controller_did, 62 - granted_scopes, 63 - granted_by 64 - ) 65 - .fetch_one(pool) 66 - .await?; 67 - 68 - Ok(id) 69 - } 70 - 71 - pub async fn revoke_delegation( 72 - pool: &PgPool, 73 - delegated_did: &str, 74 - controller_did: &str, 75 - revoked_by: &str, 76 - ) -> Result<bool, sqlx::Error> { 77 - let result = sqlx::query!( 78 - r#" 79 - UPDATE account_delegations 80 - SET revoked_at = NOW(), revoked_by = $1 81 - WHERE delegated_did = $2 AND controller_did = $3 AND revoked_at IS NULL 82 - "#, 83 - revoked_by, 84 - delegated_did, 85 - controller_did 86 - ) 87 - .execute(pool) 88 - .await?; 89 - 90 - Ok(result.rows_affected() > 0) 91 - } 92 - 93 - pub async fn update_delegation_scopes( 94 - pool: &PgPool, 95 - delegated_did: &str, 96 - controller_did: &str, 97 - new_scopes: &str, 98 - ) -> Result<bool, sqlx::Error> { 99 - let result = sqlx::query!( 100 - r#" 101 - UPDATE account_delegations 102 - SET granted_scopes = $1 103 - WHERE delegated_did = $2 AND controller_did = $3 AND revoked_at IS NULL 104 - "#, 105 - new_scopes, 106 - delegated_did, 107 - controller_did 108 - ) 109 - .execute(pool) 110 - .await?; 111 - 112 - Ok(result.rows_affected() > 0) 113 - } 114 - 115 - pub async fn get_delegation( 116 - pool: &PgPool, 117 - delegated_did: &str, 118 - controller_did: &str, 119 - ) -> Result<Option<DelegationGrant>, sqlx::Error> { 120 - let grant = sqlx::query_as!( 121 - DelegationGrant, 122 - r#" 123 - SELECT id, delegated_did, controller_did, granted_scopes, 124 - granted_at, granted_by, revoked_at, revoked_by 125 - FROM account_delegations 126 - WHERE delegated_did = $1 AND controller_did = $2 AND revoked_at IS NULL 127 - "#, 128 - delegated_did, 129 - controller_did 130 - ) 131 - .fetch_optional(pool) 132 - .await?; 133 - 134 - Ok(grant) 135 - } 136 - 137 - pub async fn get_delegations_for_account( 138 - pool: &PgPool, 139 - delegated_did: &str, 140 - ) -> Result<Vec<ControllerInfo>, sqlx::Error> { 141 - let controllers = sqlx::query_as!( 142 - ControllerInfo, 143 - r#" 144 - SELECT 145 - u.did, 146 - u.handle, 147 - d.granted_scopes, 148 - d.granted_at, 149 - (u.deactivated_at IS NULL AND u.takedown_ref IS NULL) as "is_active!" 150 - FROM account_delegations d 151 - JOIN users u ON u.did = d.controller_did 152 - WHERE d.delegated_did = $1 AND d.revoked_at IS NULL 153 - ORDER BY d.granted_at DESC 154 - "#, 155 - delegated_did 156 - ) 157 - .fetch_all(pool) 158 - .await?; 159 - 160 - Ok(controllers) 161 - } 162 - 163 - pub async fn get_accounts_controlled_by( 164 - pool: &PgPool, 165 - controller_did: &str, 166 - ) -> Result<Vec<DelegatedAccountInfo>, sqlx::Error> { 167 - let accounts = sqlx::query_as!( 168 - DelegatedAccountInfo, 169 - r#" 170 - SELECT 171 - u.did, 172 - u.handle, 173 - d.granted_scopes, 174 - d.granted_at 175 - FROM account_delegations d 176 - JOIN users u ON u.did = d.delegated_did 177 - WHERE d.controller_did = $1 178 - AND d.revoked_at IS NULL 179 - AND u.deactivated_at IS NULL 180 - AND u.takedown_ref IS NULL 181 - ORDER BY d.granted_at DESC 182 - "#, 183 - controller_did 184 - ) 185 - .fetch_all(pool) 186 - .await?; 187 - 188 - Ok(accounts) 189 - } 190 - 191 - pub async fn get_active_controllers_for_account( 192 - pool: &PgPool, 193 - delegated_did: &str, 194 - ) -> Result<Vec<ControllerInfo>, sqlx::Error> { 195 - let controllers = sqlx::query_as!( 196 - ControllerInfo, 197 - r#" 198 - SELECT 199 - u.did, 200 - u.handle, 201 - d.granted_scopes, 202 - d.granted_at, 203 - true as "is_active!" 204 - FROM account_delegations d 205 - JOIN users u ON u.did = d.controller_did 206 - WHERE d.delegated_did = $1 207 - AND d.revoked_at IS NULL 208 - AND u.deactivated_at IS NULL 209 - AND u.takedown_ref IS NULL 210 - ORDER BY d.granted_at DESC 211 - "#, 212 - delegated_did 213 - ) 214 - .fetch_all(pool) 215 - .await?; 216 - 217 - Ok(controllers) 218 - } 219 - 220 - pub async fn count_active_controllers( 221 - pool: &PgPool, 222 - delegated_did: &str, 223 - ) -> Result<i64, sqlx::Error> { 224 - let count = sqlx::query_scalar!( 225 - r#" 226 - SELECT COUNT(*) as "count!" 227 - FROM account_delegations d 228 - JOIN users u ON u.did = d.controller_did 229 - WHERE d.delegated_did = $1 230 - AND d.revoked_at IS NULL 231 - AND u.deactivated_at IS NULL 232 - AND u.takedown_ref IS NULL 233 - "#, 234 - delegated_did 235 - ) 236 - .fetch_one(pool) 237 - .await?; 238 - 239 - Ok(count) 240 - } 241 - 242 - pub async fn has_any_controllers(pool: &PgPool, did: &str) -> Result<bool, sqlx::Error> { 243 - let exists = sqlx::query_scalar!( 244 - r#"SELECT EXISTS( 245 - SELECT 1 FROM account_delegations 246 - WHERE delegated_did = $1 AND revoked_at IS NULL 247 - ) as "exists!""#, 248 - did 249 - ) 250 - .fetch_one(pool) 251 - .await?; 252 - 253 - Ok(exists) 254 - } 255 - 256 - pub async fn controls_any_accounts(pool: &PgPool, did: &str) -> Result<bool, sqlx::Error> { 257 - let exists = sqlx::query_scalar!( 258 - r#"SELECT EXISTS( 259 - SELECT 1 FROM account_delegations 260 - WHERE controller_did = $1 AND revoked_at IS NULL 261 - ) as "exists!""#, 262 - did 263 - ) 264 - .fetch_one(pool) 265 - .await?; 266 - 267 - Ok(exists) 268 - }
+1 -8
crates/tranquil-pds/src/delegation/mod.rs
··· 1 - pub mod audit; 2 - pub mod db; 3 1 pub mod scopes; 4 2 5 - pub use audit::{DelegationActionType, log_delegation_action}; 6 - pub use db::{ 7 - DelegationGrant, controls_any_accounts, create_delegation, get_accounts_controlled_by, 8 - get_delegation, get_delegations_for_account, has_any_controllers, is_delegated_account, 9 - revoke_delegation, update_delegation_scopes, 10 - }; 11 3 pub use scopes::{SCOPE_PRESETS, ScopePreset, intersect_scopes}; 4 + pub use tranquil_db_traits::DelegationActionType;
+10 -8
crates/tranquil-pds/src/main.rs
··· 32 32 33 33 let (shutdown_tx, shutdown_rx) = watch::channel(false); 34 34 35 - let backfill_db = state.db.clone(); 35 + let backfill_repo_repo = state.repo_repo.clone(); 36 36 let backfill_block_store = state.block_store.clone(); 37 37 tokio::spawn(async move { 38 38 tokio::join!( 39 - backfill_genesis_commit_blocks(&backfill_db, backfill_block_store.clone()), 40 - backfill_repo_rev(&backfill_db, backfill_block_store.clone()), 41 - backfill_user_blocks(&backfill_db, backfill_block_store.clone()), 42 - backfill_record_blobs(&backfill_db, backfill_block_store), 39 + backfill_genesis_commit_blocks(backfill_repo_repo.clone(), backfill_block_store.clone()), 40 + backfill_repo_rev(backfill_repo_repo.clone(), backfill_block_store.clone()), 41 + backfill_user_blocks(backfill_repo_repo.clone(), backfill_block_store.clone()), 42 + backfill_record_blobs(backfill_repo_repo, backfill_block_store), 43 43 ); 44 44 }); 45 45 46 - let mut comms_service = CommsService::new(state.db.clone()); 46 + let mut comms_service = CommsService::new(state.infra_repo.clone()); 47 47 48 48 if let Some(email_sender) = EmailSender::from_env() { 49 49 info!("Email comms enabled"); ··· 88 88 let backup_handle = if let Some(backup_storage) = state.backup_storage.clone() { 89 89 info!("Backup service enabled"); 90 90 Some(tokio::spawn(start_backup_tasks( 91 - state.db.clone(), 91 + state.repo_repo.clone(), 92 + state.backup_repo.clone(), 92 93 state.block_store.clone(), 93 94 backup_storage, 94 95 shutdown_rx.clone(), ··· 99 100 }; 100 101 101 102 let scheduled_handle = tokio::spawn(start_scheduled_tasks( 102 - state.db.clone(), 103 + state.user_repo.clone(), 104 + state.blob_repo.clone(), 103 105 state.blob_store.clone(), 104 106 shutdown_rx, 105 107 ));
-46
crates/tranquil-pds/src/oauth/db/client.rs
··· 1 - use super::super::{AuthorizedClientData, OAuthError}; 2 - use super::helpers::{from_json, to_json}; 3 - use sqlx::PgPool; 4 - 5 - pub async fn upsert_authorized_client( 6 - pool: &PgPool, 7 - did: &str, 8 - client_id: &str, 9 - data: &AuthorizedClientData, 10 - ) -> Result<(), OAuthError> { 11 - let data_json = to_json(data)?; 12 - sqlx::query!( 13 - r#" 14 - INSERT INTO oauth_authorized_client (did, client_id, created_at, updated_at, data) 15 - VALUES ($1, $2, NOW(), NOW(), $3) 16 - ON CONFLICT (did, client_id) DO UPDATE SET updated_at = NOW(), data = $3 17 - "#, 18 - did, 19 - client_id, 20 - data_json 21 - ) 22 - .execute(pool) 23 - .await?; 24 - Ok(()) 25 - } 26 - 27 - pub async fn get_authorized_client( 28 - pool: &PgPool, 29 - did: &str, 30 - client_id: &str, 31 - ) -> Result<Option<AuthorizedClientData>, OAuthError> { 32 - let row = sqlx::query_scalar!( 33 - r#" 34 - SELECT data FROM oauth_authorized_client 35 - WHERE did = $1 AND client_id = $2 36 - "#, 37 - did, 38 - client_id 39 - ) 40 - .fetch_optional(pool) 41 - .await?; 42 - match row { 43 - Some(v) => Ok(Some(from_json(v)?)), 44 - None => Ok(None), 45 - } 46 - }
-148
crates/tranquil-pds/src/oauth/db/device.rs
··· 1 - use super::super::{DeviceData, OAuthError}; 2 - use crate::types::Handle; 3 - use chrono::{DateTime, Utc}; 4 - use sqlx::PgPool; 5 - 6 - pub struct DeviceAccountRow { 7 - pub did: String, 8 - pub handle: Handle, 9 - pub email: Option<String>, 10 - pub last_used_at: DateTime<Utc>, 11 - } 12 - 13 - pub async fn create_device( 14 - pool: &PgPool, 15 - device_id: &str, 16 - data: &DeviceData, 17 - ) -> Result<(), OAuthError> { 18 - sqlx::query!( 19 - r#" 20 - INSERT INTO oauth_device (id, session_id, user_agent, ip_address, last_seen_at) 21 - VALUES ($1, $2, $3, $4, $5) 22 - "#, 23 - device_id, 24 - data.session_id, 25 - data.user_agent, 26 - data.ip_address, 27 - data.last_seen_at, 28 - ) 29 - .execute(pool) 30 - .await?; 31 - Ok(()) 32 - } 33 - 34 - pub async fn get_device(pool: &PgPool, device_id: &str) -> Result<Option<DeviceData>, OAuthError> { 35 - let row = sqlx::query!( 36 - r#" 37 - SELECT session_id, user_agent, ip_address, last_seen_at 38 - FROM oauth_device 39 - WHERE id = $1 40 - "#, 41 - device_id 42 - ) 43 - .fetch_optional(pool) 44 - .await?; 45 - Ok(row.map(|r| DeviceData { 46 - session_id: r.session_id, 47 - user_agent: r.user_agent, 48 - ip_address: r.ip_address, 49 - last_seen_at: r.last_seen_at, 50 - })) 51 - } 52 - 53 - pub async fn update_device_last_seen(pool: &PgPool, device_id: &str) -> Result<(), OAuthError> { 54 - sqlx::query!( 55 - r#" 56 - UPDATE oauth_device 57 - SET last_seen_at = NOW() 58 - WHERE id = $1 59 - "#, 60 - device_id 61 - ) 62 - .execute(pool) 63 - .await?; 64 - Ok(()) 65 - } 66 - 67 - pub async fn delete_device(pool: &PgPool, device_id: &str) -> Result<(), OAuthError> { 68 - sqlx::query!( 69 - r#" 70 - DELETE FROM oauth_device WHERE id = $1 71 - "#, 72 - device_id 73 - ) 74 - .execute(pool) 75 - .await?; 76 - Ok(()) 77 - } 78 - 79 - pub async fn upsert_account_device( 80 - pool: &PgPool, 81 - did: &str, 82 - device_id: &str, 83 - ) -> Result<(), OAuthError> { 84 - sqlx::query!( 85 - r#" 86 - INSERT INTO oauth_account_device (did, device_id, created_at, updated_at) 87 - VALUES ($1, $2, NOW(), NOW()) 88 - ON CONFLICT (did, device_id) DO UPDATE SET updated_at = NOW() 89 - "#, 90 - did, 91 - device_id 92 - ) 93 - .execute(pool) 94 - .await?; 95 - Ok(()) 96 - } 97 - 98 - pub async fn get_device_accounts( 99 - pool: &PgPool, 100 - device_id: &str, 101 - ) -> Result<Vec<DeviceAccountRow>, OAuthError> { 102 - let rows = sqlx::query!( 103 - r#" 104 - SELECT u.did, u.handle, u.email, ad.updated_at as last_used_at 105 - FROM oauth_account_device ad 106 - JOIN users u ON u.did = ad.did 107 - WHERE ad.device_id = $1 108 - AND u.deactivated_at IS NULL 109 - AND u.takedown_ref IS NULL 110 - ORDER BY ad.updated_at DESC 111 - "#, 112 - device_id 113 - ) 114 - .fetch_all(pool) 115 - .await?; 116 - Ok(rows 117 - .into_iter() 118 - .map(|r| DeviceAccountRow { 119 - did: r.did, 120 - handle: r.handle.into(), 121 - email: r.email, 122 - last_used_at: r.last_used_at, 123 - }) 124 - .collect()) 125 - } 126 - 127 - pub async fn verify_account_on_device( 128 - pool: &PgPool, 129 - device_id: &str, 130 - did: &str, 131 - ) -> Result<bool, OAuthError> { 132 - let row = sqlx::query!( 133 - r#" 134 - SELECT 1 as exists 135 - FROM oauth_account_device ad 136 - JOIN users u ON u.did = ad.did 137 - WHERE ad.device_id = $1 138 - AND ad.did = $2 139 - AND u.deactivated_at IS NULL 140 - AND u.takedown_ref IS NULL 141 - "#, 142 - device_id, 143 - did 144 - ) 145 - .fetch_optional(pool) 146 - .await?; 147 - Ok(row.is_some()) 148 - }
-32
crates/tranquil-pds/src/oauth/db/dpop.rs
··· 1 - use super::super::OAuthError; 2 - use sqlx::PgPool; 3 - 4 - pub async fn check_and_record_dpop_jti(pool: &PgPool, jti: &str) -> Result<bool, OAuthError> { 5 - let result = sqlx::query!( 6 - r#" 7 - INSERT INTO oauth_dpop_jti (jti) 8 - VALUES ($1) 9 - ON CONFLICT (jti) DO NOTHING 10 - "#, 11 - jti 12 - ) 13 - .execute(pool) 14 - .await?; 15 - Ok(result.rows_affected() > 0) 16 - } 17 - 18 - pub async fn cleanup_expired_dpop_jtis( 19 - pool: &PgPool, 20 - max_age_secs: i64, 21 - ) -> Result<u64, OAuthError> { 22 - let result = sqlx::query!( 23 - r#" 24 - DELETE FROM oauth_dpop_jti 25 - WHERE created_at < NOW() - INTERVAL '1 second' * $1 26 - "#, 27 - max_age_secs as f64 28 - ) 29 - .execute(pool) 30 - .await?; 31 - Ok(result.rows_affected()) 32 - }
-16
crates/tranquil-pds/src/oauth/db/helpers.rs
··· 1 - use super::super::OAuthError; 2 - use serde::{Serialize, de::DeserializeOwned}; 3 - 4 - pub fn to_json<T: Serialize>(value: &T) -> Result<serde_json::Value, OAuthError> { 5 - serde_json::to_value(value).map_err(|e| { 6 - tracing::error!("JSON serialization error: {}", e); 7 - OAuthError::ServerError("Internal serialization error".to_string()) 8 - }) 9 - } 10 - 11 - pub fn from_json<T: DeserializeOwned>(value: serde_json::Value) -> Result<T, OAuthError> { 12 - serde_json::from_value(value).map_err(|e| { 13 - tracing::error!("JSON deserialization error: {}", e); 14 - OAuthError::ServerError("Internal data corruption".to_string()) 15 - }) 16 - }
+4 -33
crates/tranquil-pds/src/oauth/db/mod.rs
··· 1 - mod client; 2 - mod device; 3 - mod dpop; 4 - mod helpers; 5 - mod request; 6 1 mod scope_preference; 7 2 mod token; 8 3 mod two_factor; 9 4 10 - pub use client::{get_authorized_client, upsert_authorized_client}; 11 - pub use device::{ 12 - DeviceAccountRow, create_device, delete_device, get_device, get_device_accounts, 13 - update_device_last_seen, upsert_account_device, verify_account_on_device, 14 - }; 15 - pub use dpop::{check_and_record_dpop_jti, cleanup_expired_dpop_jtis}; 16 - pub use request::{ 17 - consume_authorization_request_by_code, create_authorization_request, 18 - delete_authorization_request, delete_expired_authorization_requests, get_authorization_request, 19 - get_authorization_request_with_state, mark_request_authenticated, set_authorization_did, 20 - set_controller_did, set_request_did, update_authorization_request, update_request_scope, 21 - }; 22 - pub use scope_preference::{ 23 - ScopePreference, delete_scope_preferences, get_scope_preferences, should_show_consent, 24 - upsert_scope_preferences, 25 - }; 26 - pub use token::{ 27 - RefreshTokenLookup, check_refresh_token_used, count_tokens_for_user, create_token, 28 - delete_oldest_tokens_for_user, delete_token, delete_token_family, enforce_token_limit_for_user, 29 - get_token_by_id, get_token_by_previous_refresh_token, get_token_by_refresh_token, 30 - list_tokens_for_user, lookup_refresh_token, revoke_tokens_for_client, 31 - revoke_tokens_for_controller, rotate_token, 32 - }; 33 - pub use two_factor::{ 34 - TwoFactorChallenge, check_user_2fa_enabled, cleanup_expired_2fa_challenges, 35 - create_2fa_challenge, delete_2fa_challenge, delete_2fa_challenge_by_request_uri, 36 - generate_2fa_code, get_2fa_challenge, increment_2fa_attempts, 37 - }; 5 + pub use scope_preference::{ScopePreference, should_show_consent}; 6 + pub use token::{RefreshTokenLookup, enforce_token_limit_for_user, lookup_refresh_token}; 7 + pub use tranquil_db_traits::{DeviceAccountRow, TwoFactorChallenge}; 8 + pub use two_factor::generate_2fa_code;
-265
crates/tranquil-pds/src/oauth/db/request.rs
··· 1 - use super::super::{ 2 - AuthFlowState, AuthorizationRequestParameters, ClientAuth, OAuthError, RequestData, 3 - }; 4 - use super::helpers::{from_json, to_json}; 5 - use sqlx::PgPool; 6 - 7 - pub async fn get_authorization_request_with_state( 8 - pool: &PgPool, 9 - request_id: &str, 10 - ) -> Result<Option<(RequestData, AuthFlowState)>, OAuthError> { 11 - match get_authorization_request(pool, request_id).await? { 12 - Some(data) => { 13 - let state = AuthFlowState::from_request_data(&data); 14 - Ok(Some((data, state))) 15 - } 16 - None => Ok(None), 17 - } 18 - } 19 - 20 - pub async fn create_authorization_request( 21 - pool: &PgPool, 22 - request_id: &str, 23 - data: &RequestData, 24 - ) -> Result<(), OAuthError> { 25 - let client_auth_json = match &data.client_auth { 26 - Some(ca) => Some(to_json(ca)?), 27 - None => None, 28 - }; 29 - let parameters_json = to_json(&data.parameters)?; 30 - sqlx::query!( 31 - r#" 32 - INSERT INTO oauth_authorization_request 33 - (id, did, device_id, client_id, client_auth, parameters, expires_at, code) 34 - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) 35 - "#, 36 - request_id, 37 - data.did, 38 - data.device_id, 39 - data.client_id, 40 - client_auth_json, 41 - parameters_json, 42 - data.expires_at, 43 - data.code, 44 - ) 45 - .execute(pool) 46 - .await?; 47 - Ok(()) 48 - } 49 - 50 - pub async fn get_authorization_request( 51 - pool: &PgPool, 52 - request_id: &str, 53 - ) -> Result<Option<RequestData>, OAuthError> { 54 - let row = sqlx::query!( 55 - r#" 56 - SELECT did, device_id, client_id, client_auth, parameters, expires_at, code, controller_did 57 - FROM oauth_authorization_request 58 - WHERE id = $1 59 - "#, 60 - request_id 61 - ) 62 - .fetch_optional(pool) 63 - .await?; 64 - match row { 65 - Some(r) => { 66 - let client_auth: Option<ClientAuth> = match r.client_auth { 67 - Some(v) => Some(from_json(v)?), 68 - None => None, 69 - }; 70 - let parameters: AuthorizationRequestParameters = from_json(r.parameters)?; 71 - Ok(Some(RequestData { 72 - client_id: r.client_id, 73 - client_auth, 74 - parameters, 75 - expires_at: r.expires_at, 76 - did: r.did, 77 - device_id: r.device_id, 78 - code: r.code, 79 - controller_did: r.controller_did, 80 - })) 81 - } 82 - None => Ok(None), 83 - } 84 - } 85 - 86 - pub async fn set_authorization_did( 87 - pool: &PgPool, 88 - request_id: &str, 89 - did: &str, 90 - device_id: Option<&str>, 91 - ) -> Result<(), OAuthError> { 92 - sqlx::query!( 93 - r#" 94 - UPDATE oauth_authorization_request 95 - SET did = $2, device_id = $3 96 - WHERE id = $1 97 - "#, 98 - request_id, 99 - did, 100 - device_id 101 - ) 102 - .execute(pool) 103 - .await?; 104 - Ok(()) 105 - } 106 - 107 - pub async fn update_authorization_request( 108 - pool: &PgPool, 109 - request_id: &str, 110 - did: &str, 111 - device_id: Option<&str>, 112 - code: &str, 113 - ) -> Result<(), OAuthError> { 114 - sqlx::query!( 115 - r#" 116 - UPDATE oauth_authorization_request 117 - SET did = $2, device_id = $3, code = $4 118 - WHERE id = $1 119 - "#, 120 - request_id, 121 - did, 122 - device_id, 123 - code 124 - ) 125 - .execute(pool) 126 - .await?; 127 - Ok(()) 128 - } 129 - 130 - pub async fn consume_authorization_request_by_code( 131 - pool: &PgPool, 132 - code: &str, 133 - ) -> Result<Option<RequestData>, OAuthError> { 134 - let row = sqlx::query!( 135 - r#" 136 - DELETE FROM oauth_authorization_request 137 - WHERE code = $1 138 - RETURNING did, device_id, client_id, client_auth, parameters, expires_at, code, controller_did 139 - "#, 140 - code 141 - ) 142 - .fetch_optional(pool) 143 - .await?; 144 - match row { 145 - Some(r) => { 146 - let client_auth: Option<ClientAuth> = match r.client_auth { 147 - Some(v) => Some(from_json(v)?), 148 - None => None, 149 - }; 150 - let parameters: AuthorizationRequestParameters = from_json(r.parameters)?; 151 - Ok(Some(RequestData { 152 - client_id: r.client_id, 153 - client_auth, 154 - parameters, 155 - expires_at: r.expires_at, 156 - did: r.did, 157 - device_id: r.device_id, 158 - code: r.code, 159 - controller_did: r.controller_did, 160 - })) 161 - } 162 - None => Ok(None), 163 - } 164 - } 165 - 166 - pub async fn delete_authorization_request( 167 - pool: &PgPool, 168 - request_id: &str, 169 - ) -> Result<(), OAuthError> { 170 - sqlx::query!( 171 - r#" 172 - DELETE FROM oauth_authorization_request WHERE id = $1 173 - "#, 174 - request_id 175 - ) 176 - .execute(pool) 177 - .await?; 178 - Ok(()) 179 - } 180 - 181 - pub async fn delete_expired_authorization_requests(pool: &PgPool) -> Result<u64, OAuthError> { 182 - let result = sqlx::query!( 183 - r#" 184 - DELETE FROM oauth_authorization_request 185 - WHERE expires_at < NOW() 186 - "# 187 - ) 188 - .execute(pool) 189 - .await?; 190 - Ok(result.rows_affected()) 191 - } 192 - 193 - pub async fn mark_request_authenticated( 194 - pool: &PgPool, 195 - request_id: &str, 196 - did: &str, 197 - device_id: Option<&str>, 198 - ) -> Result<(), OAuthError> { 199 - sqlx::query!( 200 - r#" 201 - UPDATE oauth_authorization_request 202 - SET did = $2, device_id = $3 203 - WHERE id = $1 204 - "#, 205 - request_id, 206 - did, 207 - device_id 208 - ) 209 - .execute(pool) 210 - .await?; 211 - Ok(()) 212 - } 213 - 214 - pub async fn update_request_scope( 215 - pool: &PgPool, 216 - request_id: &str, 217 - scope: &str, 218 - ) -> Result<(), OAuthError> { 219 - sqlx::query!( 220 - r#" 221 - UPDATE oauth_authorization_request 222 - SET parameters = jsonb_set(parameters, '{scope}', to_jsonb($2::text)) 223 - WHERE id = $1 224 - "#, 225 - request_id, 226 - scope 227 - ) 228 - .execute(pool) 229 - .await?; 230 - Ok(()) 231 - } 232 - 233 - pub async fn set_controller_did( 234 - pool: &PgPool, 235 - request_id: &str, 236 - controller_did: &str, 237 - ) -> Result<(), OAuthError> { 238 - sqlx::query!( 239 - r#" 240 - UPDATE oauth_authorization_request 241 - SET controller_did = $2 242 - WHERE id = $1 243 - "#, 244 - request_id, 245 - controller_did 246 - ) 247 - .execute(pool) 248 - .await?; 249 - Ok(()) 250 - } 251 - 252 - pub async fn set_request_did(pool: &PgPool, request_id: &str, did: &str) -> Result<(), OAuthError> { 253 - sqlx::query!( 254 - r#" 255 - UPDATE oauth_authorization_request 256 - SET did = $2 257 - WHERE id = $1 258 - "#, 259 - request_id, 260 - did 261 - ) 262 - .execute(pool) 263 - .await?; 264 - Ok(()) 265 - }
+10 -78
crates/tranquil-pds/src/oauth/db/scope_preference.rs
··· 1 1 use super::super::OAuthError; 2 - use serde::{Deserialize, Serialize}; 3 - use sqlx::PgPool; 4 - 5 - #[derive(Debug, Clone, Serialize, Deserialize)] 6 - pub struct ScopePreference { 7 - pub scope: String, 8 - pub granted: bool, 9 - } 10 - 11 - pub async fn get_scope_preferences( 12 - pool: &PgPool, 13 - did: &str, 14 - client_id: &str, 15 - ) -> Result<Vec<ScopePreference>, OAuthError> { 16 - let rows = sqlx::query!( 17 - r#" 18 - SELECT scope, granted FROM oauth_scope_preference 19 - WHERE did = $1 AND client_id = $2 20 - "#, 21 - did, 22 - client_id 23 - ) 24 - .fetch_all(pool) 25 - .await?; 26 - 27 - Ok(rows 28 - .into_iter() 29 - .map(|r| ScopePreference { 30 - scope: r.scope, 31 - granted: r.granted, 32 - }) 33 - .collect()) 34 - } 2 + use tranquil_db_traits::OAuthRepository; 3 + use tranquil_types::{ClientId, Did}; 35 4 36 - pub async fn upsert_scope_preferences( 37 - pool: &PgPool, 38 - did: &str, 39 - client_id: &str, 40 - prefs: &[ScopePreference], 41 - ) -> Result<(), OAuthError> { 42 - for pref in prefs { 43 - sqlx::query!( 44 - r#" 45 - INSERT INTO oauth_scope_preference (did, client_id, scope, granted, created_at, updated_at) 46 - VALUES ($1, $2, $3, $4, NOW(), NOW()) 47 - ON CONFLICT (did, client_id, scope) DO UPDATE SET granted = $4, updated_at = NOW() 48 - "#, 49 - did, 50 - client_id, 51 - pref.scope, 52 - pref.granted 53 - ) 54 - .execute(pool) 55 - .await?; 56 - } 57 - Ok(()) 58 - } 5 + pub use tranquil_db_traits::ScopePreference; 59 6 60 7 pub async fn should_show_consent( 61 - pool: &PgPool, 62 - did: &str, 63 - client_id: &str, 8 + oauth_repo: &dyn OAuthRepository, 9 + did: &Did, 10 + client_id: &ClientId, 64 11 requested_scopes: &[String], 65 12 ) -> Result<bool, OAuthError> { 66 13 if requested_scopes.is_empty() { 67 14 return Ok(false); 68 15 } 69 16 70 - let stored_prefs = get_scope_preferences(pool, did, client_id).await?; 17 + let stored_prefs = oauth_repo 18 + .get_scope_preferences(did, client_id) 19 + .await 20 + .map_err(crate::oauth::db_err_to_oauth)?; 71 21 if stored_prefs.is_empty() { 72 22 return Ok(true); 73 23 } ··· 79 29 .iter() 80 30 .any(|scope| !stored_scopes.contains(scope.as_str()))) 81 31 } 82 - 83 - pub async fn delete_scope_preferences( 84 - pool: &PgPool, 85 - did: &str, 86 - client_id: &str, 87 - ) -> Result<(), OAuthError> { 88 - sqlx::query!( 89 - r#" 90 - DELETE FROM oauth_scope_preference 91 - WHERE did = $1 AND client_id = $2 92 - "#, 93 - did, 94 - client_id 95 - ) 96 - .execute(pool) 97 - .await?; 98 - Ok(()) 99 - }
+34 -382
crates/tranquil-pds/src/oauth/db/token.rs
··· 1 - use super::super::{OAuthError, RefreshTokenState, TokenData}; 2 - use super::helpers::{from_json, to_json}; 3 - use chrono::{DateTime, Utc}; 4 - use sqlx::PgPool; 5 - 6 - pub enum RefreshTokenLookup { 7 - Valid { 8 - db_id: i32, 9 - token_data: TokenData, 10 - }, 11 - InGracePeriod { 12 - db_id: i32, 13 - token_data: TokenData, 14 - rotated_at: DateTime<Utc>, 15 - }, 16 - Used { 17 - original_token_id: i32, 18 - }, 19 - Expired { 20 - db_id: i32, 21 - }, 22 - NotFound, 23 - } 24 - 25 - impl RefreshTokenLookup { 26 - pub fn state(&self) -> RefreshTokenState { 27 - match self { 28 - RefreshTokenLookup::Valid { .. } => RefreshTokenState::Valid, 29 - RefreshTokenLookup::InGracePeriod { rotated_at, .. } => { 30 - RefreshTokenState::InGracePeriod { 31 - rotated_at: *rotated_at, 32 - } 33 - } 34 - RefreshTokenLookup::Used { .. } => RefreshTokenState::Used { at: Utc::now() }, 35 - RefreshTokenLookup::Expired { .. } => RefreshTokenState::Expired, 36 - RefreshTokenLookup::NotFound => RefreshTokenState::Revoked, 37 - } 38 - } 39 - } 1 + use super::super::OAuthError; 2 + use tranquil_db_traits::OAuthRepository; 3 + use tranquil_types::{Did, RefreshToken}; 4 + pub use tranquil_db_traits::RefreshTokenLookup; 40 5 41 6 pub async fn lookup_refresh_token( 42 - pool: &PgPool, 43 - refresh_token: &str, 7 + oauth_repo: &dyn OAuthRepository, 8 + refresh_token: &RefreshToken, 44 9 ) -> Result<RefreshTokenLookup, OAuthError> { 45 - if let Some(token_id) = check_refresh_token_used(pool, refresh_token).await? { 46 - if let Some((db_id, token_data)) = 47 - get_token_by_previous_refresh_token(pool, refresh_token).await? 48 - { 10 + let token_id = oauth_repo 11 + .check_refresh_token_used(refresh_token) 12 + .await 13 + .map_err(crate::oauth::db_err_to_oauth)?; 14 + if let Some(token_id) = token_id { 15 + let prev_token = oauth_repo 16 + .get_token_by_previous_refresh_token(refresh_token) 17 + .await 18 + .map_err(crate::oauth::db_err_to_oauth)?; 19 + if let Some((db_id, token_data)) = prev_token { 49 20 let rotated_at = token_data.updated_at; 50 21 return Ok(RefreshTokenLookup::InGracePeriod { 51 22 db_id, ··· 58 29 }); 59 30 } 60 31 61 - match get_token_by_refresh_token(pool, refresh_token).await? { 32 + let token = oauth_repo 33 + .get_token_by_refresh_token(refresh_token) 34 + .await 35 + .map_err(crate::oauth::db_err_to_oauth)?; 36 + match token { 62 37 Some((db_id, token_data)) => { 63 - if token_data.expires_at < Utc::now() { 38 + if token_data.expires_at < chrono::Utc::now() { 64 39 Ok(RefreshTokenLookup::Expired { db_id }) 65 40 } else { 66 41 Ok(RefreshTokenLookup::Valid { db_id, token_data }) ··· 70 45 } 71 46 } 72 47 73 - pub async fn create_token(pool: &PgPool, data: &TokenData) -> Result<i32, OAuthError> { 74 - let client_auth_json = to_json(&data.client_auth)?; 75 - let parameters_json = to_json(&data.parameters)?; 76 - let row = sqlx::query!( 77 - r#" 78 - INSERT INTO oauth_token 79 - (did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 80 - device_id, parameters, details, code, current_refresh_token, scope, controller_did) 81 - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) 82 - RETURNING id 83 - "#, 84 - data.did, 85 - data.token_id, 86 - data.created_at, 87 - data.updated_at, 88 - data.expires_at, 89 - data.client_id, 90 - client_auth_json, 91 - data.device_id, 92 - parameters_json, 93 - data.details, 94 - data.code, 95 - data.current_refresh_token, 96 - data.scope, 97 - data.controller_did, 98 - ) 99 - .fetch_one(pool) 100 - .await?; 101 - Ok(row.id) 102 - } 103 - 104 - pub async fn get_token_by_id( 105 - pool: &PgPool, 106 - token_id: &str, 107 - ) -> Result<Option<TokenData>, OAuthError> { 108 - let row = sqlx::query!( 109 - r#" 110 - SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 111 - device_id, parameters, details, code, current_refresh_token, scope, controller_did 112 - FROM oauth_token 113 - WHERE token_id = $1 114 - "#, 115 - token_id 116 - ) 117 - .fetch_optional(pool) 118 - .await?; 119 - match row { 120 - Some(r) => Ok(Some(TokenData { 121 - did: r.did, 122 - token_id: r.token_id, 123 - created_at: r.created_at, 124 - updated_at: r.updated_at, 125 - expires_at: r.expires_at, 126 - client_id: r.client_id, 127 - client_auth: from_json(r.client_auth)?, 128 - device_id: r.device_id, 129 - parameters: from_json(r.parameters)?, 130 - details: r.details, 131 - code: r.code, 132 - current_refresh_token: r.current_refresh_token, 133 - scope: r.scope, 134 - controller_did: r.controller_did, 135 - })), 136 - None => Ok(None), 137 - } 138 - } 139 - 140 - pub async fn get_token_by_refresh_token( 141 - pool: &PgPool, 142 - refresh_token: &str, 143 - ) -> Result<Option<(i32, TokenData)>, OAuthError> { 144 - let row = sqlx::query!( 145 - r#" 146 - SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 147 - device_id, parameters, details, code, current_refresh_token, scope, controller_did 148 - FROM oauth_token 149 - WHERE current_refresh_token = $1 150 - "#, 151 - refresh_token 152 - ) 153 - .fetch_optional(pool) 154 - .await?; 155 - match row { 156 - Some(r) => Ok(Some(( 157 - r.id, 158 - TokenData { 159 - did: r.did, 160 - token_id: r.token_id, 161 - created_at: r.created_at, 162 - updated_at: r.updated_at, 163 - expires_at: r.expires_at, 164 - client_id: r.client_id, 165 - client_auth: from_json(r.client_auth)?, 166 - device_id: r.device_id, 167 - parameters: from_json(r.parameters)?, 168 - details: r.details, 169 - code: r.code, 170 - current_refresh_token: r.current_refresh_token, 171 - scope: r.scope, 172 - controller_did: r.controller_did, 173 - }, 174 - ))), 175 - None => Ok(None), 176 - } 177 - } 178 - 179 - pub async fn rotate_token( 180 - pool: &PgPool, 181 - old_db_id: i32, 182 - new_refresh_token: &str, 183 - new_expires_at: DateTime<Utc>, 184 - ) -> Result<(), OAuthError> { 185 - let mut tx = pool.begin().await?; 186 - let old_refresh = sqlx::query_scalar!( 187 - r#" 188 - SELECT current_refresh_token FROM oauth_token WHERE id = $1 189 - "#, 190 - old_db_id 191 - ) 192 - .fetch_one(&mut *tx) 193 - .await?; 194 - if let Some(ref old_rt) = old_refresh { 195 - sqlx::query!( 196 - r#" 197 - INSERT INTO oauth_used_refresh_token (refresh_token, token_id) 198 - VALUES ($1, $2) 199 - "#, 200 - old_rt, 201 - old_db_id 202 - ) 203 - .execute(&mut *tx) 204 - .await?; 205 - } 206 - sqlx::query!( 207 - r#" 208 - UPDATE oauth_token 209 - SET current_refresh_token = $2, expires_at = $3, updated_at = NOW(), 210 - previous_refresh_token = $4, rotated_at = NOW() 211 - WHERE id = $1 212 - "#, 213 - old_db_id, 214 - new_refresh_token, 215 - new_expires_at, 216 - old_refresh 217 - ) 218 - .execute(&mut *tx) 219 - .await?; 220 - tx.commit().await?; 221 - Ok(()) 222 - } 223 - 224 - pub async fn check_refresh_token_used( 225 - pool: &PgPool, 226 - refresh_token: &str, 227 - ) -> Result<Option<i32>, OAuthError> { 228 - let row = sqlx::query_scalar!( 229 - r#" 230 - SELECT token_id FROM oauth_used_refresh_token WHERE refresh_token = $1 231 - "#, 232 - refresh_token 233 - ) 234 - .fetch_optional(pool) 235 - .await?; 236 - Ok(row) 237 - } 238 - 239 - const REFRESH_GRACE_PERIOD_SECS: i64 = 60; 240 - 241 - pub async fn get_token_by_previous_refresh_token( 242 - pool: &PgPool, 243 - refresh_token: &str, 244 - ) -> Result<Option<(i32, TokenData)>, OAuthError> { 245 - let grace_cutoff = Utc::now() - chrono::Duration::seconds(REFRESH_GRACE_PERIOD_SECS); 246 - let row = sqlx::query!( 247 - r#" 248 - SELECT id, did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 249 - device_id, parameters, details, code, current_refresh_token, scope, controller_did 250 - FROM oauth_token 251 - WHERE previous_refresh_token = $1 AND rotated_at > $2 252 - "#, 253 - refresh_token, 254 - grace_cutoff 255 - ) 256 - .fetch_optional(pool) 257 - .await?; 258 - match row { 259 - Some(r) => Ok(Some(( 260 - r.id, 261 - TokenData { 262 - did: r.did, 263 - token_id: r.token_id, 264 - created_at: r.created_at, 265 - updated_at: r.updated_at, 266 - expires_at: r.expires_at, 267 - client_id: r.client_id, 268 - client_auth: from_json(r.client_auth)?, 269 - device_id: r.device_id, 270 - parameters: from_json(r.parameters)?, 271 - details: r.details, 272 - code: r.code, 273 - current_refresh_token: r.current_refresh_token, 274 - scope: r.scope, 275 - controller_did: r.controller_did, 276 - }, 277 - ))), 278 - None => Ok(None), 279 - } 280 - } 281 - 282 - pub async fn delete_token(pool: &PgPool, token_id: &str) -> Result<(), OAuthError> { 283 - sqlx::query!( 284 - r#" 285 - DELETE FROM oauth_token WHERE token_id = $1 286 - "#, 287 - token_id 288 - ) 289 - .execute(pool) 290 - .await?; 291 - Ok(()) 292 - } 293 - 294 - pub async fn delete_token_family(pool: &PgPool, db_id: i32) -> Result<(), OAuthError> { 295 - sqlx::query!( 296 - r#" 297 - DELETE FROM oauth_token WHERE id = $1 298 - "#, 299 - db_id 300 - ) 301 - .execute(pool) 302 - .await?; 303 - Ok(()) 304 - } 305 - 306 - pub async fn list_tokens_for_user(pool: &PgPool, did: &str) -> Result<Vec<TokenData>, OAuthError> { 307 - let rows = sqlx::query!( 308 - r#" 309 - SELECT did, token_id, created_at, updated_at, expires_at, client_id, client_auth, 310 - device_id, parameters, details, code, current_refresh_token, scope, controller_did 311 - FROM oauth_token 312 - WHERE did = $1 313 - "#, 314 - did 315 - ) 316 - .fetch_all(pool) 317 - .await?; 318 - rows.into_iter() 319 - .map(|r| { 320 - Ok(TokenData { 321 - did: r.did, 322 - token_id: r.token_id, 323 - created_at: r.created_at, 324 - updated_at: r.updated_at, 325 - expires_at: r.expires_at, 326 - client_id: r.client_id, 327 - client_auth: from_json(r.client_auth)?, 328 - device_id: r.device_id, 329 - parameters: from_json(r.parameters)?, 330 - details: r.details, 331 - code: r.code, 332 - current_refresh_token: r.current_refresh_token, 333 - scope: r.scope, 334 - controller_did: r.controller_did, 335 - }) 336 - }) 337 - .collect() 338 - } 339 - 340 - pub async fn count_tokens_for_user(pool: &PgPool, did: &str) -> Result<i64, OAuthError> { 341 - let count = sqlx::query_scalar!( 342 - r#" 343 - SELECT COUNT(*) as "count!" FROM oauth_token WHERE did = $1 344 - "#, 345 - did 346 - ) 347 - .fetch_one(pool) 348 - .await?; 349 - Ok(count) 350 - } 351 - 352 - pub async fn delete_oldest_tokens_for_user( 353 - pool: &PgPool, 354 - did: &str, 355 - keep_count: i64, 356 - ) -> Result<u64, OAuthError> { 357 - let result = sqlx::query!( 358 - r#" 359 - DELETE FROM oauth_token 360 - WHERE id IN ( 361 - SELECT id FROM oauth_token 362 - WHERE did = $1 363 - ORDER BY updated_at ASC 364 - OFFSET $2 365 - ) 366 - "#, 367 - did, 368 - keep_count 369 - ) 370 - .execute(pool) 371 - .await?; 372 - Ok(result.rows_affected()) 373 - } 374 - 375 48 const MAX_TOKENS_PER_USER: i64 = 100; 376 49 377 - pub async fn enforce_token_limit_for_user(pool: &PgPool, did: &str) -> Result<(), OAuthError> { 378 - let count = count_tokens_for_user(pool, did).await?; 50 + pub async fn enforce_token_limit_for_user( 51 + oauth_repo: &dyn OAuthRepository, 52 + did: &Did, 53 + ) -> Result<(), OAuthError> { 54 + let count = oauth_repo 55 + .count_tokens_for_user(did) 56 + .await 57 + .map_err(crate::oauth::db_err_to_oauth)?; 379 58 if count > MAX_TOKENS_PER_USER { 380 59 let to_keep = MAX_TOKENS_PER_USER - 1; 381 - delete_oldest_tokens_for_user(pool, did, to_keep).await?; 60 + oauth_repo 61 + .delete_oldest_tokens_for_user(did, to_keep) 62 + .await 63 + .map_err(crate::oauth::db_err_to_oauth)?; 382 64 } 383 65 Ok(()) 384 66 } 385 - 386 - pub async fn revoke_tokens_for_client( 387 - pool: &PgPool, 388 - did: &str, 389 - client_id: &str, 390 - ) -> Result<u64, OAuthError> { 391 - let result = sqlx::query!( 392 - "DELETE FROM oauth_token WHERE did = $1 AND client_id = $2", 393 - did, 394 - client_id 395 - ) 396 - .execute(pool) 397 - .await?; 398 - Ok(result.rows_affected()) 399 - } 400 - 401 - pub async fn revoke_tokens_for_controller( 402 - pool: &PgPool, 403 - delegated_did: &str, 404 - controller_did: &str, 405 - ) -> Result<u64, OAuthError> { 406 - let result = sqlx::query!( 407 - "DELETE FROM oauth_token WHERE did = $1 AND controller_did = $2", 408 - delegated_did, 409 - controller_did 410 - ) 411 - .execute(pool) 412 - .await?; 413 - Ok(result.rows_affected()) 414 - }
-137
crates/tranquil-pds/src/oauth/db/two_factor.rs
··· 1 - use super::super::OAuthError; 2 - use chrono::{DateTime, Duration, Utc}; 3 1 use rand::Rng; 4 - use sqlx::PgPool; 5 - use uuid::Uuid; 6 - 7 - pub struct TwoFactorChallenge { 8 - pub id: Uuid, 9 - pub did: String, 10 - pub request_uri: String, 11 - pub code: String, 12 - pub attempts: i32, 13 - pub created_at: DateTime<Utc>, 14 - pub expires_at: DateTime<Utc>, 15 - } 16 2 17 3 pub fn generate_2fa_code() -> String { 18 4 let mut rng = rand::thread_rng(); 19 5 let code: u32 = rng.gen_range(0..1_000_000); 20 6 format!("{:06}", code) 21 7 } 22 - 23 - pub async fn create_2fa_challenge( 24 - pool: &PgPool, 25 - did: &str, 26 - request_uri: &str, 27 - ) -> Result<TwoFactorChallenge, OAuthError> { 28 - let code = generate_2fa_code(); 29 - let expires_at = Utc::now() + Duration::minutes(10); 30 - let row = sqlx::query!( 31 - r#" 32 - INSERT INTO oauth_2fa_challenge (did, request_uri, code, expires_at) 33 - VALUES ($1, $2, $3, $4) 34 - RETURNING id, did, request_uri, code, attempts, created_at, expires_at 35 - "#, 36 - did, 37 - request_uri, 38 - code, 39 - expires_at, 40 - ) 41 - .fetch_one(pool) 42 - .await?; 43 - Ok(TwoFactorChallenge { 44 - id: row.id, 45 - did: row.did, 46 - request_uri: row.request_uri, 47 - code: row.code, 48 - attempts: row.attempts, 49 - created_at: row.created_at, 50 - expires_at: row.expires_at, 51 - }) 52 - } 53 - 54 - pub async fn get_2fa_challenge( 55 - pool: &PgPool, 56 - request_uri: &str, 57 - ) -> Result<Option<TwoFactorChallenge>, OAuthError> { 58 - let row = sqlx::query!( 59 - r#" 60 - SELECT id, did, request_uri, code, attempts, created_at, expires_at 61 - FROM oauth_2fa_challenge 62 - WHERE request_uri = $1 63 - "#, 64 - request_uri 65 - ) 66 - .fetch_optional(pool) 67 - .await?; 68 - Ok(row.map(|r| TwoFactorChallenge { 69 - id: r.id, 70 - did: r.did, 71 - request_uri: r.request_uri, 72 - code: r.code, 73 - attempts: r.attempts, 74 - created_at: r.created_at, 75 - expires_at: r.expires_at, 76 - })) 77 - } 78 - 79 - pub async fn increment_2fa_attempts(pool: &PgPool, id: Uuid) -> Result<i32, OAuthError> { 80 - let row = sqlx::query!( 81 - r#" 82 - UPDATE oauth_2fa_challenge 83 - SET attempts = attempts + 1 84 - WHERE id = $1 85 - RETURNING attempts 86 - "#, 87 - id 88 - ) 89 - .fetch_one(pool) 90 - .await?; 91 - Ok(row.attempts) 92 - } 93 - 94 - pub async fn delete_2fa_challenge(pool: &PgPool, id: Uuid) -> Result<(), OAuthError> { 95 - sqlx::query!( 96 - r#" 97 - DELETE FROM oauth_2fa_challenge WHERE id = $1 98 - "#, 99 - id 100 - ) 101 - .execute(pool) 102 - .await?; 103 - Ok(()) 104 - } 105 - 106 - pub async fn delete_2fa_challenge_by_request_uri( 107 - pool: &PgPool, 108 - request_uri: &str, 109 - ) -> Result<(), OAuthError> { 110 - sqlx::query!( 111 - r#" 112 - DELETE FROM oauth_2fa_challenge WHERE request_uri = $1 113 - "#, 114 - request_uri 115 - ) 116 - .execute(pool) 117 - .await?; 118 - Ok(()) 119 - } 120 - 121 - pub async fn cleanup_expired_2fa_challenges(pool: &PgPool) -> Result<u64, OAuthError> { 122 - let result = sqlx::query!( 123 - r#" 124 - DELETE FROM oauth_2fa_challenge WHERE expires_at < NOW() 125 - "# 126 - ) 127 - .execute(pool) 128 - .await?; 129 - Ok(result.rows_affected()) 130 - } 131 - 132 - pub async fn check_user_2fa_enabled(pool: &PgPool, did: &str) -> Result<bool, OAuthError> { 133 - let row = sqlx::query!( 134 - r#" 135 - SELECT two_factor_enabled 136 - FROM users 137 - WHERE did = $1 138 - "#, 139 - did 140 - ) 141 - .fetch_optional(pool) 142 - .await?; 143 - Ok(row.map(|r| r.two_factor_enabled).unwrap_or(false)) 144 - }
+421 -283
crates/tranquil-pds/src/oauth/endpoints/authorize.rs
··· 1 - use crate::comms::{CommsChannel, channel_display_name, enqueue_2fa_code}; 1 + use crate::comms::{channel_display_name, comms_repo::enqueue_2fa_code}; 2 2 use crate::oauth::{ 3 - AuthFlowState, ClientMetadataCache, Code, DeviceData, DeviceId, OAuthError, SessionId, db, 3 + AuthFlowState, ClientMetadataCache, Code, DeviceData, DeviceId, OAuthError, SessionId, 4 + db::should_show_consent, 4 5 }; 5 6 use crate::state::{AppState, RateLimitKind}; 6 - use crate::types::{Handle, PlainPassword}; 7 + use tranquil_db_traits::ScopePreference; 8 + use crate::types::{Did, Handle, PlainPassword}; 9 + use tranquil_types::{AuthorizationCode, ClientId, DeviceId as DeviceIdType, RequestId}; 7 10 use axum::{ 8 11 Json, 9 12 extract::{Query, State}, ··· 203 206 ); 204 207 } 205 208 }; 206 - let request_data = match db::get_authorization_request(&state.db, &request_uri).await { 209 + let request_id = RequestId::from(request_uri.clone()); 210 + let request_data = match state.oauth_repo.get_authorization_request(&request_id).await { 207 211 Ok(Some(data)) => data, 208 212 Ok(None) => { 209 213 if wants_json(&headers) { ··· 235 239 } 236 240 }; 237 241 if request_data.expires_at < Utc::now() { 238 - let _ = db::delete_authorization_request(&state.db, &request_uri).await; 242 + let _ = state.oauth_repo.delete_authorization_request(&request_id).await; 239 243 if wants_json(&headers) { 240 244 return ( 241 245 StatusCode::BAD_REQUEST, ··· 273 277 tracing::info!(login_hint = %login_hint, "Checking login_hint for delegation"); 274 278 let pds_hostname = 275 279 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 280 + let hostname_for_handles = pds_hostname.split(':').next().unwrap_or(&pds_hostname); 276 281 let normalized = if login_hint.contains('@') || login_hint.starts_with("did:") { 277 282 login_hint.clone() 278 283 } else if !login_hint.contains('.') { 279 - format!("{}.{}", login_hint.to_lowercase(), pds_hostname) 284 + format!("{}.{}", login_hint.to_lowercase(), hostname_for_handles) 280 285 } else { 281 286 login_hint.to_lowercase() 282 287 }; 283 288 tracing::info!(normalized = %normalized, "Normalized login_hint"); 284 289 285 - match sqlx::query!( 286 - "SELECT did, password_hash FROM users WHERE handle = $1 OR email = $1", 287 - normalized 288 - ) 289 - .fetch_optional(&state.db) 290 - .await 291 - { 290 + match state.user_repo.get_login_check_by_handle_or_email(&normalized).await { 292 291 Ok(Some(user)) => { 293 292 tracing::info!(did = %user.did, has_password = user.password_hash.is_some(), "Found user for login_hint"); 294 - let is_delegated = crate::delegation::is_delegated_account(&state.db, &user.did) 293 + let is_delegated = state 294 + .delegation_repo 295 + .is_delegated_account(&user.did) 295 296 .await 296 297 .unwrap_or(false); 297 298 let has_password = user.password_hash.is_some(); ··· 319 320 320 321 if !force_new_account 321 322 && let Some(device_id) = extract_device_cookie(&headers) 322 - && let Ok(accounts) = db::get_device_accounts(&state.db, &device_id).await 323 + && let Ok(accounts) = state.oauth_repo.get_device_accounts(&DeviceIdType::from(device_id.clone())).await 323 324 && !accounts.is_empty() 324 325 { 325 326 return redirect_see_other(&format!( ··· 340 341 let request_uri = query 341 342 .request_uri 342 343 .ok_or_else(|| OAuthError::InvalidRequest("request_uri is required".to_string()))?; 343 - let request_data = db::get_authorization_request(&state.db, &request_uri) 344 - .await? 344 + let request_id_json = RequestId::from(request_uri.clone()); 345 + let request_data = state.oauth_repo.get_authorization_request(&request_id_json) 346 + .await 347 + .map_err(crate::oauth::db_err_to_oauth)? 345 348 .ok_or_else(|| OAuthError::InvalidRequest("Invalid or expired request_uri".to_string()))?; 346 349 if request_data.expires_at < Utc::now() { 347 - db::delete_authorization_request(&state.db, &request_uri).await?; 350 + let _ = state.oauth_repo.delete_authorization_request(&request_id_json).await; 348 351 return Err(OAuthError::InvalidRequest( 349 352 "request_uri has expired".to_string(), 350 353 )); ··· 417 420 .into_response(); 418 421 } 419 422 }; 420 - let accounts = match db::get_device_accounts(&state.db, &device_id).await { 423 + let device_id_typed = DeviceIdType::from(device_id.clone()); 424 + let accounts = match state.oauth_repo.get_device_accounts(&device_id_typed).await { 421 425 Ok(accts) => accts, 422 426 Err(_) => { 423 427 return Json(AccountsResponse { ··· 430 434 let account_infos: Vec<AccountInfo> = accounts 431 435 .into_iter() 432 436 .map(|row| AccountInfo { 433 - did: row.did, 437 + did: row.did.to_string(), 434 438 handle: row.handle, 435 439 email: row.email.map(|e| mask_email(&e)), 436 440 }) ··· 469 473 "Too many login attempts. Please try again later.", 470 474 ); 471 475 } 472 - let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 476 + let form_request_id = RequestId::from(form.request_uri.clone()); 477 + let request_data = match state.oauth_repo.get_authorization_request(&form_request_id).await { 473 478 Ok(Some(data)) => data, 474 479 Ok(None) => { 475 480 if json_response { ··· 502 507 } 503 508 }; 504 509 if request_data.expires_at < Utc::now() { 505 - let _ = db::delete_authorization_request(&state.db, &form.request_uri).await; 510 + let _ = state.oauth_repo.delete_authorization_request(&form_request_id).await; 506 511 if json_response { 507 512 return ( 508 513 axum::http::StatusCode::BAD_REQUEST, ··· 536 541 )) 537 542 }; 538 543 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 544 + let hostname_for_handles = pds_hostname.split(':').next().unwrap_or(&pds_hostname); 539 545 let normalized_username = form.username.trim(); 540 546 let normalized_username = normalized_username 541 547 .strip_prefix('@') ··· 543 549 let normalized_username = if normalized_username.contains('@') { 544 550 normalized_username.to_string() 545 551 } else if !normalized_username.contains('.') { 546 - format!("{}.{}", normalized_username, pds_hostname) 552 + format!("{}.{}", normalized_username, hostname_for_handles) 547 553 } else { 548 554 normalized_username.to_string() 549 555 }; ··· 553 559 pds_hostname = %pds_hostname, 554 560 "Normalized username for lookup" 555 561 ); 556 - let user = match sqlx::query!( 557 - r#" 558 - SELECT id, did, email, password_hash, password_required, two_factor_enabled, 559 - preferred_comms_channel as "preferred_comms_channel: CommsChannel", 560 - deactivated_at, takedown_ref, 561 - email_verified, discord_verified, telegram_verified, signal_verified, 562 - account_type::text as "account_type!" 563 - FROM users 564 - WHERE handle = $1 OR email = $1 565 - "#, 566 - normalized_username 567 - ) 568 - .fetch_optional(&state.db) 569 - .await 570 - { 562 + let user = match state.user_repo.get_login_info_by_handle_or_email(&normalized_username).await { 571 563 Ok(Some(u)) => u, 572 564 Ok(None) => { 573 565 let _ = bcrypt::verify( ··· 596 588 } 597 589 598 590 if user.account_type == "delegated" { 599 - if db::set_authorization_did(&state.db, &form.request_uri, &user.did, None) 591 + if state.oauth_repo.set_authorization_did(&form_request_id, &user.did, None) 600 592 .await 601 593 .is_err() 602 594 { ··· 622 614 } 623 615 624 616 if !user.password_required { 625 - if db::set_authorization_did(&state.db, &form.request_uri, &user.did, None) 617 + if state.oauth_repo.set_authorization_did(&form_request_id, &user.did, None) 626 618 .await 627 619 .is_err() 628 620 { ··· 661 653 if has_totp { 662 654 let device_cookie = extract_device_cookie(&headers); 663 655 let device_is_trusted = if let Some(ref dev_id) = device_cookie { 664 - crate::api::server::is_device_trusted(&state.db, dev_id, &user.did).await 656 + crate::api::server::is_device_trusted(state.oauth_repo.as_ref(), dev_id, &user.did).await 665 657 } else { 666 658 false 667 659 }; 668 660 669 661 if device_is_trusted { 670 662 if let Some(ref dev_id) = device_cookie { 671 - let _ = crate::api::server::extend_device_trust(&state.db, dev_id).await; 663 + let _ = crate::api::server::extend_device_trust(state.oauth_repo.as_ref(), dev_id).await; 672 664 } 673 665 } else { 674 - if db::set_authorization_did(&state.db, &form.request_uri, &user.did, None) 666 + if state.oauth_repo.set_authorization_did(&form_request_id, &user.did, None) 675 667 .await 676 668 .is_err() 677 669 { ··· 690 682 } 691 683 } 692 684 if user.two_factor_enabled { 693 - let _ = db::delete_2fa_challenge_by_request_uri(&state.db, &form.request_uri).await; 694 - match db::create_2fa_challenge(&state.db, &user.did, &form.request_uri).await { 685 + let _ = state.oauth_repo.delete_2fa_challenge_by_request_uri(&form_request_id).await; 686 + match state.oauth_repo.create_2fa_challenge(&user.did, &form_request_id).await { 695 687 Ok(challenge) => { 696 688 let hostname = 697 689 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 698 690 if let Err(e) = 699 - enqueue_2fa_code(&state.db, user.id, &challenge.code, &hostname).await 691 + enqueue_2fa_code(state.user_repo.as_ref(), state.infra_repo.as_ref(), user.id, &challenge.code, &hostname).await 700 692 { 701 693 tracing::warn!( 702 694 did = %user.did, ··· 736 728 ip_address: extract_client_ip(&headers), 737 729 last_seen_at: Utc::now(), 738 730 }; 739 - if db::create_device(&state.db, &new_id.0, &device_data) 731 + let new_device_id_typed = DeviceIdType::from(new_id.0.clone()); 732 + if state.oauth_repo.create_device(&new_device_id_typed, &device_data) 740 733 .await 741 734 .is_ok() 742 735 { ··· 745 738 } 746 739 new_id.0 747 740 }; 748 - let _ = db::upsert_account_device(&state.db, &user.did, &final_device_id).await; 741 + let final_device_typed = DeviceIdType::from(final_device_id.clone()); 742 + let _ = state.oauth_repo.upsert_account_device(&user.did, &final_device_typed).await; 749 743 } 750 - if db::set_authorization_did( 751 - &state.db, 752 - &form.request_uri, 753 - &user.did, 754 - device_id.as_deref(), 755 - ) 756 - .await 757 - .is_err() 744 + let set_auth_device_id = device_id.as_ref().map(|d| DeviceIdType::from(d.clone())); 745 + if state 746 + .oauth_repo 747 + .set_authorization_did(&form_request_id, &user.did, set_auth_device_id.as_ref()) 748 + .await 749 + .is_err() 758 750 { 759 751 return show_login_error("An error occurred. Please try again.", json_response); 760 752 } ··· 767 759 .split_whitespace() 768 760 .map(|s| s.to_string()) 769 761 .collect(); 770 - let needs_consent = db::should_show_consent( 771 - &state.db, 762 + let client_id_typed = ClientId::from(request_data.parameters.client_id.clone()); 763 + let needs_consent = should_show_consent( 764 + state.oauth_repo.as_ref(), 772 765 &user.did, 773 - &request_data.parameters.client_id, 766 + &client_id_typed, 774 767 &requested_scopes, 775 768 ) 776 769 .await ··· 801 794 return redirect_see_other(&consent_url); 802 795 } 803 796 let code = Code::generate(); 804 - if db::update_authorization_request( 805 - &state.db, 806 - &form.request_uri, 797 + let auth_post_device_id = device_id.as_ref().map(|d| DeviceIdType::from(d.clone())); 798 + let auth_post_code = AuthorizationCode::from(code.0.clone()); 799 + if state.oauth_repo.update_authorization_request( 800 + &form_request_id, 807 801 &user.did, 808 - device_id.as_deref(), 809 - &code.0, 802 + auth_post_device_id.as_ref(), 803 + &auth_post_code, 810 804 ) 811 805 .await 812 806 .is_err() ··· 864 858 ) 865 859 .into_response() 866 860 }; 867 - let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 861 + let select_request_id = RequestId::from(form.request_uri.clone()); 862 + let request_data = match state.oauth_repo.get_authorization_request(&select_request_id).await { 868 863 Ok(Some(data)) => data, 869 864 Ok(None) => { 870 865 return json_error( ··· 882 877 } 883 878 }; 884 879 if request_data.expires_at < Utc::now() { 885 - let _ = db::delete_authorization_request(&state.db, &form.request_uri).await; 880 + let _ = state.oauth_repo.delete_authorization_request(&select_request_id).await; 886 881 return json_error( 887 882 StatusCode::BAD_REQUEST, 888 883 "invalid_request", ··· 899 894 ); 900 895 } 901 896 }; 902 - let account_valid = match db::verify_account_on_device(&state.db, &device_id, &form.did).await { 897 + let did: Did = match form.did.parse() { 898 + Ok(d) => d, 899 + Err(_) => { 900 + return json_error( 901 + StatusCode::BAD_REQUEST, 902 + "invalid_request", 903 + "Invalid DID format.", 904 + ); 905 + } 906 + }; 907 + let verify_device_id = DeviceIdType::from(device_id.clone()); 908 + let account_valid = match state.oauth_repo.verify_account_on_device(&verify_device_id, &did).await { 903 909 Ok(valid) => valid, 904 910 Err(_) => { 905 911 return json_error( ··· 916 922 "This account is not available on this device. Please sign in.", 917 923 ); 918 924 } 919 - let user = match sqlx::query!( 920 - r#" 921 - SELECT id, two_factor_enabled, 922 - preferred_comms_channel as "preferred_comms_channel: CommsChannel", 923 - email_verified, discord_verified, telegram_verified, signal_verified 924 - FROM users 925 - WHERE did = $1 926 - "#, 927 - form.did 928 - ) 929 - .fetch_optional(&state.db) 930 - .await 931 - { 925 + let user = match state.user_repo.get_2fa_status_by_did(&did).await { 932 926 Ok(Some(u)) => u, 933 927 Ok(None) => { 934 928 return json_error( ··· 956 950 "Please verify your account before logging in.", 957 951 ); 958 952 } 959 - let has_totp = crate::api::server::has_totp_enabled(&state, &form.did).await; 953 + let has_totp = crate::api::server::has_totp_enabled(&state, &did).await; 954 + let select_early_device_typed = DeviceIdType::from(device_id.clone()); 960 955 if has_totp { 961 - if db::set_authorization_did(&state.db, &form.request_uri, &form.did, Some(&device_id)) 956 + if state.oauth_repo.set_authorization_did(&select_request_id, &did, Some(&select_early_device_typed)) 962 957 .await 963 958 .is_err() 964 959 { ··· 974 969 .into_response(); 975 970 } 976 971 if user.two_factor_enabled { 977 - let _ = db::delete_2fa_challenge_by_request_uri(&state.db, &form.request_uri).await; 978 - match db::create_2fa_challenge(&state.db, &form.did, &form.request_uri).await { 972 + let _ = state.oauth_repo.delete_2fa_challenge_by_request_uri(&select_request_id).await; 973 + match state.oauth_repo.create_2fa_challenge(&did, &select_request_id).await { 979 974 Ok(challenge) => { 980 975 let hostname = 981 976 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 982 977 if let Err(e) = 983 - enqueue_2fa_code(&state.db, user.id, &challenge.code, &hostname).await 978 + enqueue_2fa_code(state.user_repo.as_ref(), state.infra_repo.as_ref(), user.id, &challenge.code, &hostname).await 984 979 { 985 980 tracing::warn!( 986 981 did = %form.did, ··· 1004 999 } 1005 1000 } 1006 1001 } 1007 - let _ = db::upsert_account_device(&state.db, &form.did, &device_id).await; 1002 + let select_device_typed = DeviceIdType::from(device_id.clone()); 1003 + let _ = state.oauth_repo.upsert_account_device(&did, &select_device_typed).await; 1008 1004 let code = Code::generate(); 1009 - if db::update_authorization_request( 1010 - &state.db, 1011 - &form.request_uri, 1012 - &form.did, 1013 - Some(&device_id), 1014 - &code.0, 1005 + let select_code = AuthorizationCode::from(code.0.clone()); 1006 + if state.oauth_repo.update_authorization_request( 1007 + &select_request_id, 1008 + &did, 1009 + Some(&select_device_typed), 1010 + &select_code, 1015 1011 ) 1016 1012 .await 1017 1013 .is_err() ··· 1124 1120 State(state): State<AppState>, 1125 1121 Json(form): Json<AuthorizeDenyForm>, 1126 1122 ) -> Response { 1127 - let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 1123 + let deny_request_id = RequestId::from(form.request_uri.clone()); 1124 + let request_data = match state.oauth_repo.get_authorization_request(&deny_request_id).await { 1128 1125 Ok(Some(data)) => data, 1129 1126 Ok(None) => { 1130 1127 return ( ··· 1147 1144 .into_response(); 1148 1145 } 1149 1146 }; 1150 - let _ = db::delete_authorization_request(&state.db, &form.request_uri).await; 1147 + let _ = state.oauth_repo.delete_authorization_request(&deny_request_id).await; 1151 1148 let redirect_uri = &request_data.parameters.redirect_uri; 1152 1149 let mut redirect_url = redirect_uri.to_string(); 1153 1150 let separator = if redirect_url.contains('?') { '&' } else { '?' }; ··· 1188 1185 State(state): State<AppState>, 1189 1186 Query(query): Query<Authorize2faQuery>, 1190 1187 ) -> Response { 1191 - let challenge = match db::get_2fa_challenge(&state.db, &query.request_uri).await { 1188 + let twofa_request_id = RequestId::from(query.request_uri.clone()); 1189 + let challenge = match state.oauth_repo.get_2fa_challenge(&twofa_request_id).await { 1192 1190 Ok(Some(c)) => c, 1193 1191 Ok(None) => { 1194 1192 return redirect_to_frontend_error( ··· 1204 1202 } 1205 1203 }; 1206 1204 if challenge.expires_at < Utc::now() { 1207 - let _ = db::delete_2fa_challenge(&state.db, challenge.id).await; 1205 + let _ = state.oauth_repo.delete_2fa_challenge(challenge.id).await; 1208 1206 return redirect_to_frontend_error( 1209 1207 "invalid_request", 1210 1208 "2FA code has expired. Please start over.", 1211 1209 ); 1212 1210 } 1213 - let _request_data = match db::get_authorization_request(&state.db, &query.request_uri).await { 1211 + let _request_data = match state.oauth_repo.get_authorization_request(&twofa_request_id).await { 1214 1212 Ok(Some(d)) => d, 1215 1213 Ok(None) => { 1216 1214 return redirect_to_frontend_error( ··· 1279 1277 State(state): State<AppState>, 1280 1278 Query(query): Query<ConsentQuery>, 1281 1279 ) -> Response { 1282 - let (request_data, flow_state) = 1283 - match db::get_authorization_request_with_state(&state.db, &query.request_uri).await { 1284 - Ok(Some(result)) => result, 1280 + let consent_request_id = RequestId::from(query.request_uri.clone()); 1281 + let request_data = 1282 + match state.oauth_repo.get_authorization_request(&consent_request_id).await { 1283 + Ok(Some(data)) => data, 1285 1284 Ok(None) => { 1286 1285 return json_error( 1287 1286 StatusCode::BAD_REQUEST, ··· 1297 1296 ); 1298 1297 } 1299 1298 }; 1299 + let flow_state = AuthFlowState::from_request_data(&request_data); 1300 1300 1301 1301 if let Some(err_response) = validate_auth_flow_state(&flow_state, true) { 1302 1302 if flow_state.is_expired() { 1303 - let _ = db::delete_authorization_request(&state.db, &query.request_uri).await; 1303 + let _ = state.oauth_repo.delete_authorization_request(&consent_request_id).await; 1304 1304 } 1305 1305 return err_response; 1306 1306 } 1307 1307 1308 - let did = flow_state.did().unwrap().to_string(); 1308 + let did_str = flow_state.did().unwrap().to_string(); 1309 + let did: Did = match did_str.parse() { 1310 + Ok(d) => d, 1311 + Err(_) => { 1312 + return json_error( 1313 + StatusCode::BAD_REQUEST, 1314 + "invalid_request", 1315 + "Invalid DID format in request.", 1316 + ); 1317 + } 1318 + }; 1309 1319 let client_cache = ClientMetadataCache::new(3600); 1310 1320 let client_metadata = client_cache 1311 1321 .get(&request_data.parameters.client_id) ··· 1318 1328 .filter(|s| !s.trim().is_empty()) 1319 1329 .unwrap_or("atproto"); 1320 1330 1321 - let delegation_grant = if let Some(ref ctrl_did) = request_data.controller_did { 1322 - crate::delegation::get_delegation(&state.db, &did, ctrl_did) 1331 + let controller_did_parsed: Option<Did> = request_data.controller_did.as_ref().and_then(|s| s.parse().ok()); 1332 + let delegation_grant = if let Some(ref ctrl_did) = controller_did_parsed { 1333 + state 1334 + .delegation_repo 1335 + .get_delegation(&did, ctrl_did) 1323 1336 .await 1324 1337 .ok() 1325 1338 .flatten() ··· 1328 1341 }; 1329 1342 1330 1343 let effective_scope_str = if let Some(ref grant) = delegation_grant { 1331 - crate::delegation::scopes::intersect_scopes(requested_scope_str, &grant.granted_scopes) 1344 + crate::delegation::intersect_scopes(requested_scope_str, &grant.granted_scopes) 1332 1345 } else { 1333 1346 requested_scope_str.to_string() 1334 1347 }; 1335 1348 1336 1349 let requested_scopes: Vec<&str> = effective_scope_str.split_whitespace().collect(); 1350 + let consent_client_id = ClientId::from(request_data.parameters.client_id.clone()); 1337 1351 let preferences = 1338 - db::get_scope_preferences(&state.db, &did, &request_data.parameters.client_id) 1352 + state.oauth_repo.get_scope_preferences(&did, &consent_client_id) 1339 1353 .await 1340 1354 .unwrap_or_default(); 1341 1355 let pref_map: std::collections::HashMap<_, _> = preferences ··· 1344 1358 .collect(); 1345 1359 let requested_scope_strings: Vec<String> = 1346 1360 requested_scopes.iter().map(|s| s.to_string()).collect(); 1347 - let show_consent = db::should_show_consent( 1348 - &state.db, 1361 + let show_consent = should_show_consent( 1362 + state.oauth_repo.as_ref(), 1349 1363 &did, 1350 - &request_data.parameters.client_id, 1364 + &consent_client_id, 1351 1365 &requested_scope_strings, 1352 1366 ) 1353 1367 .await ··· 1389 1403 } 1390 1404 }) 1391 1405 .collect(); 1392 - let (is_delegation, controller_did, controller_handle, delegation_level) = 1393 - if let Some(ref ctrl_did) = request_data.controller_did { 1394 - let ctrl_handle = 1395 - sqlx::query_scalar!("SELECT handle FROM users WHERE did = $1", ctrl_did) 1396 - .fetch_optional(&state.db) 1397 - .await 1398 - .ok() 1399 - .flatten(); 1406 + let (is_delegation, controller_did_resp, controller_handle, delegation_level) = 1407 + if let Some(ref ctrl_did) = controller_did_parsed { 1408 + let ctrl_handle = state 1409 + .user_repo 1410 + .get_handle_by_did(ctrl_did) 1411 + .await 1412 + .ok() 1413 + .flatten() 1414 + .map(|h| h.to_string()); 1400 1415 1401 1416 let level = if let Some(ref grant) = delegation_grant { 1402 - let preset = crate::delegation::SCOPE_PRESETS 1403 - .iter() 1404 - .find(|p| p.scopes == grant.granted_scopes); 1417 + let preset = crate::delegation::SCOPE_PRESETS.iter().find(|p| p.scopes == grant.granted_scopes); 1405 1418 preset 1406 1419 .map(|p| p.label.to_string()) 1407 1420 .unwrap_or_else(|| "Custom".to_string()) ··· 1409 1422 "Unknown".to_string() 1410 1423 }; 1411 1424 1412 - (Some(true), Some(ctrl_did.clone()), ctrl_handle, Some(level)) 1425 + (Some(true), Some(ctrl_did.to_string()), ctrl_handle, Some(level)) 1413 1426 } else { 1414 1427 (None, None, None, None) 1415 1428 }; ··· 1422 1435 logo_uri: client_metadata.as_ref().and_then(|m| m.logo_uri.clone()), 1423 1436 scopes, 1424 1437 show_consent, 1425 - did, 1438 + did: did_str, 1426 1439 is_delegation, 1427 - controller_did, 1440 + controller_did: controller_did_resp, 1428 1441 controller_handle, 1429 1442 delegation_level, 1430 1443 }) ··· 1440 1453 form.approved_scopes, 1441 1454 form.remember 1442 1455 ); 1443 - let (request_data, flow_state) = 1444 - match db::get_authorization_request_with_state(&state.db, &form.request_uri).await { 1445 - Ok(Some(result)) => result, 1456 + let consent_post_request_id = RequestId::from(form.request_uri.clone()); 1457 + let request_data = 1458 + match state.oauth_repo.get_authorization_request(&consent_post_request_id).await { 1459 + Ok(Some(data)) => data, 1446 1460 Ok(None) => { 1447 1461 return json_error( 1448 1462 StatusCode::BAD_REQUEST, ··· 1458 1472 ); 1459 1473 } 1460 1474 }; 1475 + let flow_state = AuthFlowState::from_request_data(&request_data); 1461 1476 1462 1477 if flow_state.is_expired() { 1463 - let _ = db::delete_authorization_request(&state.db, &form.request_uri).await; 1478 + let _ = state.oauth_repo.delete_authorization_request(&consent_post_request_id).await; 1464 1479 return json_error( 1465 1480 StatusCode::BAD_REQUEST, 1466 1481 "invalid_request", ··· 1471 1486 return json_error(StatusCode::FORBIDDEN, "access_denied", "Not authenticated"); 1472 1487 } 1473 1488 1474 - let did = flow_state.did().unwrap().to_string(); 1489 + let did_str = flow_state.did().unwrap().to_string(); 1490 + let did: Did = match did_str.parse() { 1491 + Ok(d) => d, 1492 + Err(_) => { 1493 + return json_error( 1494 + StatusCode::BAD_REQUEST, 1495 + "invalid_request", 1496 + "Invalid DID format", 1497 + ); 1498 + } 1499 + }; 1475 1500 let original_scope_str = request_data 1476 1501 .parameters 1477 1502 .scope 1478 1503 .as_deref() 1479 1504 .unwrap_or("atproto"); 1480 1505 1481 - let delegation_grant = if let Some(ref ctrl_did) = request_data.controller_did { 1482 - crate::delegation::get_delegation(&state.db, &did, ctrl_did) 1506 + let controller_did_parsed: Option<Did> = request_data 1507 + .controller_did 1508 + .as_ref() 1509 + .and_then(|s| s.parse().ok()); 1510 + 1511 + let delegation_grant = match controller_did_parsed.as_ref() { 1512 + Some(ctrl_did) => state 1513 + .delegation_repo 1514 + .get_delegation(&did, ctrl_did) 1483 1515 .await 1484 1516 .ok() 1485 - .flatten() 1486 - } else { 1487 - None 1517 + .flatten(), 1518 + None => None, 1488 1519 }; 1489 1520 1490 1521 let effective_scope_str = if let Some(ref grant) = delegation_grant { 1491 - crate::delegation::scopes::intersect_scopes(original_scope_str, &grant.granted_scopes) 1522 + crate::delegation::intersect_scopes(original_scope_str, &grant.granted_scopes) 1492 1523 } else { 1493 1524 original_scope_str.to_string() 1494 1525 }; ··· 1537 1568 ); 1538 1569 } 1539 1570 if form.remember { 1540 - let preferences: Vec<db::ScopePreference> = requested_scopes 1571 + let preferences: Vec<ScopePreference> = requested_scopes 1541 1572 .iter() 1542 - .map(|s| db::ScopePreference { 1573 + .map(|s| ScopePreference { 1543 1574 scope: s.to_string(), 1544 1575 granted: form.approved_scopes.contains(&s.to_string()), 1545 1576 }) 1546 1577 .collect(); 1547 - let _ = db::upsert_scope_preferences( 1548 - &state.db, 1578 + let consent_post_client_id = ClientId::from(request_data.parameters.client_id.clone()); 1579 + let _ = state.oauth_repo.upsert_scope_preferences( 1549 1580 &did, 1550 - &request_data.parameters.client_id, 1581 + &consent_post_client_id, 1551 1582 &preferences, 1552 1583 ) 1553 1584 .await; 1554 1585 } 1555 1586 if let Err(e) = 1556 - db::update_request_scope(&state.db, &form.request_uri, &approved_scope_str).await 1587 + state.oauth_repo.update_request_scope(&consent_post_request_id, &approved_scope_str).await 1557 1588 { 1558 1589 tracing::warn!("Failed to update request scope: {:?}", e); 1559 1590 } 1560 1591 let code = Code::generate(); 1561 - if db::update_authorization_request( 1562 - &state.db, 1563 - &form.request_uri, 1592 + let consent_post_device_id = request_data.device_id.as_ref().map(|d| DeviceIdType::from(d.clone())); 1593 + let consent_post_code = AuthorizationCode::from(code.0.clone()); 1594 + if state.oauth_repo.update_authorization_request( 1595 + &consent_post_request_id, 1564 1596 &did, 1565 - request_data.device_id.as_deref(), 1566 - &code.0, 1597 + consent_post_device_id.as_ref(), 1598 + &consent_post_code, 1567 1599 ) 1568 1600 .await 1569 1601 .is_err() ··· 1616 1648 "Too many attempts. Please try again later.", 1617 1649 ); 1618 1650 } 1619 - let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 1651 + let twofa_post_request_id = RequestId::from(form.request_uri.clone()); 1652 + let request_data = match state.oauth_repo.get_authorization_request(&twofa_post_request_id).await { 1620 1653 Ok(Some(d)) => d, 1621 1654 Ok(None) => { 1622 1655 return json_error( ··· 1634 1667 } 1635 1668 }; 1636 1669 if request_data.expires_at < Utc::now() { 1637 - let _ = db::delete_authorization_request(&state.db, &form.request_uri).await; 1670 + let _ = state.oauth_repo.delete_authorization_request(&twofa_post_request_id).await; 1638 1671 return json_error( 1639 1672 StatusCode::BAD_REQUEST, 1640 1673 "invalid_request", 1641 1674 "Authorization request has expired.", 1642 1675 ); 1643 1676 } 1644 - let challenge = db::get_2fa_challenge(&state.db, &form.request_uri) 1677 + let challenge = state.oauth_repo.get_2fa_challenge(&twofa_post_request_id) 1645 1678 .await 1646 1679 .ok() 1647 1680 .flatten(); 1648 1681 if let Some(challenge) = challenge { 1649 1682 if challenge.expires_at < Utc::now() { 1650 - let _ = db::delete_2fa_challenge(&state.db, challenge.id).await; 1683 + let _ = state.oauth_repo.delete_2fa_challenge( challenge.id).await; 1651 1684 return json_error( 1652 1685 StatusCode::BAD_REQUEST, 1653 1686 "invalid_request", ··· 1655 1688 ); 1656 1689 } 1657 1690 if challenge.attempts >= MAX_2FA_ATTEMPTS { 1658 - let _ = db::delete_2fa_challenge(&state.db, challenge.id).await; 1691 + let _ = state.oauth_repo.delete_2fa_challenge( challenge.id).await; 1659 1692 return json_error( 1660 1693 StatusCode::FORBIDDEN, 1661 1694 "access_denied", ··· 1669 1702 .ct_eq(challenge.code.as_bytes()) 1670 1703 .into(); 1671 1704 if !code_valid { 1672 - let _ = db::increment_2fa_attempts(&state.db, challenge.id).await; 1705 + let _ = state.oauth_repo.increment_2fa_attempts(challenge.id).await; 1673 1706 return json_error( 1674 1707 StatusCode::FORBIDDEN, 1675 1708 "invalid_code", 1676 1709 "Invalid verification code. Please try again.", 1677 1710 ); 1678 1711 } 1679 - let _ = db::delete_2fa_challenge(&state.db, challenge.id).await; 1712 + let _ = state.oauth_repo.delete_2fa_challenge(challenge.id).await; 1680 1713 let code = Code::generate(); 1681 1714 let device_id = extract_device_cookie(&headers); 1682 - if db::update_authorization_request( 1683 - &state.db, 1684 - &form.request_uri, 1715 + let twofa_totp_device_id = device_id.as_ref().map(|d| DeviceIdType::from(d.clone())); 1716 + let twofa_totp_code = AuthorizationCode::from(code.0.clone()); 1717 + if state.oauth_repo.update_authorization_request( 1718 + &twofa_post_request_id, 1685 1719 &challenge.did, 1686 - device_id.as_deref(), 1687 - &code.0, 1720 + twofa_totp_device_id.as_ref(), 1721 + &twofa_totp_code, 1688 1722 ) 1689 1723 .await 1690 1724 .is_err() ··· 1706 1740 })) 1707 1741 .into_response(); 1708 1742 } 1709 - let did = match &request_data.did { 1743 + let did_str = match &request_data.did { 1710 1744 Some(d) => d.clone(), 1711 1745 None => { 1712 1746 return json_error( ··· 1716 1750 ); 1717 1751 } 1718 1752 }; 1753 + let did: tranquil_types::Did = match did_str.parse() { 1754 + Ok(d) => d, 1755 + Err(_) => { 1756 + return json_error(StatusCode::BAD_REQUEST, "invalid_request", "Invalid DID format."); 1757 + } 1758 + }; 1719 1759 if !crate::api::server::has_totp_enabled(&state, &did).await { 1720 1760 return json_error( 1721 1761 StatusCode::BAD_REQUEST, ··· 1747 1787 if form.trust_device 1748 1788 && let Some(ref dev_id) = device_id 1749 1789 { 1750 - let _ = crate::api::server::trust_device(&state.db, dev_id).await; 1790 + let _ = crate::api::server::trust_device(state.oauth_repo.as_ref(), dev_id).await; 1751 1791 } 1752 1792 let requested_scope_str = request_data 1753 1793 .parameters ··· 1758 1798 .split_whitespace() 1759 1799 .map(|s| s.to_string()) 1760 1800 .collect(); 1761 - let needs_consent = db::should_show_consent( 1762 - &state.db, 1801 + let twofa_post_client_id = ClientId::from(request_data.parameters.client_id.clone()); 1802 + let needs_consent = should_show_consent( 1803 + state.oauth_repo.as_ref(), 1763 1804 &did, 1764 - &request_data.parameters.client_id, 1805 + &twofa_post_client_id, 1765 1806 &requested_scopes, 1766 1807 ) 1767 1808 .await ··· 1774 1815 return Json(serde_json::json!({"redirect_uri": consent_url})).into_response(); 1775 1816 } 1776 1817 let code = Code::generate(); 1777 - if db::update_authorization_request( 1778 - &state.db, 1779 - &form.request_uri, 1818 + let twofa_final_device_id = device_id.as_ref().map(|d| DeviceIdType::from(d.clone())); 1819 + let twofa_final_code = AuthorizationCode::from(code.0.clone()); 1820 + if state.oauth_repo.update_authorization_request( 1821 + &twofa_post_request_id, 1780 1822 &did, 1781 - device_id.as_deref(), 1782 - &code.0, 1823 + twofa_final_device_id.as_ref(), 1824 + &twofa_final_code, 1783 1825 ) 1784 1826 .await 1785 1827 .is_err() ··· 1819 1861 Query(query): Query<CheckPasskeysQuery>, 1820 1862 ) -> Response { 1821 1863 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 1864 + let hostname_for_handles = pds_hostname.split(':').next().unwrap_or(&pds_hostname); 1822 1865 let normalized_identifier = query.identifier.trim(); 1823 1866 let normalized_identifier = normalized_identifier 1824 1867 .strip_prefix('@') 1825 1868 .unwrap_or(normalized_identifier); 1826 1869 let normalized_identifier = if let Some(bare_handle) = 1827 - normalized_identifier.strip_suffix(&format!(".{}", pds_hostname)) 1870 + normalized_identifier.strip_suffix(&format!(".{}", hostname_for_handles)) 1828 1871 { 1829 1872 bare_handle.to_string() 1830 1873 } else { 1831 1874 normalized_identifier.to_string() 1832 1875 }; 1833 1876 1834 - let user = sqlx::query!( 1835 - "SELECT did FROM users WHERE handle = $1 OR email = $1", 1836 - normalized_identifier 1837 - ) 1838 - .fetch_optional(&state.db) 1839 - .await; 1877 + let user = state 1878 + .user_repo 1879 + .get_login_check_by_handle_or_email(&normalized_identifier) 1880 + .await; 1840 1881 1841 1882 let has_passkeys = match user { 1842 1883 Ok(Some(u)) => crate::api::server::has_passkeys_for_user(&state, &u.did).await, ··· 1862 1903 Query(query): Query<CheckPasskeysQuery>, 1863 1904 ) -> Response { 1864 1905 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 1906 + let hostname_for_handles = pds_hostname.split(':').next().unwrap_or(&pds_hostname); 1865 1907 let identifier = query.identifier.trim(); 1866 1908 let identifier = identifier.strip_prefix('@').unwrap_or(identifier); 1867 1909 let normalized_identifier = if identifier.contains('@') || identifier.starts_with("did:") { 1868 1910 identifier.to_string() 1869 1911 } else if !identifier.contains('.') { 1870 - format!("{}.{}", identifier.to_lowercase(), pds_hostname) 1912 + format!("{}.{}", identifier.to_lowercase(), hostname_for_handles) 1871 1913 } else { 1872 1914 identifier.to_lowercase() 1873 1915 }; 1874 1916 1875 - let user = sqlx::query!( 1876 - "SELECT did, password_hash FROM users WHERE handle = $1 OR email = $1", 1877 - normalized_identifier 1878 - ) 1879 - .fetch_optional(&state.db) 1880 - .await; 1917 + let user = state 1918 + .user_repo 1919 + .get_login_check_by_handle_or_email(&normalized_identifier) 1920 + .await; 1881 1921 1882 1922 let (has_passkeys, has_totp, has_password, is_delegated, did): ( 1883 1923 bool, ··· 1890 1930 let passkeys = crate::api::server::has_passkeys_for_user(&state, &u.did).await; 1891 1931 let totp = crate::api::server::has_totp_enabled(&state, &u.did).await; 1892 1932 let has_pw = u.password_hash.is_some(); 1893 - let has_controllers = crate::delegation::is_delegated_account(&state.db, &u.did) 1933 + let has_controllers = state 1934 + .delegation_repo 1935 + .is_delegated_account(&u.did) 1894 1936 .await 1895 1937 .unwrap_or(false); 1896 - (passkeys, totp, has_pw, has_controllers, Some(u.did)) 1938 + (passkeys, totp, has_pw, has_controllers, Some(u.did.to_string())) 1897 1939 } 1898 1940 _ => (false, false, false, false, None), 1899 1941 }; ··· 1942 1984 .into_response(); 1943 1985 } 1944 1986 1945 - let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 1987 + let passkey_start_request_id = RequestId::from(form.request_uri.clone()); 1988 + let request_data = match state.oauth_repo.get_authorization_request(&passkey_start_request_id).await { 1946 1989 Ok(Some(data)) => data, 1947 1990 Ok(None) => { 1948 1991 return ( ··· 1967 2010 }; 1968 2011 1969 2012 if request_data.expires_at < Utc::now() { 1970 - let _ = db::delete_authorization_request(&state.db, &form.request_uri).await; 2013 + let _ = state.oauth_repo.delete_authorization_request(&passkey_start_request_id).await; 1971 2014 return ( 1972 2015 StatusCode::BAD_REQUEST, 1973 2016 Json(serde_json::json!({ ··· 1979 2022 } 1980 2023 1981 2024 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 2025 + let hostname_for_handles = pds_hostname.split(':').next().unwrap_or(&pds_hostname); 1982 2026 let normalized_username = form.identifier.trim(); 1983 2027 let normalized_username = normalized_username 1984 2028 .strip_prefix('@') ··· 1986 2030 let normalized_username = if normalized_username.contains('@') { 1987 2031 normalized_username.to_string() 1988 2032 } else if !normalized_username.contains('.') { 1989 - format!("{}.{}", normalized_username, pds_hostname) 2033 + format!("{}.{}", normalized_username, hostname_for_handles) 1990 2034 } else { 1991 2035 normalized_username.to_string() 1992 2036 }; 1993 2037 1994 - let user = match sqlx::query!( 1995 - r#" 1996 - SELECT did, deactivated_at, takedown_ref, 1997 - email_verified, discord_verified, telegram_verified, signal_verified 1998 - FROM users 1999 - WHERE handle = $1 OR email = $1 2000 - "#, 2001 - normalized_username 2002 - ) 2003 - .fetch_optional(&state.db) 2004 - .await 2005 - { 2038 + let user = match state.user_repo.get_login_info_by_handle_or_email(&normalized_username).await { 2006 2039 Ok(Some(u)) => u, 2007 2040 Ok(None) => { 2008 2041 return ( ··· 2064 2097 .into_response(); 2065 2098 } 2066 2099 2067 - let stored_passkeys = 2068 - match crate::auth::webauthn::get_passkeys_for_user(&state.db, &user.did).await { 2069 - Ok(pks) => pks, 2070 - Err(e) => { 2071 - tracing::error!(error = %e, "Failed to get passkeys"); 2072 - return ( 2073 - StatusCode::INTERNAL_SERVER_ERROR, 2074 - Json(serde_json::json!({ 2075 - "error": "server_error", 2076 - "error_description": "An error occurred." 2077 - })), 2078 - ) 2079 - .into_response(); 2080 - } 2081 - }; 2100 + let stored_passkeys = match state.user_repo.get_passkeys_for_user(&user.did).await { 2101 + Ok(pks) => pks, 2102 + Err(e) => { 2103 + tracing::error!(error = %e, "Failed to get passkeys"); 2104 + return ( 2105 + StatusCode::INTERNAL_SERVER_ERROR, 2106 + Json(serde_json::json!({ 2107 + "error": "server_error", 2108 + "error_description": "An error occurred." 2109 + })), 2110 + ) 2111 + .into_response(); 2112 + } 2113 + }; 2082 2114 2083 2115 if stored_passkeys.is_empty() { 2084 2116 return ( ··· 2093 2125 2094 2126 let passkeys: Vec<webauthn_rs::prelude::SecurityKey> = stored_passkeys 2095 2127 .iter() 2096 - .filter_map(|sp| sp.to_security_key().ok()) 2128 + .filter_map(|sp| serde_json::from_slice(&sp.public_key).ok()) 2097 2129 .collect(); 2098 2130 2099 2131 if passkeys.is_empty() { ··· 2137 2169 } 2138 2170 }; 2139 2171 2140 - if let Err(e) = 2141 - crate::auth::webauthn::save_authentication_state(&state.db, &user.did, &auth_state).await 2172 + let state_json = match serde_json::to_string(&auth_state) { 2173 + Ok(j) => j, 2174 + Err(e) => { 2175 + tracing::error!(error = %e, "Failed to serialize authentication state"); 2176 + return ( 2177 + StatusCode::INTERNAL_SERVER_ERROR, 2178 + Json(serde_json::json!({ 2179 + "error": "server_error", 2180 + "error_description": "An error occurred." 2181 + })), 2182 + ) 2183 + .into_response(); 2184 + } 2185 + }; 2186 + 2187 + if let Err(e) = state 2188 + .user_repo 2189 + .save_webauthn_challenge(&user.did, "authentication", &state_json) 2190 + .await 2142 2191 { 2143 2192 tracing::error!(error = %e, "Failed to save authentication state"); 2144 2193 return ( ··· 2151 2200 .into_response(); 2152 2201 } 2153 2202 2154 - if db::set_authorization_did(&state.db, &form.request_uri, &user.did, None) 2203 + if state.oauth_repo.set_authorization_did(&passkey_start_request_id, &user.did, None) 2155 2204 .await 2156 2205 .is_err() 2157 2206 { ··· 2181 2230 headers: HeaderMap, 2182 2231 Json(form): Json<PasskeyFinishInput>, 2183 2232 ) -> Response { 2184 - let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 2233 + let passkey_finish_request_id = RequestId::from(form.request_uri.clone()); 2234 + let request_data = match state.oauth_repo.get_authorization_request(&passkey_finish_request_id).await { 2185 2235 Ok(Some(data)) => data, 2186 2236 Ok(None) => { 2187 2237 return ( ··· 2206 2256 }; 2207 2257 2208 2258 if request_data.expires_at < Utc::now() { 2209 - let _ = db::delete_authorization_request(&state.db, &form.request_uri).await; 2259 + let _ = state.oauth_repo.delete_authorization_request(&passkey_finish_request_id).await; 2210 2260 return ( 2211 2261 StatusCode::BAD_REQUEST, 2212 2262 Json(serde_json::json!({ ··· 2217 2267 .into_response(); 2218 2268 } 2219 2269 2220 - let did = match request_data.did { 2270 + let did_str = match request_data.did { 2221 2271 Some(d) => d, 2222 2272 None => { 2223 2273 return ( ··· 2230 2280 .into_response(); 2231 2281 } 2232 2282 }; 2283 + let did: tranquil_types::Did = match did_str.parse() { 2284 + Ok(d) => d, 2285 + Err(_) => { 2286 + return ( 2287 + StatusCode::BAD_REQUEST, 2288 + Json(serde_json::json!({ 2289 + "error": "invalid_request", 2290 + "error_description": "Invalid DID format." 2291 + })), 2292 + ) 2293 + .into_response(); 2294 + } 2295 + }; 2233 2296 2234 - let auth_state = match crate::auth::webauthn::load_authentication_state(&state.db, &did).await { 2297 + let auth_state_json = match state 2298 + .user_repo 2299 + .load_webauthn_challenge(&did, "authentication") 2300 + .await 2301 + { 2235 2302 Ok(Some(s)) => s, 2236 2303 Ok(None) => { 2237 2304 return ( ··· 2255 2322 .into_response(); 2256 2323 } 2257 2324 }; 2325 + 2326 + let auth_state: webauthn_rs::prelude::SecurityKeyAuthentication = 2327 + match serde_json::from_str(&auth_state_json) { 2328 + Ok(s) => s, 2329 + Err(e) => { 2330 + tracing::error!(error = %e, "Failed to deserialize authentication state"); 2331 + return ( 2332 + StatusCode::INTERNAL_SERVER_ERROR, 2333 + Json(serde_json::json!({ 2334 + "error": "server_error", 2335 + "error_description": "An error occurred." 2336 + })), 2337 + ) 2338 + .into_response(); 2339 + } 2340 + }; 2258 2341 2259 2342 let credential: webauthn_rs::prelude::PublicKeyCredential = 2260 2343 match serde_json::from_value(form.credential) { ··· 2303 2386 } 2304 2387 }; 2305 2388 2306 - if let Err(e) = crate::auth::webauthn::delete_authentication_state(&state.db, &did).await { 2389 + if let Err(e) = state 2390 + .user_repo 2391 + .delete_webauthn_challenge(&did, "authentication") 2392 + .await 2393 + { 2307 2394 tracing::warn!(error = %e, "Failed to delete authentication state"); 2308 2395 } 2309 2396 2310 2397 if auth_result.needs_update() { 2311 - match crate::auth::webauthn::update_passkey_counter( 2312 - &state.db, 2313 - auth_result.cred_id(), 2314 - auth_result.counter(), 2315 - ) 2316 - .await 2398 + let cred_id_bytes = auth_result.cred_id().as_slice(); 2399 + match state 2400 + .user_repo 2401 + .update_passkey_counter(cred_id_bytes, auth_result.counter() as i32) 2402 + .await 2317 2403 { 2318 2404 Ok(false) => { 2319 2405 tracing::warn!(did = %did, "Passkey counter anomaly detected - possible cloned key"); ··· 2343 2429 .into_response(); 2344 2430 } 2345 2431 2346 - let user = sqlx::query!( 2347 - "SELECT two_factor_enabled, preferred_comms_channel as \"preferred_comms_channel: CommsChannel\", id FROM users WHERE did = $1", 2348 - did 2349 - ) 2350 - .fetch_optional(&state.db) 2351 - .await; 2432 + let user = state.user_repo.get_2fa_status_by_did(&did).await; 2352 2433 2353 2434 if let Ok(Some(user)) = user 2354 2435 && user.two_factor_enabled 2355 2436 { 2356 - let _ = db::delete_2fa_challenge_by_request_uri(&state.db, &form.request_uri).await; 2357 - match db::create_2fa_challenge(&state.db, &did, &form.request_uri).await { 2437 + let _ = state.oauth_repo.delete_2fa_challenge_by_request_uri(&passkey_finish_request_id).await; 2438 + match state.oauth_repo.create_2fa_challenge(&did, &passkey_finish_request_id).await { 2358 2439 Ok(challenge) => { 2359 2440 let hostname = 2360 2441 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 2361 2442 if let Err(e) = 2362 - enqueue_2fa_code(&state.db, user.id, &challenge.code, &hostname).await 2443 + enqueue_2fa_code(state.user_repo.as_ref(), state.infra_repo.as_ref(), user.id, &challenge.code, &hostname).await 2363 2444 { 2364 2445 tracing::warn!(did = %did, error = %e, "Failed to enqueue 2FA notification"); 2365 2446 } ··· 2394 2475 .map(|s| s.to_string()) 2395 2476 .collect(); 2396 2477 2397 - let needs_consent = db::should_show_consent( 2398 - &state.db, 2478 + let passkey_finish_client_id = ClientId::from(request_data.parameters.client_id.clone()); 2479 + let needs_consent = should_show_consent( 2480 + state.oauth_repo.as_ref(), 2399 2481 &did, 2400 - &request_data.parameters.client_id, 2482 + &passkey_finish_client_id, 2401 2483 &requested_scopes, 2402 2484 ) 2403 2485 .await ··· 2412 2494 } 2413 2495 2414 2496 let code = Code::generate(); 2415 - if db::update_authorization_request( 2416 - &state.db, 2417 - &form.request_uri, 2497 + let passkey_final_device_id = device_id.as_ref().map(|d| DeviceIdType::from(d.clone())); 2498 + let passkey_final_code = AuthorizationCode::from(code.0.clone()); 2499 + if state.oauth_repo.update_authorization_request( 2500 + &passkey_finish_request_id, 2418 2501 &did, 2419 - device_id.as_deref(), 2420 - &code.0, 2502 + passkey_final_device_id.as_ref(), 2503 + &passkey_final_code, 2421 2504 ) 2422 2505 .await 2423 2506 .is_err() ··· 2463 2546 ) -> Response { 2464 2547 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 2465 2548 2466 - let request_data = match db::get_authorization_request(&state.db, &query.request_uri).await { 2549 + let auth_passkey_start_request_id = RequestId::from(query.request_uri.clone()); 2550 + let request_data = match state.oauth_repo.get_authorization_request(&auth_passkey_start_request_id).await { 2467 2551 Ok(Some(d)) => d, 2468 2552 Ok(None) => { 2469 2553 return ( ··· 2488 2572 }; 2489 2573 2490 2574 if request_data.expires_at < Utc::now() { 2491 - let _ = db::delete_authorization_request(&state.db, &query.request_uri).await; 2575 + let _ = state.oauth_repo.delete_authorization_request(&auth_passkey_start_request_id).await; 2492 2576 return ( 2493 2577 StatusCode::BAD_REQUEST, 2494 2578 Json(serde_json::json!({ ··· 2499 2583 .into_response(); 2500 2584 } 2501 2585 2502 - let did = match &request_data.did { 2586 + let did_str = match &request_data.did { 2503 2587 Some(d) => d.clone(), 2504 2588 None => { 2505 2589 return ( ··· 2513 2597 } 2514 2598 }; 2515 2599 2516 - let stored_passkeys = match crate::auth::webauthn::get_passkeys_for_user(&state.db, &did).await 2517 - { 2600 + let did: tranquil_types::Did = match did_str.parse() { 2601 + Ok(d) => d, 2602 + Err(_) => { 2603 + return ( 2604 + StatusCode::BAD_REQUEST, 2605 + Json(serde_json::json!({ 2606 + "error": "invalid_request", 2607 + "error_description": "Invalid DID format." 2608 + })), 2609 + ) 2610 + .into_response(); 2611 + } 2612 + }; 2613 + 2614 + let stored_passkeys = match state.user_repo.get_passkeys_for_user(&did).await { 2518 2615 Ok(pks) => pks, 2519 2616 Err(e) => { 2520 2617 tracing::error!("Failed to get passkeys: {:?}", e); 2521 2618 return ( 2522 - StatusCode::INTERNAL_SERVER_ERROR, 2523 - Json(serde_json::json!({"error": "server_error", "error_description": "An error occurred."})), 2524 - ) 2525 - .into_response(); 2619 + StatusCode::INTERNAL_SERVER_ERROR, 2620 + Json(serde_json::json!({"error": "server_error", "error_description": "An error occurred."})), 2621 + ) 2622 + .into_response(); 2526 2623 } 2527 2624 }; 2528 2625 ··· 2539 2636 2540 2637 let passkeys: Vec<webauthn_rs::prelude::SecurityKey> = stored_passkeys 2541 2638 .iter() 2542 - .filter_map(|sp| sp.to_security_key().ok()) 2639 + .filter_map(|sp| serde_json::from_slice(&sp.public_key).ok()) 2543 2640 .collect(); 2544 2641 2545 2642 if passkeys.is_empty() { ··· 2574 2671 } 2575 2672 }; 2576 2673 2577 - if let Err(e) = 2578 - crate::auth::webauthn::save_authentication_state(&state.db, &did, &auth_state).await 2674 + let state_json = match serde_json::to_string(&auth_state) { 2675 + Ok(j) => j, 2676 + Err(e) => { 2677 + tracing::error!("Failed to serialize authentication state: {:?}", e); 2678 + return ( 2679 + StatusCode::INTERNAL_SERVER_ERROR, 2680 + Json(serde_json::json!({"error": "server_error", "error_description": "An error occurred."})), 2681 + ) 2682 + .into_response(); 2683 + } 2684 + }; 2685 + 2686 + if let Err(e) = state 2687 + .user_repo 2688 + .save_webauthn_challenge(&did, "authentication", &state_json) 2689 + .await 2579 2690 { 2580 2691 tracing::error!("Failed to save authentication state: {:?}", e); 2581 2692 return ( ··· 2606 2717 Json(form): Json<AuthorizePasskeySubmit>, 2607 2718 ) -> Response { 2608 2719 let pds_hostname = std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 2720 + let passkey_finish_request_id = RequestId::from(form.request_uri.clone()); 2609 2721 2610 - let request_data = match db::get_authorization_request(&state.db, &form.request_uri).await { 2722 + let request_data = match state.oauth_repo.get_authorization_request(&passkey_finish_request_id).await { 2611 2723 Ok(Some(d)) => d, 2612 2724 Ok(None) => { 2613 2725 return ( ··· 2632 2744 }; 2633 2745 2634 2746 if request_data.expires_at < Utc::now() { 2635 - let _ = db::delete_authorization_request(&state.db, &form.request_uri).await; 2747 + let _ = state.oauth_repo.delete_authorization_request(&passkey_finish_request_id).await; 2636 2748 return ( 2637 2749 StatusCode::BAD_REQUEST, 2638 2750 Json(serde_json::json!({ ··· 2643 2755 .into_response(); 2644 2756 } 2645 2757 2646 - let did = match &request_data.did { 2758 + let did_str = match &request_data.did { 2647 2759 Some(d) => d.clone(), 2648 2760 None => { 2649 2761 return ( ··· 2657 2769 } 2658 2770 }; 2659 2771 2660 - let auth_state = match crate::auth::webauthn::load_authentication_state(&state.db, &did).await { 2772 + let did: tranquil_types::Did = match did_str.parse() { 2773 + Ok(d) => d, 2774 + Err(_) => { 2775 + return ( 2776 + StatusCode::BAD_REQUEST, 2777 + Json(serde_json::json!({ 2778 + "error": "invalid_request", 2779 + "error_description": "Invalid DID format." 2780 + })), 2781 + ) 2782 + .into_response(); 2783 + } 2784 + }; 2785 + 2786 + let auth_state_json = match state 2787 + .user_repo 2788 + .load_webauthn_challenge(&did, "authentication") 2789 + .await 2790 + { 2661 2791 Ok(Some(s)) => s, 2662 2792 Ok(None) => { 2663 2793 return ( ··· 2672 2802 Err(e) => { 2673 2803 tracing::error!("Failed to load authentication state: {:?}", e); 2674 2804 return ( 2805 + StatusCode::INTERNAL_SERVER_ERROR, 2806 + Json(serde_json::json!({"error": "server_error", "error_description": "An error occurred."})), 2807 + ) 2808 + .into_response(); 2809 + } 2810 + }; 2811 + 2812 + let auth_state: webauthn_rs::prelude::SecurityKeyAuthentication = 2813 + match serde_json::from_str(&auth_state_json) { 2814 + Ok(s) => s, 2815 + Err(e) => { 2816 + tracing::error!("Failed to deserialize authentication state: {:?}", e); 2817 + return ( 2675 2818 StatusCode::INTERNAL_SERVER_ERROR, 2676 2819 Json(serde_json::json!({"error": "server_error", "error_description": "An error occurred."})), 2677 2820 ) 2678 2821 .into_response(); 2679 - } 2680 - }; 2822 + } 2823 + }; 2681 2824 2682 2825 let credential: webauthn_rs::prelude::PublicKeyCredential = 2683 2826 match serde_json::from_value(form.credential.clone()) { ··· 2722 2865 } 2723 2866 }; 2724 2867 2725 - let _ = crate::auth::webauthn::delete_authentication_state(&state.db, &did).await; 2868 + let _ = state 2869 + .user_repo 2870 + .delete_webauthn_challenge(&did, "authentication") 2871 + .await; 2726 2872 2727 - match crate::auth::webauthn::update_passkey_counter( 2728 - &state.db, 2729 - credential.id.as_ref(), 2730 - auth_result.counter(), 2731 - ) 2732 - .await 2873 + match state 2874 + .user_repo 2875 + .update_passkey_counter(credential.id.as_ref(), auth_result.counter() as i32) 2876 + .await 2733 2877 { 2734 2878 Ok(false) => { 2735 2879 tracing::warn!(did = %did, "Passkey counter anomaly detected - possible cloned key"); ··· 2748 2892 Ok(true) => {} 2749 2893 } 2750 2894 2751 - let has_totp = crate::api::server::has_totp_enabled_db(&state.db, &did).await; 2895 + let has_totp = state.user_repo.has_totp_enabled(&did).await.unwrap_or(false); 2752 2896 if has_totp { 2753 2897 let device_cookie = extract_device_cookie(&headers); 2754 2898 let device_is_trusted = if let Some(ref dev_id) = device_cookie { 2755 - crate::api::server::is_device_trusted(&state.db, dev_id, &did).await 2899 + crate::api::server::is_device_trusted(state.oauth_repo.as_ref(), dev_id, &did).await 2756 2900 } else { 2757 2901 false 2758 2902 }; 2759 2903 2760 2904 if device_is_trusted { 2761 2905 if let Some(ref dev_id) = device_cookie { 2762 - let _ = crate::api::server::extend_device_trust(&state.db, dev_id).await; 2906 + let _ = crate::api::server::extend_device_trust(state.oauth_repo.as_ref(), dev_id).await; 2763 2907 } 2764 2908 } else { 2765 - let user = match sqlx::query!( 2766 - r#"SELECT id, preferred_comms_channel as "preferred_comms_channel: CommsChannel" FROM users WHERE did = $1"#, 2767 - did 2768 - ) 2769 - .fetch_optional(&state.db) 2770 - .await 2771 - { 2909 + let user = match state.user_repo.get_2fa_status_by_did(&did).await { 2772 2910 Ok(Some(u)) => u, 2773 2911 _ => { 2774 2912 return ( ··· 2779 2917 } 2780 2918 }; 2781 2919 2782 - let _ = db::delete_2fa_challenge_by_request_uri(&state.db, &form.request_uri).await; 2783 - match db::create_2fa_challenge(&state.db, &did, &form.request_uri).await { 2920 + let _ = state.oauth_repo.delete_2fa_challenge_by_request_uri(&passkey_finish_request_id).await; 2921 + match state.oauth_repo.create_2fa_challenge(&did, &passkey_finish_request_id).await { 2784 2922 Ok(challenge) => { 2785 2923 if let Err(e) = 2786 - enqueue_2fa_code(&state.db, user.id, &challenge.code, &pds_hostname).await 2924 + enqueue_2fa_code(state.user_repo.as_ref(), state.infra_repo.as_ref(), user.id, &challenge.code, &pds_hostname).await 2787 2925 { 2788 2926 tracing::warn!(did = %did, error = %e, "Failed to enqueue 2FA notification"); 2789 2927 }
+137 -73
crates/tranquil-pds/src/oauth/endpoints/delegation.rs
··· 1 - use crate::delegation; 2 - use crate::oauth::db; 1 + use crate::delegation::DelegationActionType; 3 2 use crate::state::{AppState, RateLimitKind}; 4 3 use crate::types::PlainPassword; 5 4 use crate::util::extract_client_ip; ··· 10 9 response::{IntoResponse, Response}, 11 10 }; 12 11 use serde::{Deserialize, Serialize}; 12 + use tranquil_types::{Did, RequestId}; 13 13 14 14 #[derive(Debug, Deserialize)] 15 15 pub struct DelegationAuthSubmit { ··· 54 54 .into_response(); 55 55 } 56 56 57 - let request = match db::get_authorization_request(&state.db, &form.request_uri).await { 57 + let request_id = RequestId::from(form.request_uri.clone()); 58 + let request = match state 59 + .oauth_repo 60 + .get_authorization_request(&request_id) 61 + .await 62 + { 58 63 Ok(Some(r)) => r, 59 64 Ok(None) => { 60 65 return Json(DelegationAuthResponse { ··· 76 81 } 77 82 }; 78 83 79 - let delegated_did = match form.delegated_did.as_ref().or(request.did.as_ref()) { 84 + let delegated_did_str = match form.delegated_did.as_ref().or(request.did.as_ref()) { 80 85 Some(did) => did.clone(), 81 86 None => { 82 87 return Json(DelegationAuthResponse { ··· 89 94 } 90 95 }; 91 96 92 - if db::set_request_did(&state.db, &form.request_uri, &delegated_did) 97 + let delegated_did: Did = match delegated_did_str.parse() { 98 + Ok(d) => d, 99 + Err(_) => { 100 + return Json(DelegationAuthResponse { 101 + success: false, 102 + needs_totp: None, 103 + redirect_uri: None, 104 + error: Some("Invalid delegated DID".to_string()), 105 + }) 106 + .into_response(); 107 + } 108 + }; 109 + 110 + let controller_did: Did = match form.controller_did.parse() { 111 + Ok(d) => d, 112 + Err(_) => { 113 + return Json(DelegationAuthResponse { 114 + success: false, 115 + needs_totp: None, 116 + redirect_uri: None, 117 + error: Some("Invalid controller DID".to_string()), 118 + }) 119 + .into_response(); 120 + } 121 + }; 122 + 123 + if state 124 + .oauth_repo 125 + .set_request_did(&request_id, &delegated_did) 93 126 .await 94 127 .is_err() 95 128 { 96 129 tracing::warn!("Failed to set delegated DID on authorization request"); 97 130 } 98 131 99 - let grant = 100 - match delegation::get_delegation(&state.db, &delegated_did, &form.controller_did).await { 101 - Ok(Some(g)) => g, 102 - Ok(None) => { 103 - return Json(DelegationAuthResponse { 104 - success: false, 105 - needs_totp: None, 106 - redirect_uri: None, 107 - error: Some("No delegation grant found for this controller".to_string()), 108 - }) 109 - .into_response(); 110 - } 111 - Err(_) => { 112 - return Json(DelegationAuthResponse { 113 - success: false, 114 - needs_totp: None, 115 - redirect_uri: None, 116 - error: Some("Server error".to_string()), 117 - }) 118 - .into_response(); 119 - } 120 - }; 132 + let grant = match state 133 + .delegation_repo 134 + .get_delegation(&delegated_did, &controller_did) 135 + .await 136 + { 137 + Ok(Some(g)) => g, 138 + Ok(None) => { 139 + return Json(DelegationAuthResponse { 140 + success: false, 141 + needs_totp: None, 142 + redirect_uri: None, 143 + error: Some("No delegation grant found for this controller".to_string()), 144 + }) 145 + .into_response(); 146 + } 147 + Err(_) => { 148 + return Json(DelegationAuthResponse { 149 + success: false, 150 + needs_totp: None, 151 + redirect_uri: None, 152 + error: Some("Server error".to_string()), 153 + }) 154 + .into_response(); 155 + } 156 + }; 121 157 122 - let controller = match sqlx::query!( 123 - r#" 124 - SELECT id, did, password_hash, deactivated_at, takedown_ref, 125 - email_verified, discord_verified, telegram_verified, signal_verified 126 - FROM users 127 - WHERE did = $1 128 - "#, 129 - form.controller_did 130 - ) 131 - .fetch_optional(&state.db) 132 - .await 133 - { 158 + let controller = match state.user_repo.get_auth_info_by_did(&controller_did).await { 134 159 Ok(Some(u)) => u, 135 160 Ok(None) => { 136 161 return Json(DelegationAuthResponse { ··· 188 213 .into_response(); 189 214 } 190 215 191 - if db::set_controller_did(&state.db, &form.request_uri, &form.controller_did) 216 + if state 217 + .oauth_repo 218 + .set_controller_did(&request_id, &controller_did) 192 219 .await 193 220 .is_err() 194 221 { ··· 201 228 .into_response(); 202 229 } 203 230 204 - let has_totp = crate::api::server::has_totp_enabled(&state, &form.controller_did).await; 231 + let has_totp = crate::api::server::has_totp_enabled(&state, &controller_did).await; 205 232 if has_totp { 206 233 return Json(DelegationAuthResponse { 207 234 success: true, ··· 221 248 .and_then(|v| v.to_str().ok()) 222 249 .map(|s| s.to_string()); 223 250 224 - let _ = delegation::log_delegation_action( 225 - &state.db, 226 - &delegated_did, 227 - &form.controller_did, 228 - Some(&form.controller_did), 229 - delegation::DelegationActionType::TokenIssued, 230 - Some(serde_json::json!({ 231 - "client_id": request.client_id, 232 - "granted_scopes": grant.granted_scopes 233 - })), 234 - Some(&ip), 235 - user_agent.as_deref(), 236 - ) 237 - .await; 251 + let _ = state 252 + .delegation_repo 253 + .log_delegation_action( 254 + &delegated_did, 255 + &controller_did, 256 + Some(&controller_did), 257 + DelegationActionType::TokenIssued, 258 + Some(serde_json::json!({ 259 + "client_id": request.client_id, 260 + "granted_scopes": grant.granted_scopes 261 + })), 262 + Some(&ip), 263 + user_agent.as_deref(), 264 + ) 265 + .await; 238 266 239 267 Json(DelegationAuthResponse { 240 268 success: true, ··· 276 304 .into_response(); 277 305 } 278 306 279 - let request = match db::get_authorization_request(&state.db, &form.request_uri).await { 307 + let totp_request_id = RequestId::from(form.request_uri.clone()); 308 + let request = match state 309 + .oauth_repo 310 + .get_authorization_request(&totp_request_id) 311 + .await 312 + { 280 313 Ok(Some(r)) => r, 281 314 Ok(None) => { 282 315 return Json(DelegationAuthResponse { ··· 298 331 } 299 332 }; 300 333 301 - let controller_did = match &request.controller_did { 334 + let controller_did_str = match &request.controller_did { 302 335 Some(did) => did.clone(), 303 336 None => { 304 337 return Json(DelegationAuthResponse { ··· 311 344 } 312 345 }; 313 346 314 - let delegated_did = match &request.did { 347 + let controller_did: Did = match controller_did_str.parse() { 348 + Ok(d) => d, 349 + Err(_) => { 350 + return Json(DelegationAuthResponse { 351 + success: false, 352 + needs_totp: None, 353 + redirect_uri: None, 354 + error: Some("Invalid controller DID".to_string()), 355 + }) 356 + .into_response(); 357 + } 358 + }; 359 + 360 + let delegated_did_str = match &request.did { 315 361 Some(did) => did.clone(), 316 362 None => { 317 363 return Json(DelegationAuthResponse { ··· 324 370 } 325 371 }; 326 372 327 - let grant = match delegation::get_delegation(&state.db, &delegated_did, &controller_did).await { 373 + let delegated_did: Did = match delegated_did_str.parse() { 374 + Ok(d) => d, 375 + Err(_) => { 376 + return Json(DelegationAuthResponse { 377 + success: false, 378 + needs_totp: None, 379 + redirect_uri: None, 380 + error: Some("Invalid delegated DID".to_string()), 381 + }) 382 + .into_response(); 383 + } 384 + }; 385 + 386 + let grant = match state 387 + .delegation_repo 388 + .get_delegation(&delegated_did, &controller_did) 389 + .await 390 + { 328 391 Ok(Some(g)) => g, 329 392 _ => { 330 393 return Json(DelegationAuthResponse { ··· 356 419 .and_then(|v| v.to_str().ok()) 357 420 .map(|s| s.to_string()); 358 421 359 - let _ = delegation::log_delegation_action( 360 - &state.db, 361 - &delegated_did, 362 - &controller_did, 363 - Some(&controller_did), 364 - delegation::DelegationActionType::TokenIssued, 365 - Some(serde_json::json!({ 366 - "client_id": request.client_id, 367 - "granted_scopes": grant.granted_scopes 368 - })), 369 - Some(&ip), 370 - user_agent.as_deref(), 371 - ) 372 - .await; 422 + let _ = state 423 + .delegation_repo 424 + .log_delegation_action( 425 + &delegated_did, 426 + &controller_did, 427 + Some(&controller_did), 428 + DelegationActionType::TokenIssued, 429 + Some(serde_json::json!({ 430 + "client_id": request.client_id, 431 + "granted_scopes": grant.granted_scopes 432 + })), 433 + Some(&ip), 434 + user_agent.as_deref(), 435 + ) 436 + .await; 373 437 374 438 Json(DelegationAuthResponse { 375 439 success: true,
+10 -4
crates/tranquil-pds/src/oauth/endpoints/par.rs
··· 1 1 use crate::oauth::{ 2 2 AuthorizationRequestParameters, ClientAuth, ClientMetadataCache, OAuthError, RequestData, 3 - RequestId, db, 3 + RequestId, 4 4 scopes::{ParsedScope, parse_scope}, 5 5 }; 6 6 use crate::state::{AppState, RateLimitKind}; 7 + use tranquil_types::RequestId as RequestIdType; 7 8 use axum::body::Bytes; 8 9 use axum::{Json, extract::State, http::HeaderMap}; 9 10 use chrono::{Duration, Utc}; ··· 131 132 code: None, 132 133 controller_did: None, 133 134 }; 134 - db::create_authorization_request(&state.db, &request_id.0, &request_data).await?; 135 + let request_id_typed = RequestIdType::from(request_id.0.clone()); 136 + state 137 + .oauth_repo 138 + .create_authorization_request(&request_id_typed, &request_data) 139 + .await 140 + .map_err(crate::oauth::db_err_to_oauth)?; 135 141 tokio::spawn({ 136 - let pool = state.db.clone(); 142 + let oauth_repo = state.oauth_repo.clone(); 137 143 async move { 138 - if let Err(e) = db::delete_expired_authorization_requests(&pool).await { 144 + if let Err(e) = oauth_repo.delete_expired_authorization_requests().await { 139 145 tracing::warn!("Failed to cleanup expired authorization requests: {:?}", e); 140 146 } 141 147 }
+60 -16
crates/tranquil-pds/src/oauth/endpoints/token/grants.rs
··· 1 1 use super::helpers::{create_access_token_with_delegation, verify_pkce}; 2 2 use super::types::{TokenGrant, TokenResponse, ValidatedTokenRequest}; 3 3 use crate::config::AuthConfig; 4 - use crate::delegation; 4 + use crate::delegation::intersect_scopes; 5 5 use crate::oauth::{ 6 6 AuthFlowState, ClientAuth, ClientMetadataCache, DPoPVerifier, OAuthError, RefreshToken, 7 7 TokenData, TokenId, 8 - db::{self, RefreshTokenLookup}, 8 + db::{lookup_refresh_token, enforce_token_limit_for_user}, 9 9 scopes::expand_include_scopes, 10 10 verify_client_auth, 11 11 }; 12 12 use crate::state::AppState; 13 + use tranquil_db_traits::RefreshTokenLookup; 14 + use tranquil_types::{AuthorizationCode, Did, RefreshToken as RefreshTokenType}; 13 15 use axum::Json; 14 16 use axum::http::HeaderMap; 15 17 use chrono::{Duration, Utc}; ··· 41 43 )); 42 44 } 43 45 }; 44 - let auth_request = db::consume_authorization_request_by_code(&state.db, &code) 45 - .await? 46 + let auth_code = AuthorizationCode::from(code); 47 + let auth_request = state 48 + .oauth_repo 49 + .consume_authorization_request_by_code(&auth_code) 50 + .await 51 + .map_err(crate::oauth::db_err_to_oauth)? 46 52 .ok_or_else(|| OAuthError::InvalidGrant("Invalid or expired code".to_string()))?; 47 53 48 54 let flow_state = AuthFlowState::from_request_data(&auth_request); ··· 100 106 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 101 107 let token_endpoint = format!("https://{}/oauth/token", pds_hostname); 102 108 let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?; 103 - if !db::check_and_record_dpop_jti(&state.db, &result.jti).await? { 109 + if !state 110 + .oauth_repo 111 + .check_and_record_dpop_jti(&result.jti) 112 + .await 113 + .map_err(crate::oauth::db_err_to_oauth)? 114 + { 104 115 return Err(OAuthError::InvalidDpopProof( 105 116 "DPoP proof has already been used".to_string(), 106 117 )); ··· 125 136 let now = Utc::now(); 126 137 127 138 let (raw_scope, controller_did) = if let Some(ref controller) = auth_request.controller_did { 128 - let grant = delegation::get_delegation(&state.db, &did, controller) 139 + let did_parsed: Did = did.parse().map_err(|_| { 140 + OAuthError::InvalidRequest("Invalid DID format".to_string()) 141 + })?; 142 + let controller_parsed: Did = controller.parse().map_err(|_| { 143 + OAuthError::InvalidRequest("Invalid controller DID format".to_string()) 144 + })?; 145 + let grant = state 146 + .delegation_repo 147 + .get_delegation(&did_parsed, &controller_parsed) 129 148 .await 130 149 .ok() 131 150 .flatten(); ··· 135 154 .scope 136 155 .as_deref() 137 156 .unwrap_or("atproto"); 138 - let intersected = delegation::intersect_scopes(requested, &granted_scopes); 157 + let intersected = intersect_scopes(requested, &granted_scopes); 139 158 (Some(intersected), Some(controller.clone())) 140 159 } else { 141 160 (auth_request.parameters.scope.clone(), None) ··· 182 201 scope: final_scope.clone(), 183 202 controller_did: controller_did.clone(), 184 203 }; 185 - db::create_token(&state.db, &token_data).await?; 204 + state 205 + .oauth_repo 206 + .create_token(&token_data) 207 + .await 208 + .map_err(crate::oauth::db_err_to_oauth)?; 186 209 tracing::info!( 187 210 did = %did, 188 211 token_id = %token_id.0, ··· 190 213 "Authorization code grant completed, token created" 191 214 ); 192 215 tokio::spawn({ 193 - let pool = state.db.clone(); 216 + let oauth_repo = state.oauth_repo.clone(); 194 217 let did_clone = did.clone(); 195 218 async move { 196 - if let Err(e) = db::enforce_token_limit_for_user(&pool, &did_clone).await { 197 - tracing::warn!("Failed to enforce token limit for user: {:?}", e); 219 + if let Ok(did_typed) = did_clone.parse::<tranquil_types::Did>() { 220 + if let Err(e) = enforce_token_limit_for_user(oauth_repo.as_ref(), &did_typed).await { 221 + tracing::warn!("Failed to enforce token limit for user: {:?}", e); 222 + } 198 223 } 199 224 } 200 225 }); ··· 236 261 "Refresh token grant requested" 237 262 ); 238 263 239 - let lookup = db::lookup_refresh_token(&state.db, &refresh_token_str).await?; 264 + let refresh_token_typed = RefreshTokenType::from(refresh_token_str.clone()); 265 + let lookup = lookup_refresh_token(state.oauth_repo.as_ref(), &refresh_token_typed).await?; 240 266 let token_state = lookup.state(); 241 267 tracing::debug!(state = %token_state, "Refresh token state"); 242 268 ··· 281 307 refresh_token_prefix = %token_prefix, 282 308 "Refresh token reuse detected, revoking token family" 283 309 ); 284 - db::delete_token_family(&state.db, original_token_id).await?; 310 + state 311 + .oauth_repo 312 + .delete_token_family(original_token_id) 313 + .await 314 + .map_err(crate::oauth::db_err_to_oauth)?; 285 315 return Err(OAuthError::InvalidGrant( 286 316 "Refresh token reuse detected, token family revoked".to_string(), 287 317 )); 288 318 } 289 319 RefreshTokenLookup::Expired { db_id } => { 290 320 tracing::warn!(refresh_token_prefix = %token_prefix, "Refresh token has expired"); 291 - db::delete_token_family(&state.db, db_id).await?; 321 + state 322 + .oauth_repo 323 + .delete_token_family(db_id) 324 + .await 325 + .map_err(crate::oauth::db_err_to_oauth)?; 292 326 return Err(OAuthError::InvalidGrant( 293 327 "Refresh token has expired".to_string(), 294 328 )); ··· 307 341 std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| "localhost".to_string()); 308 342 let token_endpoint = format!("https://{}/oauth/token", pds_hostname); 309 343 let result = verifier.verify_proof(proof, "POST", &token_endpoint, None)?; 310 - if !db::check_and_record_dpop_jti(&state.db, &result.jti).await? { 344 + if !state 345 + .oauth_repo 346 + .check_and_record_dpop_jti(&result.jti) 347 + .await 348 + .map_err(crate::oauth::db_err_to_oauth)? 349 + { 311 350 return Err(OAuthError::InvalidDpopProof( 312 351 "DPoP proof has already been used".to_string(), 313 352 )); ··· 334 373 REFRESH_TOKEN_EXPIRY_DAYS_CONFIDENTIAL 335 374 }; 336 375 let new_expires_at = Utc::now() + Duration::days(refresh_expiry_days); 337 - db::rotate_token(&state.db, db_id, &new_refresh_token.0, new_expires_at).await?; 376 + let new_refresh_typed = RefreshTokenType::from(new_refresh_token.0.clone()); 377 + state 378 + .oauth_repo 379 + .rotate_token(db_id, &new_refresh_typed, new_expires_at) 380 + .await 381 + .map_err(crate::oauth::db_err_to_oauth)?; 338 382 tracing::info!( 339 383 did = %token_data.did, 340 384 new_expires_at = %new_expires_at,
+22 -5
crates/tranquil-pds/src/oauth/endpoints/token/introspect.rs
··· 1 1 use super::helpers::extract_token_claims; 2 - use crate::oauth::{OAuthError, db}; 2 + use crate::oauth::OAuthError; 3 3 use crate::state::{AppState, RateLimitKind}; 4 4 use axum::extract::State; 5 5 use axum::http::{HeaderMap, StatusCode}; 6 6 use axum::{Form, Json}; 7 7 use chrono::Utc; 8 8 use serde::{Deserialize, Serialize}; 9 + use tranquil_types::{RefreshToken, TokenId}; 9 10 10 11 #[derive(Debug, Deserialize)] 11 12 pub struct RevokeRequest { ··· 28 29 return Err(OAuthError::RateLimited); 29 30 } 30 31 if let Some(token) = &request.token { 31 - if let Some((db_id, _)) = db::get_token_by_refresh_token(&state.db, token).await? { 32 - db::delete_token_family(&state.db, db_id).await?; 32 + let refresh_token = RefreshToken::from(token.clone()); 33 + if let Some((db_id, _)) = state 34 + .oauth_repo 35 + .get_token_by_refresh_token(&refresh_token) 36 + .await 37 + .map_err(crate::oauth::db_err_to_oauth)? 38 + { 39 + state 40 + .oauth_repo 41 + .delete_token_family(db_id) 42 + .await 43 + .map_err(crate::oauth::db_err_to_oauth)?; 33 44 } else { 34 - db::delete_token(&state.db, token).await?; 45 + let token_id = TokenId::from(token.clone()); 46 + state 47 + .oauth_repo 48 + .delete_token(&token_id) 49 + .await 50 + .map_err(crate::oauth::db_err_to_oauth)?; 35 51 } 36 52 } 37 53 Ok(StatusCode::OK) ··· 102 118 Ok(info) => info, 103 119 Err(_) => return Ok(Json(inactive_response)), 104 120 }; 105 - let token_data = match db::get_token_by_id(&state.db, &token_info.sid).await { 121 + let token_id = TokenId::from(token_info.sid.clone()); 122 + let token_data = match state.oauth_repo.get_token_by_id(&token_id).await { 106 123 Ok(Some(data)) => data, 107 124 _ => return Ok(Json(inactive_response)), 108 125 };
+5
crates/tranquil-pds/src/oauth/mod.rs
··· 4 4 pub mod scopes; 5 5 pub mod verify; 6 6 7 + pub fn db_err_to_oauth(err: tranquil_db::DbError) -> OAuthError { 8 + tracing::error!("Database error in OAuth flow: {}", err); 9 + OAuthError::ServerError("An internal error occurred".to_string()) 10 + } 11 + 7 12 pub use tranquil_oauth::{ 8 13 AuthFlowState, AuthorizationRequestParameters, AuthorizationServerMetadata, 9 14 AuthorizedClientData, ClientAuth, ClientMetadata, ClientMetadataCache, Code, DPoPClaims,
+26 -12
crates/tranquil-pds/src/oauth/verify.rs
··· 8 8 use hmac::{Hmac, Mac}; 9 9 use serde_json::json; 10 10 use sha2::Sha256; 11 - use sqlx::PgPool; 12 11 use subtle::ConstantTimeEq; 12 + use tranquil_db_traits::{OAuthRepository, UserRepository}; 13 + use tranquil_types::TokenId; 13 14 14 - use super::db; 15 15 use super::scopes::ScopePermissions; 16 16 use super::{DPoPVerifier, OAuthError}; 17 17 use crate::config::AuthConfig; ··· 34 34 } 35 35 36 36 pub async fn verify_oauth_access_token( 37 - pool: &PgPool, 37 + oauth_repo: &dyn OAuthRepository, 38 38 access_token: &str, 39 39 dpop_proof: Option<&str>, 40 40 http_method: &str, ··· 46 46 has_dpop_proof = dpop_proof.is_some(), 47 47 "Verifying OAuth access token" 48 48 ); 49 - let token_data = db::get_token_by_id(pool, &token_info.token_id) 50 - .await? 49 + let token_id = TokenId::from(token_info.token_id.clone()); 50 + let token_data = oauth_repo 51 + .get_token_by_id(&token_id) 52 + .await 53 + .map_err(crate::oauth::db_err_to_oauth)? 51 54 .ok_or_else(|| { 52 - tracing::warn!(token_id = %token_info.token_id, "Token not found in database"); 55 + tracing::warn!(token_id = %token_id, "Token not found in database"); 53 56 OAuthError::InvalidToken("Token not found or revoked".to_string()) 54 57 })?; 55 58 let now = chrono::Utc::now(); ··· 73 76 tracing::warn!(error = ?e, http_method = %http_method, http_uri = %http_uri, "DPoP proof verification failed"); 74 77 e 75 78 })?; 76 - if !db::check_and_record_dpop_jti(pool, &result.jti).await? { 79 + if !oauth_repo 80 + .check_and_record_dpop_jti(&result.jti) 81 + .await 82 + .map_err(crate::oauth::db_err_to_oauth)? 83 + { 77 84 return Err(OAuthError::InvalidDpopProof( 78 85 "DPoP proof has already been used".to_string(), 79 86 )); ··· 86 93 } 87 94 Ok(VerifyResult { 88 95 did: token_data.did, 89 - token_id: token_info.token_id, 96 + token_id: token_id.to_string(), 90 97 client_id: token_data.client_id, 91 98 scope: token_data.scope, 92 99 }) ··· 271 278 }); 272 279 }; 273 280 let dpop_proof = parts.headers.get("DPoP").and_then(|v| v.to_str().ok()); 274 - if let Ok(result) = try_legacy_auth(&state.db, token).await { 281 + if let Ok(result) = try_legacy_auth(state.user_repo.as_ref(), token).await { 275 282 return Ok(OAuthUser { 276 283 did: result.did, 277 284 client_id: None, ··· 282 289 } 283 290 let http_method = parts.method.as_str(); 284 291 let http_uri = crate::util::build_full_url(&parts.uri.to_string()); 285 - match verify_oauth_access_token(&state.db, token, dpop_proof, http_method, &http_uri).await 292 + match verify_oauth_access_token( 293 + state.oauth_repo.as_ref(), 294 + token, 295 + dpop_proof, 296 + http_method, 297 + &http_uri, 298 + ) 299 + .await 286 300 { 287 301 Ok(result) => { 288 302 let permissions = ScopePermissions::from_scope_string(result.scope.as_deref()); ··· 371 385 did: String, 372 386 } 373 387 374 - async fn try_legacy_auth(pool: &PgPool, token: &str) -> Result<LegacyAuthResult, ()> { 375 - match crate::auth::validate_bearer_token(pool, token).await { 388 + async fn try_legacy_auth(user_repo: &dyn UserRepository, token: &str) -> Result<LegacyAuthResult, ()> { 389 + match crate::auth::validate_bearer_token(user_repo, token).await { 376 390 Ok(user) if !user.is_oauth => Ok(LegacyAuthResult { 377 391 did: user.did.to_string(), 378 392 }),
+174 -354
crates/tranquil-pds/src/scheduled.rs
··· 2 2 use ipld_core::ipld::Ipld; 3 3 use jacquard_repo::commit::Commit; 4 4 use jacquard_repo::storage::BlockStore; 5 - use sqlx::PgPool; 6 5 use std::str::FromStr; 7 6 use std::sync::Arc; 8 7 use std::time::Duration; 9 8 use tokio::sync::watch; 10 9 use tokio::time::interval; 11 10 use tracing::{debug, error, info, warn}; 11 + use tranquil_db_traits::{ 12 + BackupRepository, BlobRepository, BrokenGenesisCommit, RepoRepository, UserRepository, 13 + }; 14 + use tranquil_types::{AtUri, CidLink, Did}; 12 15 13 16 use crate::repo::PostgresBlockStore; 14 17 use crate::storage::{BackupStorage, BlobStorage}; 15 18 use crate::sync::car::encode_car_header; 16 19 17 - async fn update_genesis_blocks_cids(db: &PgPool, blocks_cids: &[String], seq: i64) -> Result<(), sqlx::Error> { 18 - sqlx::query!( 19 - "UPDATE repo_seq SET blocks_cids = $1 WHERE seq = $2", 20 - blocks_cids, 21 - seq 22 - ) 23 - .execute(db) 24 - .await?; 25 - Ok(()) 26 - } 27 - 28 - async fn update_repo_rev(db: &PgPool, rev: &str, user_id: uuid::Uuid) -> Result<(), sqlx::Error> { 29 - sqlx::query!( 30 - "UPDATE repos SET repo_rev = $1 WHERE user_id = $2", 31 - rev, 32 - user_id 33 - ) 34 - .execute(db) 35 - .await?; 36 - Ok(()) 37 - } 38 - 39 - async fn insert_user_blocks(db: &PgPool, user_id: uuid::Uuid, block_cids: &[Vec<u8>]) -> Result<(), sqlx::Error> { 40 - sqlx::query!( 41 - r#" 42 - INSERT INTO user_blocks (user_id, block_cid) 43 - SELECT $1, block_cid FROM UNNEST($2::bytea[]) AS t(block_cid) 44 - ON CONFLICT (user_id, block_cid) DO NOTHING 45 - "#, 46 - user_id, 47 - block_cids 48 - ) 49 - .execute(db) 50 - .await?; 51 - Ok(()) 52 - } 53 - 54 - async fn fetch_user_records(db: &PgPool, user_id: uuid::Uuid) -> Result<Vec<(String, String, String)>, sqlx::Error> { 55 - let rows = sqlx::query!( 56 - "SELECT collection, rkey, record_cid FROM records WHERE repo_id = $1", 57 - user_id 58 - ) 59 - .fetch_all(db) 60 - .await?; 61 - Ok(rows.into_iter().map(|r| (r.collection, r.rkey, r.record_cid)).collect()) 62 - } 63 - 64 - async fn insert_record_blobs(db: &PgPool, user_id: uuid::Uuid, record_uris: &[String], blob_cids: &[String]) -> Result<(), sqlx::Error> { 65 - sqlx::query!( 66 - r#" 67 - INSERT INTO record_blobs (repo_id, record_uri, blob_cid) 68 - SELECT $1, record_uri, blob_cid 69 - FROM UNNEST($2::text[], $3::text[]) AS t(record_uri, blob_cid) 70 - ON CONFLICT (repo_id, record_uri, blob_cid) DO NOTHING 71 - "#, 72 - user_id, 73 - record_uris, 74 - blob_cids 75 - ) 76 - .execute(db) 77 - .await?; 78 - Ok(()) 79 - } 80 - 81 - async fn delete_backup_record(db: &PgPool, id: uuid::Uuid) -> Result<(), sqlx::Error> { 82 - sqlx::query!("DELETE FROM account_backups WHERE id = $1", id) 83 - .execute(db) 84 - .await?; 85 - Ok(()) 86 - } 87 - 88 - async fn fetch_old_backups( 89 - db: &PgPool, 90 - user_id: uuid::Uuid, 91 - retention_count: i64, 92 - ) -> Result<Vec<(uuid::Uuid, String)>, sqlx::Error> { 93 - let rows = sqlx::query!( 94 - r#" 95 - SELECT id, storage_key 96 - FROM account_backups 97 - WHERE user_id = $1 98 - ORDER BY created_at DESC 99 - OFFSET $2 100 - "#, 101 - user_id, 102 - retention_count 103 - ) 104 - .fetch_all(db) 105 - .await?; 106 - Ok(rows.into_iter().map(|r| (r.id, r.storage_key)).collect()) 107 - } 108 - 109 - async fn insert_backup_record( 110 - db: &PgPool, 111 - user_id: uuid::Uuid, 112 - storage_key: &str, 113 - repo_root_cid: &str, 114 - repo_rev: &str, 115 - block_count: i32, 116 - size_bytes: i64, 117 - ) -> Result<(), sqlx::Error> { 118 - sqlx::query!( 119 - r#" 120 - INSERT INTO account_backups (user_id, storage_key, repo_root_cid, repo_rev, block_count, size_bytes) 121 - VALUES ($1, $2, $3, $4, $5, $6) 122 - "#, 123 - user_id, 124 - storage_key, 125 - repo_root_cid, 126 - repo_rev, 127 - block_count, 128 - size_bytes 129 - ) 130 - .execute(db) 131 - .await?; 132 - Ok(()) 133 - } 134 - 135 - struct GenesisCommitRow { 136 - seq: i64, 137 - did: String, 138 - commit_cid: Option<String>, 139 - } 140 - 141 20 async fn process_genesis_commit( 142 - db: &PgPool, 21 + repo_repo: &dyn RepoRepository, 143 22 block_store: &PostgresBlockStore, 144 - row: GenesisCommitRow, 145 - ) -> Result<(String, i64), (i64, &'static str)> { 23 + row: BrokenGenesisCommit, 24 + ) -> Result<(Did, i64), (i64, &'static str)> { 146 25 let commit_cid_str = row.commit_cid.ok_or((row.seq, "missing commit_cid"))?; 147 26 let commit_cid = Cid::from_str(&commit_cid_str).map_err(|_| (row.seq, "invalid CID"))?; 148 27 let block = block_store ··· 152 31 .ok_or((row.seq, "block not found"))?; 153 32 let commit = Commit::from_cbor(&block).map_err(|_| (row.seq, "failed to parse commit"))?; 154 33 let blocks_cids = vec![commit.data.to_string(), commit_cid.to_string()]; 155 - update_genesis_blocks_cids(db, &blocks_cids, row.seq) 34 + repo_repo 35 + .update_seq_blocks_cids(row.seq, &blocks_cids) 156 36 .await 157 37 .map_err(|_| (row.seq, "failed to update"))?; 158 38 Ok((row.did, row.seq)) 159 39 } 160 40 161 - pub async fn backfill_genesis_commit_blocks(db: &PgPool, block_store: PostgresBlockStore) { 162 - let broken_genesis_commits = match sqlx::query!( 163 - r#" 164 - SELECT seq, did, commit_cid 165 - FROM repo_seq 166 - WHERE event_type = 'commit' 167 - AND prev_cid IS NULL 168 - AND (blocks_cids IS NULL OR array_length(blocks_cids, 1) IS NULL OR array_length(blocks_cids, 1) = 0) 169 - "# 170 - ) 171 - .fetch_all(db) 172 - .await 173 - { 41 + pub async fn backfill_genesis_commit_blocks( 42 + repo_repo: Arc<dyn RepoRepository>, 43 + block_store: PostgresBlockStore, 44 + ) { 45 + let broken_genesis_commits = match repo_repo.get_broken_genesis_commits().await { 174 46 Ok(rows) => rows, 175 47 Err(e) => { 176 - error!("Failed to query repo_seq for genesis commit backfill: {}", e); 48 + error!( 49 + "Failed to query repo_seq for genesis commit backfill: {:?}", 50 + e 51 + ); 177 52 return; 178 53 } 179 54 }; ··· 189 64 ); 190 65 191 66 let results = futures::future::join_all(broken_genesis_commits.into_iter().map(|row| { 192 - process_genesis_commit( 193 - db, 194 - &block_store, 195 - GenesisCommitRow { 196 - seq: row.seq, 197 - did: row.did, 198 - commit_cid: row.commit_cid, 199 - }, 200 - ) 67 + let repo_repo = repo_repo.clone(); 68 + let block_store = block_store.clone(); 69 + async move { process_genesis_commit(repo_repo.as_ref(), &block_store, row).await } 201 70 })) 202 71 .await; 203 72 ··· 219 88 } 220 89 221 90 async fn process_repo_rev( 222 - db: &PgPool, 91 + repo_repo: &dyn RepoRepository, 223 92 block_store: &PostgresBlockStore, 224 93 user_id: uuid::Uuid, 225 94 repo_root_cid: String, ··· 233 102 .ok_or(user_id)?; 234 103 let commit = Commit::from_cbor(&block).map_err(|_| user_id)?; 235 104 let rev = commit.rev().to_string(); 236 - update_repo_rev(db, &rev, user_id) 105 + repo_repo 106 + .update_repo_rev(user_id, &rev) 237 107 .await 238 108 .map_err(|_| user_id)?; 239 109 Ok(user_id) 240 110 } 241 111 242 - pub async fn backfill_repo_rev(db: &PgPool, block_store: PostgresBlockStore) { 243 - let repos_missing_rev = 244 - match sqlx::query!("SELECT user_id, repo_root_cid FROM repos WHERE repo_rev IS NULL") 245 - .fetch_all(db) 246 - .await 247 - { 248 - Ok(rows) => rows, 249 - Err(e) => { 250 - error!("Failed to query repos for backfill: {}", e); 251 - return; 252 - } 253 - }; 112 + pub async fn backfill_repo_rev( 113 + repo_repo: Arc<dyn RepoRepository>, 114 + block_store: PostgresBlockStore, 115 + ) { 116 + let repos_missing_rev = match repo_repo.get_repos_without_rev().await { 117 + Ok(rows) => rows, 118 + Err(e) => { 119 + error!("Failed to query repos for backfill: {:?}", e); 120 + return; 121 + } 122 + }; 254 123 255 124 if repos_missing_rev.is_empty() { 256 125 debug!("No repos need repo_rev backfill"); ··· 263 132 ); 264 133 265 134 let results = futures::future::join_all(repos_missing_rev.into_iter().map(|repo| { 266 - process_repo_rev(db, &block_store, repo.user_id, repo.repo_root_cid) 135 + let repo_repo = repo_repo.clone(); 136 + let block_store = block_store.clone(); 137 + async move { 138 + process_repo_rev(repo_repo.as_ref(), &block_store, repo.user_id, repo.repo_root_cid.to_string()) 139 + .await 140 + } 267 141 })) 268 142 .await; 269 143 270 - let (success, failed) = results 271 - .iter() 272 - .fold((0, 0), |(s, f), r| match r { 273 - Ok(_) => (s + 1, f), 274 - Err(user_id) => { 275 - warn!(user_id = %user_id, "Failed to update repo_rev"); 276 - (s, f + 1) 277 - } 278 - }); 144 + let (success, failed) = results.iter().fold((0, 0), |(s, f), r| match r { 145 + Ok(_) => (s + 1, f), 146 + Err(user_id) => { 147 + warn!(user_id = %user_id, "Failed to update repo_rev"); 148 + (s, f + 1) 149 + } 150 + }); 279 151 280 152 info!(success, failed, "Completed repo_rev backfill"); 281 153 } 282 154 283 155 async fn process_user_blocks( 284 - db: &PgPool, 156 + repo_repo: &dyn RepoRepository, 285 157 block_store: &PostgresBlockStore, 286 158 user_id: uuid::Uuid, 287 159 repo_root_cid: String, 160 + repo_rev: Option<String>, 288 161 ) -> Result<(uuid::Uuid, usize), uuid::Uuid> { 289 162 let root_cid = Cid::from_str(&repo_root_cid).map_err(|_| user_id)?; 290 163 let block_cids = collect_current_repo_blocks(block_store, &root_cid) ··· 294 167 return Err(user_id); 295 168 } 296 169 let count = block_cids.len(); 297 - insert_user_blocks(db, user_id, &block_cids) 170 + let rev = repo_rev.unwrap_or_else(|| "0".to_string()); 171 + repo_repo 172 + .insert_user_blocks(user_id, &block_cids, &rev) 298 173 .await 299 174 .map_err(|_| user_id)?; 300 175 Ok((user_id, count)) 301 176 } 302 177 303 - pub async fn backfill_user_blocks(db: &PgPool, block_store: PostgresBlockStore) { 304 - let users_without_blocks = match sqlx::query!( 305 - r#" 306 - SELECT u.id as user_id, r.repo_root_cid 307 - FROM users u 308 - JOIN repos r ON r.user_id = u.id 309 - WHERE NOT EXISTS (SELECT 1 FROM user_blocks ub WHERE ub.user_id = u.id) 310 - "# 311 - ) 312 - .fetch_all(db) 313 - .await 314 - { 178 + pub async fn backfill_user_blocks( 179 + repo_repo: Arc<dyn RepoRepository>, 180 + block_store: PostgresBlockStore, 181 + ) { 182 + let users_without_blocks = match repo_repo.get_users_without_blocks().await { 315 183 Ok(rows) => rows, 316 184 Err(e) => { 317 - error!("Failed to query users for user_blocks backfill: {}", e); 185 + error!("Failed to query users for user_blocks backfill: {:?}", e); 318 186 return; 319 187 } 320 188 }; ··· 330 198 ); 331 199 332 200 let results = futures::future::join_all(users_without_blocks.into_iter().map(|user| { 333 - process_user_blocks(db, &block_store, user.user_id, user.repo_root_cid) 201 + let repo_repo = repo_repo.clone(); 202 + let block_store = block_store.clone(); 203 + async move { 204 + process_user_blocks( 205 + repo_repo.as_ref(), 206 + &block_store, 207 + user.user_id, 208 + user.repo_root_cid.to_string(), 209 + user.repo_rev, 210 + ) 211 + .await 212 + } 334 213 })) 335 214 .await; 336 215 ··· 401 280 } 402 281 403 282 async fn process_record_blobs( 404 - db: &PgPool, 283 + repo_repo: &dyn RepoRepository, 405 284 block_store: &PostgresBlockStore, 406 285 user_id: uuid::Uuid, 407 - did: String, 408 - ) -> Result<(uuid::Uuid, String, usize), (uuid::Uuid, &'static str)> { 409 - let records = fetch_user_records(db, user_id) 286 + did: Did, 287 + ) -> Result<(uuid::Uuid, Did, usize), (uuid::Uuid, &'static str)> { 288 + let records = repo_repo 289 + .get_all_records(user_id) 410 290 .await 411 291 .map_err(|_| (user_id, "failed to fetch records"))?; 412 292 413 - let mut batch_record_uris: Vec<String> = Vec::new(); 414 - let mut batch_blob_cids: Vec<String> = Vec::new(); 293 + let mut batch_record_uris: Vec<AtUri> = Vec::new(); 294 + let mut batch_blob_cids: Vec<CidLink> = Vec::new(); 415 295 416 - futures::future::join_all(records.into_iter().map(|(collection, rkey, record_cid)| { 296 + futures::future::join_all(records.into_iter().map(|record| { 417 297 let did = did.clone(); 418 298 async move { 419 - let cid = Cid::from_str(&record_cid).ok()?; 299 + let cid = Cid::from_str(&record.record_cid).ok()?; 420 300 let block_bytes = block_store.get(&cid).await.ok()??; 421 301 let record_ipld: Ipld = serde_ipld_dagcbor::from_slice(&block_bytes).ok()?; 422 302 let blob_refs = crate::sync::import::find_blob_refs_ipld(&record_ipld, 0); ··· 424 304 blob_refs 425 305 .into_iter() 426 306 .map(|blob_ref| { 427 - let record_uri = format!("at://{}/{}/{}", did, collection, rkey); 428 - (record_uri, blob_ref.cid) 307 + let record_uri = 308 + AtUri::from_parts(did.as_str(), record.collection.as_str(), record.rkey.as_str()); 309 + (record_uri, CidLink::new_unchecked(blob_ref.cid)) 429 310 }) 430 311 .collect::<Vec<_>>(), 431 312 ) ··· 442 323 443 324 let blob_refs_found = batch_record_uris.len(); 444 325 if !batch_record_uris.is_empty() { 445 - insert_record_blobs(db, user_id, &batch_record_uris, &batch_blob_cids) 326 + repo_repo 327 + .insert_record_blobs(user_id, &batch_record_uris, &batch_blob_cids) 446 328 .await 447 329 .map_err(|_| (user_id, "failed to insert"))?; 448 330 } 449 331 Ok((user_id, did, blob_refs_found)) 450 332 } 451 333 452 - pub async fn backfill_record_blobs(db: &PgPool, block_store: PostgresBlockStore) { 453 - let users_needing_backfill = match sqlx::query!( 454 - r#" 455 - SELECT DISTINCT u.id as user_id, u.did 456 - FROM users u 457 - JOIN records r ON r.repo_id = u.id 458 - WHERE NOT EXISTS (SELECT 1 FROM record_blobs rb WHERE rb.repo_id = u.id) 459 - LIMIT 100 460 - "# 461 - ) 462 - .fetch_all(db) 463 - .await 464 - { 334 + pub async fn backfill_record_blobs( 335 + repo_repo: Arc<dyn RepoRepository>, 336 + block_store: PostgresBlockStore, 337 + ) { 338 + let users_needing_backfill = match repo_repo.get_users_needing_record_blobs_backfill(100).await { 465 339 Ok(rows) => rows, 466 340 Err(e) => { 467 - error!("Failed to query users for record_blobs backfill: {}", e); 341 + error!("Failed to query users for record_blobs backfill: {:?}", e); 468 342 return; 469 343 } 470 344 }; ··· 480 354 ); 481 355 482 356 let results = futures::future::join_all(users_needing_backfill.into_iter().map(|user| { 483 - process_record_blobs(db, &block_store, user.user_id, user.did) 357 + let repo_repo = repo_repo.clone(); 358 + let block_store = block_store.clone(); 359 + async move { 360 + process_record_blobs(repo_repo.as_ref(), &block_store, user.user_id, user.did).await 361 + } 484 362 })) 485 363 .await; 486 364 ··· 501 379 } 502 380 503 381 pub async fn start_scheduled_tasks( 504 - db: PgPool, 382 + user_repo: Arc<dyn UserRepository>, 383 + blob_repo: Arc<dyn BlobRepository>, 505 384 blob_store: Arc<dyn BlobStorage>, 506 385 mut shutdown_rx: watch::Receiver<bool>, 507 386 ) { ··· 529 408 } 530 409 } 531 410 _ = ticker.tick() => { 532 - if let Err(e) = process_scheduled_deletions(&db, blob_store.as_ref()).await { 411 + if let Err(e) = process_scheduled_deletions( 412 + user_repo.as_ref(), 413 + blob_repo.as_ref(), 414 + blob_store.as_ref(), 415 + ).await { 533 416 error!("Error processing scheduled deletions: {}", e); 534 417 } 535 418 } ··· 538 421 } 539 422 540 423 async fn process_scheduled_deletions( 541 - db: &PgPool, 424 + user_repo: &dyn UserRepository, 425 + blob_repo: &dyn BlobRepository, 542 426 blob_store: &dyn BlobStorage, 543 427 ) -> Result<(), String> { 544 - let accounts_to_delete = sqlx::query!( 545 - r#" 546 - SELECT did, handle 547 - FROM users 548 - WHERE delete_after IS NOT NULL 549 - AND delete_after < NOW() 550 - AND deactivated_at IS NOT NULL 551 - LIMIT 100 552 - "# 553 - ) 554 - .fetch_all(db) 555 - .await 556 - .map_err(|e| format!("DB error fetching accounts to delete: {}", e))?; 428 + let accounts_to_delete = user_repo 429 + .get_accounts_scheduled_for_deletion(100) 430 + .await 431 + .map_err(|e| format!("DB error fetching accounts to delete: {:?}", e))?; 557 432 558 433 if accounts_to_delete.is_empty() { 559 434 debug!("No accounts scheduled for deletion"); ··· 566 441 ); 567 442 568 443 futures::future::join_all(accounts_to_delete.into_iter().map(|account| async move { 569 - let result = delete_account_data(db, blob_store, &account.did, &account.handle).await; 444 + let result = 445 + delete_account_data(user_repo, blob_repo, blob_store, account.id, &account.did).await; 570 446 (account.did, account.handle, result) 571 447 })) 572 448 .await ··· 580 456 } 581 457 582 458 async fn delete_account_data( 583 - db: &PgPool, 459 + user_repo: &dyn UserRepository, 460 + blob_repo: &dyn BlobRepository, 584 461 blob_store: &dyn BlobStorage, 585 - did: &str, 586 - _handle: &str, 462 + user_id: uuid::Uuid, 463 + did: &Did, 587 464 ) -> Result<(), String> { 588 - let user_id: uuid::Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 589 - .fetch_one(db) 465 + let blob_storage_keys = blob_repo 466 + .get_blob_storage_keys_by_user(user_id) 590 467 .await 591 - .map_err(|e| format!("DB error fetching user: {}", e))?; 592 - 593 - let blob_storage_keys: Vec<String> = sqlx::query_scalar!( 594 - r#"SELECT storage_key as "storage_key!" FROM blobs WHERE created_by_user = $1"#, 595 - user_id 596 - ) 597 - .fetch_all(db) 598 - .await 599 - .map_err(|e| format!("DB error fetching blob keys: {}", e))?; 468 + .map_err(|e| format!("DB error fetching blob keys: {:?}", e))?; 600 469 601 470 futures::future::join_all(blob_storage_keys.iter().map(|storage_key| async move { 602 471 (storage_key, blob_store.delete(storage_key).await) ··· 608 477 warn!(storage_key = %key, error = %e, "Failed to delete blob from storage (continuing anyway)"); 609 478 }); 610 479 611 - let mut tx = db 612 - .begin() 613 - .await 614 - .map_err(|e| format!("Failed to begin transaction: {}", e))?; 615 - 616 - sqlx::query!("DELETE FROM blobs WHERE created_by_user = $1", user_id) 617 - .execute(&mut *tx) 618 - .await 619 - .map_err(|e| format!("Failed to delete blobs: {}", e))?; 620 - 621 - sqlx::query!("DELETE FROM users WHERE id = $1", user_id) 622 - .execute(&mut *tx) 623 - .await 624 - .map_err(|e| format!("Failed to delete user: {}", e))?; 625 - 626 - let account_seq = sqlx::query_scalar!( 627 - r#" 628 - INSERT INTO repo_seq (did, event_type, active, status) 629 - VALUES ($1, 'account', false, 'deleted') 630 - RETURNING seq 631 - "#, 632 - did 633 - ) 634 - .fetch_one(&mut *tx) 635 - .await 636 - .map_err(|e| format!("Failed to sequence account deletion: {}", e))?; 637 - 638 - sqlx::query!( 639 - "DELETE FROM repo_seq WHERE did = $1 AND seq != $2", 640 - did, 641 - account_seq 642 - ) 643 - .execute(&mut *tx) 644 - .await 645 - .map_err(|e| format!("Failed to cleanup sequences: {}", e))?; 646 - 647 - tx.commit() 480 + let _account_seq = user_repo 481 + .delete_account_with_firehose(user_id, did) 648 482 .await 649 - .map_err(|e| format!("Failed to commit transaction: {}", e))?; 650 - 651 - sqlx::query(&format!("NOTIFY repo_updates, '{}'", account_seq)) 652 - .execute(db) 653 - .await 654 - .map_err(|e| format!("Failed to notify: {}", e))?; 483 + .map_err(|e| format!("Failed to delete account: {:?}", e))?; 655 484 656 485 info!( 657 486 did = %did, ··· 663 492 } 664 493 665 494 pub async fn start_backup_tasks( 666 - db: PgPool, 495 + repo_repo: Arc<dyn RepoRepository>, 496 + backup_repo: Arc<dyn BackupRepository>, 667 497 block_store: PostgresBlockStore, 668 498 backup_storage: Arc<BackupStorage>, 669 499 mut shutdown_rx: watch::Receiver<bool>, ··· 688 518 } 689 519 } 690 520 _ = ticker.tick() => { 691 - if let Err(e) = process_scheduled_backups(&db, &block_store, &backup_storage).await { 521 + if let Err(e) = process_scheduled_backups( 522 + repo_repo.as_ref(), 523 + backup_repo.as_ref(), 524 + &block_store, 525 + &backup_storage, 526 + ).await { 692 527 error!("Error processing scheduled backups: {}", e); 693 528 } 694 529 } ··· 711 546 } 712 547 713 548 async fn process_single_backup( 714 - db: &PgPool, 549 + repo_repo: &dyn RepoRepository, 550 + backup_repo: &dyn BackupRepository, 715 551 block_store: &PostgresBlockStore, 716 552 backup_storage: &BackupStorage, 717 553 user_id: uuid::Uuid, ··· 729 565 Err(_) => return BackupOutcome::Skipped(did, "invalid repo_root_cid"), 730 566 }; 731 567 732 - let car_bytes = match generate_full_backup(db, block_store, user_id, &head_cid).await { 568 + let car_bytes = match generate_full_backup(repo_repo, block_store, user_id, &head_cid).await { 733 569 Ok(bytes) => bytes, 734 570 Err(e) => return BackupOutcome::Failed(did, format!("CAR generation: {}", e)), 735 571 }; ··· 742 578 Err(e) => return BackupOutcome::Failed(did, format!("S3 upload: {}", e)), 743 579 }; 744 580 745 - if let Err(e) = insert_backup_record( 746 - db, 747 - user_id, 748 - &storage_key, 749 - &repo_root_cid, 750 - &repo_rev, 751 - block_count, 752 - size_bytes, 753 - ) 754 - .await 581 + if let Err(e) = backup_repo 582 + .insert_backup( 583 + user_id, 584 + &storage_key, 585 + &repo_root_cid, 586 + &repo_rev, 587 + block_count, 588 + size_bytes, 589 + ) 590 + .await 755 591 { 756 592 if let Err(rollback_err) = backup_storage.delete_backup(&storage_key).await { 757 593 error!( ··· 761 597 "Failed to rollback orphaned backup from S3" 762 598 ); 763 599 } 764 - return BackupOutcome::Failed(did, format!("DB insert: {}", e)); 600 + return BackupOutcome::Failed(did, format!("DB insert: {:?}", e)); 765 601 } 766 602 767 603 BackupOutcome::Success(BackupResult { ··· 774 610 } 775 611 776 612 async fn process_scheduled_backups( 777 - db: &PgPool, 613 + repo_repo: &dyn RepoRepository, 614 + backup_repo: &dyn BackupRepository, 778 615 block_store: &PostgresBlockStore, 779 616 backup_storage: &BackupStorage, 780 617 ) -> Result<(), String> { 781 618 let backup_interval_secs = BackupStorage::interval_secs() as i64; 782 619 let retention_count = BackupStorage::retention_count(); 783 620 784 - let users_needing_backup = sqlx::query!( 785 - r#" 786 - SELECT u.id as user_id, u.did, r.repo_root_cid, r.repo_rev 787 - FROM users u 788 - JOIN repos r ON r.user_id = u.id 789 - WHERE u.backup_enabled = true 790 - AND u.deactivated_at IS NULL 791 - AND ( 792 - NOT EXISTS ( 793 - SELECT 1 FROM account_backups ab WHERE ab.user_id = u.id 794 - ) 795 - OR ( 796 - SELECT MAX(ab.created_at) FROM account_backups ab WHERE ab.user_id = u.id 797 - ) < NOW() - make_interval(secs => $1) 798 - ) 799 - LIMIT 50 800 - "#, 801 - backup_interval_secs as f64 802 - ) 803 - .fetch_all(db) 804 - .await 805 - .map_err(|e| format!("DB error fetching users for backup: {}", e))?; 621 + let users_needing_backup = backup_repo 622 + .get_users_needing_backup(backup_interval_secs, 50) 623 + .await 624 + .map_err(|e| format!("DB error fetching users for backup: {:?}", e))?; 806 625 807 626 if users_needing_backup.is_empty() { 808 627 debug!("No accounts need backup"); ··· 816 635 817 636 let results = futures::future::join_all(users_needing_backup.into_iter().map(|user| { 818 637 process_single_backup( 819 - db, 638 + repo_repo, 639 + backup_repo, 820 640 block_store, 821 641 backup_storage, 822 - user.user_id, 823 - user.did, 824 - user.repo_root_cid, 642 + user.id, 643 + user.did.to_string(), 644 + user.repo_root_cid.to_string(), 825 645 user.repo_rev, 826 646 ) 827 647 })) ··· 838 658 "Created backup" 839 659 ); 840 660 if let Err(e) = 841 - cleanup_old_backups(db, backup_storage, result.user_id, retention_count).await 661 + cleanup_old_backups(backup_repo, backup_storage, result.user_id, retention_count) 662 + .await 842 663 { 843 664 warn!(did = %result.did, error = %e, "Failed to cleanup old backups"); 844 665 } ··· 905 726 } 906 727 907 728 pub async fn generate_repo_car_from_user_blocks( 908 - db: &PgPool, 729 + repo_repo: &dyn tranquil_db_traits::RepoRepository, 909 730 block_store: &PostgresBlockStore, 910 731 user_id: uuid::Uuid, 911 732 _head_cid: &Cid, 912 733 ) -> Result<Vec<u8>, String> { 913 734 use std::str::FromStr; 914 735 915 - let repo_root_cid_str: String = sqlx::query_scalar!( 916 - "SELECT repo_root_cid FROM repos WHERE user_id = $1", 917 - user_id 918 - ) 919 - .fetch_optional(db) 920 - .await 921 - .map_err(|e| format!("Failed to fetch repo: {}", e))? 922 - .ok_or_else(|| "Repository not found".to_string())?; 736 + let repo_root_cid_str: String = repo_repo 737 + .get_repo_root_cid_by_user_id(user_id) 738 + .await 739 + .map_err(|e| format!("Failed to fetch repo: {:?}", e))? 740 + .ok_or_else(|| "Repository not found".to_string())? 741 + .to_string(); 923 742 924 743 let actual_head_cid = 925 744 Cid::from_str(&repo_root_cid_str).map_err(|e| format!("Invalid repo_root_cid: {}", e))?; ··· 928 747 } 929 748 930 749 pub async fn generate_full_backup( 931 - db: &PgPool, 750 + repo_repo: &dyn tranquil_db_traits::RepoRepository, 932 751 block_store: &PostgresBlockStore, 933 752 user_id: uuid::Uuid, 934 753 head_cid: &Cid, 935 754 ) -> Result<Vec<u8>, String> { 936 - generate_repo_car_from_user_blocks(db, block_store, user_id, head_cid).await 755 + generate_repo_car_from_user_blocks(repo_repo, block_store, user_id, head_cid).await 937 756 } 938 757 939 758 pub fn count_car_blocks(car_bytes: &[u8]) -> i32 { ··· 977 796 } 978 797 979 798 async fn cleanup_old_backups( 980 - db: &PgPool, 799 + backup_repo: &dyn BackupRepository, 981 800 backup_storage: &BackupStorage, 982 801 user_id: uuid::Uuid, 983 802 retention_count: u32, 984 803 ) -> Result<(), String> { 985 - let old_backups = fetch_old_backups(db, user_id, retention_count as i64) 804 + let old_backups = backup_repo 805 + .get_old_backups(user_id, retention_count as i64) 986 806 .await 987 - .map_err(|e| format!("DB error fetching old backups: {}", e))?; 807 + .map_err(|e| format!("DB error fetching old backups: {:?}", e))?; 988 808 989 - let results = futures::future::join_all(old_backups.into_iter().map(|(id, storage_key)| async move { 990 - match backup_storage.delete_backup(&storage_key).await { 991 - Ok(()) => match delete_backup_record(db, id).await { 809 + let results = futures::future::join_all(old_backups.into_iter().map(|backup| async move { 810 + match backup_storage.delete_backup(&backup.storage_key).await { 811 + Ok(()) => match backup_repo.delete_backup(backup.id).await { 992 812 Ok(()) => Ok(()), 993 - Err(e) => Err(format!("DB delete failed for {}: {}", storage_key, e)), 813 + Err(e) => Err(format!("DB delete failed for {}: {:?}", backup.storage_key, e)), 994 814 }, 995 815 Err(e) => { 996 816 warn!( 997 - storage_key = %storage_key, 817 + storage_key = %backup.storage_key, 998 818 error = %e, 999 819 "Failed to delete old backup from storage, skipping DB cleanup to avoid orphan" 1000 820 );
+29 -3
crates/tranquil-pds/src/state.rs
··· 10 10 use std::error::Error; 11 11 use std::sync::Arc; 12 12 use tokio::sync::broadcast; 13 + use tranquil_db::{ 14 + BacklinkRepository, BackupRepository, BlobRepository, DelegationRepository, InfraRepository, 15 + OAuthRepository, PostgresRepositories, RepoEventNotifier, RepoRepository, SessionRepository, 16 + UserRepository, 17 + }; 13 18 14 19 #[derive(Clone)] 15 20 pub struct AppState { 16 - pub db: PgPool, 21 + pub repos: Arc<PostgresRepositories>, 22 + pub user_repo: Arc<dyn UserRepository>, 23 + pub oauth_repo: Arc<dyn OAuthRepository>, 24 + pub session_repo: Arc<dyn SessionRepository>, 25 + pub delegation_repo: Arc<dyn DelegationRepository>, 26 + pub repo_repo: Arc<dyn RepoRepository>, 27 + pub blob_repo: Arc<dyn BlobRepository>, 28 + pub infra_repo: Arc<dyn InfraRepository>, 29 + pub backup_repo: Arc<dyn BackupRepository>, 30 + pub backlink_repo: Arc<dyn BacklinkRepository>, 31 + pub event_notifier: Arc<dyn RepoEventNotifier>, 17 32 pub block_store: PostgresBlockStore, 18 33 pub blob_store: Arc<dyn BlobStorage>, 19 34 pub backup_storage: Option<Arc<BackupStorage>>, ··· 133 148 pub async fn from_db(db: PgPool) -> Self { 134 149 AuthConfig::init(); 135 150 136 - let block_store = PostgresBlockStore::new(db.clone()); 151 + let repos = Arc::new(PostgresRepositories::new(db.clone())); 152 + let block_store = PostgresBlockStore::new(db); 137 153 let blob_store = S3BlobStorage::new().await; 138 154 let backup_storage = BackupStorage::new().await.map(Arc::new); 139 155 ··· 149 165 let did_resolver = Arc::new(DidResolver::new()); 150 166 151 167 Self { 152 - db, 168 + user_repo: repos.user.clone(), 169 + oauth_repo: repos.oauth.clone(), 170 + session_repo: repos.session.clone(), 171 + delegation_repo: repos.delegation.clone(), 172 + repo_repo: repos.repo.clone(), 173 + blob_repo: repos.blob.clone(), 174 + infra_repo: repos.infra.clone(), 175 + backup_repo: repos.backup.clone(), 176 + backlink_repo: repos.backlink.clone(), 177 + event_notifier: repos.event_notifier.clone(), 178 + repos, 153 179 block_store, 154 180 blob_store: Arc::new(blob_store), 155 181 backup_storage,
+47 -56
crates/tranquil-pds/src/sync/blob.rs
··· 11 11 }; 12 12 use serde::{Deserialize, Serialize}; 13 13 use tracing::error; 14 + use tranquil_types::{CidLink, Did}; 14 15 15 16 #[derive(Deserialize)] 16 17 pub struct GetBlobParams { ··· 22 23 State(state): State<AppState>, 23 24 Query(params): Query<GetBlobParams>, 24 25 ) -> Response { 25 - let did = params.did.trim(); 26 - let cid = params.cid.trim(); 27 - if did.is_empty() { 26 + let did_str = params.did.trim(); 27 + let cid_str = params.cid.trim(); 28 + if did_str.is_empty() { 28 29 return ApiError::InvalidRequest("did is required".into()).into_response(); 29 30 } 30 - if cid.is_empty() { 31 + if cid_str.is_empty() { 31 32 return ApiError::InvalidRequest("cid is required".into()).into_response(); 32 33 } 34 + let did: Did = match did_str.parse() { 35 + Ok(d) => d, 36 + Err(_) => return ApiError::InvalidRequest("invalid did".into()).into_response(), 37 + }; 38 + let cid: CidLink = match cid_str.parse() { 39 + Ok(c) => c, 40 + Err(_) => return ApiError::InvalidRequest("invalid cid".into()).into_response(), 41 + }; 33 42 34 - let _account = match assert_repo_availability(&state.db, did, false).await { 43 + let _account = match assert_repo_availability(state.repo_repo.as_ref(), &did, false).await { 35 44 Ok(a) => a, 36 45 Err(e) => return e.into_response(), 37 46 }; 38 47 39 - let blob_result = sqlx::query!( 40 - "SELECT storage_key, mime_type, size_bytes FROM blobs WHERE cid = $1", 41 - cid 42 - ) 43 - .fetch_optional(&state.db) 44 - .await; 48 + let blob_result = state.blob_repo.get_blob_metadata(&cid).await; 45 49 match blob_result { 46 - Ok(Some(row)) => { 47 - let storage_key = &row.storage_key; 48 - let mime_type = &row.mime_type; 49 - let size_bytes = row.size_bytes; 50 - match state.blob_store.get(storage_key).await { 50 + Ok(Some(metadata)) => { 51 + match state.blob_store.get(&metadata.storage_key).await { 51 52 Ok(data) => Response::builder() 52 53 .status(StatusCode::OK) 53 - .header(header::CONTENT_TYPE, mime_type) 54 - .header(header::CONTENT_LENGTH, size_bytes.to_string()) 54 + .header(header::CONTENT_TYPE, &metadata.mime_type) 55 + .header(header::CONTENT_LENGTH, metadata.size_bytes.to_string()) 55 56 .header("x-content-type-options", "nosniff") 56 57 .header("content-security-policy", "default-src 'none'; sandbox") 57 58 .body(Body::from(data)) ··· 65 66 Ok(None) => ApiError::BlobNotFound(Some("Blob not found".into())).into_response(), 66 67 Err(e) => { 67 68 error!("DB error in get_blob: {:?}", e); 68 - ApiError::InternalError(Some("Database error".into())).into_response() 69 + ApiError::InternalError(Some(format!("Database error: {}", e))).into_response() 69 70 } 70 71 } 71 72 } ··· 89 90 State(state): State<AppState>, 90 91 Query(params): Query<ListBlobsParams>, 91 92 ) -> Response { 92 - let did = params.did.trim(); 93 - if did.is_empty() { 93 + let did_str = params.did.trim(); 94 + if did_str.is_empty() { 94 95 return ApiError::InvalidRequest("did is required".into()).into_response(); 95 96 } 97 + let did: Did = match did_str.parse() { 98 + Ok(d) => d, 99 + Err(_) => return ApiError::InvalidRequest("invalid did".into()).into_response(), 100 + }; 96 101 97 - let account = match assert_repo_availability(&state.db, did, false).await { 102 + let account = match assert_repo_availability(state.repo_repo.as_ref(), &did, false).await { 98 103 Ok(a) => a, 99 104 Err(e) => return e.into_response(), 100 105 }; ··· 103 108 let cursor_cid = params.cursor.as_deref().unwrap_or(""); 104 109 let user_id = account.user_id; 105 110 106 - let cids_result: Result<Vec<String>, sqlx::Error> = if let Some(since) = &params.since { 107 - sqlx::query_scalar!( 108 - r#" 109 - SELECT DISTINCT unnest(blobs) as "cid!" 110 - FROM repo_seq 111 - WHERE did = $1 AND rev > $2 AND blobs IS NOT NULL 112 - "#, 113 - did, 114 - since 115 - ) 116 - .fetch_all(&state.db) 117 - .await 118 - .map(|mut cids| { 119 - cids.sort(); 120 - cids.into_iter() 121 - .filter(|c| c.as_str() > cursor_cid) 122 - .take((limit + 1) as usize) 123 - .collect() 124 - }) 111 + let cids_result: Result<Vec<String>, _> = if let Some(since) = &params.since { 112 + state 113 + .blob_repo 114 + .list_blobs_since_rev(&did, since) 115 + .await 116 + .map(|cids| { 117 + let mut cid_strs: Vec<String> = cids.into_iter().map(|c| c.to_string()).collect(); 118 + cid_strs.sort(); 119 + cid_strs 120 + .into_iter() 121 + .filter(|c| c.as_str() > cursor_cid) 122 + .take((limit + 1) as usize) 123 + .collect() 124 + }) 125 125 } else { 126 - sqlx::query!( 127 - r#" 128 - SELECT cid FROM blobs 129 - WHERE created_by_user = $1 AND cid > $2 130 - ORDER BY cid ASC 131 - LIMIT $3 132 - "#, 133 - user_id, 134 - cursor_cid, 135 - limit + 1 136 - ) 137 - .fetch_all(&state.db) 138 - .await 139 - .map(|rows| rows.into_iter().map(|r| r.cid).collect()) 126 + state 127 + .blob_repo 128 + .list_blobs_by_user(user_id, Some(cursor_cid), limit + 1) 129 + .await 130 + .map(|cids| cids.into_iter().map(|c| c.to_string()).collect()) 140 131 }; 141 132 match cids_result { 142 133 Ok(cids) => { ··· 154 145 } 155 146 Err(e) => { 156 147 error!("DB error in list_blobs: {:?}", e); 157 - ApiError::InternalError(Some("Database error".into())).into_response() 148 + ApiError::InternalError(Some(format!("Database error: {}", e))).into_response() 158 149 } 159 150 } 160 151 }
+27 -26
crates/tranquil-pds/src/sync/commit.rs
··· 13 13 use serde::{Deserialize, Serialize}; 14 14 use std::str::FromStr; 15 15 use tracing::error; 16 + use tranquil_types::Did; 16 17 17 18 async fn get_rev_from_commit(state: &AppState, cid_str: &str) -> Option<String> { 18 19 let cid = Cid::from_str(cid_str).ok()?; ··· 36 37 State(state): State<AppState>, 37 38 Query(params): Query<GetLatestCommitParams>, 38 39 ) -> Response { 39 - let did = params.did.trim(); 40 - if did.is_empty() { 40 + let did_str = params.did.trim(); 41 + if did_str.is_empty() { 41 42 return ApiError::InvalidRequest("did is required".into()).into_response(); 42 43 } 44 + let did: Did = match did_str.parse() { 45 + Ok(d) => d, 46 + Err(_) => return ApiError::InvalidRequest("invalid did".into()).into_response(), 47 + }; 43 48 44 - let account = match assert_repo_availability(&state.db, did, false).await { 49 + let account = match assert_repo_availability(state.repo_repo.as_ref(), &did, false).await { 45 50 Ok(a) => a, 46 51 Err(e) => return e.into_response(), 47 52 }; ··· 53 58 let Some(rev) = get_rev_from_commit(&state, &repo_root_cid).await else { 54 59 error!( 55 60 "Failed to parse commit for DID {}: CID {}", 56 - did, repo_root_cid 61 + did_str, repo_root_cid 57 62 ); 58 63 return ApiError::InternalError(Some("Failed to read repo commit".into())).into_response(); 59 64 }; ··· 97 102 Query(params): Query<ListReposParams>, 98 103 ) -> Response { 99 104 let limit = params.limit.unwrap_or(50).clamp(1, 1000); 100 - let cursor_did = params.cursor.as_deref().unwrap_or(""); 101 - let result = sqlx::query!( 102 - r#" 103 - SELECT u.did, u.deactivated_at, u.takedown_ref, r.repo_root_cid, r.repo_rev 104 - FROM repos r 105 - JOIN users u ON r.user_id = u.id 106 - WHERE u.did > $1 107 - ORDER BY u.did ASC 108 - LIMIT $2 109 - "#, 110 - cursor_did, 111 - limit + 1 112 - ) 113 - .fetch_all(&state.db) 114 - .await; 105 + let cursor_did: Option<Did> = params 106 + .cursor 107 + .as_ref() 108 + .and_then(|s| s.parse().ok()); 109 + let cursor_ref = cursor_did.as_ref(); 110 + let result = state.repo_repo.list_repos_paginated(cursor_ref, limit + 1).await; 115 111 match result { 116 112 Ok(rows) => { 117 113 let has_more = rows.len() as i64 > limit; 118 114 let mut repos: Vec<RepoInfo> = Vec::new(); 119 115 for row in rows.iter().take(limit as usize) { 120 - let rev = match get_rev_from_commit(&state, &row.repo_root_cid).await { 116 + let cid_str = row.repo_root_cid.to_string(); 117 + let rev = match get_rev_from_commit(&state, &cid_str).await { 121 118 Some(r) => r, 122 119 None => { 123 120 if let Some(ref stored_rev) = row.repo_rev { ··· 140 137 AccountStatus::Active 141 138 }; 142 139 repos.push(RepoInfo { 143 - did: row.did.clone(), 144 - head: row.repo_root_cid.clone(), 140 + did: row.did.to_string(), 141 + head: cid_str, 145 142 rev, 146 143 active: status.is_active(), 147 144 status: status.as_str().map(String::from), ··· 187 184 State(state): State<AppState>, 188 185 Query(params): Query<GetRepoStatusParams>, 189 186 ) -> Response { 190 - let did = params.did.trim(); 191 - if did.is_empty() { 187 + let did_str = params.did.trim(); 188 + if did_str.is_empty() { 192 189 return ApiError::InvalidRequest("did is required".into()).into_response(); 193 190 } 191 + let did: Did = match did_str.parse() { 192 + Ok(d) => d, 193 + Err(_) => return ApiError::InvalidRequest("invalid did".into()).into_response(), 194 + }; 194 195 195 - let account = match get_account_with_status(&state.db, did).await { 196 + let account = match get_account_with_status(state.repo_repo.as_ref(), &did).await { 196 197 Ok(Some(a)) => a, 197 198 Ok(None) => { 198 - return ApiError::RepoNotFound(Some(format!("Could not find repo for DID: {}", did))) 199 + return ApiError::RepoNotFound(Some(format!("Could not find repo for DID: {}", did_str))) 199 200 .into_response(); 200 201 } 201 202 Err(e) => {
+20 -10
crates/tranquil-pds/src/sync/deprecated.rs
··· 2 2 use crate::state::AppState; 3 3 use crate::sync::car::encode_car_header; 4 4 use crate::sync::util::assert_repo_availability; 5 + use tranquil_types::Did; 5 6 use axum::{ 6 7 Json, 7 8 extract::{Query, State}, ··· 27 28 let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 28 29 let http_uri = "/"; 29 30 match crate::auth::validate_token_with_dpop( 30 - &state.db, 31 + state.user_repo.as_ref(), 32 + state.oauth_repo.as_ref(), 31 33 &extracted.token, 32 34 extracted.is_dpop, 33 35 dpop_proof, ··· 58 60 headers: HeaderMap, 59 61 Query(params): Query<GetHeadParams>, 60 62 ) -> Response { 61 - let did = params.did.trim(); 62 - if did.is_empty() { 63 + let did_str = params.did.trim(); 64 + if did_str.is_empty() { 63 65 return ApiError::InvalidRequest("did is required".into()).into_response(); 64 66 } 65 - let is_admin_or_self = check_admin_or_self(&state, &headers, did).await; 66 - let account = match assert_repo_availability(&state.db, did, is_admin_or_self).await { 67 + let did: Did = match did_str.parse() { 68 + Ok(d) => d, 69 + Err(_) => return ApiError::InvalidRequest("invalid did".into()).into_response(), 70 + }; 71 + let is_admin_or_self = check_admin_or_self(&state, &headers, did_str).await; 72 + let account = match assert_repo_availability(state.repo_repo.as_ref(), &did, is_admin_or_self).await { 67 73 Ok(a) => a, 68 74 Err(e) => return e.into_response(), 69 75 }; 70 76 match account.repo_root_cid { 71 77 Some(root) => (StatusCode::OK, Json(GetHeadOutput { root })).into_response(), 72 - None => ApiError::RepoNotFound(Some(format!("Could not find root for DID: {}", did))) 78 + None => ApiError::RepoNotFound(Some(format!("Could not find root for DID: {}", did_str))) 73 79 .into_response(), 74 80 } 75 81 } ··· 84 90 headers: HeaderMap, 85 91 Query(params): Query<GetCheckoutParams>, 86 92 ) -> Response { 87 - let did = params.did.trim(); 88 - if did.is_empty() { 93 + let did_str = params.did.trim(); 94 + if did_str.is_empty() { 89 95 return ApiError::InvalidRequest("did is required".into()).into_response(); 90 96 } 91 - let is_admin_or_self = check_admin_or_self(&state, &headers, did).await; 92 - let account = match assert_repo_availability(&state.db, did, is_admin_or_self).await { 97 + let did: Did = match did_str.parse() { 98 + Ok(d) => d, 99 + Err(_) => return ApiError::InvalidRequest("invalid did".into()).into_response(), 100 + }; 101 + let is_admin_or_self = check_admin_or_self(&state, &headers, did_str).await; 102 + let account = match assert_repo_availability(state.repo_repo.as_ref(), &did, is_admin_or_self).await { 93 103 Ok(a) => a, 94 104 Err(e) => return e.into_response(), 95 105 };
+1 -21
crates/tranquil-pds/src/sync/firehose.rs
··· 1 - use chrono::{DateTime, Utc}; 2 - use serde::{Deserialize, Serialize}; 3 - use serde_json::Value; 4 - 5 - #[derive(Debug, Clone, Serialize, Deserialize)] 6 - pub struct SequencedEvent { 7 - pub seq: i64, 8 - pub did: String, 9 - pub created_at: DateTime<Utc>, 10 - pub event_type: String, 11 - pub commit_cid: Option<String>, 12 - pub prev_cid: Option<String>, 13 - pub prev_data_cid: Option<String>, 14 - pub ops: Option<Value>, 15 - pub blobs: Option<Vec<String>>, 16 - pub blocks_cids: Option<Vec<String>>, 17 - pub handle: Option<String>, 18 - pub active: Option<bool>, 19 - pub status: Option<String>, 20 - pub rev: Option<String>, 21 - } 1 + pub use tranquil_db_traits::SequencedEvent;
+4 -4
crates/tranquil-pds/src/sync/frame.rs
··· 198 198 type Error = CommitFrameError; 199 199 200 200 fn try_from(event: SequencedEvent) -> Result<Self, Self::Error> { 201 - let commit_cid_str = event.commit_cid.ok_or_else(|| { 201 + let commit_cid = event.commit_cid.ok_or_else(|| { 202 202 CommitFrameError::InvalidCommitCid("Missing commit_cid in event".to_string()) 203 203 })?; 204 204 let builder = CommitFrameBuilder::new( 205 205 event.seq, 206 - event.did, 207 - &commit_cid_str, 208 - event.prev_cid.as_deref(), 206 + event.did.to_string(), 207 + commit_cid.as_str(), 208 + event.prev_cid.as_ref().map(|c| c.as_str()), 209 209 event.ops.unwrap_or_default(), 210 210 event.blobs.unwrap_or_default(), 211 211 event.created_at,
+38 -56
crates/tranquil-pds/src/sync/import.rs
··· 3 3 use ipld_core::ipld::Ipld; 4 4 use iroh_car::CarReader; 5 5 use serde_json::Value as JsonValue; 6 - use sqlx::PgPool; 7 6 use std::collections::HashMap; 8 7 use std::io::Cursor; 8 + use std::sync::Arc; 9 9 use thiserror::Error; 10 10 use tracing::debug; 11 + use tranquil_db::{ImportBlock, ImportRecord, ImportRepoError, RepoRepository}; 11 12 use uuid::Uuid; 12 13 13 14 #[derive(Error, Debug)] ··· 21 22 #[error("Invalid CBOR: {0}")] 22 23 InvalidCbor(String), 23 24 #[error("Database error: {0}")] 24 - Database(#[from] sqlx::Error), 25 + Database(String), 25 26 #[error("Block store error: {0}")] 26 27 BlockStore(String), 27 28 #[error("Import size limit exceeded")] ··· 36 37 VerificationFailed(#[from] super::verify::VerifyError), 37 38 #[error("DID mismatch: CAR is for {car_did}, but authenticated as {auth_did}")] 38 39 DidMismatch { car_did: String, auth_did: String }, 40 + } 41 + 42 + impl From<ImportRepoError> for ImportError { 43 + fn from(e: ImportRepoError) -> Self { 44 + match e { 45 + ImportRepoError::RepoNotFound => ImportError::RepoNotFound, 46 + ImportRepoError::ConcurrentModification => ImportError::ConcurrentModification, 47 + ImportRepoError::Database(msg) => ImportError::Database(msg), 48 + } 49 + } 39 50 } 40 51 41 52 #[derive(Debug, Clone)] ··· 307 318 } 308 319 309 320 pub async fn apply_import( 310 - db: &PgPool, 321 + repo_repo: &Arc<dyn RepoRepository>, 311 322 user_id: Uuid, 312 323 root: Cid, 313 324 blocks: HashMap<Cid, Bytes>, ··· 329 340 records.len(), 330 341 user_id 331 342 ); 332 - let mut tx = db.begin().await?; 333 - let repo = sqlx::query!( 334 - "SELECT repo_root_cid FROM repos WHERE user_id = $1 FOR UPDATE NOWAIT", 335 - user_id 336 - ) 337 - .fetch_optional(&mut *tx) 338 - .await 339 - .map_err(|e| { 340 - if let sqlx::Error::Database(ref db_err) = e 341 - && db_err.code().as_deref() == Some("55P03") 342 - { 343 - return ImportError::ConcurrentModification; 344 - } 345 - ImportError::Database(e) 346 - })?; 347 - if repo.is_none() { 348 - return Err(ImportError::RepoNotFound); 349 - } 350 - let block_chunks: Vec<Vec<(&Cid, &Bytes)>> = blocks 343 + 344 + let import_blocks: Vec<ImportBlock> = blocks 351 345 .iter() 352 - .collect::<Vec<_>>() 353 - .chunks(100) 354 - .map(|c| c.to_vec()) 346 + .map(|(cid, data)| ImportBlock { 347 + cid_bytes: cid.to_bytes(), 348 + data: data.to_vec(), 349 + }) 350 + .collect(); 351 + 352 + let import_records: Vec<ImportRecord> = records 353 + .iter() 354 + .filter_map(|r| { 355 + let collection = r.collection.parse().ok()?; 356 + let rkey = r.rkey.parse().ok()?; 357 + let record_cid = r.cid.to_string().parse().ok()?; 358 + Some(ImportRecord { 359 + collection, 360 + rkey, 361 + record_cid, 362 + }) 363 + }) 355 364 .collect(); 356 - for chunk in block_chunks { 357 - for (cid, data) in chunk { 358 - let cid_bytes = cid.to_bytes(); 359 - sqlx::query!( 360 - "INSERT INTO blocks (cid, data) VALUES ($1, $2) ON CONFLICT (cid) DO NOTHING", 361 - &cid_bytes, 362 - data.as_ref() 363 - ) 364 - .execute(&mut *tx) 365 - .await?; 366 - } 367 - } 368 - sqlx::query!("DELETE FROM records WHERE repo_id = $1", user_id) 369 - .execute(&mut *tx) 365 + 366 + repo_repo 367 + .import_repo_data(user_id, &import_blocks, &import_records) 370 368 .await?; 371 - for record in &records { 372 - let record_cid_str = record.cid.to_string(); 373 - sqlx::query!( 374 - r#" 375 - INSERT INTO records (repo_id, collection, rkey, record_cid) 376 - VALUES ($1, $2, $3, $4) 377 - ON CONFLICT (repo_id, collection, rkey) DO UPDATE SET record_cid = $4 378 - "#, 379 - user_id, 380 - record.collection, 381 - record.rkey, 382 - record_cid_str 383 - ) 384 - .execute(&mut *tx) 385 - .await?; 386 - } 387 - tx.commit().await?; 369 + 388 370 debug!( 389 371 "Successfully imported {} blocks and {} records", 390 372 blocks.len(),
+48 -61
crates/tranquil-pds/src/sync/listener.rs
··· 1 1 use crate::state::AppState; 2 2 use crate::sync::firehose::SequencedEvent; 3 - use sqlx::postgres::PgListener; 4 3 use std::sync::atomic::{AtomicI64, Ordering}; 5 4 use tracing::{debug, error, info, warn}; 6 5 7 6 static LAST_BROADCAST_SEQ: AtomicI64 = AtomicI64::new(0); 8 7 9 8 pub async fn start_sequencer_listener(state: AppState) { 10 - let initial_seq = sqlx::query_scalar!("SELECT COALESCE(MAX(seq), 0) as max FROM repo_seq") 11 - .fetch_one(&state.db) 12 - .await 13 - .unwrap_or(Some(0)) 14 - .unwrap_or(0); 9 + let initial_seq = state.repo_repo.get_max_seq().await.unwrap_or(0); 15 10 LAST_BROADCAST_SEQ.store(initial_seq, Ordering::SeqCst); 16 11 info!(initial_seq = initial_seq, "Initialized sequencer listener"); 17 12 tokio::spawn(async move { ··· 26 21 } 27 22 28 23 async fn listen_loop(state: AppState) -> anyhow::Result<()> { 29 - let mut listener = PgListener::connect_with(&state.db).await?; 30 - listener.listen("repo_updates").await?; 31 - info!("Connected to Postgres and listening for 'repo_updates'"); 24 + let mut receiver = state 25 + .event_notifier 26 + .subscribe() 27 + .await 28 + .map_err(|e| anyhow::anyhow!("Failed to subscribe to events: {:?}", e))?; 29 + info!("Connected to database and listening for repo updates"); 32 30 let catchup_start = LAST_BROADCAST_SEQ.load(Ordering::SeqCst); 33 - let events = sqlx::query_as!( 34 - SequencedEvent, 35 - r#" 36 - SELECT seq, did, created_at, event_type, commit_cid, prev_cid, prev_data_cid, ops, blobs, blocks_cids, handle, active, status, rev 37 - FROM repo_seq 38 - WHERE seq > $1 39 - ORDER BY seq ASC 40 - "#, 41 - catchup_start 42 - ) 43 - .fetch_all(&state.db) 44 - .await?; 31 + let events = state 32 + .repo_repo 33 + .get_events_since_seq(catchup_start, None) 34 + .await 35 + .map_err(|e| anyhow::anyhow!("Failed to fetch catchup events: {:?}", e))?; 45 36 if !events.is_empty() { 46 37 info!( 47 38 count = events.len(), ··· 50 41 ); 51 42 events.into_iter().for_each(|event| { 52 43 let seq = event.seq; 53 - let _ = state.firehose_tx.send(event); 44 + let firehose_event = to_firehose_event(event); 45 + let _ = state.firehose_tx.send(firehose_event); 54 46 LAST_BROADCAST_SEQ.store(seq, Ordering::SeqCst); 55 47 }); 56 48 } 57 49 loop { 58 - let notification = listener.recv().await?; 59 - let payload = notification.payload(); 60 - debug!(payload = %payload, "Received postgres notification"); 61 - let seq_id: i64 = match payload.parse() { 62 - Ok(id) => id, 63 - Err(e) => { 64 - warn!( 65 - "Received invalid payload in repo_updates: '{}'. Error: {}", 66 - payload, e 67 - ); 68 - continue; 69 - } 50 + let Some(seq_id) = receiver.recv().await else { 51 + return Err(anyhow::anyhow!("Event receiver disconnected")); 70 52 }; 53 + debug!(seq = seq_id, "Received event notification"); 71 54 let last_seq = LAST_BROADCAST_SEQ.load(Ordering::SeqCst); 72 55 if seq_id <= last_seq { 73 56 debug!( ··· 78 61 continue; 79 62 } 80 63 if seq_id > last_seq + 1 { 81 - let gap_events = sqlx::query_as!( 82 - SequencedEvent, 83 - r#" 84 - SELECT seq, did, created_at, event_type, commit_cid, prev_cid, prev_data_cid, ops, blobs, blocks_cids, handle, active, status, rev 85 - FROM repo_seq 86 - WHERE seq > $1 AND seq < $2 87 - ORDER BY seq ASC 88 - "#, 89 - last_seq, 90 - seq_id 91 - ) 92 - .fetch_all(&state.db) 93 - .await?; 64 + let gap_events = state 65 + .repo_repo 66 + .get_events_in_seq_range(last_seq, seq_id) 67 + .await 68 + .unwrap_or_default(); 94 69 if !gap_events.is_empty() { 95 70 debug!(count = gap_events.len(), "Filling sequence gap"); 96 71 gap_events.into_iter().for_each(|event| { 97 72 let seq = event.seq; 98 - let _ = state.firehose_tx.send(event); 73 + let firehose_event = to_firehose_event(event); 74 + let _ = state.firehose_tx.send(firehose_event); 99 75 LAST_BROADCAST_SEQ.store(seq, Ordering::SeqCst); 100 76 }); 101 77 } 102 78 } 103 - let event = sqlx::query_as!( 104 - SequencedEvent, 105 - r#" 106 - SELECT seq, did, created_at, event_type, commit_cid, prev_cid, prev_data_cid, ops, blobs, blocks_cids, handle, active, status, rev 107 - FROM repo_seq 108 - WHERE seq = $1 109 - "#, 110 - seq_id 111 - ) 112 - .fetch_optional(&state.db) 113 - .await?; 79 + let event = state.repo_repo.get_event_by_seq(seq_id).await.ok().flatten(); 114 80 if let Some(event) = event { 115 - match state.firehose_tx.send(event) { 81 + let seq = event.seq; 82 + let firehose_event = to_firehose_event(event); 83 + match state.firehose_tx.send(firehose_event) { 116 84 Ok(receiver_count) => { 117 85 debug!( 118 86 seq = seq_id, ··· 124 92 warn!(seq = seq_id, error = %e, "Failed to broadcast event (no receivers?)"); 125 93 } 126 94 } 127 - LAST_BROADCAST_SEQ.store(seq_id, Ordering::SeqCst); 95 + LAST_BROADCAST_SEQ.store(seq, Ordering::SeqCst); 128 96 } else { 129 97 warn!( 130 98 seq = seq_id, ··· 133 101 } 134 102 } 135 103 } 104 + 105 + fn to_firehose_event(event: tranquil_db_traits::SequencedEvent) -> SequencedEvent { 106 + SequencedEvent { 107 + seq: event.seq, 108 + did: event.did, 109 + created_at: event.created_at, 110 + event_type: event.event_type, 111 + commit_cid: event.commit_cid, 112 + prev_cid: event.prev_cid, 113 + prev_data_cid: event.prev_data_cid, 114 + ops: event.ops, 115 + blobs: event.blobs, 116 + blocks_cids: event.blocks_cids, 117 + handle: event.handle, 118 + active: event.active, 119 + status: event.status, 120 + rev: event.rev, 121 + } 122 + }
+37 -40
crates/tranquil-pds/src/sync/repo.rs
··· 14 14 use std::io::Write; 15 15 use std::str::FromStr; 16 16 use tracing::error; 17 + use tranquil_types::Did; 17 18 18 19 fn parse_get_blocks_query(query_string: &str) -> Result<(String, Vec<String>), String> { 19 20 let did = crate::util::parse_repeated_query_param(Some(query_string), "did") ··· 29 30 return ApiError::InvalidRequest("Missing query parameters".into()).into_response(); 30 31 }; 31 32 32 - let (did, cid_strings) = match parse_get_blocks_query(&query_string) { 33 + let (did_str, cid_strings) = match parse_get_blocks_query(&query_string) { 33 34 Ok(parsed) => parsed, 34 35 Err(msg) => return ApiError::InvalidRequest(msg).into_response(), 35 36 }; 37 + let did: Did = match did_str.parse() { 38 + Ok(d) => d, 39 + Err(_) => return ApiError::InvalidRequest("invalid did".into()).into_response(), 40 + }; 36 41 37 - let _account = match assert_repo_availability(&state.db, &did, false).await { 42 + let _account = match assert_repo_availability(state.repo_repo.as_ref(), &did, false).await { 38 43 Ok(a) => a, 39 44 Err(e) => return e.into_response(), 40 45 }; ··· 119 124 State(state): State<AppState>, 120 125 Query(query): Query<GetRepoQuery>, 121 126 ) -> Response { 122 - let account = match assert_repo_availability(&state.db, &query.did, false).await { 127 + let did: Did = match query.did.parse() { 128 + Ok(d) => d, 129 + Err(_) => return ApiError::InvalidRequest("invalid did".into()).into_response(), 130 + }; 131 + let account = match assert_repo_availability(state.repo_repo.as_ref(), &did, false).await { 123 132 Ok(a) => a, 124 133 Err(e) => return e.into_response(), 125 134 }; ··· 133 142 }; 134 143 135 144 if let Some(since) = &query.since { 136 - return get_repo_since(&state, &query.did, &head_cid, since).await; 145 + return get_repo_since(&state, &did, &head_cid, since).await; 137 146 } 138 147 139 148 let car_bytes = match generate_repo_car_from_user_blocks( 140 - &state.db, 149 + state.repo_repo.as_ref(), 141 150 &state.block_store, 142 151 account.user_id, 143 152 &head_cid, ··· 159 168 .into_response() 160 169 } 161 170 162 - async fn get_repo_since(state: &AppState, did: &str, head_cid: &Cid, since: &str) -> Response { 163 - let events = sqlx::query!( 164 - r#" 165 - SELECT blocks_cids, commit_cid 166 - FROM repo_seq 167 - WHERE did = $1 AND rev > $2 168 - ORDER BY seq DESC 169 - "#, 170 - did, 171 - since 172 - ) 173 - .fetch_all(&state.db) 174 - .await; 171 + async fn get_repo_since(state: &AppState, did: &Did, head_cid: &Cid, since: &str) -> Response { 172 + let user_id = match state.user_repo.get_id_by_did(did).await { 173 + Ok(Some(id)) => id, 174 + Ok(None) => { 175 + return ApiError::RepoNotFound(Some(format!("Could not find repo for DID: {}", did))) 176 + .into_response(); 177 + } 178 + Err(e) => { 179 + error!("DB error looking up user: {:?}", e); 180 + return ApiError::InternalError(Some("Database error".into())).into_response(); 181 + } 182 + }; 175 183 176 - let events = match events { 177 - Ok(e) => e, 184 + let block_cid_bytes = match state.repo_repo.get_user_block_cids_since_rev(user_id, since).await 185 + { 186 + Ok(cids) => cids, 178 187 Err(e) => { 179 188 error!("DB error in get_repo_since: {:?}", e); 180 189 return ApiError::InternalError(Some("Database error".into())).into_response(); 181 190 } 182 191 }; 183 192 184 - let block_cids: Vec<Cid> = events 193 + let block_cids: Vec<Cid> = block_cid_bytes 185 194 .iter() 186 - .flat_map(|event| { 187 - let block_cids = event 188 - .blocks_cids 189 - .as_ref() 190 - .map(|cids| cids.iter().filter_map(|s| Cid::from_str(s).ok()).collect()) 191 - .unwrap_or_else(Vec::new); 192 - let commit_cid = event 193 - .commit_cid 194 - .as_ref() 195 - .and_then(|s| Cid::from_str(s).ok()); 196 - block_cids.into_iter().chain(commit_cid) 197 - }) 198 - .fold(Vec::new(), |mut acc, cid| { 199 - if !acc.contains(&cid) { 200 - acc.push(cid); 201 - } 202 - acc 203 - }); 195 + .filter_map(|bytes| Cid::try_from(bytes.as_slice()).ok()) 196 + .collect(); 204 197 205 198 let mut car_bytes = match encode_car_header(head_cid) { 206 199 Ok(h) => h, ··· 269 262 use std::collections::BTreeMap; 270 263 use std::sync::Arc; 271 264 272 - let account = match assert_repo_availability(&state.db, &query.did, false).await { 265 + let did: Did = match query.did.parse() { 266 + Ok(d) => d, 267 + Err(_) => return ApiError::InvalidRequest("invalid did".into()).into_response(), 268 + }; 269 + let account = match assert_repo_availability(state.repo_repo.as_ref(), &did, false).await { 273 270 Ok(a) => a, 274 271 Err(e) => return e.into_response(), 275 272 };
+17 -56
crates/tranquil-pds/src/sync/subscribe_repos.rs
··· 72 72 let mut last_seen: i64 = -1; 73 73 74 74 if let Some(cursor) = params.cursor { 75 - let current_seq = sqlx::query_scalar!("SELECT MAX(seq) FROM repo_seq") 76 - .fetch_one(&state.db) 77 - .await 78 - .ok() 79 - .flatten() 80 - .unwrap_or(0); 75 + let current_seq = state.repo_repo.get_max_seq().await.unwrap_or(0); 81 76 82 77 if cursor > current_seq { 83 78 if let Ok(error_bytes) = ··· 91 86 92 87 let backfill_time = chrono::Utc::now() - chrono::Duration::hours(get_backfill_hours()); 93 88 94 - let first_event = sqlx::query_as!( 95 - SequencedEvent, 96 - r#" 97 - SELECT seq, did, created_at, event_type, commit_cid, prev_cid, prev_data_cid, ops, blobs, blocks_cids, handle, active, status, rev 98 - FROM repo_seq 99 - WHERE seq > $1 100 - ORDER BY seq ASC 101 - LIMIT 1 102 - "#, 103 - cursor 104 - ) 105 - .fetch_optional(&state.db) 106 - .await 107 - .ok() 108 - .flatten(); 89 + let first_event = state 90 + .repo_repo 91 + .get_events_since_cursor(cursor, 1) 92 + .await 93 + .ok() 94 + .and_then(|events| events.into_iter().next()); 109 95 110 96 let mut current_cursor = cursor; 111 97 ··· 119 105 let _ = socket.send(Message::Binary(info_bytes.into())).await; 120 106 } 121 107 122 - let earliest = sqlx::query_scalar!( 123 - "SELECT MIN(seq) FROM repo_seq WHERE created_at >= $1", 124 - backfill_time 125 - ) 126 - .fetch_one(&state.db) 127 - .await 128 - .ok() 129 - .flatten(); 108 + let earliest = state.repo_repo.get_min_seq_since(backfill_time).await.ok().flatten(); 130 109 131 110 if let Some(earliest_seq) = earliest { 132 111 current_cursor = earliest_seq - 1; ··· 136 115 last_seen = current_cursor; 137 116 138 117 loop { 139 - let events = sqlx::query_as!( 140 - SequencedEvent, 141 - r#" 142 - SELECT seq, did, created_at, event_type, commit_cid, prev_cid, prev_data_cid, ops, blobs, blocks_cids, handle, active, status, rev 143 - FROM repo_seq 144 - WHERE seq > $1 145 - ORDER BY seq ASC 146 - LIMIT $2 147 - "#, 148 - current_cursor, 149 - BACKFILL_BATCH_SIZE 150 - ) 151 - .fetch_all(&state.db) 152 - .await; 118 + let events = state 119 + .repo_repo 120 + .get_events_since_cursor(current_cursor, BACKFILL_BATCH_SIZE) 121 + .await; 153 122 match events { 154 123 Ok(events) => { 155 124 if events.is_empty() { ··· 186 155 } 187 156 } 188 157 Err(e) => { 189 - error!("Failed to fetch backfill events: {}", e); 158 + error!("Failed to fetch backfill events: {:?}", e); 190 159 socket.close().await.ok(); 191 160 return Err(()); 192 161 } 193 162 } 194 163 } 195 164 196 - let cutover_events = sqlx::query_as!( 197 - SequencedEvent, 198 - r#" 199 - SELECT seq, did, created_at, event_type, commit_cid, prev_cid, prev_data_cid, ops, blobs, blocks_cids, handle, active, status, rev 200 - FROM repo_seq 201 - WHERE seq > $1 202 - ORDER BY seq ASC 203 - "#, 204 - last_seen 205 - ) 206 - .fetch_all(&state.db) 207 - .await; 165 + let cutover_events = state 166 + .repo_repo 167 + .get_events_since_seq(last_seen, None) 168 + .await; 208 169 209 170 if let Ok(events) = cutover_events 210 171 && !events.is_empty()
+31 -39
crates/tranquil-pds/src/sync/util.rs
··· 12 12 use jacquard_repo::commit::Commit; 13 13 use jacquard_repo::storage::BlockStore; 14 14 use serde::Serialize; 15 - use sqlx::PgPool; 15 + use tranquil_db_traits::RepoRepository; 16 + use tranquil_types::Did; 16 17 use std::collections::{BTreeMap, HashMap}; 17 18 use std::io::Cursor; 18 19 use std::str::FromStr; ··· 134 135 } 135 136 136 137 pub async fn get_account_with_status( 137 - db: &PgPool, 138 - did: &str, 139 - ) -> Result<Option<RepoAccount>, sqlx::Error> { 140 - let row = sqlx::query!( 141 - r#" 142 - SELECT u.id, u.did, u.deactivated_at, u.takedown_ref, r.repo_root_cid 143 - FROM users u 144 - LEFT JOIN repos r ON r.user_id = u.id 145 - WHERE u.did = $1 146 - "#, 147 - did 148 - ) 149 - .fetch_optional(db) 150 - .await?; 138 + repo_repo: &dyn RepoRepository, 139 + did: &Did, 140 + ) -> Result<Option<RepoAccount>, tranquil_db_traits::DbError> { 141 + let row = repo_repo.get_account_with_repo(did).await?; 151 142 152 143 Ok(row.map(|r| { 153 144 let status = if r.takedown_ref.is_some() { ··· 159 150 }; 160 151 161 152 RepoAccount { 162 - did: r.did, 163 - user_id: r.id, 153 + did: r.did.to_string(), 154 + user_id: r.user_id, 164 155 status, 165 - repo_root_cid: Some(r.repo_root_cid), 156 + repo_root_cid: r.repo_root_cid.map(|c| c.to_string()), 166 157 } 167 158 })) 168 159 } 169 160 170 161 pub async fn assert_repo_availability( 171 - db: &PgPool, 172 - did: &str, 162 + repo_repo: &dyn RepoRepository, 163 + did: &Did, 173 164 is_admin_or_self: bool, 174 165 ) -> Result<RepoAccount, RepoAvailabilityError> { 175 - let account = get_account_with_status(db, did) 166 + let account = get_account_with_status(repo_repo, did) 176 167 .await 177 168 .map_err(|e| RepoAvailabilityError::Internal(e.to_string()))?; 178 169 170 + let did_str = did.to_string(); 179 171 let account = match account { 180 172 Some(a) => a, 181 - None => return Err(RepoAvailabilityError::NotFound(did.to_string())), 173 + None => return Err(RepoAvailabilityError::NotFound(did_str)), 182 174 }; 183 175 184 176 if is_admin_or_self { ··· 186 178 } 187 179 188 180 match account.status { 189 - AccountStatus::Takendown => return Err(RepoAvailabilityError::Takendown(did.to_string())), 181 + AccountStatus::Takendown => return Err(RepoAvailabilityError::Takendown(did_str)), 190 182 AccountStatus::Deactivated => { 191 - return Err(RepoAvailabilityError::Deactivated(did.to_string())); 183 + return Err(RepoAvailabilityError::Deactivated(did_str)); 192 184 } 193 185 _ => {} 194 186 } ··· 239 231 240 232 fn format_identity_event(event: &SequencedEvent) -> Result<Vec<u8>, anyhow::Error> { 241 233 let frame = IdentityFrame { 242 - did: event.did.clone(), 243 - handle: event.handle.clone(), 234 + did: event.did.to_string(), 235 + handle: event.handle.as_ref().map(|h| h.to_string()), 244 236 seq: event.seq, 245 237 time: format_atproto_time(event.created_at), 246 238 }; ··· 256 248 257 249 fn format_account_event(event: &SequencedEvent) -> Result<Vec<u8>, anyhow::Error> { 258 250 let frame = AccountFrame { 259 - did: event.did.clone(), 251 + did: event.did.to_string(), 260 252 active: event.active.unwrap_or(true), 261 253 status: event.status.clone(), 262 254 seq: event.seq, ··· 303 295 }; 304 296 let car_bytes = write_car_blocks(commit_cid, Some(commit_bytes), BTreeMap::new()).await?; 305 297 let frame = SyncFrame { 306 - did: event.did.clone(), 298 + did: event.did.to_string(), 307 299 rev, 308 300 blocks: car_bytes, 309 301 seq: event.seq, ··· 330 322 _ => {} 331 323 } 332 324 let block_cids_str = event.blocks_cids.clone().unwrap_or_default(); 333 - let prev_cid_str = event.prev_cid.clone(); 334 - let prev_data_cid_str = event.prev_data_cid.clone(); 325 + let prev_cid_link = event.prev_cid.clone(); 326 + let prev_data_cid_link = event.prev_data_cid.clone(); 335 327 let mut frame: CommitFrame = event 336 328 .try_into() 337 329 .map_err(|e| anyhow::anyhow!("Invalid event: {}", e))?; 338 - if let Some(ref pdc) = prev_data_cid_str 339 - && let Ok(cid) = Cid::from_str(pdc) 330 + if let Some(ref pdc) = prev_data_cid_link 331 + && let Ok(cid) = Cid::from_str(pdc.as_str()) 340 332 { 341 333 frame.prev_data = Some(cid); 342 334 } 343 335 let commit_cid = frame.commit; 344 - let prev_cid = prev_cid_str.as_ref().and_then(|s| Cid::from_str(s).ok()); 336 + let prev_cid = prev_cid_link.as_ref().and_then(|c| Cid::from_str(c.as_str()).ok()); 345 337 let mut all_cids: Vec<Cid> = block_cids_str 346 338 .iter() 347 339 .filter_map(|s| Cid::from_str(s).ok()) ··· 443 435 BTreeMap::new(), 444 436 ))?; 445 437 let frame = SyncFrame { 446 - did: event.did.clone(), 438 + did: event.did.to_string(), 447 439 rev, 448 440 blocks: car_bytes, 449 441 seq: event.seq, ··· 470 462 _ => {} 471 463 } 472 464 let block_cids_str = event.blocks_cids.clone().unwrap_or_default(); 473 - let prev_cid_str = event.prev_cid.clone(); 474 - let prev_data_cid_str = event.prev_data_cid.clone(); 465 + let prev_cid_link = event.prev_cid.clone(); 466 + let prev_data_cid_link = event.prev_data_cid.clone(); 475 467 let mut frame: CommitFrame = event 476 468 .try_into() 477 469 .map_err(|e| anyhow::anyhow!("Invalid event: {}", e))?; 478 - if let Some(ref pdc) = prev_data_cid_str 479 - && let Ok(cid) = Cid::from_str(pdc) 470 + if let Some(ref pdc) = prev_data_cid_link 471 + && let Ok(cid) = Cid::from_str(pdc.as_str()) 480 472 { 481 473 frame.prev_data = Some(cid); 482 474 } 483 475 let commit_cid = frame.commit; 484 - let prev_cid = prev_cid_str.as_ref().and_then(|s| Cid::from_str(s).ok()); 476 + let prev_cid = prev_cid_link.as_ref().and_then(|c| Cid::from_str(c.as_str()).ok()); 485 477 let mut all_cids: Vec<Cid> = block_cids_str 486 478 .iter() 487 479 .filter_map(|s| Cid::from_str(s).ok())
-64
crates/tranquil-pds/src/util.rs
··· 3 3 use ipld_core::ipld::Ipld; 4 4 use rand::Rng; 5 5 use serde_json::Value as JsonValue; 6 - use sqlx::PgPool; 7 6 use std::collections::BTreeMap; 8 7 use std::str::FromStr; 9 8 use std::sync::OnceLock; 10 - use uuid::Uuid; 11 - 12 - use crate::types::{Did, Handle}; 13 9 14 10 const BASE32_ALPHABET: &str = "abcdefghijklmnopqrstuvwxyz234567"; 15 11 const DEFAULT_MAX_BLOB_SIZE: usize = 10 * 1024 * 1024 * 1024; ··· 41 37 }) 42 38 .collect::<Vec<_>>() 43 39 .join("-") 44 - } 45 - 46 - #[derive(Debug)] 47 - pub enum DbLookupError { 48 - NotFound, 49 - DatabaseError(sqlx::Error), 50 - } 51 - 52 - impl From<sqlx::Error> for DbLookupError { 53 - fn from(e: sqlx::Error) -> Self { 54 - DbLookupError::DatabaseError(e) 55 - } 56 - } 57 - 58 - pub async fn get_user_id_by_did(db: &PgPool, did: &str) -> Result<Uuid, DbLookupError> { 59 - sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 60 - .fetch_optional(db) 61 - .await? 62 - .ok_or(DbLookupError::NotFound) 63 - } 64 - 65 - pub struct UserInfo { 66 - pub id: Uuid, 67 - pub did: Did, 68 - pub handle: Handle, 69 - } 70 - 71 - pub async fn get_user_by_did(db: &PgPool, did: &str) -> Result<UserInfo, DbLookupError> { 72 - sqlx::query_as!( 73 - UserInfo, 74 - "SELECT id, did, handle FROM users WHERE did = $1", 75 - did 76 - ) 77 - .fetch_optional(db) 78 - .await? 79 - .ok_or(DbLookupError::NotFound) 80 - } 81 - 82 - pub async fn get_user_by_identifier( 83 - db: &PgPool, 84 - identifier: &str, 85 - ) -> Result<UserInfo, DbLookupError> { 86 - sqlx::query_as!( 87 - UserInfo, 88 - "SELECT id, did, handle FROM users WHERE did = $1 OR handle = $1", 89 - identifier 90 - ) 91 - .fetch_optional(db) 92 - .await? 93 - .ok_or(DbLookupError::NotFound) 94 - } 95 - 96 - pub async fn is_account_migrated(db: &PgPool, did: &str) -> Result<bool, sqlx::Error> { 97 - let row = sqlx::query!( 98 - r#"SELECT (migrated_to_pds IS NOT NULL AND deactivated_at IS NOT NULL) as "migrated!: bool" FROM users WHERE did = $1"#, 99 - did 100 - ) 101 - .fetch_optional(db) 102 - .await?; 103 - Ok(row.map(|r| r.migrated).unwrap_or(false)) 104 40 } 105 41 106 42 pub fn parse_repeated_query_param(query: Option<&str>, key: &str) -> Vec<String> {
+24 -18
crates/tranquil-pds/tests/account_notifications.rs
··· 1 1 mod common; 2 2 use common::{base_url, client, create_account_and_login, get_test_db_pool}; 3 3 use serde_json::{Value, json}; 4 - use tranquil_pds::comms::{CommsType, NewComms, enqueue_comms}; 4 + use sqlx::Row; 5 5 6 6 #[tokio::test] 7 7 async fn test_get_notification_history() { ··· 10 10 let pool = get_test_db_pool().await; 11 11 let (token, did) = create_account_and_login(&client).await; 12 12 13 - let user_id: uuid::Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 13 + let user_id: uuid::Uuid = sqlx::query_scalar("SELECT id FROM users WHERE did = $1") 14 + .bind(&did) 14 15 .fetch_one(pool) 15 16 .await 16 17 .expect("User not found"); 17 18 18 19 for i in 0..3 { 19 - let comms = NewComms::email( 20 - user_id, 21 - CommsType::Welcome, 22 - "test@example.com".to_string(), 23 - format!("Subject {}", i), 24 - format!("Body {}", i), 25 - ); 26 - enqueue_comms(pool, comms).await.expect("Failed to enqueue"); 20 + sqlx::query( 21 + r#"INSERT INTO comms_queue (user_id, channel, comms_type, recipient, subject, body) 22 + VALUES ($1, 'email', 'welcome', $2, $3, $4)"#, 23 + ) 24 + .bind(user_id) 25 + .bind("test@example.com") 26 + .bind(format!("Subject {}", i)) 27 + .bind(format!("Body {}", i)) 28 + .execute(pool) 29 + .await 30 + .expect("Failed to enqueue"); 27 31 } 28 32 29 33 let resp = client ··· 69 73 ); 70 74 71 75 let pool = get_test_db_pool().await; 72 - let user_id: uuid::Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 76 + let user_id: uuid::Uuid = sqlx::query_scalar("SELECT id FROM users WHERE did = $1") 77 + .bind(&did) 73 78 .fetch_one(pool) 74 79 .await 75 80 .expect("User not found"); 76 81 77 - let row = sqlx::query!( 82 + let row = sqlx::query( 78 83 "SELECT body, metadata FROM comms_queue WHERE user_id = $1 AND comms_type = 'channel_verification' ORDER BY created_at DESC LIMIT 1", 79 - user_id 80 84 ) 85 + .bind(user_id) 81 86 .fetch_one(pool) 82 87 .await 83 88 .expect("Verification code not found"); 84 89 85 - let code = row 86 - .metadata 90 + let metadata: Option<serde_json::Value> = row.get("metadata"); 91 + let code = metadata 87 92 .as_ref() 88 93 .and_then(|m| m.get("code")) 89 94 .and_then(|c| c.as_str()) ··· 203 208 .contains(&json!("email")) 204 209 ); 205 210 206 - let user_id: uuid::Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 211 + let user_id: uuid::Uuid = sqlx::query_scalar("SELECT id FROM users WHERE did = $1") 212 + .bind(&did) 207 213 .fetch_one(pool) 208 214 .await 209 215 .expect("User not found"); 210 216 211 - let body_text: String = sqlx::query_scalar!( 217 + let body_text: String = sqlx::query_scalar( 212 218 "SELECT body FROM comms_queue WHERE user_id = $1 AND comms_type = 'email_update' ORDER BY created_at DESC LIMIT 1", 213 - user_id 214 219 ) 220 + .bind(user_id) 215 221 .fetch_one(pool) 216 222 .await 217 223 .expect("Verification code not found");
+12 -2
crates/tranquil-pds/tests/common/mod.rs
··· 76 76 *APP_PORT.get().expect("APP_PORT not initialized") 77 77 } 78 78 79 + #[allow(dead_code)] 80 + pub fn pds_hostname() -> String { 81 + std::env::var("PDS_HOSTNAME").unwrap_or_else(|_| format!("pds.test:{}", app_port())) 82 + } 83 + 84 + #[allow(dead_code)] 85 + pub fn pds_endpoint() -> String { 86 + format!("https://{}", pds_hostname()) 87 + } 88 + 79 89 pub async fn base_url() -> &'static str { 80 90 SERVER_URL.get_or_init(|| { 81 91 let (tx, rx) = std::sync::mpsc::channel(); ··· 457 467 let addr = listener.local_addr().unwrap(); 458 468 APP_PORT.set(addr.port()).ok(); 459 469 unsafe { 460 - std::env::set_var("PDS_HOSTNAME", addr.to_string()); 470 + std::env::set_var("PDS_HOSTNAME", format!("pds.test:{}", addr.port())); 461 471 } 462 472 let rate_limiters = RateLimiters::new() 463 473 .with_login_limit(10000) ··· 474 484 tokio::spawn(async move { 475 485 axum::serve(listener, app).await.unwrap(); 476 486 }); 477 - format!("http://{}", addr) 487 + format!("http://localhost:{}", addr.port()) 478 488 } 479 489 480 490 #[allow(dead_code)]
+15 -12
crates/tranquil-pds/tests/did_web.rs
··· 94 94 #[tokio::test] 95 95 async fn test_external_did_web_no_local_doc() { 96 96 let client = client(); 97 + let base = base_url().await; 97 98 let mock_server = MockServer::start().await; 98 99 let mock_uri = mock_server.uri(); 99 100 let mock_addr = mock_uri.trim_start_matches("http://"); 100 101 let did = format!("did:web:{}", mock_addr.replace(":", "%3A")); 101 102 let handle = format!("xw{}", &uuid::Uuid::new_v4().simple().to_string()[..12]); 102 - let pds_endpoint = base_url().await.replace("http://", "https://"); 103 + let pds_endpoint = common::pds_endpoint(); 103 104 104 105 let reserve_res = client 105 106 .post(format!( 106 107 "{}/xrpc/com.atproto.server.reserveSigningKey", 107 - base_url().await 108 + base 108 109 )) 109 110 .json(&json!({ "did": did })) 110 111 .send() ··· 150 151 let res = client 151 152 .post(format!( 152 153 "{}/xrpc/com.atproto.server.createAccount", 153 - base_url().await 154 + base 154 155 )) 155 156 .json(&payload) 156 157 .send() ··· 161 162 panic!("createAccount failed: {:?}", body); 162 163 } 163 164 let res = client 164 - .get(format!("{}/u/{}/did.json", base_url().await, handle)) 165 + .get(format!("{}/u/{}/did.json", base, handle)) 165 166 .send() 166 167 .await 167 168 .expect("Failed to fetch DID doc"); ··· 383 384 #[tokio::test] 384 385 async fn test_did_web_byod_flow() { 385 386 let client = client(); 387 + let base = base_url().await; 386 388 let mock_server = MockServer::start().await; 387 389 let mock_uri = mock_server.uri(); 388 390 let mock_addr = mock_uri.trim_start_matches("http://"); ··· 393 395 unique_id 394 396 ); 395 397 let handle = format!("by{}", &uuid::Uuid::new_v4().simple().to_string()[..12]); 396 - let pds_endpoint = base_url().await.replace("http://", "https://"); 397 - let pds_did = format!("did:web:{}", pds_endpoint.trim_start_matches("https://")); 398 + let pds_endpoint = common::pds_endpoint(); 399 + let pds_hostname = common::pds_hostname(); 400 + let pds_did = format!("did:web:{}", pds_hostname); 398 401 399 402 let temp_key = SigningKey::random(&mut rand::thread_rng()); 400 403 let public_key_multibase = signing_key_to_multibase(&temp_key); ··· 430 433 let res = client 431 434 .post(format!( 432 435 "{}/xrpc/com.atproto.server.createAccount", 433 - base_url().await 436 + base 434 437 )) 435 438 .header("Authorization", format!("Bearer {}", service_jwt)) 436 439 .json(&payload) ··· 454 457 let res = client 455 458 .get(format!( 456 459 "{}/xrpc/com.atproto.server.checkAccountStatus", 457 - base_url().await 460 + base 458 461 )) 459 462 .bearer_auth(&access_jwt) 460 463 .send() ··· 470 473 let res = client 471 474 .get(format!( 472 475 "{}/xrpc/com.atproto.identity.getRecommendedDidCredentials", 473 - base_url().await 476 + base 474 477 )) 475 478 .bearer_auth(&access_jwt) 476 479 .send() ··· 493 496 let res = client 494 497 .post(format!( 495 498 "{}/xrpc/com.atproto.server.activateAccount", 496 - base_url().await 499 + base 497 500 )) 498 501 .bearer_auth(&access_jwt) 499 502 .send() ··· 508 511 let res = client 509 512 .get(format!( 510 513 "{}/xrpc/com.atproto.server.checkAccountStatus", 511 - base_url().await 514 + base 512 515 )) 513 516 .bearer_auth(&access_jwt) 514 517 .send() ··· 524 527 let res = client 525 528 .post(format!( 526 529 "{}/xrpc/com.atproto.repo.createRecord", 527 - base_url().await 530 + base 528 531 )) 529 532 .bearer_auth(&access_jwt) 530 533 .json(&json!({
+131
crates/tranquil-pds/tests/firehose_validation.rs
··· 850 850 "Should have received commits even with outdated cursor" 851 851 ); 852 852 } 853 + 854 + #[tokio::test] 855 + async fn test_firehose_car_contains_mst_blocks() { 856 + let client = client(); 857 + let (token, did) = create_account_and_login(&client).await; 858 + 859 + for i in 0..3 { 860 + let post_payload = json!({ 861 + "repo": did, 862 + "collection": "app.bsky.feed.post", 863 + "record": { 864 + "$type": "app.bsky.feed.post", 865 + "text": format!("Setup post {}", i), 866 + "createdAt": chrono::Utc::now().to_rfc3339(), 867 + } 868 + }); 869 + client 870 + .post(format!( 871 + "{}/xrpc/com.atproto.repo.createRecord", 872 + base_url().await 873 + )) 874 + .bearer_auth(&token) 875 + .json(&post_payload) 876 + .send() 877 + .await 878 + .expect("Failed to create setup post"); 879 + tokio::time::sleep(std::time::Duration::from_millis(50)).await; 880 + } 881 + 882 + let url = format!( 883 + "ws://127.0.0.1:{}/xrpc/com.atproto.sync.subscribeRepos", 884 + app_port() 885 + ); 886 + let (mut ws_stream, _) = connect_async(&url).await.expect("Failed to connect"); 887 + tokio::time::sleep(std::time::Duration::from_millis(100)).await; 888 + 889 + let post_payload = json!({ 890 + "repo": did, 891 + "collection": "app.bsky.feed.post", 892 + "record": { 893 + "$type": "app.bsky.feed.post", 894 + "text": "Test post for MST block validation", 895 + "createdAt": chrono::Utc::now().to_rfc3339(), 896 + } 897 + }); 898 + let res = client 899 + .post(format!( 900 + "{}/xrpc/com.atproto.repo.createRecord", 901 + base_url().await 902 + )) 903 + .bearer_auth(&token) 904 + .json(&post_payload) 905 + .send() 906 + .await 907 + .expect("Failed to create post"); 908 + assert_eq!(res.status(), StatusCode::OK); 909 + let create_result: Value = res.json().await.unwrap(); 910 + let record_cid_str = create_result["cid"].as_str().unwrap(); 911 + let expected_record_cid: Cid = record_cid_str.parse().unwrap(); 912 + 913 + let mut frame_opt: Option<CommitFrame> = None; 914 + let timeout = tokio::time::timeout(std::time::Duration::from_secs(10), async { 915 + loop { 916 + let msg = ws_stream.next().await.unwrap().unwrap(); 917 + let raw_bytes = match msg { 918 + tungstenite::Message::Binary(bin) => bin, 919 + _ => continue, 920 + }; 921 + if let Ok((_, f)) = parse_frame(&raw_bytes) 922 + && f.repo == did 923 + && f.ops.iter().any(|op| op.cid == Some(expected_record_cid)) 924 + { 925 + frame_opt = Some(f); 926 + break; 927 + } 928 + } 929 + }) 930 + .await; 931 + assert!(timeout.is_ok(), "Timed out waiting for firehose event"); 932 + let frame = frame_opt.expect("No matching frame found"); 933 + 934 + let mut car_reader = CarReader::new(Cursor::new(&frame.blocks)).await.unwrap(); 935 + 936 + let mut block_count = 0; 937 + let mut found_commit = false; 938 + let mut found_record = false; 939 + let mut mst_block_count = 0; 940 + 941 + while let Ok(Some((cid, data))) = car_reader.next_block().await { 942 + block_count += 1; 943 + 944 + if cid == frame.commit { 945 + found_commit = true; 946 + continue; 947 + } 948 + 949 + if cid == expected_record_cid { 950 + found_record = true; 951 + continue; 952 + } 953 + 954 + if data.len() > 10 && data.len() < 5000 { 955 + mst_block_count += 1; 956 + } 957 + } 958 + 959 + println!("CAR block analysis:"); 960 + println!(" Total blocks: {}", block_count); 961 + println!(" Found commit: {}", found_commit); 962 + println!(" Found record: {}", found_record); 963 + println!(" MST/other blocks: {}", mst_block_count); 964 + 965 + assert!(found_commit, "CAR must contain commit block"); 966 + assert!(found_record, "CAR must contain record block"); 967 + 968 + assert!( 969 + block_count >= 3, 970 + "CAR should contain at least commit + record + MST node(s), got {} blocks. \ 971 + This may indicate firehose is not including all relevant blocks.", 972 + block_count 973 + ); 974 + 975 + assert!( 976 + mst_block_count >= 1, 977 + "CAR should contain MST node blocks for repo validation, got {} MST blocks. \ 978 + Firehose must include relevant MST blocks, not just new ones.", 979 + mst_block_count 980 + ); 981 + 982 + ws_stream.send(tungstenite::Message::Close(None)).await.ok(); 983 + }
+12 -9
crates/tranquil-pds/tests/identity.rs
··· 48 48 #[tokio::test] 49 49 async fn test_resolve_handle_not_found() { 50 50 let client = client(); 51 - let params = [("handle", "nonexistent_handle_12345")]; 51 + let _base = base_url().await; 52 + let params = [("handle", "nonexistent.handle.test")]; 52 53 let res = client 53 54 .get(format!( 54 55 "{}/xrpc/com.atproto.identity.resolveHandle", 55 - base_url().await 56 + _base 56 57 )) 57 58 .query(&params) 58 59 .send() ··· 99 100 let mock_addr = mock_uri.trim_start_matches("http://"); 100 101 let did = format!("did:web:{}", mock_addr.replace(":", "%3A")); 101 102 let handle = format!("wu{}", &uuid::Uuid::new_v4().simple().to_string()[..12]); 102 - let pds_endpoint = base_url().await.replace("http://", "https://"); 103 + let base = base_url().await; 104 + let pds_endpoint = common::pds_endpoint(); 103 105 104 106 let reserve_res = client 105 107 .post(format!( 106 108 "{}/xrpc/com.atproto.server.reserveSigningKey", 107 - base_url().await 109 + base 108 110 )) 109 111 .json(&json!({ "did": did })) 110 112 .send() ··· 149 151 let res = client 150 152 .post(format!( 151 153 "{}/xrpc/com.atproto.server.createAccount", 152 - base_url().await 154 + base 153 155 )) 154 156 .json(&payload) 155 157 .send() ··· 169 171 .expect("createAccount response was not JSON"); 170 172 assert_eq!(body["did"], did); 171 173 let res = client 172 - .get(format!("{}/u/{}/did.json", base_url().await, handle)) 174 + .get(format!("{}/u/{}/did.json", base, handle)) 173 175 .send() 174 176 .await 175 177 .expect("Failed to fetch DID doc"); ··· 217 219 #[tokio::test] 218 220 async fn test_did_web_lifecycle() { 219 221 let client = client(); 222 + let base = base_url().await; 220 223 let mock_server = MockServer::start().await; 221 224 let mock_uri = mock_server.uri(); 222 225 let mock_addr = mock_uri.trim_start_matches("http://"); 223 226 let handle = format!("lc{}", &uuid::Uuid::new_v4().simple().to_string()[..12]); 224 227 let did = format!("did:web:{}:u:{}", mock_addr.replace(":", "%3A"), handle); 225 228 let email = format!("{}@test.com", handle); 226 - let pds_endpoint = base_url().await.replace("http://", "https://"); 229 + let pds_endpoint = common::pds_endpoint(); 227 230 228 231 let reserve_res = client 229 232 .post(format!( 230 233 "{}/xrpc/com.atproto.server.reserveSigningKey", 231 - base_url().await 234 + base 232 235 )) 233 236 .json(&json!({ "did": did })) 234 237 .send() ··· 273 276 let res = client 274 277 .post(format!( 275 278 "{}/xrpc/com.atproto.server.createAccount", 276 - base_url().await 279 + base 277 280 )) 278 281 .json(&create_payload) 279 282 .send()
+2 -2
crates/tranquil-pds/tests/lifecycle_record.rs
··· 132 132 .expect("Failed to send stale update"); 133 133 assert_eq!( 134 134 stale_res.status(), 135 - StatusCode::CONFLICT, 136 - "Stale update should cause 409" 135 + StatusCode::BAD_REQUEST, 136 + "Stale update should cause 400 InvalidSwap" 137 137 ); 138 138 let good_update_payload = json!({ 139 139 "repo": did,
+52 -74
crates/tranquil-pds/tests/notifications.rs
··· 1 1 mod common; 2 - use tranquil_pds::comms::{ 3 - CommsChannel, CommsStatus, CommsType, NewComms, enqueue_comms, enqueue_welcome, 4 - }; 2 + use sqlx::Row; 3 + use tranquil_pds::comms::{CommsChannel, CommsStatus, CommsType}; 5 4 6 5 #[tokio::test] 7 6 async fn test_enqueue_comms() { 8 7 let pool = common::get_test_db_pool().await; 9 8 let (_, did) = common::create_account_and_login(&common::client()).await; 10 - let user_id: uuid::Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 9 + let user_id: uuid::Uuid = sqlx::query_scalar("SELECT id FROM users WHERE did = $1") 10 + .bind(&did) 11 11 .fetch_one(pool) 12 12 .await 13 13 .expect("User not found"); 14 - let item = NewComms::email( 15 - user_id, 16 - CommsType::Welcome, 17 - "test@example.com".to_string(), 18 - "Test Subject".to_string(), 19 - "Test body".to_string(), 20 - ); 21 - let comms_id = enqueue_comms(pool, item) 22 - .await 23 - .expect("Failed to enqueue comms"); 24 - let row = sqlx::query!( 25 - r#" 26 - SELECT 27 - id, user_id, recipient, subject, body, 28 - channel as "channel: CommsChannel", 29 - comms_type as "comms_type: CommsType", 30 - status as "status: CommsStatus" 31 - FROM comms_queue 32 - WHERE id = $1 33 - "#, 34 - comms_id 14 + let comms_id: uuid::Uuid = sqlx::query_scalar( 15 + r#"INSERT INTO comms_queue (user_id, channel, comms_type, recipient, subject, body) 16 + VALUES ($1, 'email', 'welcome', $2, $3, $4) 17 + RETURNING id"#, 35 18 ) 19 + .bind(user_id) 20 + .bind("test@example.com") 21 + .bind("Test Subject") 22 + .bind("Test body") 36 23 .fetch_one(pool) 37 24 .await 38 - .expect("Comms not found"); 39 - assert_eq!(row.user_id, user_id); 40 - assert_eq!(row.recipient, "test@example.com"); 41 - assert_eq!(row.subject.as_deref(), Some("Test Subject")); 42 - assert_eq!(row.body, "Test body"); 43 - assert_eq!(row.channel, CommsChannel::Email); 44 - assert_eq!(row.comms_type, CommsType::Welcome); 45 - assert_eq!(row.status, CommsStatus::Pending); 46 - } 47 - 48 - #[tokio::test] 49 - async fn test_enqueue_welcome() { 50 - let pool = common::get_test_db_pool().await; 51 - let (_, did) = common::create_account_and_login(&common::client()).await; 52 - let user_row = sqlx::query!("SELECT id, email, handle FROM users WHERE did = $1", did) 53 - .fetch_one(pool) 54 - .await 55 - .expect("User not found"); 56 - let comms_id = enqueue_welcome(pool, user_row.id, "example.com") 57 - .await 58 - .expect("Failed to enqueue welcome comms"); 59 - let row = sqlx::query!( 25 + .expect("Failed to enqueue comms"); 26 + let row = sqlx::query( 60 27 r#" 61 - SELECT 62 - recipient, subject, body, 63 - comms_type as "comms_type: CommsType" 28 + SELECT id, user_id, recipient, subject, body, channel, comms_type, status 64 29 FROM comms_queue 65 30 WHERE id = $1 66 31 "#, 67 - comms_id 68 32 ) 33 + .bind(comms_id) 69 34 .fetch_one(pool) 70 35 .await 71 36 .expect("Comms not found"); 72 - assert_eq!(Some(row.recipient), user_row.email); 73 - assert_eq!(row.subject.as_deref(), Some("Welcome to example.com")); 74 - assert!(row.body.contains(&format!("@{}", user_row.handle))); 75 - assert_eq!(row.comms_type, CommsType::Welcome); 37 + let row_user_id: uuid::Uuid = row.get("user_id"); 38 + let row_recipient: String = row.get("recipient"); 39 + let row_subject: Option<String> = row.get("subject"); 40 + let row_body: String = row.get("body"); 41 + let row_channel: CommsChannel = row.get("channel"); 42 + let row_comms_type: CommsType = row.get("comms_type"); 43 + let row_status: CommsStatus = row.get("status"); 44 + assert_eq!(row_user_id, user_id); 45 + assert_eq!(row_recipient, "test@example.com"); 46 + assert_eq!(row_subject.as_deref(), Some("Test Subject")); 47 + assert_eq!(row_body, "Test body"); 48 + assert_eq!(row_channel, CommsChannel::Email); 49 + assert_eq!(row_comms_type, CommsType::Welcome); 50 + assert_eq!(row_status, CommsStatus::Pending); 76 51 } 77 52 78 53 #[tokio::test] 79 54 async fn test_comms_queue_status_index() { 80 55 let pool = common::get_test_db_pool().await; 81 56 let (_, did) = common::create_account_and_login(&common::client()).await; 82 - let user_id: uuid::Uuid = sqlx::query_scalar!("SELECT id FROM users WHERE did = $1", did) 57 + let user_id: uuid::Uuid = sqlx::query_scalar("SELECT id FROM users WHERE did = $1") 58 + .bind(&did) 83 59 .fetch_one(pool) 84 60 .await 85 61 .expect("User not found"); 86 - let initial_count: i64 = sqlx::query_scalar!( 62 + let initial_count: i64 = sqlx::query_scalar( 87 63 "SELECT COUNT(*) FROM comms_queue WHERE status = 'pending' AND user_id = $1", 88 - user_id 89 64 ) 65 + .bind(user_id) 90 66 .fetch_one(pool) 91 67 .await 92 - .expect("Failed to count") 93 - .unwrap_or(0); 94 - for i in 0..5 { 95 - let item = NewComms::email( 96 - user_id, 97 - CommsType::PasswordReset, 98 - format!("test{}@example.com", i), 99 - "Test".to_string(), 100 - "Body".to_string(), 101 - ); 102 - enqueue_comms(pool, item).await.expect("Failed to enqueue"); 103 - } 104 - let final_count: i64 = sqlx::query_scalar!( 68 + .expect("Failed to count"); 69 + let inserts = (0..5).map(|i| { 70 + sqlx::query( 71 + r#"INSERT INTO comms_queue (user_id, channel, comms_type, recipient, subject, body) 72 + VALUES ($1, 'email', 'password_reset', $2, $3, $4)"#, 73 + ) 74 + .bind(user_id) 75 + .bind(format!("test{}@example.com", i)) 76 + .bind("Test") 77 + .bind("Body") 78 + .execute(pool) 79 + }); 80 + futures::future::try_join_all(inserts) 81 + .await 82 + .expect("Failed to enqueue"); 83 + let final_count: i64 = sqlx::query_scalar( 105 84 "SELECT COUNT(*) FROM comms_queue WHERE status = 'pending' AND user_id = $1", 106 - user_id 107 85 ) 86 + .bind(user_id) 108 87 .fetch_one(pool) 109 88 .await 110 - .expect("Failed to count") 111 - .unwrap_or(0); 89 + .expect("Failed to count"); 112 90 assert_eq!(final_count - initial_count, 5); 113 91 }
+2 -2
crates/tranquil-pds/tests/sync_conformance.rs
··· 352 352 let initial_body: Value = initial_commit_res.json().await.unwrap(); 353 353 let initial_rev = initial_body["rev"].as_str().unwrap(); 354 354 355 + create_post(&client, &did, &jwt, "Test post for since param").await; 356 + 355 357 let full_repo_res = client 356 358 .get(format!( 357 359 "{}/xrpc/com.atproto.sync.getRepo", ··· 364 366 assert_eq!(full_repo_res.status(), StatusCode::OK); 365 367 let full_repo_bytes = full_repo_res.bytes().await.unwrap(); 366 368 let full_repo_size = full_repo_bytes.len(); 367 - 368 - create_post(&client, &did, &jwt, "Test post for since param").await; 369 369 370 370 let partial_repo_res = client 371 371 .get(format!(
+379
crates/tranquil-types/src/lib.rs
··· 639 639 pub fn into_inner(self) -> String { 640 640 self.0 641 641 } 642 + 643 + pub fn did(&self) -> Option<&str> { 644 + self.0 645 + .strip_prefix("at://") 646 + .and_then(|s| s.split('/').next()) 647 + } 648 + 649 + pub fn collection(&self) -> Option<&str> { 650 + self.0 651 + .strip_prefix("at://") 652 + .and_then(|s| s.split('/').nth(1)) 653 + } 654 + 655 + pub fn rkey(&self) -> Option<&str> { 656 + self.0 657 + .strip_prefix("at://") 658 + .and_then(|s| s.split('/').nth(2)) 659 + } 642 660 } 643 661 644 662 impl AsRef<str> for AtUri { ··· 1439 1457 } 1440 1458 } 1441 1459 1460 + #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)] 1461 + #[serde(transparent)] 1462 + #[sqlx(transparent)] 1463 + pub struct TokenId(String); 1464 + 1465 + impl TokenId { 1466 + pub fn new(s: impl Into<String>) -> Self { 1467 + Self(s.into()) 1468 + } 1469 + 1470 + pub fn as_str(&self) -> &str { 1471 + &self.0 1472 + } 1473 + 1474 + pub fn into_inner(self) -> String { 1475 + self.0 1476 + } 1477 + } 1478 + 1479 + impl AsRef<str> for TokenId { 1480 + fn as_ref(&self) -> &str { 1481 + &self.0 1482 + } 1483 + } 1484 + 1485 + impl Deref for TokenId { 1486 + type Target = str; 1487 + 1488 + fn deref(&self) -> &Self::Target { 1489 + &self.0 1490 + } 1491 + } 1492 + 1493 + impl From<String> for TokenId { 1494 + fn from(s: String) -> Self { 1495 + Self(s) 1496 + } 1497 + } 1498 + 1499 + impl fmt::Display for TokenId { 1500 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 1501 + write!(f, "{}", self.0) 1502 + } 1503 + } 1504 + 1505 + #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)] 1506 + #[serde(transparent)] 1507 + #[sqlx(transparent)] 1508 + pub struct ClientId(String); 1509 + 1510 + impl ClientId { 1511 + pub fn new(s: impl Into<String>) -> Self { 1512 + Self(s.into()) 1513 + } 1514 + 1515 + pub fn as_str(&self) -> &str { 1516 + &self.0 1517 + } 1518 + 1519 + pub fn into_inner(self) -> String { 1520 + self.0 1521 + } 1522 + } 1523 + 1524 + impl AsRef<str> for ClientId { 1525 + fn as_ref(&self) -> &str { 1526 + &self.0 1527 + } 1528 + } 1529 + 1530 + impl Deref for ClientId { 1531 + type Target = str; 1532 + 1533 + fn deref(&self) -> &Self::Target { 1534 + &self.0 1535 + } 1536 + } 1537 + 1538 + impl From<String> for ClientId { 1539 + fn from(s: String) -> Self { 1540 + Self(s) 1541 + } 1542 + } 1543 + 1544 + impl fmt::Display for ClientId { 1545 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 1546 + write!(f, "{}", self.0) 1547 + } 1548 + } 1549 + 1550 + #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)] 1551 + #[serde(transparent)] 1552 + #[sqlx(transparent)] 1553 + pub struct DeviceId(String); 1554 + 1555 + impl DeviceId { 1556 + pub fn new(s: impl Into<String>) -> Self { 1557 + Self(s.into()) 1558 + } 1559 + 1560 + pub fn as_str(&self) -> &str { 1561 + &self.0 1562 + } 1563 + 1564 + pub fn into_inner(self) -> String { 1565 + self.0 1566 + } 1567 + } 1568 + 1569 + impl AsRef<str> for DeviceId { 1570 + fn as_ref(&self) -> &str { 1571 + &self.0 1572 + } 1573 + } 1574 + 1575 + impl Deref for DeviceId { 1576 + type Target = str; 1577 + 1578 + fn deref(&self) -> &Self::Target { 1579 + &self.0 1580 + } 1581 + } 1582 + 1583 + impl From<String> for DeviceId { 1584 + fn from(s: String) -> Self { 1585 + Self(s) 1586 + } 1587 + } 1588 + 1589 + impl fmt::Display for DeviceId { 1590 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 1591 + write!(f, "{}", self.0) 1592 + } 1593 + } 1594 + 1595 + #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)] 1596 + #[serde(transparent)] 1597 + #[sqlx(transparent)] 1598 + pub struct RequestId(String); 1599 + 1600 + impl RequestId { 1601 + pub fn new(s: impl Into<String>) -> Self { 1602 + Self(s.into()) 1603 + } 1604 + 1605 + pub fn as_str(&self) -> &str { 1606 + &self.0 1607 + } 1608 + 1609 + pub fn into_inner(self) -> String { 1610 + self.0 1611 + } 1612 + } 1613 + 1614 + impl AsRef<str> for RequestId { 1615 + fn as_ref(&self) -> &str { 1616 + &self.0 1617 + } 1618 + } 1619 + 1620 + impl Deref for RequestId { 1621 + type Target = str; 1622 + 1623 + fn deref(&self) -> &Self::Target { 1624 + &self.0 1625 + } 1626 + } 1627 + 1628 + impl From<String> for RequestId { 1629 + fn from(s: String) -> Self { 1630 + Self(s) 1631 + } 1632 + } 1633 + 1634 + impl fmt::Display for RequestId { 1635 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 1636 + write!(f, "{}", self.0) 1637 + } 1638 + } 1639 + 1640 + #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)] 1641 + #[serde(transparent)] 1642 + #[sqlx(transparent)] 1643 + pub struct Jti(String); 1644 + 1645 + impl Jti { 1646 + pub fn new(s: impl Into<String>) -> Self { 1647 + Self(s.into()) 1648 + } 1649 + 1650 + pub fn as_str(&self) -> &str { 1651 + &self.0 1652 + } 1653 + 1654 + pub fn into_inner(self) -> String { 1655 + self.0 1656 + } 1657 + } 1658 + 1659 + impl AsRef<str> for Jti { 1660 + fn as_ref(&self) -> &str { 1661 + &self.0 1662 + } 1663 + } 1664 + 1665 + impl Deref for Jti { 1666 + type Target = str; 1667 + 1668 + fn deref(&self) -> &Self::Target { 1669 + &self.0 1670 + } 1671 + } 1672 + 1673 + impl From<String> for Jti { 1674 + fn from(s: String) -> Self { 1675 + Self(s) 1676 + } 1677 + } 1678 + 1679 + impl fmt::Display for Jti { 1680 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 1681 + write!(f, "{}", self.0) 1682 + } 1683 + } 1684 + 1685 + #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)] 1686 + #[serde(transparent)] 1687 + #[sqlx(transparent)] 1688 + pub struct AuthorizationCode(String); 1689 + 1690 + impl AuthorizationCode { 1691 + pub fn new(s: impl Into<String>) -> Self { 1692 + Self(s.into()) 1693 + } 1694 + 1695 + pub fn as_str(&self) -> &str { 1696 + &self.0 1697 + } 1698 + 1699 + pub fn into_inner(self) -> String { 1700 + self.0 1701 + } 1702 + } 1703 + 1704 + impl AsRef<str> for AuthorizationCode { 1705 + fn as_ref(&self) -> &str { 1706 + &self.0 1707 + } 1708 + } 1709 + 1710 + impl Deref for AuthorizationCode { 1711 + type Target = str; 1712 + 1713 + fn deref(&self) -> &Self::Target { 1714 + &self.0 1715 + } 1716 + } 1717 + 1718 + impl From<String> for AuthorizationCode { 1719 + fn from(s: String) -> Self { 1720 + Self(s) 1721 + } 1722 + } 1723 + 1724 + impl fmt::Display for AuthorizationCode { 1725 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 1726 + write!(f, "{}", self.0) 1727 + } 1728 + } 1729 + 1730 + #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)] 1731 + #[serde(transparent)] 1732 + #[sqlx(transparent)] 1733 + pub struct RefreshToken(String); 1734 + 1735 + impl RefreshToken { 1736 + pub fn new(s: impl Into<String>) -> Self { 1737 + Self(s.into()) 1738 + } 1739 + 1740 + pub fn as_str(&self) -> &str { 1741 + &self.0 1742 + } 1743 + 1744 + pub fn into_inner(self) -> String { 1745 + self.0 1746 + } 1747 + } 1748 + 1749 + impl AsRef<str> for RefreshToken { 1750 + fn as_ref(&self) -> &str { 1751 + &self.0 1752 + } 1753 + } 1754 + 1755 + impl Deref for RefreshToken { 1756 + type Target = str; 1757 + 1758 + fn deref(&self) -> &Self::Target { 1759 + &self.0 1760 + } 1761 + } 1762 + 1763 + impl From<String> for RefreshToken { 1764 + fn from(s: String) -> Self { 1765 + Self(s) 1766 + } 1767 + } 1768 + 1769 + impl fmt::Display for RefreshToken { 1770 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 1771 + write!(f, "{}", self.0) 1772 + } 1773 + } 1774 + 1775 + #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)] 1776 + #[serde(transparent)] 1777 + #[sqlx(transparent)] 1778 + pub struct InviteCode(String); 1779 + 1780 + impl InviteCode { 1781 + pub fn new(s: impl Into<String>) -> Self { 1782 + Self(s.into()) 1783 + } 1784 + 1785 + pub fn as_str(&self) -> &str { 1786 + &self.0 1787 + } 1788 + 1789 + pub fn into_inner(self) -> String { 1790 + self.0 1791 + } 1792 + } 1793 + 1794 + impl AsRef<str> for InviteCode { 1795 + fn as_ref(&self) -> &str { 1796 + &self.0 1797 + } 1798 + } 1799 + 1800 + impl Deref for InviteCode { 1801 + type Target = str; 1802 + 1803 + fn deref(&self) -> &Self::Target { 1804 + &self.0 1805 + } 1806 + } 1807 + 1808 + impl From<String> for InviteCode { 1809 + fn from(s: String) -> Self { 1810 + Self(s) 1811 + } 1812 + } 1813 + 1814 + impl fmt::Display for InviteCode { 1815 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 1816 + write!(f, "{}", self.0) 1817 + } 1818 + } 1819 + 1442 1820 #[cfg(test)] 1443 1821 mod tests { 1444 1822 use super::*; ··· 1486 1864 assert!(Handle::new("user.bsky.social").is_ok()); 1487 1865 assert!(Handle::new("test.example.com").is_ok()); 1488 1866 assert!(Handle::new("invalid handle with spaces").is_err()); 1867 + assert!(Handle::new("alice.pds.test").is_ok()); 1489 1868 } 1490 1869 1491 1870 #[test]