P2P support library for the beaver compute environment
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

at main 263 lines 8.0 kB view raw
1/* SPDX Id: AGPL-3.0-or-later */ 2 3use std::{ 4 collections::{HashMap, HashSet}, 5 sync::Arc, 6}; 7 8use iroh::{EndpointAddr, EndpointId, address_lookup::DiscoveryEvent, endpoint::SendStream}; 9 10use log::{info, warn}; 11use serde::{Deserialize, Serialize}; 12use std::sync::mpsc::Sender; 13use tokio::sync::Mutex; 14use tokio::sync::mpsc::{Receiver as TokioReceiver, Sender as TokioSender}; 15 16pub(crate) type SharedState = Arc<Mutex<State>>; 17 18#[derive(Clone, Debug, PartialEq)] 19pub enum EndpointStatus { 20 PairedConnected, 21 PairedDisconnected, 22 Discovered, 23} 24 25#[derive(Debug)] 26pub(crate) struct EndpointProxy { 27 name: String, 28 id: EndpointId, 29 addr: EndpointAddr, 30 status: EndpointStatus, 31 message_sender: Option<SendStream>, 32} 33 34impl EndpointProxy { 35 pub(crate) fn new( 36 name: &str, 37 id: EndpointId, 38 addr: EndpointAddr, 39 status: EndpointStatus, 40 ) -> Self { 41 Self { 42 name: name.to_owned(), 43 id: id.to_owned(), 44 addr, 45 status, 46 message_sender: None, 47 } 48 } 49 50 pub(crate) fn addr(&self) -> EndpointAddr { 51 self.addr.clone() 52 } 53 54 pub(crate) fn is_paired(&self) -> bool { 55 self.status == EndpointStatus::PairedConnected 56 || self.status == EndpointStatus::PairedDisconnected 57 } 58} 59 60// PairingFailed represents protocol errors such as 61// failure to complete the Ack handshake. 62#[derive(Debug)] 63pub enum PeerEvent { 64 Discovery(DiscoveryEvent), 65 PairingRequest(EndpointId), 66 PairingAccepted(EndpointId), 67 PairingRejected(EndpointId), 68 PairingFailed(EndpointId), 69 Message(EndpointId, Vec<u8>), 70} 71 72#[derive(Serialize, Deserialize, Debug, PartialEq)] 73pub(crate) enum PairingCommand { 74 Request, 75 Accept, 76 Reject, 77 Ack, 78} 79 80#[derive(Debug)] 81pub struct EndpointDescription { 82 pub name: String, 83 pub id: EndpointId, 84 pub addr: EndpointAddr, 85 pub status: EndpointStatus, 86} 87 88impl From<&EndpointProxy> for EndpointDescription { 89 fn from(value: &EndpointProxy) -> Self { 90 Self { 91 name: value.name.clone(), 92 id: value.id, 93 addr: value.addr.clone(), 94 status: value.status.clone(), 95 } 96 } 97} 98 99#[derive(Debug)] 100pub(crate) struct State { 101 /// The list of known endpoints, with their status. 102 endpoints: HashMap<EndpointId, EndpointProxy>, 103 104 /// The set of endpoints that we sent pairing requests to. 105 pairing_requested: HashSet<EndpointId>, 106 107 /// The senders used to provide accept/reject responses 108 pairing_responders: HashMap<EndpointId, (TokioSender<PairingCommand>, TokioReceiver<bool>)>, 109 110 /// The set of endpoints that we are waiting to receive ack from. 111 pending_ack: HashSet<EndpointId>, 112 113 /// The sender side of the channel used to receive high level events. 114 sender: Sender<PeerEvent>, 115} 116 117impl State { 118 pub(crate) fn new(sender: Sender<PeerEvent>) -> Self { 119 Self { 120 endpoints: HashMap::new(), 121 pairing_requested: HashSet::new(), 122 pairing_responders: HashMap::new(), 123 pending_ack: HashSet::new(), 124 sender, 125 } 126 } 127 128 pub(crate) fn endpoints(&self) -> &HashMap<EndpointId, EndpointProxy> { 129 &self.endpoints 130 } 131 132 pub(crate) fn notify(&self, event: PeerEvent) { 133 self.sender.send(event).expect("Failed to send {event:?}"); 134 } 135 136 pub(crate) fn has(&self, id: &EndpointId, status: EndpointStatus) -> bool { 137 self.endpoints 138 .get(id) 139 .is_some_and(|desc| desc.status == status) 140 } 141 142 pub(crate) fn has_any(&self, id: &EndpointId) -> bool { 143 self.endpoints.contains_key(id) 144 } 145 146 pub(crate) fn by_id(&self, id: &EndpointId) -> Option<&EndpointProxy> { 147 self.endpoints.get(id) 148 } 149 150 pub(crate) fn add_endpoint(&mut self, id: &EndpointId, description: EndpointProxy) { 151 self.endpoints.insert(*id, description); 152 } 153 154 pub(crate) fn set_message_sender(&mut self, id: &EndpointId, stream: SendStream) { 155 info!("set_message_sender for {id}"); 156 if let Some(desc) = self.endpoints.get_mut(id) { 157 desc.message_sender = Some(stream); 158 } 159 } 160 161 pub(crate) fn get_message_sender(&mut self, id: &EndpointId) -> Option<&mut SendStream> { 162 info!("get_message_sender for {id}"); 163 if let Some(desc) = self.endpoints.get_mut(id) { 164 desc.message_sender.as_mut() 165 } else { 166 None 167 } 168 } 169 170 pub(crate) fn remove_message_sender(&mut self, id: &EndpointId) { 171 info!("remove_message_sender for {id}"); 172 if let Some(desc) = self.endpoints.get_mut(id) { 173 desc.message_sender = None; 174 } 175 } 176 177 pub(crate) fn set_status(&mut self, id: &EndpointId, status: EndpointStatus) { 178 if let Some(desc) = self.endpoints.get_mut(id) { 179 desc.status = status; 180 } 181 } 182 183 fn discovered(&self, id: &EndpointId) -> bool { 184 self.endpoints 185 .get(id) 186 .is_some_and(|desc| desc.status != EndpointStatus::PairedDisconnected) 187 } 188 189 pub(crate) fn has_requested(&self, id: &EndpointId) -> bool { 190 self.pairing_requested.contains(id) 191 } 192 193 pub(crate) fn set_pairing_requested(&mut self, id: &EndpointId) { 194 self.pairing_requested.insert(*id); 195 } 196 197 pub(crate) fn remove_pairing_requested(&mut self, id: &EndpointId) { 198 self.pairing_requested.remove(id); 199 } 200 201 pub(crate) fn set_pending_ack(&mut self, id: &EndpointId) { 202 self.pending_ack.insert(*id); 203 } 204 205 pub(crate) fn set_pairing_responder( 206 &mut self, 207 id: &EndpointId, 208 params: (TokioSender<PairingCommand>, TokioReceiver<bool>), 209 ) { 210 self.pairing_responders.insert(*id, params); 211 } 212 213 pub(crate) fn take_pairing_responder( 214 &mut self, 215 id: &EndpointId, 216 ) -> Option<(TokioSender<PairingCommand>, TokioReceiver<bool>)> { 217 self.pairing_responders.remove(id) 218 } 219 220 pub(crate) fn on_discovery(&mut self, event: &DiscoveryEvent) { 221 match event { 222 DiscoveryEvent::Discovered { endpoint_info, .. } => { 223 // Ignore if we already know about this endpoint. 224 if self.discovered(&endpoint_info.endpoint_id) { 225 return; 226 } 227 228 // Add it as Discovered and notify the listener. 229 let description = EndpointProxy { 230 name: endpoint_info 231 .data 232 .user_data() 233 .map(|d| d.as_ref()) 234 .unwrap_or_else(|| "<no name>") 235 .into(), 236 id: endpoint_info.endpoint_id, 237 addr: endpoint_info.to_endpoint_addr(), 238 status: EndpointStatus::Discovered, 239 message_sender: None, 240 }; 241 self.endpoints 242 .insert(endpoint_info.endpoint_id, description); 243 self.notify(PeerEvent::Discovery(event.clone())); 244 } 245 DiscoveryEvent::Expired { endpoint_id } => { 246 // PairedConnected -> PairedDisconnected 247 // Discovered -> removed 248 if let Some(mut old_desc) = self.endpoints.remove(endpoint_id) { 249 if old_desc.status == EndpointStatus::PairedConnected { 250 old_desc.status = EndpointStatus::PairedDisconnected; 251 self.endpoints.insert(*endpoint_id, old_desc); 252 } else if old_desc.status != EndpointStatus::Discovered { 253 warn!( 254 "Unexpected status for expired endpoint: {:?}", 255 old_desc.status 256 ); 257 } 258 self.notify(PeerEvent::Discovery(event.clone())); 259 } 260 } 261 } 262 } 263}