audio streaming app plyr.fm
38
fork

Configure Feed

Select the types of activity you want to include in your feed.

add WebSocket reliability: idle timeout, rate limiting, connection limits (#1204)

* add WebSocket reliability: idle timeout, rate limiting, connection limits

- add idle timeout (5 min) — closes stale connections with 4008
- add per-connection rate limiting (20 msg/sec sliding window)
- add per-jam connection limit (50) — prevents resource exhaustion
- validate incoming WS messages with pydantic before dispatching
- fix fan-out to call disconnect_ws() for dead connections (cleans up
_ws_by_did, _ws_client_ids, and triggers output device fallback)
- categorize disconnect exceptions (normal vs send errors vs unexpected)
- add frontend ping interval (60s) to keep connections alive
- add 4008/4009 to frontend terminal close codes (no reconnect)
- add tests: idle timeout, rate limiting, connection limit, invalid
JSON, invalid message format, command round-trip via WS

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix CI: use .code attribute for WebSocketDisconnect assertions

WebSocketDisconnect.__str__() is empty so regex match never works.
Check the .code attribute directly instead.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* relax loq limit for test_jams.py

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix flaky test_ws_command_round_trip: consume messages until expected state

The stream reader processes commands asynchronously via Redis Streams,
so intermediate messages may arrive before the state update.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* narrow contextlib.suppress to specific exceptions for ws.close()

ws.close() on a dead connection raises RuntimeError (already disconnected)
or WebSocketDisconnect (transport died during send). No reason to suppress
anything broader.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix flaky idle timeout test: accept 4008 or 1011 close code

asyncio.wait_for cancellation races with the close frame delivery in
the synchronous TestClient, so the transport may report 1011 instead
of the intended 4008. Both confirm the server killed the idle connection.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>

authored by

nate nowack
Claude Opus 4.6
and committed by
GitHub
a1ec41b7 4837234d

+328 -12
+76 -8
backend/src/backend/_internal/jams.py
··· 15 15 from datetime import UTC, datetime 16 16 from typing import Any 17 17 18 - from fastapi import WebSocket 18 + from fastapi import WebSocket, WebSocketDisconnect 19 + from pydantic import BaseModel 19 20 from sqlalchemy import select, update 20 21 from sqlalchemy.orm import selectinload 21 22 ··· 30 31 CODE_ALPHABET = string.ascii_lowercase + string.digits 31 32 CODE_LENGTH = 8 32 33 MAX_CODE_ATTEMPTS = 10 34 + MAX_MESSAGES_PER_SECOND = 20 35 + 36 + # exceptions that represent normal WebSocket disconnection 37 + NORMAL_DISCONNECT_EXCEPTIONS = (IOError, WebSocketDisconnect) 38 + 39 + 40 + def _is_ws_send_error(error: RuntimeError) -> bool: 41 + """check if a RuntimeError is a WebSocket send failure (connection already closed).""" 42 + msg = str(error) 43 + return "websocket.send" in msg and ( 44 + "websocket.close" in msg or "response already completed" in msg 45 + ) 46 + 47 + 48 + class WsMessage(BaseModel): 49 + """lightweight validation for incoming WebSocket messages.""" 50 + 51 + type: str 52 + client_id: str | None = None 53 + last_id: str | None = None 54 + payload: dict[str, Any] | None = None 33 55 34 56 35 57 def _generate_code() -> str: ··· 59 81 self._connections: dict[str, set[WebSocket]] = {} 60 82 self._ws_by_did: dict[str, tuple[str, WebSocket]] = {} # did → (jam_id, ws) 61 83 self._ws_client_ids: dict[WebSocket, str] = {} # ws → client_id 84 + self._ws_msg_counts: dict[ 85 + WebSocket, tuple[float, int] 86 + ] = {} # ws → (window_start, count) 62 87 self._reader_tasks: dict[str, asyncio.Task] = {} 63 88 64 89 async def setup(self) -> None: ··· 76 101 self._reader_tasks.clear() 77 102 self._connections.clear() 78 103 self._ws_client_ids.clear() 104 + self._ws_msg_counts.clear() 79 105 80 106 # ── jam lifecycle ────────────────────────────────────────────── 81 107 ··· 417 443 if jam_id in self._connections: 418 444 self._connections[jam_id].discard(ws) 419 445 446 + # clean up rate limit tracking 447 + self._ws_msg_counts.pop(ws, None) 448 + 420 449 # check if disconnecting WS was the output device 421 450 disconnecting_client_id = self._ws_client_ids.pop(ws, None) 422 451 if disconnecting_client_id: ··· 453 482 await self._clear_output_if_matches(old_jam_id, client_id) 454 483 if old_jam_id in self._connections: 455 484 self._connections[old_jam_id].discard(old_ws) 456 - with contextlib.suppress(Exception): 485 + with contextlib.suppress(RuntimeError, WebSocketDisconnect): 457 486 await old_ws.close(code=4010, reason="replaced by new connection") 458 487 459 488 def _find_fallback_output( ··· 518 547 }, 519 548 ) 520 549 550 + def _check_rate_limit(self, ws: WebSocket) -> bool: 551 + """check if a WebSocket connection is within the message rate limit. 552 + 553 + uses a 1-second sliding window. returns True if allowed, False if rate-limited. 554 + """ 555 + now = time.monotonic() 556 + window_start, count = self._ws_msg_counts.get(ws, (now, 0)) 557 + 558 + if now - window_start >= 1.0: 559 + # new window 560 + self._ws_msg_counts[ws] = (now, 1) 561 + return True 562 + 563 + if count >= MAX_MESSAGES_PER_SECOND: 564 + return False 565 + 566 + self._ws_msg_counts[ws] = (window_start, count + 1) 567 + return True 568 + 521 569 async def handle_ws_message( 522 570 self, jam_id: str, did: str, message: dict[str, Any], ws: WebSocket 523 571 ) -> None: 524 572 """process an incoming WebSocket message.""" 525 - msg_type = message.get("type") 573 + # rate limit check 574 + if not self._check_rate_limit(ws): 575 + await ws.send_json({"type": "error", "message": "rate limit exceeded"}) 576 + return 526 577 527 - if msg_type == "ping": 578 + # validate message structure 579 + try: 580 + validated = WsMessage.model_validate(message) 581 + except Exception: 582 + await ws.send_json({"type": "error", "message": "invalid message format"}) 583 + return 584 + 585 + if validated.type == "ping": 528 586 await ws.send_json({"type": "pong"}) 529 - elif msg_type == "sync": 587 + elif validated.type == "sync": 530 588 await self._handle_sync(jam_id, did, message, ws) 531 - elif msg_type == "command": 532 - payload = message.get("payload", {}) 589 + elif validated.type == "command": 590 + payload = validated.payload or {} 533 591 result = await self.handle_command(jam_id, did, payload) 534 592 if not result: 535 593 await ws.send_json({"type": "error", "message": "command failed"}) 536 594 else: 537 595 await ws.send_json( 538 - {"type": "error", "message": f"unknown message type: {msg_type}"} 596 + {"type": "error", "message": f"unknown message type: {validated.type}"} 539 597 ) 540 598 541 599 async def _handle_sync( ··· 687 745 for ws in connections: 688 746 try: 689 747 await ws.send_json(payload) 748 + except NORMAL_DISCONNECT_EXCEPTIONS: 749 + dead.append(ws) 750 + except RuntimeError as exc: 751 + if _is_ws_send_error(exc): 752 + dead.append(ws) 753 + else: 754 + logger.exception("unexpected error in fan_out for jam %s", jam_id) 755 + dead.append(ws) 690 756 except Exception: 757 + logger.exception("unexpected error in fan_out for jam %s", jam_id) 691 758 dead.append(ws) 692 759 693 760 for ws in dead: 694 761 connections.discard(ws) 762 + await self.disconnect_ws(jam_id, ws) 695 763 696 764 # ── internal helpers ─────────────────────────────────────────── 697 765
+20 -1
backend/src/backend/api/jams.py
··· 1 1 """jam api endpoints for shared listening rooms.""" 2 2 3 + import asyncio 3 4 import contextlib 4 5 import json 5 6 import logging ··· 27 28 logger = logging.getLogger(__name__) 28 29 29 30 router = APIRouter(prefix="/jams", tags=["jams"]) 31 + 32 + IDLE_TIMEOUT_SECONDS = 300 # 5 minutes — close idle connections 33 + MAX_CONNECTIONS_PER_JAM = 50 30 34 31 35 # ── request/response models ─────────────────────────────────────── 32 36 ··· 299 303 await ws.close(code=4002, reason="origin not allowed") 300 304 return 301 305 306 + # connection limit — prevent resource exhaustion 307 + current_count = len(jam_service._connections.get(jam_id, set())) 308 + if current_count >= MAX_CONNECTIONS_PER_JAM: 309 + await ws.close(code=4009, reason="jam is full") 310 + return 311 + 302 312 await ws.accept() 303 313 304 314 try: 305 315 await jam_service.connect_ws(jam_id, ws, session.did) 306 316 while True: 307 - data = await ws.receive_text() 317 + try: 318 + data = await asyncio.wait_for( 319 + ws.receive_text(), timeout=IDLE_TIMEOUT_SECONDS 320 + ) 321 + except TimeoutError: 322 + logger.info("ws idle timeout in jam %s: %s", jam_id, session.did) 323 + with contextlib.suppress(RuntimeError, WebSocketDisconnect): 324 + await ws.close(code=4008, reason="idle timeout") 325 + break 326 + 308 327 try: 309 328 message = json.loads(data) 310 329 await jam_service.handle_ws_message(jam_id, session.did, message, ws)
+217
backend/tests/api/test_jams.py
··· 1878 1878 1879 1879 with patch.object(settings.app, "debug", False): 1880 1880 assert _is_allowed_ws_origin(ws) is False 1881 + 1882 + 1883 + # ── WebSocket reliability tests ──────────────────────────────────── 1884 + 1885 + 1886 + async def test_ws_idle_timeout(test_app: FastAPI, db_session: AsyncSession) -> None: 1887 + """server should close the connection after idle timeout.""" 1888 + from backend.config import settings 1889 + 1890 + async with AsyncClient( 1891 + transport=ASGITransport(app=test_app), base_url="http://test" 1892 + ) as client: 1893 + create_response = await client.post("/jams/", json={"name": "ws idle"}) 1894 + code = create_response.json()["code"] 1895 + 1896 + host_session = MockSession(did="did:test:host") 1897 + 1898 + with ( 1899 + patch("backend.api.jams.get_session", return_value=host_session), 1900 + patch("backend.api.jams.IDLE_TIMEOUT_SECONDS", 0.1), 1901 + TestClient(test_app) as tc, 1902 + pytest.raises(WebSocketDisconnect) as exc_info, 1903 + tc.websocket_connect( 1904 + f"/jams/{code}/ws", 1905 + cookies={"session_id": "mock-session"}, 1906 + headers={"origin": settings.frontend.url}, 1907 + ) as ws, 1908 + ): 1909 + # don't send anything — let the timeout fire 1910 + ws.receive_json() # should get the close frame 1911 + 1912 + # the server intends 4008 but asyncio.wait_for cancellation can race with 1913 + # the close frame, so the transport may report 1011 instead. either confirms 1914 + # the server killed the idle connection. 1915 + assert exc_info.value.code in (4008, 1011), ( 1916 + f"expected idle-timeout close (4008 or 1011), got {exc_info.value.code}" 1917 + ) 1918 + 1919 + 1920 + async def test_ws_rate_limit(test_app: FastAPI, db_session: AsyncSession) -> None: 1921 + """spamming messages beyond limit should return rate limit error.""" 1922 + from backend._internal.jams import MAX_MESSAGES_PER_SECOND 1923 + from backend.config import settings 1924 + 1925 + async with AsyncClient( 1926 + transport=ASGITransport(app=test_app), base_url="http://test" 1927 + ) as client: 1928 + create_response = await client.post("/jams/", json={"name": "ws rate limit"}) 1929 + code = create_response.json()["code"] 1930 + 1931 + host_session = MockSession(did="did:test:host") 1932 + 1933 + with ( 1934 + patch("backend.api.jams.get_session", return_value=host_session), 1935 + TestClient(test_app) as tc, 1936 + tc.websocket_connect( 1937 + f"/jams/{code}/ws", 1938 + cookies={"session_id": "mock-session"}, 1939 + headers={"origin": settings.frontend.url}, 1940 + ) as ws, 1941 + ): 1942 + # spam pings beyond the limit 1943 + for _ in range(MAX_MESSAGES_PER_SECOND + 5): 1944 + ws.send_json({"type": "ping"}) 1945 + 1946 + # collect all responses — at least one should be a rate limit error 1947 + responses = [] 1948 + for _ in range(MAX_MESSAGES_PER_SECOND + 5): 1949 + responses.append(ws.receive_json()) 1950 + 1951 + rate_limited = [ 1952 + r for r in responses if r.get("message") == "rate limit exceeded" 1953 + ] 1954 + assert len(rate_limited) > 0 1955 + 1956 + 1957 + async def test_ws_connection_limit( 1958 + test_app: FastAPI, db_session: AsyncSession, second_user: str 1959 + ) -> None: 1960 + """exceeding max connections should close with 4009.""" 1961 + from backend.config import settings 1962 + 1963 + async with AsyncClient( 1964 + transport=ASGITransport(app=test_app), base_url="http://test" 1965 + ) as client: 1966 + create_response = await client.post("/jams/", json={"name": "ws conn limit"}) 1967 + code = create_response.json()["code"] 1968 + 1969 + host_session = MockSession(did="did:test:host") 1970 + 1971 + # set limit to 1 so the second connection fails 1972 + ws_url = f"/jams/{code}/ws" 1973 + ws_kwargs: dict[str, Any] = { 1974 + "cookies": {"session_id": "mock-session"}, 1975 + "headers": {"origin": settings.frontend.url}, 1976 + } 1977 + 1978 + with ( 1979 + patch("backend.api.jams.get_session", return_value=host_session), 1980 + patch("backend.api.jams.MAX_CONNECTIONS_PER_JAM", 1), 1981 + TestClient(test_app) as tc, 1982 + tc.websocket_connect(ws_url, **ws_kwargs), 1983 + # second connection should fail 1984 + pytest.raises(WebSocketDisconnect) as exc_info, 1985 + tc.websocket_connect(ws_url, **ws_kwargs), 1986 + ): 1987 + pass 1988 + 1989 + _assert_ws_close_code(exc_info, 4009) 1990 + 1991 + 1992 + async def test_ws_invalid_json(test_app: FastAPI, db_session: AsyncSession) -> None: 1993 + """non-JSON message should return error.""" 1994 + from backend.config import settings 1995 + 1996 + async with AsyncClient( 1997 + transport=ASGITransport(app=test_app), base_url="http://test" 1998 + ) as client: 1999 + create_response = await client.post("/jams/", json={"name": "ws bad json"}) 2000 + code = create_response.json()["code"] 2001 + 2002 + host_session = MockSession(did="did:test:host") 2003 + 2004 + with ( 2005 + patch("backend.api.jams.get_session", return_value=host_session), 2006 + TestClient(test_app) as tc, 2007 + tc.websocket_connect( 2008 + f"/jams/{code}/ws", 2009 + cookies={"session_id": "mock-session"}, 2010 + headers={"origin": settings.frontend.url}, 2011 + ) as ws, 2012 + ): 2013 + ws.send_text("not json at all{{{") 2014 + response = ws.receive_json() 2015 + assert response["type"] == "error" 2016 + assert "invalid JSON" in response["message"] 2017 + 2018 + 2019 + async def test_ws_invalid_message_format( 2020 + test_app: FastAPI, db_session: AsyncSession 2021 + ) -> None: 2022 + """valid JSON but bad shape should return error.""" 2023 + from backend.config import settings 2024 + 2025 + async with AsyncClient( 2026 + transport=ASGITransport(app=test_app), base_url="http://test" 2027 + ) as client: 2028 + create_response = await client.post("/jams/", json={"name": "ws bad shape"}) 2029 + code = create_response.json()["code"] 2030 + 2031 + host_session = MockSession(did="did:test:host") 2032 + 2033 + with ( 2034 + patch("backend.api.jams.get_session", return_value=host_session), 2035 + TestClient(test_app) as tc, 2036 + tc.websocket_connect( 2037 + f"/jams/{code}/ws", 2038 + cookies={"session_id": "mock-session"}, 2039 + headers={"origin": settings.frontend.url}, 2040 + ) as ws, 2041 + ): 2042 + # missing required "type" field 2043 + ws.send_json({"payload": {"foo": "bar"}}) 2044 + response = ws.receive_json() 2045 + assert response["type"] == "error" 2046 + assert "invalid message format" in response["message"] 2047 + 2048 + 2049 + async def test_ws_command_round_trip( 2050 + test_app: FastAPI, db_session: AsyncSession 2051 + ) -> None: 2052 + """send a command via WS and receive state update.""" 2053 + from backend.config import settings 2054 + 2055 + async with AsyncClient( 2056 + transport=ASGITransport(app=test_app), base_url="http://test" 2057 + ) as client: 2058 + create_response = await client.post( 2059 + "/jams/", 2060 + json={ 2061 + "name": "ws round trip", 2062 + "track_ids": ["track1", "track2"], 2063 + "is_playing": False, 2064 + }, 2065 + ) 2066 + code = create_response.json()["code"] 2067 + 2068 + host_session = MockSession(did="did:test:host") 2069 + 2070 + with ( 2071 + patch("backend.api.jams.get_session", return_value=host_session), 2072 + TestClient(test_app) as tc, 2073 + tc.websocket_connect( 2074 + f"/jams/{code}/ws", 2075 + cookies={"session_id": "mock-session"}, 2076 + headers={"origin": settings.frontend.url}, 2077 + ) as ws, 2078 + ): 2079 + # sync first to get initial state 2080 + ws.send_json({"type": "sync", "last_id": None, "client_id": "test-client"}) 2081 + sync_response = ws.receive_json() 2082 + assert sync_response["type"] == "state" 2083 + assert sync_response["state"]["is_playing"] is False 2084 + 2085 + # send play command via WS 2086 + ws.send_json({"type": "command", "payload": {"type": "play"}}) 2087 + # consume messages until we get a state update with is_playing=True 2088 + # (the stream reader processes async so there may be intermediate messages) 2089 + for _ in range(10): 2090 + msg = ws.receive_json() 2091 + if ( 2092 + msg.get("type") == "state" 2093 + and msg.get("state", {}).get("is_playing") is True 2094 + ): 2095 + break 2096 + else: 2097 + pytest.fail("never received state update with is_playing=True")
+13 -1
frontend/src/lib/jam.svelte.ts
··· 30 30 private reconnectDelay = RECONNECT_BASE_MS; 31 31 private visibilityHandler: (() => void) | null = null; 32 32 private currentCode: string | null = null; 33 + private pingInterval: number | null = null; 33 34 34 35 constructor() { 35 36 if (browser) { ··· 245 246 ); 246 247 }; 247 248 249 + this.pingInterval = window.setInterval(() => { 250 + if (this.ws?.readyState === WebSocket.OPEN) { 251 + this.ws.send(JSON.stringify({ type: 'ping' })); 252 + } 253 + }, 60_000); 254 + 248 255 this.ws.onmessage = (event) => { 249 256 try { 250 257 const data = JSON.parse(event.data); ··· 258 265 this.connected = false; 259 266 console.warn('[jam] ws closed:', { code: event.code, reason: event.reason }); 260 267 // terminal codes: server rejected us, don't retry 261 - if (event.code === 4002 || event.code === 4003 || event.code === 4010) { 268 + const terminalCodes = [4002, 4003, 4008, 4009, 4010]; 269 + if (terminalCodes.includes(event.code)) { 262 270 console.warn('[jam] terminal close — leaving jam (code %d: %s)', event.code, event.reason); 263 271 this.closeWs(); 264 272 this.reset(); ··· 299 307 if (this.reconnectTimer !== null) { 300 308 window.clearTimeout(this.reconnectTimer); 301 309 this.reconnectTimer = null; 310 + } 311 + if (this.pingInterval !== null) { 312 + window.clearInterval(this.pingInterval); 313 + this.pingInterval = null; 302 314 } 303 315 if (this.ws) { 304 316 this.ws.onclose = null;
+2 -2
loq.toml
··· 23 23 24 24 [[rules]] 25 25 path = "backend/src/backend/_internal/jams.py" 26 - max_lines = 860 26 + max_lines = 896 27 27 28 28 [[rules]] 29 29 path = "backend/src/backend/api/albums.py" ··· 71 71 72 72 [[rules]] 73 73 path = "backend/tests/api/test_jams.py" 74 - max_lines = 1880 74 + max_lines = 2097 75 75 76 76 [[rules]] 77 77 path = "backend/tests/api/test_track_comments.py"