Another remote for gh:zotero-rag/zotero-rag
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

fix: pass backend instead of recreating instances

authored by

Rahul Yedida and committed by
Rahul Yedida
430e46e9 6a6e9d9a

+97 -99
+6
zqa-rag/src/vector/backends/lance.rs
··· 124 124 source_col, 125 125 } 126 126 } 127 + 128 + /// Returns a reference to the embedding provider configuration. 129 + #[must_use] 130 + pub fn embedding_config(&self) -> &EmbeddingProviderConfig { 131 + &self.config 132 + } 127 133 } 128 134 129 135 /// From a `RecordBatch`, return all values from a specified column as a `Vec<String>`.
+11 -6
zqa/src/cli/app.rs
··· 228 228 embedding_dims: DEFAULT_VOYAGE_EMBEDDING_DIM as usize, 229 229 reranker: DEFAULT_VOYAGE_RERANK_MODEL.into(), 230 230 }; 231 - 232 - // Build a minimal tool; the embedding config is only used in `call`, not in the metadata 233 - // methods, so we use a dummy VoyageAI config here. 234 - RetrievalTool::new( 231 + let schema = Arc::new(arrow_schema::Schema::new(vec![ 232 + arrow_schema::Field::new("library_key", arrow_schema::DataType::Utf8, false), 233 + arrow_schema::Field::new("title", arrow_schema::DataType::Utf8, false), 234 + arrow_schema::Field::new("file_path", arrow_schema::DataType::Utf8, false), 235 + arrow_schema::Field::new("pdf_text", arrow_schema::DataType::Utf8, false), 236 + ])); 237 + let backend = LanceBackend::new( 235 238 EmbeddingProviderConfig::VoyageAI(config.clone()), 236 - Some(RerankProviderConfig::VoyageAI(config)), 237 - ) 239 + schema, 240 + "pdf_text".into(), 241 + ); 242 + RetrievalTool::new(backend, Some(RerankProviderConfig::VoyageAI(config))) 238 243 } 239 244 240 245 #[retry(3)]
+6 -10
zqa/src/cli/handlers/library.rs
··· 73 73 { 74 74 const WARNING_THRESHOLD: usize = 100; 75 75 76 - let item_metadata = 77 - if ctx.backend.db_exists().await { 78 - get_new_library_items(&ctx.config.get_embedding_config().ok_or( 79 - CLIError::ConfigError("Could not get embedding config".into()), 80 - )?) 81 - .await 82 - } else { 83 - parse_library_metadata(None, None) 84 - }; 76 + let item_metadata = if ctx.backend.db_exists().await { 77 + get_new_library_items(&ctx.backend).await 78 + } else { 79 + parse_library_metadata(None, None) 80 + }; 85 81 86 82 if let Err(parse_err) = item_metadata { 87 83 writeln!( ··· 110 106 } 111 107 } 112 108 113 - let record_batch = full_library_to_arrow(&ctx.config, None, None).await?; 109 + let record_batch = full_library_to_arrow(&ctx.backend, None, None).await?; 114 110 let schema = record_batch.schema(); 115 111 let batches = vec![record_batch.clone()]; 116 112
+2 -6
zqa/src/cli/handlers/query.rs
··· 88 88 let vector_search_start = Instant::now(); 89 89 let (mut search_results, _) = vector_search( 90 90 search_term.clone(), 91 - &ctx.config 92 - .get_embedding_config() 93 - .ok_or(CLIError::ConfigError( 94 - "Could not get embedding config".into(), 95 - ))?, 91 + &ctx.backend, 96 92 ctx.config.get_reranker_config().as_ref(), 97 93 ) 98 94 .await?; ··· 212 208 .as_ref() 213 209 .map(|c| (c.provider_name().to_string(), c.model_name().to_string())); 214 210 215 - let retrieval_tool = RetrievalTool::new(embedding_config.clone(), reranker_config); 211 + let retrieval_tool = RetrievalTool::new(ctx.backend.clone(), reranker_config); 216 212 let retrieval_embedding_chars = std::sync::Arc::clone(&retrieval_tool.embedding_chars); 217 213 let retrieval_rerank_chars = std::sync::Arc::clone(&retrieval_tool.rerank_chars); 218 214
+23 -16
zqa/src/tools/retrieval.rs
··· 10 10 use serde::Deserialize; 11 11 use serde_json::json; 12 12 use zqa_rag::{ 13 - embedding::common::EmbeddingProviderConfig, llm::tools::Tool, 14 - reranking::common::RerankProviderConfig, 13 + llm::tools::Tool, reranking::common::RerankProviderConfig, 14 + vector::backends::lance::LanceBackend, 15 15 }; 16 16 17 17 use crate::utils::{ ··· 25 25 /// A tool to perform vector search and reranking. 26 26 #[derive(Debug)] 27 27 pub(crate) struct RetrievalTool { 28 - /// The embedding provider configuration. Note that this must be the same embedding provider 29 - /// used when initially creating the database. 30 - pub(crate) embedding_config: EmbeddingProviderConfig, 28 + /// The backend used for vector search. 29 + pub(crate) backend: LanceBackend, 31 30 /// The reranker provider to use. 32 31 pub(crate) reranker_config: Option<RerankProviderConfig>, 33 32 /// Accumulated character count of text sent to the embedding API across all calls. ··· 37 36 } 38 37 39 38 impl RetrievalTool { 40 - /// Create a new instance of the [`RetrievalTool`] given an embedding config and reranker config. 39 + /// Create a new instance of the [`RetrievalTool`] given a backend and reranker config. 41 40 pub(crate) fn new( 42 - embedding_config: EmbeddingProviderConfig, 41 + backend: LanceBackend, 43 42 reranker_provider: Option<RerankProviderConfig>, 44 43 ) -> Self { 45 44 Self { 46 - embedding_config, 45 + backend, 47 46 reranker_config: reranker_provider, 48 47 embedding_chars: Arc::new(AtomicU64::new(0)), 49 48 rerank_chars: Arc::new(AtomicU64::new(0)), ··· 90 89 ) -> std::pin::Pin<Box<dyn Future<Output = Result<serde_json::Value, String>> + Send + '_>> 91 90 { 92 91 let start = Instant::now(); 93 - let embedding_config = self.embedding_config.clone(); 92 + let backend = self.backend.clone(); 94 93 let reranker_config = self.reranker_config.clone(); 95 94 let embedding_chars = Arc::clone(&self.embedding_chars); 96 95 let rerank_chars = Arc::clone(&self.rerank_chars); ··· 99 98 let input: RetrievalToolInput = 100 99 serde_json::from_value(args).map_err(|e| format!("Invalid arguments: {e}"))?; 101 100 let (mut results, stats) = 102 - vector_search(input.query, &embedding_config, reranker_config.as_ref()) 101 + vector_search(input.query, &backend, reranker_config.as_ref()) 103 102 .await 104 103 .map_err(|e| format!("Search failed: {e}"))?; 105 104 embedding_chars.fetch_add(stats.embedding_chars as u64, Ordering::Relaxed); ··· 136 135 137 136 #[cfg(test)] 138 137 mod tests { 138 + use std::sync::Arc; 139 + 139 140 use serde_json::json; 140 141 use zqa_rag::constants::{ 141 142 DEFAULT_VOYAGE_EMBEDDING_DIM, DEFAULT_VOYAGE_EMBEDDING_MODEL, DEFAULT_VOYAGE_RERANK_MODEL, 142 143 }; 143 144 use zqa_rag::embedding::common::EmbeddingProviderConfig; 145 + use zqa_rag::vector::backends::lance::LanceBackend; 144 146 145 147 use super::*; 146 148 ··· 151 153 embedding_dims: DEFAULT_VOYAGE_EMBEDDING_DIM as usize, 152 154 reranker: DEFAULT_VOYAGE_RERANK_MODEL.into(), 153 155 }; 154 - 155 - // Build a minimal tool; the embedding config is only used in `call`, not in the metadata 156 - // methods, so we use a dummy VoyageAI config here. 157 - RetrievalTool::new( 156 + let schema = Arc::new(arrow_schema::Schema::new(vec![ 157 + arrow_schema::Field::new("library_key", arrow_schema::DataType::Utf8, false), 158 + arrow_schema::Field::new("title", arrow_schema::DataType::Utf8, false), 159 + arrow_schema::Field::new("file_path", arrow_schema::DataType::Utf8, false), 160 + arrow_schema::Field::new("pdf_text", arrow_schema::DataType::Utf8, false), 161 + ])); 162 + let backend = LanceBackend::new( 158 163 EmbeddingProviderConfig::VoyageAI(config.clone()), 159 - Some(RerankProviderConfig::VoyageAI(config)), 160 - ) 164 + schema, 165 + "pdf_text".into(), 166 + ); 167 + RetrievalTool::new(backend, Some(RerankProviderConfig::VoyageAI(config))) 161 168 } 162 169 163 170 #[test]
+17 -34
zqa/src/utils/arrow.rs
··· 17 17 }; 18 18 19 19 use super::library::{LibraryParsingError, parse_library}; 20 - use crate::{ 21 - config::Config, 22 - utils::library::{ZoteroItem, ZoteroItemSet}, 23 - }; 20 + use crate::utils::library::{ZoteroItem, ZoteroItemSet}; 24 21 25 22 /// An enum containing the fields stored by our application in `LanceDB`, in order. Implementations 26 23 /// `as_ref()` and `into()` are provided to convert this to `&str` and `String` respectively. ··· 93 90 fn from(value: LanceError) -> Self { 94 91 Self::LanceError(value.to_string()) 95 92 } 96 - } 97 - 98 - /// Build the LanceDB backend used by the CLI for the supplied embedding configuration. 99 - pub(crate) async fn lance_backend(embedding_config: EmbeddingProviderConfig) -> LanceBackend { 100 - let schema = Arc::new(get_schema(embedding_config.provider()).await); 101 - LanceBackend::new( 102 - embedding_config, 103 - schema, 104 - DbFields::PdfText.as_ref().to_string(), 105 - ) 106 93 } 107 94 108 95 /// Get the schema for our `LanceDB` table. This is required for both getting library items and ··· 250 237 /// multi-threading, etc. 251 238 /// * `limit` - Optional limit, meant to be used in conjunction with `start_from`. 252 239 pub async fn full_library_to_arrow( 253 - config: &Config, 240 + backend: &LanceBackend, 254 241 start_from: Option<usize>, 255 242 limit: Option<usize>, 256 243 ) -> Result<RecordBatch, ArrowError> { 257 - let lib_items = parse_library( 258 - &config.get_embedding_config().ok_or(ArrowError::Other( 259 - "Failed to get embedding config from application config".to_string(), 260 - ))?, 261 - start_from, 262 - limit, 263 - ) 264 - .await?; 244 + let lib_items = parse_library(backend, start_from, limit).await?; 265 245 log::info!("Finished parsing library items."); 266 246 267 - library_to_arrow( 268 - lib_items, 269 - config.get_embedding_config().ok_or(ArrowError::Other( 270 - "Failed to get embedding config from application config".to_string(), 271 - ))?, 272 - ) 273 - .await 247 + library_to_arrow(lib_items, backend.embedding_config().clone()).await 274 248 } 275 249 276 250 /// Statistics about the characters processed in a vector search call, used for cost estimation. ··· 315 289 /// * `ArrowError::LLMError` if reranking fails. 316 290 pub async fn vector_search( 317 291 query: String, 318 - embedding_config: &EmbeddingProviderConfig, 292 + backend: &LanceBackend, 319 293 reranker_config: Option<&RerankProviderConfig>, 320 294 ) -> Result<(Vec<ZoteroItem>, VectorSearchStats), ArrowError> { 321 295 let embedding_chars = query.len(); 322 - let backend = lance_backend(embedding_config.clone()).await; 323 296 let batches = backend.vector_search(query.clone(), 10).await?; 324 297 325 298 let items: ZoteroItemSet = batches.into(); ··· 383 356 }; 384 357 385 358 use super::*; 386 - use crate::{common::setup_logger, config::VoyageAIConfig}; 359 + use crate::{ 360 + common::setup_logger, 361 + config::{Config, VoyageAIConfig}, 362 + }; 387 363 388 364 fn get_config() -> Config { 389 365 let mut config = Config { ··· 416 392 let config = get_config(); 417 393 418 394 let record_batch = temp_env::async_with_vars([("LANCEDB_URI", Some(&db_uri))], async { 419 - full_library_to_arrow(&config, Some(0), Some(5)).await 395 + let embedding_config = config.get_embedding_config().unwrap(); 396 + let schema = Arc::new(get_schema(embedding_config.provider()).await); 397 + let backend = LanceBackend::new( 398 + embedding_config, 399 + schema, 400 + DbFields::PdfText.as_ref().to_string(), 401 + ); 402 + full_library_to_arrow(&backend, Some(0), Some(5)).await 420 403 }) 421 404 .await; 422 405
+23 -19
zqa/src/utils/library.rs
··· 17 17 use serde::Serialize; 18 18 use thiserror::Error; 19 19 use zqa_pdftools::parse::extract_text; 20 - use zqa_rag::embedding::common::EmbeddingProviderConfig; 21 20 use zqa_rag::vector::backends::{ 22 21 backend::VectorBackend, 23 - lance::{LanceError, db_exists as lancedb_exists, get_column_from_batch}, 22 + lance::{LanceBackend, LanceError, get_column_from_batch}, 24 23 }; 25 24 26 25 use crate::izip; 27 - use crate::utils::arrow::{DbFields, lance_backend}; 26 + use crate::utils::arrow::DbFields; 28 27 29 28 /// Gets the Zotero library path. Works on Linux, macOS, and Windows systems. 30 29 /// On CI environments, returns a location to a toy library in assets/ instead. ··· 184 183 /// columns from the result set could not be parsed, or `query_map` fails. 185 184 /// * `LibraryParsingError::LanceDBError` if fetching the rows from LanceDB fails. 186 185 pub async fn get_new_library_items( 187 - embedding_config: &EmbeddingProviderConfig, 186 + backend: &LanceBackend, 188 187 ) -> Result<Vec<ZoteroItemMetadata>, LibraryParsingError> { 189 - let backend = lance_backend(embedding_config.clone()).await; 190 188 let db_items = backend 191 189 .get_items(&[ 192 190 DbFields::LibraryKey.into(), ··· 428 426 /// * If the threads could not be joined. 429 427 #[allow(clippy::too_many_lines)] 430 428 pub async fn parse_library( 431 - embedding_config: &EmbeddingProviderConfig, 429 + backend: &LanceBackend, 432 430 start_from: Option<usize>, 433 431 limit: Option<usize>, 434 432 ) -> Result<Vec<ZoteroItem>, LibraryParsingError> { 435 433 let start_time = Instant::now(); 436 434 437 - let metadata = if lancedb_exists().await { 438 - get_new_library_items(embedding_config).await? 435 + let metadata = if backend.db_exists().await { 436 + get_new_library_items(backend).await? 439 437 } else { 440 438 parse_library_metadata(start_from, limit)? 441 439 }; ··· 631 629 632 630 #[cfg(test)] 633 631 mod tests { 632 + use std::sync::Arc; 633 + 634 634 use dotenv::dotenv; 635 635 use zqa_macros::{test_eq, test_ok}; 636 636 use zqa_rag::{ ··· 639 639 DEFAULT_VOYAGE_EMBEDDING_DIM, DEFAULT_VOYAGE_EMBEDDING_MODEL, 640 640 DEFAULT_VOYAGE_RERANK_MODEL, 641 641 }, 642 + embedding::common::EmbeddingProviderConfig, 642 643 }; 643 644 644 645 use super::*; ··· 719 720 // SAFETY: single-threaded async test, no concurrent env var access 720 721 unsafe { env::set_var("LANCEDB_URI", &db_uri) }; 721 722 722 - let items = parse_library( 723 - &EmbeddingProviderConfig::VoyageAI(VoyageAIConfig { 724 - embedding_model: DEFAULT_VOYAGE_EMBEDDING_MODEL.into(), 725 - embedding_dims: DEFAULT_VOYAGE_EMBEDDING_DIM as usize, 726 - api_key: env::var("VOYAGE_AI_API_KEY").expect("VOYAGE_AI_API_KEY not set"), 727 - reranker: DEFAULT_VOYAGE_RERANK_MODEL.into(), 728 - }), 729 - Some(0), 730 - Some(7), 731 - ) 732 - .await; 723 + let embedding_config = EmbeddingProviderConfig::VoyageAI(VoyageAIConfig { 724 + embedding_model: DEFAULT_VOYAGE_EMBEDDING_MODEL.into(), 725 + embedding_dims: DEFAULT_VOYAGE_EMBEDDING_DIM as usize, 726 + api_key: env::var("VOYAGE_AI_API_KEY").expect("VOYAGE_AI_API_KEY not set"), 727 + reranker: DEFAULT_VOYAGE_RERANK_MODEL.into(), 728 + }); 729 + let schema = Arc::new(arrow_schema::Schema::new(vec![ 730 + arrow_schema::Field::new("library_key", arrow_schema::DataType::Utf8, false), 731 + arrow_schema::Field::new("title", arrow_schema::DataType::Utf8, false), 732 + arrow_schema::Field::new("file_path", arrow_schema::DataType::Utf8, false), 733 + arrow_schema::Field::new("pdf_text", arrow_schema::DataType::Utf8, false), 734 + ])); 735 + let backend = LanceBackend::new(embedding_config, schema, "pdf_text".into()); 736 + let items = parse_library(&backend, Some(0), Some(7)).await; 733 737 test_ok!(items); 734 738 735 739 // Two of the items in the toy library are HTML files, so we actually
+9 -8
zqa/tests/new_library.rs
··· 49 49 }), 50 50 }; 51 51 52 - let record_batch = full_library_to_arrow(&config, None, None).await; 52 + let embedding_config = config.get_embedding_config().unwrap(); 53 + let schema = zqa::utils::arrow::get_schema(embedding_config.provider()).await; 54 + let backend = LanceBackend::new( 55 + embedding_config, 56 + std::sync::Arc::new(schema), 57 + "pdf_text".into(), 58 + ); 59 + 60 + let record_batch = full_library_to_arrow(&backend, None, None).await; 53 61 test_ok!(record_batch); 54 62 55 63 let record_batch = record_batch.unwrap(); 56 - let _schema = record_batch.schema(); 57 64 let batches = vec![record_batch.clone()]; 58 - 59 - let backend = LanceBackend::new( 60 - config.get_embedding_config().unwrap(), 61 - record_batch.schema(), 62 - "pdf_text".into(), 63 - ); 64 65 let db = backend.insert_items(batches, None).await; 65 66 66 67 test_ok!(db);