personal memory agent
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

link: add test_wsgi_bridge for solstone

+210
+210
tests/link/test_wsgi_bridge.py
··· 1 + # SPDX-License-Identifier: AGPL-3.0-only 2 + # Copyright (c) 2026 sol pbc 3 + 4 + from __future__ import annotations 5 + 6 + import asyncio 7 + import dataclasses 8 + import hashlib 9 + import json 10 + import os 11 + import time 12 + from collections.abc import Iterator 13 + 14 + import pytest 15 + from flask import Flask, Response, jsonify, request, stream_with_context 16 + 17 + from think.link.wsgi_bridge import ExchangeMetadata, serve_request 18 + 19 + 20 + class _WriterStub: 21 + def __init__(self) -> None: 22 + self.writes: list[bytes] = [] 23 + 24 + async def write(self, data: bytes) -> None: 25 + self.writes.append(data) 26 + 27 + async def drain(self) -> None: 28 + return None 29 + 30 + def close(self) -> None: 31 + return None 32 + 33 + async def wait_closed(self) -> None: 34 + return None 35 + 36 + def joined(self) -> bytes: 37 + return b"".join(self.writes) 38 + 39 + 40 + def _build_app(*, propagate_exceptions: bool = False) -> Flask: 41 + app = Flask(__name__) 42 + app.config["PROPAGATE_EXCEPTIONS"] = propagate_exceptions 43 + 44 + @app.get("/hello") 45 + def hello() -> Response: 46 + return Response(b"hello", mimetype="text/plain") 47 + 48 + @app.get("/stream") 49 + def stream() -> Response: 50 + @stream_with_context 51 + def generate() -> Iterator[bytes]: 52 + for chunk in (b"part-1\n", b"part-2\n", b"part-3\n"): 53 + time.sleep(0.01) 54 + yield chunk 55 + 56 + return Response(generate(), mimetype="text/plain") 57 + 58 + @app.post("/upload") 59 + def upload() -> Response: 60 + body = request.get_data() 61 + return jsonify( 62 + {"sha256": hashlib.sha256(body).hexdigest(), "length": len(body)} 63 + ) 64 + 65 + @app.get("/boom") 66 + def boom() -> Response: 67 + raise RuntimeError("bridge test failure") 68 + 69 + return app 70 + 71 + 72 + def _make_reader(request_bytes: bytes) -> asyncio.StreamReader: 73 + reader = asyncio.StreamReader() 74 + reader.feed_data(request_bytes) 75 + reader.feed_eof() 76 + return reader 77 + 78 + 79 + async def _serve( 80 + request_bytes: bytes, 81 + *, 82 + app: Flask | None = None, 83 + stream_id: int = 1, 84 + ) -> tuple[ExchangeMetadata, _WriterStub]: 85 + writer = _WriterStub() 86 + metadata = await serve_request( 87 + _make_reader(request_bytes), 88 + writer, 89 + (app or _build_app()).wsgi_app, 90 + stream_id=stream_id, 91 + ) 92 + return metadata, writer 93 + 94 + 95 + def _split_response(raw: bytes) -> tuple[bytes, bytes]: 96 + head, sep, body = raw.partition(b"\r\n\r\n") 97 + assert sep == b"\r\n\r\n" 98 + return head, body 99 + 100 + 101 + @pytest.mark.asyncio 102 + async def test_get_returns_200_and_body() -> None: 103 + meta, writer = await _serve( 104 + b"GET /hello HTTP/1.1\r\nHost: link.test\r\nContent-Length: 0\r\n\r\n", 105 + ) 106 + 107 + head, body = _split_response(writer.joined()) 108 + 109 + assert meta == ExchangeMetadata( 110 + method="GET", 111 + path="/hello", 112 + status=200, 113 + request_bytes=0, 114 + response_bytes=len(writer.joined()), 115 + stream_id=1, 116 + ) 117 + assert set(ExchangeMetadata.__dataclass_fields__) == { 118 + "method", 119 + "path", 120 + "status", 121 + "request_bytes", 122 + "response_bytes", 123 + "stream_id", 124 + } 125 + assert head.startswith(b"HTTP/1.1 200 OK\r\n") 126 + assert body == b"hello" 127 + 128 + 129 + @pytest.mark.asyncio 130 + async def test_post_upload_roundtrips_sha256() -> None: 131 + payload = os.urandom(1024 * 1024) 132 + digest = hashlib.sha256(payload).hexdigest() 133 + request_bytes = ( 134 + b"POST /upload HTTP/1.1\r\n" 135 + b"Host: link.test\r\n" 136 + b"Content-Type: application/octet-stream\r\n" 137 + b"Content-Length: " + str(len(payload)).encode("ascii") + b"\r\n\r\n" + payload 138 + ) 139 + 140 + meta, writer = await _serve(request_bytes) 141 + head, body = _split_response(writer.joined()) 142 + parsed = json.loads(body) 143 + 144 + assert meta.method == "POST" 145 + assert meta.path == "/upload" 146 + assert meta.status == 200 147 + assert meta.request_bytes == 1024 * 1024 148 + assert meta.response_bytes == len(writer.joined()) 149 + assert head.startswith(b"HTTP/1.1 200 OK\r\n") 150 + assert parsed == {"sha256": digest, "length": 1024 * 1024} 151 + 152 + 153 + @pytest.mark.asyncio 154 + async def test_streaming_response_arrives_in_chunks() -> None: 155 + meta, writer = await _serve( 156 + b"GET /stream HTTP/1.1\r\nHost: link.test\r\nContent-Length: 0\r\n\r\n", 157 + ) 158 + 159 + assert meta.status == 200 160 + assert meta.method == "GET" 161 + assert meta.path == "/stream" 162 + assert writer.writes[0].startswith(b"HTTP/1.1 200 OK\r\n") 163 + assert writer.writes[1:] == [b"part-1\n", b"part-2\n", b"part-3\n"] 164 + assert len(writer.writes) == 4 165 + assert meta.response_bytes == len(writer.joined()) 166 + 167 + 168 + @pytest.mark.asyncio 169 + async def test_malformed_request_line_returns_400() -> None: 170 + meta, writer = await _serve(b"NOTAREALREQUEST\r\n\r\n") 171 + 172 + head, body = _split_response(writer.joined()) 173 + 174 + assert meta.method == "-" 175 + assert meta.path == "-" 176 + assert meta.status == 400 177 + assert head.startswith(b"HTTP/1.1 400 bad request\r\n") 178 + assert body == b"bad request\n" 179 + 180 + 181 + @pytest.mark.asyncio 182 + async def test_wsgi_exception_returns_500() -> None: 183 + meta, writer = await _serve( 184 + b"GET /boom HTTP/1.1\r\nHost: link.test\r\nContent-Length: 0\r\n\r\n", 185 + app=_build_app(propagate_exceptions=True), 186 + ) 187 + 188 + head, body = _split_response(writer.joined()) 189 + 190 + assert meta.method == "GET" 191 + assert meta.path == "/boom" 192 + assert meta.status == 500 193 + assert head.startswith(b"HTTP/1.1 500 internal server error\r\n") 194 + assert body == b"internal server error\n" 195 + 196 + 197 + @pytest.mark.asyncio 198 + async def test_metadata_has_no_payload_fields() -> None: 199 + meta, _ = await _serve( 200 + b"GET /hello HTTP/1.1\r\nHost: link.test\r\nContent-Length: 0\r\n\r\n", 201 + ) 202 + 203 + assert [field.name for field in dataclasses.fields(meta)] == [ 204 + "method", 205 + "path", 206 + "status", 207 + "request_bytes", 208 + "response_bytes", 209 + "stream_id", 210 + ]