Tap drinker
2
fork

Configure Feed

Select the types of activity you want to include in your feed.

at main 333 lines 9.2 kB view raw
1mod cli; 2mod query; 3 4use std::{ 5 collections::{HashSet, VecDeque}, 6 io, process, 7 time::Duration, 8}; 9 10use futures_util::StreamExt; 11use serde_json::Value; 12use sqlx::PgPool; 13use tokio::{sync::mpsc, task::JoinSet}; 14use tokio_util::sync::CancellationToken; 15use tracing::{Instrument as _, level_filters::LevelFilter}; 16use tracing_subscriber::{EnvFilter, layer::SubscriberExt as _, util::SubscriberInitExt as _}; 17use trap::tap::{IdentityEvent, RecordAction, RecordEvent, TapChannel, TapClient, TapEvent}; 18 19/// DID provenance. 20#[derive(Debug)] 21#[allow(unused)] 22enum DidSource { 23 Seed, 24 Record(Box<str>), 25} 26 27#[tokio::main] 28async fn main() -> anyhow::Result<()> { 29 setup_tracing()?; 30 31 let arguments = cli::parse(); 32 let pool = PgPool::connect(arguments.db.as_str()).await?; 33 let version = db_version(&pool).await?; 34 tracing::info!(%version, "connected to db"); 35 36 sqlx::migrate!().run(&pool).await?; 37 38 let shutdown = CancellationToken::new(); 39 let (did_tx, did_rx) = mpsc::unbounded_channel::<(String, DidSource)>(); 40 41 let tap = TapClient::new(arguments.tap, arguments.tap_password.as_deref())?; 42 let (tap_channel, tap_task) = tap.channel(); 43 44 let mut tasks = JoinSet::new(); 45 tasks.spawn(async move { 46 if let Err(error) = tap_task.await { 47 tracing::error!(?error); 48 process::abort(); 49 } 50 Ok(()) 51 }); 52 53 tasks.spawn(event_consumer( 54 tap_channel, 55 pool.clone(), 56 did_tx.clone(), 57 shutdown.child_token(), 58 arguments.crawl, 59 )); 60 61 tasks.spawn(did_task( 62 tap, 63 pool, 64 did_rx, 65 shutdown.child_token(), 66 arguments.crawl, 67 )); 68 tasks.spawn(shutdown_task(shutdown.clone())); 69 70 // Submit seed DIDs to the Tap service. 71 for did in arguments.seed.into_iter().filter(|s| possible_did(s)) { 72 did_tx.send((did, DidSource::Seed))?; 73 } 74 75 for task in tasks.join_all().await { 76 if let Err(error) = task { 77 tracing::error!(?error, "task failed"); 78 shutdown.cancel(); 79 } 80 } 81 82 Ok(()) 83} 84 85fn setup_tracing() -> anyhow::Result<()> { 86 tracing_subscriber::registry() 87 .with( 88 EnvFilter::builder() 89 .with_default_directive(LevelFilter::INFO.into()) 90 .from_env()?, 91 ) 92 .with(tracing_subscriber::fmt::layer().with_writer(io::stderr)) 93 .try_init()?; 94 95 Ok(()) 96} 97 98async fn db_version(pool: &PgPool) -> sqlx::Result<String> { 99 let row: (String,) = sqlx::query_as("SELECT version()").fetch_one(pool).await?; 100 Ok(row.0) 101} 102 103#[tracing::instrument(skip(channel, pool, tx, shutdown))] 104async fn event_consumer( 105 mut channel: TapChannel, 106 pool: PgPool, 107 tx: mpsc::UnboundedSender<(String, DidSource)>, 108 shutdown: CancellationToken, 109 crawl: bool, 110) -> anyhow::Result<()> { 111 while let Some(Some((span, event, ack))) = shutdown.run_until_cancelled(channel.recv()).await { 112 async { 113 let mut transaction = pool.begin().await?; 114 115 save_event(&event, &mut transaction).await?; 116 match event { 117 TapEvent::Record(record) => { 118 let (record, parsed_record) = handle_record(record, &mut transaction).await?; 119 120 if crawl { 121 // Expand the network of tracked DIDs. 122 let nsid = record.collection.into_boxed_str(); 123 for did in extract_dids(&parsed_record) { 124 tx.send((did, DidSource::Record(nsid.clone())))?; 125 } 126 } 127 } 128 TapEvent::Identity(identity) => { 129 handle_identity(identity, &mut transaction).await?; 130 } 131 } 132 133 transaction.commit().await?; 134 ack.acknowledge().await?; 135 Ok::<_, anyhow::Error>(()) 136 } 137 .instrument(span) 138 .await?; 139 } 140 141 tracing::info!("complete"); 142 Ok(()) 143} 144 145async fn save_event( 146 event: &TapEvent, 147 transaction: &mut sqlx::Transaction<'static, sqlx::Postgres>, 148) -> anyhow::Result<()> { 149 let json = serde_json::to_value(&event)?; 150 sqlx::query!( 151 "INSERT INTO event (id, data) VALUES ($1, $2)", 152 i64::try_from(event.id()).expect("event ID should not exceed i64::MAX"), 153 json 154 ) 155 .execute(&mut **transaction) 156 .await?; 157 158 Ok(()) 159} 160 161async fn handle_identity( 162 identity_event: IdentityEvent, 163 transaction: &mut sqlx::Transaction<'static, sqlx::Postgres>, 164) -> anyhow::Result<()> { 165 let IdentityEvent { 166 id: _, 167 did, 168 handle, 169 is_active, 170 status, 171 } = identity_event; 172 173 query::upsert_identity(&did, &handle, &status, is_active) 174 .execute(&mut **transaction) 175 .await?; 176 177 Ok(()) 178} 179 180async fn handle_record( 181 record_event: RecordEvent, 182 transaction: &mut sqlx::Transaction<'static, sqlx::Postgres>, 183) -> anyhow::Result<(RecordEvent, Value)> { 184 let RecordEvent { 185 id: _, 186 did, 187 rev, 188 collection, 189 rkey, 190 action, 191 record, 192 cid, 193 live, 194 } = &record_event; 195 196 let parsed_record: Value = record 197 .as_ref() 198 .map(|record| serde_json::from_str(record.get())) 199 .transpose()? 200 .unwrap_or_default(); 201 202 match action { 203 RecordAction::Create | RecordAction::Update => { 204 sqlx::query_file!( 205 "queries/upsert_record.sql", 206 did.as_str(), 207 collection, 208 rkey, 209 rev, 210 cid.as_deref(), 211 live, 212 parsed_record 213 ) 214 .execute(&mut **transaction) 215 .await?; 216 } 217 RecordAction::Delete => { 218 query::delete_record(did, collection, rkey, rev) 219 .execute(&mut **transaction) 220 .await?; 221 } 222 } 223 224 Ok((record_event, parsed_record)) 225} 226 227#[tracing::instrument(skip(tap, pool, did_rx, shutdown))] 228async fn did_task( 229 tap: TapClient, 230 pool: PgPool, 231 mut did_rx: mpsc::UnboundedReceiver<(String, DidSource)>, 232 shutdown: CancellationToken, 233 crawl: bool, 234) -> anyhow::Result<()> { 235 const BATCH: usize = 64; 236 237 let mut seen: HashSet<String> = HashSet::with_capacity(10_000); 238 let mut dids = Vec::new(); 239 240 if crawl { 241 // Query known DIDs from the database. 242 let mut query = sqlx::query!("SELECT did FROM identity").fetch(&pool); 243 while let Some(Ok(row)) = query.next().await { 244 seen.insert(row.did); 245 } 246 247 tracing::debug!(count = seen.len(), "loaded tracked DIDs from database"); 248 } 249 250 loop { 251 tokio::time::sleep(Duration::from_millis(200)).await; 252 match shutdown 253 .run_until_cancelled(did_rx.recv_many(&mut dids, BATCH)) 254 .await 255 { 256 Some(0) | None => break, 257 Some(_) => {} 258 } 259 260 // Convert Vec<Box<Did>> to a Vec<&Did>. 261 let mut dedup: HashSet<&str> = HashSet::new(); 262 let mut slice = Vec::with_capacity(dids.len()); 263 for (did, source) in &dids { 264 if !dedup.insert(did) { 265 continue; 266 } 267 268 if !seen.contains(did) || slice.contains(&did.as_ref()) { 269 tracing::info!(?did, ?source, "tracking DID"); 270 slice.push(did.as_ref()); 271 } 272 } 273 274 tap.add_repos(&slice).await?; 275 276 dids.drain(..).for_each(|(did, _)| _ = seen.insert(did)); 277 } 278 279 tracing::info!("complete"); 280 Ok(()) 281} 282 283#[tracing::instrument(skip(shutdown))] 284async fn shutdown_task(shutdown: CancellationToken) -> anyhow::Result<()> { 285 tokio::signal::ctrl_c().await?; 286 eprintln!(); 287 tracing::info!("shutdown signal received"); 288 shutdown.cancel(); 289 Ok(()) 290} 291 292/// Extract any strings that look like DIDs from a JSON document. 293fn extract_dids(value: &Value) -> HashSet<String> { 294 let mut dids = HashSet::new(); 295 296 let mut queue = VecDeque::from_iter([value]); 297 while let Some(value) = queue.pop_front() { 298 match value { 299 Value::Null | Value::Bool(_) | Value::Number(_) => {} 300 Value::Array(values) => { 301 for value in values { 302 queue.push_back(value); 303 } 304 } 305 Value::Object(map) => { 306 for (_, value) in map { 307 queue.push_back(value); 308 } 309 } 310 Value::String(maybe_did) => { 311 if possible_did(maybe_did) { 312 dids.insert(maybe_did.to_string()); 313 continue; 314 } 315 316 // First segment of an "at://..." URI might be a DID. 317 if let Some(uri) = maybe_did.strip_prefix("at://") 318 && let Some((maybe_did, _)) = uri.split_once('/') 319 && possible_did(maybe_did) 320 { 321 dids.insert(maybe_did.to_string()); 322 continue; 323 } 324 } 325 } 326 } 327 328 dids 329} 330 331fn possible_did(s: &str) -> bool { 332 s.starts_with("did:plc") || s.starts_with("did:web") 333}