don't
5
fork

Configure Feed

Select the types of activity you want to include in your feed.

feat(knot): forks! fork! forks!

Signed-off-by: tjh <x@tjh.dev>

tjh 7b27a745 283404ae

+438 -221
-68
.sqlx/query-a780b4181c75ad196c735e61b92fb3c6d2e5c3946a6344927a17e4e4115c6c1e.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "INSERT INTO repository (did, rkey, rev, cid, name, knot, spindle, source, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT (did, rkey) DO UPDATE SET rev = excluded.rev, cid = excluded.cid, knot = excluded.knot, name = excluded.name, created_at = excluded.created_at WHERE excluded.rev > repository.rev AND excluded.cid <> '' RETURNING *", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "name": "did", 8 - "ordinal": 0, 9 - "type_info": "Text" 10 - }, 11 - { 12 - "name": "rkey", 13 - "ordinal": 1, 14 - "type_info": "Text" 15 - }, 16 - { 17 - "name": "rev", 18 - "ordinal": 2, 19 - "type_info": "Text" 20 - }, 21 - { 22 - "name": "cid", 23 - "ordinal": 3, 24 - "type_info": "Text" 25 - }, 26 - { 27 - "name": "name", 28 - "ordinal": 4, 29 - "type_info": "Text" 30 - }, 31 - { 32 - "name": "knot", 33 - "ordinal": 5, 34 - "type_info": "Text" 35 - }, 36 - { 37 - "name": "spindle", 38 - "ordinal": 6, 39 - "type_info": "Text" 40 - }, 41 - { 42 - "name": "source", 43 - "ordinal": 7, 44 - "type_info": "Text" 45 - }, 46 - { 47 - "name": "created_at", 48 - "ordinal": 8, 49 - "type_info": "Datetime" 50 - } 51 - ], 52 - "parameters": { 53 - "Right": 9 54 - }, 55 - "nullable": [ 56 - false, 57 - false, 58 - false, 59 - false, 60 - false, 61 - false, 62 - true, 63 - true, 64 - false 65 - ] 66 - }, 67 - "hash": "a780b4181c75ad196c735e61b92fb3c6d2e5c3946a6344927a17e4e4115c6c1e" 68 - }
+12
.sqlx/query-c833b5e0c7cd34716fce03a77cdb48ffd19fd5537ce91a28fdd9fd9c0d3c7f03.json
··· 1 + { 2 + "db_name": "SQLite", 3 + "query": "INSERT INTO repository (did, rkey, rev, cid, name, knot, spindle, source, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", 4 + "describe": { 5 + "columns": [], 6 + "parameters": { 7 + "Right": 9 8 + }, 9 + "nullable": [] 10 + }, 11 + "hash": "c833b5e0c7cd34716fce03a77cdb48ffd19fd5537ce91a28fdd9fd9c0d3c7f03" 12 + }
+8
crates/identity/src/lib.rs
··· 122 122 } 123 123 124 124 impl Resolver { 125 + pub fn new<R>(resolver: R) -> Self 126 + where 127 + R: ResolveIdentity + Send + 'static, 128 + { 129 + let inner = Arc::new(resolver); 130 + Self { inner } 131 + } 132 + 125 133 pub fn builder() -> ResolverBuilder { 126 134 ResolverBuilder::new() 127 135 }
+285 -141
crates/knot/src/lib.rs
··· 10 10 pub mod public; 11 11 pub mod services; 12 12 pub mod types; 13 + mod util; 13 14 14 15 pub async fn serve_all( 15 16 router: Router, ··· 40 39 41 40 #[cfg(test)] 42 41 mod tests { 43 - use std::{borrow::Cow, path::Path}; 42 + use std::{ 43 + borrow::Cow, 44 + collections::HashMap, 45 + net::{SocketAddr, TcpListener}, 46 + path::Path, 47 + sync::{Arc, Mutex}, 48 + }; 44 49 45 50 use atproto::{Did, did::OwnedDid, tid::Tid}; 46 51 use auth::{ ··· 61 54 use axum::{ 62 55 Json, Router, 63 56 body::Body, 64 - extract::Query, 57 + extract::{Query, State}, 65 58 http::{Method, Request, StatusCode, header}, 66 59 }; 67 - use identity::{DidDocument, Resolver}; 60 + use futures_util::FutureExt; 61 + use identity::{DidDocument, ResolveIdentity, Resolver}; 68 62 use sqlx::SqlitePool; 69 63 use time::{OffsetDateTime, format_description::well_known::Rfc3339}; 70 - use tokio::net::TcpListener; 71 64 use tower::ServiceExt; 72 65 73 66 use crate::{ ··· 75 68 services::database::DataStore, 76 69 }; 77 70 78 - fn make_knot_with(base: &Path, pool: SqlitePool, resolver: Resolver) -> Knot { 71 + fn make_knot_with<R: ResolveIdentity + Send + 'static>( 72 + base: &Path, 73 + pool: SqlitePool, 74 + resolver: R, 75 + ) -> Knot { 79 76 let config = KnotConfiguration::new( 80 77 OwnedDid::from_static("did:plc:65gha4t3avpfpzmvpbwovss7"), 81 78 OwnedDid::from_static("did:web:test"), ··· 87 76 ); 88 77 89 78 let database = DataStore::new(pool); 79 + let resolver = Resolver::new(resolver); 90 80 91 81 Knot::new(config, resolver, reqwest::Client::new(), database, []).unwrap() 92 82 } ··· 103 91 104 92 fn get(uri: &str) -> Request<Body> { 105 93 Request::builder().uri(uri).body(Body::empty()).unwrap() 106 - } 107 - 108 - fn mock_user(did: &Did, handle: &str, pds: Option<&str>) -> (DidDocument, EcdsaKeyPair) { 109 - let mut doc = DidDocument::new(did, handle).expect("valid did for did document"); 110 - if let Some(pds) = pds { 111 - doc.service.push(identity::Service { 112 - id: "#atproto_pds".to_string(), 113 - typ: "AtprotoPersonalDataServer".to_string(), 114 - service_endpoint: pds.parse().expect("valid url"), 115 - }); 116 - } 117 - 118 - // Generate a key pair and encode the public key as verification method for 119 - // the mock user. 120 - let key_pair = EcdsaKeyPair::generate(&ECDSA_P256K1_SHA256_FIXED_SIGNING).unwrap(); 121 - let public_key: EcPublicKeyCompressedBin = key_pair.public_key().as_be_bytes().unwrap(); 122 - let mut key_data = vec![0xe7, 0x01]; 123 - key_data.extend_from_slice(public_key.as_ref()); 124 - let pk = PublicKey::from_multicodec(&key_data).unwrap(); 125 - doc.verification_method 126 - .push(identity::VerificationMethod::Multikey { 127 - id: format!("{}#atproto", doc.id), 128 - controller: doc.id.clone(), 129 - public_key_multibase: pk.to_multibase(), 130 - }); 131 - 132 - (doc, key_pair) 133 - } 134 - 135 - fn service_auth(iss: &Did, aud: &Did, lxm: &str, key_pair: &EcdsaKeyPair) -> String { 136 - // Issue a service auth token for the mocked user. 137 - let jti: [u8; 16] = rand::random(); 138 - let jti = data_encoding::BASE32_NOPAD_VISUAL 139 - .encode(&jti) 140 - .to_lowercase(); 141 - 142 - let claims = Claims { 143 - iss: iss.into(), 144 - aud: aud.into(), 145 - iat: OffsetDateTime::now_utc().unix_timestamp(), 146 - exp: OffsetDateTime::now_utc().unix_timestamp() + 10, 147 - lxm: Some(lxm.try_into().unwrap()), 148 - jti: jti.into(), 149 - }; 150 - 151 - let header = Header { 152 - typ: auth::jwt::Type::JWT, 153 - alg: auth::jwt::Algorithm::ES256K, 154 - crv: None, 155 - }; 156 - 157 - let header = serde_json::to_vec(&header).unwrap(); 158 - let header = data_encoding::BASE64URL_NOPAD.encode(&header); 159 - 160 - let claims = serde_json::to_vec(&claims).unwrap(); 161 - let claims = data_encoding::BASE64URL_NOPAD.encode(&claims); 162 - 163 - let mut token = String::new(); 164 - token.push_str(&header); 165 - token.push('.'); 166 - token.push_str(&claims); 167 - 168 - let signature = key_pair 169 - .sign(&SystemRandom::new(), token.as_bytes()) 170 - .unwrap(); 171 - 172 - let signature = data_encoding::BASE64URL_NOPAD.encode(signature.as_ref()); 173 - token.push('.'); 174 - token.push_str(&signature); 175 - 176 - format!("Bearer {token}") 177 94 } 178 95 179 96 #[tokio::test] ··· 191 250 assert_eq!(response.status(), StatusCode::METHOD_NOT_ALLOWED); 192 251 } 193 252 194 - async fn create_repo<F>( 195 - knot: Knot, 196 - did: &Did, 197 - repo_name: &str, 198 - listener: TcpListener, 199 - svc_auth: F, 200 - ) -> String 201 - where 202 - F: Fn(&Did, &Did, &str) -> String, 203 - { 204 - let rkey = Tid::from_datetime(OffsetDateTime::now_utc(), 0).to_string(); 253 + #[derive(Clone, Debug)] 254 + struct Pds { 255 + addr: Option<SocketAddr>, 256 + identities: Arc<Mutex<Vec<(DidDocument, EcdsaKeyPair)>>>, 257 + documents: Arc<Mutex<HashMap<String, serde_json::Value>>>, 258 + } 205 259 206 - { 260 + impl Pds { 261 + fn new() -> Self { 262 + Self { 263 + addr: None, 264 + identities: Default::default(), 265 + documents: Default::default(), 266 + } 267 + } 268 + 269 + fn get_document(&self, uri: &str) -> Option<serde_json::Value> { 270 + self.documents.lock().unwrap().get(uri).cloned() 271 + } 272 + 273 + fn set_document( 274 + &self, 275 + did: &Did, 276 + collection: &str, 277 + rkey: &str, 278 + doc: serde_json::Value, 279 + ) -> String { 280 + let uri = format!("at://{did}/{collection}/{rkey}"); 281 + self.documents.lock().unwrap().insert(uri.clone(), doc); 282 + uri 283 + } 284 + 285 + fn mock_user(&self, did: &Did, handle: &str) { 286 + let mut doc = DidDocument::new(did, handle).expect("valid did for did document"); 287 + if let Some(pds_address) = self.addr { 288 + doc.service.push(identity::Service { 289 + id: "#atproto_pds".to_string(), 290 + typ: "AtprotoPersonalDataServer".to_string(), 291 + service_endpoint: format!("http://{pds_address}/").parse().expect("valid url"), 292 + }); 293 + } 294 + 295 + // Generate a key pair and encode the public key as verification method for 296 + // the mock user. 297 + let key_pair = EcdsaKeyPair::generate(&ECDSA_P256K1_SHA256_FIXED_SIGNING).unwrap(); 298 + let public_key: EcPublicKeyCompressedBin = key_pair.public_key().as_be_bytes().unwrap(); 299 + let mut key_data = vec![0xe7, 0x01]; 300 + key_data.extend_from_slice(public_key.as_ref()); 301 + let pk = PublicKey::from_multicodec(&key_data).unwrap(); 302 + doc.verification_method 303 + .push(identity::VerificationMethod::Multikey { 304 + id: format!("{}#atproto", doc.id), 305 + controller: doc.id.clone(), 306 + public_key_multibase: pk.to_multibase(), 307 + }); 308 + 309 + self.identities.lock().unwrap().push((doc, key_pair)); 310 + } 311 + 312 + // Issue a service auth token for the mocked user. 313 + fn service_auth(&self, iss: &Did, aud: &Did, lxm: &str) -> String { 314 + let jti: [u8; 16] = rand::random(); 315 + let jti = data_encoding::BASE32_NOPAD_VISUAL 316 + .encode(&jti) 317 + .to_lowercase(); 318 + 319 + let claims = Claims { 320 + iss: iss.into(), 321 + aud: aud.into(), 322 + iat: OffsetDateTime::now_utc().unix_timestamp(), 323 + exp: OffsetDateTime::now_utc().unix_timestamp() + 10, 324 + lxm: Some(lxm.try_into().unwrap()), 325 + jti: jti.into(), 326 + }; 327 + 328 + let header = Header { 329 + typ: auth::jwt::Type::JWT, 330 + alg: auth::jwt::Algorithm::ES256K, 331 + crv: None, 332 + }; 333 + 334 + let header = serde_json::to_vec(&header).unwrap(); 335 + let header = data_encoding::BASE64URL_NOPAD.encode(&header); 336 + 337 + let claims = serde_json::to_vec(&claims).unwrap(); 338 + let claims = data_encoding::BASE64URL_NOPAD.encode(&claims); 339 + 340 + let mut token = String::new(); 341 + token.push_str(&header); 342 + token.push('.'); 343 + token.push_str(&claims); 344 + 345 + let guard = self.identities.lock().unwrap(); 346 + let key_pair = guard 347 + .iter() 348 + .find(|(doc, _)| doc.id == iss) 349 + .map(|(_, key_pair)| key_pair) 350 + .unwrap(); 351 + 352 + let signature = key_pair 353 + .sign(&SystemRandom::new(), token.as_bytes()) 354 + .unwrap(); 355 + 356 + let signature = data_encoding::BASE64URL_NOPAD.encode(signature.as_ref()); 357 + token.push('.'); 358 + token.push_str(&signature); 359 + 360 + format!("Bearer {token}") 361 + } 362 + 363 + fn serve(&mut self) { 207 364 #[derive(serde::Deserialize)] 208 365 struct GetRecord { 209 366 repo: String, ··· 309 270 rkey: String, 310 271 } 311 272 312 - let did = did.to_owned(); 313 - let rkey = rkey.clone(); 314 - let name = repo_name.to_owned(); 315 - let state = knot.clone(); 316 - let pds = Router::new().route( 317 - "/xrpc/com.atproto.repo.getRecord", 318 - axum::routing::get(async move |Query(params): Query<GetRecord>| { 319 - assert_eq!(params.repo, did.as_str()); 320 - assert_eq!(params.collection, "sh.tangled.repo"); 321 - assert_eq!(params.rkey, rkey); 322 - Json(serde_json::json!({ 323 - "uri": format!("at://{did}/sh.tangled.repo/{rkey}"), 324 - "cid": "some_cid", 325 - "value": { 326 - "name": name, 327 - "knot": state.instance_ident(), 328 - "createdAt": OffsetDateTime::now_utc().format(&Rfc3339).unwrap() 329 - } 330 - })) 331 - }), 332 - ); 273 + let pds = Router::new() 274 + .route( 275 + "/xrpc/com.atproto.repo.getRecord", 276 + axum::routing::get( 277 + async move |State(state): State<Pds>, 278 + Query(GetRecord { 279 + repo, 280 + collection, 281 + rkey, 282 + }): Query<GetRecord>| { 283 + Json(state.get_document(&format!("at://{repo}/{collection}/{rkey}"))) 284 + }, 285 + ), 286 + ) 287 + .with_state(self.clone()); 288 + 289 + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); 290 + listener.set_nonblocking(true).unwrap(); 291 + self.addr = Some(listener.local_addr().unwrap()); 333 292 334 293 tokio::spawn(async move { 335 - axum::serve(listener, pds).await.unwrap(); 294 + axum::serve(tokio::net::TcpListener::from_std(listener).unwrap(), pds) 295 + .await 296 + .unwrap(); 336 297 }); 337 298 } 299 + } 300 + 301 + impl ResolveIdentity for Pds { 302 + fn resolve_handle<'s: 'h, 'h>( 303 + &'s self, 304 + handle: &'h str, 305 + ) -> futures_util::future::BoxFuture<'h, Result<OwnedDid, identity::ResolveError>> { 306 + async move { 307 + self.identities 308 + .lock() 309 + .unwrap() 310 + .iter() 311 + .find_map(|(doc, _)| { 312 + match doc.primary_alias().is_some_and(|alias| alias == handle) { 313 + true => Some(doc.id.clone()), 314 + false => None, 315 + } 316 + }) 317 + .ok_or(identity::ResolveError::UnresolvedHandle) 318 + } 319 + .boxed() 320 + } 321 + 322 + fn resolve_did<'s: 'd, 'd>( 323 + &'s self, 324 + did: &'d Did, 325 + ) -> futures_util::future::BoxFuture<'d, Result<DidDocument, identity::ResolveError>> 326 + { 327 + async move { 328 + self.identities 329 + .lock() 330 + .unwrap() 331 + .iter() 332 + .find_map(|(doc, _)| match doc.id == did { 333 + true => Some(doc.clone()), 334 + false => None, 335 + }) 336 + .ok_or(identity::ResolveError::UnresolvedHandle) 337 + } 338 + .boxed() 339 + } 340 + } 341 + 342 + async fn create_repo( 343 + knot: Knot, 344 + pds: Pds, 345 + did: &Did, 346 + repo_name: &str, 347 + source: Option<&str>, 348 + ) -> String { 349 + // Create fake PDS record for our new repository. 350 + let rkey = Tid::from_datetime(OffsetDateTime::now_utc(), 0).to_string(); 351 + pds.set_document( 352 + did, 353 + "sh.tangled.repo", 354 + &rkey, 355 + serde_json::json!({ 356 + "uri": format!("at://{did}/sh.tangled.repo/{rkey}"), 357 + "cid": "some_cid", 358 + "value": { 359 + "name": repo_name, 360 + "knot": knot.instance_ident(), 361 + "source": source, 362 + "createdAt": OffsetDateTime::now_utc().format(&Rfc3339).unwrap() 363 + } 364 + }), 365 + ); 338 366 339 367 // Insert our mock user as a knot member. 340 - 341 368 knot.database() 342 369 .upsert_knot_member( 343 370 "", ··· 419 314 .expect("knot member inserted into db"); 420 315 421 316 // Generate the body of the 'sh.tangled.repo.create' request. 422 - 423 317 let create = lexicon::sh::tangled::repo::create::Input { 424 318 rkey: Cow::Borrowed(&rkey), 425 319 default_branch: Some("main".into()), ··· 442 338 StatusCode::UNAUTHORIZED 443 339 ); 444 340 445 - let auth = svc_auth(&did, &knot.instance, "sh.tangled.repo"); 341 + let auth = pds.service_auth(&did, &knot.instance, "sh.tangled.repo"); 446 342 assert_eq!( 447 343 super::public::router() 448 344 .with_state(knot.clone()) ··· 462 358 ); 463 359 464 360 // Valid auth, empty request body. 465 - let auth = svc_auth(&did, &knot.instance, "sh.tangled.repo.create"); 361 + let auth = pds.service_auth(&did, &knot.instance, "sh.tangled.repo.create"); 466 362 assert_eq!( 467 363 super::public::router() 468 364 .with_state(knot.clone()) ··· 481 377 StatusCode::BAD_REQUEST 482 378 ); 483 379 484 - let auth = svc_auth(&did, &knot.instance, "sh.tangled.repo.create"); 380 + let auth = pds.service_auth(&did, &knot.instance, "sh.tangled.repo.create"); 485 381 assert_eq!( 486 382 super::public::router() 487 383 .with_state(knot.clone()) ··· 517 413 let pool = SqlitePool::connect("sqlite://:memory:").await.unwrap(); 518 414 sqlx::migrate!().run(&pool).await.unwrap(); 519 415 520 - let listener = TcpListener::bind("127.0.0.1:0") 521 - .await 522 - .expect("tcp listener"); 416 + let mut pds = Pds::new(); 417 + pds.serve(); 523 418 524 - let socket = listener.local_addr().expect("bound local socket"); 419 + pds.mock_user(did, "tjh.dev"); 420 + let knot = make_knot_with(base.path(), pool, pds.clone()); 421 + create_repo(knot, pds, did, "test-repo", None).await; 422 + } 525 423 526 - let (doc, key_pair) = mock_user(did, "tjh.dev", Some(&format!("http://{socket}/"))); 527 - let (resolver, _) = Resolver::mocked([doc]); 528 - let knot = make_knot_with(base.path(), pool, resolver); 424 + #[tokio::test] 425 + async fn xrpc_sh_tangled_repo_create_fork_atrepo() { 426 + let base = tempfile::tempdir().expect("temporary directory"); 427 + let did = Did::from_static("did:plc:65gha4t3avpfpzmvpbwovss7"); 529 428 530 - create_repo(knot, did, "test-repo", listener, |iss, aud, lxm| { 531 - service_auth(iss, aud, lxm, &key_pair) 532 - }) 429 + let pool = SqlitePool::connect("sqlite://:memory:").await.unwrap(); 430 + sqlx::migrate!().run(&pool).await.unwrap(); 431 + 432 + let mut pds = Pds::new(); 433 + pds.serve(); 434 + 435 + // Create a record for the repository to fork from. 436 + // <https://pdsls.dev/at://did:plc:65gha4t3avpfpzmvpbwovss7/sh.tangled.repo/3m24udbjajf22#record> 437 + let aturi = pds.set_document( 438 + did, 439 + "sh.tangled.repo", 440 + "3m24udbjajf22", 441 + serde_json::json!({ 442 + "uri": format!("at://{did}/sh.tangled.repo/3m24udbjajf22"), 443 + "cid": "some_cid", 444 + "value": { 445 + "name": "gordian", 446 + "knot": "gordian.tjh.dev", 447 + "createdAt": "2025-10-01T10:45:52Z" 448 + } 449 + }), 450 + ); 451 + 452 + pds.mock_user(did, "tjh.dev"); 453 + let knot = make_knot_with(base.path(), pool, pds.clone()); 454 + create_repo(knot, pds, did, "test-repo", Some(&aturi)).await; 455 + } 456 + 457 + #[tokio::test] 458 + async fn xrpc_sh_tangled_repo_create_fork_plain() { 459 + let base = tempfile::tempdir().expect("temporary directory"); 460 + let did = Did::from_static("did:plc:65gha4t3avpfpzmvpbwovss7"); 461 + 462 + let pool = SqlitePool::connect("sqlite://:memory:").await.unwrap(); 463 + sqlx::migrate!().run(&pool).await.unwrap(); 464 + 465 + let mut pds = Pds::new(); 466 + pds.serve(); 467 + 468 + pds.mock_user(did, "tjh.dev"); 469 + let knot = make_knot_with(base.path(), pool, pds.clone()); 470 + create_repo( 471 + knot, 472 + pds, 473 + did, 474 + "test-repo", 475 + Some("https://gordian.tjh.dev/did:plc:65gha4t3avpfpzmvpbwovss7/3m24udbjajf22"), 476 + ) 533 477 .await; 534 478 } 535 479 ··· 590 438 let pool = SqlitePool::connect("sqlite://:memory:").await.unwrap(); 591 439 sqlx::migrate!().run(&pool).await.unwrap(); 592 440 593 - let listener = TcpListener::bind("127.0.0.1:0") 594 - .await 595 - .expect("tcp listener"); 441 + let mut pds = Pds::new(); 442 + pds.serve(); 596 443 597 - let socket = listener.local_addr().expect("bound local socket"); 598 - 599 - let (doc, key_pair) = mock_user(did, "tjh.dev", Some(&format!("http://{socket}/"))); 600 - let (resolver, _) = Resolver::mocked([doc]); 601 - let knot = make_knot_with(base.path(), pool, resolver); 602 - 444 + pds.mock_user(did, "tjh.dev"); 445 + let knot = make_knot_with(base.path(), pool, pds.clone()); 603 446 let name = "another-test-repo"; 604 - let rkey = create_repo(knot.clone(), did, name, listener, |iss, aud, lxm| { 605 - service_auth(iss, aud, lxm, &key_pair) 606 - }) 607 - .await; 447 + let rkey = create_repo(knot.clone(), pds.clone(), did, name, None).await; 608 448 609 449 gix::open(base.path().join(did.as_str()).join(&rkey)).expect("new repository should exist"); 610 450 ··· 632 488 ); 633 489 634 490 // Or with the wrong auth. 635 - let auth = service_auth(&did, &knot.instance(), "sh.tangled.repo.create", &key_pair); 491 + let auth = pds.service_auth(&did, &knot.instance(), "sh.tangled.repo.create"); 636 492 assert_eq!( 637 493 super::public::router() 638 494 .with_state(knot.clone()) ··· 652 508 ); 653 509 654 510 // Valid auth, empty request body. 655 - let auth = service_auth(&did, &knot.instance(), "sh.tangled.repo.delete", &key_pair); 511 + let auth = pds.service_auth(&did, &knot.instance(), "sh.tangled.repo.delete"); 656 512 assert_eq!( 657 513 super::public::router() 658 514 .with_state(knot.clone()) ··· 671 527 StatusCode::BAD_REQUEST 672 528 ); 673 529 674 - let auth = service_auth(&did, &knot.instance(), "sh.tangled.repo.delete", &key_pair); 530 + let auth = pds.service_auth(&did, &knot.instance(), "sh.tangled.repo.delete"); 675 531 assert_eq!( 676 532 super::public::router() 677 533 .with_state(knot.clone())
+112 -6
crates/knot/src/model/knot_state.rs
··· 4 4 net::SocketAddr, 5 5 ops, 6 6 path::PathBuf, 7 + process::Stdio, 7 8 sync::{Arc, Mutex}, 8 9 time::Duration, 9 10 }; 10 11 11 - use atproto::did::Did; 12 + use atproto::{aturi::AtUri, did::Did}; 12 13 use futures_util::{FutureExt, future::BoxFuture}; 13 14 use identity::{HttpClient, Resolver}; 14 - use lexicon::sh::tangled::{git::RefUpdate, repo::Repo}; 15 + use lexicon::{ 16 + com::atproto::repo::list_records::Record, 17 + sh::tangled::{git::RefUpdate, repo::Repo}, 18 + }; 15 19 use moka::future::{Cache, CacheBuilder}; 16 20 use rayon::{ThreadPool, ThreadPoolBuilder}; 17 21 use serde::Serialize; 18 22 use time::OffsetDateTime; 23 + use tokio::process::Command; 24 + use url::Url; 19 25 20 26 use crate::{ 27 + release_or_debug, 21 28 services::{ 29 + atrepo, 22 30 authorization::{AuthorizationClaimsStore, AuthorizationClaimsStoreError}, 23 31 database::{DataStore, DataStoreError}, 24 32 }, ··· 81 73 82 74 repo_cache: Cache<RepositoryKey, gix::ThreadSafeRepository>, 83 75 76 + repo_mutex: Mutex<HashMap<RepositoryKey, Arc<tokio::sync::Mutex<()>>>>, 77 + 84 78 push_seed: Mutex<HashMap<RepositoryKey, Box<str>>>, 85 79 86 80 private_addrs: String, ··· 127 117 .time_to_idle(Duration::from_secs(60)) 128 118 .build(), 129 119 repo_cache, 120 + repo_mutex: Default::default(), 130 121 push_seed: Default::default(), 131 122 private_addrs, 132 123 }); ··· 171 160 172 161 pub fn private_endpoints(&self) -> &str { 173 162 &self.private_addrs 163 + } 164 + 165 + pub fn get_repo_mutex(&self, repo_key: &RepositoryKey) -> Arc<tokio::sync::Mutex<()>> { 166 + Arc::clone( 167 + &self 168 + .repo_mutex 169 + .lock() 170 + .expect("mutex should not be poisoned") 171 + .entry(repo_key.clone()) 172 + .or_default(), 173 + ) 174 174 } 175 175 176 176 /// Resolve a repository path ({handle,did},{rkey,name}) to a repository key (did, rkey). ··· 272 250 assert_eq!(*collection, "sh.tangled.repo"); 273 251 assert_eq!(repo.knot, self.instance_ident()); 274 252 275 - repository_path::validate(&did)?; 276 - repository_path::validate(&rkey)?; 253 + let repo_key = RepositoryKey::new(*did, *rkey)?; 277 254 repository_path::validate(&repo.name)?; 255 + 256 + // We're going to receive the jetstream event and the xrpc request. 257 + // 258 + // If the other is already in progress, wait here. The database insert should return 259 + // Ok(false), and repository creation will be skipped. 260 + let _guard = self.get_repo_mutex(&repo_key).lock_owned().await; 278 261 279 262 let is_new = self 280 263 .database() ··· 290 263 return Ok(()); 291 264 } 292 265 293 - let repo_key = RepositoryKey::new(*did, *rkey)?; 294 - self.init_repo(&repo_key, &repo.name)?; 266 + match &repo.source { 267 + Some(source) => self.fork_repo(&repo_key, &repo.name, source).await?, 268 + None => self.init_repo(&repo_key, &repo.name)?, 269 + } 295 270 296 271 Ok(()) 297 272 } ··· 337 308 std::os::unix::fs::symlink(&repo_key.rkey, &symlink_path)?; 338 309 339 310 Ok(()) 311 + } 312 + 313 + #[tracing::instrument(skip(self), ret)] 314 + pub async fn fork_repo( 315 + &self, 316 + repo_key: &RepositoryKey, 317 + name: &str, 318 + source: &str, 319 + ) -> anyhow::Result<()> { 320 + // Release build: only clone over https; Debug builds: try https then http. 321 + release_or_debug!(const CLONE_SCHEMES: &[&str] = &["https"], &["https", "http"]); 322 + 323 + let path = self.path_for_repository(repo_key); 324 + tracing::debug!(?path, "forking into"); 325 + 326 + let clone_urls: Vec<_> = match AtUri::parse(source) { 327 + Ok(source_uri) => { 328 + let source_did = source_uri.did().ok_or(anyhow::anyhow!( 329 + "source repository record uri does not contain a did authority " 330 + ))?; 331 + let source_rkey = source_uri.rkey.ok_or(anyhow::anyhow!( 332 + "source repository record uri does not contain a rkey" 333 + ))?; 334 + 335 + // Fetch repository record from pds. 336 + let response = atrepo::fetch_record_bytes( 337 + self.resolver(), 338 + self.http(), 339 + source_did, 340 + "sh.tangled.repo", 341 + source_rkey, 342 + ) 343 + .await?; 344 + 345 + let record = serde_json::from_slice::<Record>(&response)?; 346 + let repo: Repo = serde_json::from_str(record.value.get())?; 347 + 348 + CLONE_SCHEMES 349 + .iter() 350 + .map(|scheme| format!("{scheme}://{}/{source_did}/{}", repo.knot, repo.name)) 351 + .collect() 352 + } 353 + Err(_) => match Url::parse(source) { 354 + Ok(url) if CLONE_SCHEMES.contains(&url.scheme()) => vec![source.to_string()], 355 + _ => return Err(anyhow::anyhow!("Unrecognised URL: {source}")), 356 + }, 357 + }; 358 + 359 + for clone_url in clone_urls { 360 + tracing::info!("forking repo from '{clone_url}'"); 361 + let output = Command::new("/usr/bin/git") 362 + .env_clear() 363 + .args(["clone", "--bare", &clone_url]) 364 + .arg(&path) 365 + .stdout(Stdio::inherit()) 366 + .stderr(Stdio::inherit()) 367 + .output() 368 + .await?; 369 + 370 + if output.status.success() { 371 + // Create a symlink to map the repository name -> rkey. 372 + let symlink_path = path 373 + .parent() 374 + .expect("parent for repository path") 375 + .join(name); 376 + 377 + let _ = std::fs::remove_file(&symlink_path); 378 + let _ = std::os::unix::fs::symlink(&repo_key.rkey, &symlink_path); 379 + 380 + return Ok(()); 381 + } else { 382 + tracing::error!(?clone_url, ?path, "git clone failed"); 383 + } 384 + } 385 + 386 + tracing::error!("all clone attempts failed"); 387 + Err(anyhow::anyhow!("failed to fork repo")) 340 388 } 341 389 342 390 pub fn delete_repo(&self, did: &Did, rkey: &str) -> anyhow::Result<()> {
+12 -6
crates/knot/src/services/database.rs
··· 6 6 use jetstream::{Delete, Value}; 7 7 use lexicon::sh::tangled::{PublicKey, knot::Member, repo::Repo}; 8 8 use serde::Serialize; 9 - use sqlx::SqlitePool; 9 + use sqlx::{SqlitePool, error::ErrorKind}; 10 10 use time::OffsetDateTime; 11 11 use types::{DeletedRecord, EventRow}; 12 12 ··· 303 303 repository: &Repo<'_>, 304 304 ) -> Result<bool, DataStoreError> { 305 305 let result = sqlx::query!( 306 - "INSERT INTO repository (did, rkey, rev, cid, name, knot, spindle, source, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT (did, rkey) DO UPDATE SET rev = excluded.rev, cid = excluded.cid, knot = excluded.knot, name = excluded.name, created_at = excluded.created_at WHERE excluded.rev > repository.rev AND excluded.cid <> '' RETURNING *", 306 + "INSERT INTO repository (did, rkey, rev, cid, name, knot, spindle, source, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", 307 307 did, 308 308 rkey, 309 309 rev, ··· 313 313 repository.spindle, 314 314 repository.source, 315 315 repository.created_at 316 - ).fetch_optional(&self.db).await?; 316 + ).fetch_optional(&self.db).await; 317 317 318 - tracing::debug!(?result); 319 - 320 - Ok(result.is_some_and(|record| record.rev == rev && record.cid == cid)) 318 + match result { 319 + Ok(_) => Ok(true), 320 + Err(error) => match error.as_database_error() { 321 + Some(database_error) if database_error.kind() == ErrorKind::UniqueViolation => { 322 + Ok(false) 323 + } 324 + _ => Err(error)?, 325 + }, 326 + } 321 327 } 322 328 323 329 pub async fn update_repository(
+9
crates/knot/src/util.rs
··· 1 + #[macro_export] 2 + macro_rules! release_or_debug { 3 + (const $val:ident: $t:ty = $rel:expr, $dbg:expr) => { 4 + #[cfg(not(debug_assertions))] 5 + const $val: $t = $rel; 6 + #[cfg(debug_assertions)] 7 + const $val: $t = $dbg; 8 + }; 9 + }