semantic bufo search find-bufo.com
bufo
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

refactor: extract traits and modules for provider abstraction

- add Embedder trait (providers.rs) for swappable embedding backends
- add VectorStore trait for swappable vector search backends
- extract scoring/fusion logic into scoring.rs module
- extract filter logic into filter.rs with composable Filter trait
- refactor VoyageEmbedder and TurbopufferStore to implement traits
- simplify search.rs using new abstractions
- add unit tests for scoring and filtering

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

zzstoatzz 59eb8927 5cba5945

+786 -372
+35 -20
src/embedding.rs
··· 1 - use anyhow::{Context, Result}; 1 + //! voyage AI embedding implementation 2 + //! 3 + //! implements the `Embedder` trait for voyage's multimodal-3 model. 4 + 5 + use crate::providers::{Embedder, EmbeddingError}; 2 6 use reqwest::Client; 3 7 use serde::{Deserialize, Serialize}; 4 8 9 + const VOYAGE_API_URL: &str = "https://api.voyageai.com/v1/multimodalembeddings"; 10 + const VOYAGE_MODEL: &str = "voyage-multimodal-3"; 11 + 5 12 #[derive(Debug, Serialize)] 6 - struct VoyageEmbeddingRequest { 13 + struct VoyageRequest { 7 14 inputs: Vec<MultimodalInput>, 8 15 model: String, 9 16 #[serde(skip_serializing_if = "Option::is_none")] ··· 22 29 } 23 30 24 31 #[derive(Debug, Deserialize)] 25 - struct VoyageEmbeddingResponse { 32 + struct VoyageResponse { 26 33 data: Vec<VoyageEmbeddingData>, 27 34 } 28 35 ··· 31 38 embedding: Vec<f32>, 32 39 } 33 40 34 - pub struct EmbeddingClient { 41 + /// voyage AI multimodal embedding client 42 + /// 43 + /// uses the voyage-multimodal-3 model which produces 1024-dimensional vectors. 44 + /// designed for early fusion of text and image content. 45 + #[derive(Clone)] 46 + pub struct VoyageEmbedder { 35 47 client: Client, 36 48 api_key: String, 37 49 } 38 50 39 - impl EmbeddingClient { 51 + impl VoyageEmbedder { 40 52 pub fn new(api_key: String) -> Self { 41 53 Self { 42 54 client: Client::new(), 43 55 api_key, 44 56 } 45 57 } 58 + } 46 59 47 - pub async fn embed_text(&self, text: &str) -> Result<Vec<f32>> { 48 - let request = VoyageEmbeddingRequest { 60 + impl Embedder for VoyageEmbedder { 61 + async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> { 62 + let request = VoyageRequest { 49 63 inputs: vec![MultimodalInput { 50 64 content: vec![ContentSegment::Text { 51 65 text: text.to_string(), 52 66 }], 53 67 }], 54 - model: "voyage-multimodal-3".to_string(), 68 + model: VOYAGE_MODEL.to_string(), 55 69 input_type: Some("query".to_string()), 56 70 }; 57 71 58 72 let response = self 59 73 .client 60 - .post("https://api.voyageai.com/v1/multimodalembeddings") 74 + .post(VOYAGE_API_URL) 61 75 .header("Authorization", format!("Bearer {}", self.api_key)) 62 76 .json(&request) 63 77 .send() 64 - .await 65 - .context("failed to send embedding request")?; 78 + .await?; 66 79 67 80 if !response.status().is_success() { 68 - let status = response.status(); 81 + let status = response.status().as_u16(); 69 82 let body = response.text().await.unwrap_or_default(); 70 - anyhow::bail!("voyage api error ({}): {}", status, body); 83 + return Err(EmbeddingError::Api { status, body }); 71 84 } 72 85 73 - let embedding_response: VoyageEmbeddingResponse = response 74 - .json() 75 - .await 76 - .context("failed to parse embedding response")?; 86 + let voyage_response: VoyageResponse = response.json().await.map_err(|e| { 87 + EmbeddingError::Other(anyhow::anyhow!("failed to parse response: {}", e)) 88 + })?; 77 89 78 - let embedding = embedding_response 90 + voyage_response 79 91 .data 80 92 .into_iter() 81 93 .next() 82 94 .map(|d| d.embedding) 83 - .context("no embedding returned")?; 95 + .ok_or(EmbeddingError::EmptyResponse) 96 + } 84 97 85 - Ok(embedding) 98 + fn name(&self) -> &'static str { 99 + "voyage-multimodal-3" 86 100 } 87 101 } 102 +
+193
src/filter.rs
··· 1 + //! composable result filters 2 + //! 3 + //! filters are predicates that can be combined to create complex filtering logic. 4 + 5 + use regex::Regex; 6 + 7 + /// a single search result that can be filtered 8 + pub trait Filterable { 9 + fn name(&self) -> &str; 10 + } 11 + 12 + /// a predicate that can accept or reject items 13 + pub trait Filter<T: Filterable>: Send + Sync { 14 + /// returns true if the item should be kept 15 + fn matches(&self, item: &T) -> bool; 16 + } 17 + 18 + /// filters out inappropriate content based on a blocklist 19 + struct BlocklistFilter { 20 + blocklist: Vec<&'static str>, 21 + } 22 + 23 + impl BlocklistFilter { 24 + fn inappropriate_bufos() -> Self { 25 + Self { 26 + blocklist: vec![ 27 + "bufo-juicy", 28 + "good-news-bufo-offers-suppository", 29 + "bufo-declines-your-suppository-offer", 30 + "tsa-bufo-gropes-you", 31 + ], 32 + } 33 + } 34 + } 35 + 36 + impl<T: Filterable> Filter<T> for BlocklistFilter { 37 + fn matches(&self, item: &T) -> bool { 38 + !self.blocklist.iter().any(|blocked| item.name().contains(blocked)) 39 + } 40 + } 41 + 42 + /// filters out items matching any of the given regex patterns 43 + struct ExcludePatternFilter { 44 + patterns: Vec<Regex>, 45 + } 46 + 47 + impl ExcludePatternFilter { 48 + fn from_comma_separated(pattern_str: &str) -> Self { 49 + let patterns = pattern_str 50 + .split(',') 51 + .map(|p| p.trim()) 52 + .filter(|p| !p.is_empty()) 53 + .filter_map(|p| Regex::new(p).ok()) 54 + .collect(); 55 + 56 + Self { patterns } 57 + } 58 + 59 + fn empty() -> Self { 60 + Self { patterns: vec![] } 61 + } 62 + } 63 + 64 + impl<T: Filterable> Filter<T> for ExcludePatternFilter { 65 + fn matches(&self, item: &T) -> bool { 66 + !self.patterns.iter().any(|p| p.is_match(item.name())) 67 + } 68 + } 69 + 70 + /// combined filter that handles family-friendly mode and include/exclude patterns 71 + pub struct ContentFilter { 72 + family_friendly: bool, 73 + blocklist: BlocklistFilter, 74 + exclude: ExcludePatternFilter, 75 + include_patterns: Vec<Regex>, 76 + } 77 + 78 + impl ContentFilter { 79 + pub fn new( 80 + family_friendly: bool, 81 + exclude_str: Option<&str>, 82 + include_str: Option<&str>, 83 + ) -> Self { 84 + let exclude = exclude_str 85 + .map(ExcludePatternFilter::from_comma_separated) 86 + .unwrap_or_else(ExcludePatternFilter::empty); 87 + 88 + let include_patterns: Vec<Regex> = include_str 89 + .map(|s| { 90 + s.split(',') 91 + .map(|p| p.trim()) 92 + .filter(|p| !p.is_empty()) 93 + .filter_map(|p| Regex::new(p).ok()) 94 + .collect() 95 + }) 96 + .unwrap_or_default(); 97 + 98 + Self { 99 + family_friendly, 100 + blocklist: BlocklistFilter::inappropriate_bufos(), 101 + exclude, 102 + include_patterns, 103 + } 104 + } 105 + 106 + pub fn exclude_pattern_count(&self) -> usize { 107 + self.exclude.patterns.len() 108 + } 109 + 110 + pub fn exclude_patterns_str(&self) -> String { 111 + self.exclude 112 + .patterns 113 + .iter() 114 + .map(|r| r.as_str()) 115 + .collect::<Vec<_>>() 116 + .join(",") 117 + } 118 + } 119 + 120 + impl<T: Filterable> Filter<T> for ContentFilter { 121 + fn matches(&self, item: &T) -> bool { 122 + // check family-friendly blocklist 123 + if self.family_friendly && !self.blocklist.matches(item) { 124 + return false; 125 + } 126 + 127 + // check if explicitly included (overrides exclude) 128 + let matches_include = self.include_patterns.iter().any(|p| p.is_match(item.name())); 129 + if matches_include { 130 + return true; 131 + } 132 + 133 + // check exclude patterns 134 + self.exclude.matches(item) 135 + } 136 + } 137 + 138 + #[cfg(test)] 139 + mod tests { 140 + use super::*; 141 + 142 + struct TestItem { 143 + name: String, 144 + } 145 + 146 + impl Filterable for TestItem { 147 + fn name(&self) -> &str { 148 + &self.name 149 + } 150 + } 151 + 152 + #[test] 153 + fn test_blocklist_filter() { 154 + let filter = BlocklistFilter::inappropriate_bufos(); 155 + let good = TestItem { 156 + name: "bufo-happy".into(), 157 + }; 158 + let bad = TestItem { 159 + name: "bufo-juicy".into(), 160 + }; 161 + 162 + assert!(filter.matches(&good)); 163 + assert!(!filter.matches(&bad)); 164 + } 165 + 166 + #[test] 167 + fn test_exclude_pattern_filter() { 168 + let filter = ExcludePatternFilter::from_comma_separated("test, draft"); 169 + let good = TestItem { 170 + name: "bufo-happy".into(), 171 + }; 172 + let bad = TestItem { 173 + name: "bufo-test-mode".into(), 174 + }; 175 + 176 + assert!(filter.matches(&good)); 177 + assert!(!filter.matches(&bad)); 178 + } 179 + 180 + #[test] 181 + fn test_include_overrides_exclude() { 182 + let filter = ContentFilter::new(false, Some("party"), Some("birthday-party")); 183 + let excluded = TestItem { 184 + name: "bufo-party".into(), 185 + }; 186 + let included = TestItem { 187 + name: "bufo-birthday-party".into(), 188 + }; 189 + 190 + assert!(!filter.matches(&excluded)); 191 + assert!(filter.matches(&included)); 192 + } 193 + }
+3
src/main.rs
··· 1 1 mod config; 2 2 mod embedding; 3 + mod filter; 4 + mod providers; 5 + mod scoring; 3 6 mod search; 4 7 mod turbopuffer; 5 8
+99
src/providers.rs
··· 1 + //! provider abstractions for embedding and vector search backends 2 + //! 3 + //! these traits allow swapping implementations (e.g., voyage → openai embeddings) 4 + //! without changing the search logic. 5 + //! 6 + //! ## design notes 7 + //! 8 + //! we use `async fn` in traits directly (stabilized in rust 1.75). for this crate's 9 + //! use case (single-threaded actix-web), the Send bound issue doesn't apply. 10 + //! 11 + //! the trait design follows patterns from: 12 + //! - async-openai's `Config` trait for backend abstraction 13 + //! - tower's `Service` trait for composability (though simpler here) 14 + 15 + use std::future::Future; 16 + use thiserror::Error; 17 + 18 + /// errors that can occur when generating embeddings 19 + #[derive(Debug, Error)] 20 + pub enum EmbeddingError { 21 + #[error("failed to send request: {0}")] 22 + Request(#[from] reqwest::Error), 23 + 24 + #[error("api error ({status}): {body}")] 25 + Api { status: u16, body: String }, 26 + 27 + #[error("no embedding returned from provider")] 28 + EmptyResponse, 29 + 30 + #[error("{0}")] 31 + Other(#[from] anyhow::Error), 32 + } 33 + 34 + /// a provider that can generate embeddings for text 35 + /// 36 + /// implementations should be cheap to clone (wrap expensive resources in Arc). 37 + /// 38 + /// # example 39 + /// 40 + /// ```ignore 41 + /// let client = VoyageEmbedder::new(api_key); 42 + /// let embedding = client.embed("hello world").await?; 43 + /// ``` 44 + pub trait Embedder: Send + Sync { 45 + /// generate an embedding vector for the given text 46 + fn embed(&self, text: &str) -> impl Future<Output = Result<Vec<f32>, EmbeddingError>> + Send; 47 + 48 + /// human-readable name for logging/debugging 49 + fn name(&self) -> &'static str; 50 + } 51 + 52 + /// errors that can occur during vector search 53 + #[derive(Debug, Error)] 54 + pub enum VectorSearchError { 55 + #[error("request failed: {0}")] 56 + Request(#[from] reqwest::Error), 57 + 58 + #[error("api error ({status}): {body}")] 59 + Api { status: u16, body: String }, 60 + 61 + #[error("query too long: {message}")] 62 + QueryTooLong { message: String }, 63 + 64 + #[error("parse error: {0}")] 65 + Parse(String), 66 + 67 + #[error("{0}")] 68 + Other(#[from] anyhow::Error), 69 + } 70 + 71 + /// a single result from a vector search 72 + #[derive(Debug, Clone)] 73 + pub struct SearchResult { 74 + pub id: String, 75 + /// raw distance/score from the backend (interpretation varies by method) 76 + pub score: f32, 77 + /// arbitrary key-value attributes 78 + pub attributes: std::collections::HashMap<String, String>, 79 + } 80 + 81 + /// a provider that can perform vector similarity search 82 + pub trait VectorStore: Send + Sync { 83 + /// search by vector embedding (ANN/cosine similarity) 84 + fn search_by_vector( 85 + &self, 86 + embedding: &[f32], 87 + top_k: usize, 88 + ) -> impl Future<Output = Result<Vec<SearchResult>, VectorSearchError>> + Send; 89 + 90 + /// search by keyword (BM25 full-text search) 91 + fn search_by_keyword( 92 + &self, 93 + query: &str, 94 + top_k: usize, 95 + ) -> impl Future<Output = Result<Vec<SearchResult>, VectorSearchError>> + Send; 96 + 97 + /// human-readable name for logging/debugging 98 + fn name(&self) -> &'static str; 99 + }
+164
src/scoring.rs
··· 1 + //! score fusion and normalization for hybrid search 2 + //! 3 + //! this module handles the weighted combination of semantic (vector) and 4 + //! keyword (BM25) search scores. 5 + //! 6 + //! ## normalization strategies 7 + //! 8 + //! - **cosine distance → similarity**: `1.0 - (distance / 2.0)` maps [0, 2] → [1, 0] 9 + //! - **BM25 max-scaling**: divide by max score so top result = 1.0 10 + //! 11 + //! ## fusion formula 12 + //! 13 + //! ```text 14 + //! score = α * semantic + (1 - α) * keyword 15 + //! ``` 16 + //! 17 + //! reference: https://opensourceconnections.com/blog/2023/02/27/hybrid-vigor-winning-at-hybrid-search/ 18 + 19 + use std::collections::HashMap; 20 + 21 + /// configuration for score fusion 22 + #[derive(Debug, Clone)] 23 + pub struct FusionConfig { 24 + /// weight for semantic scores (0.0 = pure keyword, 1.0 = pure semantic) 25 + pub alpha: f32, 26 + /// minimum fused score to include in results (filters noise) 27 + pub min_score: f32, 28 + } 29 + 30 + impl Default for FusionConfig { 31 + fn default() -> Self { 32 + Self { 33 + alpha: 0.7, 34 + min_score: 0.001, 35 + } 36 + } 37 + } 38 + 39 + impl FusionConfig { 40 + pub fn new(alpha: f32) -> Self { 41 + Self { 42 + alpha, 43 + ..Default::default() 44 + } 45 + } 46 + } 47 + 48 + /// normalize cosine distance to similarity score 49 + /// 50 + /// cosine distance ranges from 0 (identical) to 2 (opposite). 51 + /// we convert to similarity: 1.0 (identical) to 0.0 (opposite). 52 + #[inline] 53 + pub fn cosine_distance_to_similarity(distance: f32) -> f32 { 54 + 1.0 - (distance / 2.0) 55 + } 56 + 57 + /// normalize BM25 scores using max-scaling 58 + /// 59 + /// divides all scores by the maximum score, ensuring: 60 + /// - top result gets score 1.0 61 + /// - relative spacing is preserved 62 + /// - handles edge cases (empty results, identical scores) 63 + pub fn normalize_bm25_scores(scores: &[(String, f32)]) -> HashMap<String, f32> { 64 + let max_score = scores 65 + .iter() 66 + .map(|(_, s)| *s) 67 + .fold(f32::NEG_INFINITY, f32::max) 68 + .max(0.001); // avoid division by zero 69 + 70 + scores 71 + .iter() 72 + .map(|(id, score)| (id.clone(), (score / max_score).min(1.0))) 73 + .collect() 74 + } 75 + 76 + /// fuse semantic and keyword scores using weighted combination 77 + /// 78 + /// returns items sorted by fused score (descending), filtered by min_score. 79 + pub fn fuse_scores( 80 + semantic_scores: &HashMap<String, f32>, 81 + keyword_scores: &HashMap<String, f32>, 82 + config: &FusionConfig, 83 + ) -> Vec<(String, f32)> { 84 + // collect all unique IDs 85 + let all_ids: std::collections::HashSet<_> = semantic_scores 86 + .keys() 87 + .chain(keyword_scores.keys()) 88 + .collect(); 89 + 90 + let mut fused: Vec<(String, f32)> = all_ids 91 + .into_iter() 92 + .map(|id| { 93 + let semantic = semantic_scores.get(id).copied().unwrap_or(0.0); 94 + let keyword = keyword_scores.get(id).copied().unwrap_or(0.0); 95 + let score = config.alpha * semantic + (1.0 - config.alpha) * keyword; 96 + (id.clone(), score) 97 + }) 98 + .filter(|(_, score)| *score > config.min_score) 99 + .collect(); 100 + 101 + // sort descending by score 102 + fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); 103 + 104 + fused 105 + } 106 + 107 + #[cfg(test)] 108 + mod tests { 109 + use super::*; 110 + 111 + #[test] 112 + fn test_cosine_distance_to_similarity() { 113 + assert!((cosine_distance_to_similarity(0.0) - 1.0).abs() < 0.001); 114 + assert!((cosine_distance_to_similarity(2.0) - 0.0).abs() < 0.001); 115 + assert!((cosine_distance_to_similarity(1.0) - 0.5).abs() < 0.001); 116 + } 117 + 118 + #[test] 119 + fn test_normalize_bm25_scores() { 120 + let scores = vec![ 121 + ("a".to_string(), 10.0), 122 + ("b".to_string(), 5.0), 123 + ("c".to_string(), 2.5), 124 + ]; 125 + 126 + let normalized = normalize_bm25_scores(&scores); 127 + 128 + assert!((normalized["a"] - 1.0).abs() < 0.001); 129 + assert!((normalized["b"] - 0.5).abs() < 0.001); 130 + assert!((normalized["c"] - 0.25).abs() < 0.001); 131 + } 132 + 133 + #[test] 134 + fn test_fuse_scores_pure_semantic() { 135 + let mut semantic = HashMap::new(); 136 + semantic.insert("a".to_string(), 0.9); 137 + semantic.insert("b".to_string(), 0.5); 138 + 139 + let mut keyword = HashMap::new(); 140 + keyword.insert("a".to_string(), 0.1); 141 + keyword.insert("c".to_string(), 1.0); 142 + 143 + let config = FusionConfig::new(1.0); // pure semantic 144 + let fused = fuse_scores(&semantic, &keyword, &config); 145 + 146 + assert_eq!(fused[0].0, "a"); 147 + assert!((fused[0].1 - 0.9).abs() < 0.001); 148 + } 149 + 150 + #[test] 151 + fn test_fuse_scores_balanced() { 152 + let mut semantic = HashMap::new(); 153 + semantic.insert("a".to_string(), 0.8); 154 + 155 + let mut keyword = HashMap::new(); 156 + keyword.insert("a".to_string(), 0.4); 157 + 158 + let config = FusionConfig::new(0.5); // balanced 159 + let fused = fuse_scores(&semantic, &keyword, &config); 160 + 161 + // 0.5 * 0.8 + 0.5 * 0.4 = 0.6 162 + assert!((fused[0].1 - 0.6).abs() < 0.001); 163 + } 164 + }
+192 -260
src/search.rs
··· 27 27 //! - `α=0.5`: balanced (equal weight to semantic and keyword signals) 28 28 //! - `α=0.0`: pure keyword (best for exact filename searches) 29 29 //! 30 - //! ## empirical behavior 31 - //! 32 - //! query: "happy", top_k=3 33 - //! - α=1.0: ["proud-bufo-is-excited", "bufo-hehe", "bufo-excited"] (semantic similarity) 34 - //! - α=0.5: ["bufo-is-happy-youre-happy", ...] (exact match rises to top) 35 - //! - α=0.0: ["bufo-is-happy-youre-happy" (1.0), others (0.0)] (only exact matches score) 36 - //! 37 30 //! ## references 38 31 //! 39 32 //! - voyage multimodal embeddings: https://docs.voyageai.com/docs/multimodal-embeddings ··· 41 34 //! - weighted fusion: standard approach in modern hybrid search systems (2024) 42 35 43 36 use crate::config::Config; 44 - use crate::embedding::EmbeddingClient; 45 - use crate::turbopuffer::{QueryRequest, TurbopufferClient, TurbopufferError}; 37 + use crate::embedding::VoyageEmbedder; 38 + use crate::filter::{ContentFilter, Filter, Filterable}; 39 + use crate::providers::{Embedder, VectorSearchError, VectorStore}; 40 + use crate::scoring::{cosine_distance_to_similarity, fuse_scores, normalize_bm25_scores, FusionConfig}; 41 + use crate::turbopuffer::TurbopufferStore; 46 42 use actix_web::{web, HttpRequest, HttpResponse, Result as ActixResult}; 47 - use regex::Regex; 48 43 use serde::{Deserialize, Serialize}; 49 44 use std::collections::hash_map::DefaultHasher; 45 + use std::collections::HashMap; 50 46 use std::hash::{Hash, Hasher}; 51 47 52 48 #[derive(Debug, Deserialize)] ··· 81 77 true 82 78 } 83 79 84 - /// blocklist of inappropriate bufos (filtered when family_friendly=true) 85 - fn get_inappropriate_bufos() -> Vec<&'static str> { 86 - vec![ 87 - "bufo-juicy", 88 - "good-news-bufo-offers-suppository", 89 - "bufo-declines-your-suppository-offer", 90 - "tsa-bufo-gropes-you", 91 - ] 92 - } 93 - 94 80 #[derive(Debug, Serialize)] 95 81 pub struct SearchResponse { 96 82 pub results: Vec<BufoResult>, 97 83 } 98 84 99 - #[derive(Debug, Serialize)] 85 + #[derive(Debug, Serialize, Clone)] 100 86 pub struct BufoResult { 101 87 pub id: String, 102 88 pub url: String, 103 89 pub name: String, 104 - pub score: f32, // normalized 0-1 score for display 90 + pub score: f32, 91 + } 92 + 93 + impl Filterable for BufoResult { 94 + fn name(&self) -> &str { 95 + &self.name 96 + } 97 + } 98 + 99 + /// errors that can occur during search 100 + #[derive(Debug, thiserror::Error)] 101 + pub enum SearchError { 102 + #[error("embedding error: {0}")] 103 + Embedding(#[from] crate::providers::EmbeddingError), 104 + 105 + #[error("vector search error: {0}")] 106 + VectorSearch(#[from] VectorSearchError), 107 + } 108 + 109 + impl SearchError { 110 + fn into_actix_error(self) -> actix_web::Error { 111 + match &self { 112 + SearchError::VectorSearch(VectorSearchError::QueryTooLong { .. }) => { 113 + actix_web::error::ErrorBadRequest( 114 + "search query is too long (max 1024 characters for text search). try a shorter query." 115 + ) 116 + } 117 + _ => actix_web::error::ErrorInternalServerError(self.to_string()), 118 + } 119 + } 105 120 } 106 121 107 122 /// generate etag for caching based on query parameters 108 - fn generate_etag(query: &str, top_k: usize, alpha: f32, family_friendly: bool, exclude: &Option<String>, include: &Option<String>) -> String { 123 + fn generate_etag( 124 + query: &str, 125 + top_k: usize, 126 + alpha: f32, 127 + family_friendly: bool, 128 + exclude: &Option<String>, 129 + include: &Option<String>, 130 + ) -> String { 109 131 let mut hasher = DefaultHasher::new(); 110 132 query.hash(&mut hasher); 111 133 top_k.hash(&mut hasher); 112 - // convert f32 to bits for consistent hashing 113 134 alpha.to_bits().hash(&mut hasher); 114 135 family_friendly.hash(&mut hasher); 115 136 exclude.hash(&mut hasher); ··· 117 138 format!("\"{}\"", hasher.finish()) 118 139 } 119 140 120 - /// shared search implementation used by both POST and GET handlers 121 - async fn perform_search( 122 - query_text: String, 123 - top_k_val: usize, 124 - alpha: f32, 125 - family_friendly: bool, 126 - exclude: Option<String>, 127 - include: Option<String>, 128 - config: &Config, 129 - ) -> ActixResult<SearchResponse> { 130 - // parse and compile exclusion regex patterns from comma-separated string 131 - let exclude_patterns: Vec<Regex> = exclude 132 - .as_ref() 133 - .map(|s| { 134 - s.split(',') 135 - .map(|p| p.trim()) 136 - .filter(|p| !p.is_empty()) 137 - .filter_map(|p| Regex::new(p).ok()) // silently skip invalid patterns 138 - .collect() 139 - }) 140 - .unwrap_or_default(); 141 - 142 - // parse and compile inclusion regex patterns (these override exclusions) 143 - let include_patterns: Vec<Regex> = include 144 - .as_ref() 145 - .map(|s| { 146 - s.split(',') 147 - .map(|p| p.trim()) 148 - .filter(|p| !p.is_empty()) 149 - .filter_map(|p| Regex::new(p).ok()) 150 - .collect() 151 - }) 152 - .unwrap_or_default(); 153 - 154 - let _search_span = logfire::span!( 155 - "bufo_search", 156 - query = &query_text, 157 - top_k = top_k_val as i64, 158 - alpha = alpha as f64, 159 - family_friendly = family_friendly, 160 - exclude_patterns_count = exclude_patterns.len() as i64 161 - ).entered(); 162 - 163 - let exclude_patterns_str: String = exclude_patterns.iter().map(|r| r.as_str()).collect::<Vec<_>>().join(","); 164 - logfire::info!( 165 - "search request received", 166 - query = &query_text, 167 - top_k = top_k_val as i64, 168 - alpha = alpha as f64, 169 - exclude_patterns = &exclude_patterns_str 170 - ); 141 + /// execute hybrid search using the provided embedder and vector store 142 + async fn execute_hybrid_search<E: Embedder, V: VectorStore>( 143 + query: &str, 144 + top_k: usize, 145 + fusion_config: &FusionConfig, 146 + embedder: &E, 147 + vector_store: &V, 148 + ) -> Result<Vec<(String, f32, HashMap<String, String>)>, SearchError> { 149 + // fetch extra results to ensure we have enough after filtering 150 + let search_top_k = top_k * 5; 151 + let query_owned = query.to_string(); 171 152 172 - let embedding_client = EmbeddingClient::new(config.voyage_api_key.clone()); 173 - let tpuf_client = TurbopufferClient::new( 174 - config.turbopuffer_api_key.clone(), 175 - config.turbopuffer_namespace.clone(), 176 - ); 153 + // generate query embedding 154 + let _embed_span = logfire::span!( 155 + "embedding.generate", 156 + query = &query_owned, 157 + model = embedder.name() 158 + ) 159 + .entered(); 177 160 178 - // generate embedding for user query 179 - let query_embedding = { 180 - let _span = logfire::span!( 181 - "voyage.embed_text", 182 - query = &query_text, 183 - model = "voyage-3-lite" 184 - ).entered(); 185 - 186 - embedding_client 187 - .embed_text(&query_text) 188 - .await 189 - .map_err(|e| { 190 - let error_msg = e.to_string(); 191 - logfire::error!( 192 - "embedding generation failed", 193 - error = error_msg, 194 - query = &query_text 195 - ); 196 - actix_web::error::ErrorInternalServerError(format!( 197 - "failed to generate embedding: {}", 198 - e 199 - )) 200 - })? 201 - }; 161 + let query_embedding = embedder.embed(query).await?; 202 162 203 163 logfire::info!( 204 164 "embedding generated", 205 - query = &query_text, 165 + query = &query_owned, 206 166 embedding_dim = query_embedding.len() as i64 207 167 ); 208 168 209 - // run vector search (semantic) 210 - // fetch extra results to ensure we have enough after filtering by family_friendly and exclude patterns 211 - let search_top_k = top_k_val * 5; 212 - let vector_request = QueryRequest { 213 - rank_by: vec![ 214 - serde_json::json!("vector"), 215 - serde_json::json!("ANN"), 216 - serde_json::json!(query_embedding), 217 - ], 218 - top_k: search_top_k, 219 - include_attributes: Some(vec!["url".to_string(), "name".to_string(), "filename".to_string()]), 220 - }; 169 + // run both searches in sequence (could parallelize with tokio::join! if needed) 170 + let namespace = vector_store.name().to_string(); 221 171 222 - let namespace = config.turbopuffer_namespace.clone(); 223 172 let vector_results = { 224 173 let _span = logfire::span!( 225 174 "turbopuffer.vector_search", 226 - query = &query_text, 175 + query = &query_owned, 227 176 top_k = search_top_k as i64, 228 177 namespace = &namespace 229 - ).entered(); 178 + ) 179 + .entered(); 230 180 231 - tpuf_client.query(vector_request).await.map_err(|e| { 232 - let error_msg = e.to_string(); 233 - logfire::error!( 234 - "vector search failed", 235 - error = error_msg, 236 - query = &query_text, 237 - top_k = search_top_k as i64 238 - ); 239 - actix_web::error::ErrorInternalServerError(format!( 240 - "failed to query turbopuffer (vector): {}", 241 - e 242 - )) 243 - })? 181 + vector_store 182 + .search_by_vector(&query_embedding, search_top_k) 183 + .await? 244 184 }; 245 185 246 186 logfire::info!( 247 187 "vector search completed", 248 - query = &query_text, 188 + query = &query_owned, 249 189 results_found = vector_results.len() as i64 250 190 ); 251 191 252 - // run BM25 text search (keyword) 253 192 let bm25_results = { 254 193 let _span = logfire::span!( 255 194 "turbopuffer.bm25_search", 256 - query = &query_text, 195 + query = &query_owned, 257 196 top_k = search_top_k as i64, 258 197 namespace = &namespace 259 - ).entered(); 260 - 261 - tpuf_client.bm25_query(&query_text, search_top_k).await.map_err(|e| { 262 - let error_msg = e.to_string(); 263 - logfire::error!( 264 - "bm25 search failed", 265 - error = error_msg, 266 - query = &query_text, 267 - top_k = search_top_k as i64 268 - ); 198 + ) 199 + .entered(); 269 200 270 - // return appropriate HTTP status based on error type 271 - match e { 272 - TurbopufferError::QueryTooLong { .. } => { 273 - actix_web::error::ErrorBadRequest( 274 - "search query is too long (max 1024 characters for text search). try a shorter query." 275 - ) 276 - } 277 - _ => { 278 - actix_web::error::ErrorInternalServerError(format!( 279 - "failed to query turbopuffer (BM25): {}", 280 - e 281 - )) 282 - } 283 - } 284 - })? 201 + vector_store.search_by_keyword(query, search_top_k).await? 285 202 }; 286 203 287 - // weighted fusion: combine vector and BM25 results 288 - use std::collections::HashMap; 204 + // normalize scores 205 + let semantic_scores: HashMap<String, f32> = vector_results 206 + .iter() 207 + .map(|r| (r.id.clone(), cosine_distance_to_similarity(r.score))) 208 + .collect(); 289 209 290 - // normalize vector scores (cosine distance -> 0-1 similarity) 291 - let mut semantic_scores: HashMap<String, f32> = HashMap::new(); 292 - for row in &vector_results { 293 - let score = 1.0 - (row.dist / 2.0); 294 - semantic_scores.insert(row.id.clone(), score); 295 - } 210 + let bm25_raw: Vec<(String, f32)> = bm25_results 211 + .iter() 212 + .map(|r| (r.id.clone(), r.score)) 213 + .collect(); 214 + let keyword_scores = normalize_bm25_scores(&bm25_raw); 296 215 297 - // normalize BM25 scores using max normalization (BM25-max-scaled approach) 298 - // this preserves relative spacing and handles edge cases (single result, similar scores) 299 - // reference: https://opensourceconnections.com/blog/2023/02/27/hybrid-vigor-winning-at-hybrid-search/ 300 - let bm25_scores_vec: Vec<f32> = bm25_results.iter().map(|r| r.dist).collect(); 301 - let max_bm25 = bm25_scores_vec.iter().cloned().fold(f32::NEG_INFINITY, f32::max).max(0.001); // avoid division by zero 302 - 303 - let mut keyword_scores: HashMap<String, f32> = HashMap::new(); 304 - for row in &bm25_results { 305 - // divide by max to ensure top result gets 1.0, others scale proportionally 306 - let normalized_score = (row.dist / max_bm25).min(1.0); 307 - keyword_scores.insert(row.id.clone(), normalized_score); 308 - } 216 + let max_bm25 = bm25_raw 217 + .iter() 218 + .map(|(_, s)| *s) 219 + .fold(f32::NEG_INFINITY, f32::max); 309 220 310 221 logfire::info!( 311 222 "bm25 search completed", 312 - query = &query_text, 223 + query = &query_owned, 313 224 results_found = bm25_results.len() as i64, 314 225 max_bm25 = max_bm25 as f64, 315 - top_bm25_raw = bm25_scores_vec.first().copied().unwrap_or(0.0) as f64, 316 - top_bm25_normalized = keyword_scores.values().cloned().fold(f32::NEG_INFINITY, f32::max) as f64 226 + top_bm25_raw = bm25_raw.first().map(|(_, s)| *s).unwrap_or(0.0) as f64 227 + ); 228 + 229 + // fuse scores 230 + let fused = fuse_scores(&semantic_scores, &keyword_scores, fusion_config); 231 + 232 + logfire::info!( 233 + "weighted fusion completed", 234 + total_candidates = (vector_results.len() + bm25_results.len()) as i64, 235 + alpha = fusion_config.alpha as f64, 236 + pre_filter_results = fused.len() as i64 317 237 ); 318 238 319 - // collect all unique results and compute weighted fusion scores 320 - let mut all_results: HashMap<String, crate::turbopuffer::QueryRow> = HashMap::new(); 321 - for row in vector_results.into_iter().chain(bm25_results.into_iter()) { 322 - all_results.entry(row.id.clone()).or_insert(row); 239 + // collect attributes from both result sets 240 + let mut all_attributes: HashMap<String, HashMap<String, String>> = HashMap::new(); 241 + for result in vector_results.into_iter().chain(bm25_results.into_iter()) { 242 + all_attributes 243 + .entry(result.id.clone()) 244 + .or_insert(result.attributes); 323 245 } 324 246 325 - let mut fused_scores: Vec<(String, f32)> = all_results 326 - .keys() 327 - .map(|id| { 328 - let semantic = semantic_scores.get(id).copied().unwrap_or(0.0); 329 - let keyword = keyword_scores.get(id).copied().unwrap_or(0.0); 330 - let fused = alpha * semantic + (1.0 - alpha) * keyword; 331 - (id.clone(), fused) 247 + // return fused results with attributes 248 + Ok(fused 249 + .into_iter() 250 + .map(|(id, score)| { 251 + let attrs = all_attributes.remove(&id).unwrap_or_default(); 252 + (id, score, attrs) 332 253 }) 333 - .collect(); 254 + .collect()) 255 + } 334 256 335 - // filter out zero-scored results (irrelevant matches from the other search method) 336 - // this prevents vector-only results from appearing when alpha=0.0 (pure keyword) 337 - // and keyword-only results from appearing when alpha=1.0 (pure semantic) 338 - fused_scores.retain(|(_, score)| *score > 0.001); 257 + /// shared search implementation used by both POST and GET handlers 258 + async fn perform_search( 259 + query_text: String, 260 + top_k_val: usize, 261 + alpha: f32, 262 + family_friendly: bool, 263 + exclude: Option<String>, 264 + include: Option<String>, 265 + config: &Config, 266 + ) -> ActixResult<SearchResponse> { 267 + let content_filter = ContentFilter::new( 268 + family_friendly, 269 + exclude.as_deref(), 270 + include.as_deref(), 271 + ); 339 272 340 - // sort by fused score (descending) 341 - fused_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); 273 + let _search_span = logfire::span!( 274 + "bufo_search", 275 + query = &query_text, 276 + top_k = top_k_val as i64, 277 + alpha = alpha as f64, 278 + family_friendly = family_friendly, 279 + exclude_patterns_count = content_filter.exclude_pattern_count() as i64 280 + ) 281 + .entered(); 342 282 343 283 logfire::info!( 344 - "weighted fusion completed", 345 - total_candidates = all_results.len() as i64, 284 + "search request received", 285 + query = &query_text, 286 + top_k = top_k_val as i64, 346 287 alpha = alpha as f64, 347 - pre_filter_results = fused_scores.len() as i64 288 + exclude_patterns = &content_filter.exclude_patterns_str() 348 289 ); 349 290 350 - // convert to bufo results and apply ALL filtering BEFORE truncating 351 - // this ensures we return top_k results after filtering, not fewer 352 - let inappropriate_bufos = get_inappropriate_bufos(); 353 - let results: Vec<BufoResult> = fused_scores 354 - .into_iter() 355 - .filter_map(|(id, score)| { 356 - all_results.get(&id).map(|row| { 357 - let url = row 358 - .attributes 359 - .get("url") 360 - .and_then(|v| v.as_str()) 361 - .unwrap_or("") 362 - .to_string(); 291 + // create clients 292 + let embedder = VoyageEmbedder::new(config.voyage_api_key.clone()); 293 + let vector_store = TurbopufferStore::new( 294 + config.turbopuffer_api_key.clone(), 295 + config.turbopuffer_namespace.clone(), 296 + ); 363 297 364 - let name = row 365 - .attributes 366 - .get("name") 367 - .and_then(|v| v.as_str()) 368 - .unwrap_or(&row.id) 369 - .to_string(); 298 + let fusion_config = FusionConfig::new(alpha); 370 299 371 - BufoResult { 372 - id: row.id.clone(), 373 - url, 374 - name, 375 - score, 376 - } 377 - }) 378 - }) 379 - .filter(|result| { 380 - // filter out inappropriate bufos if family_friendly mode is enabled 381 - if family_friendly && inappropriate_bufos.iter().any(|&blocked| result.name.contains(blocked)) { 382 - return false; 383 - } 300 + // execute hybrid search 301 + let fused_results = execute_hybrid_search( 302 + &query_text, 303 + top_k_val, 304 + &fusion_config, 305 + &embedder, 306 + &vector_store, 307 + ) 308 + .await 309 + .map_err(|e| e.into_actix_error())?; 384 310 385 - // check if result matches any include pattern (allowlist override) 386 - let matches_include = include_patterns.iter().any(|p| p.is_match(&result.name)); 387 - 388 - // check if result matches any exclude pattern 389 - let matches_exclude = exclude_patterns.iter().any(|p| p.is_match(&result.name)); 390 - 391 - // keep if: matches include OR doesn't match exclude 392 - // i.e., include overrides exclude 393 - if matches_exclude && !matches_include { 394 - return false; 395 - } 396 - 397 - true 311 + // convert to BufoResults and apply filtering 312 + let results: Vec<BufoResult> = fused_results 313 + .into_iter() 314 + .map(|(id, score, attrs)| BufoResult { 315 + id: id.clone(), 316 + url: attrs.get("url").cloned().unwrap_or_default(), 317 + name: attrs.get("name").cloned().unwrap_or_else(|| id.clone()), 318 + score, 398 319 }) 399 - .take(top_k_val) // take top_k AFTER filtering 320 + .filter(|result| content_filter.matches(result)) 321 + .take(top_k_val) 400 322 .collect(); 401 323 402 324 let results_count = results.len() as i64; 403 - let top_result_name = results.first().map(|r| r.name.clone()).unwrap_or_else(|| "none".to_string()); 325 + let top_result_name = results 326 + .first() 327 + .map(|r| r.name.clone()) 328 + .unwrap_or_else(|| "none".to_string()); 404 329 let top_score_val = results.first().map(|r| r.score as f64).unwrap_or(0.0); 405 330 let avg_score_val = if !results.is_empty() { 406 331 results.iter().map(|r| r.score as f64).sum::<f64>() / results.len() as f64 ··· 432 357 query.family_friendly, 433 358 query.exclude.clone(), 434 359 query.include.clone(), 435 - &config 436 - ).await?; 360 + &config, 361 + ) 362 + .await?; 437 363 Ok(HttpResponse::Ok().json(response)) 438 364 } 439 365 ··· 443 369 config: web::Data<Config>, 444 370 req: HttpRequest, 445 371 ) -> ActixResult<HttpResponse> { 446 - // generate etag for caching 447 - let etag = generate_etag(&query.query, query.top_k, query.alpha, query.family_friendly, &query.exclude, &query.include); 372 + let etag = generate_etag( 373 + &query.query, 374 + query.top_k, 375 + query.alpha, 376 + query.family_friendly, 377 + &query.exclude, 378 + &query.include, 379 + ); 448 380 449 - // check if client has cached version 450 381 if let Some(if_none_match) = req.headers().get("if-none-match") { 451 382 if if_none_match.to_str().unwrap_or("") == etag { 452 383 return Ok(HttpResponse::NotModified() ··· 462 393 query.family_friendly, 463 394 query.exclude.clone(), 464 395 query.include.clone(), 465 - &config 466 - ).await?; 396 + &config, 397 + ) 398 + .await?; 467 399 468 400 Ok(HttpResponse::Ok() 469 401 .insert_header(("etag", etag.clone())) 470 - .insert_header(("cache-control", "public, max-age=300")) // cache for 5 minutes 402 + .insert_header(("cache-control", "public, max-age=300")) 471 403 .json(response)) 472 404 }
+100 -92
src/turbopuffer.rs
··· 1 - use anyhow::{Context, Result}; 1 + //! turbopuffer vector database implementation 2 + //! 3 + //! implements the `VectorStore` trait for turbopuffer's hybrid search API. 4 + 5 + use crate::providers::{SearchResult, VectorSearchError, VectorStore}; 2 6 use reqwest::Client; 3 7 use serde::{Deserialize, Serialize}; 4 - use thiserror::Error; 5 8 6 - #[derive(Debug, Error)] 7 - pub enum TurbopufferError { 8 - #[error("query too long: {message}")] 9 - QueryTooLong { message: String }, 10 - #[error("turbopuffer API error: {0}")] 11 - ApiError(String), 12 - #[error("request failed: {0}")] 13 - RequestFailed(#[from] reqwest::Error), 14 - #[error("{0}")] 15 - Other(#[from] anyhow::Error), 9 + const TURBOPUFFER_API_BASE: &str = "https://api.turbopuffer.com/v1/vectors"; 10 + 11 + /// raw response row from turbopuffer API 12 + #[derive(Debug, Deserialize, Serialize, Clone)] 13 + pub struct QueryRow { 14 + pub id: String, 15 + pub dist: f32, 16 + pub attributes: serde_json::Map<String, serde_json::Value>, 17 + } 18 + 19 + impl From<QueryRow> for SearchResult { 20 + fn from(row: QueryRow) -> Self { 21 + let attributes = row 22 + .attributes 23 + .iter() 24 + .filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string()))) 25 + .collect(); 26 + 27 + SearchResult { 28 + id: row.id, 29 + score: row.dist, 30 + attributes, 31 + } 32 + } 16 33 } 17 34 18 35 #[derive(Debug, Deserialize)] 19 - struct TurbopufferErrorResponse { 36 + struct ErrorResponse { 20 37 error: String, 21 38 #[allow(dead_code)] 22 39 status: String, 23 40 } 24 41 25 - #[derive(Debug, Serialize)] 26 - pub struct QueryRequest { 27 - pub rank_by: Vec<serde_json::Value>, 28 - pub top_k: usize, 29 - #[serde(skip_serializing_if = "Option::is_none")] 30 - pub include_attributes: Option<Vec<String>>, 31 - } 32 - 33 - pub type QueryResponse = Vec<QueryRow>; 34 - 35 - #[derive(Debug, Deserialize, Serialize, Clone)] 36 - pub struct QueryRow { 37 - pub id: String, 38 - pub dist: f32, // for vector: cosine distance; for BM25: BM25 score 39 - pub attributes: serde_json::Map<String, serde_json::Value>, 40 - } 41 - 42 - pub struct TurbopufferClient { 42 + /// turbopuffer vector database client 43 + /// 44 + /// supports both ANN vector search and BM25 full-text search. 45 + #[derive(Clone)] 46 + pub struct TurbopufferStore { 43 47 client: Client, 44 48 api_key: String, 45 49 namespace: String, 46 50 } 47 51 48 - impl TurbopufferClient { 52 + impl TurbopufferStore { 49 53 pub fn new(api_key: String, namespace: String) -> Self { 50 54 Self { 51 55 client: Client::new(), ··· 54 58 } 55 59 } 56 60 57 - pub async fn query(&self, request: QueryRequest) -> Result<QueryResponse> { 58 - let url = format!( 59 - "https://api.turbopuffer.com/v1/vectors/{}/query", 60 - self.namespace 61 - ); 62 - 63 - let request_json = serde_json::to_string_pretty(&request)?; 64 - log::debug!("turbopuffer query request: {}", request_json); 61 + fn query_url(&self) -> String { 62 + format!("{}/{}/query", TURBOPUFFER_API_BASE, self.namespace) 63 + } 65 64 65 + async fn execute_query( 66 + &self, 67 + request: serde_json::Value, 68 + ) -> Result<Vec<QueryRow>, VectorSearchError> { 66 69 let response = self 67 70 .client 68 - .post(&url) 71 + .post(self.query_url()) 69 72 .header("Authorization", format!("Bearer {}", self.api_key)) 70 73 .json(&request) 71 74 .send() 72 - .await 73 - .context("failed to send query request")?; 75 + .await?; 74 76 75 77 if !response.status().is_success() { 76 - let status = response.status(); 78 + let status = response.status().as_u16(); 77 79 let body = response.text().await.unwrap_or_default(); 78 - anyhow::bail!("turbopuffer query failed with status {}: {}", status, body); 80 + 81 + // check for specific error types 82 + if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&body) { 83 + if error_resp.error.contains("too long") && error_resp.error.contains("max 1024") { 84 + return Err(VectorSearchError::QueryTooLong { 85 + message: error_resp.error, 86 + }); 87 + } 88 + } 89 + 90 + return Err(VectorSearchError::Api { status, body }); 79 91 } 80 92 81 - let body = response.text().await.context("failed to read response body")?; 93 + let body = response.text().await.map_err(|e| { 94 + VectorSearchError::Other(anyhow::anyhow!("failed to read response: {}", e)) 95 + })?; 82 96 83 97 serde_json::from_str(&body) 84 - .context(format!("failed to parse query response: {}", body)) 98 + .map_err(|e| VectorSearchError::Parse(format!("failed to parse response: {}", e))) 85 99 } 86 - 87 - pub async fn bm25_query(&self, query_text: &str, top_k: usize) -> Result<QueryResponse, TurbopufferError> { 88 - let url = format!( 89 - "https://api.turbopuffer.com/v1/vectors/{}/query", 90 - self.namespace 91 - ); 100 + } 92 101 102 + impl VectorStore for TurbopufferStore { 103 + async fn search_by_vector( 104 + &self, 105 + embedding: &[f32], 106 + top_k: usize, 107 + ) -> Result<Vec<SearchResult>, VectorSearchError> { 93 108 let request = serde_json::json!({ 94 - "rank_by": ["name", "BM25", query_text], 109 + "rank_by": ["vector", "ANN", embedding], 95 110 "top_k": top_k, 96 111 "include_attributes": ["url", "name", "filename"], 97 112 }); 98 113 99 - if let Ok(pretty) = serde_json::to_string_pretty(&request) { 100 - log::debug!("turbopuffer BM25 query request: {}", pretty); 101 - } 102 - 103 - let response = self 104 - .client 105 - .post(&url) 106 - .header("Authorization", format!("Bearer {}", self.api_key)) 107 - .json(&request) 108 - .send() 109 - .await?; 110 - 111 - if !response.status().is_success() { 112 - let status = response.status(); 113 - let body = response.text().await.unwrap_or_default(); 114 + log::debug!( 115 + "turbopuffer vector query: {}", 116 + serde_json::to_string_pretty(&request).unwrap_or_default() 117 + ); 114 118 115 - // try to parse turbopuffer error response 116 - if let Ok(error_resp) = serde_json::from_str::<TurbopufferErrorResponse>(&body) { 117 - // check if it's a query length error 118 - if error_resp.error.contains("too long") && error_resp.error.contains("max 1024") { 119 - return Err(TurbopufferError::QueryTooLong { 120 - message: error_resp.error, 121 - }); 122 - } 123 - } 119 + let rows = self.execute_query(request).await?; 120 + Ok(rows.into_iter().map(SearchResult::from).collect()) 121 + } 124 122 125 - return Err(TurbopufferError::ApiError(format!( 126 - "turbopuffer BM25 query failed with status {}: {}", 127 - status, body 128 - ))); 129 - } 123 + async fn search_by_keyword( 124 + &self, 125 + query: &str, 126 + top_k: usize, 127 + ) -> Result<Vec<SearchResult>, VectorSearchError> { 128 + let request = serde_json::json!({ 129 + "rank_by": ["name", "BM25", query], 130 + "top_k": top_k, 131 + "include_attributes": ["url", "name", "filename"], 132 + }); 130 133 131 - let body = response.text().await 132 - .map_err(|e| TurbopufferError::Other(anyhow::anyhow!("failed to read response body: {}", e)))?; 133 - log::debug!("turbopuffer BM25 response: {}", body); 134 + log::debug!( 135 + "turbopuffer BM25 query: {}", 136 + serde_json::to_string_pretty(&request).unwrap_or_default() 137 + ); 134 138 135 - let parsed: QueryResponse = serde_json::from_str(&body) 136 - .map_err(|e| TurbopufferError::Other(anyhow::anyhow!("failed to parse BM25 query response: {}", e)))?; 139 + let rows = self.execute_query(request).await?; 137 140 138 - // DEBUG: log first result to see what BM25 returns 139 - if let Some(first) = parsed.first() { 140 - log::info!("BM25 first result - id: {}, dist: {}, name: {:?}", 141 + if let Some(first) = rows.first() { 142 + log::info!( 143 + "BM25 first result - id: {}, dist: {}, name: {:?}", 141 144 first.id, 142 145 first.dist, 143 146 first.attributes.get("name") 144 147 ); 145 148 } 146 149 147 - Ok(parsed) 150 + Ok(rows.into_iter().map(SearchResult::from).collect()) 151 + } 152 + 153 + fn name(&self) -> &'static str { 154 + "turbopuffer" 148 155 } 149 156 } 157 +