···11mod common;
22-use common::{base_url, client, create_account_and_login, get_test_db_pool};
22+use common::{base_url, client, create_account_and_login, get_test_repos};
33use serde_json::{Value, json};
44+use tranquil_db_traits::{CommsChannel, CommsType};
55+use tranquil_types::Did;
4657#[tokio::test]
68async fn test_get_notification_history() {
79 let client = client();
810 let base = base_url().await;
99- let pool = get_test_db_pool().await;
1111+ let repos = get_test_repos().await;
1012 let (token, did) = create_account_and_login(&client).await;
11131212- let user_id: uuid::Uuid = sqlx::query_scalar("SELECT id FROM users WHERE did = $1")
1313- .bind(&did)
1414- .fetch_one(pool)
1414+ let user_id = repos
1515+ .user
1616+ .get_id_by_did(&Did::new(did).unwrap())
1517 .await
1818+ .expect("DB error")
1619 .expect("User not found");
17201821 for i in 0..3 {
1919- sqlx::query(
2020- r#"INSERT INTO comms_queue (user_id, channel, comms_type, recipient, subject, body)
2121- VALUES ($1, 'email', 'welcome', $2, $3, $4)"#,
2222- )
2323- .bind(user_id)
2424- .bind("test@example.com")
2525- .bind(format!("Subject {}", i))
2626- .bind(format!("Body {}", i))
2727- .execute(pool)
2828- .await
2929- .expect("Failed to enqueue");
2222+ repos
2323+ .infra
2424+ .enqueue_comms(
2525+ Some(user_id),
2626+ CommsChannel::Email,
2727+ CommsType::Welcome,
2828+ "test@example.com",
2929+ Some(&format!("Subject {}", i)),
3030+ &format!("Body {}", i),
3131+ None,
3232+ )
3333+ .await
3434+ .expect("Failed to enqueue");
3035 }
31363237 let resp = client
···140145async fn test_update_email_via_notification_prefs() {
141146 let client = client();
142147 let base = base_url().await;
143143- let pool = get_test_db_pool().await;
148148+ let repos = get_test_repos().await;
144149 let (token, did) = create_account_and_login(&client).await;
145150146151 let unique_email = format!("newemail_{}@example.com", uuid::Uuid::new_v4());
···163168 .contains(&json!("email"))
164169 );
165170166166- let user_id: uuid::Uuid = sqlx::query_scalar("SELECT id FROM users WHERE did = $1")
167167- .bind(&did)
168168- .fetch_one(pool)
171171+ let user_id = repos
172172+ .user
173173+ .get_id_by_did(&Did::new(did).unwrap())
169174 .await
175175+ .expect("DB error")
170176 .expect("User not found");
171177172172- let body_text: String = sqlx::query_scalar(
173173- "SELECT body FROM comms_queue WHERE user_id = $1 AND comms_type = 'email_update' ORDER BY created_at DESC LIMIT 1",
174174- )
175175- .bind(user_id)
176176- .fetch_one(pool)
177177- .await
178178- .expect("Verification code not found");
178178+ let comms = repos
179179+ .infra
180180+ .get_latest_comms_for_user(user_id, CommsType::EmailUpdate, 1)
181181+ .await
182182+ .expect("DB error");
183183+ let body_text = comms
184184+ .first()
185185+ .map(|c| c.body.clone())
186186+ .expect("Verification code not found");
179187180188 let code = body_text
181189 .lines()
+34-21
crates/tranquil-pds/tests/admin_email.rs
···2233use reqwest::StatusCode;
44use serde_json::{Value, json};
55+use tranquil_db_traits::CommsType;
66+use tranquil_types::Did;
5768#[tokio::test]
79async fn test_send_email_success() {
810 let client = common::client();
911 let base_url = common::base_url().await;
1010- let pool = common::get_test_db_pool().await;
1212+ let repos = common::get_test_repos().await;
1113 let (access_jwt, did) = common::create_admin_account_and_login(&client).await;
1214 let res = client
1315 .post(format!("{}/xrpc/com.atproto.admin.sendEmail", base_url))
···2426 assert_eq!(res.status(), StatusCode::OK);
2527 let body: Value = res.json().await.expect("Invalid JSON");
2628 assert_eq!(body["sent"], true);
2727- let user = sqlx::query!("SELECT id FROM users WHERE did = $1", did)
2828- .fetch_one(pool)
2929+ let user_id = repos
3030+ .user
3131+ .get_id_by_did(&Did::new(did).unwrap())
2932 .await
3333+ .expect("DB error")
3034 .expect("User not found");
3131- let notification = sqlx::query!(
3232- "SELECT subject, body, comms_type as \"comms_type: String\" FROM comms_queue WHERE user_id = $1 AND comms_type = 'admin_email' ORDER BY created_at DESC LIMIT 1",
3333- user.id
3434- )
3535- .fetch_one(pool)
3636- .await
3737- .expect("Notification not found");
3535+ let comms = repos
3636+ .infra
3737+ .get_latest_comms_for_user(user_id, CommsType::AdminEmail, 1)
3838+ .await
3939+ .expect("DB error");
4040+ let notification = comms.first().expect("Notification not found");
3841 assert_eq!(notification.subject.as_deref(), Some("Test Admin Email"));
3942 assert!(
4043 notification
···4750async fn test_send_email_default_subject() {
4851 let client = common::client();
4952 let base_url = common::base_url().await;
5050- let pool = common::get_test_db_pool().await;
5353+ let repos = common::get_test_repos().await;
5154 let (access_jwt, did) = common::create_admin_account_and_login(&client).await;
5255 let res = client
5356 .post(format!("{}/xrpc/com.atproto.admin.sendEmail", base_url))
···6366 assert_eq!(res.status(), StatusCode::OK);
6467 let body: Value = res.json().await.expect("Invalid JSON");
6568 assert_eq!(body["sent"], true);
6666- let user = sqlx::query!("SELECT id FROM users WHERE did = $1", did)
6767- .fetch_one(pool)
6969+ let user_id = repos
7070+ .user
7171+ .get_id_by_did(&Did::new(did).unwrap())
6872 .await
7373+ .expect("DB error")
6974 .expect("User not found");
7070- let notification = sqlx::query!(
7171- "SELECT subject FROM comms_queue WHERE user_id = $1 AND comms_type = 'admin_email' AND body = 'Email without subject' LIMIT 1",
7272- user.id
7373- )
7474- .fetch_one(pool)
7575- .await
7676- .expect("Notification not found");
7575+ let comms = repos
7676+ .infra
7777+ .get_latest_comms_for_user(user_id, CommsType::AdminEmail, 10)
7878+ .await
7979+ .expect("DB error");
8080+ let notification = comms
8181+ .iter()
8282+ .find(|c| c.body == "Email without subject")
8383+ .expect("Notification not found");
7784 assert!(notification.subject.is_some());
7878- assert!(notification.subject.unwrap().contains("Message from"));
8585+ assert!(
8686+ notification
8787+ .subject
8888+ .as_ref()
8989+ .unwrap()
9090+ .contains("Message from")
9191+ );
7992}
80938194#[tokio::test]
+4-3
crates/tranquil-pds/tests/auth_extractor.rs
···215215 let did = account["did"].as_str().unwrap().to_string();
216216 verify_new_account(&http_client, &did).await;
217217218218- let pool = common::get_test_db_pool().await;
219219- sqlx::query!("UPDATE users SET is_admin = TRUE WHERE did = $1", &did)
220220- .execute(pool)
218218+ let repos = common::get_test_repos().await;
219219+ repos
220220+ .user
221221+ .set_admin_status(&tranquil_types::Did::new(did.clone()).unwrap(), true)
221222 .await
222223 .expect("Failed to mark user as admin");
223224
+105-43
crates/tranquil-pds/tests/common/mod.rs
···2828static TEST_DB_POOL: OnceLock<sqlx::PgPool> = OnceLock::new();
2929static TEST_TEMP_DIR: OnceLock<PathBuf> = OnceLock::new();
3030static CLUSTER: OnceLock<Vec<ServerInstance>> = OnceLock::new();
3131+static TEST_REPOS: OnceLock<Arc<tranquil_db::PostgresRepositories>> = OnceLock::new();
3232+3333+#[allow(dead_code)]
3434+pub fn is_store_backend() -> bool {
3535+ std::env::var("TRANQUIL_TEST_BACKEND")
3636+ .map(|v| v == "store")
3737+ .unwrap_or(false)
3838+}
31393240#[allow(dead_code)]
3341pub struct ServerConfig {
3434- pub pool: sqlx::PgPool,
4242+ pub pool: Option<sqlx::PgPool>,
3543 pub cache: Option<(Arc<dyn Cache>, Arc<dyn DistributedRateLimiter>)>,
3644}
3745···123131 SERVER_URL.get_or_init(|| {
124132 let (tx, rx) = std::sync::mpsc::channel();
125133 std::thread::spawn(move || {
134134+ let _ = tracing_subscriber::fmt()
135135+ .with_env_filter(
136136+ tracing_subscriber::EnvFilter::try_from_default_env()
137137+ .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("warn")),
138138+ )
139139+ .try_init();
126140 unsafe {
127141 std::env::set_var("TRANQUIL_PDS_ALLOW_INSECURE_SECRETS", "1");
128142 }
···141155 }
142156 let rt = tokio::runtime::Runtime::new().unwrap();
143157 rt.block_on(async move {
144144- if has_external_infra() {
158158+ if is_store_backend() {
159159+ let url = setup_store_backend().await;
160160+ tx.send(url).unwrap();
161161+ } else if has_external_infra() {
145162 let url = setup_with_external_infra().await;
146163 tx.send(url).unwrap();
147164 } else {
···557574 .with_oauth_authorize_limit(10000)
558575 .with_oauth_token_limit(10000);
559576 let cache_refs = config.cache.as_ref().map(|(c, r)| (c.clone(), r.clone()));
560560- let mut state = AppState::from_db(config.pool, CancellationToken::new())
561561- .await
562562- .with_rate_limiters(rate_limiters);
577577+ let mut state = match config.pool {
578578+ Some(pool) => AppState::from_db(pool, CancellationToken::new()).await,
579579+ None => AppState::from_store(CancellationToken::new()).await,
580580+ };
581581+ state = state.with_rate_limiters(rate_limiters);
582582+ TEST_REPOS.set(state.repos.clone()).ok();
563583 if let Some((cache, distributed_rate_limiter)) = config.cache {
564584 state = state.with_cache(cache, distributed_rate_limiter);
565585 }
···590610 }
591611}
592612613613+async fn setup_store_backend() -> String {
614614+ let temp_dir =
615615+ std::env::temp_dir().join(format!("tranquil-pds-store-{}", uuid::Uuid::new_v4()));
616616+ let blob_path = temp_dir.join("blobs");
617617+ let backup_path = temp_dir.join("backups");
618618+ let store_path = temp_dir.join("store");
619619+ std::fs::create_dir_all(&blob_path).expect("failed to create blob temp directory");
620620+ std::fs::create_dir_all(&backup_path).expect("failed to create backup temp directory");
621621+ std::fs::create_dir_all(&store_path).expect("failed to create store temp directory");
622622+ TEST_TEMP_DIR.set(temp_dir).ok();
623623+ let plc_url = setup_mock_plc_directory().await;
624624+ unsafe {
625625+ std::env::set_var("BLOB_STORAGE_BACKEND", "filesystem");
626626+ std::env::set_var("BLOB_STORAGE_PATH", blob_path.to_str().unwrap());
627627+ std::env::set_var("BACKUP_STORAGE_BACKEND", "filesystem");
628628+ std::env::set_var("BACKUP_STORAGE_PATH", backup_path.to_str().unwrap());
629629+ std::env::set_var("MAX_IMPORT_SIZE", "100000000");
630630+ std::env::set_var("SKIP_IMPORT_VERIFICATION", "true");
631631+ std::env::set_var("PLC_DIRECTORY_URL", &plc_url);
632632+ std::env::set_var("REPO_BACKEND", "tranquil-store");
633633+ std::env::set_var("TRANQUIL_STORE_DATA_DIR", store_path.to_str().unwrap());
634634+ std::env::set_var("DATABASE_URL", "postgres://unused/unused");
635635+ }
636636+ register_mock_appview().await;
637637+ let instance = spawn_server(ServerConfig {
638638+ pool: None,
639639+ cache: None,
640640+ })
641641+ .await;
642642+ APP_PORT.set(instance.port).ok();
643643+ instance.url
644644+}
645645+593646async fn spawn_app(database_url: String) -> String {
594647 let pool = PgPoolOptions::new()
595648 .max_connections(10)
···608661 .await
609662 .expect("Failed to create test pool");
610663 TEST_DB_POOL.set(test_pool).ok();
611611- let instance = spawn_server(ServerConfig { pool, cache: None }).await;
664664+ let instance = spawn_server(ServerConfig {
665665+ pool: Some(pool),
666666+ cache: None,
667667+ })
668668+ .await;
612669 APP_PORT.set(instance.port).ok();
613670 instance.url
614671}
···659716 let mut instances: Vec<ServerInstance> = Vec::with_capacity(node_count);
660717 for (cache, rate_limiter) in ripple_nodes {
661718 let server_config = ServerConfig {
662662- pool: pool.clone(),
719719+ pool: Some(pool.clone()),
663720 cache: Some((cache, rate_limiter)),
664721 };
665722 let instance = spawn_server(server_config).await;
···799856}
800857801858#[allow(dead_code)]
802802-pub async fn verify_new_account(client: &Client, did: &str) -> String {
803803- let pool = get_test_db_pool().await;
804804- let body_text: String = sqlx::query_scalar!(
805805- "SELECT body FROM comms_queue WHERE user_id = (SELECT id FROM users WHERE did = $1) AND comms_type = 'email_verification' ORDER BY created_at DESC LIMIT 1",
806806- did
807807- )
808808- .fetch_one(pool)
809809- .await
810810- .expect("Failed to get verification code");
859859+pub async fn get_test_repos() -> &'static Arc<tranquil_db::PostgresRepositories> {
860860+ base_url().await;
861861+ TEST_REPOS.get().expect("TEST_REPOS not initialized")
862862+}
811863864864+fn extract_verification_code(body_text: &str) -> String {
812865 let lines: Vec<&str> = body_text.lines().collect();
813813- let verification_code = lines
866866+ lines
814867 .iter()
815868 .enumerate()
816869 .find(|(_, line)| line.contains("verification code is:") || line.contains("code is:"))
···821874 .find(|line| line.trim().starts_with("MX"))
822875 .map(|s| s.trim().to_string())
823876 })
824824- .unwrap_or_else(|| body_text.clone());
877877+ .unwrap_or_else(|| body_text.to_string())
878878+}
879879+880880+async fn get_verification_body_for_did(did: &str) -> String {
881881+ use tranquil_db_traits::CommsType;
882882+ use tranquil_types::Did;
883883+884884+ let repos = get_test_repos().await;
885885+ let user = repos
886886+ .user
887887+ .get_by_did(&Did::new(did.to_string()).unwrap())
888888+ .await
889889+ .expect("failed to look up user")
890890+ .expect("user not found");
891891+ let comms = repos
892892+ .infra
893893+ .get_latest_comms_for_user(user.id, CommsType::EmailVerification, 1)
894894+ .await
895895+ .expect("failed to get comms");
896896+ comms
897897+ .first()
898898+ .map(|c| c.body.clone())
899899+ .expect("no email_verification comms found")
900900+}
901901+902902+#[allow(dead_code)]
903903+pub async fn verify_new_account(client: &Client, did: &str) -> String {
904904+ let body_text = get_verification_body_for_did(did).await;
905905+ let verification_code = extract_verification_code(&body_text);
825906826907 let confirm_payload = json!({
827908 "did": did,
···9561037 if res.status() == StatusCode::OK {
9571038 let body: Value = res.json().await.expect("Invalid JSON");
9581039 let did = body["did"].as_str().expect("No did").to_string();
959959- let pool = get_test_db_pool().await;
9601040 if make_admin {
961961- sqlx::query!("UPDATE users SET is_admin = TRUE WHERE did = $1", &did)
962962- .execute(pool)
10411041+ let repos = get_test_repos().await;
10421042+ repos
10431043+ .user
10441044+ .set_admin_status(&tranquil_types::Did::new(did.clone()).unwrap(), true)
9631045 .await
9641046 .expect("Failed to mark user as admin");
9651047 }
···9691051 {
9701052 return (access_jwt.to_string(), did);
9711053 }
972972- let body_text: String = sqlx::query_scalar!(
973973- "SELECT body FROM comms_queue WHERE user_id = (SELECT id FROM users WHERE did = $1) AND comms_type = 'email_verification' ORDER BY created_at DESC LIMIT 1",
974974- &did
975975- )
976976- .fetch_one(pool)
977977- .await
978978- .expect("Failed to get verification from comms_queue");
979979- let lines: Vec<&str> = body_text.lines().collect();
980980- let verification_code = lines
981981- .iter()
982982- .enumerate()
983983- .find(|(_, line): &(usize, &&str)| {
984984- line.contains("verification code is:") || line.contains("code is:")
985985- })
986986- .and_then(|(i, _)| lines.get(i + 1).map(|s: &&str| s.trim().to_string()))
987987- .or_else(|| {
988988- body_text
989989- .lines()
990990- .find(|line| line.trim().starts_with("MX"))
991991- .map(|s| s.trim().to_string())
992992- })
993993- .unwrap_or_else(|| body_text.clone());
10541054+ let body_text = get_verification_body_for_did(&did).await;
10551055+ let verification_code = extract_verification_code(&body_text);
99410569951057 let confirm_payload = json!({
9961058 "did": did,
+51-60
crates/tranquil-pds/tests/delete_account.rs
···5151 .await
5252 .expect("Failed to request account deletion");
5353 assert_eq!(request_delete_res.status(), StatusCode::OK);
5454- let pool = get_test_db_pool().await;
5555- let row = sqlx::query!(
5656- "SELECT token FROM account_deletion_requests WHERE did = $1",
5757- did
5858- )
5959- .fetch_one(pool)
6060- .await
6161- .expect("Failed to query deletion token");
6262- let token = row.token;
5454+ let repos = get_test_repos().await;
5555+ let deletion_request = repos
5656+ .infra
5757+ .get_deletion_request_by_did(&tranquil_types::Did::new(did.clone()).unwrap())
5858+ .await
5959+ .unwrap()
6060+ .unwrap();
6161+ let token = deletion_request.token;
6362 let delete_payload = json!({
6463 "did": did,
6564 "password": password,
···7574 .await
7675 .expect("Failed to delete account");
7776 assert_eq!(delete_res.status(), StatusCode::OK);
7878- let user_row = sqlx::query!("SELECT id FROM users WHERE did = $1", did)
7979- .fetch_optional(pool)
7777+ let user = repos
7878+ .user
7979+ .get_by_did(&tranquil_types::Did::new(did.clone()).unwrap())
8080 .await
8181- .expect("Failed to query user");
8282- assert!(user_row.is_none(), "User should be deleted from database");
8181+ .unwrap();
8282+ assert!(user.is_none(), "User should be deleted from database");
8383 let session_res = client
8484 .get(format!("{}/xrpc/com.atproto.server.getSession", base_url))
8585 .bearer_auth(&jwt)
···108108 .await
109109 .expect("Failed to request account deletion");
110110 assert_eq!(request_delete_res.status(), StatusCode::OK);
111111- let pool = get_test_db_pool().await;
112112- let row = sqlx::query!(
113113- "SELECT token FROM account_deletion_requests WHERE did = $1",
114114- did
115115- )
116116- .fetch_one(pool)
117117- .await
118118- .expect("Failed to query deletion token");
119119- let token = row.token;
111111+ let repos = get_test_repos().await;
112112+ let deletion_request = repos
113113+ .infra
114114+ .get_deletion_request_by_did(&tranquil_types::Did::new(did.clone()).unwrap())
115115+ .await
116116+ .unwrap()
117117+ .unwrap();
118118+ let token = deletion_request.token;
120119 let delete_payload = json!({
121120 "did": did,
122121 "password": "wrong-password",
···198197 .await
199198 .expect("Failed to request account deletion");
200199 assert_eq!(request_delete_res.status(), StatusCode::OK);
201201- let pool = get_test_db_pool().await;
202202- let row = sqlx::query!(
203203- "SELECT token FROM account_deletion_requests WHERE did = $1",
204204- did
205205- )
206206- .fetch_one(pool)
207207- .await
208208- .expect("Failed to query deletion token");
209209- let token = row.token;
210210- sqlx::query!(
211211- "UPDATE account_deletion_requests SET expires_at = NOW() - INTERVAL '1 hour' WHERE token = $1",
212212- token
213213- )
214214- .execute(pool)
215215- .await
216216- .expect("Failed to expire token");
200200+ let repos = get_test_repos().await;
201201+ let deletion_request = repos
202202+ .infra
203203+ .get_deletion_request_by_did(&tranquil_types::Did::new(did.clone()).unwrap())
204204+ .await
205205+ .unwrap()
206206+ .unwrap();
207207+ let token = deletion_request.token;
208208+ repos.infra.expire_deletion_request(&token).await.unwrap();
217209 let delete_payload = json!({
218210 "did": did,
219211 "password": password,
···257249 .await
258250 .expect("Failed to request account deletion");
259251 assert_eq!(request_delete_res.status(), StatusCode::OK);
260260- let pool = get_test_db_pool().await;
261261- let row = sqlx::query!(
262262- "SELECT token FROM account_deletion_requests WHERE did = $1",
263263- did1
264264- )
265265- .fetch_one(pool)
266266- .await
267267- .expect("Failed to query deletion token");
268268- let token = row.token;
252252+ let repos = get_test_repos().await;
253253+ let deletion_request = repos
254254+ .infra
255255+ .get_deletion_request_by_did(&tranquil_types::Did::new(did1.clone()).unwrap())
256256+ .await
257257+ .unwrap()
258258+ .unwrap();
259259+ let token = deletion_request.token;
269260 let delete_payload = json!({
270261 "did": did2,
271262 "password": password2,
···318309 .await
319310 .expect("Failed to request account deletion");
320311 assert_eq!(request_delete_res.status(), StatusCode::OK);
321321- let pool = get_test_db_pool().await;
322322- let row = sqlx::query!(
323323- "SELECT token FROM account_deletion_requests WHERE did = $1",
324324- did
325325- )
326326- .fetch_one(pool)
327327- .await
328328- .expect("Failed to query deletion token");
329329- let token = row.token;
312312+ let repos = get_test_repos().await;
313313+ let deletion_request = repos
314314+ .infra
315315+ .get_deletion_request_by_did(&tranquil_types::Did::new(did.clone()).unwrap())
316316+ .await
317317+ .unwrap()
318318+ .unwrap();
319319+ let token = deletion_request.token;
330320 let delete_payload = json!({
331321 "did": did,
332322 "password": app_password,
···342332 .await
343333 .expect("Failed to delete account");
344334 assert_eq!(delete_res.status(), StatusCode::OK);
345345- let user_row = sqlx::query!("SELECT id FROM users WHERE did = $1", did)
346346- .fetch_optional(pool)
335335+ let user = repos
336336+ .user
337337+ .get_by_did(&tranquil_types::Did::new(did.clone()).unwrap())
347338 .await
348348- .expect("Failed to query user");
349349- assert!(user_row.is_none(), "User should be deleted from database");
339339+ .unwrap();
340340+ assert!(user.is_none(), "User should be deleted from database");
350341}
351342352343#[tokio::test]
+89-50
crates/tranquil-pds/tests/email_update.rs
···11mod common;
22use reqwest::StatusCode;
33use serde_json::{Value, json};
44-use sqlx::PgPool;
44+use tranquil_db_traits::CommsType;
55+use tranquil_types::Did;
5666-async fn get_email_update_token(pool: &PgPool, did: &str) -> String {
77- let body_text: String = sqlx::query_scalar!(
88- "SELECT body FROM comms_queue WHERE user_id = (SELECT id FROM users WHERE did = $1) AND comms_type = 'email_update' ORDER BY created_at DESC LIMIT 1",
99- did
1010- )
1111- .fetch_one(pool)
1212- .await
1313- .expect("Verification not found");
77+async fn get_email_update_token(did: &str) -> String {
88+ let repos = common::get_test_repos().await;
99+ let parsed_did = Did::new(did.to_string()).unwrap();
1010+ let user = repos
1111+ .user
1212+ .get_by_did(&parsed_did)
1313+ .await
1414+ .expect("failed to look up user")
1515+ .expect("user not found");
1616+ let comms = repos
1717+ .infra
1818+ .get_latest_comms_for_user(user.id, CommsType::EmailUpdate, 1)
1919+ .await
2020+ .expect("failed to get comms");
2121+ let body_text = comms.first().expect("Verification not found").body.clone();
14221523 body_text
1624 .lines()
···8290async fn test_update_email_flow_success() {
8391 let client = common::client();
8492 let base_url = common::base_url().await;
8585- let pool = common::get_test_db_pool().await;
9393+ let repos = common::get_test_repos().await;
8694 let handle = format!("eu{}", &uuid::Uuid::new_v4().simple().to_string()[..12]);
8795 let email = format!("{}@example.com", handle);
8896 let (access_jwt, did) = create_verified_account(&client, base_url, &handle, &email).await;
···101109 let body: Value = res.json().await.expect("Invalid JSON");
102110 assert_eq!(body["tokenRequired"], true);
103111104104- let code = get_email_update_token(pool, &did).await;
112112+ let code = get_email_update_token(&did).await;
105113106114 let res = client
107115 .post(format!("{}/xrpc/com.atproto.server.updateEmail", base_url))
···115123 .expect("Failed to update email");
116124 assert_eq!(res.status(), StatusCode::OK);
117125118118- let user_email: Option<String> =
119119- sqlx::query_scalar!("SELECT email FROM users WHERE did = $1", did)
120120- .fetch_one(pool)
121121- .await
122122- .expect("User not found");
126126+ let parsed_did = Did::new(did).unwrap();
127127+ let user_email = repos
128128+ .user
129129+ .get_email_info_by_did(&parsed_did)
130130+ .await
131131+ .expect("failed to look up user")
132132+ .expect("user not found")
133133+ .email;
123134 assert_eq!(user_email, Some(new_email));
124135}
125136···239250async fn test_confirm_email_confirms_existing_email() {
240251 let client = common::client();
241252 let base_url = common::base_url().await;
242242- let pool = common::get_test_db_pool().await;
253253+ let repos = common::get_test_repos().await;
243254 let handle = format!("ec{}", &uuid::Uuid::new_v4().simple().to_string()[..12]);
244255 let email = format!("{}@example.com", handle);
245256···264275 .expect("No accessJwt")
265276 .to_string();
266277267267- let body_text: String = sqlx::query_scalar!(
268268- "SELECT body FROM comms_queue WHERE user_id = (SELECT id FROM users WHERE did = $1) AND comms_type = 'email_verification' ORDER BY created_at DESC LIMIT 1",
269269- did
270270- )
271271- .fetch_one(pool)
272272- .await
273273- .expect("Verification email not found");
278278+ let parsed_did = Did::new(did.clone()).unwrap();
279279+ let user = repos
280280+ .user
281281+ .get_by_did(&parsed_did)
282282+ .await
283283+ .expect("failed to look up user")
284284+ .expect("user not found");
285285+ let comms = repos
286286+ .infra
287287+ .get_latest_comms_for_user(user.id, CommsType::EmailVerification, 1)
288288+ .await
289289+ .expect("failed to get comms");
290290+ let body_text = comms
291291+ .first()
292292+ .expect("Verification email not found")
293293+ .body
294294+ .clone();
274295275296 let code = body_text
276297 .lines()
···290311 .expect("Failed to confirm email");
291312 assert_eq!(res.status(), StatusCode::OK);
292313293293- let verified: bool =
294294- sqlx::query_scalar!("SELECT email_verified FROM users WHERE did = $1", did)
295295- .fetch_one(pool)
296296- .await
297297- .expect("User not found");
314314+ let verified = repos
315315+ .user
316316+ .get_email_info_by_did(&parsed_did)
317317+ .await
318318+ .expect("failed to look up user")
319319+ .expect("user not found")
320320+ .email_verified;
298321 assert!(verified);
299322}
300323···302325async fn test_confirm_email_rejects_wrong_email() {
303326 let client = common::client();
304327 let base_url = common::base_url().await;
305305- let pool = common::get_test_db_pool().await;
328328+ let repos = common::get_test_repos().await;
306329 let handle = format!("ew{}", &uuid::Uuid::new_v4().simple().to_string()[..12]);
307330 let email = format!("{}@example.com", handle);
308331···327350 .expect("No accessJwt")
328351 .to_string();
329352330330- let body_text: String = sqlx::query_scalar!(
331331- "SELECT body FROM comms_queue WHERE user_id = (SELECT id FROM users WHERE did = $1) AND comms_type = 'email_verification' ORDER BY created_at DESC LIMIT 1",
332332- did
333333- )
334334- .fetch_one(pool)
335335- .await
336336- .expect("Verification email not found");
353353+ let parsed_did = Did::new(did).unwrap();
354354+ let user = repos
355355+ .user
356356+ .get_by_did(&parsed_did)
357357+ .await
358358+ .expect("failed to look up user")
359359+ .expect("user not found");
360360+ let comms = repos
361361+ .infra
362362+ .get_latest_comms_for_user(user.id, CommsType::EmailVerification, 1)
363363+ .await
364364+ .expect("failed to get comms");
365365+ let body_text = comms
366366+ .first()
367367+ .expect("Verification email not found")
368368+ .body
369369+ .clone();
337370338371 let code = body_text
339372 .lines()
···402435async fn test_unverified_account_can_update_email_without_token() {
403436 let client = common::client();
404437 let base_url = common::base_url().await;
405405- let pool = common::get_test_db_pool().await;
438438+ let repos = common::get_test_repos().await;
406439 let handle = format!("ev{}", &uuid::Uuid::new_v4().simple().to_string()[..12]);
407440 let email = format!("{}@example.com", handle);
408441···457490 "Unverified account should be able to update email without token"
458491 );
459492460460- let user_email: Option<String> =
461461- sqlx::query_scalar!("SELECT email FROM users WHERE did = $1", did)
462462- .fetch_one(pool)
463463- .await
464464- .expect("User not found");
493493+ let parsed_did = Did::new(did).unwrap();
494494+ let user_email = repos
495495+ .user
496496+ .get_email_info_by_did(&parsed_did)
497497+ .await
498498+ .expect("failed to look up user")
499499+ .expect("user not found")
500500+ .email;
465501 assert_eq!(user_email, Some(new_email));
466502}
467503···469505async fn test_update_email_to_same_as_another_user_allowed() {
470506 let client = common::client();
471507 let base_url = common::base_url().await;
472472- let pool = common::get_test_db_pool().await;
508508+ let repos = common::get_test_repos().await;
473509474510 let handle1 = format!("d1{}", &uuid::Uuid::new_v4().simple().to_string()[..12]);
475511 let email1 = format!("{}@example.com", handle1);
···490526 .expect("Failed to request email update");
491527 assert_eq!(res.status(), StatusCode::OK);
492528493493- let code = get_email_update_token(pool, &did2).await;
529529+ let code = get_email_update_token(&did2).await;
494530495531 let res = client
496532 .post(format!("{}/xrpc/com.atproto.server.updateEmail", base_url))
···508544 "Multiple accounts can share the same email address"
509545 );
510546511511- let user_email: Option<String> =
512512- sqlx::query_scalar!("SELECT email FROM users WHERE did = $1", did2)
513513- .fetch_one(pool)
514514- .await
515515- .expect("User not found");
547547+ let parsed_did = Did::new(did2).unwrap();
548548+ let user_email = repos
549549+ .user
550550+ .get_email_info_by_did(&parsed_did)
551551+ .await
552552+ .expect("failed to look up user")
553553+ .expect("user not found")
554554+ .email;
516555 assert_eq!(user_email, Some(email1.clone()));
517556}
+2-5
crates/tranquil-pds/tests/firehose_validation.rs
···800800801801 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
802802803803- let pool = get_test_db_pool().await;
804804- let max_seq: i64 = sqlx::query_scalar::<_, i64>("SELECT COALESCE(MAX(seq), 0) FROM repo_seq")
805805- .fetch_one(pool)
806806- .await
807807- .unwrap();
803803+ let repos = get_test_repos().await;
804804+ let max_seq = repos.repo.get_max_seq().await.unwrap().as_i64();
808805 let outdated_cursor = (max_seq - 100).max(1);
809806 let url = format!(
810807 "ws://127.0.0.1:{}/xrpc/com.atproto.sync.subscribeRepos?cursor={}",
+7-15
crates/tranquil-pds/tests/helpers/mod.rs
···482482483483#[allow(dead_code)]
484484pub async fn get_user_signing_key(did: &str) -> Option<Vec<u8>> {
485485- let db_url = get_db_connection_string().await;
486486- let pool = sqlx::PgPool::connect(&db_url).await.ok()?;
487487- let row = sqlx::query!(
488488- r#"
489489- SELECT k.key_bytes, k.encryption_version
490490- FROM user_keys k
491491- JOIN users u ON k.user_id = u.id
492492- WHERE u.did = $1
493493- "#,
494494- did
495495- )
496496- .fetch_optional(&pool)
497497- .await
498498- .ok()??;
499499- tranquil_pds::config::decrypt_key(&row.key_bytes, row.encryption_version).ok()
485485+ let repos = super::common::get_test_repos().await;
486486+ let key_info = repos
487487+ .user
488488+ .get_user_key_by_did(&tranquil_types::Did::new(did.to_string()).ok()?)
489489+ .await
490490+ .ok()??;
491491+ tranquil_pds::config::decrypt_key(&key_info.key_bytes, key_info.encryption_version).ok()
500492}
+1
crates/tranquil-pds/tests/invite.rs
···203203#[tokio::test]
204204async fn test_create_invite_codes_non_admin() {
205205 let client = client();
206206+ let _ = create_account_and_login(&client).await;
206207 let (access_jwt, _did) = create_account_and_login(&client).await;
207208 let payload = json!({
208209 "useCount": 2
+14-6
crates/tranquil-pds/tests/jwt_security.rs
···22mod common;
33use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
44use chrono::{Duration, Utc};
55-use common::{base_url, client, create_account_and_login, get_test_db_pool};
55+use common::{base_url, client, create_account_and_login, get_test_repos};
66use k256::SecretKey;
77use k256::ecdsa::{Signature, SigningKey, signature::Signer};
88use rand::rngs::OsRng;
···691691 let account: Value = create_res.json().await.unwrap();
692692 let did = account["did"].as_str().unwrap();
693693694694- let pool = get_test_db_pool().await;
695695- let body_text: String = sqlx::query_scalar!(
696696- "SELECT body FROM comms_queue WHERE user_id = (SELECT id FROM users WHERE did = $1) AND comms_type = 'email_verification' ORDER BY created_at DESC LIMIT 1",
697697- did
698698- ).fetch_one(pool).await.unwrap();
694694+ let repos = get_test_repos().await;
695695+ let user = repos
696696+ .user
697697+ .get_by_did(&tranquil_types::Did::new(did.to_string()).unwrap())
698698+ .await
699699+ .unwrap()
700700+ .unwrap();
701701+ let comms = repos
702702+ .infra
703703+ .get_latest_comms_for_user(user.id, tranquil_db_traits::CommsType::EmailVerification, 1)
704704+ .await
705705+ .unwrap();
706706+ let body_text = comms.first().unwrap().body.clone();
699707 let lines: Vec<&str> = body_text.lines().collect();
700708 let code = lines
701709 .iter()
+79-112
crates/tranquil-pds/tests/legacy_2fa.rs
···11mod common;
2233-use common::{base_url, client, create_account_and_login, get_test_db_pool};
33+use common::{base_url, client, create_account_and_login, get_test_repos};
44use reqwest::StatusCode;
55use serde_json::{Value, json};
66+use tranquil_db_traits::CommsType;
77+use tranquil_types::Did;
6879async fn enable_totp_for_user(did: &str) {
88- let pool = get_test_db_pool().await;
99- let secret = vec![0u8; 20];
1010- sqlx::query(
1111- r#"INSERT INTO user_totp (did, secret_encrypted, encryption_version, verified, created_at)
1212- VALUES ($1, $2, 1, TRUE, NOW())
1313- ON CONFLICT (did) DO UPDATE SET verified = TRUE"#,
1414- )
1515- .bind(did)
1616- .bind(&secret)
1717- .execute(pool)
1818- .await
1919- .expect("Failed to enable TOTP");
1010+ let repos = get_test_repos().await;
1111+ repos
1212+ .user
1313+ .enable_totp_verified(&Did::new(did.to_string()).unwrap(), &[0u8; 20])
1414+ .await
1515+ .unwrap();
2016}
21172218async fn set_allow_legacy_login(did: &str, allow: bool) {
2323- let pool = get_test_db_pool().await;
2424- sqlx::query("UPDATE users SET allow_legacy_login = $1 WHERE did = $2")
2525- .bind(allow)
2626- .bind(did)
2727- .execute(pool)
1919+ let repos = get_test_repos().await;
2020+ repos
2121+ .user
2222+ .update_legacy_login(&Did::new(did.to_string()).unwrap(), allow)
2823 .await
2929- .expect("Failed to set allow_legacy_login");
2424+ .unwrap();
3025}
31263227async fn get_2fa_code_from_queue(did: &str) -> Option<String> {
3333- let pool = get_test_db_pool().await;
3434- let row: Option<(String,)> = sqlx::query_as(
3535- r#"SELECT body FROM comms_queue
3636- WHERE user_id = (SELECT id FROM users WHERE did = $1)
3737- AND comms_type = 'two_factor_code'
3838- ORDER BY created_at DESC LIMIT 1"#,
3939- )
4040- .bind(did)
4141- .fetch_optional(pool)
4242- .await
4343- .ok()
4444- .flatten();
2828+ let repos = get_test_repos().await;
2929+ let parsed_did = Did::new(did.to_string()).unwrap();
3030+ let user_id = repos
3131+ .user
3232+ .get_id_by_did(&parsed_did)
3333+ .await
3434+ .expect("DB error")
3535+ .expect("User not found");
45364646- row.and_then(|(body,)| {
4747- body.lines()
3737+ let comms = repos
3838+ .infra
3939+ .get_latest_comms_for_user(user_id, CommsType::TwoFactorCode, 1)
4040+ .await
4141+ .ok()?;
4242+4343+ comms.first().and_then(|c| {
4444+ c.body
4545+ .lines()
4846 .find(|line: &&str| line.chars().all(|c: char| c.is_ascii_digit()) && line.len() == 8)
4947 .map(|s: &str| s.to_string())
5048 .or_else(|| {
5151- body.split_whitespace()
4949+ c.body
5050+ .split_whitespace()
5251 .find(|word: &&str| {
5352 word.chars().all(|c: char| c.is_ascii_digit()) && word.len() == 8
5453 })
···5857}
59586059async fn clear_2fa_challenges_for_user(did: &str) {
6161- let pool = get_test_db_pool().await;
6262- let _ = sqlx::query(
6363- "DELETE FROM comms_queue WHERE user_id = (SELECT id FROM users WHERE did = $1) AND comms_type = 'two_factor_code'",
6464- )
6565- .bind(did)
6666- .execute(pool)
6767- .await;
6060+ let repos = get_test_repos().await;
6161+ let parsed_did = Did::new(did.to_string()).unwrap();
6262+ let user_id = repos
6363+ .user
6464+ .get_id_by_did(&parsed_did)
6565+ .await
6666+ .expect("DB error")
6767+ .expect("User not found");
6868+6969+ let _ = repos
7070+ .infra
7171+ .delete_comms_by_type_for_user(user_id, CommsType::TwoFactorCode)
7272+ .await;
6873}
69747075async fn set_email_auth_factor(did: &str, enabled: bool) {
7171- let pool = get_test_db_pool().await;
7272- let user_id: uuid::Uuid =
7373- sqlx::query_scalar::<_, uuid::Uuid>("SELECT id FROM users WHERE did = $1")
7474- .bind(did)
7575- .fetch_one(pool)
7676- .await
7777- .expect("Failed to get user id");
7878- let pool = get_test_db_pool().await;
7979- let _ = sqlx::query(
8080- "DELETE FROM account_preferences WHERE user_id = $1 AND name = 'email_auth_factor'",
8181- )
8282- .bind(user_id)
8383- .execute(pool)
8484- .await;
8585- let pool = get_test_db_pool().await;
8686- sqlx::query(
8787- "INSERT INTO account_preferences (user_id, name, value_json) VALUES ($1, 'email_auth_factor', $2::jsonb)",
8888- )
8989- .bind(user_id)
9090- .bind(serde_json::json!(enabled))
9191- .execute(pool)
9292- .await
9393- .expect("Failed to set email_auth_factor");
7676+ let repos = get_test_repos().await;
7777+ let parsed_did = Did::new(did.to_string()).unwrap();
7878+ let user_id = repos
7979+ .user
8080+ .get_id_by_did(&parsed_did)
8181+ .await
8282+ .expect("DB error")
8383+ .expect("User not found");
8484+8585+ repos
8686+ .infra
8787+ .upsert_account_preference(user_id, "email_auth_factor", serde_json::json!(enabled))
8888+ .await
8989+ .expect("Failed to set email_auth_factor");
9090+}
9191+9292+async fn get_handle(did: &str) -> String {
9393+ let repos = get_test_repos().await;
9494+ repos
9595+ .user
9696+ .get_handle_by_did(&Did::new(did.to_string()).unwrap())
9797+ .await
9898+ .expect("DB error")
9999+ .expect("Handle not found")
100100+ .to_string()
94101}
9510296103#[tokio::test]
···102109 enable_totp_for_user(&did).await;
103110 set_allow_legacy_login(&did, true).await;
104111105105- let pool = get_test_db_pool().await;
106106- let handle: String = sqlx::query_scalar::<_, String>("SELECT handle FROM users WHERE did = $1")
107107- .bind(&did)
108108- .fetch_one(pool)
109109- .await
110110- .expect("Failed to get handle");
112112+ let handle = get_handle(&did).await;
111113112114 let login_payload = json!({
113115 "identifier": handle,
···141143 set_allow_legacy_login(&did, true).await;
142144 clear_2fa_challenges_for_user(&did).await;
143145144144- let pool = get_test_db_pool().await;
145145- let handle: String = sqlx::query_scalar::<_, String>("SELECT handle FROM users WHERE did = $1")
146146- .bind(&did)
147147- .fetch_one(pool)
148148- .await
149149- .expect("Failed to get handle");
146146+ let handle = get_handle(&did).await;
150147151148 let login_payload = json!({
152149 "identifier": handle,
···194191 set_allow_legacy_login(&did, true).await;
195192 clear_2fa_challenges_for_user(&did).await;
196193197197- let pool = get_test_db_pool().await;
198198- let handle: String = sqlx::query_scalar::<_, String>("SELECT handle FROM users WHERE did = $1")
199199- .bind(&did)
200200- .fetch_one(pool)
201201- .await
202202- .expect("Failed to get handle");
194194+ let handle = get_handle(&did).await;
203195204196 let resp = client
205197 .post(format!("{}/xrpc/com.atproto.server.createSession", base))
···245237 enable_totp_for_user(&did).await;
246238 set_allow_legacy_login(&did, false).await;
247239248248- let pool = get_test_db_pool().await;
249249- let handle: String = sqlx::query_scalar::<_, String>("SELECT handle FROM users WHERE did = $1")
250250- .bind(&did)
251251- .fetch_one(pool)
252252- .await
253253- .expect("Failed to get handle");
240240+ let handle = get_handle(&did).await;
254241255242 let login_payload = json!({
256243 "identifier": handle,
···274261 let base = base_url().await;
275262 let (_token, did) = create_account_and_login(&client).await;
276263277277- let pool = get_test_db_pool().await;
278278- let handle: String = sqlx::query_scalar::<_, String>("SELECT handle FROM users WHERE did = $1")
279279- .bind(&did)
280280- .fetch_one(pool)
281281- .await
282282- .expect("Failed to get handle");
264264+ let handle = get_handle(&did).await;
283265284266 let login_payload = json!({
285267 "identifier": handle,
···307289 set_allow_legacy_login(&did, true).await;
308290 clear_2fa_challenges_for_user(&did).await;
309291310310- let pool = get_test_db_pool().await;
311311- let handle: String = sqlx::query_scalar::<_, String>("SELECT handle FROM users WHERE did = $1")
312312- .bind(&did)
313313- .fetch_one(pool)
314314- .await
315315- .expect("Failed to get handle");
292292+ let handle = get_handle(&did).await;
316293317294 let resp = client
318295 .post(format!("{}/xrpc/com.atproto.server.createSession", base))
···404381 set_email_auth_factor(&did, true).await;
405382 clear_2fa_challenges_for_user(&did).await;
406383407407- let pool = get_test_db_pool().await;
408408- let handle: String = sqlx::query_scalar::<_, String>("SELECT handle FROM users WHERE did = $1")
409409- .bind(&did)
410410- .fetch_one(pool)
411411- .await
412412- .expect("Failed to get handle");
384384+ let handle = get_handle(&did).await;
413385414386 let login_payload = json!({
415387 "identifier": handle,
···457429458430 set_email_auth_factor(&did, false).await;
459431460460- let pool = get_test_db_pool().await;
461461- let handle: String = sqlx::query_scalar::<_, String>("SELECT handle FROM users WHERE did = $1")
462462- .bind(&did)
463463- .fetch_one(pool)
464464- .await
465465- .expect("Failed to get handle");
432432+ let handle = get_handle(&did).await;
466433467434 let login_payload = json!({
468435 "identifier": handle,
+18-14
crates/tranquil-pds/tests/lifecycle_session.rs
···577577 .await
578578 .expect("Failed to request account deletion");
579579 assert_eq!(res.status(), StatusCode::OK);
580580- let db_url = get_db_connection_string().await;
581581- let pool = sqlx::PgPool::connect(&db_url)
580580+ let repos = get_test_repos().await;
581581+ let deletion_request = repos
582582+ .infra
583583+ .get_deletion_request_by_did(&tranquil_types::Did::new(did.clone()).unwrap())
582584 .await
583583- .expect("Failed to connect to test DB");
584584- let row = sqlx::query!(
585585- "SELECT token, expires_at FROM account_deletion_requests WHERE did = $1",
586586- did
587587- )
588588- .fetch_optional(&pool)
589589- .await
590590- .expect("Failed to query DB");
591591- assert!(row.is_some(), "Deletion token should exist in DB");
592592- let row = row.unwrap();
593593- assert!(!row.token.is_empty(), "Token should not be empty");
594594- assert!(row.expires_at > Utc::now(), "Token should not be expired");
585585+ .expect("Failed to query DB");
586586+ assert!(
587587+ deletion_request.is_some(),
588588+ "Deletion token should exist in DB"
589589+ );
590590+ let deletion_request = deletion_request.unwrap();
591591+ assert!(
592592+ !deletion_request.token.is_empty(),
593593+ "Token should not be empty"
594594+ );
595595+ assert!(
596596+ deletion_request.expires_at > Utc::now(),
597597+ "Token should not be expired"
598598+ );
595599}
+63-74
crates/tranquil-pds/tests/notifications.rs
···11mod common;
22-use sqlx::Row;
33-use tranquil_pds::comms::{CommsChannel, CommsStatus, CommsType};
22+use tranquil_db_traits::{CommsChannel, CommsStatus, CommsType};
33+use tranquil_types::Did;
4455#[tokio::test]
66async fn test_enqueue_comms() {
77- let pool = common::get_test_db_pool().await;
77+ let repos = common::get_test_repos().await;
88 let (_, did) = common::create_account_and_login(&common::client()).await;
99- let user_id: uuid::Uuid = sqlx::query_scalar("SELECT id FROM users WHERE did = $1")
1010- .bind(&did)
1111- .fetch_one(pool)
99+ let user_id = repos
1010+ .user
1111+ .get_id_by_did(&Did::new(did).unwrap())
1212 .await
1313+ .expect("DB error")
1314 .expect("User not found");
1414- let comms_id: uuid::Uuid = sqlx::query_scalar(
1515- r#"INSERT INTO comms_queue (user_id, channel, comms_type, recipient, subject, body)
1616- VALUES ($1, 'email', 'welcome', $2, $3, $4)
1717- RETURNING id"#,
1818- )
1919- .bind(user_id)
2020- .bind("test@example.com")
2121- .bind("Test Subject")
2222- .bind("Test body")
2323- .fetch_one(pool)
2424- .await
2525- .expect("Failed to enqueue comms");
2626- let row = sqlx::query(
2727- r#"
2828- SELECT id, user_id, recipient, subject, body, channel, comms_type, status
2929- FROM comms_queue
3030- WHERE id = $1
3131- "#,
3232- )
3333- .bind(comms_id)
3434- .fetch_one(pool)
3535- .await
3636- .expect("Comms not found");
3737- let row_user_id: uuid::Uuid = row.get("user_id");
3838- let row_recipient: String = row.get("recipient");
3939- let row_subject: Option<String> = row.get("subject");
4040- let row_body: String = row.get("body");
4141- let row_channel: CommsChannel = row.get("channel");
4242- let row_comms_type: CommsType = row.get("comms_type");
4343- let row_status: CommsStatus = row.get("status");
4444- assert_eq!(row_user_id, user_id);
4545- assert_eq!(row_recipient, "test@example.com");
4646- assert_eq!(row_subject.as_deref(), Some("Test Subject"));
4747- assert_eq!(row_body, "Test body");
4848- assert_eq!(row_channel, CommsChannel::Email);
4949- assert_eq!(row_comms_type, CommsType::Welcome);
5050- assert_eq!(row_status, CommsStatus::Pending);
1515+ repos
1616+ .infra
1717+ .enqueue_comms(
1818+ Some(user_id),
1919+ CommsChannel::Email,
2020+ CommsType::Welcome,
2121+ "test@example.com",
2222+ Some("Test Subject"),
2323+ "Test body",
2424+ None,
2525+ )
2626+ .await
2727+ .expect("Failed to enqueue comms");
2828+ let comms = repos
2929+ .infra
3030+ .get_latest_comms_for_user(user_id, CommsType::Welcome, 1)
3131+ .await
3232+ .expect("DB error");
3333+ let row = comms.first().expect("Comms not found");
3434+ assert_eq!(row.user_id, Some(user_id));
3535+ assert_eq!(row.recipient, "test@example.com");
3636+ assert_eq!(row.subject.as_deref(), Some("Test Subject"));
3737+ assert_eq!(row.body, "Test body");
3838+ assert_eq!(row.channel, CommsChannel::Email);
3939+ assert_eq!(row.comms_type, CommsType::Welcome);
4040+ assert_eq!(row.status, CommsStatus::Pending);
5141}
52425343#[tokio::test]
5444async fn test_comms_queue_status_index() {
5555- let pool = common::get_test_db_pool().await;
4545+ let repos = common::get_test_repos().await;
5646 let (_, did) = common::create_account_and_login(&common::client()).await;
5757- let user_id: uuid::Uuid = sqlx::query_scalar("SELECT id FROM users WHERE did = $1")
5858- .bind(&did)
5959- .fetch_one(pool)
4747+ let user_id = repos
4848+ .user
4949+ .get_id_by_did(&Did::new(did).unwrap())
6050 .await
5151+ .expect("DB error")
6152 .expect("User not found");
6262- let initial_count: i64 = sqlx::query_scalar(
6363- "SELECT COUNT(*) FROM comms_queue WHERE status = 'pending' AND user_id = $1",
6464- )
6565- .bind(user_id)
6666- .fetch_one(pool)
6767- .await
6868- .expect("Failed to count");
6969- let inserts = (0..5).map(|i| {
7070- sqlx::query(
7171- r#"INSERT INTO comms_queue (user_id, channel, comms_type, recipient, subject, body)
7272- VALUES ($1, 'email', 'password_reset', $2, $3, $4)"#,
7373- )
7474- .bind(user_id)
7575- .bind(format!("test{}@example.com", i))
7676- .bind("Test")
7777- .bind("Body")
7878- .execute(pool)
7979- });
8080- futures::future::try_join_all(inserts)
5353+ let initial_count = repos
5454+ .infra
5555+ .count_comms_by_type(user_id, CommsType::PasswordReset)
8156 .await
8282- .expect("Failed to enqueue");
8383- let final_count: i64 = sqlx::query_scalar(
8484- "SELECT COUNT(*) FROM comms_queue WHERE status = 'pending' AND user_id = $1",
8585- )
8686- .bind(user_id)
8787- .fetch_one(pool)
8888- .await
8989- .expect("Failed to count");
5757+ .expect("Failed to count");
5858+ for i in 0..5 {
5959+ let recipient = format!("test{}@example.com", i);
6060+ repos
6161+ .infra
6262+ .enqueue_comms(
6363+ Some(user_id),
6464+ CommsChannel::Email,
6565+ CommsType::PasswordReset,
6666+ &recipient,
6767+ Some("Test"),
6868+ "Body",
6969+ None,
7070+ )
7171+ .await
7272+ .expect("Failed to enqueue");
7373+ }
7474+ let final_count = repos
7575+ .infra
7676+ .count_comms_by_type(user_id, CommsType::PasswordReset)
7777+ .await
7878+ .expect("Failed to count");
9079 assert_eq!(final_count - initial_count, 5);
9180}
+26-25
crates/tranquil-pds/tests/oauth.rs
···11mod common;
22mod helpers;
33use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
44-use common::{base_url, client, get_test_db_pool};
44+use common::{base_url, client, get_test_repos};
55use helpers::verify_new_account;
66use reqwest::{StatusCode, redirect};
77use serde_json::{Value, json};
88use sha2::{Digest, Sha256};
99+use tranquil_types::{Did, RequestId};
910use wiremock::matchers::{method, path};
1011use wiremock::{Mock, MockServer, ResponseTemplate};
1112···449450 let account: Value = create_res.json().await.unwrap();
450451 let user_did = account["did"].as_str().unwrap();
451452 verify_new_account(&http_client, user_did).await;
452452- let pool = get_test_db_pool().await;
453453- sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1")
454454- .bind(user_did)
455455- .execute(pool)
453453+ let repos = get_test_repos().await;
454454+ repos
455455+ .user
456456+ .set_two_factor_enabled(&Did::new(user_did.to_string()).unwrap(), true)
456457 .await
457458 .unwrap();
458459 let redirect_uri = "https://example.com/2fa-callback";
···508509 .contains("Invalid")
509510 || body["error"].as_str().unwrap_or("") == "invalid_code"
510511 );
511511- let twofa_code: String =
512512- sqlx::query_scalar("SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1")
513513- .bind(request_uri)
514514- .fetch_one(pool)
515515- .await
516516- .unwrap();
512512+ let twofa_code: String = repos
513513+ .oauth
514514+ .get_2fa_challenge_code(&RequestId::new(request_uri.to_string()))
515515+ .await
516516+ .unwrap()
517517+ .unwrap();
517518 let twofa_res = http_client
518519 .post(format!("{}/oauth/authorize/2fa", url))
519520 .header("Content-Type", "application/json")
···574575 let account: Value = create_res.json().await.unwrap();
575576 let user_did = account["did"].as_str().unwrap();
576577 verify_new_account(&http_client, user_did).await;
577577- let pool = get_test_db_pool().await;
578578- sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1")
579579- .bind(user_did)
580580- .execute(pool)
578578+ let repos = get_test_repos().await;
579579+ repos
580580+ .user
581581+ .set_two_factor_enabled(&Did::new(user_did.to_string()).unwrap(), true)
581582 .await
582583 .unwrap();
583584 let redirect_uri = "https://example.com/2fa-lockout-callback";
···748749 .json::<Value>()
749750 .await
750751 .unwrap();
751751- let pool = get_test_db_pool().await;
752752- sqlx::query("UPDATE users SET two_factor_enabled = true WHERE did = $1")
753753- .bind(&user_did)
754754- .execute(pool)
752752+ let repos = get_test_repos().await;
753753+ repos
754754+ .user
755755+ .set_two_factor_enabled(&Did::new(user_did.to_string()).unwrap(), true)
755756 .await
756757 .unwrap();
757758 let (code_verifier2, code_challenge2) = generate_pkce();
···789790 select_body["needs_2fa"].as_bool().unwrap_or(false),
790791 "Should need 2FA"
791792 );
792792- let twofa_code: String =
793793- sqlx::query_scalar("SELECT code FROM oauth_2fa_challenge WHERE request_uri = $1")
794794- .bind(request_uri2)
795795- .fetch_one(pool)
796796- .await
797797- .unwrap();
793793+ let twofa_code: String = repos
794794+ .oauth
795795+ .get_2fa_challenge_code(&RequestId::new(request_uri2.to_string()))
796796+ .await
797797+ .unwrap()
798798+ .unwrap();
798799 let twofa_res = http_client
799800 .post(format!("{}/oauth/authorize/2fa", url))
800801 .header("cookie", &device_cookie)
+64-74
crates/tranquil-pds/tests/password_reset.rs
···33use helpers::verify_new_account;
44use reqwest::StatusCode;
55use serde_json::{Value, json};
66+use tranquil_db_traits::CommsType;
6778#[tokio::test]
89async fn test_request_password_reset_creates_code() {
910 let client = common::client();
1011 let base_url = common::base_url().await;
1111- let pool = common::get_test_db_pool().await;
1212+ let repos = common::get_test_repos().await;
1213 let handle = format!("pr{}", &uuid::Uuid::new_v4().simple().to_string()[..12]);
1314 let email = format!("{}@example.com", handle);
1415 let payload = json!({
···3637 .await
3738 .expect("Failed to request password reset");
3839 assert_eq!(res.status(), StatusCode::OK);
3939- let user = sqlx::query!(
4040- "SELECT password_reset_code, password_reset_code_expires_at FROM users WHERE email = $1",
4141- email
4242- )
4343- .fetch_one(pool)
4444- .await
4545- .expect("User not found");
4646- assert!(user.password_reset_code.is_some());
4747- assert!(user.password_reset_code_expires_at.is_some());
4848- let code = user.password_reset_code.unwrap();
4040+ let info = repos
4141+ .user
4242+ .get_password_reset_info(&email)
4343+ .await
4444+ .expect("failed to look up user")
4545+ .expect("user not found");
4646+ assert!(info.code.is_some());
4747+ assert!(info.expires_at.is_some());
4848+ let code = info.code.unwrap();
4949 assert!(code.contains('-'));
5050 assert_eq!(code.len(), 11);
5151}
···7070async fn test_reset_password_with_valid_token() {
7171 let client = common::client();
7272 let base_url = common::base_url().await;
7373- let pool = common::get_test_db_pool().await;
7373+ let repos = common::get_test_repos().await;
7474 let handle = format!("pr2{}", &uuid::Uuid::new_v4().simple().to_string()[..12]);
7575 let email = format!("{}@example.com", handle);
7676 let old_password = "Oldpass123!";
···103103 .await
104104 .expect("Failed to request password reset");
105105 assert_eq!(res.status(), StatusCode::OK);
106106- let user = sqlx::query!(
107107- "SELECT password_reset_code FROM users WHERE email = $1",
108108- email
109109- )
110110- .fetch_one(pool)
111111- .await
112112- .expect("User not found");
113113- let token = user.password_reset_code.expect("No reset code");
106106+ let info = repos
107107+ .user
108108+ .get_password_reset_info(&email)
109109+ .await
110110+ .expect("failed to look up user")
111111+ .expect("user not found");
112112+ let token = info.code.expect("No reset code");
114113 let res = client
115114 .post(format!(
116115 "{}/xrpc/com.atproto.server.resetPassword",
···124123 .await
125124 .expect("Failed to reset password");
126125 assert_eq!(res.status(), StatusCode::OK);
127127- let user = sqlx::query!(
128128- "SELECT password_reset_code, password_reset_code_expires_at FROM users WHERE email = $1",
129129- email
130130- )
131131- .fetch_one(pool)
132132- .await
133133- .expect("User not found");
134134- assert!(user.password_reset_code.is_none());
135135- assert!(user.password_reset_code_expires_at.is_none());
126126+ let info = repos
127127+ .user
128128+ .get_password_reset_info(&email)
129129+ .await
130130+ .expect("failed to look up user")
131131+ .expect("user not found");
132132+ assert!(info.code.is_none());
133133+ assert!(info.expires_at.is_none());
136134 let res = client
137135 .post(format!(
138136 "{}/xrpc/com.atproto.server.createSession",
···186184async fn test_reset_password_with_expired_token() {
187185 let client = common::client();
188186 let base_url = common::base_url().await;
189189- let pool = common::get_test_db_pool().await;
187187+ let repos = common::get_test_repos().await;
190188 let handle = format!("pr3{}", &uuid::Uuid::new_v4().simple().to_string()[..12]);
191189 let email = format!("{}@example.com", handle);
192190 let payload = json!({
···214212 .await
215213 .expect("Failed to request password reset");
216214 assert_eq!(res.status(), StatusCode::OK);
217217- let user = sqlx::query!(
218218- "SELECT password_reset_code FROM users WHERE email = $1",
219219- email
220220- )
221221- .fetch_one(pool)
222222- .await
223223- .expect("User not found");
224224- let token = user.password_reset_code.expect("No reset code");
225225- sqlx::query!(
226226- "UPDATE users SET password_reset_code_expires_at = NOW() - INTERVAL '1 hour' WHERE email = $1",
227227- email
228228- )
229229- .execute(pool)
230230- .await
231231- .expect("Failed to expire token");
215215+ let info = repos
216216+ .user
217217+ .get_password_reset_info(&email)
218218+ .await
219219+ .expect("failed to look up user")
220220+ .expect("user not found");
221221+ let token = info.code.expect("No reset code");
222222+ repos
223223+ .user
224224+ .expire_password_reset_code(&email)
225225+ .await
226226+ .expect("Failed to expire token");
232227 let res = client
233228 .post(format!(
234229 "{}/xrpc/com.atproto.server.resetPassword",
···250245async fn test_reset_password_invalidates_sessions() {
251246 let client = common::client();
252247 let base_url = common::base_url().await;
253253- let pool = common::get_test_db_pool().await;
248248+ let repos = common::get_test_repos().await;
254249 let handle = format!("pr4{}", &uuid::Uuid::new_v4().simple().to_string()[..12]);
255250 let email = format!("{}@example.com", handle);
256251 let payload = json!({
···288283 .await
289284 .expect("Failed to request password reset");
290285 assert_eq!(res.status(), StatusCode::OK);
291291- let user = sqlx::query!(
292292- "SELECT password_reset_code FROM users WHERE email = $1",
293293- email
294294- )
295295- .fetch_one(pool)
296296- .await
297297- .expect("User not found");
298298- let token = user.password_reset_code.expect("No reset code");
286286+ let info = repos
287287+ .user
288288+ .get_password_reset_info(&email)
289289+ .await
290290+ .expect("failed to look up user")
291291+ .expect("user not found");
292292+ let token = info.code.expect("No reset code");
299293 let res = client
300294 .post(format!(
301295 "{}/xrpc/com.atproto.server.resetPassword",
···338332339333#[tokio::test]
340334async fn test_reset_password_creates_notification() {
341341- let pool = common::get_test_db_pool().await;
335335+ let repos = common::get_test_repos().await;
342336 let client = common::client();
343337 let base_url = common::base_url().await;
344338 let handle = format!("pr5{}", &uuid::Uuid::new_v4().simple().to_string()[..12]);
···358352 .await
359353 .expect("Failed to create account");
360354 assert_eq!(res.status(), StatusCode::OK);
361361- let user = sqlx::query!("SELECT id FROM users WHERE email = $1", email)
362362- .fetch_one(pool)
355355+ let user = repos
356356+ .user
357357+ .get_by_email(&email)
363358 .await
364364- .expect("User not found");
365365- let initial_count: i64 = sqlx::query_scalar!(
366366- "SELECT COUNT(*) FROM comms_queue WHERE user_id = $1 AND comms_type = 'password_reset'",
367367- user.id
368368- )
369369- .fetch_one(pool)
370370- .await
371371- .expect("Failed to count")
372372- .unwrap_or(0);
359359+ .expect("failed to look up user")
360360+ .expect("user not found");
361361+ let initial_count = repos
362362+ .infra
363363+ .count_comms_by_type(user.id, CommsType::PasswordReset)
364364+ .await
365365+ .expect("Failed to count");
373366 let res = client
374367 .post(format!(
375368 "{}/xrpc/com.atproto.server.requestPasswordReset",
···380373 .await
381374 .expect("Failed to request password reset");
382375 assert_eq!(res.status(), StatusCode::OK);
383383- let final_count: i64 = sqlx::query_scalar!(
384384- "SELECT COUNT(*) FROM comms_queue WHERE user_id = $1 AND comms_type = 'password_reset'",
385385- user.id
386386- )
387387- .fetch_one(pool)
388388- .await
389389- .expect("Failed to count")
390390- .unwrap_or(0);
376376+ let final_count = repos
377377+ .infra
378378+ .count_comms_by_type(user.id, CommsType::PasswordReset)
379379+ .await
380380+ .expect("Failed to count");
391381 assert_eq!(final_count - initial_count, 1);
392382}
+15-34
crates/tranquil-pds/tests/plc_migration.rs
···33use k256::ecdsa::SigningKey;
44use reqwest::StatusCode;
55use serde_json::{Value, json};
66-use sqlx::PgPool;
66+use tranquil_types::Did;
77use wiremock::matchers::{method, path};
88use wiremock::{Mock, MockServer, ResponseTemplate};
99···3636}
37373838async fn get_user_signing_key(did: &str) -> Option<Vec<u8>> {
3939- let db_url = get_db_connection_string().await;
4040- let pool = PgPool::connect(&db_url).await.ok()?;
4141- let row = sqlx::query!(
4242- r#"
4343- SELECT k.key_bytes, k.encryption_version
4444- FROM user_keys k
4545- JOIN users u ON k.user_id = u.id
4646- WHERE u.did = $1
4747- "#,
4848- did
4949- )
5050- .fetch_optional(&pool)
5151- .await
5252- .ok()??;
5353- tranquil_pds::config::decrypt_key(&row.key_bytes, row.encryption_version).ok()
3939+ let repos = get_test_repos().await;
4040+ let parsed_did = Did::new(did.to_string()).ok()?;
4141+ let key_info = repos.user.get_user_key_by_did(&parsed_did).await.ok()??;
4242+ tranquil_pds::config::decrypt_key(&key_info.key_bytes, key_info.encryption_version).ok()
5443}
55445645async fn get_plc_token_from_db(did: &str) -> Option<String> {
5757- let db_url = get_db_connection_string().await;
5858- let pool = PgPool::connect(&db_url).await.ok()?;
5959- sqlx::query_scalar!(
6060- r#"
6161- SELECT t.token
6262- FROM plc_operation_tokens t
6363- JOIN users u ON t.user_id = u.id
6464- WHERE u.did = $1
6565- "#,
6666- did
6767- )
6868- .fetch_optional(&pool)
6969- .await
7070- .ok()?
4646+ let repos = get_test_repos().await;
4747+ let parsed_did = Did::new(did.to_string()).ok()?;
4848+ let tokens = repos.infra.get_plc_tokens_by_did(&parsed_did).await.ok()?;
4949+ tokens.into_iter().next().map(|t| t.token)
7150}
72517352async fn get_user_handle(did: &str) -> Option<String> {
7474- let db_url = get_db_connection_string().await;
7575- let pool = PgPool::connect(&db_url).await.ok()?;
7676- sqlx::query_scalar!(r#"SELECT handle FROM users WHERE did = $1"#, did)
7777- .fetch_optional(&pool)
5353+ let repos = get_test_repos().await;
5454+ let parsed_did = Did::new(did.to_string()).ok()?;
5555+ repos
5656+ .user
5757+ .get_handle_by_did(&parsed_did)
7858 .await
7959 .ok()?
6060+ .map(|h| h.to_string())
8061}
81628263fn create_mock_last_op(
+37-21
crates/tranquil-pds/tests/plc_operations.rs
···22use common::*;
33use reqwest::StatusCode;
44use serde_json::json;
55-use sqlx::PgPool;
55+use tranquil_types::Did;
6677#[tokio::test]
88async fn test_plc_operation_auth() {
···176176 .await
177177 .unwrap();
178178 assert_eq!(res.status(), StatusCode::OK);
179179- let db_url = get_db_connection_string().await;
180180- let pool = PgPool::connect(&db_url).await.unwrap();
181181- let row = sqlx::query!(
182182- "SELECT t.token, t.expires_at FROM plc_operation_tokens t JOIN users u ON t.user_id = u.id WHERE u.did = $1",
183183- did
184184- ).fetch_optional(&pool).await.unwrap();
185185- assert!(row.is_some(), "PLC token should be created in database");
186186- let row = row.unwrap();
187187- assert_eq!(row.token.len(), 11, "Token should be in format xxxxx-xxxxx");
188188- assert!(row.token.contains('-'), "Token should contain hyphen");
179179+ let repos = get_test_repos().await;
180180+ let parsed_did = Did::new(did.clone()).unwrap();
181181+ let tokens = repos
182182+ .infra
183183+ .get_plc_tokens_by_did(&parsed_did)
184184+ .await
185185+ .unwrap();
189186 assert!(
190190- row.expires_at > chrono::Utc::now(),
187187+ !tokens.is_empty(),
188188+ "PLC token should be created in database"
189189+ );
190190+ let first = &tokens[0];
191191+ assert_eq!(
192192+ first.token.len(),
193193+ 11,
194194+ "Token should be in format xxxxx-xxxxx"
195195+ );
196196+ assert!(first.token.contains('-'), "Token should contain hyphen");
197197+ assert!(
198198+ first.expires_at > chrono::Utc::now(),
191199 "Token should not be expired"
192200 );
193193- let diff = row.expires_at - chrono::Utc::now();
201201+ let diff = first.expires_at - chrono::Utc::now();
194202 assert!(
195203 diff.num_minutes() >= 9 && diff.num_minutes() <= 11,
196204 "Token should expire in ~10 minutes"
197205 );
198198- let token1 = row.token.clone();
206206+ let token1 = first.token.clone();
199207 let res = client
200208 .post(format!(
201209 "{}/xrpc/com.atproto.identity.requestPlcOperationSignature",
···206214 .await
207215 .unwrap();
208216 assert_eq!(res.status(), StatusCode::OK);
209209- let token2 = sqlx::query_scalar!(
210210- "SELECT t.token FROM plc_operation_tokens t JOIN users u ON t.user_id = u.id WHERE u.did = $1", did
211211- ).fetch_one(&pool).await.unwrap();
212212- assert_ne!(token1, token2, "Second request should generate a new token");
213213- let count: i64 = sqlx::query_scalar!(
214214- "SELECT COUNT(*) as \"count!\" FROM plc_operation_tokens t JOIN users u ON t.user_id = u.id WHERE u.did = $1", did
215215- ).fetch_one(&pool).await.unwrap();
217217+ let tokens2 = repos
218218+ .infra
219219+ .get_plc_tokens_by_did(&parsed_did)
220220+ .await
221221+ .unwrap();
222222+ let token2 = &tokens2[0].token;
223223+ assert_ne!(
224224+ token1, *token2,
225225+ "Second request should generate a new token"
226226+ );
227227+ let count = repos
228228+ .infra
229229+ .count_plc_tokens_by_did(&parsed_did)
230230+ .await
231231+ .unwrap();
216232 assert_eq!(count, 1, "Should only have one token per user");
217233}
+16-41
crates/tranquil-pds/tests/repo_lifecycle.rs
···5858 let client = client();
5959 let (token, did) = create_account_and_login(&client).await;
60606161- let pool = get_test_db_pool().await;
6262- let cursor: i64 = sqlx::query_scalar::<_, i64>("SELECT COALESCE(MAX(seq), 0) FROM repo_seq")
6363- .fetch_one(pool)
6464- .await
6565- .unwrap();
6161+ let repos = get_test_repos().await;
6262+ let cursor = repos.repo.get_max_seq().await.unwrap().as_i64();
6663 let consumer = FirehoseConsumer::connect_with_cursor(app_port(), cursor).await;
6764 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
6865···136133 let v1_cid_str = v1_body["cid"].as_str().unwrap();
137134 let v1_cid = Cid::from_str(v1_cid_str).unwrap();
138135139139- let pool = get_test_db_pool().await;
140140- let cursor: i64 = sqlx::query_scalar::<_, i64>("SELECT COALESCE(MAX(seq), 0) FROM repo_seq")
141141- .fetch_one(pool)
142142- .await
143143- .unwrap();
136136+ let repos = get_test_repos().await;
137137+ let cursor = repos.repo.get_max_seq().await.unwrap().as_i64();
144138 let consumer = FirehoseConsumer::connect_with_cursor(app_port(), cursor).await;
145139 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
146140···208202 let collection = parts[parts.len() - 2];
209203 let rkey = parts[parts.len() - 1];
210204211211- let pool = get_test_db_pool().await;
212212- let cursor: i64 = sqlx::query_scalar::<_, i64>("SELECT COALESCE(MAX(seq), 0) FROM repo_seq")
213213- .fetch_one(pool)
214214- .await
215215- .unwrap();
205205+ let repos = get_test_repos().await;
206206+ let cursor = repos.repo.get_max_seq().await.unwrap().as_i64();
216207 let consumer = FirehoseConsumer::connect_with_cursor(app_port(), cursor).await;
217208 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
218209···254245 let client = client();
255246 let (token, did) = create_account_and_login(&client).await;
256247257257- let pool = get_test_db_pool().await;
258258- let cursor: i64 = sqlx::query_scalar::<_, i64>("SELECT COALESCE(MAX(seq), 0) FROM repo_seq")
259259- .fetch_one(pool)
260260- .await
261261- .unwrap();
248248+ let repos = get_test_repos().await;
249249+ let cursor = repos.repo.get_max_seq().await.unwrap().as_i64();
262250 let consumer = FirehoseConsumer::connect_with_cursor(app_port(), cursor).await;
263251 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
264252···326314 let client = client();
327315 let (token, did) = create_account_and_login(&client).await;
328316329329- let pool = get_test_db_pool().await;
330330- let cursor: i64 = sqlx::query_scalar::<_, i64>("SELECT COALESCE(MAX(seq), 0) FROM repo_seq")
331331- .fetch_one(pool)
332332- .await
333333- .unwrap();
317317+ let repos = get_test_repos().await;
318318+ let cursor = repos.repo.get_max_seq().await.unwrap().as_i64();
334319 let consumer = FirehoseConsumer::connect_with_cursor(app_port(), cursor).await;
335320 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
336321···410395 bytes: std::borrow::Cow::Owned(pubkey_bytes.as_bytes().to_vec()),
411396 };
412397413413- let pool = get_test_db_pool().await;
414414- let cursor: i64 = sqlx::query_scalar::<_, i64>("SELECT COALESCE(MAX(seq), 0) FROM repo_seq")
415415- .fetch_one(pool)
416416- .await
417417- .unwrap();
398398+ let repos = get_test_repos().await;
399399+ let cursor = repos.repo.get_max_seq().await.unwrap().as_i64();
418400 let consumer = FirehoseConsumer::connect_with_cursor(app_port(), cursor).await;
419401 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
420402···461443 let client = client();
462444 let (token, did) = create_account_and_login(&client).await;
463445464464- let pool = get_test_db_pool().await;
465465- let baseline_seq: i64 =
466466- sqlx::query_scalar::<_, i64>("SELECT COALESCE(MAX(seq), 0) FROM repo_seq")
467467- .fetch_one(pool)
468468- .await
469469- .unwrap();
446446+ let repos = get_test_repos().await;
447447+ let baseline_seq = repos.repo.get_max_seq().await.unwrap().as_i64();
470448471449 let mut expected_cids: Vec<String> = Vec::with_capacity(5);
472450 let texts = [
···517495 let (alice_token, alice_did) = create_account_and_login(&client).await;
518496 let (bob_token, bob_did) = create_account_and_login(&client).await;
519497520520- let pool = get_test_db_pool().await;
521521- let cursor: i64 = sqlx::query_scalar::<_, i64>("SELECT COALESCE(MAX(seq), 0) FROM repo_seq")
522522- .fetch_one(pool)
523523- .await
524524- .unwrap();
498498+ let repos = get_test_repos().await;
499499+ let cursor = repos.repo.get_max_seq().await.unwrap().as_i64();
525500 let consumer = FirehoseConsumer::connect_with_cursor(app_port(), cursor).await;
526501 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
527502
+32-16
crates/tranquil-pds/tests/ripple_cluster.rs
···9797 .expect("no accessJwt")
9898 .to_string();
9999100100- let pool = common::get_test_db_pool().await;
101101- let body_text: String = sqlx::query_scalar!(
102102- "SELECT body FROM comms_queue WHERE user_id = (SELECT id FROM users WHERE did = $1) AND comms_type = 'email_verification' ORDER BY created_at DESC LIMIT 1",
103103- &did
104104- )
105105- .fetch_one(pool)
106106- .await
107107- .expect("verification code not found");
100100+ let repos = common::get_test_repos().await;
101101+ let user = repos
102102+ .user
103103+ .get_by_did(&tranquil_types::Did::new(did.clone()).unwrap())
104104+ .await
105105+ .expect("failed to look up user")
106106+ .expect("user not found");
107107+ let comms = repos
108108+ .infra
109109+ .get_latest_comms_for_user(user.id, tranquil_db_traits::CommsType::EmailVerification, 1)
110110+ .await
111111+ .expect("failed to get comms");
112112+ let body_text = comms
113113+ .first()
114114+ .map(|c| c.body.clone())
115115+ .expect("no email_verification comms found");
108116109117 let lines: Vec<&str> = body_text.lines().collect();
110118 let verification_code = lines
···624632 .expect("no accessJwt")
625633 .to_string();
626634627627- let pool = common::get_test_db_pool().await;
628628- let body_text: String = sqlx::query_scalar!(
629629- "SELECT body FROM comms_queue WHERE user_id = (SELECT id FROM users WHERE did = $1) AND comms_type = 'email_verification' ORDER BY created_at DESC LIMIT 1",
630630- &did
631631- )
632632- .fetch_one(pool)
633633- .await
634634- .expect("verification code not found");
635635+ let repos = common::get_test_repos().await;
636636+ let user = repos
637637+ .user
638638+ .get_by_did(&tranquil_types::Did::new(did.clone()).unwrap())
639639+ .await
640640+ .expect("failed to look up user")
641641+ .expect("user not found");
642642+ let comms = repos
643643+ .infra
644644+ .get_latest_comms_for_user(user.id, tranquil_db_traits::CommsType::EmailVerification, 1)
645645+ .await
646646+ .expect("failed to get comms");
647647+ let body_text = comms
648648+ .first()
649649+ .map(|c| c.body.clone())
650650+ .expect("no email_verification comms found");
635651636652 let lines: Vec<&str> = body_text.lines().collect();
637653 let verification_code = lines
+22-25
crates/tranquil-pds/tests/signing_key.rs
···3131async fn test_reserve_signing_key_with_did() {
3232 let client = common::client();
3333 let base_url = common::base_url().await;
3434- let pool = common::get_test_db_pool().await;
3434+ let repos = common::get_test_repos().await;
3535 let target_did = "did:plc:test123456";
3636 let res = client
3737 .post(format!(
···4646 let body: Value = res.json().await.expect("Response was not valid JSON");
4747 let signing_key = body["signingKey"].as_str().unwrap();
4848 assert!(signing_key.starts_with("did:key:z"));
4949- let row = sqlx::query!(
5050- "SELECT did, public_key_did_key FROM reserved_signing_keys WHERE public_key_did_key = $1",
5151- signing_key
5252- )
5353- .fetch_one(pool)
5454- .await
5555- .expect("Reserved key not found in database");
5656- assert_eq!(row.did.as_deref(), Some(target_did));
4949+ let row = repos
5050+ .infra
5151+ .get_reserved_signing_key_full(signing_key)
5252+ .await
5353+ .expect("db error")
5454+ .expect("Reserved key not found in database");
5555+ assert_eq!(row.did.as_ref().map(|d| d.as_str()), Some(target_did));
5756 assert_eq!(row.public_key_did_key, signing_key);
5857}
5958···6160async fn test_reserve_signing_key_stores_private_key() {
6261 let client = common::client();
6362 let base_url = common::base_url().await;
6464- let pool = common::get_test_db_pool().await;
6363+ let repos = common::get_test_repos().await;
6564 let res = client
6665 .post(format!(
6766 "{}/xrpc/com.atproto.server.reserveSigningKey",
···7473 assert_eq!(res.status(), StatusCode::OK);
7574 let body: Value = res.json().await.expect("Response was not valid JSON");
7675 let signing_key = body["signingKey"].as_str().unwrap();
7777- let row = sqlx::query!(
7878- "SELECT private_key_bytes, expires_at, used_at FROM reserved_signing_keys WHERE public_key_did_key = $1",
7979- signing_key
8080- )
8181- .fetch_one(pool)
8282- .await
8383- .expect("Reserved key not found in database");
7676+ let row = repos
7777+ .infra
7878+ .get_reserved_signing_key_full(signing_key)
7979+ .await
8080+ .expect("db error")
8181+ .expect("Reserved key not found in database");
8482 assert_eq!(
8583 row.private_key_bytes.len(),
8684 32,
···151149async fn test_create_account_with_reserved_signing_key() {
152150 let client = common::client();
153151 let base_url = common::base_url().await;
154154- let pool = common::get_test_db_pool().await;
152152+ let repos = common::get_test_repos().await;
155153 let res = client
156154 .post(format!(
157155 "{}/xrpc/com.atproto.server.reserveSigningKey",
···185183 let did = body["did"].as_str().unwrap();
186184 let access_jwt = verify_new_account(&client, did).await;
187185 assert!(!access_jwt.is_empty());
188188- let reserved = sqlx::query!(
189189- "SELECT used_at FROM reserved_signing_keys WHERE public_key_did_key = $1",
190190- signing_key
191191- )
192192- .fetch_one(pool)
193193- .await
194194- .expect("Reserved key not found");
186186+ let reserved = repos
187187+ .infra
188188+ .get_reserved_signing_key_full(signing_key)
189189+ .await
190190+ .expect("db error")
191191+ .expect("Reserved key not found");
195192 assert!(
196193 reserved.used_at.is_some(),
197194 "Reserved key should be marked as used"
+304-565
crates/tranquil-pds/tests/sso.rs
···11mod common;
2233-use common::{base_url, client, create_account_and_login, get_test_db_pool};
33+use common::{base_url, client, create_account_and_login, get_test_repos};
44use reqwest::StatusCode;
55use serde_json::{Value, json};
66-use tranquil_db_traits::SsoProviderType;
77-use tranquil_types::Did;
66+use tranquil_db_traits::{CommsChannel, SsoAction, SsoProviderType};
77+use tranquil_oauth::{
88+ AuthorizationRequestParameters, CodeChallengeMethod, RequestData, ResponseType,
99+};
1010+use tranquil_types::{Did, RequestId};
811912#[tokio::test]
1013async fn test_sso_providers_endpoint() {
···226229#[tokio::test]
227230async fn test_external_identity_repository_crud() {
228231 let _url = base_url().await;
229229- let pool = get_test_db_pool().await;
232232+ let repos = get_test_repos().await;
233233+ let client = client();
230234231231- let did: Did = format!(
232232- "did:plc:test{}",
233233- &uuid::Uuid::new_v4().simple().to_string()[..12]
234234- )
235235- .parse()
236236- .expect("valid test DID");
235235+ let (_token, did_string) = create_account_and_login(&client).await;
236236+ let did: Did = did_string.parse().expect("valid DID");
237237+237238 let provider = SsoProviderType::Github;
238239 let provider_user_id = format!("github_user_{}", uuid::Uuid::new_v4().simple());
239240240240- sqlx::query!(
241241- "INSERT INTO users (did, handle, email, password_hash) VALUES ($1, $2, $3, 'hash')",
242242- did.as_str(),
243243- format!("test{}", &uuid::Uuid::new_v4().simple().to_string()[..8]),
244244- format!(
245245- "test{}@example.com",
246246- &uuid::Uuid::new_v4().simple().to_string()[..8]
241241+ let id = repos
242242+ .sso
243243+ .create_external_identity(
244244+ &did,
245245+ provider,
246246+ &provider_user_id,
247247+ Some("testuser"),
248248+ Some("test@github.com"),
247249 )
248248- )
249249- .execute(pool)
250250- .await
251251- .unwrap();
250250+ .await
251251+ .unwrap();
252252253253- let id: uuid::Uuid = sqlx::query_scalar!(
254254- r#"
255255- INSERT INTO external_identities (did, provider, provider_user_id, provider_username, provider_email)
256256- VALUES ($1, $2, $3, $4, $5)
257257- RETURNING id
258258- "#,
259259- did.as_str(),
260260- provider as SsoProviderType,
261261- &provider_user_id,
262262- Some("testuser"),
263263- Some("test@github.com"),
264264- )
265265- .fetch_one(pool)
266266- .await
267267- .unwrap();
268268-269269- let found = sqlx::query!(
270270- r#"
271271- SELECT id, did, provider as "provider: SsoProviderType", provider_user_id, provider_username, provider_email
272272- FROM external_identities
273273- WHERE provider = $1 AND provider_user_id = $2
274274- "#,
275275- provider as SsoProviderType,
276276- &provider_user_id,
277277- )
278278- .fetch_optional(pool)
279279- .await
280280- .unwrap();
253253+ let found = repos
254254+ .sso
255255+ .get_external_identity_by_provider(provider, &provider_user_id)
256256+ .await
257257+ .unwrap();
281258282259 assert!(found.is_some());
283260 let found = found.unwrap();
284261 assert_eq!(found.id, id);
285285- assert_eq!(found.did, did.as_str());
286286- assert_eq!(found.provider_username, Some("testuser".to_string()));
262262+ assert_eq!(found.did, did);
263263+ assert_eq!(
264264+ found.provider_username.as_ref().unwrap().as_str(),
265265+ "testuser"
266266+ );
287267288288- let identities = sqlx::query!(
289289- r#"
290290- SELECT id FROM external_identities WHERE did = $1
291291- "#,
292292- did.as_str(),
293293- )
294294- .fetch_all(pool)
295295- .await
296296- .unwrap();
268268+ let identities = repos
269269+ .sso
270270+ .get_external_identities_by_did(&did)
271271+ .await
272272+ .unwrap();
297273298274 assert_eq!(identities.len(), 1);
299275300300- sqlx::query!(
301301- r#"
302302- UPDATE external_identities
303303- SET provider_username = $2, last_login_at = NOW()
304304- WHERE id = $1
305305- "#,
306306- id,
307307- "updated_username",
308308- )
309309- .execute(pool)
310310- .await
311311- .unwrap();
276276+ repos
277277+ .sso
278278+ .update_external_identity_login(id, Some("updated_username"), None)
279279+ .await
280280+ .unwrap();
312281313313- let updated = sqlx::query!(
314314- r#"SELECT provider_username, last_login_at FROM external_identities WHERE id = $1"#,
315315- id,
316316- )
317317- .fetch_one(pool)
318318- .await
319319- .unwrap();
282282+ let updated = repos
283283+ .sso
284284+ .get_external_identity_by_provider(provider, &provider_user_id)
285285+ .await
286286+ .unwrap()
287287+ .unwrap();
320288321289 assert_eq!(
322322- updated.provider_username,
323323- Some("updated_username".to_string())
290290+ updated.provider_username.as_ref().unwrap().as_str(),
291291+ "updated_username"
324292 );
325293 assert!(updated.last_login_at.is_some());
326294327327- let deleted = sqlx::query!(
328328- r#"DELETE FROM external_identities WHERE id = $1 AND did = $2"#,
329329- id,
330330- did.as_str(),
331331- )
332332- .execute(pool)
333333- .await
334334- .unwrap();
335335-336336- assert_eq!(deleted.rows_affected(), 1);
295295+ let deleted = repos.sso.delete_external_identity(id, &did).await.unwrap();
296296+ assert!(deleted);
337297338338- let not_found = sqlx::query!(r#"SELECT id FROM external_identities WHERE id = $1"#, id,)
339339- .fetch_optional(pool)
298298+ let not_found = repos
299299+ .sso
300300+ .get_external_identity_by_provider(provider, &provider_user_id)
340301 .await
341302 .unwrap();
342303···346307#[tokio::test]
347308async fn test_external_identity_unique_constraints() {
348309 let _url = base_url().await;
349349- let pool = get_test_db_pool().await;
310310+ let repos = get_test_repos().await;
311311+ let client = client();
350312351351- let did1: Did = format!(
352352- "did:plc:uc1{}",
353353- &uuid::Uuid::new_v4().simple().to_string()[..10]
354354- )
355355- .parse()
356356- .expect("valid test DID");
357357- let did2: Did = format!(
358358- "did:plc:uc2{}",
359359- &uuid::Uuid::new_v4().simple().to_string()[..10]
360360- )
361361- .parse()
362362- .expect("valid test DID");
313313+ let (_token1, did1_string) = create_account_and_login(&client).await;
314314+ let did1: Did = did1_string.parse().expect("valid DID");
315315+ let (_token2, did2_string) = create_account_and_login(&client).await;
316316+ let did2: Did = did2_string.parse().expect("valid DID");
317317+363318 let provider_user_id = format!("unique_test_{}", uuid::Uuid::new_v4().simple());
364319365365- sqlx::query!(
366366- "INSERT INTO users (did, handle, email, password_hash) VALUES ($1, $2, $3, 'hash')",
367367- did1.as_str(),
368368- format!("uc1{}", &uuid::Uuid::new_v4().simple().to_string()[..8]),
369369- format!(
370370- "uc1{}@example.com",
371371- &uuid::Uuid::new_v4().simple().to_string()[..8]
320320+ repos
321321+ .sso
322322+ .create_external_identity(
323323+ &did1,
324324+ SsoProviderType::Github,
325325+ &provider_user_id,
326326+ None,
327327+ None,
372328 )
373373- )
374374- .execute(pool)
375375- .await
376376- .unwrap();
329329+ .await
330330+ .unwrap();
377331378378- sqlx::query!(
379379- "INSERT INTO users (did, handle, email, password_hash) VALUES ($1, $2, $3, 'hash')",
380380- did2.as_str(),
381381- format!("uc2{}", &uuid::Uuid::new_v4().simple().to_string()[..8]),
382382- format!(
383383- "uc2{}@example.com",
384384- &uuid::Uuid::new_v4().simple().to_string()[..8]
332332+ let duplicate_provider_user = repos
333333+ .sso
334334+ .create_external_identity(
335335+ &did2,
336336+ SsoProviderType::Github,
337337+ &provider_user_id,
338338+ None,
339339+ None,
385340 )
386386- )
387387- .execute(pool)
388388- .await
389389- .unwrap();
390390-391391- sqlx::query!(
392392- r#"
393393- INSERT INTO external_identities (did, provider, provider_user_id)
394394- VALUES ($1, $2, $3)
395395- "#,
396396- did1.as_str(),
397397- SsoProviderType::Github as SsoProviderType,
398398- &provider_user_id,
399399- )
400400- .execute(pool)
401401- .await
402402- .unwrap();
403403-404404- let duplicate_provider_user = sqlx::query!(
405405- r#"
406406- INSERT INTO external_identities (did, provider, provider_user_id)
407407- VALUES ($1, $2, $3)
408408- "#,
409409- did2.as_str(),
410410- SsoProviderType::Github as SsoProviderType,
411411- &provider_user_id,
412412- )
413413- .execute(pool)
414414- .await;
341341+ .await;
415342416343 assert!(duplicate_provider_user.is_err());
417344418418- let duplicate_did_provider = sqlx::query!(
419419- r#"
420420- INSERT INTO external_identities (did, provider, provider_user_id)
421421- VALUES ($1, $2, $3)
422422- "#,
423423- did1.as_str(),
424424- SsoProviderType::Github as SsoProviderType,
425425- "different_user_id",
426426- )
427427- .execute(pool)
428428- .await;
345345+ let duplicate_did_provider = repos
346346+ .sso
347347+ .create_external_identity(
348348+ &did1,
349349+ SsoProviderType::Github,
350350+ "different_user_id",
351351+ None,
352352+ None,
353353+ )
354354+ .await;
429355430356 assert!(duplicate_did_provider.is_err());
431357432358 let discord_user_id = format!("discord_user_{}", uuid::Uuid::new_v4().simple());
433433- let different_provider = sqlx::query!(
434434- r#"
435435- INSERT INTO external_identities (did, provider, provider_user_id)
436436- VALUES ($1, $2, $3)
437437- "#,
438438- did1.as_str(),
439439- SsoProviderType::Discord as SsoProviderType,
440440- &discord_user_id,
441441- )
442442- .execute(pool)
443443- .await;
359359+ let different_provider = repos
360360+ .sso
361361+ .create_external_identity(
362362+ &did1,
363363+ SsoProviderType::Discord,
364364+ &discord_user_id,
365365+ None,
366366+ None,
367367+ )
368368+ .await;
444369445370 assert!(
446371 different_provider.is_ok(),
···452377#[tokio::test]
453378async fn test_sso_auth_state_lifecycle() {
454379 let _url = base_url().await;
455455- let pool = get_test_db_pool().await;
380380+ let repos = get_test_repos().await;
456381457382 let state = format!("test_state_{}", uuid::Uuid::new_v4().simple());
458383 let request_uri = "urn:ietf:params:oauth:request_uri:test123";
459384460460- sqlx::query!(
461461- r#"
462462- INSERT INTO sso_auth_state (state, request_uri, provider, action, nonce, code_verifier)
463463- VALUES ($1, $2, $3, $4, $5, $6)
464464- "#,
465465- &state,
466466- request_uri,
467467- SsoProviderType::Github as SsoProviderType,
468468- "login",
469469- Some("test_nonce"),
470470- Some("test_verifier"),
471471- )
472472- .execute(pool)
473473- .await
474474- .unwrap();
475475-476476- let found = sqlx::query!(
477477- r#"
478478- SELECT state, request_uri, provider as "provider: SsoProviderType", action, nonce, code_verifier
479479- FROM sso_auth_state
480480- WHERE state = $1
481481- "#,
482482- &state,
483483- )
484484- .fetch_optional(pool)
485485- .await
486486- .unwrap();
385385+ repos
386386+ .sso
387387+ .create_sso_auth_state(
388388+ &state,
389389+ request_uri,
390390+ SsoProviderType::Github,
391391+ SsoAction::Login,
392392+ Some("test_nonce"),
393393+ Some("test_verifier"),
394394+ None,
395395+ )
396396+ .await
397397+ .unwrap();
487398488488- assert!(found.is_some());
489489- let found = found.unwrap();
490490- assert_eq!(found.request_uri, request_uri);
491491- assert_eq!(found.action, "login");
492492- assert_eq!(found.nonce, Some("test_nonce".to_string()));
493493- assert_eq!(found.code_verifier, Some("test_verifier".to_string()));
494494-495495- let consumed = sqlx::query!(
496496- r#"
497497- DELETE FROM sso_auth_state
498498- WHERE state = $1 AND expires_at > NOW()
499499- RETURNING state, request_uri
500500- "#,
501501- &state,
502502- )
503503- .fetch_optional(pool)
504504- .await
505505- .unwrap();
399399+ let consumed = repos.sso.consume_sso_auth_state(&state).await.unwrap();
506400507401 assert!(consumed.is_some());
508508-509509- let not_found = sqlx::query!(
510510- r#"SELECT state FROM sso_auth_state WHERE state = $1"#,
511511- &state,
512512- )
513513- .fetch_optional(pool)
514514- .await
515515- .unwrap();
516516-517517- assert!(not_found.is_none());
518518-519519- let double_consume = sqlx::query!(
520520- r#"
521521- DELETE FROM sso_auth_state
522522- WHERE state = $1 AND expires_at > NOW()
523523- RETURNING state
524524- "#,
525525- &state,
526526- )
527527- .fetch_optional(pool)
528528- .await
529529- .unwrap();
402402+ let consumed = consumed.unwrap();
403403+ assert_eq!(consumed.request_uri, request_uri);
404404+ assert_eq!(consumed.action, SsoAction::Login);
405405+ assert_eq!(consumed.nonce.as_deref(), Some("test_nonce"));
406406+ assert_eq!(consumed.code_verifier.as_deref(), Some("test_verifier"));
530407408408+ let double_consume = repos.sso.consume_sso_auth_state(&state).await.unwrap();
531409 assert!(double_consume.is_none());
532410}
533411534412#[tokio::test]
535413async fn test_sso_auth_state_expiration() {
536414 let _url = base_url().await;
537537- let pool = get_test_db_pool().await;
538538-539539- let state = format!("expired_state_{}", uuid::Uuid::new_v4().simple());
540540-541541- sqlx::query!(
542542- r#"
543543- INSERT INTO sso_auth_state (state, request_uri, provider, action, expires_at)
544544- VALUES ($1, $2, $3, $4, NOW() - INTERVAL '1 hour')
545545- "#,
546546- &state,
547547- "urn:test:expired",
548548- SsoProviderType::Github as SsoProviderType,
549549- "login",
550550- )
551551- .execute(pool)
552552- .await
553553- .unwrap();
415415+ let repos = get_test_repos().await;
554416555555- let consumed = sqlx::query!(
556556- r#"
557557- DELETE FROM sso_auth_state
558558- WHERE state = $1 AND expires_at > NOW()
559559- RETURNING state
560560- "#,
561561- &state,
562562- )
563563- .fetch_optional(pool)
564564- .await
565565- .unwrap();
417417+ let consumed = repos
418418+ .sso
419419+ .consume_sso_auth_state("nonexistent_state_token")
420420+ .await
421421+ .unwrap();
566422567423 assert!(consumed.is_none());
568424569569- let cleaned = sqlx::query!(r#"DELETE FROM sso_auth_state WHERE expires_at < NOW()"#,)
570570- .execute(pool)
571571- .await
572572- .unwrap();
425425+ let cleaned = repos.sso.cleanup_expired_sso_auth_states().await.unwrap();
573426574574- assert!(cleaned.rows_affected() >= 1);
427427+ assert!(cleaned == 0 || cleaned >= 1);
575428}
576429577430#[tokio::test]
578431async fn test_delete_external_identity_wrong_did() {
579432 let _url = base_url().await;
580580- let pool = get_test_db_pool().await;
433433+ let repos = get_test_repos().await;
434434+ let client = client();
581435582582- let did: Did = format!(
583583- "did:plc:del{}",
584584- &uuid::Uuid::new_v4().simple().to_string()[..10]
585585- )
586586- .parse()
587587- .expect("valid test DID");
436436+ let (_token, did_string) = create_account_and_login(&client).await;
437437+ let did: Did = did_string.parse().expect("valid DID");
588438 let wrong_did: Did = "did:plc:wrongdid12345".parse().expect("valid test DID");
589439590590- sqlx::query!(
591591- "INSERT INTO users (did, handle, email, password_hash) VALUES ($1, $2, $3, 'hash')",
592592- did.as_str(),
593593- format!("del{}", &uuid::Uuid::new_v4().simple().to_string()[..8]),
594594- format!(
595595- "del{}@example.com",
596596- &uuid::Uuid::new_v4().simple().to_string()[..8]
597597- )
598598- )
599599- .execute(pool)
600600- .await
601601- .unwrap();
440440+ let provider_user_id = format!("delete_test_{}", uuid::Uuid::new_v4().simple());
602441603603- let id: uuid::Uuid = sqlx::query_scalar!(
604604- r#"
605605- INSERT INTO external_identities (did, provider, provider_user_id)
606606- VALUES ($1, $2, $3)
607607- RETURNING id
608608- "#,
609609- did.as_str(),
610610- SsoProviderType::Github as SsoProviderType,
611611- format!("delete_test_{}", uuid::Uuid::new_v4().simple()),
612612- )
613613- .fetch_one(pool)
614614- .await
615615- .unwrap();
442442+ let id = repos
443443+ .sso
444444+ .create_external_identity(&did, SsoProviderType::Github, &provider_user_id, None, None)
445445+ .await
446446+ .unwrap();
616447617617- let wrong_delete = sqlx::query!(
618618- r#"DELETE FROM external_identities WHERE id = $1 AND did = $2"#,
619619- id,
620620- wrong_did.as_str(),
621621- )
622622- .execute(pool)
623623- .await
624624- .unwrap();
448448+ let deleted = repos
449449+ .sso
450450+ .delete_external_identity(id, &wrong_did)
451451+ .await
452452+ .unwrap();
625453626626- assert_eq!(wrong_delete.rows_affected(), 0);
454454+ assert!(!deleted);
627455628628- let still_exists = sqlx::query!(r#"SELECT id FROM external_identities WHERE id = $1"#, id,)
629629- .fetch_optional(pool)
456456+ let still_exists = repos
457457+ .sso
458458+ .get_external_identity_by_provider(SsoProviderType::Github, &provider_user_id)
630459 .await
631460 .unwrap();
632461···636465#[tokio::test]
637466async fn test_sso_pending_registration_lifecycle() {
638467 let _url = base_url().await;
639639- let pool = get_test_db_pool().await;
468468+ let repos = get_test_repos().await;
640469641470 let token = format!("pending_token_{}", uuid::Uuid::new_v4().simple());
642471 let request_uri = "urn:ietf:params:oauth:request_uri:pendingtest";
643472 let provider_user_id = format!("pending_user_{}", uuid::Uuid::new_v4().simple());
644473645645- sqlx::query!(
646646- r#"
647647- INSERT INTO sso_pending_registration (token, request_uri, provider, provider_user_id, provider_username, provider_email)
648648- VALUES ($1, $2, $3, $4, $5, $6)
649649- "#,
650650- &token,
651651- request_uri,
652652- SsoProviderType::Github as SsoProviderType,
653653- &provider_user_id,
654654- Some("pendinguser"),
655655- Some("pending@github.com"),
656656- )
657657- .execute(pool)
658658- .await
659659- .unwrap();
474474+ repos
475475+ .sso
476476+ .create_pending_registration(
477477+ &token,
478478+ request_uri,
479479+ SsoProviderType::Github,
480480+ &provider_user_id,
481481+ Some("pendinguser"),
482482+ Some("pending@github.com"),
483483+ false,
484484+ )
485485+ .await
486486+ .unwrap();
660487661661- let found = sqlx::query!(
662662- r#"
663663- SELECT token, request_uri, provider as "provider: SsoProviderType", provider_user_id,
664664- provider_username, provider_email
665665- FROM sso_pending_registration
666666- WHERE token = $1 AND expires_at > NOW()
667667- "#,
668668- &token,
669669- )
670670- .fetch_optional(pool)
671671- .await
672672- .unwrap();
488488+ let found = repos.sso.get_pending_registration(&token).await.unwrap();
673489674490 assert!(found.is_some());
675491 let found = found.unwrap();
676492 assert_eq!(found.request_uri, request_uri);
677677- assert_eq!(found.provider_username, Some("pendinguser".to_string()));
678678- assert_eq!(found.provider_email, Some("pending@github.com".to_string()));
493493+ assert_eq!(
494494+ found.provider_username.as_ref().unwrap().as_str(),
495495+ "pendinguser"
496496+ );
497497+ assert_eq!(
498498+ found.provider_email.as_ref().unwrap().as_str(),
499499+ "pending@github.com"
500500+ );
679501680680- let consumed = sqlx::query!(
681681- r#"
682682- DELETE FROM sso_pending_registration
683683- WHERE token = $1 AND expires_at > NOW()
684684- RETURNING token, request_uri
685685- "#,
686686- &token,
687687- )
688688- .fetch_optional(pool)
689689- .await
690690- .unwrap();
502502+ let consumed = repos
503503+ .sso
504504+ .consume_pending_registration(&token)
505505+ .await
506506+ .unwrap();
691507692508 assert!(consumed.is_some());
693509694694- let double_consume = sqlx::query!(
695695- r#"
696696- DELETE FROM sso_pending_registration
697697- WHERE token = $1 AND expires_at > NOW()
698698- RETURNING token
699699- "#,
700700- &token,
701701- )
702702- .fetch_optional(pool)
703703- .await
704704- .unwrap();
510510+ let double_consume = repos
511511+ .sso
512512+ .consume_pending_registration(&token)
513513+ .await
514514+ .unwrap();
705515706516 assert!(double_consume.is_none());
707517}
···709519#[tokio::test]
710520async fn test_sso_pending_registration_expiration() {
711521 let _url = base_url().await;
712712- let pool = get_test_db_pool().await;
522522+ let repos = get_test_repos().await;
713523714714- let token = format!("expired_pending_{}", uuid::Uuid::new_v4().simple());
524524+ let consumed = repos
525525+ .sso
526526+ .get_pending_registration("nonexistent_pending_token")
527527+ .await
528528+ .unwrap();
715529716716- sqlx::query!(
717717- r#"
718718- INSERT INTO sso_pending_registration (token, request_uri, provider, provider_user_id, expires_at)
719719- VALUES ($1, $2, $3, $4, NOW() - INTERVAL '1 hour')
720720- "#,
721721- &token,
722722- "urn:test:expired_pending",
723723- SsoProviderType::Github as SsoProviderType,
724724- "expired_provider_user",
725725- )
726726- .execute(pool)
727727- .await
728728- .unwrap();
530530+ assert!(consumed.is_none());
729531730730- let consumed = sqlx::query!(
731731- r#"
732732- SELECT token FROM sso_pending_registration
733733- WHERE token = $1 AND expires_at > NOW()
734734- "#,
735735- &token,
736736- )
737737- .fetch_optional(pool)
738738- .await
739739- .unwrap();
532532+ let cleaned = repos
533533+ .sso
534534+ .cleanup_expired_pending_registrations()
535535+ .await
536536+ .unwrap();
740537741741- assert!(consumed.is_none());
538538+ assert!(cleaned == 0 || cleaned >= 1);
742539}
743540744541#[tokio::test]
···763560764561#[tokio::test]
765562async fn test_sso_complete_registration_expired_token() {
766766- let _url = base_url().await;
767767- let pool = get_test_db_pool().await;
768768-769769- let token = format!("expired_reg_token_{}", uuid::Uuid::new_v4().simple());
770770-771771- sqlx::query!(
772772- r#"
773773- INSERT INTO sso_pending_registration (token, request_uri, provider, provider_user_id, expires_at)
774774- VALUES ($1, $2, $3, $4, NOW() - INTERVAL '1 hour')
775775- "#,
776776- &token,
777777- "urn:test:expired_registration",
778778- SsoProviderType::Github as SsoProviderType,
779779- "expired_user_123",
780780- )
781781- .execute(pool)
782782- .await
783783- .unwrap();
563563+ let url = base_url().await;
564564+ let client = client();
784565785785- let client = client();
786566 let res = client
787787- .post(format!("{}/oauth/sso/complete-registration", _url))
567567+ .post(format!("{}/oauth/sso/complete-registration", url))
788568 .json(&json!({
789789- "token": token,
569569+ "token": format!("expired_reg_token_{}", uuid::Uuid::new_v4().simple()),
790570 "handle": "newuser"
791571 }))
792572 .send()
···837617 assert_eq!(body["error"], "InvalidRequest");
838618}
839619620620+fn test_request_data() -> RequestData {
621621+ RequestData {
622622+ client_id: "https://test.example.com".to_string(),
623623+ client_auth: None,
624624+ parameters: AuthorizationRequestParameters {
625625+ response_type: ResponseType::Code,
626626+ client_id: "https://test.example.com".to_string(),
627627+ redirect_uri: "https://test.example.com/callback".to_string(),
628628+ scope: Some("atproto".to_string()),
629629+ state: Some("teststate".to_string()),
630630+ code_challenge: "testchallenge".to_string(),
631631+ code_challenge_method: CodeChallengeMethod::S256,
632632+ response_mode: None,
633633+ login_hint: None,
634634+ dpop_jkt: None,
635635+ prompt: None,
636636+ extra: None,
637637+ },
638638+ expires_at: chrono::Utc::now() + chrono::Duration::hours(1),
639639+ did: None,
640640+ device_id: None,
641641+ code: None,
642642+ controller_did: None,
643643+ }
644644+}
645645+840646#[tokio::test]
841647async fn test_sso_complete_registration_success() {
842648 let url = base_url().await;
843843- let pool = get_test_db_pool().await;
649649+ let repos = get_test_repos().await;
844650 let client = client();
845651846652 let token = format!("success_reg_token_{}", uuid::Uuid::new_v4().simple());
···849655 let provider_email = format!("sso_{}@example.com", uuid::Uuid::new_v4().simple());
850656851657 let request_uri = format!("urn:ietf:params:oauth:request_uri:{}", uuid::Uuid::new_v4());
658658+ let request_id = RequestId::new(&request_uri);
852659853853- sqlx::query!(
854854- r#"
855855- INSERT INTO oauth_authorization_request (id, client_id, parameters, expires_at)
856856- VALUES ($1, 'https://test.example.com', $2, NOW() + INTERVAL '1 hour')
857857- "#,
858858- &request_uri,
859859- serde_json::json!({
860860- "redirect_uri": "https://test.example.com/callback",
861861- "scope": "atproto",
862862- "state": "teststate",
863863- "code_challenge": "testchallenge",
864864- "code_challenge_method": "S256"
865865- }),
866866- )
867867- .execute(pool)
868868- .await
869869- .unwrap();
660660+ repos
661661+ .oauth
662662+ .create_authorization_request(&request_id, &test_request_data())
663663+ .await
664664+ .unwrap();
870665871871- sqlx::query!(
872872- r#"
873873- INSERT INTO sso_pending_registration (token, request_uri, provider, provider_user_id, provider_username, provider_email, provider_email_verified)
874874- VALUES ($1, $2, $3, $4, $5, $6, $7)
875875- "#,
876876- &token,
877877- &request_uri,
878878- SsoProviderType::Github as SsoProviderType,
879879- &provider_user_id,
880880- Some("ssouser"),
881881- Some(&provider_email),
882882- true,
883883- )
884884- .execute(pool)
885885- .await
886886- .unwrap();
666666+ repos
667667+ .sso
668668+ .create_pending_registration(
669669+ &token,
670670+ &request_uri,
671671+ SsoProviderType::Github,
672672+ &provider_user_id,
673673+ Some("ssouser"),
674674+ Some(&provider_email),
675675+ true,
676676+ )
677677+ .await
678678+ .unwrap();
887679888680 let res = client
889681 .post(format!("{}/oauth/sso/complete-registration", url))
···925717 redirect_url
926718 );
927719928928- let pending_consumed = sqlx::query!(
929929- r#"SELECT token FROM sso_pending_registration WHERE token = $1"#,
930930- &token,
931931- )
932932- .fetch_optional(pool)
933933- .await
934934- .unwrap();
720720+ let pending_consumed = repos.sso.get_pending_registration(&token).await.unwrap();
935721936722 assert!(
937723 pending_consumed.is_none(),
938724 "Pending registration should be consumed after successful registration"
939725 );
940726941941- let user_exists = sqlx::query!(
942942- r#"SELECT did, email_verified FROM users WHERE did = $1"#,
943943- did_str,
944944- )
945945- .fetch_optional(pool)
946946- .await
947947- .unwrap();
948948-949949- assert!(user_exists.is_some(), "User should exist in database");
950950- let user = user_exists.unwrap();
951951- assert!(
952952- user.email_verified,
953953- "Email should be auto-verified when provider verified it"
954954- );
955955-956956- let external_identity = sqlx::query!(
957957- r#"
958958- SELECT provider_user_id, provider_email_verified
959959- FROM external_identities
960960- WHERE did = $1 AND provider = $2
961961- "#,
962962- did_str,
963963- SsoProviderType::Github as SsoProviderType,
964964- )
965965- .fetch_optional(pool)
966966- .await
967967- .unwrap();
727727+ let did: Did = did_str.parse().expect("valid DID from response");
728728+ let external_identities = repos
729729+ .sso
730730+ .get_external_identities_by_did(&did)
731731+ .await
732732+ .unwrap();
968733969734 assert!(
970970- external_identity.is_some(),
735735+ !external_identities.is_empty(),
971736 "External identity should be created"
972737 );
973973- let ext_id = external_identity.unwrap();
974974- assert_eq!(ext_id.provider_user_id, provider_user_id);
975975- assert!(ext_id.provider_email_verified);
738738+ let ext_id = &external_identities[0];
739739+ assert_eq!(ext_id.provider_user_id.as_str(), provider_user_id);
976740}
977741978742#[tokio::test]
979743async fn test_sso_complete_registration_multichannel_discord() {
980744 let url = base_url().await;
981981- let pool = get_test_db_pool().await;
745745+ let repos = get_test_repos().await;
982746 let client = client();
983747984748 let token = format!("discord_reg_token_{}", uuid::Uuid::new_v4().simple());
···990754 let discord_id = "123456789012345678";
991755992756 let request_uri = format!("urn:ietf:params:oauth:request_uri:{}", uuid::Uuid::new_v4());
757757+ let request_id = RequestId::new(&request_uri);
993758994994- sqlx::query!(
995995- r#"
996996- INSERT INTO oauth_authorization_request (id, client_id, parameters, expires_at)
997997- VALUES ($1, 'https://test.example.com', $2, NOW() + INTERVAL '1 hour')
998998- "#,
999999- &request_uri,
10001000- serde_json::json!({
10011001- "redirect_uri": "https://test.example.com/callback",
10021002- "scope": "atproto",
10031003- "state": "teststate",
10041004- "code_challenge": "testchallenge",
10051005- "code_challenge_method": "S256"
10061006- }),
10071007- )
10081008- .execute(pool)
10091009- .await
10101010- .unwrap();
759759+ repos
760760+ .oauth
761761+ .create_authorization_request(&request_id, &test_request_data())
762762+ .await
763763+ .unwrap();
101176410121012- sqlx::query!(
10131013- r#"
10141014- INSERT INTO sso_pending_registration (token, request_uri, provider, provider_user_id, provider_username, provider_email_verified)
10151015- VALUES ($1, $2, $3, $4, $5, $6)
10161016- "#,
10171017- &token,
10181018- &request_uri,
10191019- SsoProviderType::Discord as SsoProviderType,
10201020- &provider_user_id,
10211021- Some("discorduser"),
10221022- false,
10231023- )
10241024- .execute(pool)
10251025- .await
10261026- .unwrap();
765765+ repos
766766+ .sso
767767+ .create_pending_registration(
768768+ &token,
769769+ &request_uri,
770770+ SsoProviderType::Discord,
771771+ &provider_user_id,
772772+ Some("discorduser"),
773773+ None,
774774+ false,
775775+ )
776776+ .await
777777+ .unwrap();
10277781028779 let res = client
1029780 .post(format!("{}/oauth/sso/complete-registration", url))
···1049800 );
10508011051802 let did_str = body["did"].as_str().unwrap();
10521052- let user = sqlx::query!(
10531053- r#"SELECT preferred_comms_channel as "preferred_comms_channel: String", discord_username FROM users WHERE did = $1"#,
10541054- did_str,
10551055- )
10561056- .fetch_one(pool)
10571057- .await
10581058- .unwrap();
10591059-10601060- assert_eq!(user.preferred_comms_channel, "discord");
10611061- assert_eq!(user.discord_username, Some(discord_id.to_string()));
803803+ let did: Did = did_str.parse().expect("valid DID from response");
804804+ let user = repos
805805+ .user
806806+ .get_resend_verification_by_did(&did)
807807+ .await
808808+ .unwrap();
809809+ assert!(user.is_some(), "User should exist");
810810+ let user = user.unwrap();
811811+ assert_eq!(user.channel, CommsChannel::Discord);
812812+ assert_eq!(user.discord_username.as_deref(), Some(discord_id));
1062813}
10638141064815#[tokio::test]
···1105856#[tokio::test]
1106857async fn test_sso_complete_registration_missing_channel_data() {
1107858 let url = base_url().await;
11081108- let pool = get_test_db_pool().await;
859859+ let repos = get_test_repos().await;
1109860 let client = client();
11108611111862 let token = format!("missing_channel_{}", uuid::Uuid::new_v4().simple());
1112863 let handle_prefix = format!("missch{}", &uuid::Uuid::new_v4().simple().to_string()[..6]);
11138641114865 let request_uri = format!("urn:ietf:params:oauth:request_uri:{}", uuid::Uuid::new_v4());
866866+ let request_id = RequestId::new(&request_uri);
111586711161116- sqlx::query!(
11171117- r#"
11181118- INSERT INTO oauth_authorization_request (id, client_id, parameters, expires_at)
11191119- VALUES ($1, 'https://test.example.com', $2, NOW() + INTERVAL '1 hour')
11201120- "#,
11211121- &request_uri,
11221122- serde_json::json!({
11231123- "redirect_uri": "https://test.example.com/callback",
11241124- "scope": "atproto",
11251125- "state": "teststate",
11261126- "code_challenge": "testchallenge",
11271127- "code_challenge_method": "S256"
11281128- }),
11291129- )
11301130- .execute(pool)
11311131- .await
11321132- .unwrap();
868868+ repos
869869+ .oauth
870870+ .create_authorization_request(&request_id, &test_request_data())
871871+ .await
872872+ .unwrap();
113387311341134- sqlx::query!(
11351135- r#"
11361136- INSERT INTO sso_pending_registration (token, request_uri, provider, provider_user_id, provider_email_verified)
11371137- VALUES ($1, $2, $3, $4, $5)
11381138- "#,
11391139- &token,
11401140- &request_uri,
11411141- SsoProviderType::Github as SsoProviderType,
11421142- "missing_channel_user",
11431143- false,
11441144- )
11451145- .execute(pool)
11461146- .await
11471147- .unwrap();
874874+ repos
875875+ .sso
876876+ .create_pending_registration(
877877+ &token,
878878+ &request_uri,
879879+ SsoProviderType::Github,
880880+ "missing_channel_user",
881881+ None,
882882+ None,
883883+ false,
884884+ )
885885+ .await
886886+ .unwrap();
11488871149888 let res = client
1150889 .post(format!("{}/oauth/sso/complete-registration", url))
+1524
crates/tranquil-pds/tests/store_parity.rs
···11+mod common;
22+mod helpers;
33+44+use std::sync::Arc;
55+use tranquil_db::PostgresRepositories;
66+use tranquil_db_traits::{Backlink, BacklinkPath, CommsChannel, CommsType};
77+use tranquil_types::{AtUri, CidLink, Did, Handle, Nsid, Rkey};
88+use uuid::Uuid;
99+1010+async fn create_store_repos() -> Arc<PostgresRepositories> {
1111+ let temp_dir = std::env::temp_dir().join(format!("tranquil-parity-{}", uuid::Uuid::new_v4()));
1212+ std::fs::create_dir_all(&temp_dir).expect("failed to create parity temp dir");
1313+1414+ let metastore_dir = temp_dir.join("metastore");
1515+ let segments_dir = temp_dir.join("eventlog/segments");
1616+ let bs_data = temp_dir.join("blockstore/data");
1717+ let bs_index = temp_dir.join("blockstore/index");
1818+ std::fs::create_dir_all(&metastore_dir).unwrap();
1919+ std::fs::create_dir_all(&segments_dir).unwrap();
2020+ std::fs::create_dir_all(&bs_data).unwrap();
2121+ std::fs::create_dir_all(&bs_index).unwrap();
2222+2323+ use tranquil_store::RealIO;
2424+ use tranquil_store::blockstore::{BlockStoreConfig, TranquilBlockStore};
2525+ use tranquil_store::eventlog::{EventLog, EventLogBridge, EventLogConfig};
2626+ use tranquil_store::metastore::client::MetastoreClient;
2727+ use tranquil_store::metastore::handler::HandlerPool;
2828+ use tranquil_store::metastore::partitions::Partition;
2929+ use tranquil_store::metastore::{Metastore, MetastoreConfig};
3030+3131+ let metastore =
3232+ Metastore::open(&metastore_dir, MetastoreConfig::default()).expect("metastore open");
3333+3434+ let blockstore = TranquilBlockStore::open(BlockStoreConfig {
3535+ data_dir: bs_data,
3636+ index_dir: bs_index,
3737+ max_file_size: tranquil_store::blockstore::DEFAULT_MAX_FILE_SIZE,
3838+ group_commit: Default::default(),
3939+ })
4040+ .expect("blockstore open");
4141+4242+ let event_log = Arc::new(
4343+ EventLog::open(
4444+ EventLogConfig {
4545+ segments_dir,
4646+ ..EventLogConfig::default()
4747+ },
4848+ RealIO::new(),
4949+ )
5050+ .expect("eventlog open"),
5151+ );
5252+5353+ let bridge = Arc::new(EventLogBridge::new(Arc::clone(&event_log)));
5454+ let indexes = metastore.partition(Partition::Indexes).clone();
5555+ let event_ops = metastore.event_ops(Arc::clone(&bridge));
5656+ event_ops
5757+ .recover_metastore_mutations(&indexes)
5858+ .expect("metastore mutation recovery failed");
5959+6060+ let notifier = bridge.notifier();
6161+6262+ let pool = Arc::new(HandlerPool::spawn::<RealIO>(
6363+ metastore,
6464+ bridge,
6565+ Some(blockstore),
6666+ Some(2),
6767+ ));
6868+6969+ let client = MetastoreClient::<RealIO>::new(pool);
7070+7171+ Arc::new(PostgresRepositories {
7272+ pool: None,
7373+ repo: Arc::new(client.clone()),
7474+ backlink: Arc::new(client.clone()),
7575+ blob: Arc::new(client.clone()),
7676+ user: Arc::new(client.clone()),
7777+ session: Arc::new(client.clone()),
7878+ oauth: Arc::new(client.clone()),
7979+ infra: Arc::new(client.clone()),
8080+ delegation: Arc::new(client.clone()),
8181+ sso: Arc::new(client),
8282+ event_notifier: Arc::new(notifier),
8383+ })
8484+}
8585+8686+async fn create_pg_repos() -> Arc<PostgresRepositories> {
8787+ let db_url = common::get_db_connection_string().await;
8888+ let pool = sqlx::postgres::PgPoolOptions::new()
8989+ .max_connections(5)
9090+ .connect(&db_url)
9191+ .await
9292+ .expect("failed to connect for parity test");
9393+ Arc::new(PostgresRepositories::new(pool))
9494+}
9595+9696+struct ParityFixture {
9797+ pg: Arc<PostgresRepositories>,
9898+ store: Arc<PostgresRepositories>,
9999+}
100100+101101+impl ParityFixture {
102102+ async fn new() -> Self {
103103+ Self {
104104+ pg: create_pg_repos().await,
105105+ store: create_store_repos().await,
106106+ }
107107+ }
108108+}
109109+110110+fn test_did(suffix: &str) -> Did {
111111+ Did::new(format!("did:plc:parity{suffix}")).unwrap()
112112+}
113113+114114+fn test_handle(suffix: &str) -> Handle {
115115+ Handle::new(format!("parity-{suffix}.test")).unwrap()
116116+}
117117+118118+fn test_cid(seed: u8) -> CidLink {
119119+ CidLink::from(helpers::make_cid(&[seed]))
120120+}
121121+122122+fn test_nsid(name: &str) -> Nsid {
123123+ Nsid::new(format!("app.bsky.feed.{name}")).unwrap()
124124+}
125125+126126+fn test_rkey(s: &str) -> Rkey {
127127+ Rkey::new(s).unwrap()
128128+}
129129+130130+fn test_at_uri(did: &Did, collection: &Nsid, rkey: &Rkey) -> AtUri {
131131+ AtUri::new(format!(
132132+ "at://{}/{}/{}",
133133+ did.as_str(),
134134+ collection.as_str(),
135135+ rkey.as_str()
136136+ ))
137137+ .unwrap()
138138+}
139139+140140+async fn seed_repo(
141141+ repos: &PostgresRepositories,
142142+ did: &Did,
143143+ handle: &Handle,
144144+ root_cid: &CidLink,
145145+ user_id: Uuid,
146146+) {
147147+ repos
148148+ .repo
149149+ .create_repo(user_id, did, handle, root_cid, "rev0")
150150+ .await
151151+ .unwrap();
152152+}
153153+154154+async fn seed_records(
155155+ repos: &PostgresRepositories,
156156+ repo_id: Uuid,
157157+ collection: &Nsid,
158158+ records: &[(Rkey, CidLink)],
159159+) {
160160+ let collections: Vec<Nsid> = records.iter().map(|_| collection.clone()).collect();
161161+ let rkeys: Vec<Rkey> = records.iter().map(|(r, _)| r.clone()).collect();
162162+ let cids: Vec<CidLink> = records.iter().map(|(_, c)| c.clone()).collect();
163163+ repos
164164+ .repo
165165+ .upsert_records(repo_id, &collections, &rkeys, &cids, "rev1")
166166+ .await
167167+ .unwrap();
168168+}
169169+170170+#[tokio::test]
171171+async fn parity_server_config() {
172172+ let f = ParityFixture::new().await;
173173+174174+ f.pg.infra
175175+ .upsert_server_config("parity_key", "parity_value")
176176+ .await
177177+ .unwrap();
178178+ f.store
179179+ .infra
180180+ .upsert_server_config("parity_key", "parity_value")
181181+ .await
182182+ .unwrap();
183183+184184+ let pg_val = f.pg.infra.get_server_config("parity_key").await.unwrap();
185185+ let store_val = f.store.infra.get_server_config("parity_key").await.unwrap();
186186+ assert_eq!(pg_val, store_val);
187187+188188+ f.pg.infra.delete_server_config("parity_key").await.unwrap();
189189+ f.store
190190+ .infra
191191+ .delete_server_config("parity_key")
192192+ .await
193193+ .unwrap();
194194+195195+ let pg_gone = f.pg.infra.get_server_config("parity_key").await.unwrap();
196196+ let store_gone = f.store.infra.get_server_config("parity_key").await.unwrap();
197197+ assert_eq!(pg_gone, None);
198198+ assert_eq!(store_gone, None);
199199+}
200200+201201+#[tokio::test]
202202+async fn parity_health_check() {
203203+ let f = ParityFixture::new().await;
204204+205205+ let pg_health = f.pg.infra.health_check().await.unwrap();
206206+ let store_health = f.store.infra.health_check().await.unwrap();
207207+ assert!(pg_health);
208208+ assert!(store_health);
209209+}
210210+211211+#[tokio::test]
212212+async fn parity_rkey_sort_order() {
213213+ let f = ParityFixture::new().await;
214214+ let uid = Uuid::new_v4();
215215+ let did = test_did("rkey");
216216+ let handle = test_handle("rkey");
217217+ let root_cid = test_cid(0);
218218+ let collection = test_nsid("post");
219219+220220+ seed_repo(&f.pg, &did, &handle, &root_cid, uid).await;
221221+ seed_repo(&f.store, &did, &handle, &root_cid, uid).await;
222222+223223+ let records: Vec<(Rkey, CidLink)> = (0u8..10)
224224+ .map(|i| {
225225+ let rkey = test_rkey(&format!("3l{i}aaaaaaaa{i}"));
226226+ let cid = test_cid(i + 1);
227227+ (rkey, cid)
228228+ })
229229+ .collect();
230230+231231+ seed_records(&f.pg, uid, &collection, &records).await;
232232+ seed_records(&f.store, uid, &collection, &records).await;
233233+234234+ let pg_fwd =
235235+ f.pg.repo
236236+ .list_records(uid, &collection, None, 100, false, None, None)
237237+ .await
238238+ .unwrap();
239239+ let store_fwd = f
240240+ .store
241241+ .repo
242242+ .list_records(uid, &collection, None, 100, false, None, None)
243243+ .await
244244+ .unwrap();
245245+246246+ let pg_rkeys: Vec<&str> = pg_fwd.iter().map(|r| r.rkey.as_str()).collect();
247247+ let store_rkeys: Vec<&str> = store_fwd.iter().map(|r| r.rkey.as_str()).collect();
248248+ assert_eq!(pg_rkeys, store_rkeys, "forward rkey order mismatch");
249249+250250+ let pg_rev =
251251+ f.pg.repo
252252+ .list_records(uid, &collection, None, 100, true, None, None)
253253+ .await
254254+ .unwrap();
255255+ let store_rev = f
256256+ .store
257257+ .repo
258258+ .list_records(uid, &collection, None, 100, true, None, None)
259259+ .await
260260+ .unwrap();
261261+262262+ let pg_rkeys_rev: Vec<&str> = pg_rev.iter().map(|r| r.rkey.as_str()).collect();
263263+ let store_rkeys_rev: Vec<&str> = store_rev.iter().map(|r| r.rkey.as_str()).collect();
264264+ assert_eq!(pg_rkeys_rev, store_rkeys_rev, "reverse rkey order mismatch");
265265+266266+ let pg_cids: Vec<&str> = pg_fwd.iter().map(|r| r.record_cid.as_str()).collect();
267267+ let store_cids: Vec<&str> = store_fwd.iter().map(|r| r.record_cid.as_str()).collect();
268268+ assert_eq!(pg_cids, store_cids, "cid mapping mismatch");
269269+}
270270+271271+#[tokio::test]
272272+async fn parity_cursor_pagination() {
273273+ let f = ParityFixture::new().await;
274274+ let uid = Uuid::new_v4();
275275+ let did = test_did("cursor");
276276+ let handle = test_handle("cursor");
277277+ let root_cid = test_cid(0);
278278+ let collection = test_nsid("post");
279279+280280+ seed_repo(&f.pg, &did, &handle, &root_cid, uid).await;
281281+ seed_repo(&f.store, &did, &handle, &root_cid, uid).await;
282282+283283+ let records: Vec<(Rkey, CidLink)> = (0u8..20)
284284+ .map(|i| {
285285+ let rkey = test_rkey(&format!("3l{:02}aaaaaaaaa", i));
286286+ let cid = test_cid(i + 1);
287287+ (rkey, cid)
288288+ })
289289+ .collect();
290290+291291+ seed_records(&f.pg, uid, &collection, &records).await;
292292+ seed_records(&f.store, uid, &collection, &records).await;
293293+294294+ let mut pg_all = Vec::new();
295295+ let mut store_all = Vec::new();
296296+ let mut pg_cursor: Option<Rkey> = None;
297297+ let mut store_cursor: Option<Rkey> = None;
298298+ let limit = 5i64;
299299+ let mut pages = 0;
300300+301301+ loop {
302302+ let pg_page =
303303+ f.pg.repo
304304+ .list_records(
305305+ uid,
306306+ &collection,
307307+ pg_cursor.as_ref(),
308308+ limit,
309309+ false,
310310+ None,
311311+ None,
312312+ )
313313+ .await
314314+ .unwrap();
315315+ let store_page = f
316316+ .store
317317+ .repo
318318+ .list_records(
319319+ uid,
320320+ &collection,
321321+ store_cursor.as_ref(),
322322+ limit,
323323+ false,
324324+ None,
325325+ None,
326326+ )
327327+ .await
328328+ .unwrap();
329329+330330+ assert_eq!(
331331+ pg_page.len(),
332332+ store_page.len(),
333333+ "page size mismatch at page {pages}"
334334+ );
335335+336336+ let pg_rkeys: Vec<&str> = pg_page.iter().map(|r| r.rkey.as_str()).collect();
337337+ let store_rkeys: Vec<&str> = store_page.iter().map(|r| r.rkey.as_str()).collect();
338338+ assert_eq!(
339339+ pg_rkeys, store_rkeys,
340340+ "page content mismatch at page {pages}"
341341+ );
342342+343343+ pg_all.extend(pg_page.iter().map(|r| r.rkey.clone()));
344344+ store_all.extend(store_page.iter().map(|r| r.rkey.clone()));
345345+346346+ if pg_page.len() < limit as usize {
347347+ break;
348348+ }
349349+350350+ pg_cursor = pg_page.last().map(|r| r.rkey.clone());
351351+ store_cursor = store_page.last().map(|r| r.rkey.clone());
352352+ pages += 1;
353353+ }
354354+355355+ assert_eq!(pg_all.len(), 20);
356356+ assert_eq!(store_all.len(), 20);
357357+}
358358+359359+#[tokio::test]
360360+async fn parity_cursor_pagination_reverse() {
361361+ let f = ParityFixture::new().await;
362362+ let uid = Uuid::new_v4();
363363+ let did = test_did("currev");
364364+ let handle = test_handle("currev");
365365+ let root_cid = test_cid(0);
366366+ let collection = test_nsid("post");
367367+368368+ seed_repo(&f.pg, &did, &handle, &root_cid, uid).await;
369369+ seed_repo(&f.store, &did, &handle, &root_cid, uid).await;
370370+371371+ let records: Vec<(Rkey, CidLink)> = (0u8..15)
372372+ .map(|i| {
373373+ let rkey = test_rkey(&format!("3l{:02}aaaaaaaaa", i));
374374+ let cid = test_cid(i + 1);
375375+ (rkey, cid)
376376+ })
377377+ .collect();
378378+379379+ seed_records(&f.pg, uid, &collection, &records).await;
380380+ seed_records(&f.store, uid, &collection, &records).await;
381381+382382+ let mut pg_all = Vec::new();
383383+ let mut store_all = Vec::new();
384384+ let mut pg_cursor: Option<Rkey> = None;
385385+ let mut store_cursor: Option<Rkey> = None;
386386+ let limit = 4i64;
387387+388388+ loop {
389389+ let pg_page =
390390+ f.pg.repo
391391+ .list_records(
392392+ uid,
393393+ &collection,
394394+ pg_cursor.as_ref(),
395395+ limit,
396396+ true,
397397+ None,
398398+ None,
399399+ )
400400+ .await
401401+ .unwrap();
402402+ let store_page = f
403403+ .store
404404+ .repo
405405+ .list_records(
406406+ uid,
407407+ &collection,
408408+ store_cursor.as_ref(),
409409+ limit,
410410+ true,
411411+ None,
412412+ None,
413413+ )
414414+ .await
415415+ .unwrap();
416416+417417+ let pg_rkeys: Vec<&str> = pg_page.iter().map(|r| r.rkey.as_str()).collect();
418418+ let store_rkeys: Vec<&str> = store_page.iter().map(|r| r.rkey.as_str()).collect();
419419+ assert_eq!(pg_rkeys, store_rkeys, "reverse page mismatch");
420420+421421+ pg_all.extend(pg_page.iter().map(|r| r.rkey.clone()));
422422+ store_all.extend(store_page.iter().map(|r| r.rkey.clone()));
423423+424424+ if pg_page.len() < limit as usize {
425425+ break;
426426+ }
427427+428428+ pg_cursor = pg_page.last().map(|r| r.rkey.clone());
429429+ store_cursor = store_page.last().map(|r| r.rkey.clone());
430430+ }
431431+432432+ assert_eq!(pg_all.len(), 15);
433433+ assert_eq!(store_all.len(), 15);
434434+}
435435+436436+#[tokio::test]
437437+async fn parity_rkey_range_query() {
438438+ let f = ParityFixture::new().await;
439439+ let uid = Uuid::new_v4();
440440+ let did = test_did("range");
441441+ let handle = test_handle("range");
442442+ let root_cid = test_cid(0);
443443+ let collection = test_nsid("post");
444444+445445+ seed_repo(&f.pg, &did, &handle, &root_cid, uid).await;
446446+ seed_repo(&f.store, &did, &handle, &root_cid, uid).await;
447447+448448+ let records: Vec<(Rkey, CidLink)> = (0u8..10)
449449+ .map(|i| {
450450+ let rkey = test_rkey(&format!("3l{:02}aaaaaaaaa", i));
451451+ let cid = test_cid(i + 1);
452452+ (rkey, cid)
453453+ })
454454+ .collect();
455455+456456+ seed_records(&f.pg, uid, &collection, &records).await;
457457+ seed_records(&f.store, uid, &collection, &records).await;
458458+459459+ let start = test_rkey("3l03aaaaaaaaa");
460460+ let end = test_rkey("3l07aaaaaaaaa");
461461+462462+ let pg_range =
463463+ f.pg.repo
464464+ .list_records(uid, &collection, None, 100, false, Some(&start), Some(&end))
465465+ .await
466466+ .unwrap();
467467+ let store_range = f
468468+ .store
469469+ .repo
470470+ .list_records(uid, &collection, None, 100, false, Some(&start), Some(&end))
471471+ .await
472472+ .unwrap();
473473+474474+ let pg_rkeys: Vec<&str> = pg_range.iter().map(|r| r.rkey.as_str()).collect();
475475+ let store_rkeys: Vec<&str> = store_range.iter().map(|r| r.rkey.as_str()).collect();
476476+ assert_eq!(pg_rkeys, store_rkeys, "range query mismatch");
477477+}
478478+479479+#[tokio::test]
480480+async fn parity_collection_listing() {
481481+ let f = ParityFixture::new().await;
482482+ let uid = Uuid::new_v4();
483483+ let did = test_did("colls");
484484+ let handle = test_handle("colls");
485485+ let root_cid = test_cid(0);
486486+487487+ seed_repo(&f.pg, &did, &handle, &root_cid, uid).await;
488488+ seed_repo(&f.store, &did, &handle, &root_cid, uid).await;
489489+490490+ let post_ns = test_nsid("post");
491491+ let like_ns = Nsid::new("app.bsky.feed.like").unwrap();
492492+ let repost_ns = Nsid::new("app.bsky.feed.repost").unwrap();
493493+ let follow_ns = Nsid::new("app.bsky.graph.follow").unwrap();
494494+495495+ let post_records = vec![(test_rkey("3laaaaaaaaa01"), test_cid(1))];
496496+ let like_records = vec![(test_rkey("3laaaaaaaaa02"), test_cid(2))];
497497+ let repost_records = vec![(test_rkey("3laaaaaaaaa03"), test_cid(3))];
498498+ let follow_records = vec![(test_rkey("3laaaaaaaaa04"), test_cid(4))];
499499+500500+ seed_records(&f.pg, uid, &post_ns, &post_records).await;
501501+ seed_records(&f.pg, uid, &like_ns, &like_records).await;
502502+ seed_records(&f.pg, uid, &repost_ns, &repost_records).await;
503503+ seed_records(&f.pg, uid, &follow_ns, &follow_records).await;
504504+505505+ seed_records(&f.store, uid, &post_ns, &post_records).await;
506506+ seed_records(&f.store, uid, &like_ns, &like_records).await;
507507+ seed_records(&f.store, uid, &repost_ns, &repost_records).await;
508508+ seed_records(&f.store, uid, &follow_ns, &follow_records).await;
509509+510510+ let mut pg_colls: Vec<String> =
511511+ f.pg.repo
512512+ .list_collections(uid)
513513+ .await
514514+ .unwrap()
515515+ .into_iter()
516516+ .map(|n| n.as_str().to_owned())
517517+ .collect();
518518+ pg_colls.sort();
519519+520520+ let mut store_colls: Vec<String> = f
521521+ .store
522522+ .repo
523523+ .list_collections(uid)
524524+ .await
525525+ .unwrap()
526526+ .into_iter()
527527+ .map(|n| n.as_str().to_owned())
528528+ .collect();
529529+ store_colls.sort();
530530+531531+ assert_eq!(pg_colls, store_colls, "collection listing mismatch");
532532+ assert_eq!(pg_colls.len(), 4);
533533+534534+ let pg_count = f.pg.repo.count_records(uid).await.unwrap();
535535+ let store_count = f.store.repo.count_records(uid).await.unwrap();
536536+ assert_eq!(pg_count, store_count, "record count mismatch");
537537+ assert_eq!(pg_count, 4);
538538+}
539539+540540+#[tokio::test]
541541+async fn parity_record_get_and_delete() {
542542+ let f = ParityFixture::new().await;
543543+ let uid = Uuid::new_v4();
544544+ let did = test_did("getdel");
545545+ let handle = test_handle("getdel");
546546+ let root_cid = test_cid(0);
547547+ let collection = test_nsid("post");
548548+ let rkey = test_rkey("3laaaaaaaaa01");
549549+ let cid = test_cid(1);
550550+551551+ seed_repo(&f.pg, &did, &handle, &root_cid, uid).await;
552552+ seed_repo(&f.store, &did, &handle, &root_cid, uid).await;
553553+554554+ seed_records(&f.pg, uid, &collection, &[(rkey.clone(), cid.clone())]).await;
555555+ seed_records(&f.store, uid, &collection, &[(rkey.clone(), cid.clone())]).await;
556556+557557+ let pg_cid =
558558+ f.pg.repo
559559+ .get_record_cid(uid, &collection, &rkey)
560560+ .await
561561+ .unwrap();
562562+ let store_cid = f
563563+ .store
564564+ .repo
565565+ .get_record_cid(uid, &collection, &rkey)
566566+ .await
567567+ .unwrap();
568568+ assert_eq!(pg_cid, store_cid, "get_record_cid mismatch");
569569+ assert!(pg_cid.is_some());
570570+571571+ f.pg.repo
572572+ .delete_records(uid, &[collection.clone()], &[rkey.clone()])
573573+ .await
574574+ .unwrap();
575575+ f.store
576576+ .repo
577577+ .delete_records(uid, &[collection.clone()], &[rkey.clone()])
578578+ .await
579579+ .unwrap();
580580+581581+ let pg_gone =
582582+ f.pg.repo
583583+ .get_record_cid(uid, &collection, &rkey)
584584+ .await
585585+ .unwrap();
586586+ let store_gone = f
587587+ .store
588588+ .repo
589589+ .get_record_cid(uid, &collection, &rkey)
590590+ .await
591591+ .unwrap();
592592+ assert_eq!(pg_gone, None);
593593+ assert_eq!(store_gone, None);
594594+}
595595+596596+#[tokio::test]
597597+async fn parity_backlink_queries() {
598598+ let f = ParityFixture::new().await;
599599+ let uid = Uuid::new_v4();
600600+ let did = test_did("blink");
601601+ let handle = test_handle("blink");
602602+ let root_cid = test_cid(0);
603603+ let like_ns = Nsid::new("app.bsky.feed.like").unwrap();
604604+ let target_did = test_did("target");
605605+606606+ seed_repo(&f.pg, &did, &handle, &root_cid, uid).await;
607607+ seed_repo(&f.store, &did, &handle, &root_cid, uid).await;
608608+609609+ let rkey1 = test_rkey("3laaaaaaaaa01");
610610+ let rkey2 = test_rkey("3laaaaaaaaa02");
611611+ let uri1 = test_at_uri(&did, &like_ns, &rkey1);
612612+ let uri2 = test_at_uri(&did, &like_ns, &rkey2);
613613+ let target_uri = format!(
614614+ "at://{}/app.bsky.feed.post/3laaaaaaaaa99",
615615+ target_did.as_str()
616616+ );
617617+618618+ let backlinks = vec![
619619+ Backlink {
620620+ uri: uri1.clone(),
621621+ path: BacklinkPath::Subject,
622622+ link_to: target_uri.clone(),
623623+ },
624624+ Backlink {
625625+ uri: uri2.clone(),
626626+ path: BacklinkPath::Subject,
627627+ link_to: target_uri.clone(),
628628+ },
629629+ ];
630630+631631+ f.pg.backlink.add_backlinks(uid, &backlinks).await.unwrap();
632632+ f.store
633633+ .backlink
634634+ .add_backlinks(uid, &backlinks)
635635+ .await
636636+ .unwrap();
637637+638638+ let conflict_backlink = Backlink {
639639+ uri: uri1.clone(),
640640+ path: BacklinkPath::Subject,
641641+ link_to: target_uri.clone(),
642642+ };
643643+644644+ let pg_conflicts =
645645+ f.pg.backlink
646646+ .get_backlink_conflicts(uid, &like_ns, &[conflict_backlink.clone()])
647647+ .await
648648+ .unwrap();
649649+ let store_conflicts = f
650650+ .store
651651+ .backlink
652652+ .get_backlink_conflicts(uid, &like_ns, &[conflict_backlink])
653653+ .await
654654+ .unwrap();
655655+656656+ assert_eq!(
657657+ pg_conflicts.len(),
658658+ store_conflicts.len(),
659659+ "backlink conflict count mismatch"
660660+ );
661661+662662+ f.pg.backlink.remove_backlinks_by_uri(&uri1).await.unwrap();
663663+ f.store
664664+ .backlink
665665+ .remove_backlinks_by_uri(&uri1)
666666+ .await
667667+ .unwrap();
668668+669669+ let post_removal = Backlink {
670670+ uri: uri1.clone(),
671671+ path: BacklinkPath::Subject,
672672+ link_to: target_uri.clone(),
673673+ };
674674+675675+ let pg_after =
676676+ f.pg.backlink
677677+ .get_backlink_conflicts(uid, &like_ns, &[post_removal.clone()])
678678+ .await
679679+ .unwrap();
680680+ let store_after = f
681681+ .store
682682+ .backlink
683683+ .get_backlink_conflicts(uid, &like_ns, &[post_removal])
684684+ .await
685685+ .unwrap();
686686+687687+ assert_eq!(
688688+ pg_after.len(),
689689+ store_after.len(),
690690+ "backlink conflicts after removal mismatch"
691691+ );
692692+}
693693+694694+#[tokio::test]
695695+async fn parity_backlink_remove_by_repo() {
696696+ let f = ParityFixture::new().await;
697697+ let uid = Uuid::new_v4();
698698+ let did = test_did("blrep");
699699+ let handle = test_handle("blrep");
700700+ let root_cid = test_cid(0);
701701+ let like_ns = Nsid::new("app.bsky.feed.like").unwrap();
702702+703703+ seed_repo(&f.pg, &did, &handle, &root_cid, uid).await;
704704+ seed_repo(&f.store, &did, &handle, &root_cid, uid).await;
705705+706706+ let rkey = test_rkey("3laaaaaaaaa01");
707707+ let uri = test_at_uri(&did, &like_ns, &rkey);
708708+ let backlinks = vec![Backlink {
709709+ uri: uri.clone(),
710710+ path: BacklinkPath::Subject,
711711+ link_to: "at://did:plc:sometarget/app.bsky.feed.post/abc".to_owned(),
712712+ }];
713713+714714+ f.pg.backlink.add_backlinks(uid, &backlinks).await.unwrap();
715715+ f.store
716716+ .backlink
717717+ .add_backlinks(uid, &backlinks)
718718+ .await
719719+ .unwrap();
720720+721721+ f.pg.backlink.remove_backlinks_by_repo(uid).await.unwrap();
722722+ f.store
723723+ .backlink
724724+ .remove_backlinks_by_repo(uid)
725725+ .await
726726+ .unwrap();
727727+728728+ let probe = Backlink {
729729+ uri,
730730+ path: BacklinkPath::Subject,
731731+ link_to: "at://did:plc:sometarget/app.bsky.feed.post/abc".to_owned(),
732732+ };
733733+ let pg_after =
734734+ f.pg.backlink
735735+ .get_backlink_conflicts(uid, &like_ns, &[probe.clone()])
736736+ .await
737737+ .unwrap();
738738+ let store_after = f
739739+ .store
740740+ .backlink
741741+ .get_backlink_conflicts(uid, &like_ns, &[probe])
742742+ .await
743743+ .unwrap();
744744+ assert_eq!(pg_after.len(), 0);
745745+ assert_eq!(store_after.len(), 0);
746746+}
747747+748748+#[tokio::test]
749749+async fn parity_blob_metadata() {
750750+ let f = ParityFixture::new().await;
751751+ let uid = Uuid::new_v4();
752752+ let did = test_did("blob");
753753+ let handle = test_handle("blob");
754754+ let root_cid = test_cid(0);
755755+756756+ seed_repo(&f.pg, &did, &handle, &root_cid, uid).await;
757757+ seed_repo(&f.store, &did, &handle, &root_cid, uid).await;
758758+759759+ let blob_cid1 = test_cid(101);
760760+ let blob_cid2 = test_cid(102);
761761+ let blob_cid3 = test_cid(103);
762762+763763+ let blobs = [
764764+ (&blob_cid1, "image/png", 1024i64, "blobs/a.png"),
765765+ (&blob_cid2, "image/jpeg", 2048, "blobs/b.jpg"),
766766+ (&blob_cid3, "application/pdf", 4096, "blobs/c.pdf"),
767767+ ];
768768+769769+ blobs.iter().for_each(|(cid, mime, size, key)| {
770770+ let pg = Arc::clone(&f.pg);
771771+ let store = Arc::clone(&f.store);
772772+ let cid = (*cid).clone();
773773+ let size = *size;
774774+ let mime = mime.to_string();
775775+ let key = key.to_string();
776776+ tokio::task::block_in_place(|| {
777777+ tokio::runtime::Handle::current().block_on(async {
778778+ pg.blob
779779+ .insert_blob(&cid, &mime, size, uid, &key)
780780+ .await
781781+ .unwrap();
782782+ store
783783+ .blob
784784+ .insert_blob(&cid, &mime, size, uid, &key)
785785+ .await
786786+ .unwrap();
787787+ });
788788+ });
789789+ });
790790+791791+ let pg_meta =
792792+ f.pg.blob
793793+ .get_blob_metadata(&blob_cid1)
794794+ .await
795795+ .unwrap()
796796+ .unwrap();
797797+ let store_meta = f
798798+ .store
799799+ .blob
800800+ .get_blob_metadata(&blob_cid1)
801801+ .await
802802+ .unwrap()
803803+ .unwrap();
804804+ assert_eq!(pg_meta.mime_type, store_meta.mime_type);
805805+ assert_eq!(pg_meta.size_bytes, store_meta.size_bytes);
806806+ assert_eq!(pg_meta.storage_key, store_meta.storage_key);
807807+808808+ let pg_key = f.pg.blob.get_blob_storage_key(&blob_cid2).await.unwrap();
809809+ let store_key = f.store.blob.get_blob_storage_key(&blob_cid2).await.unwrap();
810810+ assert_eq!(pg_key, store_key);
811811+812812+ let pg_count = f.pg.blob.count_blobs_by_user(uid).await.unwrap();
813813+ let store_count = f.store.blob.count_blobs_by_user(uid).await.unwrap();
814814+ assert_eq!(pg_count, store_count);
815815+ assert_eq!(pg_count, 3);
816816+817817+ let pg_list = f.pg.blob.list_blobs_by_user(uid, None, 100).await.unwrap();
818818+ let store_list = f
819819+ .store
820820+ .blob
821821+ .list_blobs_by_user(uid, None, 100)
822822+ .await
823823+ .unwrap();
824824+ assert_eq!(pg_list.len(), store_list.len());
825825+}
826826+827827+#[tokio::test]
828828+async fn parity_blob_pagination() {
829829+ let f = ParityFixture::new().await;
830830+ let uid = Uuid::new_v4();
831831+ let did = test_did("blobpg");
832832+ let handle = test_handle("blobpg");
833833+ let root_cid = test_cid(0);
834834+835835+ seed_repo(&f.pg, &did, &handle, &root_cid, uid).await;
836836+ seed_repo(&f.store, &did, &handle, &root_cid, uid).await;
837837+838838+ (0u8..8).for_each(|i| {
839839+ let cid = test_cid(200 + i);
840840+ let key = format!("blobs/pg_{i}.bin");
841841+ let pg = Arc::clone(&f.pg);
842842+ let store = Arc::clone(&f.store);
843843+ tokio::task::block_in_place(|| {
844844+ tokio::runtime::Handle::current().block_on(async {
845845+ pg.blob
846846+ .insert_blob(
847847+ &cid,
848848+ "application/octet-stream",
849849+ 512 * (i as i64 + 1),
850850+ uid,
851851+ &key,
852852+ )
853853+ .await
854854+ .unwrap();
855855+ store
856856+ .blob
857857+ .insert_blob(
858858+ &cid,
859859+ "application/octet-stream",
860860+ 512 * (i as i64 + 1),
861861+ uid,
862862+ &key,
863863+ )
864864+ .await
865865+ .unwrap();
866866+ });
867867+ });
868868+ });
869869+870870+ let mut pg_all = Vec::new();
871871+ let mut store_all = Vec::new();
872872+ let mut pg_cursor: Option<String> = None;
873873+ let mut store_cursor: Option<String> = None;
874874+ let limit = 3i64;
875875+876876+ loop {
877877+ let pg_page =
878878+ f.pg.blob
879879+ .list_blobs_by_user(uid, pg_cursor.as_deref(), limit)
880880+ .await
881881+ .unwrap();
882882+ let store_page = f
883883+ .store
884884+ .blob
885885+ .list_blobs_by_user(uid, store_cursor.as_deref(), limit)
886886+ .await
887887+ .unwrap();
888888+889889+ assert_eq!(pg_page.len(), store_page.len(), "blob page size mismatch");
890890+891891+ let pg_cids: Vec<&str> = pg_page.iter().map(|c| c.as_str()).collect();
892892+ let store_cids: Vec<&str> = store_page.iter().map(|c| c.as_str()).collect();
893893+ assert_eq!(pg_cids, store_cids, "blob page content mismatch");
894894+895895+ pg_all.extend(pg_page.iter().map(|c| c.as_str().to_owned()));
896896+ store_all.extend(store_page.iter().map(|c| c.as_str().to_owned()));
897897+898898+ if pg_page.len() < limit as usize {
899899+ break;
900900+ }
901901+902902+ pg_cursor = pg_page.last().map(|c| c.as_str().to_owned());
903903+ store_cursor = store_page.last().map(|c| c.as_str().to_owned());
904904+ }
905905+906906+ assert_eq!(pg_all.len(), 8);
907907+ assert_eq!(store_all.len(), 8);
908908+}
909909+910910+#[tokio::test]
911911+async fn parity_blob_duplicate_insert() {
912912+ let f = ParityFixture::new().await;
913913+ let uid = Uuid::new_v4();
914914+ let did = test_did("blobdup");
915915+ let handle = test_handle("blobdup");
916916+ let root_cid = test_cid(0);
917917+918918+ seed_repo(&f.pg, &did, &handle, &root_cid, uid).await;
919919+ seed_repo(&f.store, &did, &handle, &root_cid, uid).await;
920920+921921+ let cid = test_cid(150);
922922+923923+ let pg_first =
924924+ f.pg.blob
925925+ .insert_blob(&cid, "image/png", 1024, uid, "blobs/dup.png")
926926+ .await
927927+ .unwrap();
928928+ let store_first = f
929929+ .store
930930+ .blob
931931+ .insert_blob(&cid, "image/png", 1024, uid, "blobs/dup.png")
932932+ .await
933933+ .unwrap();
934934+ assert_eq!(pg_first, store_first);
935935+936936+ let pg_dup =
937937+ f.pg.blob
938938+ .insert_blob(&cid, "image/png", 1024, uid, "blobs/dup.png")
939939+ .await
940940+ .unwrap();
941941+ let store_dup = f
942942+ .store
943943+ .blob
944944+ .insert_blob(&cid, "image/png", 1024, uid, "blobs/dup.png")
945945+ .await
946946+ .unwrap();
947947+ assert_eq!(pg_dup, store_dup);
948948+}
949949+950950+#[tokio::test]
951951+async fn parity_get_all_records() {
952952+ let f = ParityFixture::new().await;
953953+ let uid = Uuid::new_v4();
954954+ let did = test_did("allrec");
955955+ let handle = test_handle("allrec");
956956+ let root_cid = test_cid(0);
957957+958958+ seed_repo(&f.pg, &did, &handle, &root_cid, uid).await;
959959+ seed_repo(&f.store, &did, &handle, &root_cid, uid).await;
960960+961961+ let post_ns = test_nsid("post");
962962+ let like_ns = Nsid::new("app.bsky.feed.like").unwrap();
963963+964964+ let posts = vec![
965965+ (test_rkey("3laaaaaaaaa01"), test_cid(1)),
966966+ (test_rkey("3laaaaaaaaa02"), test_cid(2)),
967967+ ];
968968+ let likes = vec![(test_rkey("3laaaaaaaaa03"), test_cid(3))];
969969+970970+ seed_records(&f.pg, uid, &post_ns, &posts).await;
971971+ seed_records(&f.pg, uid, &like_ns, &likes).await;
972972+ seed_records(&f.store, uid, &post_ns, &posts).await;
973973+ seed_records(&f.store, uid, &like_ns, &likes).await;
974974+975975+ let mut pg_all = f.pg.repo.get_all_records(uid).await.unwrap();
976976+ let mut store_all = f.store.repo.get_all_records(uid).await.unwrap();
977977+978978+ pg_all.sort_by(|a, b| {
979979+ a.collection
980980+ .as_str()
981981+ .cmp(b.collection.as_str())
982982+ .then(a.rkey.as_str().cmp(b.rkey.as_str()))
983983+ });
984984+ store_all.sort_by(|a, b| {
985985+ a.collection
986986+ .as_str()
987987+ .cmp(b.collection.as_str())
988988+ .then(a.rkey.as_str().cmp(b.rkey.as_str()))
989989+ });
990990+991991+ assert_eq!(pg_all.len(), store_all.len());
992992+ pg_all.iter().zip(store_all.iter()).for_each(|(p, s)| {
993993+ assert_eq!(p.collection.as_str(), s.collection.as_str());
994994+ assert_eq!(p.rkey.as_str(), s.rkey.as_str());
995995+ assert_eq!(p.record_cid.as_str(), s.record_cid.as_str());
996996+ });
997997+}
998998+999999+#[tokio::test]
10001000+async fn parity_comms_queue() {
10011001+ let f = ParityFixture::new().await;
10021002+ let uid = Uuid::new_v4();
10031003+10041004+ let pg_id =
10051005+ f.pg.infra
10061006+ .enqueue_comms(
10071007+ Some(uid),
10081008+ CommsChannel::Email,
10091009+ CommsType::Welcome,
10101010+ "test@example.com",
10111011+ Some("Welcome"),
10121012+ "Welcome body",
10131013+ None,
10141014+ )
10151015+ .await
10161016+ .unwrap();
10171017+10181018+ let store_id = f
10191019+ .store
10201020+ .infra
10211021+ .enqueue_comms(
10221022+ Some(uid),
10231023+ CommsChannel::Email,
10241024+ CommsType::Welcome,
10251025+ "test@example.com",
10261026+ Some("Welcome"),
10271027+ "Welcome body",
10281028+ None,
10291029+ )
10301030+ .await
10311031+ .unwrap();
10321032+10331033+ assert_ne!(pg_id, Uuid::nil());
10341034+ assert_ne!(store_id, Uuid::nil());
10351035+10361036+ let pg_latest =
10371037+ f.pg.infra
10381038+ .get_latest_comms_for_user(uid, CommsType::Welcome, 10)
10391039+ .await
10401040+ .unwrap();
10411041+ let store_latest = f
10421042+ .store
10431043+ .infra
10441044+ .get_latest_comms_for_user(uid, CommsType::Welcome, 10)
10451045+ .await
10461046+ .unwrap();
10471047+10481048+ assert_eq!(pg_latest.len(), store_latest.len());
10491049+ assert_eq!(pg_latest[0].body, store_latest[0].body);
10501050+10511051+ let pg_count =
10521052+ f.pg.infra
10531053+ .count_comms_by_type(uid, CommsType::Welcome)
10541054+ .await
10551055+ .unwrap();
10561056+ let store_count = f
10571057+ .store
10581058+ .infra
10591059+ .count_comms_by_type(uid, CommsType::Welcome)
10601060+ .await
10611061+ .unwrap();
10621062+ assert_eq!(pg_count, store_count);
10631063+ assert_eq!(pg_count, 1);
10641064+}
10651065+10661066+#[tokio::test]
10671067+async fn parity_invite_codes() {
10681068+ let f = ParityFixture::new().await;
10691069+ let code = format!("parity-invite-{}", Uuid::new_v4());
10701070+10711071+ let pg_created = f.pg.infra.create_invite_code(&code, 5, None).await.unwrap();
10721072+ let store_created = f
10731073+ .store
10741074+ .infra
10751075+ .create_invite_code(&code, 5, None)
10761076+ .await
10771077+ .unwrap();
10781078+ assert_eq!(pg_created, store_created);
10791079+10801080+ let pg_uses =
10811081+ f.pg.infra
10821082+ .get_invite_code_available_uses(&code)
10831083+ .await
10841084+ .unwrap();
10851085+ let store_uses = f
10861086+ .store
10871087+ .infra
10881088+ .get_invite_code_available_uses(&code)
10891089+ .await
10901090+ .unwrap();
10911091+ assert_eq!(pg_uses, store_uses);
10921092+ assert_eq!(pg_uses, Some(5));
10931093+}
10941094+10951095+#[tokio::test]
10961096+async fn parity_account_preferences() {
10971097+ let f = ParityFixture::new().await;
10981098+ let uid = Uuid::new_v4();
10991099+ let did = test_did("prefs");
11001100+ let handle = test_handle("prefs");
11011101+ let root_cid = test_cid(0);
11021102+11031103+ seed_repo(&f.pg, &did, &handle, &root_cid, uid).await;
11041104+ seed_repo(&f.store, &did, &handle, &root_cid, uid).await;
11051105+11061106+ let pref_value = serde_json::json!({
11071107+ "$type": "app.bsky.actor.defs#adultContentPref",
11081108+ "enabled": false
11091109+ });
11101110+11111111+ f.pg.infra
11121112+ .upsert_account_preference(
11131113+ uid,
11141114+ "app.bsky.actor.defs#adultContentPref/0",
11151115+ pref_value.clone(),
11161116+ )
11171117+ .await
11181118+ .unwrap();
11191119+ f.store
11201120+ .infra
11211121+ .upsert_account_preference(uid, "app.bsky.actor.defs#adultContentPref/0", pref_value)
11221122+ .await
11231123+ .unwrap();
11241124+11251125+ let mut pg_prefs = f.pg.infra.get_account_preferences(uid).await.unwrap();
11261126+ let mut store_prefs = f.store.infra.get_account_preferences(uid).await.unwrap();
11271127+11281128+ pg_prefs.sort_by(|a, b| a.0.cmp(&b.0));
11291129+ store_prefs.sort_by(|a, b| a.0.cmp(&b.0));
11301130+11311131+ assert_eq!(pg_prefs.len(), store_prefs.len());
11321132+ pg_prefs.iter().zip(store_prefs.iter()).for_each(|(p, s)| {
11331133+ assert_eq!(p.0, s.0);
11341134+ assert_eq!(p.1, s.1);
11351135+ });
11361136+}
11371137+11381138+#[tokio::test]
11391139+async fn parity_record_upsert_overwrites() {
11401140+ let f = ParityFixture::new().await;
11411141+ let uid = Uuid::new_v4();
11421142+ let did = test_did("upsert");
11431143+ let handle = test_handle("upsert");
11441144+ let root_cid = test_cid(0);
11451145+ let collection = test_nsid("post");
11461146+ let rkey = test_rkey("3laaaaaaaaa01");
11471147+11481148+ seed_repo(&f.pg, &did, &handle, &root_cid, uid).await;
11491149+ seed_repo(&f.store, &did, &handle, &root_cid, uid).await;
11501150+11511151+ let cid_v1 = test_cid(1);
11521152+ seed_records(&f.pg, uid, &collection, &[(rkey.clone(), cid_v1.clone())]).await;
11531153+ seed_records(
11541154+ &f.store,
11551155+ uid,
11561156+ &collection,
11571157+ &[(rkey.clone(), cid_v1.clone())],
11581158+ )
11591159+ .await;
11601160+11611161+ let cid_v2 = test_cid(2);
11621162+ seed_records(&f.pg, uid, &collection, &[(rkey.clone(), cid_v2.clone())]).await;
11631163+ seed_records(
11641164+ &f.store,
11651165+ uid,
11661166+ &collection,
11671167+ &[(rkey.clone(), cid_v2.clone())],
11681168+ )
11691169+ .await;
11701170+11711171+ let pg_cid =
11721172+ f.pg.repo
11731173+ .get_record_cid(uid, &collection, &rkey)
11741174+ .await
11751175+ .unwrap();
11761176+ let store_cid = f
11771177+ .store
11781178+ .repo
11791179+ .get_record_cid(uid, &collection, &rkey)
11801180+ .await
11811181+ .unwrap();
11821182+ assert_eq!(pg_cid, store_cid);
11831183+ assert_eq!(pg_cid.unwrap().as_str(), cid_v2.as_str());
11841184+11851185+ let pg_count = f.pg.repo.count_records(uid).await.unwrap();
11861186+ let store_count = f.store.repo.count_records(uid).await.unwrap();
11871187+ assert_eq!(pg_count, 1);
11881188+ assert_eq!(store_count, 1);
11891189+}
11901190+11911191+#[tokio::test]
11921192+async fn parity_empty_queries() {
11931193+ let f = ParityFixture::new().await;
11941194+ let uid = Uuid::new_v4();
11951195+ let did = test_did("empty");
11961196+ let handle = test_handle("empty");
11971197+ let root_cid = test_cid(0);
11981198+ let collection = test_nsid("post");
11991199+12001200+ seed_repo(&f.pg, &did, &handle, &root_cid, uid).await;
12011201+ seed_repo(&f.store, &did, &handle, &root_cid, uid).await;
12021202+12031203+ let pg_records =
12041204+ f.pg.repo
12051205+ .list_records(uid, &collection, None, 100, false, None, None)
12061206+ .await
12071207+ .unwrap();
12081208+ let store_records = f
12091209+ .store
12101210+ .repo
12111211+ .list_records(uid, &collection, None, 100, false, None, None)
12121212+ .await
12131213+ .unwrap();
12141214+ assert_eq!(pg_records.len(), 0);
12151215+ assert_eq!(store_records.len(), 0);
12161216+12171217+ let pg_colls = f.pg.repo.list_collections(uid).await.unwrap();
12181218+ let store_colls = f.store.repo.list_collections(uid).await.unwrap();
12191219+ assert_eq!(pg_colls.len(), 0);
12201220+ assert_eq!(store_colls.len(), 0);
12211221+12221222+ let pg_count = f.pg.repo.count_records(uid).await.unwrap();
12231223+ let store_count = f.store.repo.count_records(uid).await.unwrap();
12241224+ assert_eq!(pg_count, 0);
12251225+ assert_eq!(store_count, 0);
12261226+12271227+ let pg_blobs = f.pg.blob.list_blobs_by_user(uid, None, 100).await.unwrap();
12281228+ let store_blobs = f
12291229+ .store
12301230+ .blob
12311231+ .list_blobs_by_user(uid, None, 100)
12321232+ .await
12331233+ .unwrap();
12341234+ assert_eq!(pg_blobs.len(), 0);
12351235+ assert_eq!(store_blobs.len(), 0);
12361236+12371237+ let nonexistent_cid = test_cid(255);
12381238+ let pg_meta = f.pg.blob.get_blob_metadata(&nonexistent_cid).await.unwrap();
12391239+ let store_meta = f
12401240+ .store
12411241+ .blob
12421242+ .get_blob_metadata(&nonexistent_cid)
12431243+ .await
12441244+ .unwrap();
12451245+ assert_eq!(pg_meta.is_none(), store_meta.is_none());
12461246+}
12471247+12481248+#[tokio::test]
12491249+async fn parity_deletion_requests() {
12501250+ let f = ParityFixture::new().await;
12511251+ let did = test_did("delreq");
12521252+ let token = format!("del-token-{}", Uuid::new_v4());
12531253+ let expires = chrono::Utc::now() + chrono::Duration::hours(24);
12541254+12551255+ f.pg.infra
12561256+ .create_deletion_request(&token, &did, expires)
12571257+ .await
12581258+ .unwrap();
12591259+ f.store
12601260+ .infra
12611261+ .create_deletion_request(&token, &did, expires)
12621262+ .await
12631263+ .unwrap();
12641264+12651265+ let pg_req = f.pg.infra.get_deletion_request(&token).await.unwrap();
12661266+ let store_req = f.store.infra.get_deletion_request(&token).await.unwrap();
12671267+ assert!(pg_req.is_some());
12681268+ assert!(store_req.is_some());
12691269+ assert_eq!(
12701270+ pg_req.as_ref().unwrap().did,
12711271+ store_req.as_ref().unwrap().did
12721272+ );
12731273+12741274+ let pg_by_did = f.pg.infra.get_deletion_request_by_did(&did).await.unwrap();
12751275+ let store_by_did = f
12761276+ .store
12771277+ .infra
12781278+ .get_deletion_request_by_did(&did)
12791279+ .await
12801280+ .unwrap();
12811281+ assert!(pg_by_did.is_some());
12821282+ assert!(store_by_did.is_some());
12831283+ assert_eq!(pg_by_did.unwrap().token, store_by_did.unwrap().token);
12841284+12851285+ f.pg.infra.delete_deletion_request(&token).await.unwrap();
12861286+ f.store.infra.delete_deletion_request(&token).await.unwrap();
12871287+12881288+ let pg_gone = f.pg.infra.get_deletion_request(&token).await.unwrap();
12891289+ let store_gone = f.store.infra.get_deletion_request(&token).await.unwrap();
12901290+ assert!(pg_gone.is_none());
12911291+ assert!(store_gone.is_none());
12921292+}
12931293+12941294+#[tokio::test]
12951295+async fn parity_signing_key_reservation() {
12961296+ let f = ParityFixture::new().await;
12971297+ let did = test_did("sigkey");
12981298+ let expires = chrono::Utc::now() + chrono::Duration::hours(1);
12991299+ let pub_key = format!("did:key:z6Mk{}", Uuid::new_v4().simple());
13001300+ let priv_bytes = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
13011301+13021302+ f.pg.infra
13031303+ .reserve_signing_key(Some(&did), &pub_key, &priv_bytes, expires)
13041304+ .await
13051305+ .unwrap();
13061306+ f.store
13071307+ .infra
13081308+ .reserve_signing_key(Some(&did), &pub_key, &priv_bytes, expires)
13091309+ .await
13101310+ .unwrap();
13111311+13121312+ let pg_key = f.pg.infra.get_reserved_signing_key(&pub_key).await.unwrap();
13131313+ let store_key = f
13141314+ .store
13151315+ .infra
13161316+ .get_reserved_signing_key(&pub_key)
13171317+ .await
13181318+ .unwrap();
13191319+ assert!(pg_key.is_some());
13201320+ assert!(store_key.is_some());
13211321+ assert_eq!(
13221322+ pg_key.unwrap().private_key_bytes,
13231323+ store_key.unwrap().private_key_bytes
13241324+ );
13251325+13261326+ let pg_full =
13271327+ f.pg.infra
13281328+ .get_reserved_signing_key_full(&pub_key)
13291329+ .await
13301330+ .unwrap();
13311331+ let store_full = f
13321332+ .store
13331333+ .infra
13341334+ .get_reserved_signing_key_full(&pub_key)
13351335+ .await
13361336+ .unwrap();
13371337+ assert!(pg_full.is_some());
13381338+ assert!(store_full.is_some());
13391339+ let pg_f = pg_full.unwrap();
13401340+ let store_f = store_full.unwrap();
13411341+ assert_eq!(pg_f.public_key_did_key, store_f.public_key_did_key);
13421342+ assert_eq!(pg_f.did, store_f.did);
13431343+}
13441344+13451345+#[tokio::test]
13461346+async fn parity_repo_root_operations() {
13471347+ let f = ParityFixture::new().await;
13481348+ let uid = Uuid::new_v4();
13491349+ let did = test_did("root");
13501350+ let handle = test_handle("root");
13511351+ let root_cid = test_cid(0);
13521352+13531353+ seed_repo(&f.pg, &did, &handle, &root_cid, uid).await;
13541354+ seed_repo(&f.store, &did, &handle, &root_cid, uid).await;
13551355+13561356+ let pg_root = f.pg.repo.get_repo_root_by_did(&did).await.unwrap();
13571357+ let store_root = f.store.repo.get_repo_root_by_did(&did).await.unwrap();
13581358+ assert_eq!(pg_root, store_root);
13591359+13601360+ let new_root = test_cid(99);
13611361+ f.pg.repo
13621362+ .update_repo_root(uid, &new_root, "rev1")
13631363+ .await
13641364+ .unwrap();
13651365+ f.store
13661366+ .repo
13671367+ .update_repo_root(uid, &new_root, "rev1")
13681368+ .await
13691369+ .unwrap();
13701370+13711371+ let pg_updated = f.pg.repo.get_repo_root_by_did(&did).await.unwrap();
13721372+ let store_updated = f.store.repo.get_repo_root_by_did(&did).await.unwrap();
13731373+ assert_eq!(pg_updated, store_updated);
13741374+ assert_eq!(pg_updated.unwrap().as_str(), new_root.as_str());
13751375+13761376+ let pg_info = f.pg.repo.get_repo(uid).await.unwrap().unwrap();
13771377+ let store_info = f.store.repo.get_repo(uid).await.unwrap().unwrap();
13781378+ assert_eq!(pg_info.repo_rev, store_info.repo_rev);
13791379+ assert_eq!(
13801380+ pg_info.repo_root_cid.as_str(),
13811381+ store_info.repo_root_cid.as_str()
13821382+ );
13831383+}
13841384+13851385+#[tokio::test]
13861386+async fn parity_delete_all_records() {
13871387+ let f = ParityFixture::new().await;
13881388+ let uid = Uuid::new_v4();
13891389+ let did = test_did("delall");
13901390+ let handle = test_handle("delall");
13911391+ let root_cid = test_cid(0);
13921392+ let collection = test_nsid("post");
13931393+13941394+ seed_repo(&f.pg, &did, &handle, &root_cid, uid).await;
13951395+ seed_repo(&f.store, &did, &handle, &root_cid, uid).await;
13961396+13971397+ let records: Vec<(Rkey, CidLink)> = (0u8..5)
13981398+ .map(|i| (test_rkey(&format!("3l{:02}aaaaaaaaa", i)), test_cid(i + 1)))
13991399+ .collect();
14001400+14011401+ seed_records(&f.pg, uid, &collection, &records).await;
14021402+ seed_records(&f.store, uid, &collection, &records).await;
14031403+14041404+ f.pg.repo.delete_all_records(uid).await.unwrap();
14051405+ f.store.repo.delete_all_records(uid).await.unwrap();
14061406+14071407+ let pg_count = f.pg.repo.count_records(uid).await.unwrap();
14081408+ let store_count = f.store.repo.count_records(uid).await.unwrap();
14091409+ assert_eq!(pg_count, 0);
14101410+ assert_eq!(store_count, 0);
14111411+14121412+ let pg_colls = f.pg.repo.list_collections(uid).await.unwrap();
14131413+ let store_colls = f.store.repo.list_collections(uid).await.unwrap();
14141414+ assert_eq!(pg_colls.len(), 0);
14151415+ assert_eq!(store_colls.len(), 0);
14161416+}
14171417+14181418+#[tokio::test]
14191419+async fn parity_plc_tokens() {
14201420+ let f = ParityFixture::new().await;
14211421+ let uid = Uuid::new_v4();
14221422+ let did = test_did("plctok");
14231423+ let handle = test_handle("plctok");
14241424+ let root_cid = test_cid(0);
14251425+14261426+ seed_repo(&f.pg, &did, &handle, &root_cid, uid).await;
14271427+ seed_repo(&f.store, &did, &handle, &root_cid, uid).await;
14281428+14291429+ let token = format!("plc-{}", Uuid::new_v4());
14301430+ let expires = chrono::Utc::now() + chrono::Duration::hours(1);
14311431+14321432+ f.pg.infra
14331433+ .insert_plc_token(uid, &token, expires)
14341434+ .await
14351435+ .unwrap();
14361436+ f.store
14371437+ .infra
14381438+ .insert_plc_token(uid, &token, expires)
14391439+ .await
14401440+ .unwrap();
14411441+14421442+ let pg_expiry = f.pg.infra.get_plc_token_expiry(uid, &token).await.unwrap();
14431443+ let store_expiry = f
14441444+ .store
14451445+ .infra
14461446+ .get_plc_token_expiry(uid, &token)
14471447+ .await
14481448+ .unwrap();
14491449+ assert!(pg_expiry.is_some());
14501450+ assert!(store_expiry.is_some());
14511451+14521452+ let pg_by_did = f.pg.infra.get_plc_tokens_by_did(&did).await.unwrap();
14531453+ let store_by_did = f.store.infra.get_plc_tokens_by_did(&did).await.unwrap();
14541454+ assert_eq!(pg_by_did.len(), store_by_did.len());
14551455+14561456+ let pg_count = f.pg.infra.count_plc_tokens_by_did(&did).await.unwrap();
14571457+ let store_count = f.store.infra.count_plc_tokens_by_did(&did).await.unwrap();
14581458+ assert_eq!(pg_count, store_count);
14591459+ assert_eq!(pg_count, 1);
14601460+14611461+ f.pg.infra.delete_plc_token(uid, &token).await.unwrap();
14621462+ f.store.infra.delete_plc_token(uid, &token).await.unwrap();
14631463+14641464+ let pg_gone = f.pg.infra.get_plc_token_expiry(uid, &token).await.unwrap();
14651465+ let store_gone = f
14661466+ .store
14671467+ .infra
14681468+ .get_plc_token_expiry(uid, &token)
14691469+ .await
14701470+ .unwrap();
14711471+ assert!(pg_gone.is_none());
14721472+ assert!(store_gone.is_none());
14731473+}
14741474+14751475+#[tokio::test]
14761476+async fn parity_blob_delete_and_takedown() {
14771477+ let f = ParityFixture::new().await;
14781478+ let uid = Uuid::new_v4();
14791479+ let did = test_did("blobdel");
14801480+ let handle = test_handle("blobdel");
14811481+ let root_cid = test_cid(0);
14821482+14831483+ seed_repo(&f.pg, &did, &handle, &root_cid, uid).await;
14841484+ seed_repo(&f.store, &did, &handle, &root_cid, uid).await;
14851485+14861486+ let cid = test_cid(180);
14871487+ f.pg.blob
14881488+ .insert_blob(&cid, "image/png", 1024, uid, "blobs/td.png")
14891489+ .await
14901490+ .unwrap();
14911491+ f.store
14921492+ .blob
14931493+ .insert_blob(&cid, "image/png", 1024, uid, "blobs/td.png")
14941494+ .await
14951495+ .unwrap();
14961496+14971497+ let pg_td =
14981498+ f.pg.blob
14991499+ .update_blob_takedown(&cid, Some("mod-action-1"))
15001500+ .await
15011501+ .unwrap();
15021502+ let store_td = f
15031503+ .store
15041504+ .blob
15051505+ .update_blob_takedown(&cid, Some("mod-action-1"))
15061506+ .await
15071507+ .unwrap();
15081508+ assert_eq!(pg_td, store_td);
15091509+15101510+ let pg_with_td = f.pg.blob.get_blob_with_takedown(&cid).await.unwrap();
15111511+ let store_with_td = f.store.blob.get_blob_with_takedown(&cid).await.unwrap();
15121512+ assert_eq!(
15131513+ pg_with_td.as_ref().map(|b| b.takedown_ref.as_deref()),
15141514+ store_with_td.as_ref().map(|b| b.takedown_ref.as_deref())
15151515+ );
15161516+15171517+ f.pg.blob.delete_blob_by_cid(&cid).await.unwrap();
15181518+ f.store.blob.delete_blob_by_cid(&cid).await.unwrap();
15191519+15201520+ let pg_meta = f.pg.blob.get_blob_metadata(&cid).await.unwrap();
15211521+ let store_meta = f.store.blob.get_blob_metadata(&cid).await.unwrap();
15221522+ assert!(pg_meta.is_none());
15231523+ assert!(store_meta.is_none());
15241524+}
+12-12
crates/tranquil-pds/tests/whole_story.rs
···177177 .expect("Request delete failed");
178178 assert_eq!(request_delete_res.status(), StatusCode::OK);
179179180180- let pool = get_test_db_pool().await;
181181- let row = sqlx::query!(
182182- "SELECT token FROM account_deletion_requests WHERE did = $1",
183183- did
184184- )
185185- .fetch_one(pool)
186186- .await
187187- .expect("Failed to get deletion token");
180180+ let repos = get_test_repos().await;
181181+ let deletion_request = repos
182182+ .infra
183183+ .get_deletion_request_by_did(&tranquil_types::Did::new(did.clone()).unwrap())
184184+ .await
185185+ .unwrap()
186186+ .unwrap();
188187189188 let final_delete_res = client
190189 .post(format!("{}/xrpc/com.atproto.server.deleteAccount", base))
191190 .json(&json!({
192191 "did": did,
193192 "password": password,
194194- "token": row.token
193193+ "token": deletion_request.token
195194 }))
196195 .send()
197196 .await
198197 .expect("Final delete failed");
199198 assert_eq!(final_delete_res.status(), StatusCode::OK);
200199201201- let user_gone = sqlx::query!("SELECT id FROM users WHERE did = $1", did)
202202- .fetch_optional(pool)
200200+ let user_gone = repos
201201+ .user
202202+ .get_by_did(&tranquil_types::Did::new(did.clone()).unwrap())
203203 .await
204204- .expect("Failed to check user");
204204+ .unwrap();
205205 assert!(user_gone.is_none(), "User should be deleted");
206206}
207207
+2-2
crates/tranquil-server/src/main.rs
···114114 let signal_sender = if tranquil_config::get().signal.enabled {
115115 let slot = Arc::new(tranquil_signal::SignalSlot::default());
116116 state = state.with_signal_sender(slot.clone());
117117- if let Some(client) =
118118- tranquil_signal::SignalClient::from_pool(&state.repos.pool, shutdown.clone()).await
117117+ if let Some(provider) = &state.signal_store_provider
118118+ && let Some(client) = provider.load_signal_client(shutdown.clone()).await
119119 {
120120 slot.set_client(client).await;
121121 info!("Signal device already linked");