···124124 source_col,
125125 }
126126 }
127127+128128+ /// Returns a reference to the embedding provider configuration.
129129+ #[must_use]
130130+ pub fn embedding_config(&self) -> &EmbeddingProviderConfig {
131131+ &self.config
132132+ }
127133}
128134129135/// From a `RecordBatch`, return all values from a specified column as a `Vec<String>`.
+11-6
zqa/src/cli/app.rs
···228228 embedding_dims: DEFAULT_VOYAGE_EMBEDDING_DIM as usize,
229229 reranker: DEFAULT_VOYAGE_RERANK_MODEL.into(),
230230 };
231231-232232- // Build a minimal tool; the embedding config is only used in `call`, not in the metadata
233233- // methods, so we use a dummy VoyageAI config here.
234234- RetrievalTool::new(
231231+ let schema = Arc::new(arrow_schema::Schema::new(vec![
232232+ arrow_schema::Field::new("library_key", arrow_schema::DataType::Utf8, false),
233233+ arrow_schema::Field::new("title", arrow_schema::DataType::Utf8, false),
234234+ arrow_schema::Field::new("file_path", arrow_schema::DataType::Utf8, false),
235235+ arrow_schema::Field::new("pdf_text", arrow_schema::DataType::Utf8, false),
236236+ ]));
237237+ let backend = LanceBackend::new(
235238 EmbeddingProviderConfig::VoyageAI(config.clone()),
236236- Some(RerankProviderConfig::VoyageAI(config)),
237237- )
239239+ schema,
240240+ "pdf_text".into(),
241241+ );
242242+ RetrievalTool::new(backend, Some(RerankProviderConfig::VoyageAI(config)))
238243 }
239244240245 #[retry(3)]
+6-10
zqa/src/cli/handlers/library.rs
···7373{
7474 const WARNING_THRESHOLD: usize = 100;
75757676- let item_metadata =
7777- if ctx.backend.db_exists().await {
7878- get_new_library_items(&ctx.config.get_embedding_config().ok_or(
7979- CLIError::ConfigError("Could not get embedding config".into()),
8080- )?)
8181- .await
8282- } else {
8383- parse_library_metadata(None, None)
8484- };
7676+ let item_metadata = if ctx.backend.db_exists().await {
7777+ get_new_library_items(&ctx.backend).await
7878+ } else {
7979+ parse_library_metadata(None, None)
8080+ };
85818682 if let Err(parse_err) = item_metadata {
8783 writeln!(
···110106 }
111107 }
112108113113- let record_batch = full_library_to_arrow(&ctx.config, None, None).await?;
109109+ let record_batch = full_library_to_arrow(&ctx.backend, None, None).await?;
114110 let schema = record_batch.schema();
115111 let batches = vec![record_batch.clone()];
116112
+2-6
zqa/src/cli/handlers/query.rs
···8888 let vector_search_start = Instant::now();
8989 let (mut search_results, _) = vector_search(
9090 search_term.clone(),
9191- &ctx.config
9292- .get_embedding_config()
9393- .ok_or(CLIError::ConfigError(
9494- "Could not get embedding config".into(),
9595- ))?,
9191+ &ctx.backend,
9692 ctx.config.get_reranker_config().as_ref(),
9793 )
9894 .await?;
···212208 .as_ref()
213209 .map(|c| (c.provider_name().to_string(), c.model_name().to_string()));
214210215215- let retrieval_tool = RetrievalTool::new(embedding_config.clone(), reranker_config);
211211+ let retrieval_tool = RetrievalTool::new(ctx.backend.clone(), reranker_config);
216212 let retrieval_embedding_chars = std::sync::Arc::clone(&retrieval_tool.embedding_chars);
217213 let retrieval_rerank_chars = std::sync::Arc::clone(&retrieval_tool.rerank_chars);
218214
+23-16
zqa/src/tools/retrieval.rs
···1010use serde::Deserialize;
1111use serde_json::json;
1212use zqa_rag::{
1313- embedding::common::EmbeddingProviderConfig, llm::tools::Tool,
1414- reranking::common::RerankProviderConfig,
1313+ llm::tools::Tool, reranking::common::RerankProviderConfig,
1414+ vector::backends::lance::LanceBackend,
1515};
16161717use crate::utils::{
···2525/// A tool to perform vector search and reranking.
2626#[derive(Debug)]
2727pub(crate) struct RetrievalTool {
2828- /// The embedding provider configuration. Note that this must be the same embedding provider
2929- /// used when initially creating the database.
3030- pub(crate) embedding_config: EmbeddingProviderConfig,
2828+ /// The backend used for vector search.
2929+ pub(crate) backend: LanceBackend,
3130 /// The reranker provider to use.
3231 pub(crate) reranker_config: Option<RerankProviderConfig>,
3332 /// Accumulated character count of text sent to the embedding API across all calls.
···3736}
38373938impl RetrievalTool {
4040- /// Create a new instance of the [`RetrievalTool`] given an embedding config and reranker config.
3939+ /// Create a new instance of the [`RetrievalTool`] given a backend and reranker config.
4140 pub(crate) fn new(
4242- embedding_config: EmbeddingProviderConfig,
4141+ backend: LanceBackend,
4342 reranker_provider: Option<RerankProviderConfig>,
4443 ) -> Self {
4544 Self {
4646- embedding_config,
4545+ backend,
4746 reranker_config: reranker_provider,
4847 embedding_chars: Arc::new(AtomicU64::new(0)),
4948 rerank_chars: Arc::new(AtomicU64::new(0)),
···9089 ) -> std::pin::Pin<Box<dyn Future<Output = Result<serde_json::Value, String>> + Send + '_>>
9190 {
9291 let start = Instant::now();
9393- let embedding_config = self.embedding_config.clone();
9292+ let backend = self.backend.clone();
9493 let reranker_config = self.reranker_config.clone();
9594 let embedding_chars = Arc::clone(&self.embedding_chars);
9695 let rerank_chars = Arc::clone(&self.rerank_chars);
···9998 let input: RetrievalToolInput =
10099 serde_json::from_value(args).map_err(|e| format!("Invalid arguments: {e}"))?;
101100 let (mut results, stats) =
102102- vector_search(input.query, &embedding_config, reranker_config.as_ref())
101101+ vector_search(input.query, &backend, reranker_config.as_ref())
103102 .await
104103 .map_err(|e| format!("Search failed: {e}"))?;
105104 embedding_chars.fetch_add(stats.embedding_chars as u64, Ordering::Relaxed);
···136135137136#[cfg(test)]
138137mod tests {
138138+ use std::sync::Arc;
139139+139140 use serde_json::json;
140141 use zqa_rag::constants::{
141142 DEFAULT_VOYAGE_EMBEDDING_DIM, DEFAULT_VOYAGE_EMBEDDING_MODEL, DEFAULT_VOYAGE_RERANK_MODEL,
142143 };
143144 use zqa_rag::embedding::common::EmbeddingProviderConfig;
145145+ use zqa_rag::vector::backends::lance::LanceBackend;
144146145147 use super::*;
146148···151153 embedding_dims: DEFAULT_VOYAGE_EMBEDDING_DIM as usize,
152154 reranker: DEFAULT_VOYAGE_RERANK_MODEL.into(),
153155 };
154154-155155- // Build a minimal tool; the embedding config is only used in `call`, not in the metadata
156156- // methods, so we use a dummy VoyageAI config here.
157157- RetrievalTool::new(
156156+ let schema = Arc::new(arrow_schema::Schema::new(vec![
157157+ arrow_schema::Field::new("library_key", arrow_schema::DataType::Utf8, false),
158158+ arrow_schema::Field::new("title", arrow_schema::DataType::Utf8, false),
159159+ arrow_schema::Field::new("file_path", arrow_schema::DataType::Utf8, false),
160160+ arrow_schema::Field::new("pdf_text", arrow_schema::DataType::Utf8, false),
161161+ ]));
162162+ let backend = LanceBackend::new(
158163 EmbeddingProviderConfig::VoyageAI(config.clone()),
159159- Some(RerankProviderConfig::VoyageAI(config)),
160160- )
164164+ schema,
165165+ "pdf_text".into(),
166166+ );
167167+ RetrievalTool::new(backend, Some(RerankProviderConfig::VoyageAI(config)))
161168 }
162169163170 #[test]
+17-34
zqa/src/utils/arrow.rs
···1717};
18181919use super::library::{LibraryParsingError, parse_library};
2020-use crate::{
2121- config::Config,
2222- utils::library::{ZoteroItem, ZoteroItemSet},
2323-};
2020+use crate::utils::library::{ZoteroItem, ZoteroItemSet};
24212522/// An enum containing the fields stored by our application in `LanceDB`, in order. Implementations
2623/// `as_ref()` and `into()` are provided to convert this to `&str` and `String` respectively.
···9390 fn from(value: LanceError) -> Self {
9491 Self::LanceError(value.to_string())
9592 }
9696-}
9797-9898-/// Build the LanceDB backend used by the CLI for the supplied embedding configuration.
9999-pub(crate) async fn lance_backend(embedding_config: EmbeddingProviderConfig) -> LanceBackend {
100100- let schema = Arc::new(get_schema(embedding_config.provider()).await);
101101- LanceBackend::new(
102102- embedding_config,
103103- schema,
104104- DbFields::PdfText.as_ref().to_string(),
105105- )
10693}
1079410895/// Get the schema for our `LanceDB` table. This is required for both getting library items and
···250237/// multi-threading, etc.
251238/// * `limit` - Optional limit, meant to be used in conjunction with `start_from`.
252239pub async fn full_library_to_arrow(
253253- config: &Config,
240240+ backend: &LanceBackend,
254241 start_from: Option<usize>,
255242 limit: Option<usize>,
256243) -> Result<RecordBatch, ArrowError> {
257257- let lib_items = parse_library(
258258- &config.get_embedding_config().ok_or(ArrowError::Other(
259259- "Failed to get embedding config from application config".to_string(),
260260- ))?,
261261- start_from,
262262- limit,
263263- )
264264- .await?;
244244+ let lib_items = parse_library(backend, start_from, limit).await?;
265245 log::info!("Finished parsing library items.");
266246267267- library_to_arrow(
268268- lib_items,
269269- config.get_embedding_config().ok_or(ArrowError::Other(
270270- "Failed to get embedding config from application config".to_string(),
271271- ))?,
272272- )
273273- .await
247247+ library_to_arrow(lib_items, backend.embedding_config().clone()).await
274248}
275249276250/// Statistics about the characters processed in a vector search call, used for cost estimation.
···315289/// * `ArrowError::LLMError` if reranking fails.
316290pub async fn vector_search(
317291 query: String,
318318- embedding_config: &EmbeddingProviderConfig,
292292+ backend: &LanceBackend,
319293 reranker_config: Option<&RerankProviderConfig>,
320294) -> Result<(Vec<ZoteroItem>, VectorSearchStats), ArrowError> {
321295 let embedding_chars = query.len();
322322- let backend = lance_backend(embedding_config.clone()).await;
323296 let batches = backend.vector_search(query.clone(), 10).await?;
324297325298 let items: ZoteroItemSet = batches.into();
···383356 };
384357385358 use super::*;
386386- use crate::{common::setup_logger, config::VoyageAIConfig};
359359+ use crate::{
360360+ common::setup_logger,
361361+ config::{Config, VoyageAIConfig},
362362+ };
387363388364 fn get_config() -> Config {
389365 let mut config = Config {
···416392 let config = get_config();
417393418394 let record_batch = temp_env::async_with_vars([("LANCEDB_URI", Some(&db_uri))], async {
419419- full_library_to_arrow(&config, Some(0), Some(5)).await
395395+ let embedding_config = config.get_embedding_config().unwrap();
396396+ let schema = Arc::new(get_schema(embedding_config.provider()).await);
397397+ let backend = LanceBackend::new(
398398+ embedding_config,
399399+ schema,
400400+ DbFields::PdfText.as_ref().to_string(),
401401+ );
402402+ full_library_to_arrow(&backend, Some(0), Some(5)).await
420403 })
421404 .await;
422405
+23-19
zqa/src/utils/library.rs
···1717use serde::Serialize;
1818use thiserror::Error;
1919use zqa_pdftools::parse::extract_text;
2020-use zqa_rag::embedding::common::EmbeddingProviderConfig;
2120use zqa_rag::vector::backends::{
2221 backend::VectorBackend,
2323- lance::{LanceError, db_exists as lancedb_exists, get_column_from_batch},
2222+ lance::{LanceBackend, LanceError, get_column_from_batch},
2423};
25242625use crate::izip;
2727-use crate::utils::arrow::{DbFields, lance_backend};
2626+use crate::utils::arrow::DbFields;
28272928/// Gets the Zotero library path. Works on Linux, macOS, and Windows systems.
3029/// On CI environments, returns a location to a toy library in assets/ instead.
···184183/// columns from the result set could not be parsed, or `query_map` fails.
185184/// * `LibraryParsingError::LanceDBError` if fetching the rows from LanceDB fails.
186185pub async fn get_new_library_items(
187187- embedding_config: &EmbeddingProviderConfig,
186186+ backend: &LanceBackend,
188187) -> Result<Vec<ZoteroItemMetadata>, LibraryParsingError> {
189189- let backend = lance_backend(embedding_config.clone()).await;
190188 let db_items = backend
191189 .get_items(&[
192190 DbFields::LibraryKey.into(),
···428426/// * If the threads could not be joined.
429427#[allow(clippy::too_many_lines)]
430428pub async fn parse_library(
431431- embedding_config: &EmbeddingProviderConfig,
429429+ backend: &LanceBackend,
432430 start_from: Option<usize>,
433431 limit: Option<usize>,
434432) -> Result<Vec<ZoteroItem>, LibraryParsingError> {
435433 let start_time = Instant::now();
436434437437- let metadata = if lancedb_exists().await {
438438- get_new_library_items(embedding_config).await?
435435+ let metadata = if backend.db_exists().await {
436436+ get_new_library_items(backend).await?
439437 } else {
440438 parse_library_metadata(start_from, limit)?
441439 };
···631629632630#[cfg(test)]
633631mod tests {
632632+ use std::sync::Arc;
633633+634634 use dotenv::dotenv;
635635 use zqa_macros::{test_eq, test_ok};
636636 use zqa_rag::{
···639639 DEFAULT_VOYAGE_EMBEDDING_DIM, DEFAULT_VOYAGE_EMBEDDING_MODEL,
640640 DEFAULT_VOYAGE_RERANK_MODEL,
641641 },
642642+ embedding::common::EmbeddingProviderConfig,
642643 };
643644644645 use super::*;
···719720 // SAFETY: single-threaded async test, no concurrent env var access
720721 unsafe { env::set_var("LANCEDB_URI", &db_uri) };
721722722722- let items = parse_library(
723723- &EmbeddingProviderConfig::VoyageAI(VoyageAIConfig {
724724- embedding_model: DEFAULT_VOYAGE_EMBEDDING_MODEL.into(),
725725- embedding_dims: DEFAULT_VOYAGE_EMBEDDING_DIM as usize,
726726- api_key: env::var("VOYAGE_AI_API_KEY").expect("VOYAGE_AI_API_KEY not set"),
727727- reranker: DEFAULT_VOYAGE_RERANK_MODEL.into(),
728728- }),
729729- Some(0),
730730- Some(7),
731731- )
732732- .await;
723723+ let embedding_config = EmbeddingProviderConfig::VoyageAI(VoyageAIConfig {
724724+ embedding_model: DEFAULT_VOYAGE_EMBEDDING_MODEL.into(),
725725+ embedding_dims: DEFAULT_VOYAGE_EMBEDDING_DIM as usize,
726726+ api_key: env::var("VOYAGE_AI_API_KEY").expect("VOYAGE_AI_API_KEY not set"),
727727+ reranker: DEFAULT_VOYAGE_RERANK_MODEL.into(),
728728+ });
729729+ let schema = Arc::new(arrow_schema::Schema::new(vec![
730730+ arrow_schema::Field::new("library_key", arrow_schema::DataType::Utf8, false),
731731+ arrow_schema::Field::new("title", arrow_schema::DataType::Utf8, false),
732732+ arrow_schema::Field::new("file_path", arrow_schema::DataType::Utf8, false),
733733+ arrow_schema::Field::new("pdf_text", arrow_schema::DataType::Utf8, false),
734734+ ]));
735735+ let backend = LanceBackend::new(embedding_config, schema, "pdf_text".into());
736736+ let items = parse_library(&backend, Some(0), Some(7)).await;
733737 test_ok!(items);
734738735739 // Two of the items in the toy library are HTML files, so we actually
+9-8
zqa/tests/new_library.rs
···4949 }),
5050 };
51515252- let record_batch = full_library_to_arrow(&config, None, None).await;
5252+ let embedding_config = config.get_embedding_config().unwrap();
5353+ let schema = zqa::utils::arrow::get_schema(embedding_config.provider()).await;
5454+ let backend = LanceBackend::new(
5555+ embedding_config,
5656+ std::sync::Arc::new(schema),
5757+ "pdf_text".into(),
5858+ );
5959+6060+ let record_batch = full_library_to_arrow(&backend, None, None).await;
5361 test_ok!(record_batch);
54625563 let record_batch = record_batch.unwrap();
5656- let _schema = record_batch.schema();
5764 let batches = vec![record_batch.clone()];
5858-5959- let backend = LanceBackend::new(
6060- config.get_embedding_config().unwrap(),
6161- record_batch.schema(),
6262- "pdf_text".into(),
6363- );
6465 let db = backend.insert_items(batches, None).await;
65666667 test_ok!(db);