personal memory agent
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

link: add in-tree python tunnel client for tests

+753
+753
tests/link/client.py
··· 1 + # SPDX-License-Identifier: AGPL-3.0-only 2 + # Copyright (c) 2026 sol pbc 3 + 4 + from __future__ import annotations 5 + 6 + import asyncio 7 + import dataclasses 8 + import hashlib 9 + import logging 10 + import urllib.parse 11 + from collections.abc import AsyncIterator, Awaitable, Callable 12 + 13 + import requests 14 + import websockets 15 + from cryptography import x509 16 + from cryptography.hazmat.primitives import hashes, serialization 17 + from cryptography.hazmat.primitives.asymmetric import ec 18 + from cryptography.x509.oid import NameOID 19 + from OpenSSL import SSL, crypto 20 + from websockets.asyncio.client import ClientConnection 21 + from websockets.exceptions import ConnectionClosed 22 + 23 + from think.link.ca import cert_fingerprint 24 + from think.link.framing import ( 25 + FLAG_CLOSE, 26 + FLAG_DATA, 27 + FLAG_OPEN, 28 + FLAG_RESET, 29 + FLAG_WINDOW, 30 + INITIAL_WINDOW, 31 + MAX_CONCURRENT_STREAMS, 32 + MAX_PAYLOAD, 33 + RECOMMENDED_CHUNK, 34 + RESET_FLOW_CONTROL_ERROR, 35 + RESET_INTERNAL_ERROR, 36 + RESET_PROTOCOL_ERROR, 37 + Frame, 38 + FrameDecoder, 39 + ProtocolError, 40 + build_close, 41 + build_data, 42 + build_open, 43 + build_reset, 44 + build_window, 45 + parse_reset_reason, 46 + parse_window_credit, 47 + ) 48 + 49 + LOG = logging.getLogger(__name__) 50 + _CONNECT_TIMEOUT_SECONDS = 15 51 + _HTTP_TIMEOUT_SECONDS = 30 52 + 53 + 54 + class TlsError(RuntimeError): 55 + """Raised when the client-side TLS handshake or tunnel aborts.""" 56 + 57 + 58 + class StreamResetError(ConnectionError): 59 + """Raised when the peer sends a RESET frame for an active stream.""" 60 + 61 + 62 + @dataclasses.dataclass(frozen=True) 63 + class ClientIdentity: 64 + private_key_pem: str 65 + client_cert_pem: str 66 + ca_chain_pem: str 67 + fingerprint: str 68 + home_instance_id: str 69 + home_label: str 70 + home_attestation: str 71 + 72 + 73 + @dataclasses.dataclass(frozen=True) 74 + class EnrolledDevice: 75 + device_token: str 76 + identity: ClientIdentity 77 + 78 + 79 + @dataclasses.dataclass 80 + class _TlsClientState: 81 + conn: SSL.Connection 82 + handshake_done: bool = False 83 + 84 + 85 + @dataclasses.dataclass 86 + class _StreamState: 87 + stream_id: int 88 + buffered: list[bytes] = dataclasses.field(default_factory=list) 89 + waiters: list[asyncio.Future[bytes | None]] = dataclasses.field( 90 + default_factory=list 91 + ) 92 + closed_waiters: list[asyncio.Future[None]] = dataclasses.field(default_factory=list) 93 + send_credit: int = INITIAL_WINDOW 94 + recv_credit: int = INITIAL_WINDOW 95 + unacked_recv: int = 0 96 + writer_closed: bool = False 97 + reader_closed: bool = False 98 + reset_reason: int | None = None 99 + credit_event: asyncio.Event = dataclasses.field(default_factory=asyncio.Event) 100 + 101 + def __post_init__(self) -> None: 102 + self.credit_event.set() 103 + 104 + 105 + class _DialerStream: 106 + def __init__(self, mux: _DialerMultiplexer, state: _StreamState) -> None: 107 + self._mux = mux 108 + self._state = state 109 + 110 + @property 111 + def id(self) -> int: 112 + return self._state.stream_id 113 + 114 + async def write(self, data: bytes) -> None: 115 + if self._state.writer_closed: 116 + raise ConnectionError(f"stream {self._state.stream_id} writer is closed") 117 + view = memoryview(data) 118 + while view: 119 + chunk_len = min( 120 + len(view), 121 + RECOMMENDED_CHUNK, 122 + MAX_PAYLOAD, 123 + self._state.send_credit, 124 + ) 125 + if chunk_len <= 0: 126 + self._state.credit_event.clear() 127 + await self._state.credit_event.wait() 128 + continue 129 + chunk = bytes(view[:chunk_len]) 130 + view = view[chunk_len:] 131 + self._state.send_credit -= chunk_len 132 + await self._mux._emit(build_data(self._state.stream_id, chunk)) 133 + 134 + async def close(self) -> None: 135 + if self._state.writer_closed: 136 + return 137 + self._state.writer_closed = True 138 + await self._mux._emit(build_close(self._state.stream_id)) 139 + if self._state.reader_closed: 140 + self._mux._forget(self._state.stream_id) 141 + 142 + async def reset(self, reason: int = RESET_INTERNAL_ERROR) -> None: 143 + if self._state.writer_closed and self._state.reader_closed: 144 + return 145 + self._state.writer_closed = True 146 + self._state.reader_closed = True 147 + self._state.reset_reason = reason 148 + await self._mux._emit(build_reset(self._state.stream_id, reason)) 149 + self._mux._close_stream(self._state, forget=True) 150 + 151 + async def read(self) -> AsyncIterator[bytes]: 152 + while True: 153 + if self._state.buffered: 154 + yield self._state.buffered.pop(0) 155 + continue 156 + if self._state.reader_closed: 157 + if self._state.reset_reason is not None: 158 + raise StreamResetError( 159 + f"stream {self._state.stream_id} reset: {self._state.reset_reason}" 160 + ) 161 + return 162 + fut: asyncio.Future[bytes | None] = ( 163 + asyncio.get_running_loop().create_future() 164 + ) 165 + self._state.waiters.append(fut) 166 + chunk = await fut 167 + if chunk is None: 168 + if self._state.reset_reason is not None: 169 + raise StreamResetError( 170 + f"stream {self._state.stream_id} reset: {self._state.reset_reason}" 171 + ) 172 + return 173 + yield chunk 174 + 175 + async def read_all(self) -> bytes: 176 + parts = bytearray() 177 + async for chunk in self.read(): 178 + parts.extend(chunk) 179 + return bytes(parts) 180 + 181 + @property 182 + def closed(self) -> asyncio.Future[None]: 183 + fut: asyncio.Future[None] = asyncio.get_running_loop().create_future() 184 + if self._state.reader_closed and self._state.writer_closed: 185 + fut.set_result(None) 186 + return fut 187 + self._state.closed_waiters.append(fut) 188 + return fut 189 + 190 + 191 + class _DialerMultiplexer: 192 + def __init__(self, send_frame: Callable[[bytes], Awaitable[None]]) -> None: 193 + self._decoder = FrameDecoder() 194 + self._send_frame = send_frame 195 + self._streams: dict[int, _StreamState] = {} 196 + self._next_local_id = 1 197 + self._closed = False 198 + 199 + async def open_stream(self, initial: bytes = b"") -> _DialerStream: 200 + if self._closed: 201 + raise ConnectionError("mux is closed") 202 + if len(self._streams) >= MAX_CONCURRENT_STREAMS: 203 + raise ConnectionError("concurrent stream cap reached") 204 + if len(initial) > MAX_PAYLOAD: 205 + raise ValueError(f"initial payload exceeds framing max {MAX_PAYLOAD}") 206 + stream_id = self._next_local_id 207 + self._next_local_id += 2 208 + state = _StreamState(stream_id=stream_id) 209 + if initial: 210 + state.send_credit -= len(initial) 211 + self._streams[stream_id] = state 212 + await self._emit(build_open(stream_id, initial)) 213 + return _DialerStream(self, state) 214 + 215 + async def feed(self, plaintext: bytes) -> None: 216 + if self._closed or not plaintext: 217 + return 218 + self._decoder.feed(plaintext) 219 + while True: 220 + try: 221 + frame = self._decoder.next() 222 + except ProtocolError: 223 + self.close() 224 + return 225 + if frame is None: 226 + return 227 + await self._dispatch(frame) 228 + 229 + def close(self) -> None: 230 + if self._closed: 231 + return 232 + self._closed = True 233 + for state in list(self._streams.values()): 234 + self._close_stream(state, forget=True) 235 + 236 + async def _dispatch(self, frame: Frame) -> None: 237 + if frame.flags & FLAG_OPEN: 238 + await self._emit(build_reset(frame.stream_id, RESET_PROTOCOL_ERROR)) 239 + return 240 + 241 + state = self._streams.get(frame.stream_id) 242 + if state is None: 243 + await self._emit(build_reset(frame.stream_id, RESET_PROTOCOL_ERROR)) 244 + return 245 + 246 + if frame.flags & FLAG_DATA: 247 + if len(frame.payload) > state.recv_credit: 248 + await self._emit(build_reset(frame.stream_id, RESET_FLOW_CONTROL_ERROR)) 249 + state.reset_reason = RESET_FLOW_CONTROL_ERROR 250 + self._close_stream(state, forget=True) 251 + return 252 + state.recv_credit -= len(frame.payload) 253 + state.unacked_recv += len(frame.payload) 254 + if state.waiters: 255 + waiter = state.waiters.pop(0) 256 + if not waiter.done(): 257 + waiter.set_result(frame.payload) 258 + else: 259 + state.buffered.append(frame.payload) 260 + if state.unacked_recv >= INITIAL_WINDOW // 2: 261 + grant = state.unacked_recv 262 + state.recv_credit += grant 263 + state.unacked_recv = 0 264 + await self._emit(build_window(frame.stream_id, grant)) 265 + 266 + if frame.flags & FLAG_CLOSE: 267 + state.reader_closed = True 268 + while state.waiters: 269 + waiter = state.waiters.pop(0) 270 + if not waiter.done(): 271 + waiter.set_result(None) 272 + if state.writer_closed: 273 + self._forget(frame.stream_id) 274 + self._resolve_closed(state) 275 + 276 + if frame.flags & FLAG_WINDOW: 277 + try: 278 + credit = parse_window_credit(frame) 279 + except ProtocolError: 280 + await self._emit(build_reset(frame.stream_id, RESET_PROTOCOL_ERROR)) 281 + state.reset_reason = RESET_PROTOCOL_ERROR 282 + self._close_stream(state, forget=True) 283 + return 284 + state.send_credit += credit 285 + state.credit_event.set() 286 + 287 + if frame.flags & FLAG_RESET: 288 + try: 289 + state.reset_reason = parse_reset_reason(frame) 290 + except ProtocolError: 291 + state.reset_reason = RESET_PROTOCOL_ERROR 292 + self._close_stream(state, forget=True) 293 + 294 + async def _emit(self, frame: Frame) -> None: 295 + if self._closed: 296 + return 297 + await self._send_frame(frame.encode()) 298 + 299 + def _close_stream(self, state: _StreamState, *, forget: bool) -> None: 300 + state.writer_closed = True 301 + state.reader_closed = True 302 + while state.waiters: 303 + waiter = state.waiters.pop(0) 304 + if not waiter.done(): 305 + waiter.set_result(None) 306 + state.credit_event.set() 307 + self._resolve_closed(state) 308 + if forget: 309 + self._forget(state.stream_id) 310 + 311 + def _resolve_closed(self, state: _StreamState) -> None: 312 + while state.closed_waiters: 313 + waiter = state.closed_waiters.pop(0) 314 + if not waiter.done(): 315 + waiter.set_result(None) 316 + 317 + def _forget(self, stream_id: int) -> None: 318 + self._streams.pop(stream_id, None) 319 + 320 + 321 + class TunnelSession: 322 + def __init__( 323 + self, 324 + *, 325 + ws: ClientConnection, 326 + tls: _TlsClientState, 327 + identity: ClientIdentity, 328 + ) -> None: 329 + self._ws = ws 330 + self._tls = tls 331 + self._identity = identity 332 + self._tls_lock = asyncio.Lock() 333 + self._mux = _DialerMultiplexer(self._send_plaintext) 334 + self._closed = asyncio.Event() 335 + self._reader_task = asyncio.create_task( 336 + self._read_ws(), 337 + name=f"link-client-{identity.home_instance_id}", 338 + ) 339 + 340 + async def __aenter__(self) -> TunnelSession: 341 + return self 342 + 343 + async def __aexit__(self, *_exc: object) -> None: 344 + await self.close() 345 + 346 + async def request( 347 + self, 348 + method: str, 349 + path: str, 350 + *, 351 + headers: dict[str, str] | None = None, 352 + body: bytes = b"", 353 + ) -> tuple[int, dict[str, str], bytes]: 354 + request_bytes = _http_request_bytes(method, path, headers=headers, body=body) 355 + stream = await self._mux.open_stream(request_bytes) 356 + await stream.close() 357 + response = await stream.read_all() 358 + return _parse_http_response(response) 359 + 360 + async def close(self) -> None: 361 + if self._closed.is_set(): 362 + return 363 + self._mux.close() 364 + if not self._ws.closed: 365 + await self._ws.close() 366 + await self._reader_task 367 + self._closed.set() 368 + 369 + async def _read_ws(self) -> None: 370 + try: 371 + async for message in self._ws: 372 + inbound = ( 373 + message if isinstance(message, bytes) else message.encode("utf-8") 374 + ) 375 + async with self._tls_lock: 376 + outbound, plaintext = _drive_tls_client(self._tls, inbound=inbound) 377 + if outbound: 378 + await self._ws.send(outbound) 379 + if plaintext: 380 + await self._mux.feed(plaintext) 381 + except ConnectionClosed: 382 + pass 383 + finally: 384 + self._mux.close() 385 + self._closed.set() 386 + 387 + async def _send_plaintext(self, plaintext: bytes) -> None: 388 + async with self._tls_lock: 389 + outbound, _ = _drive_tls_client(self._tls, plaintext_out=plaintext) 390 + if outbound: 391 + await self._ws.send(outbound) 392 + 393 + 394 + class Client: 395 + @staticmethod 396 + def pair( 397 + lan_url: str, 398 + device_label: str, 399 + *, 400 + ca_fingerprint_pin: str | None = None, 401 + ) -> ClientIdentity: 402 + base_url = lan_url.rstrip("/") 403 + LOG.info("client %s: pair start", device_label) 404 + pair_start = _post_json( 405 + f"{base_url}/app/link/pair-start", 406 + {"device_label": device_label}, 407 + ) 408 + nonce = pair_start.get("nonce") 409 + if not isinstance(nonce, str) or not nonce: 410 + raise RuntimeError("pair-start returned no nonce") 411 + 412 + private_key_pem, csr_pem = _build_csr(device_label) 413 + paired = _post_json( 414 + f"{base_url}/app/link/pair", 415 + { 416 + "nonce": nonce, 417 + "csr": csr_pem, 418 + "device_label": device_label, 419 + }, 420 + ) 421 + 422 + client_cert_pem = _required_str(paired, "client_cert") 423 + ca_chain = paired.get("ca_chain") 424 + if not isinstance(ca_chain, list) or not ca_chain: 425 + raise RuntimeError("pair returned no ca_chain") 426 + if not all(isinstance(item, str) and item for item in ca_chain): 427 + raise RuntimeError("pair returned invalid ca_chain") 428 + ca_chain_pem = "".join(ca_chain) 429 + ca_fingerprint = _cert_sha256_hex(_first_cert_pem(ca_chain_pem)) 430 + if ca_fingerprint_pin is not None and ca_fingerprint != ca_fingerprint_pin: 431 + raise RuntimeError( 432 + f"CA fingerprint mismatch: got {ca_fingerprint}, expected {ca_fingerprint_pin}" 433 + ) 434 + 435 + fingerprint = _required_str(paired, "fingerprint") 436 + if cert_fingerprint(client_cert_pem) != fingerprint: 437 + raise RuntimeError("pair returned certificate fingerprint mismatch") 438 + 439 + identity = ClientIdentity( 440 + private_key_pem=private_key_pem, 441 + client_cert_pem=client_cert_pem, 442 + ca_chain_pem=ca_chain_pem, 443 + fingerprint=fingerprint, 444 + home_instance_id=_required_str(paired, "instance_id"), 445 + home_label=_required_str(paired, "home_label"), 446 + home_attestation=_required_str(paired, "home_attestation"), 447 + ) 448 + LOG.info( 449 + "client %s: paired to %s", 450 + device_label, 451 + identity.home_instance_id, 452 + ) 453 + return identity 454 + 455 + @staticmethod 456 + def enroll_device(relay_url: str, identity: ClientIdentity) -> EnrolledDevice: 457 + endpoint = f"{relay_url.rstrip('/')}/enroll/device" 458 + LOG.info("client %s: enrolling device token", identity.fingerprint) 459 + payload = _post_json( 460 + endpoint, 461 + { 462 + "instance_id": identity.home_instance_id, 463 + "client_cert": identity.client_cert_pem, 464 + "home_attestation": identity.home_attestation, 465 + }, 466 + ) 467 + device_token = _required_str(payload, "device_token") 468 + LOG.info("client %s: enroll complete", identity.fingerprint) 469 + return EnrolledDevice(device_token=device_token, identity=identity) 470 + 471 + @staticmethod 472 + async def dial(relay_url: str, enrolled: EnrolledDevice) -> TunnelSession: 473 + identity = enrolled.identity 474 + url = ( 475 + _to_ws(relay_url.rstrip("/")) 476 + + "/session/dial?" 477 + + urllib.parse.urlencode( 478 + { 479 + "instance": identity.home_instance_id, 480 + "token": enrolled.device_token, 481 + } 482 + ) 483 + ) 484 + LOG.info("client %s: dialing %s", identity.fingerprint, _redact_url(url)) 485 + ws = await websockets.connect(url, max_size=None) 486 + try: 487 + tls = _new_tls_client(_build_tls_client_ctx(identity)) 488 + pending_plaintext = bytearray() 489 + outbound, plaintext = _drive_tls_client(tls) 490 + if outbound: 491 + await ws.send(outbound) 492 + pending_plaintext.extend(plaintext) 493 + while not tls.handshake_done: 494 + inbound = await asyncio.wait_for( 495 + ws.recv(), 496 + timeout=_CONNECT_TIMEOUT_SECONDS, 497 + ) 498 + inbound_bytes = ( 499 + inbound if isinstance(inbound, bytes) else inbound.encode("utf-8") 500 + ) 501 + outbound, plaintext = _drive_tls_client(tls, inbound=inbound_bytes) 502 + if outbound: 503 + await ws.send(outbound) 504 + pending_plaintext.extend(plaintext) 505 + except Exception: 506 + await ws.close() 507 + raise 508 + 509 + session = TunnelSession( 510 + ws=ws, 511 + tls=tls, 512 + identity=identity, 513 + ) 514 + if pending_plaintext: 515 + await session._mux.feed(bytes(pending_plaintext)) 516 + return session 517 + 518 + 519 + def _build_csr(device_label: str) -> tuple[str, str]: 520 + private_key = ec.generate_private_key(ec.SECP256R1()) 521 + csr = ( 522 + x509.CertificateSigningRequestBuilder() 523 + .subject_name( 524 + x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, device_label)]) 525 + ) 526 + .sign(private_key, hashes.SHA256()) 527 + ) 528 + private_key_pem = private_key.private_bytes( 529 + serialization.Encoding.PEM, 530 + serialization.PrivateFormat.PKCS8, 531 + serialization.NoEncryption(), 532 + ).decode("ascii") 533 + csr_pem = csr.public_bytes(serialization.Encoding.PEM).decode("ascii") 534 + return private_key_pem, csr_pem 535 + 536 + 537 + def _build_tls_client_ctx(identity: ClientIdentity) -> SSL.Context: 538 + ctx = SSL.Context(SSL.TLS_METHOD) 539 + ctx.set_min_proto_version(SSL.TLS1_3_VERSION) 540 + ctx.set_max_proto_version(SSL.TLS1_3_VERSION) 541 + ctx.use_certificate( 542 + crypto.load_certificate( 543 + crypto.FILETYPE_PEM, 544 + identity.client_cert_pem.encode("ascii"), 545 + ) 546 + ) 547 + ctx.use_privatekey( 548 + crypto.load_privatekey( 549 + crypto.FILETYPE_PEM, 550 + identity.private_key_pem.encode("ascii"), 551 + ) 552 + ) 553 + store = ctx.get_cert_store() 554 + assert store is not None, "client TLS context must expose a cert store" 555 + for cert_pem in _split_pem_chain(identity.ca_chain_pem): 556 + store.add_cert( 557 + crypto.X509.from_cryptography(x509.load_pem_x509_certificate(cert_pem)), 558 + ) 559 + ctx.set_verify(SSL.VERIFY_PEER, _verify_server_cert) 560 + ctx.check_privatekey() 561 + return ctx 562 + 563 + 564 + def _verify_server_cert( 565 + _conn: SSL.Connection, 566 + _cert: crypto.X509, 567 + _errno: int, 568 + _depth: int, 569 + preverify_ok: int, 570 + ) -> bool: 571 + return bool(preverify_ok) 572 + 573 + 574 + def _new_tls_client(ctx: SSL.Context) -> _TlsClientState: 575 + conn = SSL.Connection(ctx, None) 576 + conn.set_connect_state() 577 + return _TlsClientState(conn=conn) 578 + 579 + 580 + def _drive_tls_client( 581 + state: _TlsClientState, 582 + *, 583 + inbound: bytes = b"", 584 + plaintext_out: bytes = b"", 585 + ) -> tuple[bytes, bytes]: 586 + if inbound: 587 + state.conn.bio_write(inbound) 588 + if plaintext_out: 589 + try: 590 + state.conn.send(plaintext_out) 591 + except SSL.WantReadError: 592 + pass 593 + except SSL.Error as exc: 594 + raise TlsError(f"send failed: {exc}") from exc 595 + 596 + if not state.handshake_done: 597 + try: 598 + state.conn.do_handshake() 599 + state.handshake_done = True 600 + except SSL.WantReadError: 601 + pass 602 + except SSL.Error as exc: 603 + raise TlsError(f"handshake failed: {exc}") from exc 604 + 605 + plaintext_in = bytearray() 606 + if state.handshake_done: 607 + while True: 608 + try: 609 + chunk = state.conn.recv(16 * 1024) 610 + except SSL.WantReadError: 611 + break 612 + except SSL.ZeroReturnError: 613 + break 614 + except SSL.Error as exc: 615 + raise TlsError(f"recv failed: {exc}") from exc 616 + if not chunk: 617 + break 618 + plaintext_in.extend(chunk) 619 + 620 + outbound = bytearray() 621 + while True: 622 + try: 623 + chunk = state.conn.bio_read(16 * 1024) 624 + except SSL.WantReadError: 625 + break 626 + if not chunk: 627 + break 628 + outbound.extend(chunk) 629 + return bytes(outbound), bytes(plaintext_in) 630 + 631 + 632 + def _http_request_bytes( 633 + method: str, 634 + path: str, 635 + *, 636 + headers: dict[str, str] | None, 637 + body: bytes, 638 + ) -> bytes: 639 + body_bytes = body or b"" 640 + normalized_headers = {"host": "spl.local"} 641 + if headers: 642 + normalized_headers.update({k.lower(): v for k, v in headers.items()}) 643 + normalized_headers["content-length"] = str(len(body_bytes)) 644 + head = ( 645 + f"{method} {path} HTTP/1.1\r\n" 646 + + "".join(f"{name}: {value}\r\n" for name, value in normalized_headers.items()) 647 + + "\r\n" 648 + ) 649 + return head.encode("ascii") + body_bytes 650 + 651 + 652 + def _parse_http_response(raw: bytes) -> tuple[int, dict[str, str], bytes]: 653 + split = raw.find(b"\r\n\r\n") 654 + if split < 0: 655 + raise ValueError("response missing header terminator") 656 + head = raw[:split].decode("latin-1") 657 + body = raw[split + 4 :] 658 + lines = head.split("\r\n") 659 + if not lines: 660 + raise ValueError("response missing status line") 661 + parts = lines[0].split(" ", 2) 662 + if len(parts) < 2 or not parts[1].isdigit(): 663 + raise ValueError(f"bad status line: {lines[0]!r}") 664 + status = int(parts[1]) 665 + headers: dict[str, str] = {} 666 + for line in lines[1:]: 667 + if not line or ":" not in line: 668 + continue 669 + name, value = line.split(":", 1) 670 + headers[name.strip().lower()] = value.strip() 671 + if headers.get("transfer-encoding", "").lower() == "chunked": 672 + return status, headers, _dechunk(body) 673 + content_length = headers.get("content-length") 674 + if content_length is None: 675 + return status, headers, body 676 + return status, headers, body[: int(content_length)] 677 + 678 + 679 + def _dechunk(raw: bytes) -> bytes: 680 + out = bytearray() 681 + index = 0 682 + while index < len(raw): 683 + line_end = raw.find(b"\r\n", index) 684 + if line_end < 0: 685 + raise ValueError("chunked response missing size terminator") 686 + size_text = raw[index:line_end].decode("ascii").split(";", 1)[0].strip() 687 + size = int(size_text, 16) 688 + index = line_end + 2 689 + if size == 0: 690 + return bytes(out) 691 + out.extend(raw[index : index + size]) 692 + index += size + 2 693 + return bytes(out) 694 + 695 + 696 + def _post_json(url: str, payload: dict[str, object]) -> dict[str, object]: 697 + response = requests.post(url, json=payload, timeout=_HTTP_TIMEOUT_SECONDS) 698 + if not response.ok: 699 + raise RuntimeError( 700 + f"POST {url} failed: HTTP {response.status_code}: {response.text}" 701 + ) 702 + parsed = response.json() 703 + if not isinstance(parsed, dict): 704 + raise RuntimeError(f"unexpected JSON response from {url}") 705 + return parsed 706 + 707 + 708 + def _required_str(payload: dict[str, object], key: str) -> str: 709 + value = payload.get(key) 710 + if not isinstance(value, str) or not value: 711 + raise RuntimeError(f"missing string field: {key}") 712 + return value 713 + 714 + 715 + def _split_pem_chain(pem_bundle: str) -> list[bytes]: 716 + marker = "-----END CERTIFICATE-----" 717 + certs: list[bytes] = [] 718 + for chunk in pem_bundle.split(marker): 719 + chunk = chunk.strip() 720 + if not chunk: 721 + continue 722 + certs.append(f"{chunk}\n{marker}\n".encode("ascii")) 723 + return certs 724 + 725 + 726 + def _first_cert_pem(pem_bundle: str) -> str: 727 + certs = _split_pem_chain(pem_bundle) 728 + if not certs: 729 + raise RuntimeError("empty certificate chain") 730 + return certs[0].decode("ascii") 731 + 732 + 733 + def _cert_sha256_hex(cert_pem: str) -> str: 734 + cert = x509.load_pem_x509_certificate(cert_pem.encode("ascii")) 735 + return hashlib.sha256(cert.public_bytes(serialization.Encoding.DER)).hexdigest() 736 + 737 + 738 + def _to_ws(url: str) -> str: 739 + if url.startswith("https://"): 740 + return "wss://" + url[len("https://") :] 741 + if url.startswith("http://"): 742 + return "ws://" + url[len("http://") :] 743 + return url 744 + 745 + 746 + def _redact_url(url: str) -> str: 747 + parsed = urllib.parse.urlparse(url) 748 + query = urllib.parse.parse_qs(parsed.query) 749 + if "token" in query: 750 + query["token"] = ["<redacted>"] 751 + return urllib.parse.urlunparse( 752 + parsed._replace(query=urllib.parse.urlencode(query, doseq=True)) 753 + )