Another remote for gh:zotero-rag/zotero-rag
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

fix: remaining minor issues

authored by

Rahul Yedida and committed by
Rahul Yedida
6a6e9d9a 501d57a0

+33 -22
+12 -10
zqa-rag/src/vector/backends/lance.rs
··· 82 82 } 83 83 84 84 /// Backend for LanceDB vector store. 85 + #[derive(Debug, Clone)] 85 86 pub struct LanceBackend { 86 87 /// Configuration for the LanceDB embedding provider. 87 88 config: EmbeddingProviderConfig, ··· 94 95 /// Metadata about the LanceDB database. 95 96 #[derive(Debug, PartialEq)] 96 97 pub struct LanceMetadata { 97 - /// The names of the tables in the database. 98 - table_names: Vec<String>, 98 + /// Number of tables in the database. 99 + num_tables: usize, 99 100 /// The embedding table version. Each update to a table creates a new version. 100 101 embedding_table_version: u64, 101 102 /// Number of rows in the table ··· 107 108 write!( 108 109 f, 109 110 "LanceDB Statistics:\n\tNumber of tables: {}\n\tNumber of rows: {}\n\tEmbedding table version: {}", 110 - self.table_names.len(), 111 - self.num_rows, 112 - self.embedding_table_version 111 + self.num_tables, self.num_rows, self.embedding_table_version 113 112 ) 114 113 } 115 114 } ··· 159 158 160 159 /// Returns the database URI, allowing override via `LANCEDB_URI` environment variable. 161 160 fn get_db_path(&self) -> String { 162 - std::env::var("LANCEDB_URI").unwrap_or_else(|_| LANCEDB_URI.to_string()) 161 + get_db_uri() 163 162 } 164 163 165 164 async fn get_metadata(&self) -> Result<Self::Metadata, Self::Error> { ··· 184 183 )) 185 184 })?; 186 185 187 - let table_names = db.table_names().execute().await.map_err(|e| { 188 - LanceError::InvalidStateError(format!("Failed to list table names: {e}")) 189 - })?; 186 + let num_tables = db 187 + .table_names() 188 + .execute() 189 + .await 190 + .map_err(|e| LanceError::InvalidStateError(format!("Failed to list table names: {e}")))? 191 + .len(); 190 192 191 193 let num_rows = tbl.count_rows(None).await.map_err(|e| { 192 194 LanceError::InvalidStateError(format!( ··· 195 197 })?; 196 198 197 199 Ok(LanceMetadata { 198 - table_names, 200 + num_tables, 199 201 num_rows, 200 202 embedding_table_version: table_version, 201 203 })
+1 -1
zqa/src/cli/handlers/query.rs
··· 216 216 let retrieval_embedding_chars = std::sync::Arc::clone(&retrieval_tool.embedding_chars); 217 217 let retrieval_rerank_chars = std::sync::Arc::clone(&retrieval_tool.rerank_chars); 218 218 219 - let summarization_tool = SummarizationTool::new(llm_client.clone(), embedding_config); 219 + let summarization_tool = SummarizationTool::new(llm_client.clone(), ctx.backend.clone()); 220 220 let summarization_tool_clone = summarization_tool.clone(); 221 221 let mut tools: Vec<Box<dyn Tool>> = 222 222 vec![Box::new(retrieval_tool), Box::new(summarization_tool)];
+20 -11
zqa/src/tools/summarization.rs
··· 9 9 use serde_json::json; 10 10 use tokio::task::JoinSet; 11 11 use zqa_rag::{ 12 - embedding::common::EmbeddingProviderConfig, 13 12 llm::{ 14 13 base::{ApiClient, ChatRequest, CompletionApiResponse}, 15 14 errors::LLMError, 16 15 factory::LLMClient, 17 16 tools::Tool, 18 17 }, 19 - vector::backends::backend::VectorBackend, 18 + vector::backends::{backend::VectorBackend, lance::LanceBackend}, 20 19 }; 21 20 22 21 use crate::{ 23 22 cli::prompts::get_extraction_prompt, 24 23 utils::{ 25 - arrow::{DbFields, lance_backend}, 24 + arrow::DbFields, 26 25 library::{ZoteroItem, ZoteroItemSet}, 27 26 rag::ModelResponse, 28 27 }, ··· 34 33 #[derive(Debug, Clone)] 35 34 pub(crate) struct SummarizationTool { 36 35 pub(crate) llm_client: LLMClient, 37 - /// Embedding configuration for searching stored Zotero papers. 38 - pub(crate) embedding_config: EmbeddingProviderConfig, 36 + /// Backend for searching stored Zotero papers. 37 + pub(crate) backend: LanceBackend, 39 38 /// The input tokens used 40 39 pub(crate) input_tokens: Arc<Mutex<u32>>, 41 40 /// The output tokens used ··· 43 42 } 44 43 45 44 impl SummarizationTool { 46 - /// Create a new [`SummarizationTool`] instance, given an LLM client and embedding config. 47 - pub fn new(llm_client: LLMClient, embedding_config: EmbeddingProviderConfig) -> Self { 45 + /// Create a new [`SummarizationTool`] instance, given an LLM client and a backend. 46 + pub fn new(llm_client: LLMClient, backend: LanceBackend) -> Self { 48 47 Self { 49 48 llm_client, 50 - embedding_config, 49 + backend, 51 50 input_tokens: Arc::new(Mutex::new(0)), 52 51 output_tokens: Arc::new(Mutex::new(0)), 53 52 } ··· 97 96 let input: SummarizationToolInput = 98 97 serde_json::from_value(args).map_err(|e| format!("Invalid arguments: {e}"))?; 99 98 100 - let backend = lance_backend(self.embedding_config.clone()).await; 101 - let results = backend 99 + let results = self 100 + .backend 102 101 .search_by_column(DbFields::LibraryKey.as_ref(), &input.ids) 103 102 .await 104 103 .map_err(|e| format!("Search failed: {e}"))?; ··· 166 165 #[cfg(test)] 167 166 mod tests { 168 167 use std::env; 168 + use std::sync::Arc; 169 169 170 170 use serde_json::json; 171 171 use temp_env; ··· 175 175 config::{AnthropicConfig, LLMClientConfig}, 176 176 constants::DEFAULT_ANTHROPIC_MODEL_SMALL, 177 177 llm::factory::get_client_with_config, 178 + vector::backends::lance::LanceBackend, 178 179 }; 179 180 180 181 use super::*; ··· 191 192 .unwrap(); 192 193 193 194 let config = get_config(); 194 - SummarizationTool::new(client, config.get_embedding_config().unwrap()) 195 + let embedding_config = config.get_embedding_config().unwrap(); 196 + let schema = Arc::new(arrow_schema::Schema::new(vec![ 197 + arrow_schema::Field::new("library_key", arrow_schema::DataType::Utf8, false), 198 + arrow_schema::Field::new("title", arrow_schema::DataType::Utf8, false), 199 + arrow_schema::Field::new("file_path", arrow_schema::DataType::Utf8, false), 200 + arrow_schema::Field::new("pdf_text", arrow_schema::DataType::Utf8, false), 201 + ])); 202 + let backend = LanceBackend::new(embedding_config, schema, "pdf_text".into()); 203 + SummarizationTool::new(client, backend) 195 204 } 196 205 197 206 #[test]