don't
5
fork

Configure Feed

Select the types of activity you want to include in your feed.

test(knot): extract mock pds into separate crate

Signed-off-by: tjh <x@tjh.dev>

tjh 1785efbb 2f8e45ba

+730 -343
+31 -2
Cargo.lock
··· 184 184 checksum = "7b7b6141e96a8c160799cc2d5adecd5cbbe5054cb8c7c4af53da0f83bb7ad256" 185 185 dependencies = [ 186 186 "aws-lc-sys", 187 + "untrusted 0.7.1", 187 188 "zeroize", 188 189 ] 189 190 ··· 2727 2726 "axum-extra", 2728 2727 "bytes", 2729 2728 "clap", 2729 + "dashmap", 2730 2730 "data-encoding", 2731 2731 "futures-util", 2732 2732 "git-service", ··· 2738 2736 "jetstream", 2739 2737 "lexicon", 2740 2738 "mimetype-detector", 2739 + "mock-pds", 2741 2740 "moka", 2742 2741 "multibase", 2743 2742 "rand 0.9.2", ··· 2956 2953 "libc", 2957 2954 "wasi", 2958 2955 "windows-sys 0.61.2", 2956 + ] 2957 + 2958 + [[package]] 2959 + name = "mock-pds" 2960 + version = "0.0.0" 2961 + dependencies = [ 2962 + "atproto", 2963 + "auth", 2964 + "aws-lc-rs", 2965 + "axum", 2966 + "data-encoding", 2967 + "futures-util", 2968 + "identity", 2969 + "multibase", 2970 + "serde", 2971 + "serde_json", 2972 + "sqlx", 2973 + "tokio", 2974 + "tracing", 2975 + "url", 2959 2976 ] 2960 2977 2961 2978 [[package]] ··· 3587 3564 "cfg-if", 3588 3565 "getrandom 0.2.17", 3589 3566 "libc", 3590 - "untrusted", 3567 + "untrusted 0.9.0", 3591 3568 "windows-sys 0.52.0", 3592 3569 ] 3593 3570 ··· 3712 3689 "aws-lc-rs", 3713 3690 "ring", 3714 3691 "rustls-pki-types", 3715 - "untrusted", 3692 + "untrusted 0.9.0", 3716 3693 ] 3717 3694 3718 3695 [[package]] ··· 4835 4812 version = "0.1.4" 4836 4813 source = "registry+https://github.com/rust-lang/crates.io-index" 4837 4814 checksum = "7df058c713841ad818f1dc5d3fd88063241cc61f49f5fbea4b951e8cf5a8d71d" 4815 + 4816 + [[package]] 4817 + name = "untrusted" 4818 + version = "0.7.1" 4819 + source = "registry+https://github.com/rust-lang/crates.io-index" 4820 + checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" 4838 4821 4839 4822 [[package]] 4840 4823 name = "untrusted"
+1 -1
Cargo.toml
··· 8 8 "crates/jetstream", 9 9 "crates/knot", 10 10 "crates/lexicon", 11 - "crates/git-service", 11 + "crates/git-service", "crates/mock-pds", 12 12 ] 13 13 default-members = ["crates/knot"] 14 14
+17
crates/atproto/src/uri.rs
··· 11 11 pub rkey: String, 12 12 } 13 13 14 + impl RecordUri { 15 + pub fn new( 16 + authority: impl Into<OwnedDid>, 17 + collection: impl Into<String>, 18 + rkey: impl Into<String>, 19 + ) -> Self { 20 + let authority = authority.into(); 21 + let collection = collection.into(); 22 + let rkey = rkey.into(); 23 + Self { 24 + authority, 25 + collection, 26 + rkey, 27 + } 28 + } 29 + } 30 + 14 31 #[derive(Debug, PartialEq, thiserror::Error)] 15 32 pub enum Error { 16 33 #[error("Invalid authority: {0}")]
+2
crates/knot/Cargo.toml
··· 48 48 tower = { version = "0.5.2", features = ["buffer", "filter", "limit"] } 49 49 tower-http = { version = "0.6.6", features = ["decompression-gzip", "request-id", "trace", "tracing", "util"] } 50 50 tracing-subscriber = { version = "0.3.20", features = ["env-filter"] } 51 + dashmap = "6.1.0" 52 + mock-pds = { version = "0.0.0", path = "../mock-pds" } 51 53 52 54 [dev-dependencies] 53 55 http-body-util = "0.1.3"
+190 -340
crates/knot/src/lib.rs
··· 9 9 pub mod private; 10 10 pub mod public; 11 11 pub mod services; 12 + pub mod sync; 12 13 pub mod types; 13 14 mod util; 15 + 16 + #[cfg(test)] 17 + pub(crate) mod mock; 14 18 15 19 pub async fn serve_all( 16 20 router: Router, ··· 44 40 45 41 #[cfg(test)] 46 42 mod tests { 47 - use std::{ 48 - borrow::Cow, 49 - collections::HashMap, 50 - net::{SocketAddr, TcpListener}, 51 - sync::{Arc, Mutex}, 52 - }; 43 + use std::borrow::Cow; 53 44 54 - use atproto::{Did, did::OwnedDid, tid::Tid}; 55 - use auth::jwt::{Claims, Header}; 56 - use aws_lc_rs::{ 57 - encoding::{AsBigEndian, EcPublicKeyCompressedBin}, 58 - rand::SystemRandom, 59 - signature::{ECDSA_P256K1_SHA256_FIXED_SIGNING, EcdsaKeyPair, KeyPair}, 60 - }; 45 + use atproto::{Did, tid::Tid}; 46 + use auth::jwt::Claims; 47 + 61 48 use axum::{ 62 - Json, Router, 63 49 body::Body, 64 - extract::{Query, State}, 65 - http::{HeaderValue, Method, Request, Response, StatusCode, header}, 50 + http::{Request, StatusCode}, 66 51 }; 67 - use futures_util::FutureExt; 68 - use identity::{DidDocument, ResolveIdentity, Resolver}; 69 - use sqlx::SqlitePool; 70 - use tempfile::TempDir; 71 52 use time::{OffsetDateTime, format_description::well_known::Rfc3339}; 72 53 use tower::ServiceExt; 73 54 74 - use crate::{ 75 - model::{Knot, config::KnotConfiguration}, 76 - services::database::DataStore, 77 - }; 55 + use crate::model::Knot; 78 56 79 57 const TEST_DID: &str = "did:plc:65gha4t3avpfpzmvpbwovss7"; 80 - const TEST_INSTANCE: &str = "did:web:test-knot"; 81 - 82 - /// Mock PDS, PLC directory and identity resolver. 83 - #[derive(Clone, Debug)] 84 - struct Pds { 85 - addr: Option<SocketAddr>, 86 - identities: Arc<Mutex<Vec<(DidDocument, EcdsaKeyPair)>>>, 87 - documents: Arc<Mutex<HashMap<String, serde_json::Value>>>, 88 - } 89 - 90 - impl Default for Pds { 91 - fn default() -> Self { 92 - Self::new() 93 - } 94 - } 95 - 96 - impl Pds { 97 - fn new() -> Self { 98 - Self { 99 - addr: None, 100 - identities: Default::default(), 101 - documents: Default::default(), 102 - } 103 - } 104 - 105 - fn service_endpoint(&self) -> Option<url::Url> { 106 - self.addr.map(|addr| { 107 - format!("http://{addr}/") 108 - .parse() 109 - .expect("service endpoint should be a valid URL") 110 - }) 111 - } 112 - 113 - fn get_record(&self, uri: &str) -> Option<serde_json::Value> { 114 - self.documents.lock().unwrap().get(uri).cloned() 115 - } 116 - 117 - fn insert_record( 118 - &self, 119 - did: &Did, 120 - collection: &str, 121 - rkey: &str, 122 - doc: serde_json::Value, 123 - ) -> String { 124 - let uri = format!("at://{did}/{collection}/{rkey}"); 125 - self.documents.lock().unwrap().insert(uri.clone(), doc); 126 - uri 127 - } 128 - 129 - /// Add a DID document created from `did`, `handle`, and a random ecdsa key-pair to the PDS. 130 - /// 131 - /// If [`Self::serve`] has been called before adding the identity, the local URL of the 132 - /// PDS will be set as the new DID document's '#atproto_pds' service endpoint. 133 - /// 134 - fn add_identity(&self, did: &Did, handle: &str) { 135 - let mut doc = DidDocument::new(did, handle).expect("valid did for did document"); 136 - if let Some(service_endpoint) = self.service_endpoint() { 137 - doc.service 138 - .push(identity::Service::atproto_pds(service_endpoint)); 139 - } 140 - 141 - // Generate a key pair and encode the public key as verification method for 142 - // the mock user. 143 - let key_pair = EcdsaKeyPair::generate(&ECDSA_P256K1_SHA256_FIXED_SIGNING).unwrap(); 144 - let public_key: EcPublicKeyCompressedBin = key_pair.public_key().as_be_bytes().unwrap(); 145 - let mut key_data = vec![0xe7, 0x01]; 146 - key_data.extend_from_slice(public_key.as_ref()); 147 - let public_key_multibase = multibase::encode(multibase::Base::Base58Btc, key_data); 148 - doc.verification_method 149 - .push(identity::VerificationMethod::Multikey { 150 - id: format!("{}#atproto", doc.id), 151 - controller: doc.id.clone(), 152 - public_key_multibase, 153 - }); 154 - 155 - self.identities.lock().unwrap().push((doc, key_pair)); 156 - } 157 - 158 - // Create an inter-service auth header for an account in the fake PDS. 159 - fn service_auth(&self, claims: &Claims) -> HeaderValue { 160 - use data_encoding::BASE64URL_NOPAD as Encoding; 161 - 162 - let mut token = String::new(); 163 - let header = Encoding.encode( 164 - &serde_json::to_vec(&Header { 165 - typ: auth::jwt::Type::JWT, 166 - alg: auth::jwt::Algorithm::ES256K, 167 - crv: None, 168 - }) 169 - .unwrap(), 170 - ); 171 - 172 - token.push_str(&header); 173 - 174 - let claims_enc = Encoding.encode(&serde_json::to_vec(claims).unwrap()); 175 - token.push('.'); 176 - token.push_str(&claims_enc); 177 - 178 - let guard = self.identities.lock().unwrap(); 179 - let key_pair = guard 180 - .iter() 181 - .find(|(doc, _)| doc.id == claims.iss) 182 - .map(|(_, key_pair)| key_pair) 183 - .expect("DID should exist to issue a service auth request"); 184 - 185 - let signature = key_pair 186 - .sign(&SystemRandom::new(), token.as_bytes()) 187 - .unwrap(); 188 - 189 - let signature = Encoding.encode(signature.as_ref()); 190 - token.push('.'); 191 - token.push_str(&signature); 192 - 193 - HeaderValue::from_str(&format!("Bearer {token}")) 194 - .expect("Service auth header should be valid") 195 - } 196 - 197 - fn service_auth_with<F>(&self, iss: &Did, aud: &Did, modify_claims: F) -> HeaderValue 198 - where 199 - F: FnOnce(&mut Claims), 200 - { 201 - let jti: [u8; 16] = rand::random(); 202 - let jti = data_encoding::BASE32_NOPAD_VISUAL 203 - .encode(&jti) 204 - .to_lowercase(); 205 - 206 - let mut claims = Claims { 207 - iss: iss.into(), 208 - aud: aud.into(), 209 - iat: OffsetDateTime::now_utc().unix_timestamp(), 210 - exp: OffsetDateTime::now_utc().unix_timestamp() + 10, 211 - lxm: None, 212 - jti: jti.into(), 213 - }; 214 - 215 - modify_claims(&mut claims); 216 - 217 - self.service_auth(&claims) 218 - } 219 - 220 - fn service_auth_from(&self, iss: &Did, aud: &Did, lxm: &str) -> HeaderValue { 221 - self.service_auth_with(iss, aud, |claims| { 222 - claims.lxm = Some( 223 - lxm.try_into() 224 - .expect("Lexicon method should be a valid NSID"), 225 - ); 226 - }) 227 - } 228 - 229 - fn serve(&mut self) { 230 - #[derive(serde::Deserialize)] 231 - struct GetRecord { 232 - repo: String, 233 - collection: String, 234 - rkey: String, 235 - } 236 - 237 - assert!(self.addr.is_none(), "serve() already called"); 238 - let pds = Router::new() 239 - .route( 240 - "/xrpc/com.atproto.repo.getRecord", 241 - axum::routing::get( 242 - async move |State(state): State<Pds>, 243 - Query(GetRecord { 244 - repo, 245 - collection, 246 - rkey, 247 - }): Query<GetRecord>| { 248 - Json(state.get_record(&format!("at://{repo}/{collection}/{rkey}"))) 249 - }, 250 - ), 251 - ) 252 - .with_state(self.clone()); 253 - 254 - let listener = TcpListener::bind("127.0.0.1:0").unwrap(); 255 - listener.set_nonblocking(true).unwrap(); 256 - self.addr = Some(listener.local_addr().unwrap()); 257 - 258 - tokio::spawn(async move { 259 - axum::serve(tokio::net::TcpListener::from_std(listener).unwrap(), pds) 260 - .await 261 - .unwrap(); 262 - }); 263 - } 264 - } 265 - 266 - impl ResolveIdentity for Pds { 267 - fn resolve_handle<'s: 'h, 'h>( 268 - &'s self, 269 - handle: &'h str, 270 - ) -> futures_util::future::BoxFuture<'h, Result<OwnedDid, identity::ResolveError>> { 271 - async move { 272 - self.identities 273 - .lock() 274 - .unwrap() 275 - .iter() 276 - .find_map(|(doc, _)| { 277 - match doc.primary_alias().is_some_and(|alias| alias == handle) { 278 - true => Some(doc.id.clone()), 279 - false => None, 280 - } 281 - }) 282 - .ok_or(identity::ResolveError::UnresolvedHandle) 283 - } 284 - .boxed() 285 - } 286 - 287 - fn resolve_did<'s: 'd, 'd>( 288 - &'s self, 289 - did: &'d Did, 290 - ) -> futures_util::future::BoxFuture<'d, Result<DidDocument, identity::ResolveError>> 291 - { 292 - async move { 293 - self.identities 294 - .lock() 295 - .unwrap() 296 - .iter() 297 - .find_map(|(doc, _)| match doc.id == did { 298 - true => Some(doc.clone()), 299 - false => None, 300 - }) 301 - .ok_or(identity::ResolveError::UnresolvedHandle) 302 - } 303 - .boxed() 304 - } 305 - } 306 - 307 - async fn setup(owner_did: Option<&str>) -> (TempDir, Pds, Knot) { 308 - let base = tempfile::tempdir().expect("temporary directory"); 309 - let pool = SqlitePool::connect("sqlite://:memory:").await.unwrap(); 310 - sqlx::migrate!().run(&pool).await.unwrap(); 311 - 312 - let mut pds = Pds::new(); 313 - pds.serve(); 314 - 315 - let config = KnotConfiguration::new( 316 - OwnedDid::from_static(TEST_DID), 317 - OwnedDid::from_static(TEST_INSTANCE), 318 - base.path(), 319 - ); 320 - 321 - let database = DataStore::new(pool); 322 - let resolver = Resolver::new(pds.clone()); 323 - 324 - let knot = Knot::new(config, resolver, reqwest::Client::new(), database, []).unwrap(); 325 - if let Some(owner_did) = owner_did 326 - .map(Did::parse) 327 - .transpose() 328 - .expect("knot owner did should be valid") 329 - { 330 - knot.database() 331 - .upsert_knot_member( 332 - "", 333 - "", 334 - "", 335 - &lexicon::sh_tangled::knot::Member { 336 - subject: Cow::Borrowed(&owner_did), 337 - domain: Cow::Borrowed(knot.instance_ident()), 338 - created_at: OffsetDateTime::now_utc(), 339 - }, 340 - ) 341 - .await 342 - .expect("knot member inserted into db"); 343 - } 344 - 345 - (base, pds, knot) 346 - } 58 + const TEST_INSTANCE: &str = "lib-knot-test"; 347 59 348 60 fn get(uri: &str) -> Request<Body> { 349 61 Request::builder().uri(uri).body(Body::empty()).unwrap() ··· 67 347 68 348 #[tokio::test] 69 349 async fn can_query_knot_owner() { 70 - let (_, _, knot) = setup(None).await; 350 + let (_, _, knot) = crate::mock::setup(TEST_DID, TEST_INSTANCE).await; 71 351 let response = super::public::router() 72 352 .with_state(knot) 73 353 .oneshot(get("/xrpc/sh.tangled.owner")) ··· 90 370 91 371 #[tokio::test] 92 372 async fn xrpc_sh_tangled_repo_missing_repo() { 93 - let (_, _, knot) = setup(None).await; 373 + let (_, _, knot) = crate::mock::setup(TEST_DID, TEST_INSTANCE).await; 94 374 for particle in ["tree", "log", "tags", "branches"] { 95 375 let response = super::public::router() 96 376 .with_state(knot.clone()) ··· 104 384 105 385 #[tokio::test] 106 386 async fn xrpc_sh_tangled_repo_bad_repo_format() { 107 - let (_, _, knot) = setup(None).await; 387 + let (_, _, knot) = crate::mock::setup(TEST_DID, TEST_INSTANCE).await; 108 388 for particle in ["tree", "log", "tags", "branches"] { 109 389 // Missing repo name 110 390 let response = super::public::router() ··· 135 415 136 416 #[tokio::test] 137 417 async fn xrpc_sh_tangled_repo_not_found() { 138 - let (_, _, knot) = setup(None).await; 418 + let (_, _, knot) = crate::mock::setup(TEST_DID, TEST_INSTANCE).await; 139 419 for particle in ["tree", "log", "tags", "branches"] { 140 420 let response = super::public::router() 141 421 .with_state(knot.clone()) ··· 152 432 mod sh_tangled_repo_create { 153 433 use super::super::public; 154 434 use super::*; 435 + use axum::http::{HeaderValue, Method, Response, header}; 436 + 437 + fn make_claims<F>(iss: &Did, aud: &Did, modify_claims: F) -> Claims 438 + where 439 + F: FnOnce(&mut Claims), 440 + { 441 + let jti: [u8; 16] = rand::random(); 442 + let jti = data_encoding::BASE32_NOPAD_VISUAL 443 + .encode(&jti) 444 + .to_lowercase(); 445 + 446 + let mut claims = Claims { 447 + iss: iss.into(), 448 + aud: aud.into(), 449 + iat: OffsetDateTime::now_utc().unix_timestamp(), 450 + exp: OffsetDateTime::now_utc().unix_timestamp() + 10, 451 + lxm: None, 452 + jti: jti.into(), 453 + }; 454 + 455 + modify_claims(&mut claims); 456 + claims 457 + } 458 + 459 + async fn service_auth_with<F>( 460 + pds: &mock_pds::Pds, 461 + iss: &Did, 462 + aud: &Did, 463 + modify_claims: F, 464 + ) -> HeaderValue 465 + where 466 + F: FnOnce(&mut Claims), 467 + { 468 + let claims = make_claims(iss, aud, modify_claims); 469 + let authorization = pds.service_auth(&claims).await; 470 + HeaderValue::from_str(&authorization).unwrap() 471 + } 155 472 156 473 #[tokio::test] 157 474 async fn reject_wrong_method() { 158 - let (_, _, knot) = setup(None).await; 475 + let (_, _, knot) = crate::mock::setup(TEST_DID, TEST_INSTANCE).await; 159 476 let response = public::router() 160 477 .with_state(knot.clone()) 161 478 .oneshot(get("/xrpc/sh.tangled.repo.create")) ··· 204 447 205 448 async fn create_repo_with<F>( 206 449 knot: &Knot, 207 - pds: Pds, 450 + pds: mock_pds::Pds, 208 451 did: &Did, 209 452 rkey: &str, 210 453 repo_name: &str, ··· 219 462 did, 220 463 "sh.tangled.repo", 221 464 rkey, 222 - serde_json::json!({ 223 - "uri": format!("at://{did}/sh.tangled.repo/{rkey}"), 224 - "cid": "bafyreie7ym6v4gepcdi2ul2kchylo5aahlw3nmvjg3veipoi76kziixfoa", 225 - "value": { 226 - "name": repo_name, 227 - "knot": knot.instance_ident(), 228 - "source": source, 229 - "createdAt": OffsetDateTime::now_utc().format(&Rfc3339).unwrap() 230 - } 465 + &serde_json::json!({ 466 + "name": repo_name, 467 + "knot": knot.instance_ident(), 468 + "source": source, 469 + "createdAt": OffsetDateTime::now_utc().format(&Rfc3339).unwrap() 231 470 }), 232 - ); 471 + ) 472 + .await; 233 473 234 474 // Generate the body of the 'sh.tangled.repo.create' request. 235 475 let create = lexicon::sh_tangled::repo::create::Input { ··· 235 481 source: None, 236 482 }; 237 483 238 - let auth = pds.service_auth_with(&did, &knot.instance, |claims| { 484 + let auth = service_auth_with(&pds, &did, &knot.instance, |claims| { 239 485 claims.lxm = Some("sh.tangled.repo.create".try_into().unwrap()); 240 486 modify_claims(claims); 241 - }); 487 + }) 488 + .await; 242 489 243 490 let response = public::router() 244 491 .with_state(knot.clone()) ··· 258 503 259 504 async fn create_repo( 260 505 knot: &Knot, 261 - pds: Pds, 506 + pds: mock_pds::Pds, 262 507 did: &Did, 263 508 rkey: &str, 264 509 repo_name: &str, ··· 278 523 279 524 #[tokio::test] 280 525 async fn can_create_repo() { 281 - let (_base, pds, knot) = setup(Some(TEST_DID)).await; 526 + let (_base, pds, knot) = crate::mock::setup(TEST_DID, TEST_INSTANCE).await; 282 527 283 528 let did = Did::from_static(TEST_DID); 284 - pds.add_identity(did, "tjh.dev"); 529 + pds.insert_identity(did, "tjh.dev").await; 530 + knot.add_member( 531 + "", 532 + "", 533 + "", 534 + &lexicon::sh_tangled::knot::Member::new( 535 + &did, 536 + knot.instance_ident(), 537 + OffsetDateTime::now_utc(), 538 + ), 539 + ) 540 + .await 541 + .unwrap(); 285 542 286 543 let rkey = Tid::from_datetime(OffsetDateTime::now_utc(), 0).to_string(); 287 544 assert_eq!( ··· 308 541 309 542 #[tokio::test] 310 543 async fn can_create_fork_from_at() { 311 - let (_base, pds, knot) = setup(Some(TEST_DID)).await; 544 + let (_base, pds, knot) = crate::mock::setup(TEST_DID, TEST_INSTANCE).await; 312 545 313 546 let did = Did::from_static(TEST_DID); 314 - pds.add_identity(did, "tjh.dev"); 547 + pds.insert_identity(did, "tjh.dev").await; 548 + knot.add_member( 549 + "", 550 + "", 551 + "", 552 + &lexicon::sh_tangled::knot::Member::new( 553 + &did, 554 + knot.instance_ident(), 555 + OffsetDateTime::now_utc(), 556 + ), 557 + ) 558 + .await 559 + .unwrap(); 315 560 316 561 // Create a record for the repository to fork from. 317 562 // <https://pdsls.dev/at://did:plc:65gha4t3avpfpzmvpbwovss7/sh.tangled.repo/3m24udbjajf22#record> 318 - let aturi = pds.insert_record( 319 - did, 320 - "sh.tangled.repo", 321 - "3m24udbjajf22", 322 - serde_json::json!({ 323 - "uri": format!("at://{did}/sh.tangled.repo/3m24udbjajf22"), 324 - "cid": "some_cid", 325 - "value": { 563 + let aturi = pds 564 + .insert_record( 565 + did, 566 + "sh.tangled.repo", 567 + "3m24udbjajf22", 568 + &serde_json::json!({ 326 569 "name": "gordian", 327 570 "knot": "gordian.tjh.dev", 328 571 "createdAt": "2025-10-01T10:45:52Z" 329 - } 330 - }), 331 - ); 572 + }), 573 + ) 574 + .await; 332 575 333 576 let rkey = Tid::from_datetime(OffsetDateTime::now_utc(), 0).to_string(); 334 577 assert_eq!( ··· 353 576 354 577 #[tokio::test] 355 578 async fn can_create_fork_from_http() { 356 - let (_base, pds, knot) = setup(Some(TEST_DID)).await; 579 + let (_base, pds, knot) = crate::mock::setup(TEST_DID, TEST_INSTANCE).await; 357 580 358 581 let did = Did::from_static(TEST_DID); 359 - pds.add_identity(did, "tjh.dev"); 582 + pds.insert_identity(did, "tjh.dev").await; 583 + knot.add_member( 584 + "", 585 + "", 586 + "", 587 + &lexicon::sh_tangled::knot::Member::new( 588 + &did, 589 + knot.instance_ident(), 590 + OffsetDateTime::now_utc(), 591 + ), 592 + ) 593 + .await 594 + .unwrap(); 360 595 361 596 let rkey = Tid::from_datetime(OffsetDateTime::now_utc(), 0).to_string(); 362 597 let source = ··· 385 596 386 597 #[tokio::test] 387 598 async fn can_create_fork_from_http_fail() { 388 - let (base, pds, knot) = setup(Some(TEST_DID)).await; 599 + let (base, pds, knot) = crate::mock::setup(TEST_DID, TEST_INSTANCE).await; 389 600 390 601 let did = Did::from_static(TEST_DID); 391 - pds.add_identity(did, "tjh.dev"); 602 + pds.insert_identity(did, "tjh.dev").await; 603 + knot.add_member( 604 + "", 605 + "", 606 + "", 607 + &lexicon::sh_tangled::knot::Member::new( 608 + &did, 609 + knot.instance_ident(), 610 + OffsetDateTime::now_utc(), 611 + ), 612 + ) 613 + .await 614 + .unwrap(); 392 615 393 616 let rkey = Tid::from_datetime(OffsetDateTime::now_utc(), 0).to_string(); 394 617 let source = ··· 423 622 424 623 #[tokio::test] 425 624 async fn rejects_if_owner_is_not_a_member() { 426 - let (_base, pds, knot) = setup(None).await; 625 + let (_base, pds, knot) = crate::mock::setup(TEST_DID, TEST_INSTANCE).await; 427 626 428 627 let did = Did::from_static(TEST_DID); 429 - pds.add_identity(did, "tjh.dev"); 628 + pds.insert_identity(did, "tjh.dev").await; 430 629 431 630 let rkey = Tid::from_datetime(OffsetDateTime::now_utc(), 0).to_string(); 432 631 assert_ne!( ··· 441 640 442 641 #[tokio::test] 443 642 async fn rejects_auth_issued_in_future() { 444 - let (_base, pds, knot) = setup(Some(TEST_DID)).await; 643 + let (_base, pds, knot) = crate::mock::setup(TEST_DID, TEST_INSTANCE).await; 445 644 446 645 let did = Did::from_static(TEST_DID); 447 - pds.add_identity(did, "tjh.dev"); 646 + pds.insert_identity(did, "tjh.dev").await; 647 + knot.add_member( 648 + "", 649 + "", 650 + "", 651 + &lexicon::sh_tangled::knot::Member::new( 652 + &did, 653 + knot.instance_ident(), 654 + OffsetDateTime::now_utc(), 655 + ), 656 + ) 657 + .await 658 + .unwrap(); 448 659 449 660 let rkey = Tid::from_datetime(OffsetDateTime::now_utc(), 0).to_string(); 450 661 assert_eq!( ··· 475 662 476 663 #[tokio::test] 477 664 async fn rejects_auth_expired() { 478 - let (_base, pds, knot) = setup(Some(TEST_DID)).await; 665 + let (_base, pds, knot) = crate::mock::setup(TEST_DID, TEST_INSTANCE).await; 479 666 480 667 let did = Did::from_static(TEST_DID); 481 - pds.add_identity(did, "tjh.dev"); 668 + pds.insert_identity(did, "tjh.dev").await; 669 + knot.add_member( 670 + "", 671 + "", 672 + "", 673 + &lexicon::sh_tangled::knot::Member::new( 674 + &did, 675 + knot.instance_ident(), 676 + OffsetDateTime::now_utc(), 677 + ), 678 + ) 679 + .await 680 + .unwrap(); 482 681 483 682 let rkey = Tid::from_datetime(OffsetDateTime::now_utc(), 0).to_string(); 484 683 assert_eq!( ··· 507 682 508 683 #[tokio::test] 509 684 async fn can_delete_repo() { 510 - let (base, pds, knot) = setup(Some(TEST_DID)).await; 685 + let (base, pds, knot) = crate::mock::setup(TEST_DID, TEST_INSTANCE).await; 511 686 512 687 let did = Did::from_static(TEST_DID); 513 - pds.add_identity(did, "tjh.dev"); 688 + pds.insert_identity(did, "tjh.dev").await; 689 + knot.add_member( 690 + "", 691 + "", 692 + "", 693 + &lexicon::sh_tangled::knot::Member::new( 694 + &did, 695 + knot.instance_ident(), 696 + OffsetDateTime::now_utc(), 697 + ), 698 + ) 699 + .await 700 + .unwrap(); 514 701 515 702 let rkey = Tid::from_datetime(OffsetDateTime::now_utc(), 0).to_string(); 516 703 let name = "another-test-repo"; ··· 566 729 assert!(repo_exists_in_db(&knot, &did, &rkey).await); 567 730 568 731 // Or with the wrong auth. 569 - let auth = pds.service_auth_from(&did, &knot.instance(), "sh.tangled.repo.create"); 732 + let auth = service_auth_with(&pds, &did, &knot.instance(), |claims| { 733 + claims.lxm = Some("sh.tangled.repo.create".try_into().unwrap()); 734 + }) 735 + .await; 736 + 570 737 assert_eq!( 571 738 public::router() 572 739 .with_state(knot.clone()) ··· 594 753 assert!(repo_exists_in_db(&knot, &did, &rkey).await); 595 754 596 755 // Valid auth, empty request body. 597 - let auth = pds.service_auth_from(&did, &knot.instance(), "sh.tangled.repo.delete"); 756 + // Or with the wrong auth. 757 + let auth = service_auth_with(&pds, &did, &knot.instance(), |claims| { 758 + claims.lxm = Some("sh.tangled.repo.delete".try_into().unwrap()); 759 + }) 760 + .await; 598 761 assert_eq!( 599 762 public::router() 600 763 .with_state(knot.clone()) ··· 621 776 gix::open(base.path().join(did.as_str()).join(&rkey)).expect("repository should exist"); 622 777 assert!(repo_exists_in_db(&knot, &did, &rkey).await); 623 778 624 - let auth = pds.service_auth_from(&did, &knot.instance(), "sh.tangled.repo.delete"); 779 + // Or with the wrong auth. 780 + let auth = service_auth_with(&pds, &did, &knot.instance(), |claims| { 781 + claims.lxm = Some("sh.tangled.repo.delete".try_into().unwrap()); 782 + }) 783 + .await; 784 + 625 785 assert_eq!( 626 786 public::router() 627 787 .with_state(knot.clone())
+35
crates/knot/src/mock.rs
··· 1 + use crate::{ 2 + model::{Knot, config::KnotConfiguration}, 3 + services::database::DataStore, 4 + }; 5 + use atproto::did::OwnedDid; 6 + use identity::Resolver; 7 + 8 + pub async fn setup( 9 + owner_did: &str, 10 + instance_name: &str, 11 + ) -> (tempfile::TempDir, mock_pds::Pds, Knot) { 12 + let base = tempfile::tempdir().expect("temporary directory"); 13 + let pool = sqlx::SqlitePool::connect("sqlite://:memory:") 14 + .await 15 + .unwrap(); 16 + 17 + sqlx::migrate!().run(&pool).await.unwrap(); 18 + 19 + let (pds, listener) = mock_pds::init().await; 20 + let pds_api = mock_pds::router(pds.clone()); 21 + tokio::spawn(async move { 22 + axum::serve(listener, pds_api).await.unwrap(); 23 + }); 24 + 25 + let owner_did = OwnedDid::parse(owner_did).expect("owner DID must be valid"); 26 + let instance = OwnedDid::parse(format!("did:web:{instance_name}")) 27 + .expect("instance name should form a valid DID"); 28 + 29 + let database = DataStore::new(pool); 30 + let resolver = Resolver::new(pds.clone()); 31 + let config = KnotConfiguration::new(owner_did.clone(), instance, base.path()); 32 + let knot = Knot::new(config, resolver, reqwest::Client::new(), database, []).unwrap(); 33 + 34 + (base, pds, knot) 35 + }
+5
crates/knot/src/sync.rs
··· 1 + //! 2 + //! Atmosphere synchronization. 3 + //! 4 + 5 + pub mod tap;
crates/knot/src/sync/tap.rs
+10
crates/lexicon/src/sh_tangled/knot.rs
··· 41 41 #[serde(with = "time::serde::rfc3339")] 42 42 pub created_at: OffsetDateTime, 43 43 } 44 + 45 + impl<'a> Member<'a> { 46 + pub fn new(subject: &'a Did, domain: &'a str, created_at: OffsetDateTime) -> Self { 47 + Self { 48 + subject: subject.into(), 49 + domain: domain.into(), 50 + created_at, 51 + } 52 + } 53 + }
+24
crates/mock-pds/Cargo.toml
··· 1 + [package] 2 + name = "mock-pds" 3 + version.workspace = true 4 + authors.workspace = true 5 + repository.workspace = true 6 + license.workspace = true 7 + edition.workspace = true 8 + publish.workspace = true 9 + 10 + [dependencies] 11 + atproto = { workspace = true, features = ["serde", "sqlx"] } 12 + auth.workspace = true 13 + aws-lc-rs = "1.15.4" 14 + axum.workspace = true 15 + data-encoding.workspace = true 16 + futures-util = "0.3.31" 17 + identity.workspace = true 18 + multibase = "0.9.2" 19 + serde = { workspace = true, features = ["derive"] } 20 + serde_json.workspace = true 21 + sqlx = { version = "0.8.6", features = ["sqlite"] } 22 + tokio = { version = "1.49.0", default-features = false, features = ["sync"] } 23 + tracing.workspace = true 24 + url.workspace = true
+131
crates/mock-pds/src/api.rs
··· 1 + use crate::Pds; 2 + 3 + pub fn router<S>(state: Pds) -> axum::Router<S> { 4 + axum::Router::new() 5 + .merge(com_atproto::repo::get_record()) 6 + .merge(com_atproto::repo::list_records()) 7 + .with_state(state) 8 + } 9 + 10 + pub mod com_atproto { 11 + pub mod repo { 12 + use atproto::did::OwnedDid; 13 + use axum::{ 14 + Json, Router, 15 + extract::{FromRef, Query, State}, 16 + http::StatusCode, 17 + response::IntoResponse, 18 + }; 19 + use serde_json::Value; 20 + use sqlx::Row as _; 21 + 22 + use crate::Pds; 23 + 24 + #[derive(serde::Serialize)] 25 + pub struct Record { 26 + uri: String, 27 + cid: String, 28 + value: Value, 29 + } 30 + 31 + pub fn get_record<S: Clone + Send + Sync + 'static>() -> Router<S> 32 + where 33 + Pds: FromRef<S>, 34 + { 35 + const LXM: &str = "com.atproto.repo.getRecord"; 36 + 37 + #[derive(serde::Deserialize)] 38 + pub struct Params { 39 + repo: OwnedDid, 40 + collection: String, 41 + rkey: String, 42 + cid: Option<String>, 43 + } 44 + 45 + #[tracing::instrument(target = "com_atproto::repo::get_record", skip(pds))] 46 + async fn handle( 47 + State(pds): State<Pds>, 48 + Query(Params { 49 + repo, 50 + collection, 51 + rkey, 52 + cid, 53 + }): Query<Params>, 54 + ) -> Result<Json<Record>, StatusCode> { 55 + assert_eq!(cid, None, "Get record by CID not supported"); 56 + 57 + match sqlx::query("SELECT cid, data FROM record LEFT JOIN identity USING (did) WHERE did = ? AND collection = ? AND rkey = ?") 58 + .bind(repo.as_ref()) 59 + .bind(&collection) 60 + .bind(&rkey) 61 + .fetch_optional(pds.db()) 62 + .await { 63 + Err(error) => { 64 + tracing::error!(?error); 65 + Err(StatusCode::INTERNAL_SERVER_ERROR) 66 + } 67 + Ok(None) => Err(StatusCode::NOT_FOUND), 68 + Ok(Some(row)) => { 69 + let cid: String = row.get("cid"); 70 + let data: &str = row.get("data"); 71 + 72 + let uri = format!("at://{repo}/{collection}/{rkey}"); 73 + let value: Value = serde_json::from_str(data).expect("Record value in db should be valid json"); 74 + 75 + Ok(Json(Record { uri, cid, value })) 76 + } 77 + } 78 + } 79 + 80 + Router::new().route(&format!("/xrpc/{LXM}"), axum::routing::get(handle)) 81 + } 82 + 83 + pub fn list_records<S: Clone + Send + Sync + 'static>() -> Router<S> 84 + where 85 + Pds: FromRef<S>, 86 + { 87 + const LXM: &str = "com.atproto.repo.listRecords"; 88 + 89 + #[derive(serde::Deserialize)] 90 + struct Params { 91 + repo: OwnedDid, 92 + collection: String, 93 + } 94 + 95 + #[tracing::instrument(target = "com_atproto::repo::list_records")] 96 + async fn handle( 97 + State(pds): State<Pds>, 98 + Query(Params { repo, collection }): Query<Params>, 99 + ) -> impl IntoResponse { 100 + let rows = sqlx::query( 101 + "SELECT rkey, data FROM record WHERE did = ? AND collection = ? ORDER BY rkey", 102 + ) 103 + .bind(repo.as_ref()) 104 + .bind(&collection) 105 + .fetch_all(pds.db()) 106 + .await 107 + .unwrap(); 108 + 109 + let records = rows 110 + .into_iter() 111 + .map(|row| { 112 + let rkey: &str = row.get("rkey"); 113 + let data: &str = row.get("data"); 114 + 115 + let uri = format!("at://{repo}/{collection}/{rkey}"); 116 + let cid = "bafyreie7ym6v4gepcdi2ul2kchylo5aahlw3nmvjg3veipoi76kziixfoa" 117 + .to_string(); 118 + let value: Value = serde_json::from_str(data) 119 + .expect("Record value in db should be valid json"); 120 + 121 + Record { uri, cid, value } 122 + }) 123 + .collect::<Vec<_>>(); 124 + 125 + Json(serde_json::json!({"records": records })) 126 + } 127 + 128 + Router::new().route(&format!("/xrpc/{LXM}"), axum::routing::get(handle)) 129 + } 130 + } 131 + }
+6
crates/mock-pds/src/lib.rs
··· 1 + mod api; 2 + mod state; 3 + 4 + pub use api::router; 5 + pub use state::Pds; 6 + pub use state::init;
+278
crates/mock-pds/src/state.rs
··· 1 + use std::{fmt::Debug, net::SocketAddr, sync::Arc}; 2 + 3 + use atproto::{did::OwnedDid, tid::Tid}; 4 + use auth::jwt::{Claims, Header}; 5 + use aws_lc_rs::{ 6 + encoding::{AsBigEndian as _, EcPublicKeyCompressedBin}, 7 + rand::SystemRandom, 8 + signature::{ECDSA_P256K1_SHA256_FIXED_SIGNING, EcdsaKeyPair, KeyPair as _}, 9 + }; 10 + use futures_util::FutureExt as _; 11 + use identity::DidDocument; 12 + use sqlx::{ 13 + SqlitePool, 14 + sqlite::{SqliteConnectOptions, SqlitePoolOptions}, 15 + types::time::OffsetDateTime, 16 + }; 17 + use tokio::{ 18 + net::TcpListener, 19 + sync::broadcast::{self, Receiver, Sender}, 20 + }; 21 + 22 + pub type Event = (); 23 + 24 + #[derive(Debug)] 25 + pub struct Pds { 26 + inner: Arc<Inner>, 27 + } 28 + 29 + impl Pds { 30 + pub fn events(&self) -> Receiver<Event> { 31 + self.inner.tx.subscribe() 32 + } 33 + 34 + pub fn db(&self) -> &SqlitePool { 35 + &self.inner.db 36 + } 37 + 38 + /// Get the service endpoint for the PDS. 39 + pub fn service_endpoint(&self) -> url::Url { 40 + format!("http://{}/", self.inner.addr) 41 + .parse() 42 + .expect("service endpoint should be a valid URL") 43 + } 44 + 45 + /// Add a DID document created from `did`, `handle`, and a random ecdsa key-pair to the PDS. 46 + /// 47 + /// The internal address of the mock PDS will be set as the "#atproto_pds" service for 48 + /// the new identity. 49 + /// 50 + pub async fn insert_identity(&self, did: &atproto::Did, handle: &str) { 51 + let mut doc = DidDocument::new(did, handle).expect("valid did for did document"); 52 + doc.service 53 + .push(identity::Service::atproto_pds(self.service_endpoint())); 54 + 55 + // Generate a key pair and encode the public key as verification method for 56 + // the mock user. 57 + let key_pair = EcdsaKeyPair::generate(&ECDSA_P256K1_SHA256_FIXED_SIGNING).unwrap(); 58 + let public_key: EcPublicKeyCompressedBin = key_pair.public_key().as_be_bytes().unwrap(); 59 + let mut key_data = vec![0xe7, 0x01]; 60 + key_data.extend_from_slice(public_key.as_ref()); 61 + let public_key_multibase = multibase::encode(multibase::Base::Base58Btc, key_data); 62 + doc.verification_method 63 + .push(identity::VerificationMethod::Multikey { 64 + id: format!("{}#atproto", doc.id), 65 + controller: doc.id.clone(), 66 + public_key_multibase, 67 + }); 68 + 69 + let rev = Tid::from_datetime(OffsetDateTime::now_utc(), 0).to_string(); 70 + let doc = serde_json::to_string(&doc).unwrap(); 71 + let key = key_pair.to_pkcs8v1().unwrap(); 72 + 73 + sqlx::query("INSERT INTO identity (handle, did, rev, doc, key) VALUES (?, ?, ?, ?, ?)") 74 + .bind(handle) 75 + .bind(did) 76 + .bind(rev) 77 + .bind(doc) 78 + .bind(key.as_ref()) 79 + .execute(self.db()) 80 + .await 81 + .unwrap(); 82 + } 83 + 84 + pub async fn insert_record<T>( 85 + &self, 86 + repo: &atproto::Did, 87 + collection: &str, 88 + rkey: &str, 89 + value: &T, 90 + ) -> String 91 + where 92 + T: serde::Serialize, 93 + { 94 + let rev = Tid::from_datetime(OffsetDateTime::now_utc(), 0).to_string(); 95 + let cid = "bafyreie7ym6v4gepcdi2ul2kchylo5aahlw3nmvjg3veipoi76kziixfoa"; 96 + let data = serde_json::to_string(value).expect("Value must serialize to json"); 97 + 98 + let mut tx = self.db().begin().await.unwrap(); 99 + sqlx::query("INSERT INTO record (did, collection, rkey, cid, data) VALUES (?, ?, ?, ?, ?)") 100 + .bind(repo) 101 + .bind(collection) 102 + .bind(rkey) 103 + .bind(cid) 104 + .bind(data) 105 + .execute(&mut *tx) 106 + .await 107 + .unwrap(); 108 + 109 + sqlx::query("UPDATE identity SET rev = ? WHERE did = ?") 110 + .bind(rev) 111 + .bind(repo) 112 + .execute(&mut *tx) 113 + .await 114 + .unwrap(); 115 + 116 + tx.commit().await.unwrap(); 117 + 118 + format!("at://{repo}/{collection}/{rkey}") 119 + } 120 + 121 + // Create an inter-service auth header for an account in the fake PDS. 122 + pub async fn service_auth(&self, claims: &Claims) -> String { 123 + use data_encoding::BASE64URL_NOPAD as Encoding; 124 + use sqlx::Row as _; 125 + 126 + let mut token = String::new(); 127 + let header = Encoding.encode( 128 + &serde_json::to_vec(&Header { 129 + typ: auth::jwt::Type::JWT, 130 + alg: auth::jwt::Algorithm::ES256K, 131 + crv: None, 132 + }) 133 + .unwrap(), 134 + ); 135 + 136 + token.push_str(&header); 137 + 138 + let claims_enc = Encoding.encode(&serde_json::to_vec(claims).unwrap()); 139 + token.push('.'); 140 + token.push_str(&claims_enc); 141 + 142 + let result = sqlx::query("SELECT key FROM identity WHERE did = ?") 143 + .bind(claims.iss.as_ref()) 144 + .fetch_one(self.db()) 145 + .await 146 + .unwrap(); 147 + 148 + let pkcs8: &[u8] = result.get("key"); 149 + let key_pair = EcdsaKeyPair::from_pkcs8(&ECDSA_P256K1_SHA256_FIXED_SIGNING, pkcs8) 150 + .expect("PKCSv8 key must be valid"); 151 + 152 + let signature = key_pair 153 + .sign(&SystemRandom::new(), token.as_bytes()) 154 + .unwrap(); 155 + 156 + let signature = Encoding.encode(signature.as_ref()); 157 + token.push('.'); 158 + token.push_str(&signature); 159 + 160 + format!("Bearer {token}") 161 + } 162 + } 163 + 164 + impl Clone for Pds { 165 + fn clone(&self) -> Self { 166 + let inner = Arc::clone(&self.inner); 167 + Self { inner } 168 + } 169 + } 170 + 171 + impl identity::ResolveIdentity for Pds { 172 + fn resolve_handle<'s: 'h, 'h>( 173 + &'s self, 174 + handle: &'h str, 175 + ) -> futures_util::future::BoxFuture<'h, Result<OwnedDid, identity::ResolveError>> { 176 + use sqlx::Row as _; 177 + async move { 178 + let result = sqlx::query("SELECT did FROM identity WHERE handle = ?") 179 + .bind(handle) 180 + .fetch_one(self.db()) 181 + .await 182 + .inspect_err(|error| eprintln!("{error:?}")) 183 + .map_err(|_| identity::ResolveError::UnresolvedHandle)?; 184 + 185 + let did: &atproto::Did = result.get("did"); 186 + Ok(did.to_owned()) 187 + } 188 + .boxed() 189 + } 190 + 191 + fn resolve_did<'s: 'd, 'd>( 192 + &'s self, 193 + did: &'d atproto::Did, 194 + ) -> futures_util::future::BoxFuture<'d, Result<DidDocument, identity::ResolveError>> { 195 + use sqlx::Row as _; 196 + async move { 197 + let result = sqlx::query("SELECT doc FROM identity WHERE did = ?") 198 + .bind(did) 199 + .fetch_one(self.db()) 200 + .await 201 + .inspect_err(|error| eprintln!("{error:?}")) 202 + .map_err(|_| identity::ResolveError::UnresolvedHandle)?; 203 + 204 + let doc: &str = result.get("doc"); 205 + let doc = serde_json::from_str(doc).unwrap(); 206 + Ok(doc) 207 + } 208 + .boxed() 209 + } 210 + } 211 + 212 + #[derive(Debug)] 213 + struct Inner { 214 + db: SqlitePool, 215 + tx: Sender<Event>, 216 + addr: SocketAddr, 217 + } 218 + 219 + pub async fn init() -> (Pds, TcpListener) { 220 + let db = SqlitePoolOptions::new() 221 + .max_connections(1) 222 + .connect_with( 223 + SqliteConnectOptions::new() 224 + .in_memory(true) 225 + .shared_cache(true) 226 + .foreign_keys(true), 227 + ) 228 + .await 229 + .unwrap(); 230 + 231 + sqlx::query( 232 + "CREATE TABLE identity ( 233 + handle text NOT NULL, 234 + did text NOT NULL, 235 + rev text NOT NULL, 236 + doc json NOT NULL, 237 + key blob NOT NULL, 238 + 239 + PRIMARY KEY (did), 240 + UNIQUE (handle) 241 + )", 242 + ) 243 + .execute(&db) 244 + .await 245 + .unwrap(); 246 + 247 + sqlx::query( 248 + "CREATE TABLE record ( 249 + did text NOT NULL, 250 + collection text NOT NULL, 251 + rkey text NOT NULL, 252 + cid text NOT NULL, 253 + data json NOT NULL, 254 + 255 + PRIMARY KEY (did, collection, rkey), 256 + FOREIGN KEY (did) REFERENCES identity (did) ON DELETE CASCADE 257 + )", 258 + ) 259 + .execute(&db) 260 + .await 261 + .unwrap(); 262 + 263 + let (tx, _) = broadcast::channel(1); 264 + let listener = TcpListener::bind("127.0.0.1:0") 265 + .await 266 + .expect("Must be able to bind a socket"); 267 + 268 + let addr = listener 269 + .local_addr() 270 + .expect("Listener must have a local socket"); 271 + 272 + ( 273 + Pds { 274 + inner: Arc::new(Inner { addr, db, tx }), 275 + }, 276 + listener, 277 + ) 278 + }