personal memory agent
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

Merge branch 'hopper-iqpv6zfn-dream-resilience-scope'

+614 -10
+188
tests/test_circuit_breaker.py
··· 1 + # SPDX-License-Identifier: AGPL-3.0-only 2 + # Copyright (c) 2026 sol pbc 3 + 4 + import asyncio 5 + import json 6 + import time 7 + from types import SimpleNamespace 8 + 9 + import pytest 10 + 11 + from think.providers.shared import CircuitBreaker, CircuitOpenError 12 + 13 + 14 + class FakeClientError(Exception): 15 + def __init__(self, code, response_json=None): 16 + self.code = code 17 + self.response_json = response_json or {} 18 + super().__init__(f"{code}") 19 + 20 + 21 + def test_starts_closed(): 22 + cb = CircuitBreaker("google") 23 + assert cb.state == cb.CLOSED 24 + 25 + 26 + def test_stays_closed_below_threshold(): 27 + cb = CircuitBreaker("google", failure_threshold=5) 28 + for _ in range(4): 29 + cb.record_failure(FakeClientError(429)) 30 + assert cb.state == cb.CLOSED 31 + 32 + 33 + def test_opens_at_threshold(): 34 + cb = CircuitBreaker("google", failure_threshold=5) 35 + for _ in range(5): 36 + cb.record_failure(FakeClientError(429)) 37 + assert cb.state == cb.OPEN 38 + 39 + 40 + def test_open_rejects_requests(): 41 + cb = CircuitBreaker("google", failure_threshold=1) 42 + cb.record_failure(FakeClientError(429)) 43 + with pytest.raises(CircuitOpenError): 44 + cb.check() 45 + 46 + 47 + def test_circuit_open_error_attributes(): 48 + err = CircuitOpenError("google", 12.5) 49 + assert err.provider == "google" 50 + assert err.cooldown_remaining == 12.5 51 + 52 + 53 + def test_half_open_after_cooldown(): 54 + cb = CircuitBreaker("google", failure_threshold=1, cooldown_s=5) 55 + cb.record_failure(FakeClientError(429)) 56 + cb._opened_at = time.time() - cb._current_cooldown - 1 57 + assert cb.state == cb.HALF_OPEN 58 + 59 + 60 + def test_half_open_success_closes(): 61 + cb = CircuitBreaker("google", failure_threshold=1, cooldown_s=5) 62 + cb.record_failure(FakeClientError(429)) 63 + cb._opened_at = time.time() - cb._current_cooldown - 1 64 + assert cb.state == cb.HALF_OPEN 65 + cb.record_success() 66 + assert cb.state == cb.CLOSED 67 + assert cb._failure_count == 0 68 + 69 + 70 + def test_half_open_failure_reopens_with_doubled_cooldown(): 71 + cb = CircuitBreaker("google", failure_threshold=1, cooldown_s=5) 72 + cb.record_failure(FakeClientError(429)) 73 + cb._opened_at = time.time() - cb._current_cooldown - 1 74 + assert cb.state == cb.HALF_OPEN 75 + cb.record_failure(FakeClientError(429)) 76 + assert cb.state == cb.OPEN 77 + assert cb._current_cooldown == 10 78 + 79 + 80 + def test_cooldown_cap(): 81 + cb = CircuitBreaker("google", failure_threshold=1, cooldown_s=400, max_cooldown_s=600) 82 + cb.record_failure(FakeClientError(429)) 83 + cb._opened_at = time.time() - cb._current_cooldown - 1 84 + assert cb.state == cb.HALF_OPEN 85 + cb.record_failure(FakeClientError(429)) 86 + assert cb._current_cooldown == 600 87 + 88 + 89 + def test_success_resets_failure_count(): 90 + cb = CircuitBreaker("google", failure_threshold=5) 91 + cb.record_failure(FakeClientError(429)) 92 + cb.record_failure(FakeClientError(429)) 93 + cb.record_success() 94 + assert cb._failure_count == 0 95 + assert cb.state == cb.CLOSED 96 + 97 + 98 + def test_circuit_opens_on_consecutive_429s(monkeypatch): 99 + from think.providers import google as google_provider 100 + 101 + # Pre-populate with a clean breaker (no health_path) to avoid fixture writes 102 + cb = CircuitBreaker("google") 103 + monkeypatch.setattr(google_provider, "_circuit_breakers", {"google": cb}) 104 + monkeypatch.setattr( 105 + google_provider, "_is_quota_error", lambda error: getattr(error, "code", None) == 429 106 + ) 107 + 108 + # No-op rate budget to avoid fixture file creation 109 + class _NoopBudget: 110 + def acquire(self, **kw): pass 111 + async def aacquire(self, **kw): pass 112 + monkeypatch.setattr("think.rate_limiter.get_rate_budget", lambda: _NoopBudget()) 113 + 114 + class DummyModels: 115 + def __init__(self): 116 + self.calls = 0 117 + 118 + async def generate_content(self, **kwargs): 119 + self.calls += 1 120 + raise FakeClientError(429, {"error": "rate limit"}) 121 + 122 + client = SimpleNamespace(aio=SimpleNamespace(models=DummyModels())) 123 + 124 + for _ in range(5): 125 + with pytest.raises(FakeClientError): 126 + asyncio.run(google_provider.run_agenerate("hello", client=client)) 127 + 128 + with pytest.raises(CircuitOpenError): 129 + asyncio.run(google_provider.run_agenerate("hello", client=client)) 130 + 131 + assert client.aio.models.calls == 5 132 + 133 + 134 + def test_health_file_write_on_open(tmp_path): 135 + """AC6: Circuit state visible in journal/health/agents.json.""" 136 + health_path = tmp_path / "health" / "agents.json" 137 + health_path.parent.mkdir(parents=True) 138 + health_path.write_text( 139 + json.dumps( 140 + { 141 + "results": [{"provider": "google", "ok": True}], 142 + "checked_at": "2026-04-02T00:00:00+00:00", 143 + } 144 + ) 145 + ) 146 + 147 + cb = CircuitBreaker("google", failure_threshold=5) 148 + cb._health_path = health_path 149 + 150 + for _ in range(5): 151 + cb.record_failure(FakeClientError(429)) 152 + 153 + data = json.loads(health_path.read_text()) 154 + assert "circuit_breakers" in data 155 + assert data["circuit_breakers"]["google"]["state"] == "open" 156 + assert data["circuit_breakers"]["google"]["failure_count"] == 5 157 + assert data["results"] == [{"provider": "google", "ok": True}] 158 + assert data["checked_at"] == "2026-04-02T00:00:00+00:00" 159 + 160 + 161 + def test_callosum_events_on_state_transitions(monkeypatch): 162 + """AC3: Circuit emits provider.unhealthy/provider.healthy via Callosum.""" 163 + events = [] 164 + 165 + def fake_callosum_send(tract, event, **fields): 166 + events.append({"tract": tract, "event": event, **fields}) 167 + return True 168 + 169 + monkeypatch.setattr("think.callosum.callosum_send", fake_callosum_send) 170 + 171 + cb = CircuitBreaker("google", failure_threshold=2) 172 + cb.record_failure(FakeClientError(429)) 173 + cb.record_failure(FakeClientError(429)) 174 + 175 + assert len(events) == 1 176 + assert events[0]["tract"] == "provider" 177 + assert events[0]["event"] == "unhealthy" 178 + assert events[0]["provider"] == "google" 179 + assert "cooldown_s" in events[0] 180 + 181 + cb._opened_at = time.time() - cb._current_cooldown - 1 182 + assert cb.state == cb.HALF_OPEN 183 + cb.record_success() 184 + 185 + assert len(events) == 2 186 + assert events[1]["tract"] == "provider" 187 + assert events[1]["event"] == "healthy" 188 + assert events[1]["provider"] == "google"
+98
tests/test_rate_limiter.py
··· 1 + # SPDX-License-Identifier: AGPL-3.0-only 2 + # Copyright (c) 2026 sol pbc 3 + 4 + import asyncio 5 + import json 6 + import time 7 + 8 + import pytest 9 + 10 + from think.rate_limiter import RateBudget, RateBudgetExhausted 11 + 12 + 13 + def test_try_acquire_success(tmp_path): 14 + budget = RateBudget(tmp_path / "rate_budget.json", rpm=3) 15 + assert budget.try_acquire() is True 16 + state = json.loads((tmp_path / "rate_budget.json").read_text()) 17 + assert state["remaining"] == 2 18 + 19 + 20 + def test_try_acquire_exhausted(tmp_path): 21 + budget_path = tmp_path / "rate_budget.json" 22 + budget_path.write_text( 23 + json.dumps( 24 + { 25 + "remaining": 0, 26 + "window_start": time.time(), 27 + "window_duration_s": 60.0, 28 + "budget_per_window": 3, 29 + } 30 + ) 31 + ) 32 + budget = RateBudget(budget_path, rpm=3) 33 + assert budget.try_acquire() is False 34 + 35 + 36 + def test_window_replenishment(tmp_path): 37 + budget_path = tmp_path / "rate_budget.json" 38 + budget_path.write_text( 39 + json.dumps( 40 + { 41 + "remaining": 0, 42 + "window_start": time.time() - 61, 43 + "window_duration_s": 60.0, 44 + "budget_per_window": 3, 45 + } 46 + ) 47 + ) 48 + budget = RateBudget(budget_path, rpm=3) 49 + assert budget.try_acquire() is True 50 + state = json.loads(budget_path.read_text()) 51 + assert state["remaining"] == 2 52 + 53 + 54 + def test_acquire_blocks_then_succeeds(tmp_path, monkeypatch): 55 + budget = RateBudget(tmp_path / "rate_budget.json", rpm=1) 56 + calls = {"count": 0} 57 + 58 + def fake_try_acquire(): 59 + calls["count"] += 1 60 + return calls["count"] >= 2 61 + 62 + monkeypatch.setattr(budget, "try_acquire", fake_try_acquire) 63 + budget.acquire(timeout_s=1) 64 + assert calls["count"] == 2 65 + 66 + 67 + def test_acquire_timeout_raises(tmp_path): 68 + budget = RateBudget(tmp_path / "rate_budget.json", rpm=1) 69 + assert budget.try_acquire() is True 70 + with pytest.raises(RateBudgetExhausted): 71 + budget.acquire(timeout_s=0.2) 72 + 73 + def test_aacquire_success(tmp_path): 74 + async def run(): 75 + budget = RateBudget(tmp_path / "rate_budget.json", rpm=1) 76 + await budget.aacquire(timeout_s=1) 77 + state = json.loads((tmp_path / "rate_budget.json").read_text()) 78 + assert state["remaining"] == 0 79 + 80 + asyncio.run(run()) 81 + 82 + 83 + def test_concurrent_acquire(tmp_path): 84 + async def run(): 85 + budget = RateBudget(tmp_path / "rate_budget.json", rpm=10) 86 + 87 + async def worker(): 88 + successes = 0 89 + for _ in range(5): 90 + if budget.try_acquire(): 91 + successes += 1 92 + await asyncio.sleep(0) 93 + return successes 94 + 95 + results = await asyncio.gather(worker(), worker(), worker()) 96 + assert sum(results) == 10 97 + 98 + asyncio.run(run())
+63 -10
think/providers/google.py
··· 34 34 import logging 35 35 import os 36 36 import traceback 37 + from pathlib import Path 37 38 from typing import Any, Callable 38 39 39 40 from google import genai ··· 49 50 build_cogitate_env, 50 51 ) 51 52 from .shared import ( 53 + CircuitBreaker, 52 54 GenerateResult, 53 55 JSONEventCallback, 54 56 ThinkingEvent, ··· 59 61 _DEFAULT_MODEL = GEMINI_FLASH 60 62 61 63 logger = logging.getLogger(__name__) 64 + _circuit_breakers: dict[str, CircuitBreaker] = {} 62 65 63 66 64 67 # --------------------------------------------------------------------------- ··· 89 92 http_options=types.HttpOptions(retry_options=types.HttpRetryOptions(attempts=8)), 90 93 ) 91 94 return client 95 + 96 + 97 + def _get_circuit_breaker() -> CircuitBreaker: 98 + """Get or create the Google circuit breaker singleton.""" 99 + if "google" not in _circuit_breakers: 100 + from think.utils import get_journal 101 + 102 + cb = CircuitBreaker("google") 103 + try: 104 + cb._health_path = Path(get_journal()) / "health" / "agents.json" 105 + except Exception: 106 + pass 107 + _circuit_breakers["google"] = cb 108 + return _circuit_breakers["google"] 109 + 110 + 111 + def _is_quota_error(error: Exception) -> bool: 112 + """Check if an error is a quota/rate-limit error from Google.""" 113 + try: 114 + from google.genai.errors import ClientError 115 + 116 + return isinstance(error, ClientError) and getattr(error, "code", None) == 429 117 + except ImportError: 118 + return False 92 119 93 120 94 121 def _compute_agent_thinking_params( ··· 382 409 timeout_s=timeout_s, 383 410 ) 384 411 385 - response = client.models.generate_content( 386 - model=model, 387 - contents=contents, 388 - config=config, 389 - ) 412 + cb = _get_circuit_breaker() 413 + cb.check() 414 + 415 + from think.rate_limiter import get_rate_budget 416 + 417 + get_rate_budget().acquire() 418 + 419 + try: 420 + response = client.models.generate_content( 421 + model=model, 422 + contents=contents, 423 + config=config, 424 + ) 425 + except Exception as e: 426 + if _is_quota_error(e): 427 + cb.record_failure(e) 428 + raise 390 429 430 + cb.record_success() 391 431 return GenerateResult( 392 432 text=_extract_response_text(response), 393 433 usage=_extract_usage(response), ··· 426 466 timeout_s=timeout_s, 427 467 ) 428 468 429 - response = await client.aio.models.generate_content( 430 - model=model, 431 - contents=contents, 432 - config=config, 433 - ) 469 + cb = _get_circuit_breaker() 470 + cb.check() 471 + 472 + from think.rate_limiter import get_rate_budget 473 + 474 + await get_rate_budget().aacquire() 475 + 476 + try: 477 + response = await client.aio.models.generate_content( 478 + model=model, 479 + contents=contents, 480 + config=config, 481 + ) 482 + except Exception as e: 483 + if _is_quota_error(e): 484 + cb.record_failure(e) 485 + raise 434 486 487 + cb.record_success() 435 488 return GenerateResult( 436 489 text=_extract_response_text(response), 437 490 usage=_extract_usage(response),
+139
think/providers/shared.py
··· 13 13 from __future__ import annotations 14 14 15 15 import json 16 + import time 16 17 from typing import Any, Callable, Literal, Optional, Union 17 18 18 19 from typing_extensions import Required, TypedDict ··· 235 236 return trimmed 236 237 237 238 239 + class CircuitOpenError(Exception): 240 + """Raised when a circuit breaker is open and rejecting requests.""" 241 + 242 + def __init__(self, provider: str, cooldown_remaining: float): 243 + self.provider = provider 244 + self.cooldown_remaining = cooldown_remaining 245 + super().__init__( 246 + f"Circuit breaker open for {provider} ({cooldown_remaining:.0f}s remaining)" 247 + ) 248 + 249 + 250 + class CircuitBreaker: 251 + """Per-process circuit breaker for API providers. 252 + 253 + States: closed (normal), open (rejecting), half_open (probing after cooldown). 254 + On failure_threshold consecutive quota errors, opens the circuit. 255 + Cooldown doubles on each failed probe, capped at max_cooldown_s. 256 + """ 257 + 258 + CLOSED = "closed" 259 + OPEN = "open" 260 + HALF_OPEN = "half_open" 261 + 262 + def __init__( 263 + self, 264 + provider, 265 + failure_threshold=5, 266 + cooldown_s=60, 267 + max_cooldown_s=600, 268 + ): 269 + self.provider = provider 270 + self.failure_threshold = failure_threshold 271 + self._initial_cooldown = cooldown_s 272 + self.max_cooldown_s = max_cooldown_s 273 + self._state = self.CLOSED 274 + self._failure_count = 0 275 + self._opened_at = None 276 + self._current_cooldown = cooldown_s 277 + self._health_path = None 278 + 279 + @property 280 + def state(self): 281 + if self._state == self.OPEN and self._opened_at is not None: 282 + if time.time() - self._opened_at >= self._current_cooldown: 283 + self._state = self.HALF_OPEN 284 + return self._state 285 + 286 + def check(self): 287 + """Raise CircuitOpenError if circuit is open. Call before each request.""" 288 + s = self.state 289 + if s == self.OPEN: 290 + remaining = self._current_cooldown - (time.time() - self._opened_at) 291 + raise CircuitOpenError(self.provider, max(0, remaining)) 292 + 293 + def record_success(self): 294 + """Record a successful API call. Closes circuit if half-open.""" 295 + if self._state != self.CLOSED: 296 + self._state = self.CLOSED 297 + self._failure_count = 0 298 + self._current_cooldown = self._initial_cooldown 299 + self._opened_at = None 300 + self._emit("provider", "healthy", provider=self.provider) 301 + self._write_health() 302 + else: 303 + self._failure_count = 0 304 + 305 + def record_failure(self, error): 306 + """Record a quota/429 failure. May open the circuit.""" 307 + self._failure_count += 1 308 + if self._state == self.HALF_OPEN: 309 + self._state = self.OPEN 310 + self._opened_at = time.time() 311 + self._current_cooldown = min( 312 + self._current_cooldown * 2, self.max_cooldown_s 313 + ) 314 + self._emit( 315 + "provider", 316 + "unhealthy", 317 + provider=self.provider, 318 + cooldown_s=self._current_cooldown, 319 + ) 320 + self._write_health() 321 + elif self._failure_count >= self.failure_threshold: 322 + self._state = self.OPEN 323 + self._opened_at = time.time() 324 + self._emit( 325 + "provider", 326 + "unhealthy", 327 + provider=self.provider, 328 + cooldown_s=self._current_cooldown, 329 + ) 330 + self._write_health() 331 + 332 + def _emit(self, tract, event, **fields): 333 + """Emit callosum event. Best-effort, never raises.""" 334 + try: 335 + from think.callosum import callosum_send 336 + 337 + callosum_send(tract, event, **fields) 338 + except Exception: 339 + pass 340 + 341 + def _write_health(self): 342 + """Write circuit breaker state to agents.json. Best-effort.""" 343 + if self._health_path is None: 344 + return 345 + try: 346 + import fcntl 347 + from datetime import datetime, timezone 348 + 349 + health_dir = self._health_path.parent 350 + health_dir.mkdir(parents=True, exist_ok=True) 351 + lock_path = health_dir / "agents.json.lock" 352 + with open(lock_path, "w") as lock_file: 353 + fcntl.flock(lock_file, fcntl.LOCK_EX) 354 + try: 355 + data = {} 356 + if self._health_path.exists(): 357 + data = json.loads(self._health_path.read_text()) 358 + cb_data = data.setdefault("circuit_breakers", {}) 359 + cb_data[self.provider] = { 360 + "state": self._state, 361 + "failure_count": self._failure_count, 362 + "cooldown_s": self._current_cooldown, 363 + } 364 + if self._opened_at is not None: 365 + cb_data[self.provider]["opened_at"] = datetime.fromtimestamp( 366 + self._opened_at, tz=timezone.utc 367 + ).isoformat() 368 + self._health_path.write_text(json.dumps(data, indent=2)) 369 + finally: 370 + fcntl.flock(lock_file, fcntl.LOCK_UN) 371 + except Exception: 372 + pass 373 + 374 + 238 375 __all__ = [ 376 + "CircuitBreaker", 377 + "CircuitOpenError", 239 378 "Event", 240 379 "GenerateResult", 241 380 "JSONEventCallback",
+126
think/rate_limiter.py
··· 1 + # SPDX-License-Identifier: AGPL-3.0-only 2 + # Copyright (c) 2026 sol pbc 3 + 4 + """File-based rate limiter for cross-process API budget management. 5 + 6 + Uses fcntl.flock() for cross-process synchronization following the pattern 7 + in think/entities/saving.py. Budget file lives at journal/health/rate_budget.json. 8 + """ 9 + 10 + from __future__ import annotations 11 + 12 + import asyncio 13 + import json 14 + import os 15 + import time 16 + from pathlib import Path 17 + 18 + from think.utils import get_config, get_journal 19 + 20 + _rate_budget = None 21 + 22 + 23 + class RateBudgetExhausted(Exception): 24 + """Raised when rate budget cannot be acquired within timeout.""" 25 + 26 + 27 + class RateBudget: 28 + """File-based fixed-window request budget.""" 29 + 30 + def __init__(self, budget_path: Path, rpm: int = 1500, window_s: float = 60.0): 31 + self.budget_path = budget_path 32 + self.rpm = rpm 33 + self.window_s = window_s 34 + 35 + def _default_state(self, now: float) -> dict: 36 + return { 37 + "remaining": self.rpm, 38 + "window_start": now, 39 + "window_duration_s": self.window_s, 40 + "budget_per_window": self.rpm, 41 + } 42 + 43 + def try_acquire(self) -> bool: 44 + import fcntl 45 + 46 + now = time.time() 47 + self.budget_path.parent.mkdir(parents=True, exist_ok=True) 48 + lock_path = self.budget_path.with_name(f"{self.budget_path.name}.lock") 49 + with open(lock_path, "w") as lock_file: 50 + fcntl.flock(lock_file, fcntl.LOCK_EX) 51 + try: 52 + if self.budget_path.exists(): 53 + try: 54 + state = json.loads(self.budget_path.read_text()) 55 + except (json.JSONDecodeError, OSError): 56 + state = self._default_state(now) 57 + else: 58 + state = self._default_state(now) 59 + 60 + window_start = state.get("window_start", now) 61 + window_duration_s = state.get("window_duration_s", self.window_s) 62 + budget_per_window = state.get("budget_per_window", self.rpm) 63 + remaining = state.get("remaining", budget_per_window) 64 + 65 + if now - window_start >= window_duration_s: 66 + remaining = self.rpm 67 + state = { 68 + "remaining": remaining, 69 + "window_start": now, 70 + "window_duration_s": self.window_s, 71 + "budget_per_window": self.rpm, 72 + } 73 + 74 + if remaining > 0: 75 + state["remaining"] = remaining - 1 76 + self.budget_path.write_text(json.dumps(state, indent=2)) 77 + return True 78 + 79 + self.budget_path.write_text(json.dumps(state, indent=2)) 80 + return False 81 + finally: 82 + fcntl.flock(lock_file, fcntl.LOCK_UN) 83 + 84 + def acquire(self, timeout_s: float = 30.0) -> None: 85 + deadline = time.time() + timeout_s 86 + delay = 0.1 87 + while time.time() < deadline: 88 + if self.try_acquire(): 89 + return 90 + time.sleep(delay) 91 + delay = min(delay * 2, 1.0) 92 + raise RateBudgetExhausted( 93 + f"Rate budget not available within {timeout_s:.1f}s" 94 + ) 95 + 96 + async def aacquire(self, timeout_s: float = 30.0) -> None: 97 + deadline = time.time() + timeout_s 98 + delay = 0.1 99 + while time.time() < deadline: 100 + if self.try_acquire(): 101 + return 102 + await asyncio.sleep(delay) 103 + delay = min(delay * 2, 1.0) 104 + raise RateBudgetExhausted( 105 + f"Rate budget not available within {timeout_s:.1f}s" 106 + ) 107 + 108 + 109 + def get_rate_budget() -> RateBudget: 110 + """Get or create the global rate budget instance. 111 + 112 + Reads budget RPM from: 113 + 1. SOL_RATE_BUDGET_RPM env var 114 + 2. journal config providers.rate_budget_rpm 115 + 3. Default: 1500 116 + """ 117 + global _rate_budget 118 + if _rate_budget is None: 119 + rpm_str = os.getenv("SOL_RATE_BUDGET_RPM") 120 + if rpm_str is not None: 121 + rpm = int(rpm_str) 122 + else: 123 + rpm = get_config().get("providers", {}).get("rate_budget_rpm", 1500) 124 + budget_path = Path(get_journal()) / "health" / "rate_budget.json" 125 + _rate_budget = RateBudget(budget_path, rpm=rpm) 126 + return _rate_budget