jj workspaces over the network
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

feat(server): add WebSocket yrs sync endpoint

+129
+129
crates/tandem-server/src/sync.rs
··· 1 + use axum::{ 2 + extract::{ws::{Message, WebSocket, WebSocketUpgrade}, Path, State}, 3 + response::IntoResponse, 4 + }; 5 + use futures_util::{SinkExt, StreamExt}; 6 + use std::collections::HashMap; 7 + use tokio::sync::RwLock; 8 + use uuid::Uuid; 9 + use crate::AppState; 10 + 11 + /// WebSocket handler for yrs sync 12 + pub async fn sync_handler( 13 + ws: WebSocketUpgrade, 14 + Path(repo_id): Path<String>, 15 + State(state): State<AppState>, 16 + ) -> impl IntoResponse { 17 + ws.on_upgrade(move |socket| handle_sync(socket, repo_id, state)) 18 + } 19 + 20 + async fn handle_sync(socket: WebSocket, repo_id: String, state: AppState) { 21 + let (mut sender, mut receiver) = socket.split(); 22 + 23 + let doc = match state.docs.get_or_load(&repo_id).await { 24 + Ok(doc) => doc, 25 + Err(e) => { 26 + tracing::error!("Failed to load doc for {}: {}", repo_id, e); 27 + return; 28 + } 29 + }; 30 + 31 + let client_id = Uuid::new_v4(); 32 + let mut broadcast_rx = state.sync.subscribe(&repo_id).await; 33 + 34 + tracing::info!("Client {} connected to sync for repo {}", client_id, repo_id); 35 + 36 + loop { 37 + tokio::select! { 38 + Some(msg) = receiver.next() => { 39 + match msg { 40 + Ok(Message::Binary(data)) => { 41 + let doc = doc.read().await; 42 + 43 + if let Err(_e) = doc.apply_update(&data) { 44 + // Might be a state vector - compute diff and send 45 + let update = doc.encode_update_from(&data); 46 + drop(doc); 47 + if let Err(e) = sender.send(Message::Binary(update)).await { 48 + tracing::error!("Failed to send update: {}", e); 49 + break; 50 + } 51 + } else { 52 + // Successfully applied update 53 + drop(doc); 54 + 55 + // Save to disk 56 + if let Err(e) = state.docs.save(&repo_id).await { 57 + tracing::warn!("Failed to save doc: {}", e); 58 + } 59 + 60 + // Broadcast to other clients 61 + state.sync.broadcast(&repo_id, client_id, data).await; 62 + } 63 + } 64 + Ok(Message::Close(_)) => { 65 + tracing::info!("Client {} disconnected from repo {}", client_id, repo_id); 66 + break; 67 + } 68 + Ok(Message::Ping(data)) => { 69 + if let Err(e) = sender.send(Message::Pong(data)).await { 70 + tracing::error!("Failed to send pong: {}", e); 71 + break; 72 + } 73 + } 74 + Err(e) => { 75 + tracing::error!("WebSocket error: {}", e); 76 + break; 77 + } 78 + _ => {} 79 + } 80 + } 81 + Ok(msg) = broadcast_rx.recv() => { 82 + // Don't echo back to sender 83 + if msg.sender_id != client_id { 84 + if let Err(e) = sender.send(Message::Binary(msg.data)).await { 85 + tracing::error!("Failed to forward broadcast: {}", e); 86 + break; 87 + } 88 + } 89 + } 90 + } 91 + } 92 + } 93 + 94 + /// Message wrapper that includes sender ID to prevent echo 95 + #[derive(Clone, Debug)] 96 + pub(crate) struct BroadcastMessage { 97 + sender_id: Uuid, 98 + data: Vec<u8>, 99 + } 100 + 101 + /// Track connected clients for broadcasting 102 + pub struct SyncManager { 103 + channels: RwLock<HashMap<String, tokio::sync::broadcast::Sender<BroadcastMessage>>>, 104 + } 105 + 106 + impl SyncManager { 107 + pub fn new() -> Self { 108 + Self { 109 + channels: RwLock::new(HashMap::new()), 110 + } 111 + } 112 + 113 + pub async fn subscribe(&self, repo_id: &str) -> tokio::sync::broadcast::Receiver<BroadcastMessage> { 114 + let mut channels = self.channels.write().await; 115 + let tx = channels.entry(repo_id.to_string()).or_insert_with(|| { 116 + let (tx, _rx) = tokio::sync::broadcast::channel(100); 117 + tx 118 + }); 119 + tx.subscribe() 120 + } 121 + 122 + pub async fn broadcast(&self, repo_id: &str, sender_id: Uuid, data: Vec<u8>) { 123 + let channels = self.channels.read().await; 124 + if let Some(tx) = channels.get(repo_id) { 125 + let msg = BroadcastMessage { sender_id, data }; 126 + let _ = tx.send(msg); 127 + } 128 + } 129 + }