personal memory agent
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

link: port test_mux to solstone fork

+167
+167
tests/link/test_mux.py
··· 1 + # SPDX-License-Identifier: AGPL-3.0-only 2 + # Copyright (c) 2026 sol pbc 3 + 4 + from __future__ import annotations 5 + 6 + import asyncio 7 + 8 + import pytest 9 + 10 + from think.link.framing import ( 11 + FLAG_CLOSE, 12 + FLAG_DATA, 13 + FLAG_RESET, 14 + Frame, 15 + FrameDecoder, 16 + build_close, 17 + build_data, 18 + build_open, 19 + ) 20 + from think.link.mux import Multiplexer 21 + 22 + 23 + def _decode_frames(chunks: list[bytes]) -> list[Frame]: 24 + decoder = FrameDecoder() 25 + for chunk in chunks: 26 + decoder.feed(chunk) 27 + return decoder.drain() 28 + 29 + 30 + @pytest.mark.asyncio 31 + async def test_open_with_initial_payload_hits_handler() -> None: 32 + handler_seen: dict[int, bytes] = {} 33 + 34 + async def handler( 35 + reader: asyncio.StreamReader, writer: object 36 + ) -> None: # pragma: no cover - typed by mux 37 + data = await reader.readuntil(b"\n") 38 + handler_seen[1] = data 39 + await writer.write(b"ack\n") # type: ignore[attr-defined] 40 + await writer.close() # type: ignore[attr-defined] 41 + 42 + sent: list[bytes] = [] 43 + 44 + async def send(data: bytes) -> None: 45 + sent.append(data) 46 + 47 + mux = Multiplexer(send, handler, is_listener=True) 48 + await mux.feed(build_open(1, b"hello\n").encode() + build_close(1).encode()) 49 + 50 + for _ in range(20): 51 + await asyncio.sleep(0.005) 52 + if handler_seen.get(1): 53 + break 54 + 55 + assert handler_seen.get(1) == b"hello\n" 56 + 57 + frames = _decode_frames(sent) 58 + flags = [frame.flags for frame in frames] 59 + assert any(flag & FLAG_DATA for flag in flags) 60 + assert any(flag & FLAG_CLOSE for flag in flags) 61 + assert ( 62 + b"".join(frame.payload for frame in frames if frame.flags & FLAG_DATA) 63 + == b"ack\n" 64 + ) 65 + await mux.close() 66 + 67 + 68 + @pytest.mark.asyncio 69 + async def test_wrong_parity_stream_id_gets_reset() -> None: 70 + sent: list[bytes] = [] 71 + 72 + async def send(data: bytes) -> None: 73 + sent.append(data) 74 + 75 + async def handler(*_: object) -> None: 76 + pytest.fail("handler should not be reached for wrong-parity stream ids") 77 + 78 + mux = Multiplexer(send, handler, is_listener=True) 79 + await mux.feed(build_open(2).encode()) 80 + 81 + frames = _decode_frames(sent) 82 + assert any(frame.stream_id == 2 and frame.flags & FLAG_RESET for frame in frames) 83 + await mux.close() 84 + 85 + 86 + @pytest.mark.asyncio 87 + async def test_unknown_stream_data_gets_reset() -> None: 88 + sent: list[bytes] = [] 89 + 90 + async def send(data: bytes) -> None: 91 + sent.append(data) 92 + 93 + async def handler(*_: object) -> None: 94 + return 95 + 96 + mux = Multiplexer(send, handler, is_listener=True) 97 + await mux.feed(build_data(99, b"x").encode()) 98 + 99 + frames = _decode_frames(sent) 100 + assert any(frame.stream_id == 99 and frame.flags & FLAG_RESET for frame in frames) 101 + await mux.close() 102 + 103 + 104 + @pytest.mark.asyncio 105 + async def test_concurrent_streams_do_not_interfere() -> None: 106 + responses: dict[int, bytes] = {} 107 + 108 + async def handler( 109 + reader: asyncio.StreamReader, writer: object 110 + ) -> None: # pragma: no cover - typed by mux 111 + payload = await reader.readuntil(b"\n") 112 + await writer.write(payload) # type: ignore[attr-defined] 113 + await writer.close() # type: ignore[attr-defined] 114 + 115 + sent: list[bytes] = [] 116 + 117 + async def send(data: bytes) -> None: 118 + sent.append(data) 119 + 120 + mux = Multiplexer(send, handler, is_listener=True) 121 + bulk = bytearray() 122 + for stream_id in (1, 3, 5, 7, 9): 123 + bulk.extend(build_open(stream_id, f"stream-{stream_id}\n".encode()).encode()) 124 + bulk.extend(build_close(stream_id).encode()) 125 + 126 + await mux.feed(bytes(bulk)) 127 + 128 + for _ in range(50): 129 + await asyncio.sleep(0.005) 130 + frames = _decode_frames(sent) 131 + for frame in frames: 132 + if frame.flags & FLAG_DATA: 133 + responses.setdefault(frame.stream_id, b"") 134 + responses[frame.stream_id] += frame.payload 135 + if all(stream_id in responses for stream_id in (1, 3, 5, 7, 9)): 136 + break 137 + 138 + for stream_id in (1, 3, 5, 7, 9): 139 + assert responses.get(stream_id) == f"stream-{stream_id}\n".encode() 140 + await mux.close() 141 + 142 + 143 + @pytest.mark.asyncio 144 + async def test_validates_open_reopen_is_protocol_error() -> None: 145 + sent: list[bytes] = [] 146 + 147 + async def send(data: bytes) -> None: 148 + sent.append(data) 149 + 150 + gate = asyncio.Event() 151 + 152 + async def handler( 153 + reader: asyncio.StreamReader, writer: object 154 + ) -> None: # pragma: no cover - typed by mux 155 + await gate.wait() 156 + await writer.close() # type: ignore[attr-defined] 157 + 158 + mux = Multiplexer(send, handler, is_listener=True) 159 + await mux.feed(build_open(1).encode()) 160 + await asyncio.sleep(0.01) 161 + await mux.feed(build_open(1).encode()) 162 + 163 + frames = _decode_frames(sent) 164 + assert any(frame.stream_id == 1 and frame.flags & FLAG_RESET for frame in frames) 165 + 166 + gate.set() 167 + await mux.close()