learn and share notes on atproto (wip) 馃 malfestio.stormlightlabs.org/
readability solid axum atproto srs
5
fork

Configure Feed

Select the types of activity you want to include in your feed.

at main 444 lines 15 kB view raw
1//! Sync API endpoints for bi-directional PDS synchronization. 2//! 3//! Provides endpoints for pushing local changes to PDS, getting sync status, 4//! and resolving conflicts. 5 6use crate::middleware::auth::UserContext; 7use crate::state::SharedState; 8use crate::sync_service::{ConflictStrategy, SyncError, SyncService}; 9use axum::{ 10 Json, 11 extract::{Extension, Path, State}, 12 http::StatusCode, 13 response::IntoResponse, 14}; 15use serde::{Deserialize, Serialize}; 16use serde_json::json; 17use std::str::FromStr; 18 19/// Response for sync push operation. 20#[derive(Debug, Clone, Serialize)] 21pub struct PushResponse { 22 pub entity_type: String, 23 pub entity_id: String, 24 pub pds_uri: Option<String>, 25 pub pds_cid: Option<String>, 26 pub version: i32, 27 pub status: String, 28} 29 30/// Response for sync status query. 31#[derive(Debug, Clone, Serialize)] 32pub struct SyncStatusResponse { 33 pub pending_count: usize, 34 pub conflict_count: usize, 35 pub pending_items: Vec<PendingItem>, 36 pub conflicts: Vec<ConflictItem>, 37} 38 39#[derive(Debug, Clone, Serialize)] 40pub struct PendingItem { 41 pub entity_type: String, 42 pub entity_id: String, 43} 44 45#[derive(Debug, Clone, Serialize)] 46pub struct ConflictItem { 47 pub entity_type: String, 48 pub entity_id: String, 49 pub local_version: i32, 50 pub remote_version: Option<i32>, 51} 52 53/// Request for conflict resolution. 54#[derive(Debug, Clone, Deserialize)] 55pub struct ResolveConflictRequest { 56 pub strategy: String, 57} 58 59/// Push a deck to the user's PDS. 60/// 61/// POST /api/sync/push/deck/:id 62pub async fn push_deck( 63 State(state): State<SharedState>, ctx: Option<Extension<UserContext>>, Path(deck_id): Path<String>, 64) -> impl IntoResponse { 65 let user = match ctx { 66 Some(Extension(user)) => user, 67 None => return (StatusCode::UNAUTHORIZED, Json(json!({"error": "Unauthorized"}))).into_response(), 68 }; 69 70 let sync_service = create_sync_service(&state); 71 72 match sync_service.push_deck(&deck_id, &user).await { 73 Ok(result) => ( 74 StatusCode::OK, 75 Json(PushResponse { 76 entity_type: result.entity_type, 77 entity_id: result.entity_id, 78 pds_uri: result.pds_uri, 79 pds_cid: result.pds_cid, 80 version: result.new_version, 81 status: result.status.to_string(), 82 }), 83 ) 84 .into_response(), 85 Err(e) => sync_error_response(e), 86 } 87} 88 89/// Push a note to the user's PDS. 90/// 91/// POST /api/sync/push/note/:id 92pub async fn push_note( 93 State(state): State<SharedState>, ctx: Option<Extension<UserContext>>, Path(note_id): Path<String>, 94) -> impl IntoResponse { 95 let user = match ctx { 96 Some(Extension(user)) => user, 97 None => return (StatusCode::UNAUTHORIZED, Json(json!({"error": "Unauthorized"}))).into_response(), 98 }; 99 100 let sync_service = create_sync_service(&state); 101 102 match sync_service.push_note(&note_id, &user).await { 103 Ok(result) => ( 104 StatusCode::OK, 105 Json(PushResponse { 106 entity_type: result.entity_type, 107 entity_id: result.entity_id, 108 pds_uri: result.pds_uri, 109 pds_cid: result.pds_cid, 110 version: result.new_version, 111 status: result.status.to_string(), 112 }), 113 ) 114 .into_response(), 115 Err(e) => sync_error_response(e), 116 } 117} 118 119/// Get the current sync status for the authenticated user. 120/// 121/// GET /api/sync/status 122pub async fn get_sync_status( 123 State(state): State<SharedState>, ctx: Option<Extension<UserContext>>, 124) -> impl IntoResponse { 125 let user = match ctx { 126 Some(Extension(user)) => user, 127 None => return (StatusCode::UNAUTHORIZED, Json(json!({"error": "Unauthorized"}))).into_response(), 128 }; 129 130 let sync_service = create_sync_service(&state); 131 132 match sync_service.get_sync_status(&user).await { 133 Ok(summary) => ( 134 StatusCode::OK, 135 Json(SyncStatusResponse { 136 pending_count: summary.pending_count, 137 conflict_count: summary.conflict_count, 138 pending_items: summary 139 .pending_items 140 .into_iter() 141 .map(|(entity_type, entity_id)| PendingItem { entity_type, entity_id }) 142 .collect(), 143 conflicts: summary 144 .conflicts 145 .into_iter() 146 .map(|c| ConflictItem { 147 entity_type: c.entity_type, 148 entity_id: c.entity_id, 149 local_version: c.local_version, 150 remote_version: c.remote_version, 151 }) 152 .collect(), 153 }), 154 ) 155 .into_response(), 156 Err(e) => sync_error_response(e), 157 } 158} 159 160/// Resolve a sync conflict. 161/// 162/// POST /api/sync/resolve/:entity_type/:id 163pub async fn resolve_conflict( 164 State(state): State<SharedState>, ctx: Option<Extension<UserContext>>, 165 Path((entity_type, entity_id)): Path<(String, String)>, Json(payload): Json<ResolveConflictRequest>, 166) -> impl IntoResponse { 167 let user = match ctx { 168 Some(Extension(user)) => user, 169 None => return (StatusCode::UNAUTHORIZED, Json(json!({"error": "Unauthorized"}))).into_response(), 170 }; 171 172 let strategy = match ConflictStrategy::from_str(&payload.strategy) { 173 Ok(s) => s, 174 Err(_) => { 175 return ( 176 StatusCode::BAD_REQUEST, 177 Json(json!({"error": "Invalid strategy. Use: last_write_wins, keep_local, or keep_remote"})), 178 ) 179 .into_response(); 180 } 181 }; 182 183 let sync_service = create_sync_service(&state); 184 185 match sync_service 186 .resolve_conflict(&entity_type, &entity_id, strategy, &user) 187 .await 188 { 189 Ok(result) => ( 190 StatusCode::OK, 191 Json(PushResponse { 192 entity_type: result.entity_type, 193 entity_id: result.entity_id, 194 pds_uri: result.pds_uri, 195 pds_cid: result.pds_cid, 196 version: result.new_version, 197 status: result.status.to_string(), 198 }), 199 ) 200 .into_response(), 201 Err(e) => sync_error_response(e), 202 } 203} 204 205/// Create a SyncService from the app state. 206fn create_sync_service(state: &SharedState) -> SyncService { 207 SyncService::new( 208 state.sync_repo.clone(), 209 state.deck_repo.clone(), 210 state.card_repo.clone(), 211 state.note_repo.clone(), 212 state.oauth_repo.clone(), 213 ) 214} 215 216/// Convert SyncError to HTTP response. 217fn sync_error_response(error: SyncError) -> axum::response::Response { 218 let (status, message) = match &error { 219 SyncError::NotFound(msg) => (StatusCode::NOT_FOUND, msg.clone()), 220 SyncError::AuthRequired(msg) => (StatusCode::UNAUTHORIZED, msg.clone()), 221 SyncError::NoTokens(msg) => (StatusCode::UNAUTHORIZED, msg.clone()), 222 SyncError::InvalidArgument(msg) => (StatusCode::BAD_REQUEST, msg.clone()), 223 SyncError::ConflictDetected(info) => ( 224 StatusCode::CONFLICT, 225 format!("Conflict for {}:{}", info.entity_type, info.entity_id), 226 ), 227 SyncError::PdsError(e) => (StatusCode::BAD_GATEWAY, e.to_string()), 228 SyncError::RepoError(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()), 229 }; 230 231 tracing::error!("Sync error: {}", error); 232 (status, Json(json!({"error": message}))).into_response() 233} 234 235#[cfg(test)] 236mod tests { 237 use super::*; 238 239 #[test] 240 fn test_push_response_serialization() { 241 let response = PushResponse { 242 entity_type: "deck".to_string(), 243 entity_id: "123".to_string(), 244 pds_uri: Some("at://did:plc:test/deck/tid".to_string()), 245 pds_cid: Some("bafycid".to_string()), 246 version: 2, 247 status: "synced".to_string(), 248 }; 249 250 let json = serde_json::to_string(&response).unwrap(); 251 assert!(json.contains("\"entity_type\":\"deck\"")); 252 assert!(json.contains("\"version\":2")); 253 } 254 255 #[test] 256 fn test_sync_status_response_serialization() { 257 let response = SyncStatusResponse { 258 pending_count: 2, 259 conflict_count: 1, 260 pending_items: vec![ 261 PendingItem { entity_type: "deck".to_string(), entity_id: "1".to_string() }, 262 PendingItem { entity_type: "note".to_string(), entity_id: "2".to_string() }, 263 ], 264 conflicts: vec![ConflictItem { 265 entity_type: "deck".to_string(), 266 entity_id: "3".to_string(), 267 local_version: 5, 268 remote_version: Some(6), 269 }], 270 }; 271 272 let json = serde_json::to_string(&response).unwrap(); 273 assert!(json.contains("\"pending_count\":2")); 274 assert!(json.contains("\"conflict_count\":1")); 275 } 276 277 #[test] 278 fn test_resolve_conflict_request_deserialization() { 279 let json = r#"{"strategy": "last_write_wins"}"#; 280 let request: ResolveConflictRequest = serde_json::from_str(json).unwrap(); 281 assert_eq!(request.strategy, "last_write_wins"); 282 283 let json = r#"{"strategy": "keep_local"}"#; 284 let request: ResolveConflictRequest = serde_json::from_str(json).unwrap(); 285 assert_eq!(request.strategy, "keep_local"); 286 } 287 288 #[test] 289 fn test_sync_error_response_not_found() { 290 let error = SyncError::NotFound("deck:123".to_string()); 291 let response = sync_error_response(error); 292 assert_eq!(response.status(), StatusCode::NOT_FOUND); 293 } 294 295 #[test] 296 fn test_sync_error_response_unauthorized() { 297 let error = SyncError::AuthRequired("missing token".to_string()); 298 let response = sync_error_response(error); 299 assert_eq!(response.status(), StatusCode::UNAUTHORIZED); 300 } 301 302 #[test] 303 fn test_sync_error_response_bad_request() { 304 let error = SyncError::InvalidArgument("bad entity type".to_string()); 305 let response = sync_error_response(error); 306 assert_eq!(response.status(), StatusCode::BAD_REQUEST); 307 } 308 309 #[test] 310 fn test_sync_error_response_conflict() { 311 let error = SyncError::ConflictDetected(crate::sync_service::ConflictInfo { 312 entity_type: "deck".to_string(), 313 entity_id: "123".to_string(), 314 local_version: 5, 315 remote_version: Some(6), 316 local_updated_at: None, 317 remote_updated_at: None, 318 }); 319 let response = sync_error_response(error); 320 assert_eq!(response.status(), StatusCode::CONFLICT); 321 } 322 323 #[test] 324 fn test_pending_item_serialization() { 325 let item = PendingItem { entity_type: "note".to_string(), entity_id: "456".to_string() }; 326 327 let json = serde_json::to_string(&item).unwrap(); 328 assert!(json.contains("\"entity_type\":\"note\"")); 329 assert!(json.contains("\"entity_id\":\"456\"")); 330 } 331 332 #[test] 333 fn test_conflict_item_serialization() { 334 let item = ConflictItem { 335 entity_type: "deck".to_string(), 336 entity_id: "789".to_string(), 337 local_version: 3, 338 remote_version: Some(4), 339 }; 340 341 let json = serde_json::to_string(&item).unwrap(); 342 assert!(json.contains("\"local_version\":3")); 343 assert!(json.contains("\"remote_version\":4")); 344 } 345 346 #[test] 347 fn test_conflict_item_no_remote_version() { 348 let item = ConflictItem { 349 entity_type: "note".to_string(), 350 entity_id: "abc".to_string(), 351 local_version: 1, 352 remote_version: None, 353 }; 354 355 let json = serde_json::to_string(&item).unwrap(); 356 assert!(json.contains("\"remote_version\":null")); 357 } 358 359 #[tokio::test] 360 async fn test_push_deck_unauthorized() { 361 let pool = crate::db::create_mock_pool(); 362 let repos = crate::state::Repositories::default(); 363 let config = crate::state::AppConfig { pds_url: "https://test.example.com".to_string() }; 364 let state = crate::state::AppState::new(pool, repos, config); 365 366 let response = push_deck(State(state), None, Path("deck-123".to_string())) 367 .await 368 .into_response(); 369 370 assert_eq!(response.status(), StatusCode::UNAUTHORIZED); 371 } 372 373 #[tokio::test] 374 async fn test_push_note_unauthorized() { 375 let pool = crate::db::create_mock_pool(); 376 let repos = crate::state::Repositories::default(); 377 let config = crate::state::AppConfig { pds_url: "https://test.example.com".to_string() }; 378 let state = crate::state::AppState::new(pool, repos, config); 379 380 let response = push_note(State(state), None, Path("note-456".to_string())) 381 .await 382 .into_response(); 383 384 assert_eq!(response.status(), StatusCode::UNAUTHORIZED); 385 } 386 387 #[tokio::test] 388 async fn test_get_sync_status_unauthorized() { 389 let pool = crate::db::create_mock_pool(); 390 let repos = crate::state::Repositories::default(); 391 let config = crate::state::AppConfig { pds_url: "https://test.example.com".to_string() }; 392 let state = crate::state::AppState::new(pool, repos, config); 393 394 let response = get_sync_status(State(state), None).await.into_response(); 395 396 assert_eq!(response.status(), StatusCode::UNAUTHORIZED); 397 } 398 399 #[tokio::test] 400 async fn test_resolve_conflict_unauthorized() { 401 let pool = crate::db::create_mock_pool(); 402 let repos = crate::state::Repositories::default(); 403 let config = crate::state::AppConfig { pds_url: "https://test.example.com".to_string() }; 404 let state = crate::state::AppState::new(pool, repos, config); 405 406 let response = resolve_conflict( 407 State(state), 408 None, 409 Path(("deck".to_string(), "123".to_string())), 410 Json(ResolveConflictRequest { strategy: "last_write_wins".to_string() }), 411 ) 412 .await 413 .into_response(); 414 415 assert_eq!(response.status(), StatusCode::UNAUTHORIZED); 416 } 417 418 #[tokio::test] 419 async fn test_resolve_conflict_invalid_strategy() { 420 let pool = crate::db::create_mock_pool(); 421 let repos = crate::state::Repositories::default(); 422 let config = crate::state::AppConfig { pds_url: "https://test.example.com".to_string() }; 423 let state = crate::state::AppState::new(pool, repos, config); 424 425 let user = UserContext { 426 did: "did:plc:alice".to_string(), 427 handle: "alice.bsky.social".to_string(), 428 access_token: "test_token".to_string(), 429 pds_url: "https://bsky.social".to_string(), 430 has_dpop: false, 431 }; 432 433 let response = resolve_conflict( 434 State(state), 435 Some(Extension(user)), 436 Path(("deck".to_string(), "123".to_string())), 437 Json(ResolveConflictRequest { strategy: "invalid_strategy".to_string() }), 438 ) 439 .await 440 .into_response(); 441 442 assert_eq!(response.status(), StatusCode::BAD_REQUEST); 443 } 444}