ive harnessed the harness
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

wire up memory

dawn 11654c54 521c761f

+566 -52
+173 -29
klbr-core/src/agent.rs
··· 1 1 use anyhow::Result; 2 + use std::time::Duration; 2 3 use tokio::sync::{broadcast, mpsc}; 3 4 4 5 use crate::MetricsSnapshot; ··· 9 10 interrupt::Interrupt, 10 11 llm::{LlmClient, LlmEvent, Message}, 11 12 memory::MemoryStore, 12 - mvp::SimilarityMetric, 13 + mvp::{L1MemoryRecord, SimilarityMetric}, 13 14 retrieval, 14 15 router::{RouteDecision, Router}, 16 + support::SupportScorer, 15 17 tools::{self, ToolContext}, 16 18 AgentEvent, AgentMetrics, 17 19 }; ··· 125 127 &emb, 126 128 &retrieval::RetrievalConfig { 127 129 namespace: "default".to_string(), 128 - top_k: config.memory_top_k, 130 + top_k: config.memory_candidate_k, 129 131 initial_window_days: config.memory_initial_window_days, 130 132 expansion_window_days: config 131 133 .memory_expansion_window_days ··· 154 156 ))); 155 157 } 156 158 157 - let candidates: Vec<_> = outcome 158 - .top_candidates 159 - .into_iter() 160 - .filter(|candidate| { 161 - candidate.score < config.memory_sim_threshold 162 - }) 163 - .collect(); 164 - let top1_score = candidates.first().map(|c| c.score); 165 - candidates 166 - .into_iter() 167 - .filter(|candidate| { 168 - match (config.memory_support_score_gap, top1_score) { 169 - (Some(gap), Some(top)) => { 170 - candidate.score - top <= gap 171 - } 172 - _ => true, 173 - } 174 - }) 175 - .map(|candidate| RecalledMemory { 176 - id: candidate.memory.memory_id, 177 - tags: candidate.memory.tags, 178 - content: candidate.memory.text, 179 - }) 180 - .collect() 159 + select_recalled_memories( 160 + &config, 161 + &llm, 162 + &corpus, 163 + prompt_text.as_str(), 164 + outcome.top_candidates, 165 + &output, 166 + ) 167 + .await 181 168 } 182 169 Err(e) => { 183 170 let msg = format!("memory recall failed: {e}"); ··· 391 378 } 392 379 } 393 380 381 + async fn select_recalled_memories( 382 + config: &Config, 383 + llm: &LlmClient, 384 + corpus: &[L1MemoryRecord], 385 + query: &str, 386 + first_stage: Vec<retrieval::RetrievedMemory>, 387 + output: &broadcast::Sender<AgentEvent>, 388 + ) -> Vec<RecalledMemory> { 389 + if first_stage.is_empty() { 390 + return vec![]; 391 + } 392 + 393 + if config.memory_rerank { 394 + match rerank_memory_candidates(config, llm, corpus, query, &first_stage).await { 395 + Ok(Some(memories)) => return memories, 396 + Ok(None) => return vec![], 397 + Err(err) => { 398 + let _ = output.send(AgentEvent::Status(format!( 399 + "memory rerank failed; using first stage: {err}" 400 + ))); 401 + } 402 + } 403 + } 404 + 405 + select_first_stage_memories(config, first_stage) 406 + } 407 + 408 + fn select_first_stage_memories( 409 + config: &Config, 410 + candidates: Vec<retrieval::RetrievedMemory>, 411 + ) -> Vec<RecalledMemory> { 412 + let candidates: Vec<_> = candidates 413 + .into_iter() 414 + .filter(|candidate| candidate.score < config.memory_sim_threshold) 415 + .collect(); 416 + let top1_score = candidates.first().map(|candidate| candidate.score); 417 + candidates 418 + .into_iter() 419 + .filter( 420 + |candidate| match (config.memory_support_score_gap, top1_score) { 421 + (Some(gap), Some(top)) => candidate.score - top <= gap, 422 + _ => true, 423 + }, 424 + ) 425 + .take(config.memory_top_k) 426 + .map(|candidate| RecalledMemory { 427 + id: candidate.memory.memory_id, 428 + tags: candidate.memory.tags, 429 + content: candidate.memory.text, 430 + }) 431 + .collect() 432 + } 433 + 434 + async fn rerank_memory_candidates( 435 + config: &Config, 436 + llm: &LlmClient, 437 + corpus: &[L1MemoryRecord], 438 + query: &str, 439 + first_stage: &[retrieval::RetrievedMemory], 440 + ) -> Result<Option<Vec<RecalledMemory>>> { 441 + let rerank_limit = config.memory_rerank_top_k.min(first_stage.len()).max(1); 442 + let rerank_pool = &first_stage[..rerank_limit]; 443 + let documents = rerank_pool 444 + .iter() 445 + .map(|candidate| { 446 + if candidate.memory.tags.is_empty() { 447 + candidate.memory.text.clone() 448 + } else { 449 + format!( 450 + "[tags: {}] {}", 451 + candidate.memory.tags.join(", "), 452 + candidate.memory.text 453 + ) 454 + } 455 + }) 456 + .collect::<Vec<_>>(); 457 + 458 + let rerank = tokio::time::timeout( 459 + Duration::from_millis(config.memory_rerank_timeout_ms), 460 + llm.rerank(query, &documents, false), 461 + ) 462 + .await 463 + .map_err(|_| anyhow::anyhow!("timed out after {}ms", config.memory_rerank_timeout_ms))??; 464 + 465 + let mut reranked = rerank 466 + .into_iter() 467 + .filter_map(|result| { 468 + rerank_pool.get(result.index).map(|candidate| { 469 + let mut candidate = candidate.clone(); 470 + candidate.score = result.score; 471 + candidate 472 + }) 473 + }) 474 + .collect::<Vec<_>>(); 475 + reranked.sort_by(|left, right| { 476 + right 477 + .score 478 + .partial_cmp(&left.score) 479 + .unwrap_or(std::cmp::Ordering::Equal) 480 + }); 481 + 482 + let Some(top_score) = reranked.first().map(|candidate| candidate.score) else { 483 + return Ok(None); 484 + }; 485 + if config 486 + .memory_rerank_min_score 487 + .is_some_and(|threshold| top_score < threshold) 488 + { 489 + return Ok(None); 490 + } 491 + if let Some(threshold) = config.memory_rerank_min_margin { 492 + let second_score = reranked.get(1).map(|candidate| candidate.score); 493 + let margin = second_score 494 + .map(|second| top_score - second) 495 + .unwrap_or(f32::INFINITY); 496 + if margin < threshold { 497 + return Ok(None); 498 + } 499 + } 500 + 501 + if let Some(threshold) = config.memory_support_threshold { 502 + let texts = corpus 503 + .iter() 504 + .map(|memory| memory.text.as_str()) 505 + .chain(std::iter::once(query)); 506 + let scorer = SupportScorer::from_texts(texts); 507 + let support = reranked 508 + .first() 509 + .map(|candidate| scorer.score(query, &candidate.memory.text)); 510 + if support.is_none_or(|score| score < threshold) { 511 + return Ok(None); 512 + } 513 + } 514 + 515 + let memories = reranked 516 + .into_iter() 517 + .filter(|candidate| match config.memory_rerank_score_gap { 518 + Some(gap) => top_score - candidate.score <= gap, 519 + None => true, 520 + }) 521 + .take(config.memory_top_k) 522 + .map(|candidate| RecalledMemory { 523 + id: candidate.memory.memory_id, 524 + tags: candidate.memory.tags, 525 + content: candidate.memory.text, 526 + }) 527 + .collect(); 528 + 529 + Ok(Some(memories)) 530 + } 531 + 394 532 async fn compact( 395 533 anchor: &str, 396 534 tool_ctx: &ToolContext, ··· 475 613 ctx: &Context, 476 614 output: &broadcast::Sender<AgentEvent>, 477 615 ) -> Result<()> { 478 - let pinned = tool_ctx.memory.pinned_memories().unwrap_or_default(); 616 + let pinned = tool_ctx.memory.pinned_memory_entries().unwrap_or_default(); 479 617 let unpinned = tool_ctx.memory.recent_unpinned(20).unwrap_or_default(); 480 618 481 619 let pinned_text = if pinned.is_empty() { ··· 483 621 } else { 484 622 pinned 485 623 .iter() 486 - .enumerate() 487 - .map(|(i, s)| format!("[id:{}] {s}", i + 1)) 624 + .map(|memory| { 625 + let tag_str = if memory.tags.is_empty() { 626 + String::new() 627 + } else { 628 + format!(" [{}]", memory.tags.join(", ")) 629 + }; 630 + format!("[id:{}]{tag_str} {}", memory.id, memory.content) 631 + }) 488 632 .collect::<Vec<_>>() 489 633 .join("\n") 490 634 };
+78 -1
klbr-core/src/config.rs
··· 1 1 use std::path::{Path, PathBuf}; 2 2 3 3 use anyhow::{Context, Result}; 4 - use serde::Deserialize; 4 + use serde::{Deserialize, Deserializer}; 5 5 6 6 #[derive(Clone, Debug, Deserialize)] 7 7 #[serde(default)] ··· 25 25 pub history_window: usize, 26 26 /// memories to inject per turn 27 27 pub memory_top_k: usize, 28 + /// first-stage memory candidates to retrieve before filtering/reranking. 29 + pub memory_candidate_k: usize, 28 30 /// cosine distance cutoff — only inject memories below this (0=identical, 2=opposite). 29 31 /// 0.3 ≈ cosine similarity ≥ 0.7, a reasonable default for bge-m3. 30 32 pub memory_sim_threshold: f32, ··· 37 39 /// trim injected support set: only include candidates whose cosine distance 38 40 /// is within this many units of the top-1 candidate. None = inject all top-k. 39 41 pub memory_support_score_gap: Option<f32>, 42 + /// enable runtime reranking of first-stage memory candidates. 43 + pub memory_rerank: bool, 44 + /// max first-stage candidates to send to the reranker. 45 + pub memory_rerank_top_k: usize, 46 + /// minimum reranker top-1 raw score before injecting memories. 47 + pub memory_rerank_min_score: Option<f32>, 48 + /// minimum raw reranker top-1/top-2 margin before injecting memories. 49 + pub memory_rerank_min_margin: Option<f32>, 50 + /// only inject reranked candidates within this many raw score points of top-1. 51 + pub memory_rerank_score_gap: Option<f32>, 52 + /// minimum lexical support score for the reranked top candidate before injecting memories. 53 + pub memory_support_threshold: Option<f32>, 54 + /// timeout for optional runtime rerank requests. 55 + pub memory_rerank_timeout_ms: u64, 40 56 pub db_path: String, 41 57 pub anchor: String, 42 58 pub embed_dim: usize, ··· 96 112 compaction_keep: Option<usize>, 97 113 history_window: Option<usize>, 98 114 memory_top_k: Option<usize>, 115 + memory_candidate_k: Option<usize>, 99 116 memory_sim_threshold: Option<f32>, 117 + #[serde(default, deserialize_with = "nullable_field")] 100 118 memory_initial_window_days: Option<Option<u32>>, 101 119 memory_expansion_window_days: Option<Vec<Option<u32>>>, 120 + #[serde(default, deserialize_with = "nullable_field")] 102 121 memory_expand_distance_threshold: Option<Option<f32>>, 122 + #[serde(default, deserialize_with = "nullable_field")] 103 123 memory_support_score_gap: Option<Option<f32>>, 124 + memory_rerank: Option<bool>, 125 + memory_rerank_top_k: Option<usize>, 126 + #[serde(default, deserialize_with = "nullable_field")] 127 + memory_rerank_min_score: Option<Option<f32>>, 128 + #[serde(default, deserialize_with = "nullable_field")] 129 + memory_rerank_min_margin: Option<Option<f32>>, 130 + #[serde(default, deserialize_with = "nullable_field")] 131 + memory_rerank_score_gap: Option<Option<f32>>, 132 + #[serde(default, deserialize_with = "nullable_field")] 133 + memory_support_threshold: Option<Option<f32>>, 134 + memory_rerank_timeout_ms: Option<u64>, 104 135 db_path: Option<String>, 105 136 anchor: Option<String>, 106 137 embed_dim: Option<usize>, ··· 109 140 compaction_model: Option<String>, 110 141 } 111 142 143 + fn nullable_field<'de, D, T>(deserializer: D) -> Result<Option<Option<T>>, D::Error> 144 + where 145 + D: Deserializer<'de>, 146 + T: Deserialize<'de>, 147 + { 148 + Option::<T>::deserialize(deserializer).map(Some) 149 + } 150 + 112 151 impl FileConfig { 113 152 fn merge_into(self, mut config: Config) -> Config { 114 153 if let Some(value) = self.llm_url { ··· 141 180 if let Some(value) = self.memory_top_k { 142 181 config.memory_top_k = value; 143 182 } 183 + if let Some(value) = self.memory_candidate_k { 184 + config.memory_candidate_k = value; 185 + } 144 186 if let Some(value) = self.memory_sim_threshold { 145 187 config.memory_sim_threshold = value; 146 188 } ··· 156 198 if let Some(value) = self.memory_support_score_gap { 157 199 config.memory_support_score_gap = value; 158 200 } 201 + if let Some(value) = self.memory_rerank { 202 + config.memory_rerank = value; 203 + } 204 + if let Some(value) = self.memory_rerank_top_k { 205 + config.memory_rerank_top_k = value; 206 + } 207 + if let Some(value) = self.memory_rerank_min_score { 208 + config.memory_rerank_min_score = value; 209 + } 210 + if let Some(value) = self.memory_rerank_min_margin { 211 + config.memory_rerank_min_margin = value; 212 + } 213 + if let Some(value) = self.memory_rerank_score_gap { 214 + config.memory_rerank_score_gap = value; 215 + } 216 + if let Some(value) = self.memory_support_threshold { 217 + config.memory_support_threshold = value; 218 + } 219 + if let Some(value) = self.memory_rerank_timeout_ms { 220 + config.memory_rerank_timeout_ms = value; 221 + } 159 222 if let Some(value) = self.db_path { 160 223 config.db_path = value; 161 224 } ··· 191 254 compaction_keep: 10, 192 255 history_window: 50, 193 256 memory_top_k: 3, 257 + memory_candidate_k: 20, 194 258 memory_sim_threshold: 0.3, 195 259 memory_initial_window_days: Some(7), 196 260 memory_expansion_window_days: vec![Some(30), Some(90), None], 197 261 memory_expand_distance_threshold: Some(0.35), 198 262 memory_support_score_gap: Some(0.05), 263 + memory_rerank: false, 264 + memory_rerank_top_k: 10, 265 + memory_rerank_min_score: Some(-6.0), 266 + memory_rerank_min_margin: Some(0.0), 267 + memory_rerank_score_gap: Some(5.0), 268 + memory_support_threshold: Some(0.3), 269 + memory_rerank_timeout_ms: 3_000, 199 270 db_path: "agent.db".into(), 200 271 anchor: ANCHOR.into(), 201 272 embed_dim: 1024, ··· 269 340 r#"{ 270 341 "llm_url": "http://localhost:9000", 271 342 "history_window": 99, 343 + "memory_candidate_k": 12, 344 + "memory_rerank": true, 345 + "memory_rerank_min_score": null, 272 346 "memory_sim_threshold": 0.22, 273 347 "compaction_model": "tiny-compactor" 274 348 }"#, ··· 278 352 let config = Config::load_path(&path).unwrap(); 279 353 assert_eq!(config.llm_url, "http://localhost:9000"); 280 354 assert_eq!(config.history_window, 99); 355 + assert_eq!(config.memory_candidate_k, 12); 356 + assert!(config.memory_rerank); 357 + assert_eq!(config.memory_rerank_min_score, None); 281 358 assert!((config.memory_sim_threshold - 0.22).abs() < f32::EPSILON); 282 359 assert_eq!(config.compaction_model.as_deref(), Some("tiny-compactor")); 283 360 assert_eq!(config.embed_url, Config::default().embed_url);
+1
klbr-core/src/lib.rs
··· 7 7 pub mod mvp; 8 8 pub mod retrieval; 9 9 pub mod router; 10 + pub mod support; 10 11 pub mod tools; 11 12 12 13 use std::sync::Arc;
+126 -19
klbr-core/src/memory.rs
··· 16 16 pub reasoning: Option<String>, 17 17 } 18 18 19 + #[derive(Debug, Clone, PartialEq, Eq)] 20 + pub struct MemorySummary { 21 + pub id: i64, 22 + pub content: String, 23 + pub tags: Vec<String>, 24 + } 25 + 19 26 /// a single result from recall() 20 27 #[derive(Debug, Clone)] 21 28 pub struct RecallEntry { ··· 84 91 85 92 fn migrate(&self) -> Result<()> { 86 93 let conn = self.conn.lock().unwrap(); 87 - let _ = conn.execute_batch( 88 - "ALTER TABLE memories ADD COLUMN pinned INTEGER NOT NULL DEFAULT 0;\ 89 - ALTER TABLE memories ADD COLUMN tags TEXT NOT NULL DEFAULT '[]';\ 90 - ALTER TABLE memories ADD COLUMN namespace TEXT NOT NULL DEFAULT 'default';\ 91 - ALTER TABLE memories ADD COLUMN layer TEXT NOT NULL DEFAULT 'L1';\ 92 - ALTER TABLE memories ADD COLUMN event_time INTEGER NOT NULL DEFAULT (unixepoch());\ 93 - ALTER TABLE memories ADD COLUMN ingest_time INTEGER NOT NULL DEFAULT (unixepoch());\ 94 - ALTER TABLE memories ADD COLUMN embedding_model TEXT NOT NULL DEFAULT 'unknown';\ 95 - ALTER TABLE memories ADD COLUMN embedding_dim INTEGER NOT NULL DEFAULT 0;\ 96 - ALTER TABLE memories ADD COLUMN embedding_version TEXT NOT NULL DEFAULT 'v1';\ 97 - ALTER TABLE memories ADD COLUMN status TEXT NOT NULL DEFAULT 'active';\ 98 - ALTER TABLE memories ADD COLUMN source_ref TEXT;", 94 + for statement in [ 95 + "ALTER TABLE memories ADD COLUMN pinned INTEGER NOT NULL DEFAULT 0", 96 + "ALTER TABLE memories ADD COLUMN tags TEXT NOT NULL DEFAULT '[]'", 97 + "ALTER TABLE memories ADD COLUMN namespace TEXT NOT NULL DEFAULT 'default'", 98 + "ALTER TABLE memories ADD COLUMN layer TEXT NOT NULL DEFAULT 'L1'", 99 + "ALTER TABLE memories ADD COLUMN event_time INTEGER NOT NULL DEFAULT 0", 100 + "ALTER TABLE memories ADD COLUMN ingest_time INTEGER NOT NULL DEFAULT 0", 101 + "ALTER TABLE memories ADD COLUMN embedding_model TEXT NOT NULL DEFAULT 'unknown'", 102 + "ALTER TABLE memories ADD COLUMN embedding_dim INTEGER NOT NULL DEFAULT 0", 103 + "ALTER TABLE memories ADD COLUMN embedding_version TEXT NOT NULL DEFAULT 'v1'", 104 + "ALTER TABLE memories ADD COLUMN status TEXT NOT NULL DEFAULT 'active'", 105 + "ALTER TABLE memories ADD COLUMN source_ref TEXT", 106 + ] { 107 + let _ = conn.execute(statement, []); 108 + } 109 + let _ = conn.execute( 110 + "UPDATE memories 111 + SET event_time = CASE WHEN event_time = 0 THEN COALESCE(ts, unixepoch()) ELSE event_time END, 112 + ingest_time = CASE WHEN ingest_time = 0 THEN COALESCE(ts, unixepoch()) ELSE ingest_time END", 113 + [], 99 114 ); 100 115 Ok(()) 101 116 } ··· 299 314 300 315 /// all pinned memory contents, oldest first 301 316 pub fn pinned_memories(&self) -> Result<Vec<String>> { 317 + Ok(self 318 + .pinned_memory_entries()? 319 + .into_iter() 320 + .map(|entry| entry.content) 321 + .collect()) 322 + } 323 + 324 + /// all active pinned memories, oldest first, with stable database ids 325 + pub fn pinned_memory_entries(&self) -> Result<Vec<MemorySummary>> { 302 326 let conn = self.conn.lock().unwrap(); 303 327 let mut stmt = conn.prepare( 304 - "SELECT content FROM memories 305 - WHERE pinned = 1 AND COALESCE(source_ref, '') != ?1 328 + "SELECT id, content, tags FROM memories 329 + WHERE pinned = 1 330 + AND status = 'active' 331 + AND COALESCE(source_ref, '') != ?1 306 332 ORDER BY ts ASC", 307 333 )?; 308 334 let results = stmt 309 - .query_map(params![ANCHOR_SOURCE_REF], |row| row.get(0))? 335 + .query_map(params![ANCHOR_SOURCE_REF], |row| { 336 + Ok(( 337 + row.get::<_, i64>(0)?, 338 + row.get::<_, String>(1)?, 339 + row.get::<_, String>(2)?, 340 + )) 341 + })? 310 342 .filter_map(|r| r.ok()) 343 + .map(|(id, content, tags_str)| MemorySummary { 344 + id, 345 + content, 346 + tags: serde_json::from_str(&tags_str).unwrap_or_default(), 347 + }) 311 348 .collect(); 312 349 Ok(results) 313 350 } ··· 317 354 let conn = self.conn.lock().unwrap(); 318 355 let mut stmt = conn.prepare( 319 356 "SELECT id, content, tags FROM memories 320 - WHERE pinned = 0 AND COALESCE(source_ref, '') != ?2 357 + WHERE pinned = 0 358 + AND status = 'active' 359 + AND COALESCE(source_ref, '') != ?2 321 360 ORDER BY ts DESC LIMIT ?1", 322 361 )?; 323 362 let results = stmt ··· 415 454 "SELECT m.id, m.content, m.tags, v.embedding 416 455 FROM memories m 417 456 JOIN vec_memories v ON v.rowid = m.id 418 - WHERE COALESCE(m.source_ref, '') != '{}' AND {}", 457 + WHERE m.status = 'active' 458 + AND COALESCE(m.source_ref, '') != '{}' 459 + AND {}", 419 460 ANCHOR_SOURCE_REF, 420 461 conditions.join(if tag_and { " AND " } else { " OR " }) 421 462 ); ··· 460 501 "SELECT m.id, m.content, m.tags, v.distance 461 502 FROM vec_memories v 462 503 JOIN memories m ON m.id = v.rowid 463 - WHERE COALESCE(m.source_ref, '') != ?3 504 + WHERE m.status = 'active' 505 + AND COALESCE(m.source_ref, '') != ?3 464 506 AND v.embedding MATCH ?1 AND k = ?2 465 507 ORDER BY v.distance", 466 508 )?; ··· 495 537 .collect(); 496 538 let sql = format!( 497 539 "SELECT id, content, tags FROM memories 498 - WHERE COALESCE(source_ref, '') != '{}' AND {} 540 + WHERE status = 'active' 541 + AND COALESCE(source_ref, '') != '{}' 542 + AND {} 499 543 ORDER BY ts DESC", 500 544 ANCHOR_SOURCE_REF, 501 545 conditions.join(if tag_and { " AND " } else { " OR " }) ··· 700 744 FROM memories 701 745 WHERE namespace = ?1 702 746 AND content = ?2 747 + AND status = 'active' 703 748 AND COALESCE(source_ref, '') != ?3 704 749 ORDER BY ts DESC 705 750 LIMIT 1", ··· 974 1019 stored[0].tags, 975 1020 vec!["project:klbr".to_string(), "preference".to_string()] 976 1021 ); 1022 + Ok(()) 1023 + } 1024 + 1025 + #[test] 1026 + fn test_migration_adds_all_columns_to_legacy_memory_table() -> Result<()> { 1027 + let tmp = NamedTempFile::new()?; 1028 + { 1029 + let conn = Connection::open(tmp.path())?; 1030 + conn.execute_batch( 1031 + "CREATE TABLE memories ( 1032 + id INTEGER PRIMARY KEY, 1033 + content TEXT NOT NULL, 1034 + pinned INTEGER NOT NULL DEFAULT 0, 1035 + tags TEXT NOT NULL DEFAULT '[]', 1036 + ts INTEGER NOT NULL DEFAULT 123 1037 + ); 1038 + INSERT INTO memories (content, pinned, tags, ts) 1039 + VALUES ('legacy memory', 1, '[\"legacy\"]', 123);", 1040 + )?; 1041 + } 1042 + 1043 + let store = MemoryStore::open(tmp.path().to_str().unwrap(), 4)?; 1044 + assert_eq!(store.pinned_memories()?, vec!["legacy memory"]); 1045 + 1046 + let id = store.store("new memory", &[1.0, 0.0, 0.0, 0.0], &[])?; 1047 + assert!(id > 0); 1048 + let stored = store.get_all()?; 1049 + assert_eq!(stored.len(), 1); 1050 + assert_eq!(stored[0].text, "new memory"); 1051 + Ok(()) 1052 + } 1053 + 1054 + #[test] 1055 + fn test_inactive_memories_do_not_surface_in_runtime_views() -> Result<()> { 1056 + let tmp = NamedTempFile::new()?; 1057 + let store = MemoryStore::open(tmp.path().to_str().unwrap(), 4)?; 1058 + 1059 + store.store_with_metadata(&MemoryRecordInput { 1060 + memory_id: None, 1061 + namespace: "default".to_string(), 1062 + layer: MemoryLayer::L1, 1063 + text: "archived memory".to_string(), 1064 + event_time: 1, 1065 + ingest_time: 1, 1066 + embedding_model: "test".to_string(), 1067 + embedding_dim: 4, 1068 + embedding_version: "v1".to_string(), 1069 + status: MemoryStatus::Archived, 1070 + source_ref: Some("test".to_string()), 1071 + tags: vec!["project:klbr".to_string()], 1072 + pinned: true, 1073 + embedding: vec![1.0, 0.0, 0.0, 0.0], 1074 + })?; 1075 + 1076 + assert!(store.pinned_memories()?.is_empty()); 1077 + assert!(store.recent_unpinned(10)?.is_empty()); 1078 + assert!(store 1079 + .recall(None, &["project:klbr".to_string()], false, 10)? 1080 + .is_empty()); 1081 + assert!(store 1082 + .recall(Some(&[1.0, 0.0, 0.0, 0.0]), &[], false, 10)? 1083 + .is_empty()); 977 1084 Ok(()) 978 1085 } 979 1086 }
+180
klbr-core/src/support.rs
··· 1 + use std::collections::{HashMap, HashSet}; 2 + 3 + #[derive(Debug, Clone)] 4 + pub struct SupportScorer { 5 + gram_idf: HashMap<String, f32>, 6 + default_idf: f32, 7 + n: usize, 8 + } 9 + 10 + impl SupportScorer { 11 + pub fn from_texts<'a, I>(texts: I) -> Self 12 + where 13 + I: IntoIterator<Item = &'a str>, 14 + { 15 + let n = 4usize; 16 + let mut doc_count_usize = 0usize; 17 + let mut df: HashMap<String, usize> = HashMap::new(); 18 + 19 + for text in texts { 20 + doc_count_usize += 1; 21 + for gram in char_ngrams(text, n) { 22 + *df.entry(gram).or_insert(0) += 1; 23 + } 24 + } 25 + 26 + let doc_count = doc_count_usize.max(1) as f32; 27 + let default_idf = ((doc_count + 1.0) / 1.0).ln() + 1.0; 28 + let gram_idf = df 29 + .into_iter() 30 + .map(|(gram, doc_freq)| { 31 + let idf = ((doc_count + 1.0) / (doc_freq as f32 + 1.0)).ln() + 1.0; 32 + (gram, idf) 33 + }) 34 + .collect(); 35 + 36 + Self { 37 + gram_idf, 38 + default_idf, 39 + n, 40 + } 41 + } 42 + 43 + pub fn score(&self, query_text: &str, candidate_text: &str) -> f32 { 44 + let query_grams = char_ngrams(query_text, self.n); 45 + if query_grams.is_empty() { 46 + return 0.0; 47 + } 48 + let candidate_grams = char_ngrams(candidate_text, self.n); 49 + 50 + let mut total_weight = 0.0f32; 51 + let mut matched_weight = 0.0f32; 52 + 53 + for gram in query_grams.iter() { 54 + let idf = self.idf(gram); 55 + total_weight += idf; 56 + if candidate_grams.contains(gram) { 57 + matched_weight += idf; 58 + } 59 + } 60 + 61 + if total_weight <= 0.0 { 62 + return 0.0; 63 + } 64 + 65 + let mut score = (matched_weight / total_weight).clamp(0.0, 1.0); 66 + let tokens = normalized_tokens(query_text); 67 + if tokens.len() >= 2 { 68 + let mut overlapped_tokens = 0usize; 69 + for token in tokens { 70 + let token_grams = token_ngrams(&token, self.n); 71 + if token_grams 72 + .iter() 73 + .any(|gram| candidate_grams.contains(gram)) 74 + { 75 + overlapped_tokens += 1; 76 + } 77 + } 78 + let scale = ((overlapped_tokens as f32) / 2.0).min(1.0); 79 + score = (score * scale).clamp(0.0, 1.0); 80 + } 81 + 82 + score 83 + } 84 + 85 + fn idf(&self, gram: &str) -> f32 { 86 + self.gram_idf.get(gram).copied().unwrap_or(self.default_idf) 87 + } 88 + } 89 + 90 + fn char_ngrams(text: &str, n: usize) -> HashSet<String> { 91 + let mut grams = HashSet::new(); 92 + let mut token = String::new(); 93 + 94 + let flush = |token: &mut String, grams: &mut HashSet<String>| { 95 + if token.is_empty() { 96 + return; 97 + } 98 + let chars: Vec<char> = token.chars().collect(); 99 + if chars.len() <= n { 100 + grams.insert(token.clone()); 101 + } else { 102 + for start in 0..=(chars.len() - n) { 103 + grams.insert(chars[start..start + n].iter().collect()); 104 + } 105 + } 106 + token.clear(); 107 + }; 108 + 109 + for ch in text.chars() { 110 + if ch.is_alphanumeric() { 111 + for lower in ch.to_lowercase() { 112 + token.push(lower); 113 + } 114 + } else { 115 + flush(&mut token, &mut grams); 116 + } 117 + } 118 + flush(&mut token, &mut grams); 119 + 120 + grams 121 + } 122 + 123 + fn normalized_tokens(text: &str) -> Vec<String> { 124 + let mut tokens = Vec::new(); 125 + let mut token = String::new(); 126 + 127 + let flush = |token: &mut String, tokens: &mut Vec<String>| { 128 + if !token.is_empty() { 129 + tokens.push(std::mem::take(token)); 130 + } 131 + }; 132 + 133 + for ch in text.chars() { 134 + if ch.is_alphanumeric() { 135 + for lower in ch.to_lowercase() { 136 + token.push(lower); 137 + } 138 + } else { 139 + flush(&mut token, &mut tokens); 140 + } 141 + } 142 + flush(&mut token, &mut tokens); 143 + 144 + tokens 145 + } 146 + 147 + fn token_ngrams(token: &str, n: usize) -> HashSet<String> { 148 + let mut grams = HashSet::new(); 149 + let chars: Vec<char> = token.chars().collect(); 150 + if chars.is_empty() { 151 + return grams; 152 + } 153 + if chars.len() <= n { 154 + grams.insert(token.to_string()); 155 + return grams; 156 + } 157 + for start in 0..=(chars.len() - n) { 158 + grams.insert(chars[start..start + n].iter().collect()); 159 + } 160 + grams 161 + } 162 + 163 + #[cfg(test)] 164 + mod tests { 165 + use super::SupportScorer; 166 + 167 + #[test] 168 + fn support_score_rewards_specific_overlap() { 169 + let scorer = SupportScorer::from_texts([ 170 + "switched the klbr editor to zed", 171 + "likes green tea in the morning", 172 + ]); 173 + 174 + let related = scorer.score("what editor does klbr use", "klbr editor is zed"); 175 + let unrelated = scorer.score("what editor does klbr use", "green tea preference"); 176 + 177 + assert!(related > unrelated); 178 + assert!(related > 0.0); 179 + } 180 + }
+8 -3
klbr-core/src/tools/list_memories.rs
··· 30 30 31 31 async fn execute(_args: serde_json::Value, ctx: ToolContext) -> String { 32 32 let anchor = ctx.memory.anchor_text().unwrap_or_default(); 33 - let pinned = ctx.memory.pinned_memories().unwrap_or_default(); 33 + let pinned = ctx.memory.pinned_memory_entries().unwrap_or_default(); 34 34 let unpinned = ctx.memory.recent_unpinned(10).unwrap_or_default(); 35 35 36 36 let mut out = String::new(); ··· 44 44 if pinned.is_empty() { 45 45 out.push_str("(none)\n"); 46 46 } else { 47 - for (i, content) in pinned.iter().enumerate() { 48 - out.push_str(&format!("{}. {content}\n", i + 1)); 47 + for memory in &pinned { 48 + let tag_str = if memory.tags.is_empty() { 49 + String::new() 50 + } else { 51 + format!(" [{}]", memory.tags.join(", ")) 52 + }; 53 + out.push_str(&format!("[id:{}]{tag_str} {}\n", memory.id, memory.content)); 49 54 } 50 55 } 51 56 out.push_str("\n## recent unpinned\n");