P2P support library for the beaver compute environment
1/* SPDX Id: AGPL-3.0-or-later */
2
3use std::{
4 collections::{HashMap, HashSet},
5 sync::Arc,
6};
7
8use iroh::{EndpointAddr, EndpointId, address_lookup::DiscoveryEvent, endpoint::SendStream};
9
10use log::{info, warn};
11use serde::{Deserialize, Serialize};
12use std::sync::mpsc::Sender;
13use tokio::sync::Mutex;
14use tokio::sync::mpsc::{Receiver as TokioReceiver, Sender as TokioSender};
15
16pub(crate) type SharedState = Arc<Mutex<State>>;
17
18#[derive(Clone, Debug, PartialEq)]
19pub enum EndpointStatus {
20 PairedConnected,
21 PairedDisconnected,
22 Discovered,
23}
24
25#[derive(Debug)]
26pub(crate) struct EndpointProxy {
27 name: String,
28 id: EndpointId,
29 addr: EndpointAddr,
30 status: EndpointStatus,
31 message_sender: Option<SendStream>,
32}
33
34impl EndpointProxy {
35 pub(crate) fn new(
36 name: &str,
37 id: EndpointId,
38 addr: EndpointAddr,
39 status: EndpointStatus,
40 ) -> Self {
41 Self {
42 name: name.to_owned(),
43 id: id.to_owned(),
44 addr,
45 status,
46 message_sender: None,
47 }
48 }
49
50 pub(crate) fn addr(&self) -> EndpointAddr {
51 self.addr.clone()
52 }
53
54 pub(crate) fn is_paired(&self) -> bool {
55 self.status == EndpointStatus::PairedConnected
56 || self.status == EndpointStatus::PairedDisconnected
57 }
58}
59
60// PairingFailed represents protocol errors such as
61// failure to complete the Ack handshake.
62#[derive(Debug)]
63pub enum PeerEvent {
64 Discovery(DiscoveryEvent),
65 PairingRequest(EndpointId),
66 PairingAccepted(EndpointId),
67 PairingRejected(EndpointId),
68 PairingFailed(EndpointId),
69 Message(EndpointId, Vec<u8>),
70}
71
72#[derive(Serialize, Deserialize, Debug, PartialEq)]
73pub(crate) enum PairingCommand {
74 Request,
75 Accept,
76 Reject,
77 Ack,
78}
79
80#[derive(Debug)]
81pub struct EndpointDescription {
82 pub name: String,
83 pub id: EndpointId,
84 pub addr: EndpointAddr,
85 pub status: EndpointStatus,
86}
87
88impl From<&EndpointProxy> for EndpointDescription {
89 fn from(value: &EndpointProxy) -> Self {
90 Self {
91 name: value.name.clone(),
92 id: value.id,
93 addr: value.addr.clone(),
94 status: value.status.clone(),
95 }
96 }
97}
98
99#[derive(Debug)]
100pub(crate) struct State {
101 /// The list of known endpoints, with their status.
102 endpoints: HashMap<EndpointId, EndpointProxy>,
103
104 /// The set of endpoints that we sent pairing requests to.
105 pairing_requested: HashSet<EndpointId>,
106
107 /// The senders used to provide accept/reject responses
108 pairing_responders: HashMap<EndpointId, (TokioSender<PairingCommand>, TokioReceiver<bool>)>,
109
110 /// The set of endpoints that we are waiting to receive ack from.
111 pending_ack: HashSet<EndpointId>,
112
113 /// The sender side of the channel used to receive high level events.
114 sender: Sender<PeerEvent>,
115}
116
117impl State {
118 pub(crate) fn new(sender: Sender<PeerEvent>) -> Self {
119 Self {
120 endpoints: HashMap::new(),
121 pairing_requested: HashSet::new(),
122 pairing_responders: HashMap::new(),
123 pending_ack: HashSet::new(),
124 sender,
125 }
126 }
127
128 pub(crate) fn endpoints(&self) -> &HashMap<EndpointId, EndpointProxy> {
129 &self.endpoints
130 }
131
132 pub(crate) fn notify(&self, event: PeerEvent) {
133 self.sender.send(event).expect("Failed to send {event:?}");
134 }
135
136 pub(crate) fn has(&self, id: &EndpointId, status: EndpointStatus) -> bool {
137 self.endpoints
138 .get(id)
139 .is_some_and(|desc| desc.status == status)
140 }
141
142 pub(crate) fn has_any(&self, id: &EndpointId) -> bool {
143 self.endpoints.contains_key(id)
144 }
145
146 pub(crate) fn by_id(&self, id: &EndpointId) -> Option<&EndpointProxy> {
147 self.endpoints.get(id)
148 }
149
150 pub(crate) fn add_endpoint(&mut self, id: &EndpointId, description: EndpointProxy) {
151 self.endpoints.insert(*id, description);
152 }
153
154 pub(crate) fn set_message_sender(&mut self, id: &EndpointId, stream: SendStream) {
155 info!("set_message_sender for {id}");
156 if let Some(desc) = self.endpoints.get_mut(id) {
157 desc.message_sender = Some(stream);
158 }
159 }
160
161 pub(crate) fn get_message_sender(&mut self, id: &EndpointId) -> Option<&mut SendStream> {
162 info!("get_message_sender for {id}");
163 if let Some(desc) = self.endpoints.get_mut(id) {
164 desc.message_sender.as_mut()
165 } else {
166 None
167 }
168 }
169
170 pub(crate) fn remove_message_sender(&mut self, id: &EndpointId) {
171 info!("remove_message_sender for {id}");
172 if let Some(desc) = self.endpoints.get_mut(id) {
173 desc.message_sender = None;
174 }
175 }
176
177 pub(crate) fn set_status(&mut self, id: &EndpointId, status: EndpointStatus) {
178 if let Some(desc) = self.endpoints.get_mut(id) {
179 desc.status = status;
180 }
181 }
182
183 fn discovered(&self, id: &EndpointId) -> bool {
184 self.endpoints
185 .get(id)
186 .is_some_and(|desc| desc.status != EndpointStatus::PairedDisconnected)
187 }
188
189 pub(crate) fn has_requested(&self, id: &EndpointId) -> bool {
190 self.pairing_requested.contains(id)
191 }
192
193 pub(crate) fn set_pairing_requested(&mut self, id: &EndpointId) {
194 self.pairing_requested.insert(*id);
195 }
196
197 pub(crate) fn remove_pairing_requested(&mut self, id: &EndpointId) {
198 self.pairing_requested.remove(id);
199 }
200
201 pub(crate) fn set_pending_ack(&mut self, id: &EndpointId) {
202 self.pending_ack.insert(*id);
203 }
204
205 pub(crate) fn set_pairing_responder(
206 &mut self,
207 id: &EndpointId,
208 params: (TokioSender<PairingCommand>, TokioReceiver<bool>),
209 ) {
210 self.pairing_responders.insert(*id, params);
211 }
212
213 pub(crate) fn take_pairing_responder(
214 &mut self,
215 id: &EndpointId,
216 ) -> Option<(TokioSender<PairingCommand>, TokioReceiver<bool>)> {
217 self.pairing_responders.remove(id)
218 }
219
220 pub(crate) fn on_discovery(&mut self, event: &DiscoveryEvent) {
221 match event {
222 DiscoveryEvent::Discovered { endpoint_info, .. } => {
223 // Ignore if we already know about this endpoint.
224 if self.discovered(&endpoint_info.endpoint_id) {
225 return;
226 }
227
228 // Add it as Discovered and notify the listener.
229 let description = EndpointProxy {
230 name: endpoint_info
231 .data
232 .user_data()
233 .map(|d| d.as_ref())
234 .unwrap_or_else(|| "<no name>")
235 .into(),
236 id: endpoint_info.endpoint_id,
237 addr: endpoint_info.to_endpoint_addr(),
238 status: EndpointStatus::Discovered,
239 message_sender: None,
240 };
241 self.endpoints
242 .insert(endpoint_info.endpoint_id, description);
243 self.notify(PeerEvent::Discovery(event.clone()));
244 }
245 DiscoveryEvent::Expired { endpoint_id } => {
246 // PairedConnected -> PairedDisconnected
247 // Discovered -> removed
248 if let Some(mut old_desc) = self.endpoints.remove(endpoint_id) {
249 if old_desc.status == EndpointStatus::PairedConnected {
250 old_desc.status = EndpointStatus::PairedDisconnected;
251 self.endpoints.insert(*endpoint_id, old_desc);
252 } else if old_desc.status != EndpointStatus::Discovered {
253 warn!(
254 "Unexpected status for expired endpoint: {:?}",
255 old_desc.status
256 );
257 }
258 self.notify(PeerEvent::Discovery(event.clone()));
259 }
260 }
261 }
262 }
263}