audio streaming app plyr.fm
38
fork

Configure Feed

Select the types of activity you want to include in your feed.

feat: add Jetstream WebSocket consumer for real-time ATProto ingestion (#1070)

add JetstreamConsumer that connects to ATProto Jetstream, filters for
known artist DIDs, and dispatches record events to the ingest task
layer via Docket.

- WebSocket consumer with auto-reconnect and exponential backoff
- Redis-persisted cursor for idempotent replay after restart
- Periodic DID set refresh from Artist table
- Runs as Docket perpetual task, disabled by default (JETSTREAM_ENABLED)
- JetstreamSettings in config for URL, cursor key, reconnect params

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>

authored by

nate nowack
Claude Opus 4.6
and committed by
GitHub
26d90b47 02cec63d

+472 -4
+291
backend/src/backend/_internal/jetstream.py
··· 1 + """ATProto Jetstream consumer for real-time record ingestion. 2 + 3 + connects to Jetstream's WebSocket endpoint and listens for fm.plyr.* record 4 + events. events for known DIDs (artists in our database) are dispatched to 5 + docket tasks for resolution into the database. 6 + 7 + the consumer itself is lightweight — just WebSocket receive + dispatch. all 8 + heavy lifting (DB queries, record resolution) happens in docket tasks. 9 + """ 10 + 11 + import asyncio 12 + import logging 13 + import random 14 + import time 15 + from datetime import timedelta 16 + from typing import Any 17 + 18 + import logfire 19 + import orjson 20 + import websockets 21 + from docket import Perpetual 22 + from sqlalchemy import select 23 + from websockets.asyncio.client import ClientConnection 24 + 25 + from backend._internal.background import get_docket 26 + from backend._internal.tasks.ingest import ( 27 + ingest_comment_create, 28 + ingest_comment_delete, 29 + ingest_comment_update, 30 + ingest_like_create, 31 + ingest_like_delete, 32 + ingest_list_create, 33 + ingest_list_delete, 34 + ingest_list_update, 35 + ingest_profile_update, 36 + ingest_track_create, 37 + ingest_track_delete, 38 + ingest_track_update, 39 + ) 40 + from backend.config import settings 41 + from backend.models import Artist 42 + from backend.utilities.database import db_session 43 + from backend.utilities.redis import get_async_redis_client 44 + 45 + logger = logging.getLogger(__name__) 46 + 47 + 48 + class JetstreamConsumer: 49 + """consumes ATProto Jetstream events for fm.plyr.* collections. 50 + 51 + args: 52 + collections: wildcard collection filter (e.g. "fm.plyr.*") 53 + """ 54 + 55 + def __init__(self, collections: str = "fm.plyr.*") -> None: 56 + self._collections = collections 57 + self._ws: ClientConnection | None = None 58 + self._known_dids: set[str] = set() 59 + self._cursor: int | None = None 60 + self._last_cursor_flush: float = 0.0 61 + self._last_did_refresh: float = 0.0 62 + self._shutdown_event = asyncio.Event() 63 + 64 + async def run(self) -> None: 65 + """main loop with auto-reconnect and exponential backoff.""" 66 + backoff = settings.jetstream.reconnect_base_seconds 67 + 68 + while not self._shutdown_event.is_set(): 69 + try: 70 + await self._refresh_known_dids() 71 + await self._load_cursor() 72 + await self._connect_and_consume() 73 + except asyncio.CancelledError: 74 + logger.info("jetstream consumer cancelled") 75 + await self._flush_cursor() 76 + return 77 + except Exception: 78 + logger.exception("jetstream consumer error, reconnecting") 79 + 80 + if self._shutdown_event.is_set(): 81 + return 82 + 83 + # exponential backoff with jitter 84 + jitter = random.uniform(0, backoff * 0.5) 85 + delay = min(backoff + jitter, settings.jetstream.reconnect_max_seconds) 86 + logger.info("jetstream reconnecting in %.1fs", delay) 87 + try: 88 + await asyncio.wait_for(self._shutdown_event.wait(), timeout=delay) 89 + return # shutdown requested during backoff 90 + except TimeoutError: 91 + pass 92 + backoff = min(backoff * 2, settings.jetstream.reconnect_max_seconds) 93 + 94 + async def _connect_and_consume(self) -> None: 95 + """connect to Jetstream and process events until disconnected.""" 96 + url = self._build_url() 97 + logger.info("jetstream connecting to %s", url) 98 + 99 + with logfire.span( 100 + "jetstream consume", 101 + known_dids=len(self._known_dids), 102 + cursor=self._cursor, 103 + ): 104 + async with websockets.connect(url, max_size=2**20) as ws: 105 + self._ws = ws 106 + logfire.info( 107 + "jetstream connected", 108 + known_dids=len(self._known_dids), 109 + ) 110 + 111 + async for raw in ws: 112 + if self._shutdown_event.is_set(): 113 + return 114 + 115 + try: 116 + event = orjson.loads(raw) 117 + except (orjson.JSONDecodeError, TypeError): 118 + continue 119 + 120 + await self._process_event(event) 121 + await self._maybe_flush_cursor() 122 + await self._maybe_refresh_dids() 123 + 124 + async def _process_event(self, event: dict[str, Any]) -> None: 125 + """check if event is for a known DID and dispatch to docket task.""" 126 + kind = event.get("kind") 127 + if kind != "commit": 128 + return 129 + 130 + did = event.get("did") 131 + if not did or did not in self._known_dids: 132 + return 133 + 134 + commit = event.get("commit", {}) 135 + collection = commit.get("collection", "") 136 + operation = commit.get("operation", "") 137 + rkey = commit.get("rkey", "") 138 + record = commit.get("record") 139 + cid = commit.get("cid") 140 + 141 + # update cursor from event time_us 142 + if time_us := event.get("time_us"): 143 + self._cursor = time_us 144 + 145 + # build AT URI 146 + uri = f"at://{did}/{collection}/{rkey}" 147 + 148 + await self._dispatch( 149 + collection=collection, 150 + operation=operation, 151 + did=did, 152 + rkey=rkey, 153 + record=record, 154 + uri=uri, 155 + cid=cid, 156 + ) 157 + 158 + async def _dispatch( 159 + self, 160 + collection: str, 161 + operation: str, 162 + did: str, 163 + rkey: str, 164 + record: dict[str, Any] | None, 165 + uri: str, 166 + cid: str | None, 167 + ) -> None: 168 + """dispatch event to the appropriate ingest task via docket.""" 169 + docket = get_docket() 170 + 171 + # determine which collection type this is (strip namespace prefix) 172 + # e.g. "fm.plyr.track" or "fm.plyr.dev.track" → "track" 173 + parts = collection.rsplit(".", 1) 174 + if len(parts) != 2: 175 + return 176 + record_type = parts[1] 177 + 178 + task_map: dict[tuple[str, str], Any] = { 179 + ("track", "create"): ingest_track_create, 180 + ("track", "update"): ingest_track_update, 181 + ("track", "delete"): ingest_track_delete, 182 + ("like", "create"): ingest_like_create, 183 + ("like", "delete"): ingest_like_delete, 184 + ("comment", "create"): ingest_comment_create, 185 + ("comment", "update"): ingest_comment_update, 186 + ("comment", "delete"): ingest_comment_delete, 187 + ("list", "create"): ingest_list_create, 188 + ("list", "update"): ingest_list_update, 189 + ("list", "delete"): ingest_list_delete, 190 + } 191 + 192 + # profile updates are a special case (nested collection) 193 + if collection.endswith(".actor.profile") and operation == "update": 194 + await docket.add(ingest_profile_update)(did=did, record=record or {}) 195 + logfire.info( 196 + "jetstream dispatched profile.update", 197 + did=did, 198 + _level="debug", 199 + ) 200 + return 201 + 202 + if task := task_map.get((record_type, operation)): 203 + kwargs: dict[str, Any] = {"did": did, "rkey": rkey, "uri": uri} 204 + if operation in ("create", "update"): 205 + kwargs["record"] = record or {} 206 + kwargs["cid"] = cid 207 + await docket.add(task)(**kwargs) 208 + logfire.info( 209 + "jetstream dispatched {record_type}.{operation}", 210 + record_type=record_type, 211 + operation=operation, 212 + did=did, 213 + uri=uri, 214 + ) 215 + 216 + def _build_url(self) -> str: 217 + """build WebSocket URL with query parameters.""" 218 + params = [f"wantedCollections={self._collections}"] 219 + if self._cursor is not None: 220 + # rewind cursor by 5 seconds for idempotent reprocessing 221 + rewound = self._cursor - 5_000_000 222 + params.append(f"cursor={rewound}") 223 + return f"{settings.jetstream.url}?{'&'.join(params)}" 224 + 225 + async def _load_cursor(self) -> None: 226 + """load cursor from Redis on startup.""" 227 + try: 228 + redis = get_async_redis_client() 229 + if raw := await redis.get(settings.jetstream.cursor_key): 230 + self._cursor = int(raw) 231 + logger.info("jetstream resuming from cursor %d", self._cursor) 232 + except Exception: 233 + logger.debug("jetstream could not load cursor from redis") 234 + 235 + async def _flush_cursor(self) -> None: 236 + """persist current cursor to Redis.""" 237 + if self._cursor is None: 238 + return 239 + try: 240 + redis = get_async_redis_client() 241 + await redis.set(settings.jetstream.cursor_key, str(self._cursor)) 242 + except Exception: 243 + logger.debug("jetstream could not flush cursor to redis") 244 + self._last_cursor_flush = time.monotonic() 245 + 246 + async def _maybe_flush_cursor(self) -> None: 247 + """flush cursor if enough time has elapsed.""" 248 + now = time.monotonic() 249 + if ( 250 + now - self._last_cursor_flush 251 + >= settings.jetstream.cursor_flush_interval_seconds 252 + ): 253 + await self._flush_cursor() 254 + 255 + async def _refresh_known_dids(self) -> None: 256 + """refresh the known DID set from the database.""" 257 + try: 258 + async with db_session() as db: 259 + result = await db.execute(select(Artist.did)) 260 + self._known_dids = {row[0] for row in result.fetchall()} 261 + logger.info( 262 + "jetstream refreshed known DIDs: %d artists", len(self._known_dids) 263 + ) 264 + except Exception: 265 + logger.warning("jetstream could not refresh known DIDs", exc_info=True) 266 + self._last_did_refresh = time.monotonic() 267 + 268 + async def _maybe_refresh_dids(self) -> None: 269 + """refresh known DIDs if enough time has elapsed.""" 270 + now = time.monotonic() 271 + if ( 272 + now - self._last_did_refresh 273 + >= settings.jetstream.did_refresh_interval_seconds 274 + ): 275 + await self._refresh_known_dids() 276 + 277 + 278 + async def consume_jetstream( 279 + perpetual: Perpetual = Perpetual(every=timedelta(seconds=0), automatic=True), # noqa: B008 280 + ) -> None: 281 + """perpetual task: run the Jetstream WebSocket consumer. 282 + 283 + docket's Redis lock ensures only one instance runs this across all workers. 284 + if the consumer exits (crash, disconnect), Perpetual reschedules immediately. 285 + """ 286 + if not settings.jetstream.enabled: 287 + perpetual.cancel() 288 + return 289 + 290 + consumer = JetstreamConsumer() 291 + await consumer.run()
+3
backend/src/backend/_internal/tasks/__init__.py
··· 7 7 """ 8 8 9 9 from backend._internal.export_tasks import process_export 10 + from backend._internal.jetstream import consume_jetstream 10 11 from backend._internal.pds_backfill_tasks import backfill_tracks_to_pds 11 12 from backend._internal.tasks.copyright import ( 12 13 scan_copyright, ··· 68 69 69 70 # collection of all background task functions for docket registration 70 71 background_tasks = [ 72 + consume_jetstream, 71 73 scan_copyright, 72 74 sync_copyright_resolutions, 73 75 process_export, ··· 102 104 "SubjectNotFoundError", 103 105 "background_tasks", 104 106 "classify_genres", 107 + "consume_jetstream", 105 108 "generate_embedding", 106 109 "ingest_comment_create", 107 110 "ingest_comment_delete",
+44
backend/src/backend/config.py
··· 756 756 ) 757 757 758 758 759 + class JetstreamSettings(AppSettingsSection): 760 + """ATProto Jetstream consumer settings for real-time record ingestion.""" 761 + 762 + model_config = SettingsConfigDict( 763 + env_prefix="JETSTREAM_", 764 + env_file=".env", 765 + case_sensitive=False, 766 + extra="ignore", 767 + ) 768 + 769 + enabled: bool = Field( 770 + default=False, 771 + description="Enable Jetstream consumer for real-time ATProto event ingestion", 772 + ) 773 + url: str = Field( 774 + default="wss://jetstream2.us-east.bsky.network/subscribe", 775 + description="Jetstream WebSocket URL", 776 + ) 777 + cursor_key: str = Field( 778 + default="plyr:jetstream:cursor", 779 + description="Redis key for persisting the Jetstream cursor", 780 + ) 781 + cursor_flush_interval_seconds: int = Field( 782 + default=10, 783 + description="How often to flush cursor to Redis (seconds)", 784 + ) 785 + reconnect_base_seconds: float = Field( 786 + default=1.0, 787 + description="Base delay for exponential backoff on reconnect", 788 + ) 789 + reconnect_max_seconds: float = Field( 790 + default=30.0, 791 + description="Maximum delay between reconnect attempts", 792 + ) 793 + did_refresh_interval_seconds: int = Field( 794 + default=300, 795 + description="How often to refresh the known DID set from the database (seconds)", 796 + ) 797 + 798 + 759 799 class DocketSettings(AppSettingsSection): 760 800 """Background task queue configuration using pydocket. 761 801 ··· 928 968 bufo: BufoSettings = Field( 929 969 default_factory=BufoSettings, 930 970 description="bufo easter egg settings", 971 + ) 972 + jetstream: JetstreamSettings = Field( 973 + default_factory=JetstreamSettings, 974 + description="ATProto Jetstream consumer settings", 931 975 ) 932 976 docket: DocketSettings = Field( 933 977 default_factory=DocketSettings,
+132 -2
backend/tests/test_jetstream.py
··· 1 1 """tests for Jetstream consumer and ingest tasks.""" 2 2 3 3 import uuid 4 - from datetime import UTC, datetime 5 - from unittest.mock import AsyncMock, patch 4 + from datetime import UTC, datetime, timedelta 5 + from unittest.mock import AsyncMock, MagicMock, patch 6 6 7 7 import pytest 8 + from docket import Perpetual 8 9 from sqlalchemy import select 9 10 from sqlalchemy.ext.asyncio import AsyncSession 10 11 12 + from backend._internal.jetstream import JetstreamConsumer, consume_jetstream 11 13 from backend._internal.tasks.ingest import ( 12 14 SubjectNotFoundError, 13 15 ingest_comment_create, ··· 78 80 db_session.add(t) 79 81 await db_session.commit() 80 82 return t 83 + 84 + 85 + # --- consumer tests --- 86 + 87 + 88 + class TestJetstreamConsumer: 89 + async def test_dispatches_track_create(self) -> None: 90 + consumer = JetstreamConsumer() 91 + consumer._known_dids = {"did:plc:jetstream_test"} 92 + 93 + mock_docket = MagicMock() 94 + dispatched: list[dict] = [] 95 + 96 + async def capture(**kwargs: object) -> None: 97 + dispatched.append(dict(kwargs)) 98 + 99 + mock_docket.add = MagicMock(return_value=capture) 100 + 101 + event = { 102 + "kind": "commit", 103 + "did": "did:plc:jetstream_test", 104 + "time_us": 1000000, 105 + "commit": { 106 + "collection": "fm.plyr.track", 107 + "operation": "create", 108 + "rkey": "abc123", 109 + "record": {"title": "New Track"}, 110 + "cid": "bafynew", 111 + }, 112 + } 113 + 114 + with patch("backend._internal.jetstream.get_docket", return_value=mock_docket): 115 + await consumer._process_event(event) 116 + 117 + assert len(dispatched) == 1 118 + assert dispatched[0]["did"] == "did:plc:jetstream_test" 119 + assert ( 120 + dispatched[0]["uri"] == "at://did:plc:jetstream_test/fm.plyr.track/abc123" 121 + ) 122 + 123 + async def test_skips_unknown_did(self) -> None: 124 + consumer = JetstreamConsumer() 125 + consumer._known_dids = {"did:plc:known"} 126 + 127 + event = { 128 + "kind": "commit", 129 + "did": "did:plc:unknown", 130 + "commit": { 131 + "collection": "fm.plyr.track", 132 + "operation": "create", 133 + "rkey": "abc", 134 + }, 135 + } 136 + 137 + # _dispatch should never be called 138 + consumer._dispatch = AsyncMock() # type: ignore[method-assign] 139 + await consumer._process_event(event) 140 + consumer._dispatch.assert_not_called() # type: ignore[union-attr] 141 + 142 + async def test_skips_non_commit_events(self) -> None: 143 + consumer = JetstreamConsumer() 144 + consumer._known_dids = {"did:plc:jetstream_test"} 145 + consumer._dispatch = AsyncMock() # type: ignore[method-assign] 146 + 147 + event = {"kind": "identity", "did": "did:plc:jetstream_test"} 148 + await consumer._process_event(event) 149 + consumer._dispatch.assert_not_called() # type: ignore[union-attr] 150 + 151 + async def test_persists_cursor(self) -> None: 152 + consumer = JetstreamConsumer() 153 + mock_redis = AsyncMock() 154 + 155 + with patch( 156 + "backend._internal.jetstream.get_async_redis_client", 157 + return_value=mock_redis, 158 + ): 159 + consumer._cursor = 12345678 160 + await consumer._flush_cursor() 161 + 162 + mock_redis.set.assert_called_once() 163 + args = mock_redis.set.call_args 164 + assert args[0][1] == "12345678" 165 + 166 + async def test_resumes_from_cursor(self) -> None: 167 + consumer = JetstreamConsumer() 168 + mock_redis = AsyncMock() 169 + mock_redis.get = AsyncMock(return_value="9999999") 170 + 171 + with patch( 172 + "backend._internal.jetstream.get_async_redis_client", 173 + return_value=mock_redis, 174 + ): 175 + await consumer._load_cursor() 176 + 177 + assert consumer._cursor == 9999999 178 + url = consumer._build_url() 179 + assert "cursor=" in url 180 + 181 + async def test_build_url_without_cursor(self) -> None: 182 + consumer = JetstreamConsumer() 183 + url = consumer._build_url() 184 + assert "wantedCollections=fm.plyr.*" in url 185 + assert "cursor=" not in url 186 + 187 + async def test_build_url_with_cursor_rewinds(self) -> None: 188 + consumer = JetstreamConsumer() 189 + consumer._cursor = 10_000_000 # 10 seconds in microseconds 190 + url = consumer._build_url() 191 + # rewound by 5_000_000 → cursor=5000000 192 + assert "cursor=5000000" in url 193 + 194 + 195 + class TestConsumeJetstreamPerpetual: 196 + async def test_cancels_perpetual_when_disabled(self) -> None: 197 + perpetual = Perpetual(every=timedelta(seconds=0)) 198 + with patch("backend._internal.jetstream.settings") as mock_settings: 199 + mock_settings.jetstream.enabled = False 200 + await consume_jetstream(perpetual=perpetual) 201 + assert perpetual.cancelled 202 + 203 + async def test_runs_consumer_when_enabled(self) -> None: 204 + with ( 205 + patch("backend._internal.jetstream.settings") as mock_settings, 206 + patch.object(JetstreamConsumer, "run", new_callable=AsyncMock) as mock_run, 207 + ): 208 + mock_settings.jetstream.enabled = True 209 + await consume_jetstream() 210 + mock_run.assert_called_once() 81 211 82 212 83 213 # --- track ingestion tests ---
+2 -2
loq.toml
··· 51 51 52 52 [[rules]] 53 53 path = "backend/src/backend/config.py" 54 - max_lines = 950 54 + max_lines = 1000 55 55 56 56 [[rules]] 57 57 path = "backend/src/backend/storage/r2.py" ··· 223 223 224 224 [[rules]] 225 225 path = "backend/tests/test_jetstream.py" 226 - max_lines = 1000 226 + max_lines = 1130 227 227 228 228 [[rules]] 229 229 path = "frontend/src/lib/components/embed/CollectionEmbed.svelte"