Tap drinker
1mod cli;
2mod query;
3
4use std::{
5 collections::{HashSet, VecDeque},
6 io, process,
7 time::Duration,
8};
9
10use futures_util::StreamExt;
11use serde_json::Value;
12use sqlx::PgPool;
13use tokio::{sync::mpsc, task::JoinSet};
14use tokio_util::sync::CancellationToken;
15use tracing::{Instrument as _, level_filters::LevelFilter};
16use tracing_subscriber::{EnvFilter, layer::SubscriberExt as _, util::SubscriberInitExt as _};
17use trap::tap::{IdentityEvent, RecordAction, RecordEvent, TapChannel, TapClient, TapEvent};
18
19/// DID provenance.
20#[derive(Debug)]
21#[allow(unused)]
22enum DidSource {
23 Seed,
24 Record(Box<str>),
25}
26
27#[tokio::main]
28async fn main() -> anyhow::Result<()> {
29 setup_tracing()?;
30
31 let arguments = cli::parse();
32 let pool = PgPool::connect(arguments.db.as_str()).await?;
33 let version = db_version(&pool).await?;
34 tracing::info!(%version, "connected to db");
35
36 sqlx::migrate!().run(&pool).await?;
37
38 let shutdown = CancellationToken::new();
39 let (did_tx, did_rx) = mpsc::unbounded_channel::<(String, DidSource)>();
40
41 let tap = TapClient::new(arguments.tap, arguments.tap_password.as_deref())?;
42 let (tap_channel, tap_task) = tap.channel();
43
44 let mut tasks = JoinSet::new();
45 tasks.spawn(async move {
46 if let Err(error) = tap_task.await {
47 tracing::error!(?error);
48 process::abort();
49 }
50 Ok(())
51 });
52
53 tasks.spawn(event_consumer(
54 tap_channel,
55 pool.clone(),
56 did_tx.clone(),
57 shutdown.child_token(),
58 arguments.crawl,
59 ));
60
61 tasks.spawn(did_task(
62 tap,
63 pool,
64 did_rx,
65 shutdown.child_token(),
66 arguments.crawl,
67 ));
68 tasks.spawn(shutdown_task(shutdown.clone()));
69
70 // Submit seed DIDs to the Tap service.
71 for did in arguments.seed.into_iter().filter(|s| possible_did(s)) {
72 did_tx.send((did, DidSource::Seed))?;
73 }
74
75 for task in tasks.join_all().await {
76 if let Err(error) = task {
77 tracing::error!(?error, "task failed");
78 shutdown.cancel();
79 }
80 }
81
82 Ok(())
83}
84
85fn setup_tracing() -> anyhow::Result<()> {
86 tracing_subscriber::registry()
87 .with(
88 EnvFilter::builder()
89 .with_default_directive(LevelFilter::INFO.into())
90 .from_env()?,
91 )
92 .with(tracing_subscriber::fmt::layer().with_writer(io::stderr))
93 .try_init()?;
94
95 Ok(())
96}
97
98async fn db_version(pool: &PgPool) -> sqlx::Result<String> {
99 let row: (String,) = sqlx::query_as("SELECT version()").fetch_one(pool).await?;
100 Ok(row.0)
101}
102
103#[tracing::instrument(skip(channel, pool, tx, shutdown))]
104async fn event_consumer(
105 mut channel: TapChannel,
106 pool: PgPool,
107 tx: mpsc::UnboundedSender<(String, DidSource)>,
108 shutdown: CancellationToken,
109 crawl: bool,
110) -> anyhow::Result<()> {
111 while let Some(Some((span, event, ack))) = shutdown.run_until_cancelled(channel.recv()).await {
112 async {
113 let mut transaction = pool.begin().await?;
114
115 save_event(&event, &mut transaction).await?;
116 match event {
117 TapEvent::Record(record) => {
118 let (record, parsed_record) = handle_record(record, &mut transaction).await?;
119
120 if crawl {
121 // Expand the network of tracked DIDs.
122 let nsid = record.collection.into_boxed_str();
123 for did in extract_dids(&parsed_record) {
124 tx.send((did, DidSource::Record(nsid.clone())))?;
125 }
126 }
127 }
128 TapEvent::Identity(identity) => {
129 handle_identity(identity, &mut transaction).await?;
130 }
131 }
132
133 transaction.commit().await?;
134 ack.acknowledge().await?;
135 Ok::<_, anyhow::Error>(())
136 }
137 .instrument(span)
138 .await?;
139 }
140
141 tracing::info!("complete");
142 Ok(())
143}
144
145async fn save_event(
146 event: &TapEvent,
147 transaction: &mut sqlx::Transaction<'static, sqlx::Postgres>,
148) -> anyhow::Result<()> {
149 let json = serde_json::to_value(&event)?;
150 sqlx::query!(
151 "INSERT INTO event (id, data) VALUES ($1, $2)",
152 i64::try_from(event.id()).expect("event ID should not exceed i64::MAX"),
153 json
154 )
155 .execute(&mut **transaction)
156 .await?;
157
158 Ok(())
159}
160
161async fn handle_identity(
162 identity_event: IdentityEvent,
163 transaction: &mut sqlx::Transaction<'static, sqlx::Postgres>,
164) -> anyhow::Result<()> {
165 let IdentityEvent {
166 id: _,
167 did,
168 handle,
169 is_active,
170 status,
171 } = identity_event;
172
173 query::upsert_identity(&did, &handle, &status, is_active)
174 .execute(&mut **transaction)
175 .await?;
176
177 Ok(())
178}
179
180async fn handle_record(
181 record_event: RecordEvent,
182 transaction: &mut sqlx::Transaction<'static, sqlx::Postgres>,
183) -> anyhow::Result<(RecordEvent, Value)> {
184 let RecordEvent {
185 id: _,
186 did,
187 rev,
188 collection,
189 rkey,
190 action,
191 record,
192 cid,
193 live,
194 } = &record_event;
195
196 let parsed_record: Value = record
197 .as_ref()
198 .map(|record| serde_json::from_str(record.get()))
199 .transpose()?
200 .unwrap_or_default();
201
202 match action {
203 RecordAction::Create | RecordAction::Update => {
204 sqlx::query_file!(
205 "queries/upsert_record.sql",
206 did.as_str(),
207 collection,
208 rkey,
209 rev,
210 cid.as_deref(),
211 live,
212 parsed_record
213 )
214 .execute(&mut **transaction)
215 .await?;
216 }
217 RecordAction::Delete => {
218 query::delete_record(did, collection, rkey, rev)
219 .execute(&mut **transaction)
220 .await?;
221 }
222 }
223
224 Ok((record_event, parsed_record))
225}
226
227#[tracing::instrument(skip(tap, pool, did_rx, shutdown))]
228async fn did_task(
229 tap: TapClient,
230 pool: PgPool,
231 mut did_rx: mpsc::UnboundedReceiver<(String, DidSource)>,
232 shutdown: CancellationToken,
233 crawl: bool,
234) -> anyhow::Result<()> {
235 const BATCH: usize = 64;
236
237 let mut seen: HashSet<String> = HashSet::with_capacity(10_000);
238 let mut dids = Vec::new();
239
240 if crawl {
241 // Query known DIDs from the database.
242 let mut query = sqlx::query!("SELECT did FROM identity").fetch(&pool);
243 while let Some(Ok(row)) = query.next().await {
244 seen.insert(row.did);
245 }
246
247 tracing::debug!(count = seen.len(), "loaded tracked DIDs from database");
248 }
249
250 loop {
251 tokio::time::sleep(Duration::from_millis(200)).await;
252 match shutdown
253 .run_until_cancelled(did_rx.recv_many(&mut dids, BATCH))
254 .await
255 {
256 Some(0) | None => break,
257 Some(_) => {}
258 }
259
260 // Convert Vec<Box<Did>> to a Vec<&Did>.
261 let mut dedup: HashSet<&str> = HashSet::new();
262 let mut slice = Vec::with_capacity(dids.len());
263 for (did, source) in &dids {
264 if !dedup.insert(did) {
265 continue;
266 }
267
268 if !seen.contains(did) || slice.contains(&did.as_ref()) {
269 tracing::info!(?did, ?source, "tracking DID");
270 slice.push(did.as_ref());
271 }
272 }
273
274 tap.add_repos(&slice).await?;
275
276 dids.drain(..).for_each(|(did, _)| _ = seen.insert(did));
277 }
278
279 tracing::info!("complete");
280 Ok(())
281}
282
283#[tracing::instrument(skip(shutdown))]
284async fn shutdown_task(shutdown: CancellationToken) -> anyhow::Result<()> {
285 tokio::signal::ctrl_c().await?;
286 eprintln!();
287 tracing::info!("shutdown signal received");
288 shutdown.cancel();
289 Ok(())
290}
291
292/// Extract any strings that look like DIDs from a JSON document.
293fn extract_dids(value: &Value) -> HashSet<String> {
294 let mut dids = HashSet::new();
295
296 let mut queue = VecDeque::from_iter([value]);
297 while let Some(value) = queue.pop_front() {
298 match value {
299 Value::Null | Value::Bool(_) | Value::Number(_) => {}
300 Value::Array(values) => {
301 for value in values {
302 queue.push_back(value);
303 }
304 }
305 Value::Object(map) => {
306 for (_, value) in map {
307 queue.push_back(value);
308 }
309 }
310 Value::String(maybe_did) => {
311 if possible_did(maybe_did) {
312 dids.insert(maybe_did.to_string());
313 continue;
314 }
315
316 // First segment of an "at://..." URI might be a DID.
317 if let Some(uri) = maybe_did.strip_prefix("at://")
318 && let Some((maybe_did, _)) = uri.split_once('/')
319 && possible_did(maybe_did)
320 {
321 dids.insert(maybe_did.to_string());
322 continue;
323 }
324 }
325 }
326 }
327
328 dids
329}
330
331fn possible_did(s: &str) -> bool {
332 s.starts_with("did:plc") || s.starts_with("did:web")
333}