personal memory agent
1# SPDX-License-Identifier: AGPL-3.0-only
2# Copyright (c) 2026 sol pbc
3
4import fnmatch
5import inspect
6import json
7import logging
8import os
9import subprocess
10import time
11from datetime import datetime, timezone
12from pathlib import Path
13from typing import Any, Dict, List, Optional, Union
14
15import frontmatter
16from jsonschema import Draft202012Validator
17
18from think.utils import get_config, get_journal
19
20logger = logging.getLogger(__name__)
21
22# ---------------------------------------------------------------------------
23# Tier constants
24# ---------------------------------------------------------------------------
25
26TIER_PRO = 1
27TIER_FLASH = 2
28TIER_LITE = 3
29
30# ---------------------------------------------------------------------------
31# Model constants
32#
33# IMPORTANT: When updating these models, verify pricing support:
34# 1. Run: make test-only TEST=tests/test_models.py::test_all_default_models_have_pricing
35# 2. If test fails, update genai-prices: make update-prices
36# 3. If still failing, the model may be too new for genai-prices
37#
38# The genai-prices library provides token cost data. New models may not have
39# pricing immediately after release. See: https://pypi.org/project/genai-prices/
40# ---------------------------------------------------------------------------
41
42# Valid OpenAI reasoning effort suffixes appended to model names.
43# E.g., "gpt-5.2-high" → reasoning_effort="high", "gpt-5.2" → omitted.
44OPENAI_EFFORT_SUFFIXES = ("-none", "-low", "-medium", "-high", "-xhigh")
45
46# Map model names that genai-prices doesn't recognize yet to a known equivalent.
47MODEL_PRICE_ALIASES: Dict[str, str] = {
48 "gpt-5.4": "gpt-5.2",
49 "gpt-5.4-mini": "gpt-5-mini",
50}
51
52GEMINI_PRO = "gemini-3.1-pro-preview"
53GEMINI_FLASH = "gemini-3-flash-preview"
54GEMINI_LITE = "gemini-2.5-flash-lite"
55
56GPT_5 = "gpt-5.4"
57GPT_5_MINI = "gpt-5.4-low"
58GPT_5_NANO = "gpt-5.4-mini"
59
60CLAUDE_OPUS_4 = "claude-opus-4-5"
61CLAUDE_SONNET_4 = "claude-sonnet-4-5"
62CLAUDE_HAIKU_4 = "claude-haiku-4-5"
63
64OLLAMA_PRO = "ollama-local/qwen3.5:35b-a3b-bf16"
65OLLAMA_FLASH = "ollama-local/qwen3.5:9b"
66OLLAMA_LITE = "ollama-local/qwen3.5:2b"
67
68# ---------------------------------------------------------------------------
69# System defaults: provider -> tier -> model
70# ---------------------------------------------------------------------------
71
72PROVIDER_DEFAULTS: Dict[str, Dict[int, str]] = {
73 "google": {
74 TIER_PRO: GEMINI_PRO,
75 TIER_FLASH: GEMINI_FLASH,
76 TIER_LITE: GEMINI_LITE,
77 },
78 "openai": {
79 TIER_PRO: GPT_5,
80 TIER_FLASH: GPT_5_MINI,
81 TIER_LITE: GPT_5_NANO,
82 },
83 "anthropic": {
84 TIER_PRO: CLAUDE_OPUS_4,
85 TIER_FLASH: CLAUDE_SONNET_4,
86 TIER_LITE: CLAUDE_HAIKU_4,
87 },
88 "ollama": {
89 TIER_PRO: OLLAMA_PRO,
90 TIER_FLASH: OLLAMA_FLASH,
91 TIER_LITE: OLLAMA_LITE,
92 },
93}
94
95TYPE_DEFAULTS: Dict[str, Dict[str, Any]] = {
96 "generate": {"provider": "google", "tier": TIER_FLASH, "backup": "anthropic"},
97 "cogitate": {"provider": "openai", "tier": TIER_FLASH, "backup": "anthropic"},
98}
99
100
101# ---------------------------------------------------------------------------
102# Exceptions
103# ---------------------------------------------------------------------------
104
105
106class IncompleteJSONError(ValueError):
107 """Raised when JSON response is truncated due to token limits or other reasons.
108
109 Attributes:
110 reason: The finish/stop reason from the API (e.g., "MAX_TOKENS", "length").
111 partial_text: The truncated response text, useful for debugging.
112 """
113
114 def __init__(self, reason: str, partial_text: str):
115 self.reason = reason
116 self.partial_text = partial_text
117 super().__init__(f"JSON response incomplete (reason: {reason})")
118
119
120# ---------------------------------------------------------------------------
121# Prompt context discovery
122#
123# Context metadata (tier, label, group) is defined in prompt .md files via
124# YAML frontmatter. This eliminates duplication between code and config.
125#
126# NAMING CONVENTION:
127# {module}.{feature}[.{operation}]
128#
129# Examples:
130# - observe.describe.frame -> observe module, describe feature, frame operation
131# - observe.enrich -> observe module, enrich feature (no sub-operation)
132# - talent.system.meetings -> talent module, system source, meetings config
133# - talent.entities.observer -> talent module, entities app, observer config
134# - app.chat.title -> apps module, chat app, title operation
135#
136# DISCOVERY SOURCES:
137# 1. Prompt files listed in PROMPT_PATHS (with context in frontmatter)
138# 2. Categories from observe/categories/*.md (tier/label/group in frontmatter)
139# 3. Talent configs from talent/*.md and apps/*/talent/*.md
140#
141# When adding new contexts:
142# 1. Create a .md prompt file with YAML frontmatter containing:
143# context, tier, label, group
144# 2. Add the path to PROMPT_PATHS
145# 3. If not listed, context falls back to the type's default tier
146# ---------------------------------------------------------------------------
147
148# Flat list of prompt files that define context metadata in frontmatter.
149# Each must have: context, tier, label, group in YAML frontmatter.
150PROMPT_PATHS: List[str] = [
151 "observe/describe.md",
152 "observe/enrich.md",
153 "observe/extract.md",
154 "observe/transcribe/gemini.md",
155 "think/detect_created.md",
156 "think/detect_transcript_segment.md",
157 "think/detect_transcript_json.md",
158 "think/planner.md",
159]
160
161
162# ---------------------------------------------------------------------------
163# Dynamic context discovery
164# ---------------------------------------------------------------------------
165
166# Cached context registry (built lazily on first use)
167_context_registry: Optional[Dict[str, Dict[str, Any]]] = None
168_LEGACY_CONTEXT_PREFIX = "talent."
169_TALENT_CONTEXT_PREFIX = "talent."
170
171
172def _discover_prompt_contexts() -> Dict[str, Dict[str, Any]]:
173 """Load context metadata from prompt files listed in PROMPT_PATHS.
174
175 Each file must have YAML frontmatter with:
176 - context: The context string (e.g., "observe.enrich")
177 - tier: Tier number (1=pro, 2=flash, 3=lite)
178 - label: Human-readable name
179 - group: Settings UI category
180
181 Returns
182 -------
183 Dict[str, Dict[str, Any]]
184 Mapping of context patterns to {tier, label, group} dicts.
185 """
186 contexts = {}
187 base_dir = Path(__file__).parent.parent # Project root
188
189 for rel_path in PROMPT_PATHS:
190 path = base_dir / rel_path
191 if not path.exists():
192 logging.getLogger(__name__).warning(f"Prompt file not found: {path}")
193 continue
194
195 try:
196 post = frontmatter.load(path)
197 meta = post.metadata or {}
198
199 context = meta.get("context")
200 if not context:
201 logging.getLogger(__name__).warning(f"No context in {path}")
202 continue
203
204 contexts[context] = {
205 "tier": meta.get("tier", TIER_FLASH),
206 "label": meta.get("label", context),
207 "group": meta.get("group", "Other"),
208 }
209 except Exception as e:
210 logging.getLogger(__name__).warning(f"Failed to load {path}: {e}")
211
212 return contexts
213
214
215def _discover_talent_contexts() -> Dict[str, Dict[str, Any]]:
216 """Discover talent context defaults from talent/*.md config files.
217
218 Uses get_talent_configs() from think.talent to load all talent configurations
219 and converts them to context patterns with tier/label/group metadata.
220
221 Returns
222 -------
223 Dict[str, Dict[str, Any]]
224 Mapping of context patterns to {tier, label, group, type} dicts.
225 Context patterns are: talent.system.{name} or talent.{app}.{name}
226 """
227 from think.talent import get_talent_configs, key_to_context
228
229 contexts = {}
230
231 # Load all talent configs (including disabled for completeness)
232 all_configs = get_talent_configs(include_disabled=True)
233
234 for key, config in all_configs.items():
235 context = key_to_context(key)
236 contexts[context] = {
237 "tier": config.get("tier", TIER_FLASH),
238 "label": config.get("label", config.get("title", key)),
239 "group": config.get("group", "Think"),
240 "type": config.get("type"),
241 }
242
243 return contexts
244
245
246def _build_context_registry() -> Dict[str, Dict[str, Any]]:
247 """Build complete context registry from discovered configs.
248
249 Merges:
250 1. Prompt contexts from _discover_prompt_contexts()
251 2. Category contexts from observe/describe.py CATEGORIES
252 3. Talent contexts from _discover_talent_contexts()
253
254 Returns
255 -------
256 Dict[str, Dict[str, Any]]
257 Complete context registry mapping patterns to {tier, label, group}.
258 """
259 # Start with prompt contexts (from PROMPT_PATHS)
260 registry = _discover_prompt_contexts()
261
262 # Merge category contexts (lazy import to avoid circular dependency)
263 try:
264 from observe.describe import CATEGORIES
265
266 for category, metadata in CATEGORIES.items():
267 context = metadata.get("context", f"observe.describe.{category}")
268 registry[context] = {
269 "tier": metadata.get("tier", TIER_FLASH),
270 "label": metadata.get("label", category.replace("_", " ").title()),
271 "group": metadata.get("group", "Screen Analysis"),
272 }
273 except ImportError:
274 pass # observe module not available
275
276 # Merge talent contexts (agents + generators)
277 talent_contexts = _discover_talent_contexts()
278 registry.update(talent_contexts)
279
280 return registry
281
282
283def get_context_registry() -> Dict[str, Dict[str, Any]]:
284 """Get the complete context registry, building it lazily on first use.
285
286 Returns
287 -------
288 Dict[str, Dict[str, Any]]
289 Complete context registry mapping patterns to {tier, label, group}.
290 """
291 global _context_registry
292 if _context_registry is None:
293 _context_registry = _build_context_registry()
294 return _context_registry
295
296
297def _resolve_tier(context: str, agent_type: str) -> int:
298 """Resolve context to tier number.
299
300 Checks journal config contexts first, then dynamic context registry with glob matching.
301
302 Parameters
303 ----------
304 context
305 Context string (e.g., "talent.system.default", "observe.describe.frame").
306 agent_type
307 Agent type ("generate" or "cogitate").
308
309 Returns
310 -------
311 int
312 Tier number (1=pro, 2=flash, 3=lite).
313 """
314 from think.utils import get_config
315
316 default_tier = TYPE_DEFAULTS[agent_type]["tier"]
317
318 journal_config = get_config()
319 providers_config = journal_config.get("providers", {})
320 contexts = providers_config.get("contexts", {})
321
322 # Get dynamic context registry (discovered prompts, categories, talent configs)
323 registry = get_context_registry()
324
325 # Check journal config contexts first (exact match)
326 if context in contexts:
327 return contexts[context].get("tier", default_tier)
328
329 # Check context registry (exact match)
330 if context in registry:
331 return registry[context]["tier"]
332
333 # Check glob patterns in both
334 for pattern, ctx_config in contexts.items():
335 if fnmatch.fnmatch(context, pattern):
336 return ctx_config.get("tier", default_tier)
337
338 for pattern, ctx_default in registry.items():
339 if fnmatch.fnmatch(context, pattern):
340 return ctx_default["tier"]
341
342 return default_tier
343
344
345def _resolve_model(provider: str, tier: int, config_models: Dict[str, Any]) -> str:
346 """Resolve tier to model string for a given provider.
347
348 Checks config overrides first, then falls back to system defaults.
349 If requested tier is unavailable, falls back to more capable tiers
350 (3→2→1, i.e., lite→flash→pro).
351
352 Parameters
353 ----------
354 provider
355 Provider name ("google", "openai", "anthropic").
356 tier
357 Tier number (1=pro, 2=flash, 3=lite).
358 config_models
359 The "models" section from providers config, mapping provider to tier overrides.
360
361 Returns
362 -------
363 str
364 Model identifier string.
365 """
366 # Check config overrides first
367 provider_overrides = config_models.get(provider, {})
368
369 # Try requested tier, then fall back to more capable tiers (lower numbers)
370 for t in [tier, tier - 1, tier - 2] if tier > 1 else [tier]:
371 if t < 1:
372 continue
373
374 # Check config override (tier as string key in JSON)
375 tier_key = str(t)
376 if tier_key in provider_overrides:
377 return provider_overrides[tier_key]
378
379 # Check system defaults
380 provider_defaults = PROVIDER_DEFAULTS.get(provider, {})
381 if t in provider_defaults:
382 return provider_defaults[t]
383
384 # Ultimate fallback: system default for provider at TIER_FLASH
385 provider_defaults = PROVIDER_DEFAULTS.get(provider, PROVIDER_DEFAULTS["google"])
386 return provider_defaults.get(TIER_FLASH, GEMINI_FLASH)
387
388
389def resolve_model_for_provider(
390 context: str, provider: str, agent_type: str = "generate"
391) -> str:
392 """Resolve model for a specific provider based on context tier.
393
394 Use this when provider is overridden from the default - resolves the
395 appropriate model for the given provider at the context's tier.
396
397 Parameters
398 ----------
399 context
400 Context string (e.g., "talent.system.default").
401 provider
402 Provider name ("google", "openai", "anthropic").
403 agent_type
404 Agent type ("generate" or "cogitate").
405
406 Returns
407 -------
408 str
409 Model identifier string for the provider at the context's tier.
410 """
411 from think.utils import get_config
412
413 tier = _resolve_tier(context, agent_type)
414 journal_config = get_config()
415 providers_config = journal_config.get("providers", {})
416 config_models = providers_config.get("models", {})
417
418 return _resolve_model(provider, tier, config_models)
419
420
421def resolve_provider(context: str, agent_type: str) -> tuple[str, str]:
422 """Resolve context to provider and model based on configuration.
423
424 Matches context against configured contexts using exact match first,
425 then glob patterns (via fnmatch), falling back to type-specific defaults.
426
427 Supports both explicit model strings and tier-based routing:
428 - {"provider": "google", "model": "gemini-3-flash-preview"} - explicit model
429 - {"provider": "google", "tier": 2} - tier-based (2=flash)
430 - {"tier": 1} - tier only, inherits provider from type default
431
432 The "models" section in providers config allows overriding which model
433 is used for each tier per provider.
434
435 Parameters
436 ----------
437 context
438 Context string (e.g., "observe.describe.frame", "talent.system.meetings").
439 agent_type
440 Agent type ("generate" or "cogitate").
441
442 Returns
443 -------
444 tuple[str, str]
445 (provider_name, model) tuple. Provider is one of "google", "openai",
446 "anthropic". Model is the full model identifier string.
447 """
448 config = get_config()
449 providers = config.get("providers", {})
450 config_models = providers.get("models", {})
451
452 # Get type-specific defaults from config, falling back to system constants
453 type_defaults = TYPE_DEFAULTS[agent_type]
454 type_config = providers.get(agent_type, {})
455 default_provider = type_config.get("provider", type_defaults["provider"])
456 default_tier = type_config.get("tier", type_defaults["tier"])
457
458 # Handle explicit "model" key in type config (overrides tier-based resolution)
459 if "model" in type_config and "tier" not in type_config:
460 default_model = type_config["model"]
461 else:
462 default_model = _resolve_model(default_provider, default_tier, config_models)
463
464 contexts = providers.get("contexts", {})
465
466 # Find matching context config
467 match_config: Optional[Dict[str, Any]] = None
468
469 if context and contexts:
470 # Check for exact match first
471 if context in contexts:
472 match_config = contexts[context]
473 else:
474 # Check glob patterns - most specific (longest non-wildcard prefix) wins
475 matches = []
476 for pattern, ctx_config in contexts.items():
477 if fnmatch.fnmatch(context, pattern):
478 specificity = len(pattern.split("*")[0])
479 matches.append((specificity, pattern, ctx_config))
480
481 if matches:
482 matches.sort(key=lambda x: x[0], reverse=True)
483 _, _, match_config = matches[0]
484
485 # No context match - check dynamic context registry for this context
486 if match_config is None:
487 # Get dynamic context registry (discovered prompts, categories, talent configs)
488 registry = get_context_registry()
489
490 # Check for matching context default (exact match first, then glob)
491 context_tier = None
492 if context:
493 if context in registry:
494 context_tier = registry[context]["tier"]
495 else:
496 # Check glob patterns
497 matches = []
498 for pattern, ctx_default in registry.items():
499 if fnmatch.fnmatch(context, pattern):
500 specificity = len(pattern.split("*")[0])
501 matches.append((specificity, ctx_default["tier"]))
502 if matches:
503 matches.sort(key=lambda x: x[0], reverse=True)
504 context_tier = matches[0][1]
505
506 if context_tier is not None:
507 model = _resolve_model(default_provider, context_tier, config_models)
508 return (default_provider, model)
509
510 return (default_provider, default_model)
511
512 # Resolve provider (from match or default)
513 provider = match_config.get("provider", default_provider)
514
515 # Resolve model: explicit model takes precedence over tier
516 if "model" in match_config:
517 model = match_config["model"]
518 elif "tier" in match_config:
519 tier = match_config["tier"]
520 # Validate tier
521 if not isinstance(tier, int) or tier < 1 or tier > 3:
522 logging.getLogger(__name__).warning(
523 "Invalid tier %r in context %r, using default", tier, context
524 )
525 tier = default_tier
526 model = _resolve_model(provider, tier, config_models)
527 else:
528 # No model or tier specified - use default tier
529 model = _resolve_model(provider, default_tier, config_models)
530
531 return (provider, model)
532
533
534def log_token_usage(
535 model: str,
536 usage: Union[Dict[str, Any], Any],
537 context: Optional[str] = None,
538 segment: Optional[str] = None,
539 type: Optional[str] = None,
540) -> None:
541 """Log token usage to journal with unified schema.
542
543 Providers normalize usage into the unified schema (see USAGE_KEYS in
544 shared.py) before returning GenerateResult. This function passes
545 through those known keys, computes total_tokens when missing, and
546 handles a few legacy field aliases from CLI backends.
547
548 Parameters
549 ----------
550 model : str
551 Model name (e.g., "gpt-5", "gemini-2.5-flash")
552 usage : dict
553 Normalized usage dict with keys from USAGE_KEYS.
554 context : str, optional
555 Context string (e.g., "module.function:123" or "talent.system.default").
556 If None, auto-detects from call stack.
557 segment : str, optional
558 Segment key (e.g., "143022_300") for attribution.
559 If None, falls back to SOL_SEGMENT environment variable.
560 type : str, optional
561 Token entry type (e.g., "generate", "cogitate").
562 """
563 from think.providers.shared import USAGE_KEYS
564
565 try:
566 journal = get_journal()
567
568 # Auto-detect calling context if not provided
569 if context is None:
570 frame = inspect.currentframe()
571 caller_frame = frame.f_back if frame else None
572
573 # Skip frames that contain "gemini" in function name
574 while caller_frame and "gemini" in caller_frame.f_code.co_name.lower():
575 caller_frame = caller_frame.f_back
576
577 if caller_frame:
578 module_name = caller_frame.f_globals.get("__name__", "unknown")
579 func_name = caller_frame.f_code.co_name
580 line_num = caller_frame.f_lineno
581
582 # Clean up module name
583 for prefix in ["think.", "observe.", "convey."]:
584 if module_name.startswith(prefix):
585 module_name = module_name[len(prefix) :]
586 break
587
588 context = f"{module_name}.{func_name}:{line_num}"
589
590 # Pass through known keys from the already-normalized usage dict.
591 normalized_usage: Dict[str, int] = {}
592 for key in USAGE_KEYS:
593 val = usage.get(key)
594 if val:
595 normalized_usage[key] = val
596
597 # Legacy alias: some CLI backends emit cached_input_tokens
598 if not normalized_usage.get("cached_tokens") and usage.get(
599 "cached_input_tokens"
600 ):
601 normalized_usage["cached_tokens"] = usage["cached_input_tokens"]
602
603 # Compute total_tokens from parts when missing (e.g. Codex CLI omits it)
604 if not normalized_usage.get("total_tokens"):
605 inp = normalized_usage.get("input_tokens", 0)
606 out = normalized_usage.get("output_tokens", 0)
607 if inp or out:
608 normalized_usage["total_tokens"] = inp + out
609
610 # Build token log entry
611 token_data = {
612 "timestamp": time.time(),
613 "model": model,
614 "context": context,
615 "usage": normalized_usage,
616 }
617
618 # Add segment: prefer parameter, fallback to env (set by think/insight, observe handlers)
619 segment_key = segment or os.getenv("SOL_SEGMENT")
620 if segment_key:
621 token_data["segment"] = segment_key
622 if type:
623 token_data["type"] = type
624
625 # Save to journal/tokens/<YYYYMMDD>.jsonl (one file per day)
626 tokens_dir = Path(journal) / "tokens"
627 tokens_dir.mkdir(exist_ok=True)
628
629 filename = time.strftime("%Y%m%d.jsonl")
630 filepath = tokens_dir / filename
631
632 # Atomic append - safe for parallel writers
633 with open(filepath, "a") as f:
634 f.write(json.dumps(token_data) + "\n")
635
636 except Exception:
637 # Silently fail - logging shouldn't break the main flow
638 pass
639
640
641def get_model_provider(model: str) -> str:
642 """Get the provider name from a model identifier.
643
644 Parameters
645 ----------
646 model : str
647 Model name (e.g., "gpt-5", "gemini-2.5-flash", "claude-sonnet-4-5")
648
649 Returns
650 -------
651 str
652 Provider name: "openai", "google", "anthropic", "ollama", or "unknown"
653 """
654 model_lower = model.lower()
655
656 if model_lower.startswith("ollama-local/"):
657 return "ollama"
658 elif model_lower.startswith("gpt"):
659 return "openai"
660 elif model_lower.startswith("gemini"):
661 return "google"
662 elif model_lower.startswith("claude"):
663 return "anthropic"
664 else:
665 return "unknown"
666
667
668def calc_token_cost(token_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
669 """Calculate cost for a token usage record.
670
671 Parameters
672 ----------
673 token_data : dict
674 Token usage record from journal logs with structure:
675 {
676 "model": "gemini-2.5-flash",
677 "usage": {
678 "input_tokens": 1500,
679 "output_tokens": 500,
680 "cached_tokens": 800,
681 "reasoning_tokens": 200,
682 ...
683 }
684 }
685
686 Returns
687 -------
688 dict or None
689 Cost breakdown:
690 {
691 "total_cost": 0.00123,
692 "input_cost": 0.00075,
693 "output_cost": 0.00048,
694 "currency": "USD"
695 }
696 Returns None if pricing unavailable or calculation fails.
697 """
698 try:
699 from genai_prices import Usage, calc_price
700
701 model = token_data.get("model")
702 usage_data = token_data.get("usage", {})
703
704 if not model or not usage_data:
705 return None
706
707 # Strip OpenAI reasoning effort suffixes for price lookup
708 for suffix in OPENAI_EFFORT_SUFFIXES:
709 if model.endswith(suffix):
710 model = model[: -len(suffix)]
711 break
712
713 # Get provider ID before aliasing (alias may change the model family)
714 provider_id = get_model_provider(model)
715 if provider_id == "unknown":
716 return None
717
718 # Ollama models are local — no cost
719 if provider_id == "ollama":
720 return {
721 "total_cost": 0.0,
722 "input_cost": 0.0,
723 "output_cost": 0.0,
724 "currency": "USD",
725 }
726
727 # Apply price aliases for models genai-prices doesn't recognize yet
728 model = MODEL_PRICE_ALIASES.get(model, model)
729
730 # Map our token fields to genai_prices Usage format
731 # Note: Gemini reports reasoning_tokens separately, but they're billed at
732 # output token rates. genai-prices doesn't have a separate field for reasoning,
733 # so we add them to output_tokens for correct pricing.
734 input_tokens = usage_data.get("input_tokens", 0)
735 output_tokens = usage_data.get("output_tokens", 0)
736 cached_tokens = usage_data.get("cached_tokens", 0)
737 reasoning_tokens = usage_data.get("reasoning_tokens", 0)
738
739 # Add reasoning tokens to output for pricing (Gemini bills them as output)
740 total_output_tokens = output_tokens + reasoning_tokens
741
742 # Create Usage object
743 usage = Usage(
744 input_tokens=input_tokens,
745 output_tokens=total_output_tokens,
746 cache_read_tokens=cached_tokens if cached_tokens > 0 else None,
747 )
748
749 # Calculate price
750 result = calc_price(
751 usage=usage,
752 model_ref=model,
753 provider_id=provider_id,
754 )
755
756 # Return simplified cost breakdown
757 return {
758 "total_cost": float(result.total_price),
759 "input_cost": float(result.input_price),
760 "output_cost": float(result.output_price),
761 "currency": "USD",
762 }
763
764 except Exception:
765 # Silently fail if pricing unavailable
766 return None
767
768
769def calc_agent_cost(
770 model: Optional[str], usage: Optional[Dict[str, Any]]
771) -> Optional[float]:
772 """Calculate total cost for an agent run from model and usage data.
773
774 Convenience wrapper around calc_token_cost for agent cost lookups.
775
776 Returns total cost in USD, or None if data is missing or pricing unavailable.
777 """
778 if not model or not usage:
779 return None
780 try:
781 cost_data = calc_token_cost({"model": model, "usage": usage})
782 if cost_data:
783 return cost_data["total_cost"]
784 except Exception:
785 return None
786 return None
787
788
789def _normalize_legacy_context(ctx: str) -> str:
790 """Normalize legacy token-log context strings to the talent namespace."""
791 if ctx.startswith(_LEGACY_CONTEXT_PREFIX):
792 return _TALENT_CONTEXT_PREFIX + ctx[len(_LEGACY_CONTEXT_PREFIX) :]
793 return ctx
794
795
796def iter_token_log(day: str) -> Any:
797 """Iterate over token log entries for a given day.
798
799 Yields parsed JSON entries from the token log file, skipping empty lines
800 and invalid JSON. This is a shared utility for code that processes token logs.
801
802 Parameters
803 ----------
804 day : str
805 Day in YYYYMMDD format.
806
807 Yields
808 ------
809 dict
810 Parsed token log entry with fields: timestamp, model, context, usage,
811 and optionally segment.
812 """
813 journal = get_journal()
814 log_path = Path(journal) / "tokens" / f"{day}.jsonl"
815
816 if not log_path.exists():
817 return
818
819 with open(log_path, "r") as f:
820 for line in f:
821 line = line.strip()
822 if not line:
823 continue
824 try:
825 entry = json.loads(line)
826 ctx = entry.get("context")
827 if isinstance(ctx, str):
828 entry["context"] = _normalize_legacy_context(ctx)
829 yield entry
830 except json.JSONDecodeError:
831 continue
832
833
834def get_usage_cost(
835 day: str,
836 segment: Optional[str] = None,
837 context: Optional[str] = None,
838) -> Dict[str, Any]:
839 """Get aggregated token usage and cost for a day, optionally filtered.
840
841 This is a shared utility for apps that want to display cost information
842 for segments, agent runs, or other contexts.
843
844 Parameters
845 ----------
846 day : str
847 Day in YYYYMMDD format.
848 segment : str, optional
849 Filter to entries with this exact segment key.
850 context : str, optional
851 Filter to entries where context starts with this prefix.
852 For example, "talent.system" matches "talent.system.default".
853
854 Returns
855 -------
856 dict
857 Aggregated usage data:
858 {
859 "requests": int,
860 "tokens": int,
861 "cost": float, # USD
862 }
863 Returns zeros if no matching entries or day file doesn't exist.
864 """
865 result = {"requests": 0, "tokens": 0, "cost": 0.0}
866
867 for entry in iter_token_log(day):
868 # Apply filters
869 if segment is not None and entry.get("segment") != segment:
870 continue
871 if context is not None:
872 entry_context = entry.get("context", "")
873 if not entry_context.startswith(context):
874 continue
875
876 # Skip unknown providers (can't calculate cost)
877 model = entry.get("model", "unknown")
878 if get_model_provider(model) == "unknown":
879 continue
880
881 # Accumulate
882 usage = entry.get("usage", {})
883 result["requests"] += 1
884 result["tokens"] += usage.get("total_tokens", 0) or 0
885
886 cost_data = calc_token_cost(entry)
887 if cost_data:
888 result["cost"] += cost_data["total_cost"]
889
890 return result
891
892
893# ---------------------------------------------------------------------------
894# Unified generate/agenerate with provider routing
895# ---------------------------------------------------------------------------
896
897
898def _validate_json_response(result: Dict[str, Any], json_output: bool) -> None:
899 """Validate response for JSON output mode.
900
901 Raises IncompleteJSONError if finish_reason indicates truncation.
902 """
903 if not json_output:
904 return
905
906 finish_reason = result.get("finish_reason")
907 if finish_reason and finish_reason != "stop":
908 raise IncompleteJSONError(
909 reason=finish_reason,
910 partial_text=result.get("text", ""),
911 )
912
913
914def _validate_schema(text: str, schema: dict) -> dict:
915 """Validate JSON text against a JSON Schema and log any violations."""
916
917 def truncate_repr(value: Any) -> str:
918 value_repr = repr(value)
919 if len(value_repr) <= 80:
920 return value_repr
921 return value_repr[:77] + "..."
922
923 def build_pointer(path: Any) -> str:
924 segments = list(path)
925 if not segments:
926 return ""
927 escaped_segments = []
928 for segment in segments:
929 escaped = str(segment).replace("~", "~0").replace("/", "~1")
930 escaped_segments.append(escaped)
931 return "/" + "/".join(escaped_segments)
932
933 try:
934 parsed = json.loads(text)
935 except ValueError as exc:
936 error = {
937 "path": "",
938 "constraint": "json_parse",
939 "message": str(exc),
940 }
941 logger.warning(
942 "schema_validation: %s: %s: %s (value=%s)",
943 "",
944 "json_parse",
945 str(exc),
946 truncate_repr(text),
947 )
948 return {"valid": False, "errors": [error]}
949
950 errors = []
951 try:
952 validator = Draft202012Validator(schema)
953 validation_errors = list(validator.iter_errors(parsed))
954 except Exception as exc:
955 error = {
956 "path": "",
957 "constraint": "schema_validation",
958 "message": str(exc),
959 }
960 logger.warning(
961 "schema_validation: %s: %s: %s (value=%s)",
962 "",
963 "schema_validation",
964 str(exc),
965 truncate_repr(parsed),
966 )
967 return {"valid": False, "errors": [error]}
968
969 for error in validation_errors:
970 path = build_pointer(error.absolute_path)
971 constraint = str(error.validator)
972 message = error.message
973 errors.append(
974 {
975 "path": path,
976 "constraint": constraint,
977 "message": message,
978 }
979 )
980 logger.warning(
981 "schema_validation: %s: %s: %s (value=%s)",
982 path,
983 constraint,
984 message,
985 truncate_repr(error.instance),
986 )
987
988 return {"valid": len(errors) == 0, "errors": errors}
989
990
991def generate(
992 contents: Union[str, List[Any]],
993 context: str,
994 temperature: float = 0.3,
995 max_output_tokens: int = 8192 * 2,
996 system_instruction: Optional[str] = None,
997 json_output: bool = False,
998 *,
999 json_schema: dict | None = None,
1000 thinking_budget: Optional[int] = None,
1001 timeout_s: Optional[float] = None,
1002 **kwargs: Any,
1003) -> str:
1004 """Generate text using the configured provider for the given context.
1005
1006 Routes the request to the appropriate backend (Google, OpenAI, or Anthropic)
1007 based on the providers configuration in journal.json.
1008
1009 Parameters
1010 ----------
1011 contents : str or List
1012 The content to send to the model.
1013 context : str
1014 Context string for routing and token logging (e.g., "talent.system.meetings").
1015 This is required and determines which provider/model to use.
1016 temperature : float
1017 Temperature for generation (default: 0.3).
1018 max_output_tokens : int
1019 Maximum tokens for the model's response output.
1020 system_instruction : str, optional
1021 System instruction for the model.
1022 json_output : bool
1023 Whether to request JSON response format.
1024 json_schema : dict, optional
1025 JSON Schema to request structured output from the provider. When supplied,
1026 this forces json_output=True and runs advisory local validation on the
1027 returned text after truncation checks.
1028 thinking_budget : int, optional
1029 Token budget for model thinking (ignored by providers that don't support it).
1030 timeout_s : float, optional
1031 Request timeout in seconds.
1032 **kwargs
1033 Additional provider-specific options passed through to the backend.
1034
1035 Returns
1036 -------
1037 str
1038 Response text from the model.
1039
1040 Raises
1041 ------
1042 ValueError
1043 If the resolved provider is not supported.
1044 IncompleteJSONError
1045 If json_output=True and response was truncated.
1046 """
1047 from think.providers import get_provider_module
1048
1049 if json_schema is not None:
1050 json_output = True
1051
1052 # Allow model override via kwargs (used by callers with explicit model selection)
1053 model_override = kwargs.pop("model", None)
1054
1055 provider, model = resolve_provider(context, "generate")
1056 if model_override:
1057 model = model_override
1058
1059 # Get provider module via registry (raises ValueError for unknown providers)
1060 provider_mod = get_provider_module(provider)
1061
1062 # Call provider's run_generate (returns GenerateResult)
1063 result = provider_mod.run_generate(
1064 contents=contents,
1065 model=model,
1066 temperature=temperature,
1067 max_output_tokens=max_output_tokens,
1068 system_instruction=system_instruction,
1069 json_output=json_output,
1070 json_schema=json_schema,
1071 thinking_budget=thinking_budget,
1072 timeout_s=timeout_s,
1073 **kwargs,
1074 )
1075
1076 # Log token usage centrally (before validation so truncated responses
1077 # still get their usage recorded)
1078 if result.get("usage"):
1079 log_token_usage(
1080 model=model,
1081 usage=result["usage"],
1082 context=context,
1083 type="generate",
1084 )
1085
1086 # Validate JSON output if requested
1087 _validate_json_response(result, json_output)
1088
1089 if json_schema is not None:
1090 _validate_schema(result["text"], json_schema)
1091
1092 return result["text"]
1093
1094
1095# ---------------------------------------------------------------------------
1096# Provider Health & Fallback Helpers
1097# ---------------------------------------------------------------------------
1098
1099
1100def get_backup_provider(agent_type: str) -> Optional[str]:
1101 """Get the backup provider for the given agent type.
1102
1103 Reads from the type-specific section in journal config, falling back
1104 to TYPE_DEFAULTS.
1105
1106 Returns None if backup would be the same as the primary provider.
1107 """
1108 type_defaults = TYPE_DEFAULTS[agent_type]
1109 config = get_config()
1110 providers_config = config.get("providers", {})
1111 type_config = providers_config.get(agent_type, {})
1112 primary_provider = type_config.get("provider", type_defaults["provider"])
1113 backup = type_config.get("backup", type_defaults["backup"])
1114 if backup == primary_provider:
1115 return None
1116 return backup
1117
1118
1119def load_health_status() -> Optional[dict]:
1120 """Load health status from journal/health/agents.json.
1121
1122 Returns parsed dict or None if file is missing/unreadable.
1123 """
1124 try:
1125 health_path = Path(get_journal()) / "health" / "agents.json"
1126 with open(health_path) as f:
1127 return json.load(f)
1128 except (FileNotFoundError, json.JSONDecodeError, OSError):
1129 return None
1130
1131
1132def is_provider_healthy(provider: str, health_data: Optional[dict]) -> bool:
1133 """Check if a provider is healthy based on health data.
1134
1135 Returns True (assume healthy) when:
1136 - health_data is None (no data available)
1137 - No results exist for the provider
1138 - Any result for the provider has ok=True
1139
1140 Returns False only when all results for the provider have ok=False.
1141 """
1142 if health_data is None:
1143 return True
1144 results = health_data.get("results", [])
1145 provider_results = [r for r in results if r.get("provider") == provider]
1146 if not provider_results:
1147 return True
1148 return any(r.get("ok") for r in provider_results)
1149
1150
1151def should_recheck_health(health_data: Optional[dict]) -> bool:
1152 """Check if health data is stale (>1 hour old).
1153
1154 Returns False when health_data is None or on parse errors.
1155 """
1156 if health_data is None:
1157 return False
1158 checked_at = health_data.get("checked_at")
1159 if not checked_at:
1160 return False
1161 try:
1162 checked_time = datetime.fromisoformat(checked_at)
1163 if checked_time.tzinfo is None:
1164 checked_time = checked_time.replace(tzinfo=timezone.utc)
1165 age = datetime.now(timezone.utc) - checked_time
1166 return age.total_seconds() > 3600
1167 except (ValueError, TypeError):
1168 return False
1169
1170
1171def request_health_recheck() -> None:
1172 """Request a health re-check by spawning a background process.
1173
1174 Fire-and-forget; errors are logged but never propagated.
1175 """
1176 try:
1177 subprocess.Popen(
1178 ["sol", "providers", "check", "--targeted"],
1179 stdout=subprocess.DEVNULL,
1180 stderr=subprocess.DEVNULL,
1181 )
1182 except Exception:
1183 logging.getLogger(__name__).debug(
1184 "Failed to request health recheck", exc_info=True
1185 )
1186
1187
1188def generate_with_result(
1189 contents: Union[str, List[Any]],
1190 context: str,
1191 temperature: float = 0.3,
1192 max_output_tokens: int = 8192 * 2,
1193 system_instruction: Optional[str] = None,
1194 json_output: bool = False,
1195 *,
1196 json_schema: dict | None = None,
1197 thinking_budget: Optional[int] = None,
1198 timeout_s: Optional[float] = None,
1199 **kwargs: Any,
1200) -> dict:
1201 """Generate text and return full result with usage data.
1202
1203 Same as generate() but returns the full GenerateResult dict instead of
1204 just the text. Used by cortex-managed generators that need usage data
1205 for event emission.
1206
1207 Parameters
1208 ----------
1209 contents : str or List
1210 The content to send to the model.
1211 context : str
1212 Context string for routing and token logging.
1213 temperature : float
1214 Temperature for generation (default: 0.3).
1215 max_output_tokens : int
1216 Maximum tokens for the model's response output.
1217 system_instruction : str, optional
1218 System instruction for the model.
1219 json_output : bool
1220 Whether to request JSON response format.
1221 json_schema : dict, optional
1222 JSON Schema to request structured output from the provider. When supplied,
1223 this forces json_output=True and runs advisory local validation on the
1224 returned text after truncation checks.
1225 thinking_budget : int, optional
1226 Token budget for model thinking (ignored by providers that don't support it).
1227 timeout_s : float, optional
1228 Request timeout in seconds.
1229 **kwargs
1230 Additional provider-specific options passed through to the backend.
1231
1232 Returns
1233 -------
1234 dict
1235 GenerateResult with: text, usage, finish_reason, thinking, and
1236 schema_validation when json_schema is supplied. Validation is advisory
1237 and runs after truncation checks succeed.
1238 """
1239 from think.providers import get_provider_module
1240
1241 if json_schema is not None:
1242 json_output = True
1243
1244 model_override = kwargs.pop("model", None)
1245 provider_override = kwargs.pop("provider", None)
1246
1247 provider, model = resolve_provider(context, "generate")
1248 if provider_override:
1249 provider = provider_override
1250 if not model_override:
1251 model = resolve_model_for_provider(context, provider, "generate")
1252 if model_override:
1253 model = model_override
1254
1255 provider_mod = get_provider_module(provider)
1256
1257 result = provider_mod.run_generate(
1258 contents=contents,
1259 model=model,
1260 temperature=temperature,
1261 max_output_tokens=max_output_tokens,
1262 system_instruction=system_instruction,
1263 json_output=json_output,
1264 json_schema=json_schema,
1265 thinking_budget=thinking_budget,
1266 timeout_s=timeout_s,
1267 **kwargs,
1268 )
1269
1270 # Log token usage centrally (before validation so truncated responses
1271 # still get their usage recorded)
1272 if result.get("usage"):
1273 log_token_usage(
1274 model=model,
1275 usage=result["usage"],
1276 context=context,
1277 type="generate",
1278 )
1279
1280 # Validate JSON output if requested
1281 _validate_json_response(result, json_output)
1282
1283 if json_schema is not None:
1284 result["schema_validation"] = _validate_schema(result["text"], json_schema)
1285
1286 return result
1287
1288
1289async def agenerate(
1290 contents: Union[str, List[Any]],
1291 context: str,
1292 temperature: float = 0.3,
1293 max_output_tokens: int = 8192 * 2,
1294 system_instruction: Optional[str] = None,
1295 json_output: bool = False,
1296 *,
1297 json_schema: dict | None = None,
1298 thinking_budget: Optional[int] = None,
1299 timeout_s: Optional[float] = None,
1300 **kwargs: Any,
1301) -> str:
1302 """Async generate text using the configured provider for the given context.
1303
1304 Routes the request to the appropriate backend (Google, OpenAI, or Anthropic)
1305 based on the providers configuration in journal.json.
1306
1307 Parameters
1308 ----------
1309 contents : str or List
1310 The content to send to the model.
1311 context : str
1312 Context string for routing and token logging (e.g., "talent.system.meetings").
1313 This is required and determines which provider/model to use.
1314 temperature : float
1315 Temperature for generation (default: 0.3).
1316 max_output_tokens : int
1317 Maximum tokens for the model's response output.
1318 system_instruction : str, optional
1319 System instruction for the model.
1320 json_output : bool
1321 Whether to request JSON response format.
1322 json_schema : dict, optional
1323 JSON Schema to request structured output from the provider. When supplied,
1324 this forces json_output=True and runs advisory local validation on the
1325 returned text after truncation checks.
1326 thinking_budget : int, optional
1327 Token budget for model thinking (ignored by providers that don't support it).
1328 timeout_s : float, optional
1329 Request timeout in seconds.
1330 **kwargs
1331 Additional provider-specific options passed through to the backend.
1332
1333 Returns
1334 -------
1335 str
1336 Response text from the model.
1337
1338 Raises
1339 ------
1340 ValueError
1341 If the resolved provider is not supported.
1342 IncompleteJSONError
1343 If json_output=True and response was truncated.
1344 """
1345 from think.providers import get_provider_module
1346
1347 if json_schema is not None:
1348 json_output = True
1349
1350 # Allow model override via kwargs (used by Batch for explicit model selection)
1351 model_override = kwargs.pop("model", None)
1352
1353 provider, model = resolve_provider(context, "generate")
1354 if model_override:
1355 model = model_override
1356
1357 # Get provider module via registry (raises ValueError for unknown providers)
1358 provider_mod = get_provider_module(provider)
1359
1360 # Call provider's run_agenerate (returns GenerateResult)
1361 result = await provider_mod.run_agenerate(
1362 contents=contents,
1363 model=model,
1364 temperature=temperature,
1365 max_output_tokens=max_output_tokens,
1366 system_instruction=system_instruction,
1367 json_output=json_output,
1368 json_schema=json_schema,
1369 thinking_budget=thinking_budget,
1370 timeout_s=timeout_s,
1371 **kwargs,
1372 )
1373
1374 # Log token usage centrally (before validation so truncated responses
1375 # still get their usage recorded)
1376 if result.get("usage"):
1377 log_token_usage(
1378 model=model,
1379 usage=result["usage"],
1380 context=context,
1381 type="generate",
1382 )
1383
1384 # Validate JSON output if requested
1385 _validate_json_response(result, json_output)
1386
1387 if json_schema is not None:
1388 _validate_schema(result["text"], json_schema)
1389
1390 return result["text"]
1391
1392
1393__all__ = [
1394 # Provider configuration
1395 "TYPE_DEFAULTS",
1396 "PROMPT_PATHS",
1397 "get_context_registry",
1398 # Model constants (used by provider backends for defaults)
1399 "GEMINI_FLASH",
1400 "GPT_5",
1401 "CLAUDE_SONNET_4",
1402 # Unified API
1403 "generate",
1404 "generate_with_result",
1405 "agenerate",
1406 "resolve_provider",
1407 # Utilities
1408 "log_token_usage",
1409 "calc_token_cost",
1410 "calc_agent_cost",
1411 "get_usage_cost",
1412 "iter_token_log",
1413 "get_model_provider",
1414]