personal memory agent
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

at d18a7c02359cd827d0ff15058861de5c2600a96f 1414 lines 46 kB view raw
1# SPDX-License-Identifier: AGPL-3.0-only 2# Copyright (c) 2026 sol pbc 3 4import fnmatch 5import inspect 6import json 7import logging 8import os 9import subprocess 10import time 11from datetime import datetime, timezone 12from pathlib import Path 13from typing import Any, Dict, List, Optional, Union 14 15import frontmatter 16from jsonschema import Draft202012Validator 17 18from think.utils import get_config, get_journal 19 20logger = logging.getLogger(__name__) 21 22# --------------------------------------------------------------------------- 23# Tier constants 24# --------------------------------------------------------------------------- 25 26TIER_PRO = 1 27TIER_FLASH = 2 28TIER_LITE = 3 29 30# --------------------------------------------------------------------------- 31# Model constants 32# 33# IMPORTANT: When updating these models, verify pricing support: 34# 1. Run: make test-only TEST=tests/test_models.py::test_all_default_models_have_pricing 35# 2. If test fails, update genai-prices: make update-prices 36# 3. If still failing, the model may be too new for genai-prices 37# 38# The genai-prices library provides token cost data. New models may not have 39# pricing immediately after release. See: https://pypi.org/project/genai-prices/ 40# --------------------------------------------------------------------------- 41 42# Valid OpenAI reasoning effort suffixes appended to model names. 43# E.g., "gpt-5.2-high" → reasoning_effort="high", "gpt-5.2" → omitted. 44OPENAI_EFFORT_SUFFIXES = ("-none", "-low", "-medium", "-high", "-xhigh") 45 46# Map model names that genai-prices doesn't recognize yet to a known equivalent. 47MODEL_PRICE_ALIASES: Dict[str, str] = { 48 "gpt-5.4": "gpt-5.2", 49 "gpt-5.4-mini": "gpt-5-mini", 50} 51 52GEMINI_PRO = "gemini-3.1-pro-preview" 53GEMINI_FLASH = "gemini-3-flash-preview" 54GEMINI_LITE = "gemini-2.5-flash-lite" 55 56GPT_5 = "gpt-5.4" 57GPT_5_MINI = "gpt-5.4-low" 58GPT_5_NANO = "gpt-5.4-mini" 59 60CLAUDE_OPUS_4 = "claude-opus-4-5" 61CLAUDE_SONNET_4 = "claude-sonnet-4-5" 62CLAUDE_HAIKU_4 = "claude-haiku-4-5" 63 64OLLAMA_PRO = "ollama-local/qwen3.5:35b-a3b-bf16" 65OLLAMA_FLASH = "ollama-local/qwen3.5:9b" 66OLLAMA_LITE = "ollama-local/qwen3.5:2b" 67 68# --------------------------------------------------------------------------- 69# System defaults: provider -> tier -> model 70# --------------------------------------------------------------------------- 71 72PROVIDER_DEFAULTS: Dict[str, Dict[int, str]] = { 73 "google": { 74 TIER_PRO: GEMINI_PRO, 75 TIER_FLASH: GEMINI_FLASH, 76 TIER_LITE: GEMINI_LITE, 77 }, 78 "openai": { 79 TIER_PRO: GPT_5, 80 TIER_FLASH: GPT_5_MINI, 81 TIER_LITE: GPT_5_NANO, 82 }, 83 "anthropic": { 84 TIER_PRO: CLAUDE_OPUS_4, 85 TIER_FLASH: CLAUDE_SONNET_4, 86 TIER_LITE: CLAUDE_HAIKU_4, 87 }, 88 "ollama": { 89 TIER_PRO: OLLAMA_PRO, 90 TIER_FLASH: OLLAMA_FLASH, 91 TIER_LITE: OLLAMA_LITE, 92 }, 93} 94 95TYPE_DEFAULTS: Dict[str, Dict[str, Any]] = { 96 "generate": {"provider": "google", "tier": TIER_FLASH, "backup": "anthropic"}, 97 "cogitate": {"provider": "openai", "tier": TIER_FLASH, "backup": "anthropic"}, 98} 99 100 101# --------------------------------------------------------------------------- 102# Exceptions 103# --------------------------------------------------------------------------- 104 105 106class IncompleteJSONError(ValueError): 107 """Raised when JSON response is truncated due to token limits or other reasons. 108 109 Attributes: 110 reason: The finish/stop reason from the API (e.g., "MAX_TOKENS", "length"). 111 partial_text: The truncated response text, useful for debugging. 112 """ 113 114 def __init__(self, reason: str, partial_text: str): 115 self.reason = reason 116 self.partial_text = partial_text 117 super().__init__(f"JSON response incomplete (reason: {reason})") 118 119 120# --------------------------------------------------------------------------- 121# Prompt context discovery 122# 123# Context metadata (tier, label, group) is defined in prompt .md files via 124# YAML frontmatter. This eliminates duplication between code and config. 125# 126# NAMING CONVENTION: 127# {module}.{feature}[.{operation}] 128# 129# Examples: 130# - observe.describe.frame -> observe module, describe feature, frame operation 131# - observe.enrich -> observe module, enrich feature (no sub-operation) 132# - talent.system.meetings -> talent module, system source, meetings config 133# - talent.entities.observer -> talent module, entities app, observer config 134# - app.chat.title -> apps module, chat app, title operation 135# 136# DISCOVERY SOURCES: 137# 1. Prompt files listed in PROMPT_PATHS (with context in frontmatter) 138# 2. Categories from observe/categories/*.md (tier/label/group in frontmatter) 139# 3. Talent configs from talent/*.md and apps/*/talent/*.md 140# 141# When adding new contexts: 142# 1. Create a .md prompt file with YAML frontmatter containing: 143# context, tier, label, group 144# 2. Add the path to PROMPT_PATHS 145# 3. If not listed, context falls back to the type's default tier 146# --------------------------------------------------------------------------- 147 148# Flat list of prompt files that define context metadata in frontmatter. 149# Each must have: context, tier, label, group in YAML frontmatter. 150PROMPT_PATHS: List[str] = [ 151 "observe/describe.md", 152 "observe/enrich.md", 153 "observe/extract.md", 154 "observe/transcribe/gemini.md", 155 "think/detect_created.md", 156 "think/detect_transcript_segment.md", 157 "think/detect_transcript_json.md", 158 "think/planner.md", 159] 160 161 162# --------------------------------------------------------------------------- 163# Dynamic context discovery 164# --------------------------------------------------------------------------- 165 166# Cached context registry (built lazily on first use) 167_context_registry: Optional[Dict[str, Dict[str, Any]]] = None 168_LEGACY_CONTEXT_PREFIX = "talent." 169_TALENT_CONTEXT_PREFIX = "talent." 170 171 172def _discover_prompt_contexts() -> Dict[str, Dict[str, Any]]: 173 """Load context metadata from prompt files listed in PROMPT_PATHS. 174 175 Each file must have YAML frontmatter with: 176 - context: The context string (e.g., "observe.enrich") 177 - tier: Tier number (1=pro, 2=flash, 3=lite) 178 - label: Human-readable name 179 - group: Settings UI category 180 181 Returns 182 ------- 183 Dict[str, Dict[str, Any]] 184 Mapping of context patterns to {tier, label, group} dicts. 185 """ 186 contexts = {} 187 base_dir = Path(__file__).parent.parent # Project root 188 189 for rel_path in PROMPT_PATHS: 190 path = base_dir / rel_path 191 if not path.exists(): 192 logging.getLogger(__name__).warning(f"Prompt file not found: {path}") 193 continue 194 195 try: 196 post = frontmatter.load(path) 197 meta = post.metadata or {} 198 199 context = meta.get("context") 200 if not context: 201 logging.getLogger(__name__).warning(f"No context in {path}") 202 continue 203 204 contexts[context] = { 205 "tier": meta.get("tier", TIER_FLASH), 206 "label": meta.get("label", context), 207 "group": meta.get("group", "Other"), 208 } 209 except Exception as e: 210 logging.getLogger(__name__).warning(f"Failed to load {path}: {e}") 211 212 return contexts 213 214 215def _discover_talent_contexts() -> Dict[str, Dict[str, Any]]: 216 """Discover talent context defaults from talent/*.md config files. 217 218 Uses get_talent_configs() from think.talent to load all talent configurations 219 and converts them to context patterns with tier/label/group metadata. 220 221 Returns 222 ------- 223 Dict[str, Dict[str, Any]] 224 Mapping of context patterns to {tier, label, group, type} dicts. 225 Context patterns are: talent.system.{name} or talent.{app}.{name} 226 """ 227 from think.talent import get_talent_configs, key_to_context 228 229 contexts = {} 230 231 # Load all talent configs (including disabled for completeness) 232 all_configs = get_talent_configs(include_disabled=True) 233 234 for key, config in all_configs.items(): 235 context = key_to_context(key) 236 contexts[context] = { 237 "tier": config.get("tier", TIER_FLASH), 238 "label": config.get("label", config.get("title", key)), 239 "group": config.get("group", "Think"), 240 "type": config.get("type"), 241 } 242 243 return contexts 244 245 246def _build_context_registry() -> Dict[str, Dict[str, Any]]: 247 """Build complete context registry from discovered configs. 248 249 Merges: 250 1. Prompt contexts from _discover_prompt_contexts() 251 2. Category contexts from observe/describe.py CATEGORIES 252 3. Talent contexts from _discover_talent_contexts() 253 254 Returns 255 ------- 256 Dict[str, Dict[str, Any]] 257 Complete context registry mapping patterns to {tier, label, group}. 258 """ 259 # Start with prompt contexts (from PROMPT_PATHS) 260 registry = _discover_prompt_contexts() 261 262 # Merge category contexts (lazy import to avoid circular dependency) 263 try: 264 from observe.describe import CATEGORIES 265 266 for category, metadata in CATEGORIES.items(): 267 context = metadata.get("context", f"observe.describe.{category}") 268 registry[context] = { 269 "tier": metadata.get("tier", TIER_FLASH), 270 "label": metadata.get("label", category.replace("_", " ").title()), 271 "group": metadata.get("group", "Screen Analysis"), 272 } 273 except ImportError: 274 pass # observe module not available 275 276 # Merge talent contexts (agents + generators) 277 talent_contexts = _discover_talent_contexts() 278 registry.update(talent_contexts) 279 280 return registry 281 282 283def get_context_registry() -> Dict[str, Dict[str, Any]]: 284 """Get the complete context registry, building it lazily on first use. 285 286 Returns 287 ------- 288 Dict[str, Dict[str, Any]] 289 Complete context registry mapping patterns to {tier, label, group}. 290 """ 291 global _context_registry 292 if _context_registry is None: 293 _context_registry = _build_context_registry() 294 return _context_registry 295 296 297def _resolve_tier(context: str, agent_type: str) -> int: 298 """Resolve context to tier number. 299 300 Checks journal config contexts first, then dynamic context registry with glob matching. 301 302 Parameters 303 ---------- 304 context 305 Context string (e.g., "talent.system.default", "observe.describe.frame"). 306 agent_type 307 Agent type ("generate" or "cogitate"). 308 309 Returns 310 ------- 311 int 312 Tier number (1=pro, 2=flash, 3=lite). 313 """ 314 from think.utils import get_config 315 316 default_tier = TYPE_DEFAULTS[agent_type]["tier"] 317 318 journal_config = get_config() 319 providers_config = journal_config.get("providers", {}) 320 contexts = providers_config.get("contexts", {}) 321 322 # Get dynamic context registry (discovered prompts, categories, talent configs) 323 registry = get_context_registry() 324 325 # Check journal config contexts first (exact match) 326 if context in contexts: 327 return contexts[context].get("tier", default_tier) 328 329 # Check context registry (exact match) 330 if context in registry: 331 return registry[context]["tier"] 332 333 # Check glob patterns in both 334 for pattern, ctx_config in contexts.items(): 335 if fnmatch.fnmatch(context, pattern): 336 return ctx_config.get("tier", default_tier) 337 338 for pattern, ctx_default in registry.items(): 339 if fnmatch.fnmatch(context, pattern): 340 return ctx_default["tier"] 341 342 return default_tier 343 344 345def _resolve_model(provider: str, tier: int, config_models: Dict[str, Any]) -> str: 346 """Resolve tier to model string for a given provider. 347 348 Checks config overrides first, then falls back to system defaults. 349 If requested tier is unavailable, falls back to more capable tiers 350 (3→2→1, i.e., lite→flash→pro). 351 352 Parameters 353 ---------- 354 provider 355 Provider name ("google", "openai", "anthropic"). 356 tier 357 Tier number (1=pro, 2=flash, 3=lite). 358 config_models 359 The "models" section from providers config, mapping provider to tier overrides. 360 361 Returns 362 ------- 363 str 364 Model identifier string. 365 """ 366 # Check config overrides first 367 provider_overrides = config_models.get(provider, {}) 368 369 # Try requested tier, then fall back to more capable tiers (lower numbers) 370 for t in [tier, tier - 1, tier - 2] if tier > 1 else [tier]: 371 if t < 1: 372 continue 373 374 # Check config override (tier as string key in JSON) 375 tier_key = str(t) 376 if tier_key in provider_overrides: 377 return provider_overrides[tier_key] 378 379 # Check system defaults 380 provider_defaults = PROVIDER_DEFAULTS.get(provider, {}) 381 if t in provider_defaults: 382 return provider_defaults[t] 383 384 # Ultimate fallback: system default for provider at TIER_FLASH 385 provider_defaults = PROVIDER_DEFAULTS.get(provider, PROVIDER_DEFAULTS["google"]) 386 return provider_defaults.get(TIER_FLASH, GEMINI_FLASH) 387 388 389def resolve_model_for_provider( 390 context: str, provider: str, agent_type: str = "generate" 391) -> str: 392 """Resolve model for a specific provider based on context tier. 393 394 Use this when provider is overridden from the default - resolves the 395 appropriate model for the given provider at the context's tier. 396 397 Parameters 398 ---------- 399 context 400 Context string (e.g., "talent.system.default"). 401 provider 402 Provider name ("google", "openai", "anthropic"). 403 agent_type 404 Agent type ("generate" or "cogitate"). 405 406 Returns 407 ------- 408 str 409 Model identifier string for the provider at the context's tier. 410 """ 411 from think.utils import get_config 412 413 tier = _resolve_tier(context, agent_type) 414 journal_config = get_config() 415 providers_config = journal_config.get("providers", {}) 416 config_models = providers_config.get("models", {}) 417 418 return _resolve_model(provider, tier, config_models) 419 420 421def resolve_provider(context: str, agent_type: str) -> tuple[str, str]: 422 """Resolve context to provider and model based on configuration. 423 424 Matches context against configured contexts using exact match first, 425 then glob patterns (via fnmatch), falling back to type-specific defaults. 426 427 Supports both explicit model strings and tier-based routing: 428 - {"provider": "google", "model": "gemini-3-flash-preview"} - explicit model 429 - {"provider": "google", "tier": 2} - tier-based (2=flash) 430 - {"tier": 1} - tier only, inherits provider from type default 431 432 The "models" section in providers config allows overriding which model 433 is used for each tier per provider. 434 435 Parameters 436 ---------- 437 context 438 Context string (e.g., "observe.describe.frame", "talent.system.meetings"). 439 agent_type 440 Agent type ("generate" or "cogitate"). 441 442 Returns 443 ------- 444 tuple[str, str] 445 (provider_name, model) tuple. Provider is one of "google", "openai", 446 "anthropic". Model is the full model identifier string. 447 """ 448 config = get_config() 449 providers = config.get("providers", {}) 450 config_models = providers.get("models", {}) 451 452 # Get type-specific defaults from config, falling back to system constants 453 type_defaults = TYPE_DEFAULTS[agent_type] 454 type_config = providers.get(agent_type, {}) 455 default_provider = type_config.get("provider", type_defaults["provider"]) 456 default_tier = type_config.get("tier", type_defaults["tier"]) 457 458 # Handle explicit "model" key in type config (overrides tier-based resolution) 459 if "model" in type_config and "tier" not in type_config: 460 default_model = type_config["model"] 461 else: 462 default_model = _resolve_model(default_provider, default_tier, config_models) 463 464 contexts = providers.get("contexts", {}) 465 466 # Find matching context config 467 match_config: Optional[Dict[str, Any]] = None 468 469 if context and contexts: 470 # Check for exact match first 471 if context in contexts: 472 match_config = contexts[context] 473 else: 474 # Check glob patterns - most specific (longest non-wildcard prefix) wins 475 matches = [] 476 for pattern, ctx_config in contexts.items(): 477 if fnmatch.fnmatch(context, pattern): 478 specificity = len(pattern.split("*")[0]) 479 matches.append((specificity, pattern, ctx_config)) 480 481 if matches: 482 matches.sort(key=lambda x: x[0], reverse=True) 483 _, _, match_config = matches[0] 484 485 # No context match - check dynamic context registry for this context 486 if match_config is None: 487 # Get dynamic context registry (discovered prompts, categories, talent configs) 488 registry = get_context_registry() 489 490 # Check for matching context default (exact match first, then glob) 491 context_tier = None 492 if context: 493 if context in registry: 494 context_tier = registry[context]["tier"] 495 else: 496 # Check glob patterns 497 matches = [] 498 for pattern, ctx_default in registry.items(): 499 if fnmatch.fnmatch(context, pattern): 500 specificity = len(pattern.split("*")[0]) 501 matches.append((specificity, ctx_default["tier"])) 502 if matches: 503 matches.sort(key=lambda x: x[0], reverse=True) 504 context_tier = matches[0][1] 505 506 if context_tier is not None: 507 model = _resolve_model(default_provider, context_tier, config_models) 508 return (default_provider, model) 509 510 return (default_provider, default_model) 511 512 # Resolve provider (from match or default) 513 provider = match_config.get("provider", default_provider) 514 515 # Resolve model: explicit model takes precedence over tier 516 if "model" in match_config: 517 model = match_config["model"] 518 elif "tier" in match_config: 519 tier = match_config["tier"] 520 # Validate tier 521 if not isinstance(tier, int) or tier < 1 or tier > 3: 522 logging.getLogger(__name__).warning( 523 "Invalid tier %r in context %r, using default", tier, context 524 ) 525 tier = default_tier 526 model = _resolve_model(provider, tier, config_models) 527 else: 528 # No model or tier specified - use default tier 529 model = _resolve_model(provider, default_tier, config_models) 530 531 return (provider, model) 532 533 534def log_token_usage( 535 model: str, 536 usage: Union[Dict[str, Any], Any], 537 context: Optional[str] = None, 538 segment: Optional[str] = None, 539 type: Optional[str] = None, 540) -> None: 541 """Log token usage to journal with unified schema. 542 543 Providers normalize usage into the unified schema (see USAGE_KEYS in 544 shared.py) before returning GenerateResult. This function passes 545 through those known keys, computes total_tokens when missing, and 546 handles a few legacy field aliases from CLI backends. 547 548 Parameters 549 ---------- 550 model : str 551 Model name (e.g., "gpt-5", "gemini-2.5-flash") 552 usage : dict 553 Normalized usage dict with keys from USAGE_KEYS. 554 context : str, optional 555 Context string (e.g., "module.function:123" or "talent.system.default"). 556 If None, auto-detects from call stack. 557 segment : str, optional 558 Segment key (e.g., "143022_300") for attribution. 559 If None, falls back to SOL_SEGMENT environment variable. 560 type : str, optional 561 Token entry type (e.g., "generate", "cogitate"). 562 """ 563 from think.providers.shared import USAGE_KEYS 564 565 try: 566 journal = get_journal() 567 568 # Auto-detect calling context if not provided 569 if context is None: 570 frame = inspect.currentframe() 571 caller_frame = frame.f_back if frame else None 572 573 # Skip frames that contain "gemini" in function name 574 while caller_frame and "gemini" in caller_frame.f_code.co_name.lower(): 575 caller_frame = caller_frame.f_back 576 577 if caller_frame: 578 module_name = caller_frame.f_globals.get("__name__", "unknown") 579 func_name = caller_frame.f_code.co_name 580 line_num = caller_frame.f_lineno 581 582 # Clean up module name 583 for prefix in ["think.", "observe.", "convey."]: 584 if module_name.startswith(prefix): 585 module_name = module_name[len(prefix) :] 586 break 587 588 context = f"{module_name}.{func_name}:{line_num}" 589 590 # Pass through known keys from the already-normalized usage dict. 591 normalized_usage: Dict[str, int] = {} 592 for key in USAGE_KEYS: 593 val = usage.get(key) 594 if val: 595 normalized_usage[key] = val 596 597 # Legacy alias: some CLI backends emit cached_input_tokens 598 if not normalized_usage.get("cached_tokens") and usage.get( 599 "cached_input_tokens" 600 ): 601 normalized_usage["cached_tokens"] = usage["cached_input_tokens"] 602 603 # Compute total_tokens from parts when missing (e.g. Codex CLI omits it) 604 if not normalized_usage.get("total_tokens"): 605 inp = normalized_usage.get("input_tokens", 0) 606 out = normalized_usage.get("output_tokens", 0) 607 if inp or out: 608 normalized_usage["total_tokens"] = inp + out 609 610 # Build token log entry 611 token_data = { 612 "timestamp": time.time(), 613 "model": model, 614 "context": context, 615 "usage": normalized_usage, 616 } 617 618 # Add segment: prefer parameter, fallback to env (set by think/insight, observe handlers) 619 segment_key = segment or os.getenv("SOL_SEGMENT") 620 if segment_key: 621 token_data["segment"] = segment_key 622 if type: 623 token_data["type"] = type 624 625 # Save to journal/tokens/<YYYYMMDD>.jsonl (one file per day) 626 tokens_dir = Path(journal) / "tokens" 627 tokens_dir.mkdir(exist_ok=True) 628 629 filename = time.strftime("%Y%m%d.jsonl") 630 filepath = tokens_dir / filename 631 632 # Atomic append - safe for parallel writers 633 with open(filepath, "a") as f: 634 f.write(json.dumps(token_data) + "\n") 635 636 except Exception: 637 # Silently fail - logging shouldn't break the main flow 638 pass 639 640 641def get_model_provider(model: str) -> str: 642 """Get the provider name from a model identifier. 643 644 Parameters 645 ---------- 646 model : str 647 Model name (e.g., "gpt-5", "gemini-2.5-flash", "claude-sonnet-4-5") 648 649 Returns 650 ------- 651 str 652 Provider name: "openai", "google", "anthropic", "ollama", or "unknown" 653 """ 654 model_lower = model.lower() 655 656 if model_lower.startswith("ollama-local/"): 657 return "ollama" 658 elif model_lower.startswith("gpt"): 659 return "openai" 660 elif model_lower.startswith("gemini"): 661 return "google" 662 elif model_lower.startswith("claude"): 663 return "anthropic" 664 else: 665 return "unknown" 666 667 668def calc_token_cost(token_data: Dict[str, Any]) -> Optional[Dict[str, Any]]: 669 """Calculate cost for a token usage record. 670 671 Parameters 672 ---------- 673 token_data : dict 674 Token usage record from journal logs with structure: 675 { 676 "model": "gemini-2.5-flash", 677 "usage": { 678 "input_tokens": 1500, 679 "output_tokens": 500, 680 "cached_tokens": 800, 681 "reasoning_tokens": 200, 682 ... 683 } 684 } 685 686 Returns 687 ------- 688 dict or None 689 Cost breakdown: 690 { 691 "total_cost": 0.00123, 692 "input_cost": 0.00075, 693 "output_cost": 0.00048, 694 "currency": "USD" 695 } 696 Returns None if pricing unavailable or calculation fails. 697 """ 698 try: 699 from genai_prices import Usage, calc_price 700 701 model = token_data.get("model") 702 usage_data = token_data.get("usage", {}) 703 704 if not model or not usage_data: 705 return None 706 707 # Strip OpenAI reasoning effort suffixes for price lookup 708 for suffix in OPENAI_EFFORT_SUFFIXES: 709 if model.endswith(suffix): 710 model = model[: -len(suffix)] 711 break 712 713 # Get provider ID before aliasing (alias may change the model family) 714 provider_id = get_model_provider(model) 715 if provider_id == "unknown": 716 return None 717 718 # Ollama models are local — no cost 719 if provider_id == "ollama": 720 return { 721 "total_cost": 0.0, 722 "input_cost": 0.0, 723 "output_cost": 0.0, 724 "currency": "USD", 725 } 726 727 # Apply price aliases for models genai-prices doesn't recognize yet 728 model = MODEL_PRICE_ALIASES.get(model, model) 729 730 # Map our token fields to genai_prices Usage format 731 # Note: Gemini reports reasoning_tokens separately, but they're billed at 732 # output token rates. genai-prices doesn't have a separate field for reasoning, 733 # so we add them to output_tokens for correct pricing. 734 input_tokens = usage_data.get("input_tokens", 0) 735 output_tokens = usage_data.get("output_tokens", 0) 736 cached_tokens = usage_data.get("cached_tokens", 0) 737 reasoning_tokens = usage_data.get("reasoning_tokens", 0) 738 739 # Add reasoning tokens to output for pricing (Gemini bills them as output) 740 total_output_tokens = output_tokens + reasoning_tokens 741 742 # Create Usage object 743 usage = Usage( 744 input_tokens=input_tokens, 745 output_tokens=total_output_tokens, 746 cache_read_tokens=cached_tokens if cached_tokens > 0 else None, 747 ) 748 749 # Calculate price 750 result = calc_price( 751 usage=usage, 752 model_ref=model, 753 provider_id=provider_id, 754 ) 755 756 # Return simplified cost breakdown 757 return { 758 "total_cost": float(result.total_price), 759 "input_cost": float(result.input_price), 760 "output_cost": float(result.output_price), 761 "currency": "USD", 762 } 763 764 except Exception: 765 # Silently fail if pricing unavailable 766 return None 767 768 769def calc_agent_cost( 770 model: Optional[str], usage: Optional[Dict[str, Any]] 771) -> Optional[float]: 772 """Calculate total cost for an agent run from model and usage data. 773 774 Convenience wrapper around calc_token_cost for agent cost lookups. 775 776 Returns total cost in USD, or None if data is missing or pricing unavailable. 777 """ 778 if not model or not usage: 779 return None 780 try: 781 cost_data = calc_token_cost({"model": model, "usage": usage}) 782 if cost_data: 783 return cost_data["total_cost"] 784 except Exception: 785 return None 786 return None 787 788 789def _normalize_legacy_context(ctx: str) -> str: 790 """Normalize legacy token-log context strings to the talent namespace.""" 791 if ctx.startswith(_LEGACY_CONTEXT_PREFIX): 792 return _TALENT_CONTEXT_PREFIX + ctx[len(_LEGACY_CONTEXT_PREFIX) :] 793 return ctx 794 795 796def iter_token_log(day: str) -> Any: 797 """Iterate over token log entries for a given day. 798 799 Yields parsed JSON entries from the token log file, skipping empty lines 800 and invalid JSON. This is a shared utility for code that processes token logs. 801 802 Parameters 803 ---------- 804 day : str 805 Day in YYYYMMDD format. 806 807 Yields 808 ------ 809 dict 810 Parsed token log entry with fields: timestamp, model, context, usage, 811 and optionally segment. 812 """ 813 journal = get_journal() 814 log_path = Path(journal) / "tokens" / f"{day}.jsonl" 815 816 if not log_path.exists(): 817 return 818 819 with open(log_path, "r") as f: 820 for line in f: 821 line = line.strip() 822 if not line: 823 continue 824 try: 825 entry = json.loads(line) 826 ctx = entry.get("context") 827 if isinstance(ctx, str): 828 entry["context"] = _normalize_legacy_context(ctx) 829 yield entry 830 except json.JSONDecodeError: 831 continue 832 833 834def get_usage_cost( 835 day: str, 836 segment: Optional[str] = None, 837 context: Optional[str] = None, 838) -> Dict[str, Any]: 839 """Get aggregated token usage and cost for a day, optionally filtered. 840 841 This is a shared utility for apps that want to display cost information 842 for segments, agent runs, or other contexts. 843 844 Parameters 845 ---------- 846 day : str 847 Day in YYYYMMDD format. 848 segment : str, optional 849 Filter to entries with this exact segment key. 850 context : str, optional 851 Filter to entries where context starts with this prefix. 852 For example, "talent.system" matches "talent.system.default". 853 854 Returns 855 ------- 856 dict 857 Aggregated usage data: 858 { 859 "requests": int, 860 "tokens": int, 861 "cost": float, # USD 862 } 863 Returns zeros if no matching entries or day file doesn't exist. 864 """ 865 result = {"requests": 0, "tokens": 0, "cost": 0.0} 866 867 for entry in iter_token_log(day): 868 # Apply filters 869 if segment is not None and entry.get("segment") != segment: 870 continue 871 if context is not None: 872 entry_context = entry.get("context", "") 873 if not entry_context.startswith(context): 874 continue 875 876 # Skip unknown providers (can't calculate cost) 877 model = entry.get("model", "unknown") 878 if get_model_provider(model) == "unknown": 879 continue 880 881 # Accumulate 882 usage = entry.get("usage", {}) 883 result["requests"] += 1 884 result["tokens"] += usage.get("total_tokens", 0) or 0 885 886 cost_data = calc_token_cost(entry) 887 if cost_data: 888 result["cost"] += cost_data["total_cost"] 889 890 return result 891 892 893# --------------------------------------------------------------------------- 894# Unified generate/agenerate with provider routing 895# --------------------------------------------------------------------------- 896 897 898def _validate_json_response(result: Dict[str, Any], json_output: bool) -> None: 899 """Validate response for JSON output mode. 900 901 Raises IncompleteJSONError if finish_reason indicates truncation. 902 """ 903 if not json_output: 904 return 905 906 finish_reason = result.get("finish_reason") 907 if finish_reason and finish_reason != "stop": 908 raise IncompleteJSONError( 909 reason=finish_reason, 910 partial_text=result.get("text", ""), 911 ) 912 913 914def _validate_schema(text: str, schema: dict) -> dict: 915 """Validate JSON text against a JSON Schema and log any violations.""" 916 917 def truncate_repr(value: Any) -> str: 918 value_repr = repr(value) 919 if len(value_repr) <= 80: 920 return value_repr 921 return value_repr[:77] + "..." 922 923 def build_pointer(path: Any) -> str: 924 segments = list(path) 925 if not segments: 926 return "" 927 escaped_segments = [] 928 for segment in segments: 929 escaped = str(segment).replace("~", "~0").replace("/", "~1") 930 escaped_segments.append(escaped) 931 return "/" + "/".join(escaped_segments) 932 933 try: 934 parsed = json.loads(text) 935 except ValueError as exc: 936 error = { 937 "path": "", 938 "constraint": "json_parse", 939 "message": str(exc), 940 } 941 logger.warning( 942 "schema_validation: %s: %s: %s (value=%s)", 943 "", 944 "json_parse", 945 str(exc), 946 truncate_repr(text), 947 ) 948 return {"valid": False, "errors": [error]} 949 950 errors = [] 951 try: 952 validator = Draft202012Validator(schema) 953 validation_errors = list(validator.iter_errors(parsed)) 954 except Exception as exc: 955 error = { 956 "path": "", 957 "constraint": "schema_validation", 958 "message": str(exc), 959 } 960 logger.warning( 961 "schema_validation: %s: %s: %s (value=%s)", 962 "", 963 "schema_validation", 964 str(exc), 965 truncate_repr(parsed), 966 ) 967 return {"valid": False, "errors": [error]} 968 969 for error in validation_errors: 970 path = build_pointer(error.absolute_path) 971 constraint = str(error.validator) 972 message = error.message 973 errors.append( 974 { 975 "path": path, 976 "constraint": constraint, 977 "message": message, 978 } 979 ) 980 logger.warning( 981 "schema_validation: %s: %s: %s (value=%s)", 982 path, 983 constraint, 984 message, 985 truncate_repr(error.instance), 986 ) 987 988 return {"valid": len(errors) == 0, "errors": errors} 989 990 991def generate( 992 contents: Union[str, List[Any]], 993 context: str, 994 temperature: float = 0.3, 995 max_output_tokens: int = 8192 * 2, 996 system_instruction: Optional[str] = None, 997 json_output: bool = False, 998 *, 999 json_schema: dict | None = None, 1000 thinking_budget: Optional[int] = None, 1001 timeout_s: Optional[float] = None, 1002 **kwargs: Any, 1003) -> str: 1004 """Generate text using the configured provider for the given context. 1005 1006 Routes the request to the appropriate backend (Google, OpenAI, or Anthropic) 1007 based on the providers configuration in journal.json. 1008 1009 Parameters 1010 ---------- 1011 contents : str or List 1012 The content to send to the model. 1013 context : str 1014 Context string for routing and token logging (e.g., "talent.system.meetings"). 1015 This is required and determines which provider/model to use. 1016 temperature : float 1017 Temperature for generation (default: 0.3). 1018 max_output_tokens : int 1019 Maximum tokens for the model's response output. 1020 system_instruction : str, optional 1021 System instruction for the model. 1022 json_output : bool 1023 Whether to request JSON response format. 1024 json_schema : dict, optional 1025 JSON Schema to request structured output from the provider. When supplied, 1026 this forces json_output=True and runs advisory local validation on the 1027 returned text after truncation checks. 1028 thinking_budget : int, optional 1029 Token budget for model thinking (ignored by providers that don't support it). 1030 timeout_s : float, optional 1031 Request timeout in seconds. 1032 **kwargs 1033 Additional provider-specific options passed through to the backend. 1034 1035 Returns 1036 ------- 1037 str 1038 Response text from the model. 1039 1040 Raises 1041 ------ 1042 ValueError 1043 If the resolved provider is not supported. 1044 IncompleteJSONError 1045 If json_output=True and response was truncated. 1046 """ 1047 from think.providers import get_provider_module 1048 1049 if json_schema is not None: 1050 json_output = True 1051 1052 # Allow model override via kwargs (used by callers with explicit model selection) 1053 model_override = kwargs.pop("model", None) 1054 1055 provider, model = resolve_provider(context, "generate") 1056 if model_override: 1057 model = model_override 1058 1059 # Get provider module via registry (raises ValueError for unknown providers) 1060 provider_mod = get_provider_module(provider) 1061 1062 # Call provider's run_generate (returns GenerateResult) 1063 result = provider_mod.run_generate( 1064 contents=contents, 1065 model=model, 1066 temperature=temperature, 1067 max_output_tokens=max_output_tokens, 1068 system_instruction=system_instruction, 1069 json_output=json_output, 1070 json_schema=json_schema, 1071 thinking_budget=thinking_budget, 1072 timeout_s=timeout_s, 1073 **kwargs, 1074 ) 1075 1076 # Log token usage centrally (before validation so truncated responses 1077 # still get their usage recorded) 1078 if result.get("usage"): 1079 log_token_usage( 1080 model=model, 1081 usage=result["usage"], 1082 context=context, 1083 type="generate", 1084 ) 1085 1086 # Validate JSON output if requested 1087 _validate_json_response(result, json_output) 1088 1089 if json_schema is not None: 1090 _validate_schema(result["text"], json_schema) 1091 1092 return result["text"] 1093 1094 1095# --------------------------------------------------------------------------- 1096# Provider Health & Fallback Helpers 1097# --------------------------------------------------------------------------- 1098 1099 1100def get_backup_provider(agent_type: str) -> Optional[str]: 1101 """Get the backup provider for the given agent type. 1102 1103 Reads from the type-specific section in journal config, falling back 1104 to TYPE_DEFAULTS. 1105 1106 Returns None if backup would be the same as the primary provider. 1107 """ 1108 type_defaults = TYPE_DEFAULTS[agent_type] 1109 config = get_config() 1110 providers_config = config.get("providers", {}) 1111 type_config = providers_config.get(agent_type, {}) 1112 primary_provider = type_config.get("provider", type_defaults["provider"]) 1113 backup = type_config.get("backup", type_defaults["backup"]) 1114 if backup == primary_provider: 1115 return None 1116 return backup 1117 1118 1119def load_health_status() -> Optional[dict]: 1120 """Load health status from journal/health/agents.json. 1121 1122 Returns parsed dict or None if file is missing/unreadable. 1123 """ 1124 try: 1125 health_path = Path(get_journal()) / "health" / "agents.json" 1126 with open(health_path) as f: 1127 return json.load(f) 1128 except (FileNotFoundError, json.JSONDecodeError, OSError): 1129 return None 1130 1131 1132def is_provider_healthy(provider: str, health_data: Optional[dict]) -> bool: 1133 """Check if a provider is healthy based on health data. 1134 1135 Returns True (assume healthy) when: 1136 - health_data is None (no data available) 1137 - No results exist for the provider 1138 - Any result for the provider has ok=True 1139 1140 Returns False only when all results for the provider have ok=False. 1141 """ 1142 if health_data is None: 1143 return True 1144 results = health_data.get("results", []) 1145 provider_results = [r for r in results if r.get("provider") == provider] 1146 if not provider_results: 1147 return True 1148 return any(r.get("ok") for r in provider_results) 1149 1150 1151def should_recheck_health(health_data: Optional[dict]) -> bool: 1152 """Check if health data is stale (>1 hour old). 1153 1154 Returns False when health_data is None or on parse errors. 1155 """ 1156 if health_data is None: 1157 return False 1158 checked_at = health_data.get("checked_at") 1159 if not checked_at: 1160 return False 1161 try: 1162 checked_time = datetime.fromisoformat(checked_at) 1163 if checked_time.tzinfo is None: 1164 checked_time = checked_time.replace(tzinfo=timezone.utc) 1165 age = datetime.now(timezone.utc) - checked_time 1166 return age.total_seconds() > 3600 1167 except (ValueError, TypeError): 1168 return False 1169 1170 1171def request_health_recheck() -> None: 1172 """Request a health re-check by spawning a background process. 1173 1174 Fire-and-forget; errors are logged but never propagated. 1175 """ 1176 try: 1177 subprocess.Popen( 1178 ["sol", "providers", "check", "--targeted"], 1179 stdout=subprocess.DEVNULL, 1180 stderr=subprocess.DEVNULL, 1181 ) 1182 except Exception: 1183 logging.getLogger(__name__).debug( 1184 "Failed to request health recheck", exc_info=True 1185 ) 1186 1187 1188def generate_with_result( 1189 contents: Union[str, List[Any]], 1190 context: str, 1191 temperature: float = 0.3, 1192 max_output_tokens: int = 8192 * 2, 1193 system_instruction: Optional[str] = None, 1194 json_output: bool = False, 1195 *, 1196 json_schema: dict | None = None, 1197 thinking_budget: Optional[int] = None, 1198 timeout_s: Optional[float] = None, 1199 **kwargs: Any, 1200) -> dict: 1201 """Generate text and return full result with usage data. 1202 1203 Same as generate() but returns the full GenerateResult dict instead of 1204 just the text. Used by cortex-managed generators that need usage data 1205 for event emission. 1206 1207 Parameters 1208 ---------- 1209 contents : str or List 1210 The content to send to the model. 1211 context : str 1212 Context string for routing and token logging. 1213 temperature : float 1214 Temperature for generation (default: 0.3). 1215 max_output_tokens : int 1216 Maximum tokens for the model's response output. 1217 system_instruction : str, optional 1218 System instruction for the model. 1219 json_output : bool 1220 Whether to request JSON response format. 1221 json_schema : dict, optional 1222 JSON Schema to request structured output from the provider. When supplied, 1223 this forces json_output=True and runs advisory local validation on the 1224 returned text after truncation checks. 1225 thinking_budget : int, optional 1226 Token budget for model thinking (ignored by providers that don't support it). 1227 timeout_s : float, optional 1228 Request timeout in seconds. 1229 **kwargs 1230 Additional provider-specific options passed through to the backend. 1231 1232 Returns 1233 ------- 1234 dict 1235 GenerateResult with: text, usage, finish_reason, thinking, and 1236 schema_validation when json_schema is supplied. Validation is advisory 1237 and runs after truncation checks succeed. 1238 """ 1239 from think.providers import get_provider_module 1240 1241 if json_schema is not None: 1242 json_output = True 1243 1244 model_override = kwargs.pop("model", None) 1245 provider_override = kwargs.pop("provider", None) 1246 1247 provider, model = resolve_provider(context, "generate") 1248 if provider_override: 1249 provider = provider_override 1250 if not model_override: 1251 model = resolve_model_for_provider(context, provider, "generate") 1252 if model_override: 1253 model = model_override 1254 1255 provider_mod = get_provider_module(provider) 1256 1257 result = provider_mod.run_generate( 1258 contents=contents, 1259 model=model, 1260 temperature=temperature, 1261 max_output_tokens=max_output_tokens, 1262 system_instruction=system_instruction, 1263 json_output=json_output, 1264 json_schema=json_schema, 1265 thinking_budget=thinking_budget, 1266 timeout_s=timeout_s, 1267 **kwargs, 1268 ) 1269 1270 # Log token usage centrally (before validation so truncated responses 1271 # still get their usage recorded) 1272 if result.get("usage"): 1273 log_token_usage( 1274 model=model, 1275 usage=result["usage"], 1276 context=context, 1277 type="generate", 1278 ) 1279 1280 # Validate JSON output if requested 1281 _validate_json_response(result, json_output) 1282 1283 if json_schema is not None: 1284 result["schema_validation"] = _validate_schema(result["text"], json_schema) 1285 1286 return result 1287 1288 1289async def agenerate( 1290 contents: Union[str, List[Any]], 1291 context: str, 1292 temperature: float = 0.3, 1293 max_output_tokens: int = 8192 * 2, 1294 system_instruction: Optional[str] = None, 1295 json_output: bool = False, 1296 *, 1297 json_schema: dict | None = None, 1298 thinking_budget: Optional[int] = None, 1299 timeout_s: Optional[float] = None, 1300 **kwargs: Any, 1301) -> str: 1302 """Async generate text using the configured provider for the given context. 1303 1304 Routes the request to the appropriate backend (Google, OpenAI, or Anthropic) 1305 based on the providers configuration in journal.json. 1306 1307 Parameters 1308 ---------- 1309 contents : str or List 1310 The content to send to the model. 1311 context : str 1312 Context string for routing and token logging (e.g., "talent.system.meetings"). 1313 This is required and determines which provider/model to use. 1314 temperature : float 1315 Temperature for generation (default: 0.3). 1316 max_output_tokens : int 1317 Maximum tokens for the model's response output. 1318 system_instruction : str, optional 1319 System instruction for the model. 1320 json_output : bool 1321 Whether to request JSON response format. 1322 json_schema : dict, optional 1323 JSON Schema to request structured output from the provider. When supplied, 1324 this forces json_output=True and runs advisory local validation on the 1325 returned text after truncation checks. 1326 thinking_budget : int, optional 1327 Token budget for model thinking (ignored by providers that don't support it). 1328 timeout_s : float, optional 1329 Request timeout in seconds. 1330 **kwargs 1331 Additional provider-specific options passed through to the backend. 1332 1333 Returns 1334 ------- 1335 str 1336 Response text from the model. 1337 1338 Raises 1339 ------ 1340 ValueError 1341 If the resolved provider is not supported. 1342 IncompleteJSONError 1343 If json_output=True and response was truncated. 1344 """ 1345 from think.providers import get_provider_module 1346 1347 if json_schema is not None: 1348 json_output = True 1349 1350 # Allow model override via kwargs (used by Batch for explicit model selection) 1351 model_override = kwargs.pop("model", None) 1352 1353 provider, model = resolve_provider(context, "generate") 1354 if model_override: 1355 model = model_override 1356 1357 # Get provider module via registry (raises ValueError for unknown providers) 1358 provider_mod = get_provider_module(provider) 1359 1360 # Call provider's run_agenerate (returns GenerateResult) 1361 result = await provider_mod.run_agenerate( 1362 contents=contents, 1363 model=model, 1364 temperature=temperature, 1365 max_output_tokens=max_output_tokens, 1366 system_instruction=system_instruction, 1367 json_output=json_output, 1368 json_schema=json_schema, 1369 thinking_budget=thinking_budget, 1370 timeout_s=timeout_s, 1371 **kwargs, 1372 ) 1373 1374 # Log token usage centrally (before validation so truncated responses 1375 # still get their usage recorded) 1376 if result.get("usage"): 1377 log_token_usage( 1378 model=model, 1379 usage=result["usage"], 1380 context=context, 1381 type="generate", 1382 ) 1383 1384 # Validate JSON output if requested 1385 _validate_json_response(result, json_output) 1386 1387 if json_schema is not None: 1388 _validate_schema(result["text"], json_schema) 1389 1390 return result["text"] 1391 1392 1393__all__ = [ 1394 # Provider configuration 1395 "TYPE_DEFAULTS", 1396 "PROMPT_PATHS", 1397 "get_context_registry", 1398 # Model constants (used by provider backends for defaults) 1399 "GEMINI_FLASH", 1400 "GPT_5", 1401 "CLAUDE_SONNET_4", 1402 # Unified API 1403 "generate", 1404 "generate_with_result", 1405 "agenerate", 1406 "resolve_provider", 1407 # Utilities 1408 "log_token_usage", 1409 "calc_token_cost", 1410 "calc_agent_cost", 1411 "get_usage_cost", 1412 "iter_token_log", 1413 "get_model_provider", 1414]