learn and share notes on atproto (wip) 馃
malfestio.stormlightlabs.org/
readability
solid
axum
atproto
srs
1//! Sync API endpoints for bi-directional PDS synchronization.
2//!
3//! Provides endpoints for pushing local changes to PDS, getting sync status,
4//! and resolving conflicts.
5
6use crate::middleware::auth::UserContext;
7use crate::state::SharedState;
8use crate::sync_service::{ConflictStrategy, SyncError, SyncService};
9use axum::{
10 Json,
11 extract::{Extension, Path, State},
12 http::StatusCode,
13 response::IntoResponse,
14};
15use serde::{Deserialize, Serialize};
16use serde_json::json;
17use std::str::FromStr;
18
19/// Response for sync push operation.
20#[derive(Debug, Clone, Serialize)]
21pub struct PushResponse {
22 pub entity_type: String,
23 pub entity_id: String,
24 pub pds_uri: Option<String>,
25 pub pds_cid: Option<String>,
26 pub version: i32,
27 pub status: String,
28}
29
30/// Response for sync status query.
31#[derive(Debug, Clone, Serialize)]
32pub struct SyncStatusResponse {
33 pub pending_count: usize,
34 pub conflict_count: usize,
35 pub pending_items: Vec<PendingItem>,
36 pub conflicts: Vec<ConflictItem>,
37}
38
39#[derive(Debug, Clone, Serialize)]
40pub struct PendingItem {
41 pub entity_type: String,
42 pub entity_id: String,
43}
44
45#[derive(Debug, Clone, Serialize)]
46pub struct ConflictItem {
47 pub entity_type: String,
48 pub entity_id: String,
49 pub local_version: i32,
50 pub remote_version: Option<i32>,
51}
52
53/// Request for conflict resolution.
54#[derive(Debug, Clone, Deserialize)]
55pub struct ResolveConflictRequest {
56 pub strategy: String,
57}
58
59/// Push a deck to the user's PDS.
60///
61/// POST /api/sync/push/deck/:id
62pub async fn push_deck(
63 State(state): State<SharedState>, ctx: Option<Extension<UserContext>>, Path(deck_id): Path<String>,
64) -> impl IntoResponse {
65 let user = match ctx {
66 Some(Extension(user)) => user,
67 None => return (StatusCode::UNAUTHORIZED, Json(json!({"error": "Unauthorized"}))).into_response(),
68 };
69
70 let sync_service = create_sync_service(&state);
71
72 match sync_service.push_deck(&deck_id, &user).await {
73 Ok(result) => (
74 StatusCode::OK,
75 Json(PushResponse {
76 entity_type: result.entity_type,
77 entity_id: result.entity_id,
78 pds_uri: result.pds_uri,
79 pds_cid: result.pds_cid,
80 version: result.new_version,
81 status: result.status.to_string(),
82 }),
83 )
84 .into_response(),
85 Err(e) => sync_error_response(e),
86 }
87}
88
89/// Push a note to the user's PDS.
90///
91/// POST /api/sync/push/note/:id
92pub async fn push_note(
93 State(state): State<SharedState>, ctx: Option<Extension<UserContext>>, Path(note_id): Path<String>,
94) -> impl IntoResponse {
95 let user = match ctx {
96 Some(Extension(user)) => user,
97 None => return (StatusCode::UNAUTHORIZED, Json(json!({"error": "Unauthorized"}))).into_response(),
98 };
99
100 let sync_service = create_sync_service(&state);
101
102 match sync_service.push_note(¬e_id, &user).await {
103 Ok(result) => (
104 StatusCode::OK,
105 Json(PushResponse {
106 entity_type: result.entity_type,
107 entity_id: result.entity_id,
108 pds_uri: result.pds_uri,
109 pds_cid: result.pds_cid,
110 version: result.new_version,
111 status: result.status.to_string(),
112 }),
113 )
114 .into_response(),
115 Err(e) => sync_error_response(e),
116 }
117}
118
119/// Get the current sync status for the authenticated user.
120///
121/// GET /api/sync/status
122pub async fn get_sync_status(
123 State(state): State<SharedState>, ctx: Option<Extension<UserContext>>,
124) -> impl IntoResponse {
125 let user = match ctx {
126 Some(Extension(user)) => user,
127 None => return (StatusCode::UNAUTHORIZED, Json(json!({"error": "Unauthorized"}))).into_response(),
128 };
129
130 let sync_service = create_sync_service(&state);
131
132 match sync_service.get_sync_status(&user).await {
133 Ok(summary) => (
134 StatusCode::OK,
135 Json(SyncStatusResponse {
136 pending_count: summary.pending_count,
137 conflict_count: summary.conflict_count,
138 pending_items: summary
139 .pending_items
140 .into_iter()
141 .map(|(entity_type, entity_id)| PendingItem { entity_type, entity_id })
142 .collect(),
143 conflicts: summary
144 .conflicts
145 .into_iter()
146 .map(|c| ConflictItem {
147 entity_type: c.entity_type,
148 entity_id: c.entity_id,
149 local_version: c.local_version,
150 remote_version: c.remote_version,
151 })
152 .collect(),
153 }),
154 )
155 .into_response(),
156 Err(e) => sync_error_response(e),
157 }
158}
159
160/// Resolve a sync conflict.
161///
162/// POST /api/sync/resolve/:entity_type/:id
163pub async fn resolve_conflict(
164 State(state): State<SharedState>, ctx: Option<Extension<UserContext>>,
165 Path((entity_type, entity_id)): Path<(String, String)>, Json(payload): Json<ResolveConflictRequest>,
166) -> impl IntoResponse {
167 let user = match ctx {
168 Some(Extension(user)) => user,
169 None => return (StatusCode::UNAUTHORIZED, Json(json!({"error": "Unauthorized"}))).into_response(),
170 };
171
172 let strategy = match ConflictStrategy::from_str(&payload.strategy) {
173 Ok(s) => s,
174 Err(_) => {
175 return (
176 StatusCode::BAD_REQUEST,
177 Json(json!({"error": "Invalid strategy. Use: last_write_wins, keep_local, or keep_remote"})),
178 )
179 .into_response();
180 }
181 };
182
183 let sync_service = create_sync_service(&state);
184
185 match sync_service
186 .resolve_conflict(&entity_type, &entity_id, strategy, &user)
187 .await
188 {
189 Ok(result) => (
190 StatusCode::OK,
191 Json(PushResponse {
192 entity_type: result.entity_type,
193 entity_id: result.entity_id,
194 pds_uri: result.pds_uri,
195 pds_cid: result.pds_cid,
196 version: result.new_version,
197 status: result.status.to_string(),
198 }),
199 )
200 .into_response(),
201 Err(e) => sync_error_response(e),
202 }
203}
204
205/// Create a SyncService from the app state.
206fn create_sync_service(state: &SharedState) -> SyncService {
207 SyncService::new(
208 state.sync_repo.clone(),
209 state.deck_repo.clone(),
210 state.card_repo.clone(),
211 state.note_repo.clone(),
212 state.oauth_repo.clone(),
213 )
214}
215
216/// Convert SyncError to HTTP response.
217fn sync_error_response(error: SyncError) -> axum::response::Response {
218 let (status, message) = match &error {
219 SyncError::NotFound(msg) => (StatusCode::NOT_FOUND, msg.clone()),
220 SyncError::AuthRequired(msg) => (StatusCode::UNAUTHORIZED, msg.clone()),
221 SyncError::NoTokens(msg) => (StatusCode::UNAUTHORIZED, msg.clone()),
222 SyncError::InvalidArgument(msg) => (StatusCode::BAD_REQUEST, msg.clone()),
223 SyncError::ConflictDetected(info) => (
224 StatusCode::CONFLICT,
225 format!("Conflict for {}:{}", info.entity_type, info.entity_id),
226 ),
227 SyncError::PdsError(e) => (StatusCode::BAD_GATEWAY, e.to_string()),
228 SyncError::RepoError(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
229 };
230
231 tracing::error!("Sync error: {}", error);
232 (status, Json(json!({"error": message}))).into_response()
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238
239 #[test]
240 fn test_push_response_serialization() {
241 let response = PushResponse {
242 entity_type: "deck".to_string(),
243 entity_id: "123".to_string(),
244 pds_uri: Some("at://did:plc:test/deck/tid".to_string()),
245 pds_cid: Some("bafycid".to_string()),
246 version: 2,
247 status: "synced".to_string(),
248 };
249
250 let json = serde_json::to_string(&response).unwrap();
251 assert!(json.contains("\"entity_type\":\"deck\""));
252 assert!(json.contains("\"version\":2"));
253 }
254
255 #[test]
256 fn test_sync_status_response_serialization() {
257 let response = SyncStatusResponse {
258 pending_count: 2,
259 conflict_count: 1,
260 pending_items: vec![
261 PendingItem { entity_type: "deck".to_string(), entity_id: "1".to_string() },
262 PendingItem { entity_type: "note".to_string(), entity_id: "2".to_string() },
263 ],
264 conflicts: vec![ConflictItem {
265 entity_type: "deck".to_string(),
266 entity_id: "3".to_string(),
267 local_version: 5,
268 remote_version: Some(6),
269 }],
270 };
271
272 let json = serde_json::to_string(&response).unwrap();
273 assert!(json.contains("\"pending_count\":2"));
274 assert!(json.contains("\"conflict_count\":1"));
275 }
276
277 #[test]
278 fn test_resolve_conflict_request_deserialization() {
279 let json = r#"{"strategy": "last_write_wins"}"#;
280 let request: ResolveConflictRequest = serde_json::from_str(json).unwrap();
281 assert_eq!(request.strategy, "last_write_wins");
282
283 let json = r#"{"strategy": "keep_local"}"#;
284 let request: ResolveConflictRequest = serde_json::from_str(json).unwrap();
285 assert_eq!(request.strategy, "keep_local");
286 }
287
288 #[test]
289 fn test_sync_error_response_not_found() {
290 let error = SyncError::NotFound("deck:123".to_string());
291 let response = sync_error_response(error);
292 assert_eq!(response.status(), StatusCode::NOT_FOUND);
293 }
294
295 #[test]
296 fn test_sync_error_response_unauthorized() {
297 let error = SyncError::AuthRequired("missing token".to_string());
298 let response = sync_error_response(error);
299 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
300 }
301
302 #[test]
303 fn test_sync_error_response_bad_request() {
304 let error = SyncError::InvalidArgument("bad entity type".to_string());
305 let response = sync_error_response(error);
306 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
307 }
308
309 #[test]
310 fn test_sync_error_response_conflict() {
311 let error = SyncError::ConflictDetected(crate::sync_service::ConflictInfo {
312 entity_type: "deck".to_string(),
313 entity_id: "123".to_string(),
314 local_version: 5,
315 remote_version: Some(6),
316 local_updated_at: None,
317 remote_updated_at: None,
318 });
319 let response = sync_error_response(error);
320 assert_eq!(response.status(), StatusCode::CONFLICT);
321 }
322
323 #[test]
324 fn test_pending_item_serialization() {
325 let item = PendingItem { entity_type: "note".to_string(), entity_id: "456".to_string() };
326
327 let json = serde_json::to_string(&item).unwrap();
328 assert!(json.contains("\"entity_type\":\"note\""));
329 assert!(json.contains("\"entity_id\":\"456\""));
330 }
331
332 #[test]
333 fn test_conflict_item_serialization() {
334 let item = ConflictItem {
335 entity_type: "deck".to_string(),
336 entity_id: "789".to_string(),
337 local_version: 3,
338 remote_version: Some(4),
339 };
340
341 let json = serde_json::to_string(&item).unwrap();
342 assert!(json.contains("\"local_version\":3"));
343 assert!(json.contains("\"remote_version\":4"));
344 }
345
346 #[test]
347 fn test_conflict_item_no_remote_version() {
348 let item = ConflictItem {
349 entity_type: "note".to_string(),
350 entity_id: "abc".to_string(),
351 local_version: 1,
352 remote_version: None,
353 };
354
355 let json = serde_json::to_string(&item).unwrap();
356 assert!(json.contains("\"remote_version\":null"));
357 }
358
359 #[tokio::test]
360 async fn test_push_deck_unauthorized() {
361 let pool = crate::db::create_mock_pool();
362 let repos = crate::state::Repositories::default();
363 let config = crate::state::AppConfig { pds_url: "https://test.example.com".to_string() };
364 let state = crate::state::AppState::new(pool, repos, config);
365
366 let response = push_deck(State(state), None, Path("deck-123".to_string()))
367 .await
368 .into_response();
369
370 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
371 }
372
373 #[tokio::test]
374 async fn test_push_note_unauthorized() {
375 let pool = crate::db::create_mock_pool();
376 let repos = crate::state::Repositories::default();
377 let config = crate::state::AppConfig { pds_url: "https://test.example.com".to_string() };
378 let state = crate::state::AppState::new(pool, repos, config);
379
380 let response = push_note(State(state), None, Path("note-456".to_string()))
381 .await
382 .into_response();
383
384 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
385 }
386
387 #[tokio::test]
388 async fn test_get_sync_status_unauthorized() {
389 let pool = crate::db::create_mock_pool();
390 let repos = crate::state::Repositories::default();
391 let config = crate::state::AppConfig { pds_url: "https://test.example.com".to_string() };
392 let state = crate::state::AppState::new(pool, repos, config);
393
394 let response = get_sync_status(State(state), None).await.into_response();
395
396 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
397 }
398
399 #[tokio::test]
400 async fn test_resolve_conflict_unauthorized() {
401 let pool = crate::db::create_mock_pool();
402 let repos = crate::state::Repositories::default();
403 let config = crate::state::AppConfig { pds_url: "https://test.example.com".to_string() };
404 let state = crate::state::AppState::new(pool, repos, config);
405
406 let response = resolve_conflict(
407 State(state),
408 None,
409 Path(("deck".to_string(), "123".to_string())),
410 Json(ResolveConflictRequest { strategy: "last_write_wins".to_string() }),
411 )
412 .await
413 .into_response();
414
415 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
416 }
417
418 #[tokio::test]
419 async fn test_resolve_conflict_invalid_strategy() {
420 let pool = crate::db::create_mock_pool();
421 let repos = crate::state::Repositories::default();
422 let config = crate::state::AppConfig { pds_url: "https://test.example.com".to_string() };
423 let state = crate::state::AppState::new(pool, repos, config);
424
425 let user = UserContext {
426 did: "did:plc:alice".to_string(),
427 handle: "alice.bsky.social".to_string(),
428 access_token: "test_token".to_string(),
429 pds_url: "https://bsky.social".to_string(),
430 has_dpop: false,
431 };
432
433 let response = resolve_conflict(
434 State(state),
435 Some(Extension(user)),
436 Path(("deck".to_string(), "123".to_string())),
437 Json(ResolveConflictRequest { strategy: "invalid_strategy".to_string() }),
438 )
439 .await
440 .into_response();
441
442 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
443 }
444}