Another remote for gh:zotero-rag/zotero-rag
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

Merge pull request #231 from zotero-rag/04-26-refactor_zqa_add_store_abstraction_implement_for_lancedb

refactor(zqa): add Store abstraction, implement for LanceDB

authored by

Rahul Yedida and committed by
GitHub
e06533f0 8b05a7b9

+576 -339
+1
Cargo.lock
··· 6680 6680 "arrow-array", 6681 6681 "arrow-ipc", 6682 6682 "arrow-schema", 6683 + "async-trait", 6683 6684 "chrono", 6684 6685 "clap", 6685 6686 "crossbeam-channel",
+1
zqa/Cargo.toml
··· 18 18 arrow-array = { version = "^57.2.0", default-features = false } 19 19 arrow-ipc = { version = "^57.2.0", default-features = false } 20 20 arrow-schema = { version = "^57.2.0", default-features = false} 21 + async-trait = "0.1.89" 21 22 chrono = "^0.4.44" 22 23 clap = { version = "^4.6.0", default-features = false, features = ["derive", "help", "usage", "suggestions", "error-context", "std"] } 23 24 crossbeam-channel = "0.5.15"
+10 -10
zqa/src/cli/app.rs
··· 166 166 use zqa_rag::embedding::common::EmbeddingProviderConfig; 167 167 use zqa_rag::llm::tools::Tool; 168 168 use zqa_rag::reranking::common::RerankProviderConfig; 169 - use zqa_rag::vector::backends::lance::LanceBackend; 170 169 171 170 use super::dispatch_command; 172 171 use crate::common::Context; 173 172 use crate::common::State; 174 173 use crate::config::{Config, VoyageAIConfig}; 174 + use crate::store::lance::LanceZoteroStore; 175 175 use crate::tools::retrieval::RetrievalTool; 176 176 177 177 pub(crate) fn get_config() -> Config { ··· 207 207 208 208 let config = get_config(); 209 209 210 + let embedding_config = config.get_embedding_config().unwrap(); 211 + 210 212 Context { 211 213 state: State::default(), 212 - backend: LanceBackend::new( 213 - config.get_embedding_config().unwrap(), 214 - Arc::new(schema), 215 - "pdf_text".into(), 216 - ), 214 + store: LanceZoteroStore::from_schema(embedding_config, schema.into()), 217 215 config, 218 216 out, 219 217 err, 220 218 } 221 219 } 222 220 223 - fn make_retrieval_tool(_schema_key: &str) -> RetrievalTool { 221 + fn make_retrieval_tool(_schema_key: &str) -> RetrievalTool<LanceZoteroStore> { 224 222 let api_key = std::env::var("VOYAGE_AI_API_KEY").unwrap_or_default(); 225 223 let config = zqa_rag::config::VoyageAIConfig { 226 224 api_key, ··· 234 232 arrow_schema::Field::new("file_path", arrow_schema::DataType::Utf8, false), 235 233 arrow_schema::Field::new("pdf_text", arrow_schema::DataType::Utf8, false), 236 234 ])); 237 - let backend = LanceBackend::new( 235 + let store = LanceZoteroStore::from_schema( 238 236 EmbeddingProviderConfig::VoyageAI(config.clone()), 239 237 schema, 240 - "pdf_text".into(), 241 238 ); 242 - RetrievalTool::new(backend, Some(RerankProviderConfig::VoyageAI(config))) 239 + RetrievalTool::new( 240 + Arc::new(store), 241 + Some(RerankProviderConfig::VoyageAI(config)), 242 + ) 243 243 } 244 244 245 245 #[retry(3)]
+21 -34
zqa/src/cli/handlers/library.rs
··· 5 5 6 6 use arrow_array::RecordBatch; 7 7 use arrow_ipc::{reader::FileReader, writer::FileWriter}; 8 - use zqa_rag::vector::backends::backend::VectorBackend; 9 8 use zqa_rag::vector::checkhealth::lancedb_health_check; 10 9 use zqa_rag::vector::doctor::doctor as rag_doctor; 11 10 ··· 14 13 cli::{app::BATCH_ITER_FILE, errors::CLIError}, 15 14 common::Context, 16 15 full_library_to_arrow, 16 + store::common::ZoteroStore, 17 17 utils::{ 18 - arrow::{DbFields, library_to_arrow}, 18 + arrow::library_to_arrow, 19 19 library::{ZoteroItem, ZoteroItemSet, get_new_library_items, parse_library_metadata}, 20 20 }, 21 21 }; ··· 39 39 O: Write, 40 40 E: Write, 41 41 { 42 - match ctx.backend.get_metadata().await { 42 + match ctx.store.get_metadata().await { 43 43 Ok(stats) => writeln!(&mut ctx.out, "{stats}")?, 44 44 Err(e) => writeln!(&mut ctx.err, "Could not get database statistics: {e}")?, 45 45 } ··· 73 73 { 74 74 const WARNING_THRESHOLD: usize = 100; 75 75 76 - let item_metadata = if ctx.backend.db_exists().await { 77 - get_new_library_items(&ctx.backend).await 76 + let item_metadata = if ctx.store.exists().await { 77 + get_new_library_items(&ctx.store).await 78 78 } else { 79 79 parse_library_metadata(None, None) 80 80 }; ··· 106 106 } 107 107 } 108 108 109 - let record_batch = full_library_to_arrow(&ctx.backend, None, None).await?; 109 + let record_batch = full_library_to_arrow(&ctx.store, None, None).await?; 110 110 let schema = record_batch.schema(); 111 111 let batches = vec![record_batch.clone()]; 112 112 ··· 117 117 writer.write(&record_batch)?; 118 118 writer.finish()?; 119 119 120 - let result = ctx 121 - .backend 122 - .insert_items(batches, Some(&[DbFields::LibraryKey.as_ref()])) 123 - .await; 120 + let result = ctx.store.upsert_batches(batches).await; 124 121 125 122 match result { 126 123 Ok(()) => { ··· 195 192 } 196 193 writeln!(ctx.out, ".")?; 197 194 198 - let db = ctx 199 - .backend 200 - .insert_items(batches, Some(&[DbFields::LibraryKey.as_ref()])) 201 - .await; 195 + let db = ctx.store.upsert_batches(batches).await; 202 196 203 197 if db.is_ok() { 204 198 writeln!(ctx.out, "Successfully parsed library!")?; ··· 234 228 O: Write, 235 229 E: Write, 236 230 { 237 - let result = ctx 238 - .backend 239 - .dedup_rows(DbFields::Title.as_ref(), DbFields::LibraryKey.as_ref()) 240 - .await; 231 + let result = ctx.store.dedup_by_title().await; 241 232 242 233 match result { 243 234 Ok(count) => { ··· 276 267 "Updating indices. This may take a while depending on how many items need to be added." 277 268 )?; 278 269 279 - if let Err(e) = ctx 280 - .backend 281 - .create_or_update_indices(DbFields::PdfText.as_ref(), DbFields::Embeddings.as_ref()) 282 - .await 283 - { 270 + if let Err(e) = ctx.store.create_or_update_indices().await { 284 271 writeln!(&mut ctx.err, "Failed to update indexes: {e}")?; 285 272 } 286 273 ··· 414 401 .map(|item| item.metadata.library_key.clone()) 415 402 .collect(); 416 403 417 - ctx.backend 418 - .delete_rows(DbFields::LibraryKey.as_ref(), &zero_subset_keys) 419 - .await?; 404 + ctx.store.delete_by_library_keys(&zero_subset_keys).await?; 420 405 421 406 writeln!( 422 407 ctx.out, ··· 427 412 return Ok(()); 428 413 } 429 414 430 - let nonempty_zero_subset_batch = 431 - library_to_arrow(nonempty_zero_subset, embedding_config.clone()).await?; 415 + let include_embeddings = ctx.store.exists().await; 416 + let nonempty_zero_subset_batch = library_to_arrow( 417 + nonempty_zero_subset, 418 + embedding_config.clone(), 419 + include_embeddings, 420 + ) 421 + .await?; 432 422 433 423 let batches = vec![nonempty_zero_subset_batch.clone()]; 434 424 435 - ctx.backend 436 - .insert_items(batches, Some(&[DbFields::LibraryKey.as_ref()])) 437 - .await?; 425 + ctx.store.upsert_batches(batches).await?; 438 426 439 427 writeln!(ctx.out, "Successfully fixed zero embeddings!\n")?; 440 428 ··· 541 529 .await; 542 530 let output = String::from_utf8(ctx.out.into_inner()).unwrap(); 543 531 test_ok!(stats); 544 - assert!(stats.is_ok()); 545 - assert!(output.contains("LanceDB Statistics:")); 546 - assert!(output.contains("Number of rows: 8")); 532 + test_contains!(output, "LanceDB Statistics:"); 533 + test_contains!(output, "Number of rows: 8"); 547 534 548 535 if fs::metadata(BATCH_ITER_FILE).is_ok() { 549 536 fs::remove_file(BATCH_ITER_FILE).expect("Failed to clean up BATCH_ITER_FILE");
+16 -13
zqa/src/cli/handlers/query.rs
··· 20 20 providers::registry::provider_registry, 21 21 }; 22 22 23 + use crate::store::common::ZoteroStore; 23 24 use crate::{ 24 25 cli::{ 25 26 errors::CLIError, ··· 29 30 common::Context, 30 31 tools::{retrieval::RetrievalTool, summarization::SummarizationTool}, 31 32 utils::{ 32 - arrow::vector_search, 33 33 library::get_authors, 34 34 rag::ModelResponse, 35 35 terminal::{DIM_TEXT, RESET}, ··· 86 86 } 87 87 88 88 let vector_search_start = Instant::now(); 89 - let (mut search_results, _) = vector_search( 90 - search_term.clone(), 91 - &ctx.backend, 92 - ctx.config.get_reranker_config().as_ref(), 93 - ) 94 - .await?; 89 + let (mut search_results, _) = ctx 90 + .store 91 + .vector_search( 92 + search_term.clone(), 93 + 10, 94 + ctx.config.get_reranker_config().as_ref(), 95 + ) 96 + .await?; 95 97 let _ = get_authors(&mut search_results); 96 98 97 99 let vector_search_duration = vector_search_start.elapsed(); ··· 208 210 .as_ref() 209 211 .map(|c| (c.provider_name().to_string(), c.model_name().to_string())); 210 212 211 - let retrieval_tool = RetrievalTool::new(ctx.backend.clone(), reranker_config); 212 - let retrieval_embedding_chars = std::sync::Arc::clone(&retrieval_tool.embedding_chars); 213 - let retrieval_rerank_chars = std::sync::Arc::clone(&retrieval_tool.rerank_chars); 213 + let store_arc = std::sync::Arc::new(ctx.store.clone()); 214 + let retrieval_tool = RetrievalTool::new(std::sync::Arc::clone(&store_arc), reranker_config); 215 + let retrieval_embedding_tokens = std::sync::Arc::clone(&retrieval_tool.embedding_tokens); 216 + let retrieval_rerank_tokens = std::sync::Arc::clone(&retrieval_tool.rerank_tokens); 214 217 215 - let summarization_tool = SummarizationTool::new(llm_client.clone(), ctx.backend.clone()); 218 + let summarization_tool = SummarizationTool::new(llm_client.clone(), store_arc); 216 219 let summarization_tool_clone = summarization_tool.clone(); 217 220 let mut tools: Vec<Box<dyn Tool>> = 218 221 vec![Box::new(retrieval_tool), Box::new(summarization_tool)]; ··· 289 292 } 290 293 291 294 // Add embedding cost to session cost 292 - let emb_chars = retrieval_embedding_chars.load(atomic::Ordering::Relaxed); 295 + let emb_chars = retrieval_embedding_tokens.load(atomic::Ordering::Relaxed); 293 296 if emb_chars > 0 { 294 297 let emb_provider = embedding_provider_name.clone(); 295 298 let emb_model = embedding_model_name.clone(); ··· 313 316 } 314 317 315 318 // Add reranker cost to session cost 316 - let rerank_chars_val = retrieval_rerank_chars.load(atomic::Ordering::Relaxed); 319 + let rerank_chars_val = retrieval_rerank_tokens.load(atomic::Ordering::Relaxed); 317 320 if rerank_chars_val > 0 318 321 && let Some((rerank_provider, rerank_model)) = reranker_provider_and_model 319 322 {
+4 -4
zqa/src/common.rs
··· 12 12 use humantime; 13 13 use log::LevelFilter; 14 14 use zqa_pdftools::parse::ExtractedContent; 15 - use zqa_rag::{llm::base::ChatHistoryItem, vector::backends::lance::LanceBackend}; 15 + use zqa_rag::llm::base::ChatHistoryItem; 16 16 17 - use crate::config::Config; 17 + use crate::{config::Config, store::lance::LanceZoteroStore}; 18 18 19 19 #[derive(Parser, Clone, Debug)] 20 20 #[command(version, about, long_about = None)] ··· 70 70 pub(crate) state: State, 71 71 /// Config from TOML and env 72 72 pub(crate) config: Config, 73 - /// The backend to use for storage and retrieval 74 - pub(crate) backend: LanceBackend, 73 + /// The store to use for storage and retrieval 74 + pub(crate) store: LanceZoteroStore, 75 75 /// Abstraction for `stdout()` 76 76 pub(crate) out: OutStream, 77 77 /// Abstraction for `stderr()`
+8 -19
zqa/src/lib.rs
··· 3 3 #![allow(clippy::cast_precision_loss)] 4 4 #![allow(clippy::cast_possible_wrap)] 5 5 6 - use std::{ 7 - io::{self, IsTerminal, stderr, stdout}, 8 - sync::Arc, 9 - }; 6 + use std::io::{self, IsTerminal, stderr, stdout}; 10 7 11 8 use clap::Parser; 12 9 ··· 14 11 pub mod common; 15 12 pub mod config; 16 13 pub mod state; 14 + pub mod store; 17 15 pub mod tools; 18 16 pub mod utils; 19 17 ··· 22 20 use common::{Args, Context, setup_logger}; 23 21 use config::Config; 24 22 use state::{check_or_create_first_run_file, oobe}; 23 + pub use store::lance::LanceZoteroStore; 25 24 pub use utils::arrow::full_library_to_arrow; 26 25 use zqa_rag::{ 27 26 config::LLMClientConfig, embedding::common::EmbeddingProviderConfig, 28 - reranking::common::RerankProviderConfig, vector::backends::lance::LanceBackend, 27 + reranking::common::RerankProviderConfig, 29 28 }; 30 29 31 30 use crate::{ 32 31 cli::errors::CLIError, 33 32 common::State, 34 - utils::{ 35 - arrow::get_schema, 36 - terminal::{RED, RED_BOLD, RESET, YELLOW, YELLOW_BOLD}, 37 - }, 33 + utils::terminal::{RED, RED_BOLD, RESET, YELLOW, YELLOW_BOLD}, 38 34 }; 39 35 40 36 fn load_config() -> Result<Config, CLIError> { ··· 211 207 } 212 208 } 213 209 214 - let schema = get_schema(config.embedding_provider).await; 215 - 210 + let store = LanceZoteroStore::from_config(&config).await?; 216 211 let context = Context { 217 212 state: State::default(), 218 - config: config.clone(), 219 - backend: LanceBackend::new( 220 - config.get_embedding_config().ok_or_else(|| { 221 - CLIError::ConfigError("No embedding provider configured".to_string()) 222 - })?, 223 - Arc::new(schema), 224 - "pdf_text".into(), 225 - ), 213 + config, 214 + store, 226 215 out: stdout(), 227 216 err: stderr(), 228 217 };
+55
zqa/src/store/common.rs
··· 1 + use async_trait::async_trait; 2 + use zqa_rag::reranking::common::RerankProviderConfig; 3 + 4 + use crate::utils::library::{ZoteroItem, ZoteroItemMetadata}; 5 + 6 + /// Token statistics from a vector search call, used for cost estimation. 7 + pub struct VectorSearchStats { 8 + /// Number of tokens in the query string that was embedded 9 + pub(crate) embedding_tokens: usize, 10 + /// Total tokens of documents + query sent to the reranker 11 + pub(crate) rerank_tokens: usize, 12 + } 13 + 14 + /// An application-level trait for Zotero store implementations. 15 + #[async_trait] 16 + pub trait ZoteroStore: Send + Sync { 17 + /// The error type returned by store operations. 18 + type StoreError: std::error::Error + Send + Sync; 19 + /// The metadata type associated with the store. 20 + type Metadata; 21 + 22 + /// Returns `true` if the store exists, `false` otherwise. Useful to check that the store is 23 + /// configured correctly. 24 + async fn exists(&self) -> bool; 25 + /// Returns the metadata associated with the store. 26 + async fn get_metadata(&self) -> Result<Self::Metadata, Self::StoreError>; 27 + /// Returns the metadata for all existing items in the store. This is useful for operations 28 + /// such as set differences (e.g., finding newly-added items). 29 + async fn existing_item_metadata(&self) -> Result<Vec<ZoteroItemMetadata>, Self::StoreError>; 30 + /// Performs a vector search on the store, returning the top `limit` results. 31 + async fn vector_search( 32 + &self, 33 + query: String, 34 + limit: usize, 35 + reranker_config: Option<&RerankProviderConfig>, 36 + ) -> Result<(Vec<ZoteroItem>, VectorSearchStats), Self::StoreError>; 37 + /// Upserts the given items into the store. 38 + async fn upsert_items(&self, items: Vec<ZoteroItem>) -> Result<(), Self::StoreError>; 39 + /// Searches the store for items matching the given query, returning the top `limit` results. 40 + /// This variant does not perform reranking. 41 + async fn vector_search_raw( 42 + &self, 43 + query: &str, 44 + limit: usize, 45 + ) -> Result<Vec<ZoteroItem>, Self::StoreError>; 46 + /// Returns the items with the given keys from the store. This is useful for retrieving 47 + /// items by their library keys without performing a full text search. 48 + async fn get_items_by_keys(&self, keys: &[String]) 49 + -> Result<Vec<ZoteroItem>, Self::StoreError>; 50 + /// Deletes the items with the given keys from the store. 51 + async fn delete_by_library_keys(&self, keys: &[String]) -> Result<(), Self::StoreError>; 52 + /// Deletes duplicate items from the store based on their title. Returns the number of items 53 + /// deleted. 54 + async fn dedup_by_title(&self) -> Result<usize, Self::StoreError>; 55 + }
+310
zqa/src/store/lance.rs
··· 1 + use std::sync::Arc; 2 + 3 + use arrow_array::RecordBatch; 4 + use arrow_schema::Schema; 5 + use async_trait::async_trait; 6 + use zqa_rag::{ 7 + embedding::common::EmbeddingProviderConfig, 8 + reranking::common::{RerankProviderConfig, get_reranking_provider_with_config}, 9 + vector::backends::{ 10 + backend::VectorBackend, 11 + lance::{LanceBackend, LanceMetadata}, 12 + }, 13 + }; 14 + 15 + use crate::store::common::VectorSearchStats; 16 + use crate::{ 17 + cli::errors::CLIError, 18 + config::Config, 19 + store::common::ZoteroStore, 20 + utils::{ 21 + arrow::{DbFields, get_schema, library_to_arrow}, 22 + library::{ZoteroItem, ZoteroItemSet}, 23 + }, 24 + }; 25 + 26 + /// Zotero-specific store backed by LanceDB. 27 + #[derive(Clone)] 28 + pub struct LanceZoteroStore { 29 + backend: LanceBackend, 30 + embedding_config: EmbeddingProviderConfig, 31 + } 32 + 33 + impl LanceZoteroStore { 34 + /// Create a new Lance-backed Zotero store from an existing backend and embedding config. 35 + #[must_use] 36 + fn new(backend: LanceBackend, embedding_config: EmbeddingProviderConfig) -> Self { 37 + Self { 38 + backend, 39 + embedding_config, 40 + } 41 + } 42 + 43 + /// Create a Lance-backed Zotero store from an embedding config and Arrow schema. 44 + #[must_use] 45 + pub fn from_schema(embedding_config: EmbeddingProviderConfig, schema: Arc<Schema>) -> Self { 46 + let backend = LanceBackend::new( 47 + embedding_config.clone(), 48 + schema, 49 + DbFields::PdfText.as_ref().to_string(), 50 + ); 51 + 52 + Self::new(backend, embedding_config) 53 + } 54 + 55 + /// Get a read-only embedding config 56 + #[must_use] 57 + pub fn get_embedding_config(&self) -> EmbeddingProviderConfig { 58 + self.embedding_config.clone() 59 + } 60 + 61 + /// Create a Lance-backed Zotero store from an embedding configuration. 62 + pub async fn from_embedding_config(embedding_config: EmbeddingProviderConfig) -> Self { 63 + let schema = Arc::new(get_schema(embedding_config.provider(), true).await); 64 + Self::from_schema(embedding_config, schema) 65 + } 66 + 67 + /// Create a Lance-backed Zotero store from the application config. 68 + /// 69 + /// # Errors 70 + /// 71 + /// Returns a [`CLIError`] if no embedding configuration is available. 72 + pub(crate) async fn from_config(config: &Config) -> Result<Self, CLIError> { 73 + let embedding_config = config.get_embedding_config().ok_or(CLIError::ConfigError( 74 + "Could not get embedding config".into(), 75 + ))?; 76 + 77 + Ok(Self::from_embedding_config(embedding_config).await) 78 + } 79 + 80 + /// Upsert Arrow record batches into the LanceDB table by Zotero library key. 81 + /// 82 + /// TODO: We should probably deprecate this at some point in favor of the `upsert_items` from 83 + /// the trait. I'm keeping this around for now to keep refactor scopes relatively manageable. 84 + /// Ideally, we would not have any Lance-specific architecture, but currently, commands such as 85 + /// `/process` rely on this. 86 + /// 87 + /// # Errors 88 + /// 89 + /// Returns a [`CLIError`] if LanceDB insertion fails. 90 + pub(crate) async fn upsert_batches(&self, batches: Vec<RecordBatch>) -> Result<(), CLIError> { 91 + self.backend 92 + .insert_items(batches, Some(&[DbFields::LibraryKey.as_ref()])) 93 + .await 94 + .map_err(Into::into) 95 + } 96 + 97 + /// Create or update retrieval indices for the LanceDB table. 98 + /// 99 + /// # Errors 100 + /// 101 + /// Returns a [`CLIError`] if index creation or update fails. 102 + pub(crate) async fn create_or_update_indices(&self) -> Result<(), CLIError> { 103 + self.backend 104 + .create_or_update_indices(DbFields::PdfText.as_ref(), DbFields::Embeddings.as_ref()) 105 + .await 106 + .map_err(Into::into) 107 + } 108 + } 109 + 110 + #[async_trait] 111 + impl ZoteroStore for LanceZoteroStore { 112 + type StoreError = CLIError; 113 + type Metadata = LanceMetadata; 114 + 115 + async fn exists(&self) -> bool { 116 + self.backend.db_exists().await 117 + } 118 + 119 + /// Perform vector search and optional reranking. 120 + /// 121 + /// # Errors 122 + /// 123 + /// Returns a [`CLIError`] if search or reranking fails. 124 + async fn vector_search( 125 + &self, 126 + query: String, 127 + limit: usize, 128 + reranker_config: Option<&RerankProviderConfig>, 129 + ) -> Result<(Vec<ZoteroItem>, VectorSearchStats), CLIError> { 130 + let embedding_tokens = query.len(); 131 + let items = <Self as ZoteroStore>::vector_search_raw(self, &query, limit).await?; 132 + 133 + let filtered_items: Vec<ZoteroItem> = items 134 + .into_iter() 135 + .filter(|item| !item.text.trim().is_empty()) 136 + .collect(); 137 + 138 + if filtered_items.is_empty() { 139 + return Ok(( 140 + Vec::new(), 141 + VectorSearchStats { 142 + embedding_tokens, 143 + rerank_tokens: 0, 144 + }, 145 + )); 146 + } 147 + 148 + let Some(reranker) = reranker_config else { 149 + return Ok(( 150 + filtered_items, 151 + VectorSearchStats { 152 + embedding_tokens, 153 + rerank_tokens: 0, 154 + }, 155 + )); 156 + }; 157 + 158 + let rerank_provider = get_reranking_provider_with_config(reranker)?; 159 + let item_strings = filtered_items 160 + .iter() 161 + .map(|f| f.text.as_str()) 162 + .collect::<Vec<_>>(); 163 + 164 + let rerank_tokens = item_strings.iter().map(|s| s.len()).sum::<usize>() + query.len(); 165 + let indices = rerank_provider.rerank(&item_strings, &query).await?; 166 + 167 + let reranked_items = indices 168 + .into_iter() 169 + .filter_map(|idx| filtered_items.get(idx).cloned()) 170 + .collect(); 171 + 172 + Ok(( 173 + reranked_items, 174 + VectorSearchStats { 175 + embedding_tokens, 176 + rerank_tokens, 177 + }, 178 + )) 179 + } 180 + 181 + /// Return metadata for Zotero items that already exist in the store. 182 + /// 183 + /// # Errors 184 + /// 185 + /// Returns a [`CLIError`] if the existing rows cannot be fetched. 186 + async fn existing_item_metadata( 187 + &self, 188 + ) -> Result<Vec<crate::utils::library::ZoteroItemMetadata>, CLIError> { 189 + let db_items = self 190 + .backend 191 + .get_items(&[ 192 + DbFields::LibraryKey.into(), 193 + DbFields::Title.into(), 194 + DbFields::FilePath.into(), 195 + ]) 196 + .await?; 197 + 198 + Ok(db_items 199 + .iter() 200 + .flat_map(|batch| { 201 + let library_keys = crate::utils::library::get_column_from_batch(batch, 0); 202 + let titles = crate::utils::library::get_column_from_batch(batch, 1); 203 + let file_paths = crate::utils::library::get_column_from_batch(batch, 2); 204 + 205 + crate::izip!(library_keys, titles, file_paths) 206 + .map( 207 + |(key, title, path)| crate::utils::library::ZoteroItemMetadata { 208 + library_key: key, 209 + title, 210 + file_path: std::path::PathBuf::from(path), 211 + authors: None, 212 + }, 213 + ) 214 + .collect::<Vec<_>>() 215 + }) 216 + .collect()) 217 + } 218 + 219 + /// Return metadata for the underlying LanceDB table. 220 + /// 221 + /// # Errors 222 + /// 223 + /// Returns a [`CLIError`] if LanceDB metadata could not be read. 224 + async fn get_metadata(&self) -> Result<LanceMetadata, CLIError> { 225 + self.backend.get_metadata().await.map_err(Into::into) 226 + } 227 + 228 + /// Upserts the given items into the store. 229 + /// 230 + /// # Arguments 231 + /// 232 + /// * `items` - The items to upsert. 233 + /// 234 + /// # Errors 235 + /// 236 + /// Returns a [`CLIError`] if the upsert fails. 237 + async fn upsert_items(&self, items: Vec<ZoteroItem>) -> Result<(), Self::StoreError> { 238 + let include_embeddings = self.exists().await; 239 + let batch = 240 + library_to_arrow(items, self.embedding_config.clone(), include_embeddings).await?; 241 + self.upsert_batches(vec![batch]).await 242 + } 243 + 244 + /// Performs a raw vector search on the store, returning the top `limit` results. 245 + /// 246 + /// # Arguments 247 + /// 248 + /// * `query` - The query string. 249 + /// * `limit` - The maximum number of results to return. 250 + /// 251 + /// # Errors 252 + /// 253 + /// Returns a [`CLIError`] if the search fails. 254 + async fn vector_search_raw( 255 + &self, 256 + query: &str, 257 + limit: usize, 258 + ) -> Result<Vec<ZoteroItem>, Self::StoreError> { 259 + let batches = self.backend.vector_search(query.to_string(), limit).await?; 260 + Ok(ZoteroItemSet::from(batches).into()) 261 + } 262 + 263 + /// Returns the items with the given keys from the store. 264 + /// 265 + /// # Arguments 266 + /// 267 + /// * `keys` - The keys of the items to return. 268 + /// 269 + /// # Errors 270 + /// 271 + /// Returns a [`CLIError`] if the search fails. 272 + async fn get_items_by_keys( 273 + &self, 274 + keys: &[String], 275 + ) -> Result<Vec<ZoteroItem>, Self::StoreError> { 276 + let batches = self 277 + .backend 278 + .search_by_column(DbFields::LibraryKey.as_ref(), keys) 279 + .await?; 280 + Ok(ZoteroItemSet::from(batches).into()) 281 + } 282 + 283 + /// Deletes the items with the given keys from the store. 284 + /// 285 + /// # Arguments 286 + /// 287 + /// * `keys` - The keys of the items to delete. 288 + /// 289 + /// # Errors 290 + /// 291 + /// Returns a [`CLIError`] if the deletion fails. 292 + async fn delete_by_library_keys(&self, keys: &[String]) -> Result<(), Self::StoreError> { 293 + self.backend 294 + .delete_rows(DbFields::LibraryKey.as_ref(), keys) 295 + .await 296 + .map_err(Into::into) 297 + } 298 + 299 + /// Deduplicates items in the store by title, keeping the first occurrence. 300 + /// 301 + /// # Errors 302 + /// 303 + /// Returns a [`CLIError`] if the deduplication fails. 304 + async fn dedup_by_title(&self) -> Result<usize, Self::StoreError> { 305 + self.backend 306 + .dedup_rows(DbFields::Title.as_ref(), DbFields::LibraryKey.as_ref()) 307 + .await 308 + .map_err(Into::into) 309 + } 310 + }
+4
zqa/src/store/mod.rs
··· 1 + //! Zotero store implementations. This module acts as a bridge between this crate and the backends in `zqa_rag`. 2 + 3 + pub mod common; 4 + pub mod lance;
+40 -35
zqa/src/tools/retrieval.rs
··· 9 9 use schemars::{JsonSchema, schema_for}; 10 10 use serde::Deserialize; 11 11 use serde_json::json; 12 - use zqa_rag::{ 13 - llm::tools::Tool, reranking::common::RerankProviderConfig, 14 - vector::backends::lance::LanceBackend, 15 - }; 12 + use zqa_rag::{llm::tools::Tool, reranking::common::RerankProviderConfig}; 16 13 14 + use crate::store::common::ZoteroStore; 17 15 use crate::utils::{ 18 - arrow::vector_search, 19 16 library::get_authors, 20 17 terminal::{DIM_TEXT, RESET}, 21 18 }; ··· 24 21 25 22 /// A tool to perform vector search and reranking. 26 23 #[derive(Debug)] 27 - pub(crate) struct RetrievalTool { 28 - /// The backend used for vector search. 29 - pub(crate) backend: LanceBackend, 24 + pub(crate) struct RetrievalTool<T> 25 + where 26 + T: ZoteroStore, 27 + { 28 + /// The vector store abstraction 29 + pub(crate) store: Arc<T>, 30 30 /// The reranker provider to use. 31 31 pub(crate) reranker_config: Option<RerankProviderConfig>, 32 - /// Accumulated character count of text sent to the embedding API across all calls. 33 - pub(crate) embedding_chars: Arc<AtomicU64>, 34 - /// Accumulated character count of text sent to the reranker API across all calls. 35 - pub(crate) rerank_chars: Arc<AtomicU64>, 32 + /// Accumulated token count of text sent to the embedding API across all calls. 33 + pub(crate) embedding_tokens: Arc<AtomicU64>, 34 + /// Accumulated token count of text sent to the reranker API across all calls. 35 + pub(crate) rerank_tokens: Arc<AtomicU64>, 36 36 } 37 37 38 - impl RetrievalTool { 38 + impl<T> RetrievalTool<T> 39 + where 40 + T: ZoteroStore, 41 + { 39 42 /// Create a new instance of the [`RetrievalTool`] given a backend and reranker config. 40 - pub(crate) fn new( 41 - backend: LanceBackend, 42 - reranker_provider: Option<RerankProviderConfig>, 43 - ) -> Self { 43 + pub(crate) fn new(store: Arc<T>, reranker_provider: Option<RerankProviderConfig>) -> Self { 44 44 Self { 45 - backend, 45 + store, 46 46 reranker_config: reranker_provider, 47 - embedding_chars: Arc::new(AtomicU64::new(0)), 48 - rerank_chars: Arc::new(AtomicU64::new(0)), 47 + embedding_tokens: Arc::new(AtomicU64::new(0)), 48 + rerank_tokens: Arc::new(AtomicU64::new(0)), 49 49 } 50 50 } 51 51 } ··· 56 56 pub query: String, 57 57 } 58 58 59 - impl Tool for RetrievalTool { 59 + impl<T> Tool for RetrievalTool<T> 60 + where 61 + T: ZoteroStore + 'static, 62 + { 60 63 fn name(&self) -> String { 61 64 RETRIEVAL_TOOL_NAME.into() 62 65 } ··· 89 92 ) -> std::pin::Pin<Box<dyn Future<Output = Result<serde_json::Value, String>> + Send + '_>> 90 93 { 91 94 let start = Instant::now(); 92 - let backend = self.backend.clone(); 93 95 let reranker_config = self.reranker_config.clone(); 94 - let embedding_chars = Arc::clone(&self.embedding_chars); 95 - let rerank_chars = Arc::clone(&self.rerank_chars); 96 + let embedding_tokens = Arc::clone(&self.embedding_tokens); 97 + let rerank_tokens = Arc::clone(&self.rerank_tokens); 98 + let store = Arc::clone(&self.store); 96 99 97 100 Box::pin(async move { 98 101 let input: RetrievalToolInput = 99 102 serde_json::from_value(args).map_err(|e| format!("Invalid arguments: {e}"))?; 100 - let (mut results, stats) = 101 - vector_search(input.query, &backend, reranker_config.as_ref()) 102 - .await 103 - .map_err(|e| format!("Search failed: {e}"))?; 104 - embedding_chars.fetch_add(stats.embedding_chars as u64, Ordering::Relaxed); 105 - rerank_chars.fetch_add(stats.rerank_chars as u64, Ordering::Relaxed); 103 + let (mut results, stats) = store 104 + .vector_search(input.query, 10, reranker_config.as_ref()) 105 + .await 106 + .map_err(|e| format!("Search failed: {e}"))?; 107 + embedding_tokens.fetch_add(stats.embedding_tokens as u64, Ordering::Relaxed); 108 + rerank_tokens.fetch_add(stats.rerank_tokens as u64, Ordering::Relaxed); 106 109 107 110 get_authors(&mut results).map_err(|e| format!("Failed to get authors: {e}"))?; 108 111 log::info!( ··· 142 145 DEFAULT_VOYAGE_EMBEDDING_DIM, DEFAULT_VOYAGE_EMBEDDING_MODEL, DEFAULT_VOYAGE_RERANK_MODEL, 143 146 }; 144 147 use zqa_rag::embedding::common::EmbeddingProviderConfig; 145 - use zqa_rag::vector::backends::lance::LanceBackend; 146 148 147 149 use super::*; 150 + use crate::LanceZoteroStore; 148 151 149 - fn make_tool() -> RetrievalTool { 152 + fn make_tool() -> RetrievalTool<LanceZoteroStore> { 150 153 let config = zqa_rag::config::VoyageAIConfig { 151 154 api_key: String::new(), 152 155 embedding_model: DEFAULT_VOYAGE_EMBEDDING_MODEL.into(), ··· 159 162 arrow_schema::Field::new("file_path", arrow_schema::DataType::Utf8, false), 160 163 arrow_schema::Field::new("pdf_text", arrow_schema::DataType::Utf8, false), 161 164 ])); 162 - let backend = LanceBackend::new( 165 + let store = LanceZoteroStore::from_schema( 163 166 EmbeddingProviderConfig::VoyageAI(config.clone()), 164 167 schema, 165 - "pdf_text".into(), 166 168 ); 167 - RetrievalTool::new(backend, Some(RerankProviderConfig::VoyageAI(config))) 169 + RetrievalTool::new( 170 + Arc::new(store), 171 + Some(RerankProviderConfig::VoyageAI(config)), 172 + ) 168 173 } 169 174 170 175 #[test]
+41 -39
zqa/src/tools/summarization.rs
··· 8 8 use serde::Deserialize; 9 9 use serde_json::json; 10 10 use tokio::task::JoinSet; 11 - use zqa_rag::{ 12 - llm::{ 13 - base::{ApiClient, ChatRequest, CompletionApiResponse}, 14 - errors::LLMError, 15 - factory::LLMClient, 16 - tools::Tool, 17 - }, 18 - vector::backends::{backend::VectorBackend, lance::LanceBackend}, 11 + use zqa_rag::llm::{ 12 + base::{ApiClient, ChatRequest, CompletionApiResponse}, 13 + errors::LLMError, 14 + factory::LLMClient, 15 + tools::Tool, 19 16 }; 20 17 21 18 use crate::{ 22 19 cli::prompts::get_extraction_prompt, 23 - utils::{ 24 - arrow::DbFields, 25 - library::{ZoteroItem, ZoteroItemSet}, 26 - rag::ModelResponse, 27 - }, 20 + store::common::ZoteroStore, 21 + utils::{library::ZoteroItem, rag::ModelResponse}, 28 22 }; 29 23 30 24 pub(crate) const SUMMARIZATION_TOOL_NAME: &str = "summarization_tool"; 31 25 32 26 /// A tool to summarize Zotero papers with a specified ID. 33 27 #[derive(Debug, Clone)] 34 - pub(crate) struct SummarizationTool { 28 + pub(crate) struct SummarizationTool<T: ZoteroStore> { 35 29 pub(crate) llm_client: LLMClient, 36 30 /// Backend for searching stored Zotero papers. 37 - pub(crate) backend: LanceBackend, 31 + pub(crate) store: Arc<T>, 38 32 /// The input tokens used 39 33 pub(crate) input_tokens: Arc<Mutex<u32>>, 40 34 /// The output tokens used 41 35 pub(crate) output_tokens: Arc<Mutex<u32>>, 42 36 } 43 37 44 - impl SummarizationTool { 38 + impl<T> SummarizationTool<T> 39 + where 40 + T: ZoteroStore, 41 + { 45 42 /// Create a new [`SummarizationTool`] instance, given an LLM client and a backend. 46 - pub fn new(llm_client: LLMClient, backend: LanceBackend) -> Self { 43 + pub fn new(llm_client: LLMClient, store: Arc<T>) -> Self { 47 44 Self { 48 45 llm_client, 49 - backend, 46 + store, 50 47 input_tokens: Arc::new(Mutex::new(0)), 51 48 output_tokens: Arc::new(Mutex::new(0)), 52 49 } ··· 62 59 ids: Vec<String>, 63 60 } 64 61 65 - impl Tool for SummarizationTool { 62 + impl<T> Tool for SummarizationTool<T> 63 + where 64 + T: ZoteroStore + 'static, 65 + { 66 66 fn name(&self) -> String { 67 67 SUMMARIZATION_TOOL_NAME.into() 68 68 } ··· 88 88 /// A JSON object with a `"summaries"` key mapping to a list of summary strings, 89 89 /// one per successfully processed paper, and an `"errors"` key mapping to a list 90 90 /// of error messages for papers that failed to summarize. 91 - fn call<'a>( 92 - &'a self, 91 + fn call( 92 + &self, 93 93 args: serde_json::Value, 94 - ) -> Pin<Box<dyn Future<Output = Result<serde_json::Value, String>> + Send + 'a>> { 94 + ) -> Pin<Box<dyn Future<Output = Result<serde_json::Value, String>> + Send + '_>> { 95 + let store = Arc::clone(&self.store); 96 + let input_tokens = Arc::clone(&self.input_tokens); 97 + let output_tokens = Arc::clone(&self.output_tokens); 98 + let llm_client = self.llm_client.clone(); 95 99 Box::pin(async move { 96 100 let input: SummarizationToolInput = 97 101 serde_json::from_value(args).map_err(|e| format!("Invalid arguments: {e}"))?; 98 102 99 - let results = self 100 - .backend 101 - .search_by_column(DbFields::LibraryKey.as_ref(), &input.ids) 103 + let results: Vec<ZoteroItem> = store 104 + .get_items_by_keys(&input.ids) 102 105 .await 103 106 .map_err(|e| format!("Search failed: {e}"))?; 104 107 105 - let batches: ZoteroItemSet = results.into(); 106 - let items: Vec<ZoteroItem> = batches.into(); 107 - 108 108 let mut set = JoinSet::new(); 109 - for item in items { 110 - let client = self.llm_client.clone(); 109 + for item in results { 110 + let client = llm_client.clone(); 111 111 let text = item.text; 112 112 let metadata = item.metadata; 113 113 let query_cloned = input.query.clone(); ··· 140 140 summaries.push(summary); 141 141 142 142 // Update token counts (with error handling for mutex poisoning) 143 - if let Ok(mut input_tokens) = self.input_tokens.lock() { 144 - *input_tokens += response.input_tokens; 143 + if let Ok(mut toks) = input_tokens.lock() { 144 + *toks += response.input_tokens; 145 145 } 146 - if let Ok(mut output_tokens) = self.output_tokens.lock() { 147 - *output_tokens += response.output_tokens; 146 + if let Ok(mut toks) = output_tokens.lock() { 147 + *toks += response.output_tokens; 148 148 } 149 149 } 150 150 Err(e) => { ··· 175 175 config::{AnthropicConfig, LLMClientConfig}, 176 176 constants::DEFAULT_ANTHROPIC_MODEL_SMALL, 177 177 llm::factory::get_client_with_config, 178 - vector::backends::lance::LanceBackend, 179 178 }; 180 179 181 180 use super::*; 182 - use crate::cli::app::tests::{create_test_context, get_config}; 183 181 use crate::cli::handlers::library::handle_process_cmd; 182 + use crate::{ 183 + cli::app::tests::{create_test_context, get_config}, 184 + store::lance::LanceZoteroStore, 185 + }; 184 186 185 - fn make_tool() -> SummarizationTool { 187 + fn make_tool() -> SummarizationTool<LanceZoteroStore> { 186 188 let client = get_client_with_config(&LLMClientConfig::Anthropic(AnthropicConfig { 187 189 api_key: env::var("ANTHROPIC_API_KEY").unwrap(), 188 190 model: DEFAULT_ANTHROPIC_MODEL_SMALL.into(), ··· 199 201 arrow_schema::Field::new("file_path", arrow_schema::DataType::Utf8, false), 200 202 arrow_schema::Field::new("pdf_text", arrow_schema::DataType::Utf8, false), 201 203 ])); 202 - let backend = LanceBackend::new(embedding_config, schema, "pdf_text".into()); 203 - SummarizationTool::new(client, backend) 204 + let store = LanceZoteroStore::from_schema(embedding_config, schema); 205 + SummarizationTool::new(client, Arc::new(store)) 204 206 } 205 207 206 208 #[test]
+35 -122
zqa/src/utils/arrow.rs
··· 1 - use std::sync::Arc; 1 + use std::{path::PathBuf, sync::Arc}; 2 2 3 3 use arrow_array::{ArrayRef, RecordBatch, StringArray, cast::AsArray}; 4 4 use arrow_schema; ··· 9 9 EmbeddingProviderConfig, get_embedding_dims_by_provider, get_embedding_provider_with_config, 10 10 }, 11 11 llm::errors::LLMError, 12 - reranking::common::{RerankProviderConfig, get_reranking_provider_with_config}, 13 - vector::backends::{ 14 - backend::VectorBackend, 15 - lance::{LanceBackend, LanceError, db_exists as lancedb_exists}, 16 - }, 12 + vector::backends::lance::{LANCE_TABLE_NAME, LanceError, get_db_uri}, 17 13 }; 18 14 19 15 use super::library::{LibraryParsingError, parse_library}; 20 - use crate::utils::library::{ZoteroItem, ZoteroItemSet}; 16 + use crate::{store::lance::LanceZoteroStore, utils::library::ZoteroItem}; 21 17 22 18 /// An enum containing the fields stored by our application in `LanceDB`, in order. Implementations 23 19 /// `as_ref()` and `into()` are provided to convert this to `&str` and `String` respectively. ··· 92 88 } 93 89 } 94 90 91 + /// Checks whether the configured LanceDB database exists and contains the expected table. 92 + pub(crate) async fn lancedb_exists() -> bool { 93 + let uri = get_db_uri(); 94 + if !PathBuf::from(&uri).exists() { 95 + return false; 96 + } 97 + 98 + if let Ok(db) = lancedb::connect(&uri).execute().await { 99 + db.open_table(LANCE_TABLE_NAME).execute().await.is_ok() 100 + } else { 101 + false 102 + } 103 + } 104 + 95 105 /// Get the schema for our `LanceDB` table. This is required for both getting library items and 96 106 /// checkhealth. 97 107 /// 98 108 /// # Arguments 99 109 /// 100 110 /// * `embedding_provider` - The embedding used by the current DB. 111 + /// * `include_embeddings` - Whether to include the embeddings field in the schema. 101 112 /// 102 113 /// # Returns 103 114 /// 104 115 /// The schema in Arrow format. 105 - pub async fn get_schema(embedding_provider: EmbeddingProvider) -> arrow_schema::Schema { 116 + pub async fn get_schema( 117 + embedding_provider: EmbeddingProvider, 118 + include_embeddings: bool, 119 + ) -> arrow_schema::Schema { 106 120 // Convert ZoteroItemMetadata to something that can be converted to Arrow 107 121 // Need to extract fields and create appropriate Arrow arrays 108 122 let mut schema_fields = vec![ ··· 112 126 arrow_schema::Field::new(DbFields::PdfText, arrow_schema::DataType::Utf8, false), 113 127 ]; 114 128 115 - if lancedb_exists().await { 129 + if include_embeddings { 116 130 schema_fields.push(arrow_schema::Field::new( 117 131 DbFields::Embeddings, 118 132 arrow_schema::DataType::FixedSizeList( ··· 137 151 /// 138 152 /// * `items` - The items to convert to a `RecordBatch` 139 153 /// * `embedding_config` - Configuration for the embedding provider to use when computing embeddings. 154 + /// * `include_embeddings` - Whether to include the embeddings field in the schema. 140 155 /// 141 156 /// # Errors 142 157 /// ··· 150 165 pub async fn library_to_arrow( 151 166 items: Vec<ZoteroItem>, 152 167 embedding_config: EmbeddingProviderConfig, 168 + include_embeddings: bool, 153 169 ) -> Result<RecordBatch, ArrowError> { 154 - let schema = Arc::new(get_schema(embedding_config.provider()).await); 170 + let schema = Arc::new(get_schema(embedding_config.provider(), include_embeddings).await); 155 171 156 172 // Convert ZoteroItemMetadata to Arrow arrays 157 173 let library_keys = StringArray::from( ··· 194 210 Arc::new(pdf_texts.clone()) as ArrayRef, 195 211 ]; 196 212 197 - if lancedb_exists().await { 213 + if include_embeddings { 198 214 let embedding_provider = get_embedding_provider_with_config(&embedding_config)?; 199 215 let query_vec = embedding_provider.compute_source_embeddings(Arc::new(pdf_texts))?; 200 216 let query_vec = query_vec.as_fixed_size_list(); ··· 232 248 /// 233 249 /// # Arguments 234 250 /// 235 - /// * `config` - Configuration containing embedding provider information. 251 + /// * `store` - [`LanceZoteroStore`] with configuration 236 252 /// * `start_from` - An optional offset for the SQL query. Useful for debugging, pagination, 237 253 /// multi-threading, etc. 238 254 /// * `limit` - Optional limit, meant to be used in conjunction with `start_from`. 239 255 pub async fn full_library_to_arrow( 240 - backend: &LanceBackend, 256 + store: &LanceZoteroStore, 241 257 start_from: Option<usize>, 242 258 limit: Option<usize>, 243 259 ) -> Result<RecordBatch, ArrowError> { 244 - let lib_items = parse_library(backend, start_from, limit).await?; 260 + let lib_items = parse_library(store, start_from, limit).await?; 245 261 log::info!("Finished parsing library items."); 246 262 247 - library_to_arrow(lib_items, backend.embedding_config().clone()).await 248 - } 249 - 250 - /// Statistics about the characters processed in a vector search call, used for cost estimation. 251 - pub struct VectorSearchStats { 252 - /// Number of characters in the query string that was embedded. 253 - pub embedding_chars: usize, 254 - /// Total characters of documents + query sent to the reranker (0 if no reranker was used). 255 - pub rerank_chars: usize, 256 - } 257 - 258 - /// Perform vector search using a query and a specified embedding method. 259 - /// 260 - /// This function is a Zotero-specific wrapper for the `vector_search` function in the `rag` crate. 261 - /// It is implemented here since the knowledge of which column is which in the `RecordBatch`es that 262 - /// we create is in this file, so there's better locality-of-behaviour; this also makes the 263 - /// underlying implementation of `vector_search` simpler and potentially allows other RAG 264 - /// applications to be built on top of it. 265 - /// 266 - /// TODO: A limit of 10 results is currently returned, but this will be changed in a future version. 267 - /// 268 - /// In some sense, this function is the reverse of the `library_to_arrow` function, which creates a 269 - /// `RecordBatch` from vectors after calling `parse_library`. 270 - /// 271 - /// This function also uses a reranking provider to perform reranking of the vector search results. 272 - /// 273 - /// # Arguments 274 - /// 275 - /// * `query` - The query to search the `LanceDB` table for. 276 - /// * `embedding_config` - The embedding provider configuration. Note that this must be the same 277 - /// embedding provider used when initially creating the database. 278 - /// * `reranker_config` - The reranker provider to use. 279 - /// 280 - /// # Returns 281 - /// 282 - /// A tuple of the matching `ZoteroItem`s and [`VectorSearchStats`] with character counts used for 283 - /// cost estimation. Returns an `ArrowError` that wraps the underlying `LanceError` if the `rag` 284 - /// crate's `vector_search` is unsuccessful for any reason. 285 - /// 286 - /// # Errors 287 - /// 288 - /// * `ArrowError::LanceError` if vector search fails. 289 - /// * `ArrowError::LLMError` if reranking fails. 290 - pub async fn vector_search( 291 - query: String, 292 - backend: &LanceBackend, 293 - reranker_config: Option<&RerankProviderConfig>, 294 - ) -> Result<(Vec<ZoteroItem>, VectorSearchStats), ArrowError> { 295 - let embedding_chars = query.len(); 296 - let batches = backend.vector_search(query.clone(), 10).await?; 297 - 298 - let items: ZoteroItemSet = batches.into(); 299 - let items: Vec<ZoteroItem> = items.into(); 300 - 301 - let filtered_items: Vec<ZoteroItem> = items 302 - .into_iter() 303 - .filter(|item| !item.text.trim().is_empty()) 304 - .collect(); 305 - 306 - if filtered_items.is_empty() { 307 - return Ok(( 308 - Vec::new(), 309 - VectorSearchStats { 310 - embedding_chars, 311 - rerank_chars: 0, 312 - }, 313 - )); 314 - } 315 - 316 - let Some(reranker) = reranker_config else { 317 - return Ok(( 318 - filtered_items, 319 - VectorSearchStats { 320 - embedding_chars, 321 - rerank_chars: 0, 322 - }, 323 - )); 324 - }; 325 - 326 - let rerank_provider = get_reranking_provider_with_config(reranker)?; 327 - let item_strings = filtered_items 328 - .iter() 329 - .map(|f| f.text.as_str()) 330 - .collect::<Vec<_>>(); 331 - 332 - let rerank_chars = item_strings.iter().map(|s| s.len()).sum::<usize>() + query.len(); 333 - 334 - let indices = rerank_provider.rerank(&item_strings, &query).await?; 335 - 336 - let reranked_items = indices 337 - .into_iter() 338 - .filter_map(|idx| filtered_items.get(idx).cloned()) 339 - .collect(); 340 - 341 - Ok(( 342 - reranked_items, 343 - VectorSearchStats { 344 - embedding_chars, 345 - rerank_chars, 346 - }, 347 - )) 263 + let include_embeddings = lancedb_exists().await; 264 + library_to_arrow(lib_items, store.get_embedding_config(), include_embeddings).await 348 265 } 349 266 350 267 #[cfg(test)] ··· 393 310 394 311 let record_batch = temp_env::async_with_vars([("LANCEDB_URI", Some(&db_uri))], async { 395 312 let embedding_config = config.get_embedding_config().unwrap(); 396 - let schema = Arc::new(get_schema(embedding_config.provider()).await); 397 - let backend = LanceBackend::new( 398 - embedding_config, 399 - schema, 400 - DbFields::PdfText.as_ref().to_string(), 401 - ); 402 - full_library_to_arrow(&backend, Some(0), Some(5)).await 313 + let schema = Arc::new(get_schema(embedding_config.provider(), true).await); 314 + let store = LanceZoteroStore::from_schema(embedding_config, schema); 315 + full_library_to_arrow(&store, Some(0), Some(5)).await 403 316 }) 404 317 .await; 405 318
+27 -48
zqa/src/utils/library.rs
··· 10 10 use std::thread; 11 11 use std::time::Instant; 12 12 13 - use arrow_array::RecordBatch; 13 + use arrow_array::{RecordBatch, cast::AsArray}; 14 14 use directories::UserDirs; 15 15 use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; 16 16 use rusqlite::Connection; 17 17 use serde::Serialize; 18 18 use thiserror::Error; 19 19 use zqa_pdftools::parse::extract_text; 20 - use zqa_rag::vector::backends::{ 21 - backend::VectorBackend, 22 - lance::{LanceBackend, LanceError, get_column_from_batch}, 23 - }; 24 20 25 - use crate::izip; 26 - use crate::utils::arrow::DbFields; 21 + use crate::store::common::ZoteroStore; 22 + use crate::{izip, utils::arrow::DbFields}; 27 23 28 24 /// Gets the Zotero library path. Works on Linux, macOS, and Windows systems. 29 25 /// On CI environments, returns a location to a toy library in assets/ instead. ··· 153 149 } 154 150 } 155 151 156 - impl From<LanceError> for LibraryParsingError { 157 - fn from(value: LanceError) -> Self { 158 - LibraryParsingError::LanceDBError(value.to_string()) 159 - } 160 - } 161 - 162 152 impl From<Box<dyn std::error::Error>> for LibraryParsingError { 163 153 fn from(value: Box<dyn std::error::Error>) -> Self { 164 154 LibraryParsingError::PdfParsingError(value.to_string()) 165 155 } 166 156 } 167 157 158 + /// From a `RecordBatch`, return all values from a specified column as a `Vec<String>`. 159 + #[must_use] 160 + pub(crate) fn get_column_from_batch(batch: &RecordBatch, column: usize) -> Vec<String> { 161 + let results = batch.column(column).as_string::<i32>(); 162 + 163 + results 164 + .iter() 165 + .filter_map(|s| Some(s?.to_string())) 166 + .collect() 167 + } 168 + 168 169 /// Assuming an existing `LanceDB` database exists, returns a list of items present in the Zotero 169 170 /// library but not in the database. The primary use case for this is to update the DB with new 170 171 /// items. Note that this does not take into account removed items. ··· 182 183 /// * `LibraryParsingError::SqliteError` if the library path was not found, the query could not be prepared, or 183 184 /// columns from the result set could not be parsed, or `query_map` fails. 184 185 /// * `LibraryParsingError::LanceDBError` if fetching the rows from LanceDB fails. 185 - pub async fn get_new_library_items( 186 - backend: &LanceBackend, 186 + pub async fn get_new_library_items<T: ZoteroStore>( 187 + store: &T, 187 188 ) -> Result<Vec<ZoteroItemMetadata>, LibraryParsingError> { 188 - let db_items = backend 189 - .get_items(&[ 190 - DbFields::LibraryKey.into(), 191 - DbFields::Title.into(), 192 - DbFields::FilePath.into(), 193 - ]) 194 - .await?; 195 - 196 - let metadata_vecs = db_items 197 - .iter() 198 - .flat_map(|batch| { 199 - let library_keys = get_column_from_batch(batch, 0); 200 - let titles = get_column_from_batch(batch, 1); 201 - let file_paths = get_column_from_batch(batch, 2); 202 - 203 - let zipped = izip!(library_keys, titles, file_paths).collect::<Vec<_>>(); 204 - zipped 205 - .iter() 206 - .map(|(key, title, path)| ZoteroItemMetadata { 207 - library_key: key.clone(), 208 - title: title.clone(), 209 - file_path: PathBuf::from(path.clone()), 210 - authors: None, 211 - }) 212 - .collect::<Vec<_>>() 213 - }) 214 - .collect::<Vec<_>>(); 189 + let metadata_vecs = store 190 + .existing_item_metadata() 191 + .await 192 + .map_err(|e| LibraryParsingError::LanceDBError(e.to_string()))?; 215 193 216 194 let library_items = parse_library_metadata(None, None)?; 217 195 ··· 425 403 /// * If a Mutex lock could not be acquired on the progress bar. 426 404 /// * If the threads could not be joined. 427 405 #[allow(clippy::too_many_lines)] 428 - pub async fn parse_library( 429 - backend: &LanceBackend, 406 + pub async fn parse_library<T: ZoteroStore>( 407 + store: &T, 430 408 start_from: Option<usize>, 431 409 limit: Option<usize>, 432 410 ) -> Result<Vec<ZoteroItem>, LibraryParsingError> { 433 411 let start_time = Instant::now(); 434 412 435 - let metadata = if backend.db_exists().await { 436 - get_new_library_items(backend).await? 413 + let metadata = if store.exists().await { 414 + get_new_library_items(store).await? 437 415 } else { 438 416 parse_library_metadata(start_from, limit)? 439 417 }; ··· 643 621 }; 644 622 645 623 use super::*; 624 + use crate::LanceZoteroStore; 646 625 use crate::common::setup_logger; 647 626 648 627 #[test] ··· 732 711 arrow_schema::Field::new("file_path", arrow_schema::DataType::Utf8, false), 733 712 arrow_schema::Field::new("pdf_text", arrow_schema::DataType::Utf8, false), 734 713 ])); 735 - let backend = LanceBackend::new(embedding_config, schema, "pdf_text".into()); 736 - let items = parse_library(&backend, Some(0), Some(7)).await; 714 + let store = LanceZoteroStore::from_schema(embedding_config, schema); 715 + let items = parse_library(&store, Some(0), Some(7)).await; 737 716 test_ok!(items); 738 717 739 718 // Two of the items in the toy library are HTML files, so we actually
+3 -15
zqa/tests/new_library.rs
··· 4 4 use log::LevelFilter; 5 5 use zqa::common::setup_logger; 6 6 use zqa::config::{AnthropicConfig, Config, VoyageAIConfig}; 7 - use zqa::full_library_to_arrow; 7 + use zqa::{LanceZoteroStore, full_library_to_arrow}; 8 8 use zqa_macros::test_ok; 9 9 use zqa_rag::capabilities::{EmbeddingProvider, ModelProvider, RerankerProvider}; 10 10 use zqa_rag::constants::{ 11 11 DEFAULT_MAX_CONCURRENT_REQUESTS, DEFAULT_MAX_RETRIES, DEFAULT_VOYAGE_EMBEDDING_DIM, 12 12 }; 13 - use zqa_rag::vector::backends::{backend::VectorBackend, lance::LanceBackend}; 14 13 15 14 #[tokio::test] 16 15 async fn test_integration_works() { ··· 50 49 }; 51 50 52 51 let embedding_config = config.get_embedding_config().unwrap(); 53 - let schema = zqa::utils::arrow::get_schema(embedding_config.provider()).await; 54 - let backend = LanceBackend::new( 55 - embedding_config, 56 - std::sync::Arc::new(schema), 57 - "pdf_text".into(), 58 - ); 52 + let store = LanceZoteroStore::from_embedding_config(embedding_config).await; 59 53 60 - let record_batch = full_library_to_arrow(&backend, None, None).await; 54 + let record_batch = full_library_to_arrow(&store, None, None).await; 61 55 test_ok!(record_batch); 62 - 63 - let record_batch = record_batch.unwrap(); 64 - let batches = vec![record_batch.clone()]; 65 - let db = backend.insert_items(batches, None).await; 66 - 67 - test_ok!(db); 68 56 }