Unified Agent + reusable Go agent core.
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

refactor: unify llm cache ttl configuration

Lyric d47a2372 15d171c5

+811 -102
+23 -1
agent/engine.go
··· 37 37 } 38 38 } 39 39 40 + func WithSystemPromptCacheControl(ctrl *llm.CacheControl) Option { 41 + return func(e *Engine) { 42 + if ctrl == nil { 43 + return 44 + } 45 + cloned := *ctrl 46 + e.systemPromptCacheControl = &cloned 47 + } 48 + } 49 + 40 50 // WithPromptBuilder replaces the default system prompt builder. 41 51 // This hook is intended for tests in this repository. 42 52 func WithPromptBuilder(fn func(*tools.Registry, string) string) Option { ··· 152 162 153 163 engineToolsConfig EngineToolsConfig 154 164 165 + systemPromptCacheControl *llm.CacheControl 166 + 155 167 promptBuilder func(registry *tools.Registry, task string) string 156 168 paramsBuilder func(opts RunOptions) map[string]any 157 169 onToolStart func(ctx *Context, toolName string) ··· 238 250 systemPrompt = BuildSystemPrompt(e.registry, e.spec) 239 251 } 240 252 241 - messages := []llm.Message{{Role: "system", Content: systemPrompt}} 253 + systemMessage := llm.Message{Role: "system", Content: systemPrompt} 254 + if e.systemPromptCacheControl != nil && strings.TrimSpace(systemPrompt) != "" { 255 + ctrl := *e.systemPromptCacheControl 256 + systemMessage.Parts = []llm.Part{{ 257 + Type: llm.PartTypeText, 258 + Text: systemPrompt, 259 + CacheControl: &ctrl, 260 + }} 261 + } 262 + 263 + messages := []llm.Message{systemMessage} 242 264 243 265 injectedMeta := runtimeclock.WithRuntimeClockMeta(opts.Meta, time.Now()) 244 266 if _, ok := injectedMeta["host_os"]; !ok {
+58
agent/engine_hooks_test.go
··· 185 185 } 186 186 } 187 187 188 + func TestWithSystemPromptCacheControl_SetsField(t *testing.T) { 189 + client := newMockClient(finalResponse("ok")) 190 + e := New( 191 + client, 192 + baseRegistry(), 193 + baseCfg(), 194 + DefaultPromptSpec(), 195 + WithSystemPromptCacheControl(&llm.CacheControl{TTL: "5m"}), 196 + ) 197 + if e.systemPromptCacheControl == nil { 198 + t.Fatal("expected systemPromptCacheControl to be set") 199 + } 200 + if e.systemPromptCacheControl.TTL != "5m" { 201 + t.Fatalf("systemPromptCacheControl.TTL = %q, want 5m", e.systemPromptCacheControl.TTL) 202 + } 203 + } 204 + 188 205 func TestWithOnToolSuccess_SetsField(t *testing.T) { 189 206 fn := func(ctx *Context, toolName string) {} 190 207 client := newMockClient(finalResponse("ok")) ··· 468 485 expected := BuildSystemPrompt(reg, DefaultPromptSpec()) 469 486 if calls[0].Messages[0].Content != expected { 470 487 t.Error("expected default BuildSystemPrompt to be used when promptBuilder is nil") 488 + } 489 + } 490 + 491 + func TestRun_AddsSystemPromptCacheControlPart(t *testing.T) { 492 + client := newMockClient(finalResponse("ok")) 493 + reg := baseRegistry() 494 + 495 + e := New( 496 + client, 497 + reg, 498 + baseCfg(), 499 + DefaultPromptSpec(), 500 + WithSystemPromptCacheControl(&llm.CacheControl{TTL: "1h"}), 501 + ) 502 + 503 + _, _, err := e.Run(context.Background(), "test task", RunOptions{}) 504 + if err != nil { 505 + t.Fatalf("unexpected error: %v", err) 506 + } 507 + 508 + calls := client.allCalls() 509 + if len(calls) == 0 { 510 + t.Fatal("expected at least one LLM call") 511 + } 512 + 513 + expected := BuildSystemPrompt(reg, DefaultPromptSpec()) 514 + msg := calls[0].Messages[0] 515 + if msg.Content != expected { 516 + t.Fatalf("system prompt content = %q, want %q", msg.Content, expected) 517 + } 518 + if len(msg.Parts) != 1 { 519 + t.Fatalf("system prompt parts len = %d, want 1", len(msg.Parts)) 520 + } 521 + if msg.Parts[0].Type != llm.PartTypeText { 522 + t.Fatalf("system prompt part type = %q, want text", msg.Parts[0].Type) 523 + } 524 + if msg.Parts[0].Text != expected { 525 + t.Fatalf("system prompt part text = %q, want %q", msg.Parts[0].Text, expected) 526 + } 527 + if msg.Parts[0].CacheControl == nil || msg.Parts[0].CacheControl.TTL != "1h" { 528 + t.Fatalf("system prompt cache control = %#v, want TTL 1h", msg.Parts[0].CacheControl) 471 529 } 472 530 } 473 531
+3
agent/local_subtask_runner.go
··· 75 75 if r.engine.guard != nil { 76 76 subOpts = append(subOpts, WithGuard(r.engine.guard)) 77 77 } 78 + if r.engine.systemPromptCacheControl != nil { 79 + subOpts = append(subOpts, WithSystemPromptCacheControl(r.engine.systemPromptCacheControl)) 80 + } 78 81 79 82 subEngine := New(client, req.Registry, Config{ 80 83 MaxSteps: r.engine.config.MaxSteps,
+9
assets/config/config.example.yaml
··· 26 26 # headers: 27 27 # YOUR-HEADER: "${YOUR_HEADER_VALUE}" 28 28 # "-X-ABC-TOKEN": "${ABC_TOKEN}" 29 + # Cache hint shared across providers. 30 + # Values: 31 + # - off 32 + # - short 33 + # - long 34 + # - Go duration strings such as "5m", "1h", "24h" 35 + # The runtime maps this to provider-supported cache buckets automatically. 36 + cache_ttl: "short" 29 37 # Per-LLM HTTP request timeout (0 uses provider default). 30 38 request_timeout: "90s" 31 39 # Optional default temperature. If empty/unset, do not call uniai.WithTemperature(...). ··· 58 66 # profiles: 59 67 # cheap: 60 68 # model: "gpt-4.1-mini" 69 + # cache_ttl: "long" 61 70 # reasoning: 62 71 # provider: xai 63 72 # model: "grok-4.1-fast-reasoning"
+7
cmd/mistermorph/runcmd/run.go
··· 147 147 APIBase: mainCfg.Endpoint, 148 148 Model: mainCfg.Model, 149 149 }) 150 + systemPromptCacheControl, err := llmutil.SystemPromptCacheControl(mainRoute.Values.CacheTTL) 151 + if err != nil { 152 + return err 153 + } 150 154 151 155 reg := (*tools.Registry)(nil) 152 156 if deps.RegistryFromViper != nil { ··· 213 217 } 214 218 opts = append(opts, agent.WithLogger(logger)) 215 219 opts = append(opts, agent.WithLogOptions(logOpts)) 220 + if systemPromptCacheControl != nil { 221 + opts = append(opts, agent.WithSystemPromptCacheControl(systemPromptCacheControl)) 222 + } 216 223 if !isHeartbeat { 217 224 opts = append(opts, agent.WithPlanStepUpdate(func(runCtx *agent.Context, update agent.PlanStepUpdate) { 218 225 if payload := formatPlanProgressUpdate(runCtx, update); payload != "" {
+1
docs/configuration.md
··· 260 260 - Most providers use `llm.endpoint`, `llm.api_key`, and `llm.model`. 261 261 - Azure uses `llm.azure.deployment`. 262 262 - Bedrock uses `llm.bedrock.*`. 263 + - `llm.cache_ttl` controls cache intent across providers. Supported values are `off`, `short`, `long`, and Go duration strings such as `5m`, `1h`, and `24h`. The runtime maps this to each provider's supported cache buckets. 263 264 - `llm.tools_emulation_mode` controls tool-call emulation for models without native tool calling. 264 265 - `llm.profiles` defines named profile overrides. 265 266 - `llm.routes` routes semantic purposes such as `main_loop`, `addressing`, `heartbeat`, `plan_create`, and `memory_draft`.
+7
integration/runtime.go
··· 146 146 if err != nil { 147 147 return nil, err 148 148 } 149 + systemPromptCacheControl, err := llmutil.SystemPromptCacheControl(mainRoute.Values.CacheTTL) 150 + if err != nil { 151 + return nil, err 152 + } 149 153 model := strings.TrimSpace(mainRoute.ClientConfig.Model) 150 154 var requestInspector *llminspect.RequestInspector 151 155 var promptInspector *llminspect.PromptInspector ··· 261 265 ACPSpawnEnabled: snap.Registry.ToolsACPSpawnEnabled && rt.isBuiltinToolSelected(toolsutil.BuiltinACPSpawn), 262 266 }), 263 267 agent.WithACPAgents(snap.ACPAgents), 268 + } 269 + if systemPromptCacheControl != nil { 270 + opts = append(opts, agent.WithSystemPromptCacheControl(systemPromptCacheControl)) 264 271 } 265 272 if g := rt.buildGuard(snap.Guard, logger); g != nil { 266 273 opts = append(opts, agent.WithGuard(g))
+41 -34
internal/channelruntime/heartbeat/run.go
··· 66 66 if err != nil { 67 67 return err 68 68 } 69 + systemPromptCacheControl, err := llmutil.SystemPromptCacheControl(route.Values.CacheTTL) 70 + if err != nil { 71 + return err 72 + } 69 73 client, err := depsutil.CreateClient(d.CreateLLMClient, route) 70 74 if err != nil { 71 75 return err ··· 122 126 defer wg.Done() 123 127 124 128 summary, runErr := runHeartbeatTask(ctx, d, heartbeatTaskOptions{ 125 - Logger: logger, 126 - LogOptions: logOpts, 127 - Client: client, 128 - Model: model, 129 - Task: task, 130 - Meta: meta, 131 - TaskRunID: taskRunID, 132 - BaseRegistry: baseReg, 133 - SharedGuard: sharedGuard, 134 - Config: cfg, 135 - EngineToolsConfig: opts.EngineToolsConfig, 136 - TaskTimeout: opts.TaskTimeout, 137 - WakeSignal: wakeSignal, 138 - MemoryOrchestrator: orchestrator, 139 - MemoryProjectionWorker: projectionWorker, 140 - MemoryInjectionEnabled: opts.MemoryInjectionEnabled, 141 - MemoryInjectionMaxItems: opts.MemoryInjectionMaxItems, 129 + Logger: logger, 130 + LogOptions: logOpts, 131 + Client: client, 132 + Model: model, 133 + Task: task, 134 + Meta: meta, 135 + TaskRunID: taskRunID, 136 + BaseRegistry: baseReg, 137 + SharedGuard: sharedGuard, 138 + Config: cfg, 139 + EngineToolsConfig: opts.EngineToolsConfig, 140 + TaskTimeout: opts.TaskTimeout, 141 + WakeSignal: wakeSignal, 142 + SystemPromptCacheControl: systemPromptCacheControl, 143 + MemoryOrchestrator: orchestrator, 144 + MemoryProjectionWorker: projectionWorker, 145 + MemoryInjectionEnabled: opts.MemoryInjectionEnabled, 146 + MemoryInjectionMaxItems: opts.MemoryInjectionMaxItems, 142 147 }) 143 148 if runErr != nil { 144 149 displayErr := depsutil.FormatRuntimeError(runErr) ··· 194 199 } 195 200 196 201 type heartbeatTaskOptions struct { 197 - Logger *slog.Logger 198 - LogOptions agent.LogOptions 199 - Client llm.Client 200 - Model string 201 - Task string 202 - Meta map[string]any 203 - TaskRunID string 204 - BaseRegistry *tools.Registry 205 - SharedGuard *guard.Guard 206 - Config agent.Config 207 - EngineToolsConfig agent.EngineToolsConfig 208 - TaskTimeout time.Duration 209 - WakeSignal daemonruntime.PokeInput 210 - MemoryOrchestrator *memoryruntime.Orchestrator 211 - MemoryProjectionWorker *memoryruntime.ProjectionWorker 212 - MemoryInjectionEnabled bool 213 - MemoryInjectionMaxItems int 202 + Logger *slog.Logger 203 + LogOptions agent.LogOptions 204 + Client llm.Client 205 + Model string 206 + Task string 207 + Meta map[string]any 208 + TaskRunID string 209 + BaseRegistry *tools.Registry 210 + SharedGuard *guard.Guard 211 + Config agent.Config 212 + EngineToolsConfig agent.EngineToolsConfig 213 + TaskTimeout time.Duration 214 + WakeSignal daemonruntime.PokeInput 215 + SystemPromptCacheControl *llm.CacheControl 216 + MemoryOrchestrator *memoryruntime.Orchestrator 217 + MemoryProjectionWorker *memoryruntime.ProjectionWorker 218 + MemoryInjectionEnabled bool 219 + MemoryInjectionMaxItems int 214 220 } 215 221 216 222 func runHeartbeatTask(ctx context.Context, d Dependencies, opts heartbeatTaskOptions) (string, error) { ··· 268 274 agent.WithLogOptions(opts.LogOptions), 269 275 agent.WithEngineToolsConfig(opts.EngineToolsConfig), 270 276 agent.WithACPAgents(depsutil.ACPAgentsFromCommon(depsutil.CommonFromHeartbeat(d))), 277 + agent.WithSystemPromptCacheControl(opts.SystemPromptCacheControl), 271 278 agent.WithGuard(opts.SharedGuard), 272 279 ) 273 280 final, _, err := engine.Run(runCtx, task, agent.RunOptions{
+7
internal/channelruntime/taskruntime/runtime.go
··· 186 186 return RunResult{}, err 187 187 } 188 188 defer closeRuntimeClient(logger, mainClient) 189 + systemPromptCacheControl, err := llmutil.SystemPromptCacheControl(mainRoute.Values.CacheTTL) 190 + if err != nil { 191 + return RunResult{}, err 192 + } 189 193 model := strings.TrimSpace(req.Model) 190 194 if model == "" { 191 195 model = strings.TrimSpace(mainRoute.ClientConfig.Model) ··· 236 240 agent.WithSubtaskRunner(rt), 237 241 agent.WithEngineToolsConfig(engineToolsConfig), 238 242 agent.WithACPAgents(rt.ACPAgents), 243 + } 244 + if systemPromptCacheControl != nil { 245 + engineOpts = append(engineOpts, agent.WithSystemPromptCacheControl(systemPromptCacheControl)) 239 246 } 240 247 if rt.SharedGuard != nil { 241 248 engineOpts = append(engineOpts, agent.WithGuard(rt.SharedGuard))
+1
internal/configdefaults/defaults.go
··· 17 17 v.SetDefault("llm.endpoint", "") 18 18 v.SetDefault("llm.model", "") 19 19 v.SetDefault("llm.api_key", "") 20 + v.SetDefault("llm.cache_ttl", "short") 20 21 v.SetDefault("llm.request_timeout", 90*time.Second) 21 22 v.SetDefault("llm.tools_emulation_mode", "off") 22 23 v.SetDefault("llm.cloudflare.account_id", "")
+45 -21
internal/llmutil/llmutil.go
··· 25 25 APIKey string `config:"llm.api_key"` 26 26 Model string `config:"llm.model"` 27 27 Headers map[string]string 28 + CacheTTL string `config:"llm.cache_ttl"` 28 29 AzureDeployment string `config:"llm.azure.deployment"` 29 30 RequestTimeoutRaw string `config:"llm.request_timeout"` 30 31 ToolsEmulationMode string `config:"llm.tools_emulation_mode"` ··· 49 50 return RuntimeValues{} 50 51 } 51 52 return RuntimeValues{ 52 - Provider: strings.TrimSpace(r.GetString("llm.provider")), 53 - Endpoint: strings.TrimSpace(r.GetString("llm.endpoint")), 54 - APIKey: strings.TrimSpace(r.GetString("llm.api_key")), 55 - Model: strings.TrimSpace(r.GetString("llm.model")), 56 - Headers: loadStringMapKeyFromReader(r, "llm.headers"), 57 - AzureDeployment: strings.TrimSpace(r.GetString("llm.azure.deployment")), 58 - RequestTimeoutRaw: strings.TrimSpace(r.GetString("llm.request_timeout")), 59 - ToolsEmulationMode: strings.TrimSpace(r.GetString("llm.tools_emulation_mode")), 60 - TemperatureRaw: strings.TrimSpace(r.GetString("llm.temperature")), 61 - ReasoningEffortRaw: strings.TrimSpace(r.GetString("llm.reasoning_effort")), 62 - ReasoningBudgetRaw: strings.TrimSpace(r.GetString("llm.reasoning_budget_tokens")), 63 - PricingFile: strings.TrimSpace(r.GetString("llm.pricing_file")), 64 - ConfigPath: strings.TrimSpace(r.GetString("config")), 65 - Profiles: loadLLMProfilesFromReader(r), 66 - Routes: loadLLMRoutesFromReader(r), 67 - BedrockAWSKey: firstNonEmpty(r.GetString("llm.bedrock.aws_key"), r.GetString("llm.aws.key")), 68 - BedrockAWSSecret: firstNonEmpty(r.GetString("llm.bedrock.aws_secret"), r.GetString("llm.aws.secret")), 69 - BedrockAWSRegion: firstNonEmpty(r.GetString("llm.bedrock.region"), r.GetString("llm.aws.region")), 70 - BedrockModelARN: firstNonEmpty(r.GetString("llm.bedrock.model_arn"), r.GetString("llm.aws.bedrock_model_arn")), 71 - CloudflareAccountID: firstNonEmpty(r.GetString("llm.cloudflare.account_id")), 72 - CloudflareAPIToken: firstNonEmpty(r.GetString("llm.cloudflare.api_token")), 53 + Provider: strings.TrimSpace(r.GetString("llm.provider")), 54 + Endpoint: strings.TrimSpace(r.GetString("llm.endpoint")), 55 + APIKey: strings.TrimSpace(r.GetString("llm.api_key")), 56 + Model: strings.TrimSpace(r.GetString("llm.model")), 57 + Headers: loadStringMapKeyFromReader(r, "llm.headers"), 58 + CacheTTL: strings.TrimSpace(r.GetString("llm.cache_ttl")), 59 + AzureDeployment: strings.TrimSpace(r.GetString("llm.azure.deployment")), 60 + RequestTimeoutRaw: strings.TrimSpace(r.GetString("llm.request_timeout")), 61 + ToolsEmulationMode: strings.TrimSpace(r.GetString("llm.tools_emulation_mode")), 62 + TemperatureRaw: strings.TrimSpace(r.GetString("llm.temperature")), 63 + ReasoningEffortRaw: strings.TrimSpace(r.GetString("llm.reasoning_effort")), 64 + ReasoningBudgetRaw: strings.TrimSpace(r.GetString("llm.reasoning_budget_tokens")), 65 + PricingFile: strings.TrimSpace(r.GetString("llm.pricing_file")), 66 + ConfigPath: strings.TrimSpace(r.GetString("config")), 67 + Profiles: loadLLMProfilesFromReader(r), 68 + Routes: loadLLMRoutesFromReader(r), 69 + BedrockAWSKey: firstNonEmpty(r.GetString("llm.bedrock.aws_key"), r.GetString("llm.aws.key")), 70 + BedrockAWSSecret: firstNonEmpty(r.GetString("llm.bedrock.aws_secret"), r.GetString("llm.aws.secret")), 71 + BedrockAWSRegion: firstNonEmpty(r.GetString("llm.bedrock.region"), r.GetString("llm.aws.region")), 72 + BedrockModelARN: firstNonEmpty(r.GetString("llm.bedrock.model_arn"), r.GetString("llm.aws.bedrock_model_arn")), 73 + CloudflareAccountID: firstNonEmpty( 74 + r.GetString("llm.cloudflare.account_id"), 75 + ), 76 + CloudflareAPIToken: firstNonEmpty( 77 + r.GetString("llm.cloudflare.api_token"), 78 + ), 73 79 } 74 80 } 75 81 ··· 157 163 Headers: cloneStringMap(cfg.Headers), 158 164 Pricing: pricing, 159 165 RequestTimeout: cfg.RequestTimeout, 166 + CacheTTL: strings.TrimSpace(values.CacheTTL), 160 167 ToolsEmulationMode: toolsEmulationMode, 161 168 Temperature: temperature, 162 169 ReasoningEffort: reasoningEffort, ··· 249 256 default: 250 257 return "", fmt.Errorf("invalid llm.tools_emulation_mode %q (expected off|fallback|force)", mode) 251 258 } 259 + } 260 + 261 + func SystemPromptCacheControl(rawTTL string) (*llm.CacheControl, error) { 262 + rawTTL = strings.TrimSpace(rawTTL) 263 + if rawTTL == "" || strings.EqualFold(rawTTL, "off") { 264 + return nil, nil 265 + } 266 + 267 + switch strings.ToLower(rawTTL) { 268 + case "short", "long": 269 + return &llm.CacheControl{TTL: strings.ToLower(rawTTL)}, nil 270 + } 271 + 272 + if _, err := time.ParseDuration(rawTTL); err != nil { 273 + return nil, fmt.Errorf("invalid llm.cache_ttl %q (expected off|short|long|Go duration)", rawTTL) 274 + } 275 + return &llm.CacheControl{TTL: rawTTL}, nil 252 276 } 253 277 254 278 func optionalFloat64FromValue(raw, path string) (*float64, error) {
+76
internal/llmutil/llmutil_test.go
··· 567 567 v.Set("llm.endpoint", "https://api.openai.com") 568 568 v.Set("llm.api_key", "base-key") 569 569 v.Set("llm.model", "gpt-5.2") 570 + v.Set("llm.cache_ttl", "short") 570 571 v.Set("llm.request_timeout", "90s") 571 572 v.Set("llm.profiles", map[string]any{ 572 573 "cheap": map[string]any{ 573 574 "model": "gpt-4.1-mini", 574 575 "temperature": "0.2", 576 + "cache_ttl": "long", 575 577 }, 576 578 "reasoning": map[string]any{ 577 579 "provider": "xai", ··· 597 599 if values.Profiles["cheap"].Model != "gpt-4.1-mini" { 598 600 t.Fatalf("cheap model = %q, want gpt-4.1-mini", values.Profiles["cheap"].Model) 599 601 } 602 + if values.CacheTTL != "short" { 603 + t.Fatalf("cache_ttl = %q, want short", values.CacheTTL) 604 + } 605 + if values.Profiles["cheap"].CacheTTL != "long" { 606 + t.Fatalf("cheap cache_ttl = %q, want long", values.Profiles["cheap"].CacheTTL) 607 + } 600 608 if values.Profiles["reasoning"].ReasoningEffortRaw != "high" { 601 609 t.Fatalf("reasoning effort = %q, want high", values.Profiles["reasoning"].ReasoningEffortRaw) 602 610 } ··· 616 624 t.Fatalf("memory draft route profile = %q, want cheap", values.Routes.MemoryDraft.Profile) 617 625 } 618 626 } 627 + 628 + func TestResolveProfile_AppliesCacheTTLOverrides(t *testing.T) { 629 + values := RuntimeValues{ 630 + Provider: "openai_resp", 631 + Model: "gpt-5.2", 632 + CacheTTL: "short", 633 + Profiles: map[string]ProfileConfig{ 634 + "cheap": { 635 + Model: "gpt-4.1-mini", 636 + CacheTTL: "long", 637 + }, 638 + }, 639 + } 640 + 641 + resolved, err := ResolveProfile(values, "cheap") 642 + if err != nil { 643 + t.Fatalf("ResolveProfile() error = %v", err) 644 + } 645 + if resolved.Values.CacheTTL != "long" { 646 + t.Fatalf("resolved cache_ttl = %q, want long", resolved.Values.CacheTTL) 647 + } 648 + if resolved.ClientConfig.Model != "gpt-4.1-mini" { 649 + t.Fatalf("resolved model = %q, want gpt-4.1-mini", resolved.ClientConfig.Model) 650 + } 651 + } 652 + 653 + func TestSystemPromptCacheControl(t *testing.T) { 654 + ctrl, err := SystemPromptCacheControl("short") 655 + if err != nil { 656 + t.Fatalf("SystemPromptCacheControl() error = %v", err) 657 + } 658 + if ctrl == nil || ctrl.TTL != "short" { 659 + t.Fatalf("cache control = %#v, want TTL short", ctrl) 660 + } 661 + } 662 + 663 + func TestSystemPromptCacheControlEmpty(t *testing.T) { 664 + ctrl, err := SystemPromptCacheControl("") 665 + if err != nil { 666 + t.Fatalf("SystemPromptCacheControl() error = %v", err) 667 + } 668 + if ctrl != nil { 669 + t.Fatalf("cache control = %#v, want nil", ctrl) 670 + } 671 + } 672 + 673 + func TestSystemPromptCacheControlOff(t *testing.T) { 674 + ctrl, err := SystemPromptCacheControl("off") 675 + if err != nil { 676 + t.Fatalf("SystemPromptCacheControl() error = %v", err) 677 + } 678 + if ctrl != nil { 679 + t.Fatalf("cache control = %#v, want nil", ctrl) 680 + } 681 + } 682 + 683 + func TestSystemPromptCacheControlRejectsInvalidTTL(t *testing.T) { 684 + ctrl, err := SystemPromptCacheControl("not-a-ttl") 685 + if err == nil { 686 + t.Fatal("expected error for invalid cache ttl") 687 + } 688 + if ctrl != nil { 689 + t.Fatalf("cache control = %#v, want nil", ctrl) 690 + } 691 + if !strings.Contains(err.Error(), "expected off|short|long|Go duration") { 692 + t.Fatalf("error = %v, want cache ttl validation message", err) 693 + } 694 + }
+3
internal/llmutil/routes.go
··· 25 25 APIKey string `mapstructure:"api_key"` 26 26 Model string `mapstructure:"model"` 27 27 Headers map[string]string `mapstructure:"headers"` 28 + CacheTTL string `mapstructure:"cache_ttl"` 28 29 RequestTimeoutRaw string `mapstructure:"request_timeout"` 29 30 ToolsEmulationMode string `mapstructure:"tools_emulation_mode"` 30 31 TemperatureRaw string `mapstructure:"temperature"` ··· 296 297 cfg.APIKey = strings.TrimSpace(cfg.APIKey) 297 298 cfg.Model = strings.TrimSpace(cfg.Model) 298 299 cfg.Headers = cloneStringMap(cfg.Headers) 300 + cfg.CacheTTL = strings.TrimSpace(cfg.CacheTTL) 299 301 cfg.RequestTimeoutRaw = strings.TrimSpace(cfg.RequestTimeoutRaw) 300 302 cfg.ToolsEmulationMode = strings.TrimSpace(cfg.ToolsEmulationMode) 301 303 cfg.TemperatureRaw = strings.TrimSpace(cfg.TemperatureRaw) ··· 389 391 applyStringOverride(&out.APIKey, override.APIKey) 390 392 applyStringOverride(&out.Model, override.Model) 391 393 out.Headers = mergeStringMaps(out.Headers, override.Headers) 394 + applyStringOverride(&out.CacheTTL, override.CacheTTL) 392 395 applyStringOverride(&out.RequestTimeoutRaw, override.RequestTimeoutRaw) 393 396 applyStringOverride(&out.ToolsEmulationMode, override.ToolsEmulationMode) 394 397 applyStringOverride(&out.TemperatureRaw, override.TemperatureRaw)
+14 -8
llm/llm.go
··· 13 13 ToolCalls []ToolCall `json:"tool_calls,omitempty"` 14 14 } 15 15 16 + type CacheControl struct { 17 + TTL string `json:"ttl,omitempty"` 18 + } 19 + 16 20 const ( 17 21 PartTypeText = "text" 18 22 PartTypeImageURL = "image_url" ··· 20 24 ) 21 25 22 26 type Part struct { 23 - Type string `json:"type"` 24 - Text string `json:"text,omitempty"` 25 - URL string `json:"url,omitempty"` 26 - DataBase64 string `json:"data_base64,omitempty"` 27 - MIMEType string `json:"mime_type,omitempty"` 27 + Type string `json:"type"` 28 + Text string `json:"text,omitempty"` 29 + URL string `json:"url,omitempty"` 30 + DataBase64 string `json:"data_base64,omitempty"` 31 + MIMEType string `json:"mime_type,omitempty"` 32 + CacheControl *CacheControl `json:"cache_control,omitempty"` 28 33 } 29 34 30 35 type Tool struct { 31 - Name string `json:"name"` 32 - Description string `json:"description,omitempty"` 33 - ParametersJSON string `json:"parameters_json,omitempty"` 36 + Name string `json:"name"` 37 + Description string `json:"description,omitempty"` 38 + ParametersJSON string `json:"parameters_json,omitempty"` 39 + CacheControl *CacheControl `json:"cache_control,omitempty"` 34 40 } 35 41 36 42 type ToolCall struct {
+298 -16
providers/uniai/client.go
··· 28 28 Temperature *float64 29 29 ReasoningEffort string 30 30 ReasoningBudget *int 31 + CacheTTL string 31 32 32 33 ToolsEmulationMode string 33 34 AzureAPIKey string ··· 51 52 temperature *float64 52 53 reasoningEffort string 53 54 reasoningBudget *int 55 + cacheTTL string 54 56 toolsEmulationMode uniaiapi.ToolsEmulationMode 55 57 client *uniaiapi.Client 56 58 } ··· 103 105 temperature: cloneFloat64(cfg.Temperature), 104 106 reasoningEffort: strings.ToLower(strings.TrimSpace(cfg.ReasoningEffort)), 105 107 reasoningBudget: cloneInt(cfg.ReasoningBudget), 108 + cacheTTL: strings.TrimSpace(cfg.CacheTTL), 106 109 toolsEmulationMode: normalizeToolsEmulationMode(cfg.ToolsEmulationMode), 107 110 client: uniaiapi.New(uCfg), 108 111 } ··· 115 118 ctx, cancel = context.WithTimeout(ctx, c.requestTimeout) 116 119 defer cancel() 117 120 } 118 - 119 - opts := buildChatOptions(req, c.provider, req.ForceJSON, c.toolsEmulationMode, c.temperature, c.reasoningEffort, c.reasoningBudget) 121 + opts := buildChatOptions(req, c.provider, c.model, c.cacheTTL, req.ForceJSON, c.toolsEmulationMode, c.temperature, c.reasoningEffort, c.reasoningBudget) 120 122 resp, err := c.client.Chat(ctx, opts...) 121 123 if err != nil { 122 124 c.emitChatError(req.DebugFn, err, req.ForceJSON, 1) 123 125 } 124 126 if err != nil && req.ForceJSON && shouldRetryWithoutResponseFormat(err) { 125 - opts = buildChatOptions(req, c.provider, false, c.toolsEmulationMode, c.temperature, c.reasoningEffort, c.reasoningBudget) 127 + opts = buildChatOptions(req, c.provider, c.model, c.cacheTTL, false, c.toolsEmulationMode, c.temperature, c.reasoningEffort, c.reasoningBudget) 126 128 resp, err = c.client.Chat(ctx, opts...) 127 129 if err != nil { 128 130 c.emitChatError(req.DebugFn, err, false, 2) ··· 156 158 return strings.EqualFold(strings.TrimSpace(provider), "gemini") 157 159 } 158 160 159 - func buildChatOptions(req llm.Request, provider string, forceJSON bool, toolsEmulationMode uniaiapi.ToolsEmulationMode, defaultTemperature *float64, defaultReasoningEffort string, defaultReasoningBudget *int) []uniaiapi.ChatOption { 161 + func buildChatOptions(req llm.Request, provider string, defaultModel string, cacheTTL string, forceJSON bool, toolsEmulationMode uniaiapi.ToolsEmulationMode, defaultTemperature *float64, defaultReasoningEffort string, defaultReasoningBudget *int) []uniaiapi.ChatOption { 162 + req = adaptRequestForProvider(req, provider) 160 163 msgs := make([]uniaiapi.Message, len(req.Messages)) 161 164 for i, m := range req.Messages { 162 165 msg := uniaiapi.Message{Role: m.Role, Content: m.Content} 163 166 if len(m.Parts) > 0 { 164 - msg.Parts = toUniaiPartsFromLLM(m.Parts) 167 + msg.Parts = toUniaiPartsFromLLM(provider, m.Parts) 165 168 } 166 169 if strings.TrimSpace(m.ToolCallID) != "" { 167 170 msg.ToolCallID = m.ToolCallID ··· 173 176 } 174 177 175 178 opts := []uniaiapi.ChatOption{uniaiapi.WithReplaceMessages(msgs...)} 179 + openAIOptions := structs.JSONMap{} 180 + azureOptions := structs.JSONMap{} 176 181 if provider != "" { 177 182 opts = append(opts, uniaiapi.WithProvider(provider)) 178 183 } ··· 190 195 if name == "" { 191 196 continue 192 197 } 193 - tools = append(tools, uniaiapi.FunctionTool( 198 + tool := uniaiapi.FunctionTool( 194 199 name, 195 200 strings.TrimSpace(t.Description), 196 201 []byte(t.ParametersJSON), 197 - )) 202 + ) 203 + if t.CacheControl != nil { 204 + if ctrl, ok := toUniaiCacheControlForProvider(provider, *t.CacheControl); ok { 205 + tool = uniaiapi.WithToolCacheControl(tool, ctrl) 206 + } 207 + } 208 + tools = append(tools, tool) 198 209 } 199 210 if len(tools) > 0 { 200 211 opts = append(opts, uniaiapi.WithTools(tools)) ··· 240 251 opts = append(opts, uniaiapi.WithReasoningBudgetTokens(*defaultReasoningBudget)) 241 252 } 242 253 254 + applyPromptCacheOptions(provider, firstNonEmpty(req.Model, defaultModel), cacheTTL, req, openAIOptions, azureOptions) 243 255 if forceJSON && len(req.Tools) == 0 { 244 - opts = append(opts, uniaichat.WithOpenAIOptions(structs.JSONMap{ 245 - "response_format": "json_object", 246 - })) 256 + openAIOptions["response_format"] = "json_object" 257 + if strings.EqualFold(strings.TrimSpace(provider), "azure") { 258 + azureOptions["response_format"] = "json_object" 259 + } 260 + } 261 + if len(openAIOptions) > 0 { 262 + opts = append(opts, uniaiapi.WithOpenAIOptions(openAIOptions)) 263 + } 264 + if len(azureOptions) > 0 { 265 + opts = append(opts, uniaiapi.WithAzureOptions(azureOptions)) 247 266 } 248 267 249 268 if req.DebugFn != nil { ··· 354 373 return out 355 374 } 356 375 376 + func toLLMCacheControl(ctrl *uniaiapi.CacheControl) *llm.CacheControl { 377 + if ctrl == nil { 378 + return nil 379 + } 380 + return &llm.CacheControl{TTL: strings.TrimSpace(ctrl.TTL)} 381 + } 382 + 383 + func toUniaiCacheControlForProvider(provider string, ctrl llm.CacheControl) (uniaiapi.CacheControl, bool) { 384 + ttl := explicitCacheTTLForProvider(provider, ctrl.TTL) 385 + if ttl == "" { 386 + return uniaiapi.CacheControl{}, false 387 + } 388 + return uniaiapi.CacheControl{TTL: ttl}, true 389 + } 390 + 391 + func adaptRequestForProvider(req llm.Request, provider string) llm.Request { 392 + switch strings.ToLower(strings.TrimSpace(provider)) { 393 + case "anthropic": 394 + return req 395 + case "bedrock": 396 + return stripExplicitCacheControl(req, false, true) 397 + default: 398 + return stripExplicitCacheControl(req, true, true) 399 + } 400 + } 401 + 402 + func stripExplicitCacheControl(req llm.Request, stripAllParts bool, stripTools bool) llm.Request { 403 + out := req 404 + 405 + if len(req.Messages) > 0 { 406 + messages := make([]llm.Message, len(req.Messages)) 407 + copy(messages, req.Messages) 408 + changed := false 409 + for i, msg := range messages { 410 + if len(msg.Parts) == 0 { 411 + continue 412 + } 413 + parts := make([]llm.Part, len(msg.Parts)) 414 + copy(parts, msg.Parts) 415 + partChanged := false 416 + for j, part := range parts { 417 + if part.CacheControl == nil { 418 + continue 419 + } 420 + if stripAllParts || strings.EqualFold(strings.TrimSpace(msg.Role), "system") { 421 + part.CacheControl = nil 422 + parts[j] = part 423 + partChanged = true 424 + } 425 + } 426 + if partChanged { 427 + msg.Parts = parts 428 + messages[i] = msg 429 + changed = true 430 + } 431 + } 432 + if changed { 433 + out.Messages = messages 434 + } 435 + } 436 + 437 + if stripTools && len(req.Tools) > 0 { 438 + tools := make([]llm.Tool, len(req.Tools)) 439 + copy(tools, req.Tools) 440 + changed := false 441 + for i, tool := range tools { 442 + if tool.CacheControl == nil { 443 + continue 444 + } 445 + tool.CacheControl = nil 446 + tools[i] = tool 447 + changed = true 448 + } 449 + if changed { 450 + out.Tools = tools 451 + } 452 + } 453 + 454 + return out 455 + } 456 + 457 + func applyPromptCacheOptions(provider, model, cacheTTL string, req llm.Request, openAIOptions, azureOptions structs.JSONMap) { 458 + retention := promptCacheRetentionForProvider(provider, cacheTTL) 459 + key := derivedPromptCacheKey(provider, model, req) 460 + if key == "" && retention == "" { 461 + return 462 + } 463 + var target structs.JSONMap 464 + switch strings.ToLower(strings.TrimSpace(provider)) { 465 + case "openai", "openai_resp": 466 + target = openAIOptions 467 + case "azure": 468 + target = azureOptions 469 + default: 470 + return 471 + } 472 + if key != "" { 473 + target["prompt_cache_key"] = key 474 + } 475 + if retention != "" { 476 + target["prompt_cache_retention"] = retention 477 + } 478 + } 479 + 357 480 func normalizeToolsEmulationMode(mode string) uniaiapi.ToolsEmulationMode { 358 481 switch strings.ToLower(strings.TrimSpace(mode)) { 359 482 case "force": ··· 437 560 continue 438 561 } 439 562 out = append(out, llm.Part{ 440 - Type: partType, 441 - Text: part.Text, 442 - URL: part.URL, 443 - DataBase64: part.DataBase64, 444 - MIMEType: part.MIMEType, 563 + Type: partType, 564 + Text: part.Text, 565 + URL: part.URL, 566 + DataBase64: part.DataBase64, 567 + MIMEType: part.MIMEType, 568 + CacheControl: toLLMCacheControl(part.CacheControl), 445 569 }) 446 570 } 447 571 if len(out) == 0 { ··· 450 574 return out 451 575 } 452 576 453 - func toUniaiPartsFromLLM(parts []llm.Part) []uniaiapi.Part { 577 + func toUniaiPartsFromLLM(provider string, parts []llm.Part) []uniaiapi.Part { 454 578 if len(parts) == 0 { 455 579 return nil 456 580 } ··· 466 590 URL: part.URL, 467 591 DataBase64: part.DataBase64, 468 592 MIMEType: part.MIMEType, 593 + CacheControl: func() *uniaiapi.CacheControl { 594 + if part.CacheControl == nil { 595 + return nil 596 + } 597 + ctrl, ok := toUniaiCacheControlForProvider(provider, *part.CacheControl) 598 + if !ok { 599 + return nil 600 + } 601 + return &ctrl 602 + }(), 603 + }) 604 + } 605 + if len(out) == 0 { 606 + return nil 607 + } 608 + return out 609 + } 610 + 611 + func promptCacheRetentionForProvider(provider, rawTTL string) string { 612 + switch strings.ToLower(strings.TrimSpace(provider)) { 613 + case "openai", "openai_resp", "azure": 614 + default: 615 + return "" 616 + } 617 + return normalizePromptCacheRetention(rawTTL) 618 + } 619 + 620 + func normalizePromptCacheRetention(rawTTL string) string { 621 + rawTTL = strings.TrimSpace(rawTTL) 622 + if rawTTL == "" || strings.EqualFold(rawTTL, "off") { 623 + return "" 624 + } 625 + switch strings.ToLower(rawTTL) { 626 + case "short": 627 + return "in-memory" 628 + case "long": 629 + return "24h" 630 + } 631 + d, err := time.ParseDuration(rawTTL) 632 + if err != nil { 633 + return "" 634 + } 635 + if d <= 5*time.Minute { 636 + return "in-memory" 637 + } 638 + return "24h" 639 + } 640 + 641 + func explicitCacheTTLForProvider(provider, rawTTL string) string { 642 + switch strings.ToLower(strings.TrimSpace(provider)) { 643 + case "anthropic", "bedrock": 644 + default: 645 + return "" 646 + } 647 + rawTTL = strings.TrimSpace(rawTTL) 648 + if rawTTL == "" || strings.EqualFold(rawTTL, "off") { 649 + return "" 650 + } 651 + switch strings.ToLower(rawTTL) { 652 + case "short": 653 + return "5m" 654 + case "long": 655 + return "1h" 656 + } 657 + d, err := time.ParseDuration(rawTTL) 658 + if err != nil { 659 + return "" 660 + } 661 + if d <= 5*time.Minute { 662 + return "5m" 663 + } 664 + return "1h" 665 + } 666 + 667 + func derivedPromptCacheKey(provider, model string, req llm.Request) string { 668 + switch strings.ToLower(strings.TrimSpace(provider)) { 669 + case "openai", "openai_resp", "azure": 670 + default: 671 + return "" 672 + } 673 + 674 + stable := promptCacheStablePayload{ 675 + Model: strings.TrimSpace(model), 676 + Scene: strings.TrimSpace(req.Scene), 677 + } 678 + for _, msg := range req.Messages { 679 + if !strings.EqualFold(strings.TrimSpace(msg.Role), "system") { 680 + continue 681 + } 682 + stable.Messages = append(stable.Messages, stablePromptMessage{ 683 + Content: strings.TrimSpace(msg.Content), 684 + Parts: stableParts(msg.Parts), 685 + }) 686 + } 687 + for _, tool := range req.Tools { 688 + name := strings.TrimSpace(tool.Name) 689 + if name == "" { 690 + continue 691 + } 692 + stable.Tools = append(stable.Tools, stablePromptTool{ 693 + Name: name, 694 + Description: strings.TrimSpace(tool.Description), 695 + ParametersJSON: strings.TrimSpace(tool.ParametersJSON), 696 + }) 697 + } 698 + if len(stable.Messages) == 0 && len(stable.Tools) == 0 { 699 + return "" 700 + } 701 + data, err := json.Marshal(stable) 702 + if err != nil { 703 + return "" 704 + } 705 + sum := sha256.Sum256(data) 706 + return "mm-" + base64.RawURLEncoding.EncodeToString(sum[:12]) 707 + } 708 + 709 + type promptCacheStablePayload struct { 710 + Model string `json:"model,omitempty"` 711 + Scene string `json:"scene,omitempty"` 712 + Messages []stablePromptMessage `json:"messages,omitempty"` 713 + Tools []stablePromptTool `json:"tools,omitempty"` 714 + } 715 + 716 + type stablePromptMessage struct { 717 + Content string `json:"content,omitempty"` 718 + Parts []stablePart `json:"parts,omitempty"` 719 + } 720 + 721 + type stablePromptTool struct { 722 + Name string `json:"name"` 723 + Description string `json:"description,omitempty"` 724 + ParametersJSON string `json:"parameters_json,omitempty"` 725 + } 726 + 727 + type stablePart struct { 728 + Type string `json:"type"` 729 + Text string `json:"text,omitempty"` 730 + URL string `json:"url,omitempty"` 731 + DataBase64 string `json:"data_base64,omitempty"` 732 + MIMEType string `json:"mime_type,omitempty"` 733 + } 734 + 735 + func stableParts(parts []llm.Part) []stablePart { 736 + if len(parts) == 0 { 737 + return nil 738 + } 739 + out := make([]stablePart, 0, len(parts)) 740 + for _, part := range parts { 741 + partType := strings.TrimSpace(part.Type) 742 + if partType == "" { 743 + continue 744 + } 745 + out = append(out, stablePart{ 746 + Type: partType, 747 + Text: strings.TrimSpace(part.Text), 748 + URL: strings.TrimSpace(part.URL), 749 + DataBase64: strings.TrimSpace(part.DataBase64), 750 + MIMEType: strings.TrimSpace(part.MIMEType), 469 751 }) 470 752 } 471 753 if len(out) == 0 {
+218 -22
providers/uniai/client_test.go
··· 1 1 package uniai 2 2 3 3 import ( 4 + "reflect" 4 5 "testing" 5 6 6 7 "github.com/quailyquaily/mistermorph/llm" ··· 17 18 18 19 opts := append( 19 20 []uniaiapi.ChatOption{uniaiapi.WithMessages(uniaiapi.User("old"))}, 20 - buildChatOptions(req, "", false, uniaiapi.ToolsEmulationOff, nil, "", nil)..., 21 + buildChatOptions(req, "", "", "", false, uniaiapi.ToolsEmulationOff, nil, "", nil)..., 21 22 ) 22 23 23 24 built, err := uniaichat.BuildRequest(opts...) ··· 44 45 }, 45 46 } 46 47 47 - opts := buildChatOptions(req, "", false, uniaiapi.ToolsEmulationOff, nil, "", nil) 48 + opts := buildChatOptions(req, "", "", "", false, uniaiapi.ToolsEmulationOff, nil, "", nil) 48 49 built, err := uniaichat.BuildRequest(opts...) 49 50 if err != nil { 50 51 t.Fatalf("build request: %v", err) ··· 70 71 }, 71 72 } 72 73 73 - opts := buildChatOptions(req, "", false, uniaiapi.ToolsEmulationOff, nil, "", nil) 74 + opts := buildChatOptions(req, "", "", "", false, uniaiapi.ToolsEmulationOff, nil, "", nil) 74 75 built, err := uniaichat.BuildRequest(opts...) 75 76 if err != nil { 76 77 t.Fatalf("build request: %v", err) ··· 98 99 }, 99 100 } 100 101 101 - opts := buildChatOptions(req, "", false, uniaiapi.ToolsEmulationOff, nil, "", nil) 102 + opts := buildChatOptions(req, "", "", "", false, uniaiapi.ToolsEmulationOff, nil, "", nil) 102 103 built, err := uniaichat.BuildRequest(opts...) 103 104 if err != nil { 104 105 t.Fatalf("build request: %v", err) ··· 141 142 return nil 142 143 }, 143 144 } 144 - opts := buildChatOptions(req, "", false, uniaiapi.ToolsEmulationOff, nil, "", nil) 145 + opts := buildChatOptions(req, "", "", "", false, uniaiapi.ToolsEmulationOff, nil, "", nil) 145 146 146 147 built, err := uniaichat.BuildRequest(opts...) 147 148 if err != nil { ··· 232 233 Messages: []llm.Message{{Role: "user", Content: "hello"}}, 233 234 OnStream: func(llm.StreamEvent) error { return nil }, 234 235 } 235 - opts := buildChatOptions(req, "gemini", false, uniaiapi.ToolsEmulationOff, nil, "", nil) 236 + opts := buildChatOptions(req, "gemini", "", "", false, uniaiapi.ToolsEmulationOff, nil, "", nil) 236 237 237 238 built, err := uniaichat.BuildRequest(opts...) 238 239 if err != nil { ··· 248 249 Messages: []llm.Message{{Role: "user", Content: "hello"}}, 249 250 OnStream: func(llm.StreamEvent) error { return nil }, 250 251 } 251 - opts := buildChatOptions(req, "cloudflare", false, uniaiapi.ToolsEmulationOff, nil, "", nil) 252 + opts := buildChatOptions(req, "cloudflare", "", "", false, uniaiapi.ToolsEmulationOff, nil, "", nil) 252 253 253 254 built, err := uniaichat.BuildRequest(opts...) 254 255 if err != nil { ··· 268 269 gotPayload = payload 269 270 }, 270 271 } 271 - opts := buildChatOptions(req, "", false, uniaiapi.ToolsEmulationOff, nil, "", nil) 272 + opts := buildChatOptions(req, "", "", "", false, uniaiapi.ToolsEmulationOff, nil, "", nil) 272 273 273 274 built, err := uniaichat.BuildRequest(opts...) 274 275 if err != nil { ··· 287 288 req := llm.Request{ 288 289 Messages: []llm.Message{{Role: "user", Content: "hello"}}, 289 290 } 290 - opts := buildChatOptions(req, "", true, uniaiapi.ToolsEmulationOff, nil, "", nil) 291 + opts := buildChatOptions(req, "", "", "", true, uniaiapi.ToolsEmulationOff, nil, "", nil) 291 292 292 293 built, err := uniaichat.BuildRequest(opts...) 293 294 if err != nil { ··· 310 311 ParametersJSON: `{"type":"object","properties":{},"additionalProperties":false}`, 311 312 }}, 312 313 } 313 - opts := buildChatOptions(req, "", true, uniaiapi.ToolsEmulationOff, nil, "", nil) 314 + opts := buildChatOptions(req, "", "", "", true, uniaiapi.ToolsEmulationOff, nil, "", nil) 314 315 315 316 built, err := uniaichat.BuildRequest(opts...) 316 317 if err != nil { ··· 326 327 } 327 328 } 328 329 330 + func TestBuildChatOptionsMapsPromptCacheOptionsForOpenAIResp(t *testing.T) { 331 + req := llm.Request{ 332 + Scene: "runtime.loop", 333 + Messages: []llm.Message{ 334 + {Role: "system", Content: "stable system"}, 335 + {Role: "user", Content: "hello"}, 336 + }, 337 + } 338 + opts := buildChatOptions(req, "openai_resp", "gpt-5.4", "short", true, uniaiapi.ToolsEmulationOff, nil, "", nil) 339 + 340 + built, err := uniaichat.BuildRequest(opts...) 341 + if err != nil { 342 + t.Fatalf("build request: %v", err) 343 + } 344 + if built.Options.OpenAI == nil { 345 + t.Fatal("expected openai options to be set") 346 + } 347 + if got := built.Options.OpenAI["prompt_cache_key"]; got == "" || got == nil { 348 + t.Fatalf("prompt_cache_key = %#v, want non-empty derived key", got) 349 + } 350 + if got := built.Options.OpenAI["prompt_cache_retention"]; got != "in-memory" { 351 + t.Fatalf("prompt_cache_retention = %#v, want in-memory", got) 352 + } 353 + if got := built.Options.OpenAI["response_format"]; got != "json_object" { 354 + t.Fatalf("response_format = %#v, want json_object", got) 355 + } 356 + } 357 + 358 + func TestBuildChatOptionsMapsPromptCacheOptionsForAzure(t *testing.T) { 359 + req := llm.Request{ 360 + Messages: []llm.Message{ 361 + {Role: "system", Content: "stable system"}, 362 + {Role: "user", Content: "hello"}, 363 + }, 364 + } 365 + opts := buildChatOptions(req, "azure", "gpt-5.4", "long", true, uniaiapi.ToolsEmulationOff, nil, "", nil) 366 + 367 + built, err := uniaichat.BuildRequest(opts...) 368 + if err != nil { 369 + t.Fatalf("build request: %v", err) 370 + } 371 + if built.Options.Azure == nil { 372 + t.Fatal("expected azure options to be set") 373 + } 374 + if got := built.Options.Azure["prompt_cache_key"]; got == "" || got == nil { 375 + t.Fatalf("prompt_cache_key = %#v, want non-empty derived key", got) 376 + } 377 + if got := built.Options.Azure["prompt_cache_retention"]; got != "24h" { 378 + t.Fatalf("prompt_cache_retention = %#v, want 24h", got) 379 + } 380 + if got := built.Options.Azure["response_format"]; got != "json_object" { 381 + t.Fatalf("response_format = %#v, want json_object", got) 382 + } 383 + } 384 + 329 385 func TestBuildChatOptionsDoesNotInjectTemperatureWhenUnset(t *testing.T) { 330 386 req := llm.Request{ 331 387 Messages: []llm.Message{{Role: "user", Content: "hello"}}, 332 388 } 333 - opts := buildChatOptions(req, "", false, uniaiapi.ToolsEmulationOff, nil, "", nil) 389 + opts := buildChatOptions(req, "", "", "", false, uniaiapi.ToolsEmulationOff, nil, "", nil) 334 390 built, err := uniaichat.BuildRequest(opts...) 335 391 if err != nil { 336 392 t.Fatalf("build request: %v", err) ··· 346 402 } 347 403 temperature := 0.4 348 404 reasoningBudget := 8192 349 - opts := buildChatOptions(req, "", false, uniaiapi.ToolsEmulationOff, &temperature, "high", &reasoningBudget) 405 + opts := buildChatOptions(req, "", "", "", false, uniaiapi.ToolsEmulationOff, &temperature, "high", &reasoningBudget) 350 406 built, err := uniaichat.BuildRequest(opts...) 351 407 if err != nil { 352 408 t.Fatalf("build request: %v", err) ··· 367 423 Messages: []llm.Message{{Role: "user", Content: "hello"}}, 368 424 } 369 425 reasoningBudget := 8192 370 - opts := buildChatOptions(req, "openai_resp", false, uniaiapi.ToolsEmulationOff, nil, "high", &reasoningBudget) 426 + opts := buildChatOptions(req, "openai_resp", "", "", false, uniaiapi.ToolsEmulationOff, nil, "high", &reasoningBudget) 371 427 built, err := uniaichat.BuildRequest(opts...) 372 428 if err != nil { 373 429 t.Fatalf("build request: %v", err) ··· 386 442 Parameters: map[string]any{"temperature": 0.1}, 387 443 } 388 444 temperature := 0.4 389 - opts := buildChatOptions(req, "", false, uniaiapi.ToolsEmulationOff, &temperature, "", nil) 445 + opts := buildChatOptions(req, "", "", "", false, uniaiapi.ToolsEmulationOff, &temperature, "", nil) 390 446 built, err := uniaichat.BuildRequest(opts...) 391 447 if err != nil { 392 448 t.Fatalf("build request: %v", err) ··· 398 454 399 455 func TestPartRoundTripBetweenLLMAndUniai(t *testing.T) { 400 456 src := []llm.Part{ 401 - {Type: llm.PartTypeText, Text: "hello"}, 457 + {Type: llm.PartTypeText, Text: "hello", CacheControl: &llm.CacheControl{TTL: "5m"}}, 402 458 {Type: llm.PartTypeImageURL, URL: "https://example.com/a.png"}, 403 459 {Type: llm.PartTypeImageBase64, MIMEType: "image/jpeg", DataBase64: "QUJD"}, 404 460 } 405 - toUniai := toUniaiPartsFromLLM(src) 461 + toUniai := toUniaiPartsFromLLM("anthropic", src) 406 462 back := toLLMParts(toUniai) 407 463 408 - if len(back) != len(src) { 409 - t.Fatalf("parts length mismatch: got %d want %d", len(back), len(src)) 464 + if !reflect.DeepEqual(back, src) { 465 + t.Fatalf("parts mismatch: got %+v want %+v", back, src) 410 466 } 411 - for i := range src { 412 - if back[i] != src[i] { 413 - t.Fatalf("part[%d] mismatch: got %+v want %+v", i, back[i], src[i]) 414 - } 467 + } 468 + 469 + func TestBuildChatOptionsMapsToolCacheControl(t *testing.T) { 470 + req := llm.Request{ 471 + Messages: []llm.Message{{Role: "user", Content: "hello"}}, 472 + Tools: []llm.Tool{{ 473 + Name: "lookup", 474 + Description: "search", 475 + ParametersJSON: `{"type":"object","properties":{},"additionalProperties":false}`, 476 + CacheControl: &llm.CacheControl{TTL: "1h"}, 477 + }}, 478 + } 479 + opts := buildChatOptions(req, "anthropic", "", "", false, uniaiapi.ToolsEmulationOff, nil, "", nil) 480 + 481 + built, err := uniaichat.BuildRequest(opts...) 482 + if err != nil { 483 + t.Fatalf("build request: %v", err) 484 + } 485 + if len(built.Tools) != 1 { 486 + t.Fatalf("tools length = %d, want 1", len(built.Tools)) 487 + } 488 + if built.Tools[0].CacheControl == nil || built.Tools[0].CacheControl.TTL != "1h" { 489 + t.Fatalf("tool cache control = %#v, want 1h", built.Tools[0].CacheControl) 490 + } 491 + } 492 + 493 + func TestBuildChatOptionsKeepsExplicitCacheControlForAnthropic(t *testing.T) { 494 + req := llm.Request{ 495 + Messages: []llm.Message{ 496 + { 497 + Role: "system", 498 + Parts: []llm.Part{{ 499 + Type: llm.PartTypeText, 500 + Text: "sys", 501 + CacheControl: &llm.CacheControl{TTL: "5m"}, 502 + }}, 503 + }, 504 + {Role: "user", Content: "hello"}, 505 + }, 506 + Tools: []llm.Tool{{ 507 + Name: "lookup", 508 + Description: "search", 509 + ParametersJSON: `{"type":"object","properties":{},"additionalProperties":false}`, 510 + CacheControl: &llm.CacheControl{TTL: "1h"}, 511 + }}, 512 + } 513 + 514 + opts := buildChatOptions(req, "anthropic", "", "", false, uniaiapi.ToolsEmulationOff, nil, "", nil) 515 + built, err := uniaichat.BuildRequest(opts...) 516 + if err != nil { 517 + t.Fatalf("build request: %v", err) 518 + } 519 + if got := built.Messages[0].Parts[0].CacheControl; got == nil || got.TTL != "5m" { 520 + t.Fatalf("system cache control = %#v, want TTL 5m", got) 521 + } 522 + if got := built.Tools[0].CacheControl; got == nil || got.TTL != "1h" { 523 + t.Fatalf("tool cache control = %#v, want TTL 1h", got) 524 + } 525 + } 526 + 527 + func TestBuildChatOptionsStripsExplicitCacheControlForOpenAI(t *testing.T) { 528 + req := llm.Request{ 529 + Messages: []llm.Message{{ 530 + Role: "system", 531 + Parts: []llm.Part{{ 532 + Type: llm.PartTypeText, 533 + Text: "sys", 534 + CacheControl: &llm.CacheControl{TTL: "5m"}, 535 + }}, 536 + }}, 537 + Tools: []llm.Tool{{ 538 + Name: "lookup", 539 + Description: "search", 540 + ParametersJSON: `{"type":"object","properties":{},"additionalProperties":false}`, 541 + CacheControl: &llm.CacheControl{TTL: "1h"}, 542 + }}, 543 + } 544 + 545 + opts := buildChatOptions(req, "openai_resp", "", "", false, uniaiapi.ToolsEmulationOff, nil, "", nil) 546 + built, err := uniaichat.BuildRequest(opts...) 547 + if err != nil { 548 + t.Fatalf("build request: %v", err) 549 + } 550 + if got := built.Messages[0].Parts[0].CacheControl; got != nil { 551 + t.Fatalf("system cache control = %#v, want nil", got) 552 + } 553 + if got := built.Tools[0].CacheControl; got != nil { 554 + t.Fatalf("tool cache control = %#v, want nil", got) 555 + } 556 + } 557 + 558 + func TestBuildChatOptionsStripsOnlySystemPromptCacheControlForBedrock(t *testing.T) { 559 + req := llm.Request{ 560 + Messages: []llm.Message{ 561 + { 562 + Role: "system", 563 + Parts: []llm.Part{{ 564 + Type: llm.PartTypeText, 565 + Text: "sys", 566 + CacheControl: &llm.CacheControl{TTL: "5m"}, 567 + }}, 568 + }, 569 + { 570 + Role: "user", 571 + Parts: []llm.Part{{ 572 + Type: llm.PartTypeText, 573 + Text: "prefix", 574 + CacheControl: &llm.CacheControl{TTL: "1h"}, 575 + }}, 576 + }, 577 + }, 578 + Tools: []llm.Tool{{ 579 + Name: "lookup", 580 + Description: "search", 581 + ParametersJSON: `{"type":"object","properties":{},"additionalProperties":false}`, 582 + CacheControl: &llm.CacheControl{TTL: "1h"}, 583 + }}, 584 + } 585 + 586 + opts := buildChatOptions(req, "bedrock", "", "", false, uniaiapi.ToolsEmulationOff, nil, "", nil) 587 + built, err := uniaichat.BuildRequest(opts...) 588 + if err != nil { 589 + t.Fatalf("build request: %v", err) 590 + } 591 + if got := built.Messages[0].Parts[0].CacheControl; got != nil { 592 + t.Fatalf("system cache control = %#v, want nil", got) 593 + } 594 + if got := built.Messages[1].Parts[0].CacheControl; got == nil || got.TTL != "1h" { 595 + t.Fatalf("user cache control = %#v, want TTL 1h", got) 596 + } 597 + if got := built.Tools[0].CacheControl; got != nil { 598 + t.Fatalf("tool cache control = %#v, want nil", got) 599 + } 600 + } 601 + 602 + func TestNewStoresCacheTTLDefault(t *testing.T) { 603 + client := New(Config{ 604 + Provider: "openai_resp", 605 + Model: "gpt-5.2", 606 + CacheTTL: "long", 607 + }) 608 + 609 + if client.cacheTTL != "long" { 610 + t.Fatalf("cacheTTL = %q, want long", client.cacheTTL) 415 611 } 416 612 } 417 613