P2P support library for the beaver compute environment
1/* SPDX Id: AGPL-3.0-or-later */
2
3use std::collections::BTreeSet;
4
5use iroh::EndpointAddr;
6use iroh::endpoint::Connection;
7use iroh::protocol::{AcceptError, ProtocolHandler};
8use log::{error, info};
9
10use tokio::sync::mpsc::channel as tokio_channel;
11
12use crate::packet::PostcardPacket;
13pub use crate::state::PeerEvent;
14use crate::state::{EndpointProxy, EndpointStatus, PairingCommand, SharedState};
15
16#[derive(Debug)]
17pub(crate) struct PairingProtocol {
18 state: SharedState,
19}
20
21impl PairingProtocol {
22 pub(crate) fn new(state: SharedState) -> Self {
23 Self { state }
24 }
25}
26
27/// Pairing handshake:
28/// NodeA NodeB
29/// Request ----------------->
30/// <---------------------- Accept/Reject
31/// Ack ---------------------->
32///
33impl ProtocolHandler for PairingProtocol {
34 async fn accept(&self, connection: Connection) -> Result<(), AcceptError> {
35 let remote_id = connection.remote_id();
36 info!(
37 "accepted connection on {:?} from {:?}",
38 String::from_utf8_lossy(connection.alpn()),
39 remote_id
40 );
41
42 let path_info = connection.to_info().selected_path().unwrap();
43 let remote_addr = path_info.remote_addr();
44 let mut addrs = BTreeSet::new();
45 addrs.insert(remote_addr.clone());
46 let addr = EndpointAddr {
47 id: remote_id,
48 addrs,
49 };
50
51 // If we don't know about this endpoint yet, add it to our set as Discovered.
52 {
53 let mut state = self.state.lock().await;
54 if !state.has_any(&remote_id) {
55 info!("Registering auto-discovered endpoint at {remote_id}");
56 let description = EndpointProxy::new(
57 "<auto-discovered>",
58 remote_id,
59 addr.clone(),
60 EndpointStatus::Discovered,
61 );
62 state.add_endpoint(&remote_id, description);
63 }
64 }
65
66 let (mut sender, mut receiver) = connection.accept_bi().await?;
67
68 let command: PairingCommand = PostcardPacket::recv(&mut receiver)
69 .await
70 .expect("Failed to read");
71
72 // Step 1. Receive a request, store the tokio channel sender and ack receiver in the state.
73 match command {
74 PairingCommand::Request => {
75 self.state
76 .lock()
77 .await
78 .notify(PeerEvent::PairingRequest(remote_id));
79 }
80 _ => {
81 error!("Unexpected command: {command:?}");
82 return Err(AcceptError::from(n0_error::AnyError::from(format!(
83 "Unexpected command: {command:?}"
84 ))));
85 }
86 }
87
88 let (tokio_sender, mut tokio_receiver) = tokio_channel(2);
89 let (ack_sender, ack_receiver) = tokio_channel(2);
90 {
91 let mut state = self.state.lock().await;
92 state.set_pairing_responder(&remote_id, (tokio_sender, ack_receiver));
93 }
94
95 // Step 2. Wait for the Accept or Reject response on the channel and send it back.
96 let answer = tokio_receiver
97 .recv()
98 .await
99 .expect("Failed to receive pairing answer");
100
101 let accepted = answer == PairingCommand::Accept;
102
103 PostcardPacket::send(answer, &mut sender)
104 .await
105 .expect("Failed to send");
106
107 // Step 3. Wait for the Ack from the other side.
108 let command: PairingCommand = PostcardPacket::recv(&mut receiver)
109 .await
110 .expect("Failed to read");
111
112 match command {
113 PairingCommand::Ack => {
114 {
115 let mut state = self.state.lock().await;
116 if accepted {
117 state.notify(PeerEvent::PairingAccepted(remote_id));
118 state.set_status(&remote_id, EndpointStatus::PairedConnected);
119 } else {
120 state.notify(PeerEvent::PairingRejected(remote_id));
121 }
122 }
123 let _ = ack_sender.send(true).await;
124 }
125 _ => {
126 self.state
127 .lock()
128 .await
129 .notify(PeerEvent::PairingFailed(remote_id));
130 let _ = ack_sender.send(false).await;
131 error!("Unexpected command: {command:?}");
132 return Err(AcceptError::from(n0_error::AnyError::from(format!(
133 "Unexpected command: {command:?}"
134 ))));
135 }
136 }
137
138 Ok(())
139 }
140}