From 363eb722c69f31a78c55a4da53f0bf0616f1b27a Mon Sep 17 00:00:00 2001 From: Don Hardman Date: Tue, 16 Sep 2025 14:31:54 +0700 Subject: [PATCH 1/3] refactor(embedding): extract embedding logic to octolib and remove providers - Remove OpenAI, Google, Voyage, FastEmbed, and HuggingFace providers - Delegate embedding generation and config handling to octolib - Simplify main crate by dropping embedding features and tests - Update dependencies to reflect new modular structure --- Cargo.lock | 65 ++- Cargo.toml | 15 +- src/embedding/mod.rs | 238 ++-------- src/embedding/provider/fastembed.rs | 335 -------------- src/embedding/provider/google.rs | 145 ------ src/embedding/provider/huggingface.rs | 609 -------------------------- src/embedding/provider/jina.rs | 168 ------- src/embedding/provider/mod.rs | 112 ----- src/embedding/provider/openai.rs | 207 --------- src/embedding/provider/voyage.rs | 177 -------- src/embedding/tests.rs | 491 --------------------- src/embedding/types.rs | 210 --------- src/indexer/search.rs | 4 +- src/mcp/memory.rs | 2 +- src/mcp/semantic_code.rs | 2 +- 15 files changed, 92 insertions(+), 2688 deletions(-) delete mode 100644 src/embedding/provider/fastembed.rs delete mode 100644 src/embedding/provider/google.rs delete mode 100644 src/embedding/provider/huggingface.rs delete mode 100644 src/embedding/provider/jina.rs delete mode 100644 src/embedding/provider/mod.rs delete mode 100644 src/embedding/provider/openai.rs delete mode 100644 src/embedding/provider/voyage.rs delete mode 100644 src/embedding/tests.rs delete mode 100644 src/embedding/types.rs diff --git a/Cargo.lock b/Cargo.lock index 941e071..cb2cdb3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -408,6 +408,19 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "async-compression" +version = "0.4.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "977eb15ea9efd848bb8a4a1a2500347ed7f0bf794edf0dc3ddcf439f43d36b23" +dependencies = [ + "compression-codecs", + "compression-core", + "futures-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "async-lock" version = "3.4.1" @@ -1428,6 +1441,23 @@ dependencies = [ "static_assertions", ] +[[package]] +name = "compression-codecs" +version = "0.4.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "485abf41ac0c8047c07c87c72c8fb3eb5197f6e9d7ded615dfd1a00ae00a0f64" +dependencies = [ + "compression-core", + "flate2", + "memchr", +] + +[[package]] +name = "compression-core" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e47641d3deaf41fb1538ac1f54735925e275eaf3bf4d55c81b137fba797e5cbb" + [[package]] name = "concurrent-queue" version = "2.5.0" @@ -5084,19 +5114,14 @@ dependencies = [ "anyhow", "arrow", "async-trait", - "candle-core", - "candle-nn", - "candle-transformers", "chrono", "clap", "clap_complete", "dirs", "dotenvy", "ec4rs", - "fastembed", "futures", "globset", - "hf-hub", "ignore", "lance", "lance-table", @@ -5105,14 +5130,13 @@ dependencies = [ "lsp-types", "notify", "notify-debouncer-mini", + "octolib", "parking_lot", "regex", "reqwest", "serde", "serde_json", "sha2", - "tiktoken-rs", - "tokenizers", "tokio", "toml 0.9.5", "tracing", @@ -5137,6 +5161,32 @@ dependencies = [ "uuid", ] +[[package]] +name = "octolib" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-trait", + "base64 0.22.1", + "candle-core", + "candle-nn", + "candle-transformers", + "dirs", + "fastembed", + "hf-hub", + "lazy_static", + "reqwest", + "serde", + "serde_json", + "sha2", + "thiserror 2.0.16", + "tiktoken-rs", + "tokenizers", + "tokio", + "tracing", + "url", +] + [[package]] name = "once_cell" version = "1.21.3" @@ -6092,6 +6142,7 @@ version = "0.12.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d429f34c8092b2d42c7c93cec323bb4adeb7c67698f70839adec842ec10c7ceb" dependencies = [ + "async-compression", "base64 0.22.1", "bytes", "encoding_rs", diff --git a/Cargo.toml b/Cargo.toml index 3b17498..ff6652f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,9 +23,7 @@ exclude = [ ] [features] -default = ["fastembed", "huggingface"] -fastembed = ["dep:fastembed"] -huggingface = ["dep:candle-core", "dep:candle-nn", "dep:candle-transformers", "dep:tokenizers", "dep:hf-hub"] +default = [] # Optimized release profile for static linking [profile.release] @@ -74,30 +72,25 @@ clap = { version = "4.5.45", features = ["derive"] } clap_complete = "4.5.57" notify = { version = "8.2.0", default-features = false, features = ["crossbeam-channel", "macos_fsevent"] } notify-debouncer-mini = "0.7.0" -fastembed = { version = "5.0.2", optional = true } toml = "0.9.5" lazy_static = "1.5.0" futures = { version = "0.3.31", default-features = false, features = ["std"] } globset = { version = "0.4.16", default-features = false } regex = { version = "1.11.1", default-features = false, features = ["std"] } dirs = "6.0.0" -# Candle dependencies for HuggingFace support (optional) -candle-core = { version = "0.9.1", optional = true } -candle-nn = { version = "0.9.1", optional = true } -candle-transformers = { version = "0.9.1", optional = true } -tokenizers = { version = "0.21.4", optional = true } -hf-hub = { version = "0.4.3", features = ["tokio"], optional = true } # EditorConfig parsing and formatting ec4rs = "1.2.0" tracing = "0.1.41" tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"] } tracing-appender = "0.2.3" -tiktoken-rs = "0.7.0" # LSP integration dependencies lsp-types = "0.97.0" url = "2.5.4" dotenvy = "0.15" +# Local dependency on octolib with embedding features +octolib = { path = "../octolib", features = ["fastembed", "huggingface"] } + [profile.dev] opt-level = 1 # Basic optimizations without slowing compilation too much debug = true # Keep debug symbols for backtraces diff --git a/src/embedding/mod.rs b/src/embedding/mod.rs index 64c3a37..892ae95 100644 --- a/src/embedding/mod.rs +++ b/src/embedding/mod.rs @@ -12,243 +12,57 @@ // See the License for the specific language governing permissions and // limitations under the License. -pub mod provider; -#[cfg(test)] -mod tests; -pub mod types; +//! Re-export embedding functionality from octolib use crate::config::Config; use anyhow::Result; -use tiktoken_rs::cl100k_base; -pub use provider::{create_embedding_provider_from_parts, EmbeddingProvider}; -pub use types::*; +// Re-export everything from octolib::embedding +pub use octolib::embedding::*; + +/// Convert octocode Config to octolib EmbeddingGenerationConfig +impl From<&Config> for EmbeddingGenerationConfig { + fn from(config: &Config) -> Self { + Self { + code_model: config.embedding.code_model.clone(), + text_model: config.embedding.text_model.clone(), + batch_size: config.index.embeddings_batch_size, + max_tokens_per_batch: config.index.embeddings_max_tokens_per_batch, + } + } +} /// Generate embeddings based on configured provider (supports provider:model format) +/// Compatibility wrapper for octocode Config pub async fn generate_embeddings( contents: &str, is_code: bool, config: &Config, ) -> Result> { - // Get the model string from config - let model_string = if is_code { - &config.embedding.code_model - } else { - &config.embedding.text_model - }; - - // Parse provider and model from the string - let (provider, model) = parse_provider_model(model_string); - - let provider_impl = create_embedding_provider_from_parts(&provider, &model).await?; - provider_impl.generate_embedding(contents).await -} - -/// Count tokens in a text using tiktoken (cl100k_base tokenizer) -pub fn count_tokens(text: &str) -> usize { - let bpe = cl100k_base().expect("Failed to load cl100k_base tokenizer"); - bpe.encode_with_special_tokens(text).len() -} - -/// Truncate output if it exceeds token limit -pub fn truncate_output(output: &str, max_tokens: usize) -> String { - if max_tokens == 0 { - return output.to_string(); - } - - let token_count = count_tokens(output); - - if token_count <= max_tokens { - return output.to_string(); - } - - // Simple truncation - cut at character boundary - // Estimate roughly where to cut (tokens are ~4 chars average) - let estimated_chars = max_tokens * 3; // Conservative estimate - let truncated = if output.len() > estimated_chars { - &output[..estimated_chars] - } else { - output - }; - - // Find last newline to avoid cutting mid-line - let last_newline = truncated.rfind('\n').unwrap_or(truncated.len()); - let final_truncated = &truncated[..last_newline]; - - format!( - "{}\n\n[Output truncated - {} tokens estimated, max {} allowed. Use more specific queries to reduce output size]", - final_truncated, - token_count, - max_tokens - ) -} - -/// Split texts into batches respecting both count and token limits -pub fn split_texts_into_token_limited_batches( - texts: Vec, - max_batch_size: usize, - max_tokens_per_batch: usize, -) -> Vec> { - let mut batches = Vec::new(); - let mut current_batch = Vec::new(); - let mut current_token_count = 0; - - for text in texts { - let text_tokens = count_tokens(&text); - - // If adding this text would exceed either limit, start a new batch - if !current_batch.is_empty() - && (current_batch.len() >= max_batch_size - || current_token_count + text_tokens > max_tokens_per_batch) - { - batches.push(current_batch); - current_batch = Vec::new(); - current_token_count = 0; - } - - current_batch.push(text); - current_token_count += text_tokens; - } - - // Add the last batch if it's not empty - if !current_batch.is_empty() { - batches.push(current_batch); - } - - batches + let embedding_config = EmbeddingGenerationConfig::from(config); + octolib::embedding::generate_embeddings(contents, is_code, &embedding_config).await } /// Generate batch embeddings based on configured provider (supports provider:model format) -/// Now includes token-aware batching and input_type support +/// Compatibility wrapper for octocode Config pub async fn generate_embeddings_batch( texts: Vec, is_code: bool, config: &Config, - input_type: types::InputType, + input_type: InputType, ) -> Result>> { - // Get the model string from config - let model_string = if is_code { - &config.embedding.code_model - } else { - &config.embedding.text_model - }; - - // Parse provider and model from the string - let (provider, model) = parse_provider_model(model_string); - - let provider_impl = create_embedding_provider_from_parts(&provider, &model).await?; - - // Split texts into token-limited batches - let batches = split_texts_into_token_limited_batches( - texts, - config.index.embeddings_batch_size, - config.index.embeddings_max_tokens_per_batch, - ); - - let mut all_embeddings = Vec::new(); - - // Process each batch with input_type - for batch in batches { - let batch_embeddings = provider_impl - .generate_embeddings_batch(batch, input_type.clone()) - .await?; - all_embeddings.extend(batch_embeddings); - } - - Ok(all_embeddings) -} - -/// Calculate a unique hash for content including file path -pub fn calculate_unique_content_hash(contents: &str, file_path: &str) -> String { - use sha2::{Digest, Sha256}; - let mut hasher = Sha256::new(); - hasher.update(contents.as_bytes()); - hasher.update(file_path.as_bytes()); - format!("{:x}", hasher.finalize()) -} - -/// Calculate a unique hash for content including file path and line ranges -/// This ensures blocks are reindexed when their position changes in the file -pub fn calculate_content_hash_with_lines( - contents: &str, - file_path: &str, - start_line: usize, - end_line: usize, -) -> String { - use sha2::{Digest, Sha256}; - let mut hasher = Sha256::new(); - hasher.update(contents.as_bytes()); - hasher.update(file_path.as_bytes()); - hasher.update(start_line.to_string().as_bytes()); - hasher.update(end_line.to_string().as_bytes()); - format!("{:x}", hasher.finalize()) -} - -/// Calculate content hash without file path -pub fn calculate_content_hash(contents: &str) -> String { - use sha2::{Digest, Sha256}; - let mut hasher = Sha256::new(); - hasher.update(contents.as_bytes()); - format!("{:x}", hasher.finalize()) -} - -/// Search mode embeddings result -#[derive(Debug, Clone)] -pub struct SearchModeEmbeddings { - pub code_embeddings: Option>, - pub text_embeddings: Option>, + let embedding_config = EmbeddingGenerationConfig::from(config); + octolib::embedding::generate_embeddings_batch(texts, is_code, &embedding_config, input_type) + .await } /// Generate embeddings for search based on mode - centralized logic to avoid duplication -/// This ensures consistent behavior across CLI and MCP interfaces +/// Compatibility wrapper for octocode Config pub async fn generate_search_embeddings( query: &str, mode: &str, config: &Config, ) -> Result { - match mode { - "code" => { - // Use code model for code searches only - let embeddings = generate_embeddings(query, true, config).await?; - Ok(SearchModeEmbeddings { - code_embeddings: Some(embeddings), - text_embeddings: None, - }) - } - "docs" | "text" => { - // Use text model for documents and text searches only - let embeddings = generate_embeddings(query, false, config).await?; - Ok(SearchModeEmbeddings { - code_embeddings: None, - text_embeddings: Some(embeddings), - }) - } - "all" => { - // For "all" mode, check if code and text models are different - // If different, generate separate embeddings; if same, use one set - let code_model = &config.embedding.code_model; - let text_model = &config.embedding.text_model; - - if code_model == text_model { - // Same model for both - generate once and reuse - let embeddings = generate_embeddings(query, true, config).await?; - Ok(SearchModeEmbeddings { - code_embeddings: Some(embeddings.clone()), - text_embeddings: Some(embeddings), - }) - } else { - // Different models - generate separate embeddings - let code_embeddings = generate_embeddings(query, true, config).await?; - let text_embeddings = generate_embeddings(query, false, config).await?; - Ok(SearchModeEmbeddings { - code_embeddings: Some(code_embeddings), - text_embeddings: Some(text_embeddings), - }) - } - } - _ => Err(anyhow::anyhow!( - "Invalid search mode '{}'. Use 'all', 'code', 'docs', or 'text'.", - mode - )), - } + let embedding_config = EmbeddingGenerationConfig::from(config); + octolib::embedding::generate_search_embeddings(query, mode, &embedding_config).await } diff --git a/src/embedding/provider/fastembed.rs b/src/embedding/provider/fastembed.rs deleted file mode 100644 index 908029a..0000000 --- a/src/embedding/provider/fastembed.rs +++ /dev/null @@ -1,335 +0,0 @@ -// Copyright 2025 Muvon Un Limited -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/*! -* FastEmbed Provider Implementation -* -* This module provides embedding generation using the FastEmbed library. -* FastEmbed offers fast, local embedding generation with automatic model downloading and caching. -* -* Key features: -* - Automatic model downloading and caching -* - Local CPU-based inference with optimized performance -* - Thread-safe model instances -* - Support for various embedding models -* - No API keys required -* -* Usage: -* - Set provider: `octocode config --embedding-provider fastembed` -* - Set models: `octocode config --code-embedding-model "fastembed:jinaai/jina-embeddings-v2-base-code"` -* - Popular models: sentence-transformers/all-MiniLM-L6-v2, BAAI/bge-base-en-v1.5, jinaai/jina-embeddings-v2-base-code -* -* Models are automatically downloaded to the system cache directory and reused across sessions. -*/ - -#[cfg(feature = "fastembed")] -use anyhow::{Context, Result}; -#[cfg(feature = "fastembed")] -use fastembed::{EmbeddingModel, InitOptions, TextEmbedding}; -#[cfg(feature = "fastembed")] -use std::sync::{Arc, Mutex}; - -#[cfg(feature = "fastembed")] -use super::super::{types::InputType, EmbeddingProvider}; - -#[cfg(feature = "fastembed")] -/// FastEmbed provider implementation for trait -pub struct FastEmbedProviderImpl { - model: Arc>, -} - -#[cfg(feature = "fastembed")] -impl FastEmbedProviderImpl { - pub fn new(model_name: &str) -> Result { - // Validate model is supported BEFORE creating - if !Self::is_model_supported_static(model_name) { - return Err(anyhow::anyhow!( - "Unsupported FastEmbed model: {}", - model_name - )); - } - - let model_enum = FastEmbedProvider::map_model_to_fastembed(model_name); - - // Use system-wide cache for FastEmbed models - let cache_dir = crate::storage::get_fastembed_cache_dir() - .context("Failed to get FastEmbed cache directory")?; - - let model = TextEmbedding::try_new( - InitOptions::new(model_enum) - .with_show_download_progress(true) - .with_cache_dir(cache_dir), - ) - .context("Failed to initialize FastEmbed model")?; - - Ok(Self { - model: Arc::new(Mutex::new(model)), - }) - } - - /// Check if model is supported using PURE dynamic API discovery - fn is_model_supported_static(model_name: &str) -> bool { - // Use FastEmbed's dynamic model discovery API - NO STATIC LISTS - let supported_models = TextEmbedding::list_supported_models(); - - // Check if the model name matches any supported model - supported_models.iter().any(|model_info| { - // Convert ModelInfo to string representation to check against model_name - let model_str = format!("{:?}", model_info); - model_str.contains(model_name) || - // Handle common aliases dynamically - (model_name == "all-MiniLM-L12-v2" && model_str.contains("sentence-transformers/all-MiniLM-L12-v2")) || - (model_name == "multilingual-e5-small" && model_str.contains("intfloat/multilingual-e5-small")) || - (model_name == "multilingual-e5-base" && model_str.contains("intfloat/multilingual-e5-base")) || - (model_name == "multilingual-e5-large" && model_str.contains("intfloat/multilingual-e5-large")) - }) - } - - /// Get list of all supported models dynamically - pub fn list_supported_models() -> Vec { - let supported_models = TextEmbedding::list_supported_models(); - supported_models - .iter() - .map(|model_info| model_info.model_code.clone()) // Use actual model name - .collect() - } - - /// Get list of all supported models with dimensions - pub fn list_supported_models_with_dimensions() -> Vec<(String, usize)> { - let supported_models = TextEmbedding::list_supported_models(); - supported_models - .iter() - .map(|model_info| (model_info.model_code.clone(), model_info.dim)) - .collect() - } - - /// Get model dimension dynamically from ModelInfo if available - pub fn get_model_dimension_from_api(model_name: &str) -> Option { - let supported_models = TextEmbedding::list_supported_models(); - - // Find the model in the supported list and try to extract dimension - for model_info in supported_models { - let model_str = format!("{:?}", model_info); - if model_str.contains(model_name) { - // Try to extract dimension from ModelInfo - // This is a placeholder - need to understand ModelInfo structure - // For now, we'll fall back to dynamic embedding generation - return None; - } - } - None - } -} - -#[cfg(feature = "fastembed")] -#[async_trait::async_trait] -impl EmbeddingProvider for FastEmbedProviderImpl { - async fn generate_embedding(&self, text: &str) -> Result> { - let text = text.to_string(); - let model = self.model.clone(); - - let embedding = tokio::task::spawn_blocking(move || -> Result> { - let mut model = model.lock().unwrap(); - let embedding = model.embed(vec![text], None)?; - - if embedding.is_empty() { - return Err(anyhow::anyhow!("No embeddings were generated")); - } - - Ok(embedding[0].clone()) - }) - .await??; - - Ok(embedding) - } - - async fn generate_embeddings_batch( - &self, - texts: Vec, - input_type: InputType, - ) -> Result>> { - let model = self.model.clone(); - - // Apply prefix manually for FastEmbed (doesn't support input_type API) - let processed_texts: Vec = texts - .into_iter() - .map(|text| input_type.apply_prefix(&text)) - .collect(); - - let embeddings = tokio::task::spawn_blocking(move || -> Result>> { - let text_refs: Vec<&str> = processed_texts.iter().map(|s| s.as_str()).collect(); - let mut model = model.lock().unwrap(); - let embeddings = model.embed(text_refs, None)?; - - Ok(embeddings) - }) - .await??; - - Ok(embeddings) - } - - fn get_dimension(&self) -> usize { - // First try to get dimension from ModelInfo API if available - // This is more efficient than generating embeddings - // Note: This is a placeholder until we understand ModelInfo structure better - - // Fall back to dynamic embedding generation (current working method) - // Generate a single embedding to get the dimension - // This is cached by FastEmbed, so subsequent calls are fast - let model = self.model.clone(); - - // Use a simple test text to get dimension - // We need to block here since this is a sync method - let dimension = std::thread::spawn(move || { - let mut model = model.lock().unwrap(); - match model.embed(vec!["test"], None) { - Ok(embeddings) if !embeddings.is_empty() => embeddings[0].len(), - _ => { - tracing::warn!("Failed to get dimension from FastEmbed model, using fallback"); - 768 // Safe fallback - } - } - }) - .join() - .unwrap_or(768); - - dimension - } - - fn is_model_supported(&self) -> bool { - true // If we created the provider, the model is supported - } -} - -#[cfg(feature = "fastembed")] -/// FastEmbed provider implementation -pub struct FastEmbedProvider; - -#[cfg(feature = "fastembed")] -impl FastEmbedProvider { - /// Map model name to FastEmbed model enum - pub fn map_model_to_fastembed(model: &str) -> EmbeddingModel { - match model { - "sentence-transformers/all-MiniLM-L6-v2" | "Xenova/all-MiniLM-L6-v2" => { - EmbeddingModel::AllMiniLML6V2 - } - "sentence-transformers/all-MiniLM-L6-v2-quantized" | "Qdrant/all-MiniLM-L6-v2-onnx" => { - EmbeddingModel::AllMiniLML6V2Q - } - "sentence-transformers/all-MiniLM-L12-v2" - | "all-MiniLM-L12-v2" - | "Xenova/all-MiniLM-L12-v2" => EmbeddingModel::AllMiniLML12V2, - "sentence-transformers/all-MiniLM-L12-v2-quantized" => EmbeddingModel::AllMiniLML12V2Q, - "BAAI/bge-base-en-v1.5" | "Xenova/bge-base-en-v1.5" => EmbeddingModel::BGEBaseENV15, - "BAAI/bge-base-en-v1.5-quantized" | "Qdrant/bge-base-en-v1.5-onnx-Q" => { - EmbeddingModel::BGEBaseENV15Q - } - "BAAI/bge-large-en-v1.5" | "Xenova/bge-large-en-v1.5" => EmbeddingModel::BGELargeENV15, - "BAAI/bge-large-en-v1.5-quantized" | "Qdrant/bge-large-en-v1.5-onnx-Q" => { - EmbeddingModel::BGELargeENV15Q - } - "BAAI/bge-small-en-v1.5" - | "Xenova/bge-small-en-v1.5" - | "Qdrant/bge-small-en-v1.5-onnx-Q" => EmbeddingModel::BGESmallENV15, - "BAAI/bge-small-en-v1.5-quantized" => EmbeddingModel::BGESmallENV15Q, - "nomic-ai/nomic-embed-text-v1" => EmbeddingModel::NomicEmbedTextV1, - "nomic-ai/nomic-embed-text-v1.5" => EmbeddingModel::NomicEmbedTextV15, - "nomic-ai/nomic-embed-text-v1.5-quantized" => EmbeddingModel::NomicEmbedTextV15Q, - "sentence-transformers/paraphrase-MiniLM-L6-v2" => { - EmbeddingModel::ParaphraseMLMiniLML12V2 - } - "sentence-transformers/paraphrase-MiniLM-L6-v2-quantized" - | "Qdrant/paraphrase-multilingual-MiniLM-L12-v2-onnx-Q" => { - EmbeddingModel::ParaphraseMLMiniLML12V2Q - } - "sentence-transformers/paraphrase-mpnet-base-v2" - | "Xenova/paraphrase-multilingual-mpnet-base-v2" => EmbeddingModel::ParaphraseMLMpnetBaseV2, - "BAAI/bge-small-zh-v1.5" | "Xenova/bge-small-zh-v1.5" => EmbeddingModel::BGESmallZHV15, - "BAAI/bge-large-zh-v1.5" | "Xenova/bge-large-zh-v1.5" => EmbeddingModel::BGELargeZHV15, - "lightonai/modernbert-embed-large" => EmbeddingModel::ModernBertEmbedLarge, - "intfloat/multilingual-e5-small" | "multilingual-e5-small" => { - EmbeddingModel::MultilingualE5Small - } - "intfloat/multilingual-e5-base" | "multilingual-e5-base" => { - EmbeddingModel::MultilingualE5Base - } - "intfloat/multilingual-e5-large" - | "multilingual-e5-large" - | "Qdrant/multilingual-e5-large-onnx" => EmbeddingModel::MultilingualE5Large, - "mixedbread-ai/mxbai-embed-large-v1" => EmbeddingModel::MxbaiEmbedLargeV1, - "mixedbread-ai/mxbai-embed-large-v1-quantized" => EmbeddingModel::MxbaiEmbedLargeV1Q, - "Alibaba-NLP/gte-base-en-v1.5" => EmbeddingModel::GTEBaseENV15, - "Alibaba-NLP/gte-base-en-v1.5-quantized" => EmbeddingModel::GTEBaseENV15Q, - "Alibaba-NLP/gte-large-en-v1.5" => EmbeddingModel::GTELargeENV15, - "Alibaba-NLP/gte-large-en-v1.5-quantized" => EmbeddingModel::GTELargeENV15Q, - "Qdrant/clip-ViT-B-32-text" => EmbeddingModel::ClipVitB32, - "jinaai/jina-embeddings-v2-base-code" => EmbeddingModel::JinaEmbeddingsV2BaseCode, - _ => panic!("Unsupported embedding model: {}", model), - } - } -} - -// Stubs for when fastembed feature is disabled -#[cfg(not(feature = "fastembed"))] -use anyhow::Result; - -#[cfg(not(feature = "fastembed"))] -pub struct FastEmbedProviderImpl; - -#[cfg(not(feature = "fastembed"))] -impl FastEmbedProviderImpl { - pub fn new(_model_name: &str) -> Result { - Err(anyhow::anyhow!( - "FastEmbed support is not compiled in. Please rebuild with --features fastembed" - )) - } -} - -#[cfg(not(feature = "fastembed"))] -#[async_trait::async_trait] -impl super::super::EmbeddingProvider for FastEmbedProviderImpl { - async fn generate_embedding(&self, _text: &str) -> Result> { - Err(anyhow::anyhow!( - "FastEmbed support is not compiled in. Please rebuild with --features fastembed" - )) - } - - async fn generate_embeddings_batch( - &self, - _texts: Vec, - _input_type: crate::embedding::types::InputType, - ) -> Result>> { - Err(anyhow::anyhow!( - "FastEmbed support is not compiled in. Please rebuild with --features fastembed" - )) - } - - fn get_dimension(&self) -> usize { - 768 // Safe fallback when feature is disabled - } - - fn is_model_supported(&self) -> bool { - false // No support when feature is disabled - } -} - -#[cfg(not(feature = "fastembed"))] -pub struct FastEmbedProvider; - -#[cfg(not(feature = "fastembed"))] -impl FastEmbedProvider { - pub fn map_model_to_fastembed(_model: &str) { - // Return unit type when feature is disabled - } -} diff --git a/src/embedding/provider/google.rs b/src/embedding/provider/google.rs deleted file mode 100644 index f064fb7..0000000 --- a/src/embedding/provider/google.rs +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright 2025 Muvon Un Limited -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! Google AI embedding provider implementation - -use anyhow::{Context, Result}; -use serde_json::{json, Value}; - -use super::super::types::InputType; -use super::{EmbeddingProvider, HTTP_CLIENT}; - -/// Google provider implementation for trait -pub struct GoogleProviderImpl { - model_name: String, - dimension: usize, -} - -impl GoogleProviderImpl { - pub fn new(model: &str) -> Result { - let dimension = Self::get_model_dimension(model)?; - Ok(Self { - model_name: model.to_string(), - dimension, - }) - } - - fn get_model_dimension(model: &str) -> Result { - match model { - "gemini-embedding-001" => Ok(3072), // Up to 3072 dimensions, state-of-the-art performance - "text-embedding-005" => Ok(768), // Specialized in English and code tasks - "text-multilingual-embedding-002" => Ok(768), // Specialized in multilingual tasks - _ => Err(anyhow::anyhow!( - "Unsupported Google model: '{}'. Supported models: gemini-embedding-001 (3072d), text-embedding-005 (768d), text-multilingual-embedding-002 (768d)", - model - )), - } - } -} - -#[async_trait::async_trait] -impl EmbeddingProvider for GoogleProviderImpl { - async fn generate_embedding(&self, text: &str) -> Result> { - GoogleProvider::generate_embeddings(text, &self.model_name).await - } - - async fn generate_embeddings_batch( - &self, - texts: Vec, - input_type: InputType, - ) -> Result>> { - // Apply prefix manually for Google (doesn't support input_type API) - let processed_texts: Vec = texts - .into_iter() - .map(|text| input_type.apply_prefix(&text)) - .collect(); - GoogleProvider::generate_embeddings_batch(processed_texts, &self.model_name).await - } - - fn get_dimension(&self) -> usize { - self.dimension - } - - fn is_model_supported(&self) -> bool { - matches!( - self.model_name.as_str(), - "gemini-embedding-001" | "text-embedding-005" | "text-multilingual-embedding-002" - ) - } -} - -/// Google provider implementation -pub struct GoogleProvider; - -impl GoogleProvider { - /// Get list of supported models for dynamic discovery - pub fn get_supported_models() -> Vec<&'static str> { - vec![ - "gemini-embedding-001", - "text-embedding-005", - "text-multilingual-embedding-002", - ] - } - pub async fn generate_embeddings(contents: &str, model: &str) -> Result> { - let result = Self::generate_embeddings_batch(vec![contents.to_string()], model).await?; - result - .first() - .cloned() - .ok_or_else(|| anyhow::anyhow!("No embeddings found")) - } - - pub async fn generate_embeddings_batch( - texts: Vec, - model: &str, - ) -> Result>> { - let google_api_key = std::env::var("GOOGLE_API_KEY") - .context("GOOGLE_API_KEY environment variable not set")?; - - // For batch processing, we'll need to send individual requests as Google's API structure is different - let mut all_embeddings = Vec::new(); - - for text in texts { - let response = HTTP_CLIENT - .post(format!("https://generativelanguage.googleapis.com/v1beta/models/{}:embedContent?key={}", model, google_api_key)) - .header("Content-Type", "application/json") - .json(&json!({ - "content": { - "parts": [{ - "text": text - }] - } - })) - .send() - .await?; - - if !response.status().is_success() { - let error_text = response.text().await?; - return Err(anyhow::anyhow!("Google API error: {}", error_text)); - } - - let response_json: Value = response.json().await?; - - let embedding = response_json["embedding"]["values"] - .as_array() - .context("Failed to get embedding values")? - .iter() - .map(|v| v.as_f64().unwrap_or_default() as f32) - .collect(); - - all_embeddings.push(embedding); - } - - Ok(all_embeddings) - } -} diff --git a/src/embedding/provider/huggingface.rs b/src/embedding/provider/huggingface.rs deleted file mode 100644 index 8192085..0000000 --- a/src/embedding/provider/huggingface.rs +++ /dev/null @@ -1,609 +0,0 @@ -// Copyright 2025 Muvon Un Limited -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/*! - * HuggingFace Provider Implementation - * - * This module provides local embedding generation using HuggingFace models via the Candle library. - * It supports multiple model architectures with safetensors format from the HuggingFace Hub. - * - * Key features: - * - Automatic model downloading and caching - * - Local CPU-based inference (GPU support can be added) - * - Thread-safe model cache for efficient reuse - * - Mean pooling and L2 normalization for sentence embeddings - * - Full compatibility with provider:model syntax - * - Dynamic model architecture detection - * - * Usage: - * - Set provider: `octocode config --embedding-provider huggingface` - * - Set models: `octocode config --code-embedding-model "huggingface:jinaai/jina-embeddings-v2-base-code"` - * - Popular models: jinaai/jina-embeddings-v2-base-code, sentence-transformers/all-mpnet-base-v2 - * - * Models are automatically downloaded to the system cache directory and reused across sessions. - */ - -// When huggingface feature is enabled -#[cfg(feature = "huggingface")] -use anyhow::{Context, Result}; -#[cfg(feature = "huggingface")] -use candle_core::{DType, Device, Tensor}; -#[cfg(feature = "huggingface")] -use candle_nn::VarBuilder; -#[cfg(feature = "huggingface")] -use candle_transformers::models::bert::{BertModel, Config as BertConfig}; -use candle_transformers::models::jina_bert::Config as JinaBertConfig; -#[cfg(feature = "huggingface")] -use hf_hub::{api::tokio::Api, Repo, RepoType}; -#[cfg(feature = "huggingface")] -use std::collections::HashMap; -#[cfg(feature = "huggingface")] -use std::sync::Arc; -#[cfg(feature = "huggingface")] -use tokenizers::Tokenizer; -#[cfg(feature = "huggingface")] -use tokio::sync::RwLock; - -#[cfg(feature = "huggingface")] -/// HuggingFace model instance -pub struct HuggingFaceModel { - model: BertModel, - tokenizer: Tokenizer, - device: Device, -} - -#[cfg(feature = "huggingface")] -impl HuggingFaceModel { - /// Load a SentenceTransformer model from HuggingFace Hub - pub async fn load(model_name: &str) -> Result { - let device = Device::Cpu; // Use CPU for now, can be extended to support GPU - - // Use our custom cache directory for consistency with FastEmbed - // Set HF_HOME environment variable to control where models are downloaded - let cache_dir = crate::storage::get_huggingface_cache_dir() - .context("Failed to get HuggingFace cache directory")?; - - // Set the HuggingFace cache directory via environment variable - std::env::set_var("HF_HOME", &cache_dir); - - // Download model files from HuggingFace Hub with proper error handling - let api = Api::new().context("Failed to initialize HuggingFace API")?; - let repo = api.repo(Repo::new(model_name.to_string(), RepoType::Model)); - - // Download required files with enhanced error handling - let config_path = repo - .get("config.json") - .await - .with_context(|| format!("Failed to download config.json for model: {}", model_name))?; - - // Load tokenizer - try different formats - let tokenizer = if let Ok(tokenizer_json_path) = repo.get("tokenizer.json").await { - // Direct tokenizer.json file (most models) - Tokenizer::from_file(tokenizer_json_path) - .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))? - } else { - // Try to build tokenizer from components (for models like microsoft/codebert-base) - // Check for RoBERTa-style tokenizer (vocab.json + merges.txt) - if let (Ok(vocab_path), Ok(merges_path)) = - (repo.get("vocab.json").await, repo.get("merges.txt").await) - { - // Build RoBERTa/GPT2-style BPE tokenizer using BPE::from_file - use tokenizers::{ - models::bpe::BPE, normalizers, pre_tokenizers::byte_level::ByteLevel, - processors::roberta::RobertaProcessing, - }; - - // Use BPE::from_file which handles the vocab and merges loading - let bpe = BPE::from_file( - vocab_path - .to_str() - .ok_or_else(|| anyhow::anyhow!("Invalid vocab path"))?, - merges_path - .to_str() - .ok_or_else(|| anyhow::anyhow!("Invalid merges path"))?, - ) - .unk_token("".to_string()) - .build() - .map_err(|e| anyhow::anyhow!("Failed to build BPE tokenizer: {:?}", e))?; - - let mut tokenizer = Tokenizer::new(bpe); - - // Add ByteLevel pre-tokenizer (for RoBERTa) - tokenizer.with_pre_tokenizer(Some(ByteLevel::default())); - - // Add RoBERTa post-processing - let post_processor = RobertaProcessing::new( - ("".to_string(), 2), // SEP token - ("".to_string(), 0), // CLS token - ) - .trim_offsets(false) - .add_prefix_space(true); - tokenizer.with_post_processor(Some(post_processor)); - - // Add normalizer - let normalizer = - normalizers::Sequence::new(vec![normalizers::Strip::new(true, true).into()]); - tokenizer.with_normalizer(Some(normalizer)); - - tokenizer - } else { - return Err(anyhow::anyhow!( - "Could not find tokenizer files for model: {}. \ - Expected either tokenizer.json or (vocab.json + merges.txt). \ - This model may not be compatible.", - model_name - )); - } - }; - - // Try different weight file formats - let weights_path = if let Ok(path) = repo.get("model.safetensors").await { - path - } else if let Ok(path) = repo.get("pytorch_model.bin").await { - path - } else { - return Err(anyhow::anyhow!( - "Could not find model weights in safetensors or pytorch format" - )); - }; - - // Load configuration - let config_content = std::fs::read_to_string(config_path)?; - let config: BertConfig = serde_json::from_str(&config_content)?; - - // Load model weights - only support safetensors for now - let weights = if weights_path.to_string_lossy().ends_with(".safetensors") { - candle_core::safetensors::load(&weights_path, &device)? - } else { - return Err(anyhow::anyhow!("PyTorch .bin format not supported in this implementation. Please use a model with safetensors format.")); - }; - - let var_builder = VarBuilder::from_tensors(weights, DType::F32, &device); - - // Create model - let model = BertModel::load(var_builder, &config)?; - - Ok(Self { - model, - tokenizer, - device, - }) - } - - /// Generate embeddings for a single text - pub fn encode(&self, text: &str) -> Result> { - self.encode_batch(&[text.to_string()]) - .map(|embeddings| embeddings.into_iter().next().unwrap_or_default()) - } - - /// Generate embeddings for multiple texts - pub fn encode_batch(&self, texts: &[String]) -> Result>> { - let mut all_embeddings = Vec::new(); - - for text in texts { - // Tokenize input - convert String to &str - let encoding = self - .tokenizer - .encode(text.as_str(), true) - .map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?; - - let tokens = encoding.get_ids(); - let token_ids = Tensor::new(tokens, &self.device)?.unsqueeze(0)?; // Add batch dimension - - // Create attention mask (all 1s for valid tokens) - let attention_mask = Tensor::ones((1, tokens.len()), DType::U8, &self.device)?; - - // Run through model - BertModel.forward takes 3 arguments: input_ids, attention_mask, token_type_ids - let output = self.model.forward(&token_ids, &attention_mask, None)?; - - // Apply mean pooling to get sentence embedding - let embeddings = self.mean_pooling(&output, &attention_mask)?; - - // Normalize embeddings - let normalized = self.normalize(&embeddings)?; - - // Convert to Vec - let embedding_vec = normalized.to_vec1::()?; - all_embeddings.push(embedding_vec); - } - - Ok(all_embeddings) - } - - /// Mean pooling operation - fn mean_pooling(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result { - // Convert attention mask to f32 and expand dimensions - let attention_mask = attention_mask.to_dtype(DType::F32)?; - let attention_mask = attention_mask.unsqueeze(2)?; // (batch_size, seq_len, 1) - - // Apply attention mask to hidden states - let masked_hidden_states = hidden_states.mul(&attention_mask)?; - - // Sum along sequence dimension - let sum_hidden_states = masked_hidden_states.sum(1)?; // (batch_size, hidden_size) - - // Sum attention mask to get actual sequence lengths - let sum_mask = attention_mask.sum(1)?; // (batch_size, 1) - - // Compute mean - let mean_pooled = sum_hidden_states.div(&sum_mask)?; - - Ok(mean_pooled) - } - - /// Normalize embeddings to unit vectors - fn normalize(&self, embeddings: &Tensor) -> Result { - let norm = embeddings.sqr()?.sum_keepdim(1)?.sqrt()?; - Ok(embeddings.div(&norm)?) - } -} - -#[cfg(feature = "huggingface")] -// Global cache for loaded models using async-compatible RwLock -lazy_static::lazy_static! { - static ref MODEL_CACHE: Arc>>> = - Arc::new(RwLock::new(HashMap::new())); -} - -#[cfg(feature = "huggingface")] -/// HuggingFace provider implementation -pub struct HuggingFaceProvider; - -#[cfg(feature = "huggingface")] -impl HuggingFaceProvider { - /// Get or load a model from cache - async fn get_model(model_name: &str) -> Result> { - { - let cache = MODEL_CACHE.read().await; - if let Some(model) = cache.get(model_name) { - return Ok(model.clone()); - } - } - - // Model not in cache, load it - let model = HuggingFaceModel::load(model_name) - .await - .with_context(|| format!("Failed to load HuggingFace model: {}", model_name))?; - - let model_arc = Arc::new(model); - - // Add to cache - { - let mut cache = MODEL_CACHE.write().await; - cache.insert(model_name.to_string(), model_arc.clone()); - } - - Ok(model_arc) - } - - /// Generate embeddings for a single text - pub async fn generate_embeddings(contents: &str, model: &str) -> Result> { - let model_instance = Self::get_model(model).await?; - - // Run encoding in a blocking task to avoid blocking async runtime - let contents = contents.to_string(); - let result = - tokio::task::spawn_blocking(move || model_instance.encode(&contents)).await??; - - Ok(result) - } - - /// Generate batch embeddings for multiple texts - pub async fn generate_embeddings_batch( - texts: Vec, - model: &str, - ) -> Result>> { - let model_instance = Self::get_model(model).await?; - - // Run encoding in a blocking task to avoid blocking async runtime - let result = - tokio::task::spawn_blocking(move || model_instance.encode_batch(&texts)).await??; - - Ok(result) - } -} - -// Stubs for when huggingface feature is disabled -#[cfg(not(feature = "huggingface"))] -use anyhow::Result; - -#[cfg(not(feature = "huggingface"))] -pub struct HuggingFaceProvider; - -#[cfg(not(feature = "huggingface"))] -impl HuggingFaceProvider { - pub async fn generate_embeddings(_contents: &str, _model: &str) -> Result> { - Err(anyhow::anyhow!( - "HuggingFace support is not compiled in. Please rebuild with --features huggingface" - )) - } - - pub async fn generate_embeddings_batch( - _texts: Vec, - _model: &str, - ) -> Result>> { - Err(anyhow::anyhow!( - "HuggingFace support is not compiled in. Please rebuild with --features huggingface" - )) - } -} -use super::super::types::InputType; -use super::EmbeddingProvider; - -/// HuggingFace provider implementation for trait -#[cfg(feature = "huggingface")] -pub struct HuggingFaceProviderImpl { - model_name: String, - dimension: usize, -} - -#[cfg(feature = "huggingface")] -impl HuggingFaceProviderImpl { - pub async fn new(model: &str) -> Result { - #[cfg(not(feature = "huggingface"))] - { - Err(anyhow::anyhow!("HuggingFace provider requires 'huggingface' feature to be enabled. Cannot validate model '{}' without Hub API access.", model)) - } - - #[cfg(feature = "huggingface")] - { - let dimension = Self::get_model_dimension(model).await?; - Ok(Self { - model_name: model.to_string(), - dimension, - }) - } - } - - #[cfg(feature = "huggingface")] - async fn get_model_dimension(model: &str) -> Result { - Self::get_dimension_from_config(model).await - } - - /// Get model dimension using Candle config structs (like examples) - #[cfg(feature = "huggingface")] - async fn get_dimension_from_config(model_name: &str) -> Result { - // Download config.json - let config_json = Self::download_config_direct(model_name).await?; - - // Try different Candle config types - JinaBert first, then standard Bert - if let Ok(config) = Self::parse_as_jina_bert_config(&config_json) { - return Ok(config.hidden_size); - } - - if let Ok(config) = Self::parse_as_bert_config(&config_json) { - return Ok(config.hidden_size); - } - - // Fallback to JSON parsing - Self::parse_hidden_size_from_json(&config_json, model_name) - } - - /// Try to parse config as JinaBert config (for Jina models) - #[cfg(feature = "huggingface")] - fn parse_as_jina_bert_config(config_json: &str) -> Result { - serde_json::from_str::(config_json) - .map_err(|e| anyhow::anyhow!("Failed to parse as JinaBertConfig: {}", e)) - } - - /// Try to parse config as standard Candle BertConfig - #[cfg(feature = "huggingface")] - fn parse_as_bert_config( - config_json: &str, - ) -> Result { - use candle_transformers::models::bert::Config as BertConfig; - serde_json::from_str::(config_json) - .map_err(|e| anyhow::anyhow!("Failed to parse as BertConfig: {}", e)) - } - - /// Parse hidden_size from JSON config flexibly - #[cfg(feature = "huggingface")] - fn parse_hidden_size_from_json(config_json: &str, model_name: &str) -> Result { - use serde_json::Value; - - let config: Value = serde_json::from_str(config_json).with_context(|| { - format!( - "Failed to parse config.json as JSON for model: {}", - model_name - ) - })?; - - // Try different field names that contain embedding dimensions - let dimension_fields = ["hidden_size", "d_model", "embedding_size", "dim"]; - - for field in &dimension_fields { - if let Some(dim) = config.get(field).and_then(|v| v.as_u64()) { - tracing::debug!( - "Found dimension {} for model {} from config.json field '{}'", - dim, - model_name, - field - ); - return Ok(dim as usize); - } - } - - Err(anyhow::anyhow!( - "No dimension field found in config.json for model '{}'. \ - Searched for fields: {:?}. Available fields: {:?}", - model_name, - dimension_fields, - config - .as_object() - .map(|obj| obj.keys().collect::>()) - .unwrap_or_default() - )) - } - - /// Download config.json directly from HuggingFace Hub using HTTP - #[cfg(feature = "huggingface")] - async fn download_config_direct(model_name: &str) -> Result { - use reqwest; - - // Construct direct URL to config.json - let config_url = format!("https://huggingface.co/{}/raw/main/config.json", model_name); - - tracing::debug!("Downloading config from: {}", config_url); - - // Use reqwest for direct HTTP download - let client = reqwest::Client::new(); - let response = client - .get(&config_url) - .header("User-Agent", "octocode/0.7.1") - .send() - .await - .with_context(|| format!("Failed to download config.json from {}", config_url))?; - - if !response.status().is_success() { - return Err(anyhow::anyhow!( - "Failed to download config.json for model '{}'. HTTP status: {}. \ - This could be due to:\n\ - 1. Model doesn't exist on HuggingFace Hub\n\ - 2. Network connectivity issues\n\ - 3. Model is private and requires authentication\n\ - 4. Model doesn't have a config.json file", - model_name, - response.status() - )); - } - - let config_text = response.text().await.with_context(|| { - format!( - "Failed to read config.json response for model: {}", - model_name - ) - })?; - - Ok(config_text) - } -} - -#[cfg(feature = "huggingface")] -#[async_trait::async_trait] -impl EmbeddingProvider for HuggingFaceProviderImpl { - async fn generate_embedding(&self, text: &str) -> Result> { - HuggingFaceProvider::generate_embeddings(text, &self.model_name).await - } - - async fn generate_embeddings_batch( - &self, - texts: Vec, - input_type: InputType, - ) -> Result>> { - // Apply prefix manually for HuggingFace (doesn't support input_type API) - let processed_texts: Vec = texts - .into_iter() - .map(|text| input_type.apply_prefix(&text)) - .collect(); - HuggingFaceProvider::generate_embeddings_batch(processed_texts, &self.model_name).await - } - - fn get_dimension(&self) -> usize { - self.dimension - } - - fn is_model_supported(&self) -> bool { - // For HuggingFace, we support many models, so return true for most cases - // The actual validation happens when trying to load the model - true - } -} - -#[cfg(all(test, feature = "huggingface"))] -mod tests { - #[test] - fn test_roberta_tokenizer_building() { - // Test that we can build a RoBERTa-style tokenizer using BPE::from_file approach - use tokenizers::{ - models::bpe::BPE, pre_tokenizers::byte_level::ByteLevel, - processors::roberta::RobertaProcessing, Tokenizer, - }; - - // Create temporary files for testing - let vocab_file = std::env::temp_dir().join("test_vocab.json"); - let merges_file = std::env::temp_dir().join("test_merges.txt"); - - // Write test vocab - must include all tokens used in merges - let vocab_content = r#"{"":0,"":1,"":2,"":3,"h":4,"e":5,"l":6,"o":7,"r":8,"he":9,"ll":10,"or":11,"hello":12,"world":13}"#; - std::fs::write(&vocab_file, vocab_content).expect("Failed to write vocab"); - - // Write test merges - let merges_content = "#version: 0.2\nh e\nl l\no r"; - std::fs::write(&merges_file, merges_content).expect("Failed to write merges"); - - // Build BPE model using from_file - let bpe = BPE::from_file(vocab_file.to_str().unwrap(), merges_file.to_str().unwrap()) - .unk_token("".to_string()) - .build() - .expect("Failed to build BPE tokenizer"); - - let mut tokenizer = Tokenizer::new(bpe); - - // Add ByteLevel pre-tokenizer (for RoBERTa) - tokenizer.with_pre_tokenizer(Some(ByteLevel::default())); - - // Add RoBERTa post-processing - let post_processor = RobertaProcessing::new( - ("".to_string(), 2), // SEP token - ("".to_string(), 0), // CLS token - ) - .trim_offsets(false) - .add_prefix_space(true); - tokenizer.with_post_processor(Some(post_processor)); - - // Test that tokenizer works - let test_text = "hello world"; - let encoding = tokenizer - .encode(test_text, false) - .expect("Failed to encode"); - - assert!( - !encoding.get_ids().is_empty(), - "Encoding should produce tokens" - ); - println!("✓ RoBERTa-style tokenizer built successfully using BPE::from_file"); - - // Clean up - let _ = std::fs::remove_file(vocab_file); - let _ = std::fs::remove_file(merges_file); - } - - #[test] - fn test_merges_parsing() { - // Test that we correctly parse merges.txt format - let merges_content = r#"#version: 0.2 -Ġ t -Ġ a -h e -Ġt he -i n"#; - - let merges: Vec<(String, String)> = merges_content - .lines() - .skip(1) // Skip header line - .filter_map(|line| { - let parts: Vec<&str> = line.split_whitespace().collect(); - if parts.len() == 2 { - Some((parts[0].to_string(), parts[1].to_string())) - } else { - None - } - }) - .collect(); - - assert_eq!(merges.len(), 5); - assert_eq!(merges[0], ("Ġ".to_string(), "t".to_string())); - println!("✓ Merges parsing works correctly"); - } -} diff --git a/src/embedding/provider/jina.rs b/src/embedding/provider/jina.rs deleted file mode 100644 index 030a3db..0000000 --- a/src/embedding/provider/jina.rs +++ /dev/null @@ -1,168 +0,0 @@ -// Copyright 2025 Muvon Un Limited -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! Jina AI embedding provider implementation - -use anyhow::{Context, Result}; -use serde_json::{json, Value}; - -use super::super::types::InputType; -use super::{EmbeddingProvider, HTTP_CLIENT}; - -/// Jina provider implementation for trait -pub struct JinaProviderImpl { - model_name: String, - dimension: usize, -} - -impl JinaProviderImpl { - pub fn new(model: &str) -> Result { - // Validate model first - fail fast if unsupported - let supported_models = [ - "jina-embeddings-v4", - "jina-clip-v2", - "jina-embeddings-v3", - "jina-clip-v1", - "jina-embeddings-v2-base-es", - "jina-embeddings-v2-base-code", - "jina-embeddings-v2-base-de", - "jina-embeddings-v2-base-zh", - "jina-embeddings-v2-base-en", - ]; - - if !supported_models.contains(&model) { - return Err(anyhow::anyhow!( - "Unsupported Jina model: '{}'. Supported models: {:?}", - model, - supported_models - )); - } - - let dimension = Self::get_model_dimension(model); - Ok(Self { - model_name: model.to_string(), - dimension, - }) - } - - fn get_model_dimension(model: &str) -> usize { - match model { - "jina-embeddings-v4" => 2048, - "jina-clip-v2" => 1024, - "jina-embeddings-v3" => 1024, - "jina-clip-v1" => 768, - "jina-embeddings-v2-base-es" => 768, - "jina-embeddings-v2-base-code" => 768, - "jina-embeddings-v2-base-de" => 768, - "jina-embeddings-v2-base-zh" => 768, - "jina-embeddings-v2-base-en" => 768, - _ => { - // This should never be reached due to validation in new() - panic!( - "Invalid Jina model '{}' passed to get_model_dimension", - model - ); - } - } - } -} - -#[async_trait::async_trait] -impl EmbeddingProvider for JinaProviderImpl { - async fn generate_embedding(&self, text: &str) -> Result> { - JinaProvider::generate_embeddings(text, &self.model_name).await - } - - async fn generate_embeddings_batch( - &self, - texts: Vec, - input_type: InputType, - ) -> Result>> { - // Apply prefix manually for Jina (doesn't support input_type API) - let processed_texts: Vec = texts - .into_iter() - .map(|text| input_type.apply_prefix(&text)) - .collect(); - JinaProvider::generate_embeddings_batch(processed_texts, &self.model_name).await - } - - fn get_dimension(&self) -> usize { - self.dimension - } - - fn is_model_supported(&self) -> bool { - // REAL validation - only support actual Jina models - matches!( - self.model_name.as_str(), - "jina-embeddings-v4" - | "jina-clip-v2" - | "jina-embeddings-v3" - | "jina-clip-v1" - | "jina-embeddings-v2-base-es" - | "jina-embeddings-v2-base-code" - | "jina-embeddings-v2-base-de" - | "jina-embeddings-v2-base-zh" - | "jina-embeddings-v2-base-en" - ) - } -} - -/// Jina provider implementation -pub struct JinaProvider; - -impl JinaProvider { - pub async fn generate_embeddings(contents: &str, model: &str) -> Result> { - let result = Self::generate_embeddings_batch(vec![contents.to_string()], model).await?; - result - .first() - .cloned() - .ok_or_else(|| anyhow::anyhow!("No embeddings found")) - } - - pub async fn generate_embeddings_batch( - texts: Vec, - model: &str, - ) -> Result>> { - let jina_api_key = - std::env::var("JINA_API_KEY").context("JINA_API_KEY environment variable not set")?; - - let response = HTTP_CLIENT - .post("https://api.jina.ai/v1/embeddings") - .header("Authorization", format!("Bearer {}", jina_api_key)) - .json(&json!({ - "input": texts, - "model": model, - })) - .send() - .await?; - - let response_json: Value = response.json().await?; - - let embeddings = response_json["data"] - .as_array() - .context("Failed to get embeddings array")? - .iter() - .map(|data| { - data["embedding"] - .as_array() - .unwrap_or(&Vec::new()) - .iter() - .map(|v| v.as_f64().unwrap_or_default() as f32) - .collect() - }) - .collect(); - - Ok(embeddings) - } -} diff --git a/src/embedding/provider/mod.rs b/src/embedding/provider/mod.rs deleted file mode 100644 index 21a9d47..0000000 --- a/src/embedding/provider/mod.rs +++ /dev/null @@ -1,112 +0,0 @@ -// Copyright 2025 Muvon Un Limited -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! Embedding providers module -//! -//! This module contains implementations for different embedding providers. -//! Each provider can be optionally compiled based on cargo features. - -use anyhow::Result; -use reqwest::Client; -use std::sync::LazyLock; -use std::time::Duration; - -use super::types::{EmbeddingProviderType, InputType}; - -// Shared HTTP client with connection pooling for optimal performance -static HTTP_CLIENT: LazyLock = LazyLock::new(|| { - Client::builder() - .pool_max_idle_per_host(10) - .pool_idle_timeout(Duration::from_secs(30)) - .timeout(Duration::from_secs(120)) // Increased from 60s to 120s for embedding APIs - .connect_timeout(Duration::from_secs(10)) - .build() - .expect("Failed to create HTTP client") -}); - -// Feature-specific provider modules -#[cfg(feature = "fastembed")] -pub mod fastembed; -#[cfg(feature = "huggingface")] -pub mod huggingface; - -// Always available provider modules -pub mod google; -pub mod jina; -pub mod openai; -pub mod voyage; - -// Re-export providers -#[cfg(feature = "fastembed")] -pub use fastembed::{FastEmbedProvider, FastEmbedProviderImpl}; -#[cfg(feature = "huggingface")] -pub use huggingface::{HuggingFaceProvider, HuggingFaceProviderImpl}; - -// Always available provider re-exports -pub use google::{GoogleProvider, GoogleProviderImpl}; -pub use jina::{JinaProvider, JinaProviderImpl}; -pub use openai::{OpenAIProvider, OpenAIProviderImpl}; -pub use voyage::{VoyageProvider, VoyageProviderImpl}; - -/// Trait for embedding providers -#[async_trait::async_trait] -pub trait EmbeddingProvider: Send + Sync { - async fn generate_embedding(&self, text: &str) -> Result>; - async fn generate_embeddings_batch( - &self, - texts: Vec, - input_type: InputType, - ) -> Result>>; - - /// Get the vector dimension for this provider's model - fn get_dimension(&self) -> usize; - - /// Validate if the model is supported (optional, defaults to true) - fn is_model_supported(&self) -> bool { - true - } -} - -/// Create an embedding provider from provider type and model -pub async fn create_embedding_provider_from_parts( - provider: &EmbeddingProviderType, - model: &str, -) -> Result> { - match provider { - EmbeddingProviderType::FastEmbed => { - #[cfg(feature = "fastembed")] - { - Ok(Box::new(FastEmbedProviderImpl::new(model)?)) - } - #[cfg(not(feature = "fastembed"))] - { - Err(anyhow::anyhow!("FastEmbed support is not compiled in. Please rebuild with --features fastembed")) - } - } - EmbeddingProviderType::Jina => Ok(Box::new(JinaProviderImpl::new(model)?)), - EmbeddingProviderType::Voyage => Ok(Box::new(VoyageProviderImpl::new(model)?)), - EmbeddingProviderType::Google => Ok(Box::new(GoogleProviderImpl::new(model)?)), - EmbeddingProviderType::OpenAI => Ok(Box::new(OpenAIProviderImpl::new(model)?)), - EmbeddingProviderType::HuggingFace => { - #[cfg(feature = "huggingface")] - { - Ok(Box::new(HuggingFaceProviderImpl::new(model).await?)) - } - #[cfg(not(feature = "huggingface"))] - { - Err(anyhow::anyhow!("HuggingFace support is not compiled in. Please rebuild with --features huggingface")) - } - } - } -} diff --git a/src/embedding/provider/openai.rs b/src/embedding/provider/openai.rs deleted file mode 100644 index d5582a4..0000000 --- a/src/embedding/provider/openai.rs +++ /dev/null @@ -1,207 +0,0 @@ -// Copyright 2025 Muvon Un Limited -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! OpenAI embedding provider implementation - -use anyhow::{Context, Result}; -use serde_json::{json, Value}; - -use super::super::types::InputType; -use super::{EmbeddingProvider, HTTP_CLIENT}; - -/// OpenAI provider implementation for trait -pub struct OpenAIProviderImpl { - model_name: String, - dimension: usize, -} - -impl OpenAIProviderImpl { - pub fn new(model: &str) -> Result { - // Validate model first - fail fast if unsupported - let supported_models = [ - "text-embedding-3-small", - "text-embedding-3-large", - "text-embedding-ada-002", - ]; - - if !supported_models.contains(&model) { - return Err(anyhow::anyhow!( - "Unsupported OpenAI model: '{}'. Supported models: {:?}", - model, - supported_models - )); - } - - let dimension = Self::get_model_dimension(model); - Ok(Self { - model_name: model.to_string(), - dimension, - }) - } - - fn get_model_dimension(model: &str) -> usize { - match model { - "text-embedding-3-small" => 1536, - "text-embedding-3-large" => 3072, - "text-embedding-ada-002" => 1536, - _ => { - // This should never be reached due to validation in new() - panic!( - "Invalid OpenAI model '{}' passed to get_model_dimension", - model - ); - } - } - } -} - -#[async_trait::async_trait] -impl EmbeddingProvider for OpenAIProviderImpl { - async fn generate_embedding(&self, text: &str) -> Result> { - OpenAIProvider::generate_embeddings(text, &self.model_name).await - } - - async fn generate_embeddings_batch( - &self, - texts: Vec, - input_type: InputType, - ) -> Result>> { - OpenAIProvider::generate_embeddings_batch(texts, &self.model_name, input_type).await - } - - fn get_dimension(&self) -> usize { - self.dimension - } - - fn is_model_supported(&self) -> bool { - // REAL validation - only support actual OpenAI models, NO HALLUCINATIONS - matches!( - self.model_name.as_str(), - "text-embedding-3-small" | "text-embedding-3-large" | "text-embedding-ada-002" - ) - } -} - -/// OpenAI provider implementation -pub struct OpenAIProvider; - -impl OpenAIProvider { - pub async fn generate_embeddings(contents: &str, model: &str) -> Result> { - let result = - Self::generate_embeddings_batch(vec![contents.to_string()], model, InputType::None) - .await?; - result - .first() - .cloned() - .ok_or_else(|| anyhow::anyhow!("No embeddings found")) - } - - pub async fn generate_embeddings_batch( - texts: Vec, - model: &str, - input_type: InputType, - ) -> Result>> { - let openai_api_key = std::env::var("OPENAI_API_KEY") - .context("OPENAI_API_KEY environment variable not set")?; - - // Apply input type prefixes since OpenAI doesn't have native input_type support - let processed_texts: Vec = texts - .into_iter() - .map(|text| input_type.apply_prefix(&text)) - .collect(); - - // Build request body - let request_body = json!({ - "input": processed_texts, - "model": model, - "encoding_format": "float" - }); - - let response = HTTP_CLIENT - .post("https://api.openai.com/v1/embeddings") - .header("Authorization", format!("Bearer {}", openai_api_key)) - .header("Content-Type", "application/json") - .json(&request_body) - .send() - .await?; - - if !response.status().is_success() { - let error_text = response.text().await?; - return Err(anyhow::anyhow!("OpenAI API error: {}", error_text)); - } - - let response_json: Value = response.json().await?; - - let embeddings = response_json["data"] - .as_array() - .context("Failed to get embeddings array")? - .iter() - .map(|data| { - data["embedding"] - .as_array() - .unwrap_or(&Vec::new()) - .iter() - .map(|v| v.as_f64().unwrap_or_default() as f32) - .collect() - }) - .collect(); - - Ok(embeddings) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_openai_provider_creation() { - // Test valid models - assert!(OpenAIProviderImpl::new("text-embedding-3-small").is_ok()); - assert!(OpenAIProviderImpl::new("text-embedding-3-large").is_ok()); - assert!(OpenAIProviderImpl::new("text-embedding-ada-002").is_ok()); - - // Test invalid model - assert!(OpenAIProviderImpl::new("invalid-model").is_err()); - } - - #[test] - fn test_model_dimensions() { - let provider_small = OpenAIProviderImpl::new("text-embedding-3-small").unwrap(); - assert_eq!(provider_small.get_dimension(), 1536); - - let provider_large = OpenAIProviderImpl::new("text-embedding-3-large").unwrap(); - assert_eq!(provider_large.get_dimension(), 3072); - - let provider_ada = OpenAIProviderImpl::new("text-embedding-ada-002").unwrap(); - assert_eq!(provider_ada.get_dimension(), 1536); - } - - #[test] - fn test_model_validation() { - let provider_valid = OpenAIProviderImpl::new("text-embedding-3-small").unwrap(); - assert!(provider_valid.is_model_supported()); - - // This would panic if we tried to create an invalid model, so we test indirectly - let supported_models = [ - "text-embedding-3-small", - "text-embedding-3-large", - "text-embedding-ada-002", - ]; - for model in supported_models { - let provider = OpenAIProviderImpl::new(model).unwrap(); - assert!(provider.is_model_supported()); - } - } -} diff --git a/src/embedding/provider/voyage.rs b/src/embedding/provider/voyage.rs deleted file mode 100644 index 0383591..0000000 --- a/src/embedding/provider/voyage.rs +++ /dev/null @@ -1,177 +0,0 @@ -// Copyright 2025 Muvon Un Limited -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! Voyage AI embedding provider implementation - -use anyhow::{Context, Result}; -use serde_json::{json, Value}; - -use super::super::types::InputType; -use super::{EmbeddingProvider, HTTP_CLIENT}; - -/// Voyage provider implementation for trait -pub struct VoyageProviderImpl { - model_name: String, - dimension: usize, -} - -impl VoyageProviderImpl { - pub fn new(model: &str) -> Result { - // Validate model first - fail fast if unsupported - let supported_models = [ - "voyage-3.5", - "voyage-3.5-lite", - "voyage-3-large", - "voyage-code-2", - "voyage-code-3", - "voyage-finance-2", - "voyage-law-2", - "voyage-2", - ]; - - if !supported_models.contains(&model) { - return Err(anyhow::anyhow!( - "Unsupported Voyage model: '{}'. Supported models: {:?}", - model, - supported_models - )); - } - - let dimension = Self::get_model_dimension(model); - Ok(Self { - model_name: model.to_string(), - dimension, - }) - } - - fn get_model_dimension(model: &str) -> usize { - match model { - "voyage-3.5" => 1024, - "voyage-3.5-lite" => 1024, - "voyage-3-large" => 1024, - "voyage-code-2" => 1536, - "voyage-code-3" => 1024, - "voyage-finance-2" => 1024, - "voyage-law-2" => 1024, - "voyage-2" => 1024, - _ => { - // This should never be reached due to validation in new() - panic!( - "Invalid Voyage model '{}' passed to get_model_dimension", - model - ); - } - } - } -} - -#[async_trait::async_trait] -impl EmbeddingProvider for VoyageProviderImpl { - async fn generate_embedding(&self, text: &str) -> Result> { - VoyageProvider::generate_embeddings(text, &self.model_name).await - } - - async fn generate_embeddings_batch( - &self, - texts: Vec, - input_type: InputType, - ) -> Result>> { - VoyageProvider::generate_embeddings_batch(texts, &self.model_name, input_type).await - } - - fn get_dimension(&self) -> usize { - self.dimension - } - - fn is_model_supported(&self) -> bool { - // REAL validation - only support actual Voyage models, NO HALLUCINATIONS - matches!( - self.model_name.as_str(), - "voyage-3.5" - | "voyage-3.5-lite" - | "voyage-3-large" - | "voyage-code-2" - | "voyage-code-3" - | "voyage-finance-2" - | "voyage-law-2" - | "voyage-2" - ) - } -} - -/// Voyage AI provider implementation -pub struct VoyageProvider; - -impl VoyageProvider { - pub async fn generate_embeddings(contents: &str, model: &str) -> Result> { - let result = - Self::generate_embeddings_batch(vec![contents.to_string()], model, InputType::None) - .await?; - result - .first() - .cloned() - .ok_or_else(|| anyhow::anyhow!("No embeddings found")) - } - - pub async fn generate_embeddings_batch( - texts: Vec, - model: &str, - input_type: InputType, - ) -> Result>> { - let voyage_api_key = std::env::var("VOYAGE_API_KEY") - .context("VOYAGE_API_KEY environment variable not set")?; - - // Build request body with optional input_type - let mut request_body = json!({ - "input": texts, - "model": model, - }); - - // Add input_type if specified (Voyage API native support) - if let Some(input_type_str) = input_type.as_api_str() { - request_body["input_type"] = json!(input_type_str); - } - - let response = HTTP_CLIENT - .post("https://api.voyageai.com/v1/embeddings") - .header("Authorization", format!("Bearer {}", voyage_api_key)) - .header("Content-Type", "application/json") - .json(&request_body) - .send() - .await?; - - if !response.status().is_success() { - let error_text = response.text().await?; - return Err(anyhow::anyhow!("Voyage API error: {}", error_text)); - } - - let response_json: Value = response.json().await?; - - let embeddings = response_json["data"] - .as_array() - .context("Failed to get embeddings array")? - .iter() - .map(|data| { - data["embedding"] - .as_array() - .unwrap_or(&Vec::new()) - .iter() - .map(|v| v.as_f64().unwrap_or_default() as f32) - .collect() - }) - .collect(); - - Ok(embeddings) - } -} diff --git a/src/embedding/tests.rs b/src/embedding/tests.rs deleted file mode 100644 index f6287fb..0000000 --- a/src/embedding/tests.rs +++ /dev/null @@ -1,491 +0,0 @@ -// Copyright 2025 Muvon Un Limited -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Example test to verify SentenceTransformer integration -// This would typically go in tests/ directory or within the module - -#[cfg(test)] -mod embedding_tests { - use crate::embedding::types::{parse_provider_model, EmbeddingConfig}; - use crate::embedding::{ - count_tokens, split_texts_into_token_limited_batches, EmbeddingProviderType, - }; - - #[cfg(any( - feature = "huggingface", - feature = "fastembed", - not(feature = "huggingface"), - not(feature = "fastembed") - ))] - use crate::embedding::create_embedding_provider_from_parts; - - #[tokio::test] - #[cfg(feature = "huggingface")] - async fn test_sentence_transformer_provider_creation() { - // Test that we can create a SentenceTransformer provider - let provider_type = EmbeddingProviderType::HuggingFace; - let model = "sentence-transformers/all-MiniLM-L6-v2"; - - let result = create_embedding_provider_from_parts(&provider_type, model).await; - if let Err(e) = &result { - eprintln!("Error creating HuggingFace provider: {}", e); - } - assert!( - result.is_ok(), - "Should be able to create SentenceTransformer provider: {:?}", - result.err() - ); - } - - #[test] - fn test_provider_model_parsing() { - // Test the new provider:model syntax parsing - let mut test_cases = vec![( - "jinaai:jina-embeddings-v3", - EmbeddingProviderType::Jina, - "jina-embeddings-v3", - )]; - - // Add SentenceTransformer test case only if feature is enabled - #[cfg(feature = "huggingface")] - test_cases.push(( - "huggingface:sentence-transformers/all-MiniLM-L6-v2", - EmbeddingProviderType::HuggingFace, - "sentence-transformers/all-MiniLM-L6-v2", - )); - - // Add FastEmbed test cases only if feature is enabled - #[cfg(feature = "fastembed")] - { - test_cases.push(( - "fastembed:all-MiniLM-L6-v2", - EmbeddingProviderType::FastEmbed, - "all-MiniLM-L6-v2", - )); - test_cases.push(( - "all-MiniLM-L6-v2", // Legacy format without provider - EmbeddingProviderType::FastEmbed, - "all-MiniLM-L6-v2", - )); - } - - // Add Voyage test case (always available) - test_cases.push(( - "voyage:voyage-code-3", - EmbeddingProviderType::Voyage, - "voyage-code-3", - )); - - for (input, expected_provider, expected_model) in test_cases { - let (provider, model) = parse_provider_model(input); - assert_eq!( - provider, expected_provider, - "Provider should match for input: {}", - input - ); - assert_eq!( - model, expected_model, - "Model should match for input: {}", - input - ); - } - } - - #[test] - fn test_default_config_format() { - // Test that default config uses new provider:model format - let config = - crate::config::Config::load_from_template().expect("Failed to load template config"); - - // Check that default models use provider:model format - assert!( - config.embedding.code_model.contains(':'), - "Code model should use provider:model format" - ); - assert!( - config.embedding.text_model.contains(':'), - "Text model should use provider:model format" - ); - - // Test parsing the default models - let (code_provider, _) = parse_provider_model(&config.embedding.code_model); - let (text_provider, _) = parse_provider_model(&config.embedding.text_model); - - // When FastEmbed is not available, should fall back to Voyage - assert_eq!(code_provider, EmbeddingProviderType::Voyage); - assert_eq!(text_provider, EmbeddingProviderType::Voyage); - } - - #[tokio::test] - #[cfg(feature = "huggingface")] - async fn test_embedding_config_methods() { - let config = EmbeddingConfig { - code_model: "huggingface:microsoft/codebert-base".to_string(), - text_model: "huggingface:sentence-transformers/all-mpnet-base-v2".to_string(), - }; - - // Test getting active provider - let active_provider = config.get_active_provider(); - assert_eq!(active_provider, EmbeddingProviderType::HuggingFace); - - // Test vector dimensions - let dim = config - .get_vector_dimension( - &EmbeddingProviderType::HuggingFace, - "jinaai/jina-embeddings-v2-base-code", - ) - .await; - assert_eq!(dim, 768); - - let dim2 = config - .get_vector_dimension( - &EmbeddingProviderType::HuggingFace, - "sentence-transformers/all-MiniLM-L6-v2", - ) - .await; - assert_eq!(dim2, 384); - } - - #[tokio::test] - #[cfg(not(feature = "huggingface"))] - async fn test_embedding_config_methods_without_sentence_transformer() { - let config = EmbeddingConfig { - code_model: "voyage:voyage-code-3".to_string(), - text_model: "voyage:voyage-3.5-lite".to_string(), - }; - - // Test getting active provider - let active_provider = config.get_active_provider(); - assert_eq!(active_provider, EmbeddingProviderType::Voyage); - - // Test vector dimensions for Voyage models - let dim = config - .get_vector_dimension(&EmbeddingProviderType::Voyage, "voyage-code-3") - .await; - assert_eq!(dim, 1024); - - let dim2 = config - .get_vector_dimension(&EmbeddingProviderType::Voyage, "voyage-3.5-lite") - .await; - assert_eq!(dim2, 1024); - } - - #[test] - fn test_token_counting() { - // Test basic token counting - let text = "Hello world!"; - let token_count = count_tokens(text); - assert!(token_count > 0, "Should count tokens for basic text"); - - // Test empty string - let empty_count = count_tokens(""); - assert_eq!(empty_count, 0, "Empty string should have 0 tokens"); - - // Test longer text - let long_text = "This is a longer text that should have more tokens than the simple hello world example."; - let long_count = count_tokens(long_text); - assert!( - long_count > token_count, - "Longer text should have more tokens" - ); - } - - #[test] - fn test_token_limited_batching() { - let texts = vec![ - "Short text".to_string(), - "This is a medium length text that has more tokens".to_string(), - "Another short one".to_string(), - "This is a very long text that contains many words and should definitely exceed any reasonable token limit for a single batch when combined with other texts".to_string(), - "Final text".to_string(), - ]; - - // Test with small token limit to force splitting - let batches = split_texts_into_token_limited_batches(texts.clone(), 10, 20); - - // Should create multiple batches due to token limit - assert!( - batches.len() > 1, - "Should create multiple batches with small token limit" - ); - - // Verify all texts are included - let total_texts: usize = batches.iter().map(|b| b.len()).sum(); - assert_eq!( - total_texts, - texts.len(), - "All texts should be included in batches" - ); - - // Test with large limits (should create single batch) - let single_batch = split_texts_into_token_limited_batches(texts.clone(), 100, 10000); - assert_eq!( - single_batch.len(), - 1, - "Should create single batch with large limits" - ); - assert_eq!( - single_batch[0].len(), - texts.len(), - "Single batch should contain all texts" - ); - } - - #[test] - fn test_config_has_token_limit() { - let config = - crate::config::Config::load_from_template().expect("Failed to load template config"); - assert!( - config.index.embeddings_max_tokens_per_batch > 0, - "Should have positive token limit" - ); - assert_eq!( - config.index.embeddings_max_tokens_per_batch, 100000, - "Should have default token limit of 100000" - ); - } - - // FastEmbed provider tests - only run when feature is enabled - #[test] - #[cfg(feature = "fastembed")] - fn test_fastembed_provider_creation() { - use crate::embedding::provider::fastembed::FastEmbedProviderImpl; - use crate::embedding::provider::EmbeddingProvider; - - // Test creating provider with a known model - let result = FastEmbedProviderImpl::new("Xenova/all-MiniLM-L6-v2"); - assert!( - result.is_ok(), - "Should create FastEmbed provider successfully: {:?}", - result.err() - ); - - let provider = result.unwrap(); - assert_eq!( - provider.get_dimension(), - 384, - "all-MiniLM-L6-v2 should have 384 dimensions" - ); - assert!(provider.is_model_supported(), "Model should be supported"); - } - - #[test] - #[cfg(feature = "fastembed")] - fn test_fastembed_model_validation() { - use crate::embedding::provider::fastembed::FastEmbedProviderImpl; - - // Test with invalid model - let result = FastEmbedProviderImpl::new("invalid-model-name"); - assert!(result.is_err(), "Should fail with invalid model name"); - - // Test basic provider creation with valid model - let valid_result = FastEmbedProviderImpl::new("Xenova/all-MiniLM-L6-v2"); - assert!( - valid_result.is_ok(), - "Should create provider with valid model" - ); - } - - #[tokio::test] - #[cfg(feature = "fastembed")] - async fn test_fastembed_embedding_generation() { - use crate::embedding::provider::fastembed::FastEmbedProviderImpl; - use crate::embedding::provider::EmbeddingProvider; - - // Use a small, fast model for testing - let provider = FastEmbedProviderImpl::new("Xenova/all-MiniLM-L6-v2") - .expect("Should create FastEmbed provider"); - - // Test basic provider functionality without actual embedding generation - // (which would require downloading models) - assert_eq!( - provider.get_dimension(), - 384, - "Should have correct dimension" - ); - assert!(provider.is_model_supported(), "Should support the model"); - - // Note: Actual embedding generation test is commented out to avoid - // model download requirements in test environment - // In a real integration test environment, you would uncomment: - /* - let text = "This is a test text for embedding generation."; - let result = provider.generate_embedding(text).await; - assert!(result.is_ok(), "Should generate embedding successfully"); - let embedding = result.unwrap(); - assert_eq!(embedding.len(), 384, "Should produce 384-dimensional embedding"); - */ - } - - // HuggingFace provider tests - only run when feature is enabled - #[tokio::test] - #[cfg(feature = "huggingface")] - async fn test_huggingface_provider_creation() { - // Test that the HuggingFace provider feature is available - // We test through the factory function to avoid HTTP requests - let provider_type = EmbeddingProviderType::HuggingFace; - let model = "sentence-transformers/all-MiniLM-L6-v2"; - - // This will test that the provider can be created through the factory - // without actually making HTTP requests (which would happen in new()) - let result = create_embedding_provider_from_parts(&provider_type, model).await; - - // The result might be an error due to HTTP requests, but it should not be - // a "feature not compiled" error - if let Err(error) = result { - let error_msg = format!("{}", error); - assert!( - !error_msg.contains("not compiled"), - "Should not be a 'not compiled' error when feature is enabled: {}", - error_msg - ); - } - } - - #[tokio::test] - #[cfg(feature = "huggingface")] - async fn test_huggingface_dimension_detection() { - // Test that HuggingFace provider feature is available - // We test basic functionality without making HTTP requests - - // Test that the provider type is recognized - let provider_type = EmbeddingProviderType::HuggingFace; - assert_eq!(format!("{:?}", provider_type), "HuggingFace"); - - // Test that we can attempt to create providers (even if they fail due to HTTP) - let test_models = vec![ - "sentence-transformers/all-MiniLM-L6-v2", - "sentence-transformers/all-mpnet-base-v2", - "microsoft/codebert-base", - ]; - - for model in test_models { - let result = create_embedding_provider_from_parts(&provider_type, model).await; - // We don't care if it succeeds or fails, just that it's not a "not compiled" error - if let Err(error) = result { - let error_msg = format!("{}", error); - assert!( - !error_msg.contains("not compiled"), - "Should not be a 'not compiled' error for model {}: {}", - model, - error_msg - ); - } - } - } - - #[tokio::test] - #[cfg(feature = "huggingface")] - async fn test_huggingface_embedding_generation() { - // Test that HuggingFace provider feature is compiled and available - // We avoid actual embedding generation to prevent HTTP requests and runtime issues - - let provider_type = EmbeddingProviderType::HuggingFace; - let model = "sentence-transformers/all-MiniLM-L6-v2"; - - // Test that the provider can be instantiated through factory - let result = create_embedding_provider_from_parts(&provider_type, model).await; - - // We expect this might fail due to HTTP requests, but it should not be - // a "feature not compiled" error - if let Err(error) = result { - let error_msg = format!("{}", error); - assert!( - !error_msg.contains("not compiled"), - "Should not be a 'not compiled' error when feature is enabled: {}", - error_msg - ); - } - - // Note: Actual embedding generation test is commented out to avoid - // model download requirements and runtime conflicts in test environment - // In a real integration test environment, you would test actual embedding generation - } - - // Test that disabled features return appropriate errors - #[tokio::test] - #[cfg(not(feature = "fastembed"))] - async fn test_fastembed_disabled_error() { - // When feature is disabled, we test through the factory function - let provider_type = EmbeddingProviderType::FastEmbed; - let model = "any-model"; - - let result = create_embedding_provider_from_parts(&provider_type, model).await; - assert!( - result.is_err(), - "Should return error when FastEmbed feature is disabled" - ); - - if let Err(error) = result { - let error_msg = format!("{}", error); - assert!( - error_msg.contains("FastEmbed") || error_msg.contains("not compiled"), - "Error should mention FastEmbed not available: {}", - error_msg - ); - } - } - - #[tokio::test] - #[cfg(not(feature = "huggingface"))] - async fn test_huggingface_disabled_error() { - // When feature is disabled, we test through the factory function - let provider_type = EmbeddingProviderType::HuggingFace; - let model = "any-model"; - - let result = create_embedding_provider_from_parts(&provider_type, model).await; - assert!( - result.is_err(), - "Should return error when HuggingFace feature is disabled" - ); - - if let Err(error) = result { - let error_msg = format!("{}", error); - assert!( - error_msg.contains("HuggingFace") || error_msg.contains("not compiled"), - "Error should mention HuggingFace not available: {}", - error_msg - ); - } - } - - // Integration test for provider factory with features - #[tokio::test] - #[cfg(feature = "fastembed")] - async fn test_provider_factory_with_fastembed() { - let provider_type = EmbeddingProviderType::FastEmbed; - let model = "Xenova/all-MiniLM-L6-v2"; - - let result = create_embedding_provider_from_parts(&provider_type, model).await; - assert!( - result.is_ok(), - "Should create FastEmbed provider through factory: {:?}", - result.err() - ); - } - - #[tokio::test] - #[cfg(feature = "huggingface")] - async fn test_provider_factory_with_huggingface() { - let provider_type = EmbeddingProviderType::HuggingFace; - let model = "sentence-transformers/all-MiniLM-L6-v2"; - - let result = create_embedding_provider_from_parts(&provider_type, model).await; - assert!( - result.is_ok(), - "Should create HuggingFace provider through factory: {:?}", - result.err() - ); - } -} diff --git a/src/embedding/types.rs b/src/embedding/types.rs deleted file mode 100644 index fbccf7b..0000000 --- a/src/embedding/types.rs +++ /dev/null @@ -1,210 +0,0 @@ -// Copyright 2025 Muvon Un Limited -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use anyhow::Result; -use serde::{Deserialize, Serialize}; - -/// Input type for embedding generation -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum InputType { - /// Default - no input_type (existing behavior) - None, - /// For search operations - Query, - /// For indexing operations - Document, -} - -impl Default for InputType { - fn default() -> Self { - Self::None - } -} - -impl InputType { - /// Convert to API string for providers that support it (like Voyage) - pub fn as_api_str(&self) -> Option<&'static str> { - match self { - InputType::None => None, - InputType::Query => Some("query"), - InputType::Document => Some("document"), - } - } - - /// Get prefix for manual injection (for providers that don't support input_type API) - pub fn get_prefix(&self) -> Option<&'static str> { - match self { - InputType::None => None, - InputType::Query => Some(crate::constants::QUERY_PREFIX), - InputType::Document => Some(crate::constants::DOCUMENT_PREFIX), - } - } - - /// Apply prefix to text for manual injection - pub fn apply_prefix(&self, text: &str) -> String { - match self.get_prefix() { - Some(prefix) => format!("{}{}", prefix, text), - None => text.to_string(), - } - } -} - -/// Supported embedding providers -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -#[serde(rename_all = "lowercase")] -pub enum EmbeddingProviderType { - FastEmbed, - Jina, - Voyage, - Google, - HuggingFace, - OpenAI, -} - -impl Default for EmbeddingProviderType { - fn default() -> Self { - #[cfg(feature = "fastembed")] - { - Self::FastEmbed - } - #[cfg(not(feature = "fastembed"))] - { - Self::Voyage - } - } -} - -/// Configuration for embedding models (simplified) -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct EmbeddingConfig { - /// Code embedding model (format: "provider:model") - pub code_model: String, - - /// Text embedding model (format: "provider:model") - pub text_model: String, -} - -impl Default for EmbeddingConfig { - fn default() -> Self { - // Use FastEmbed models if available, otherwise fall back to Voyage - #[cfg(feature = "fastembed")] - { - Self { - code_model: "fastembed:jinaai/jina-embeddings-v2-base-code".to_string(), - text_model: "fastembed:sentence-transformers/all-MiniLM-L6-v2-quantized" - .to_string(), - } - } - #[cfg(not(feature = "fastembed"))] - { - Self { - code_model: "voyage:voyage-code-3".to_string(), - text_model: "voyage:voyage-3.5-lite".to_string(), - } - } - } -} - -/// Parse provider and model from a string in format "provider:model" -pub fn parse_provider_model(input: &str) -> (EmbeddingProviderType, String) { - if let Some((provider_str, model)) = input.split_once(':') { - let provider = match provider_str.to_lowercase().as_str() { - "fastembed" => EmbeddingProviderType::FastEmbed, - "jinaai" | "jina" => EmbeddingProviderType::Jina, - "voyageai" | "voyage" => EmbeddingProviderType::Voyage, - "google" => EmbeddingProviderType::Google, - "huggingface" | "hf" => EmbeddingProviderType::HuggingFace, - "openai" => EmbeddingProviderType::OpenAI, - _ => { - // Default fallback - use FastEmbed if available, otherwise Voyage - #[cfg(feature = "fastembed")] - { - EmbeddingProviderType::FastEmbed - } - #[cfg(not(feature = "fastembed"))] - { - EmbeddingProviderType::Voyage - } - } - }; - (provider, model.to_string()) - } else { - // Legacy format - assume FastEmbed if available, otherwise Voyage - #[cfg(feature = "fastembed")] - { - (EmbeddingProviderType::FastEmbed, input.to_string()) - } - #[cfg(not(feature = "fastembed"))] - { - (EmbeddingProviderType::Voyage, input.to_string()) - } - } -} - -impl EmbeddingConfig { - /// Get the currently active provider based on the code model - pub fn get_active_provider(&self) -> EmbeddingProviderType { - let (provider, _) = parse_provider_model(&self.code_model); - provider - } - - /// Get API key for a specific provider (from environment variables only) - pub fn get_api_key(&self, provider: &EmbeddingProviderType) -> Option { - match provider { - EmbeddingProviderType::Jina => std::env::var("JINA_API_KEY").ok(), - EmbeddingProviderType::Voyage => std::env::var("VOYAGE_API_KEY").ok(), - EmbeddingProviderType::Google => std::env::var("GOOGLE_API_KEY").ok(), - _ => None, // FastEmbed and SentenceTransformer don't need API keys - } - } - - /// Get vector dimension by creating a provider instance - pub async fn get_vector_dimension( - &self, - provider: &EmbeddingProviderType, - model: &str, - ) -> usize { - // Try to create provider and get dimension - match crate::embedding::provider::create_embedding_provider_from_parts(provider, model) - .await - { - Ok(provider_impl) => provider_impl.get_dimension(), - Err(e) => { - panic!( - "Failed to create provider for {:?}:{}: {}. Using fallback dimension.", - provider, model, e - ); - } - } - } - - /// Validate model by trying to create provider - pub async fn validate_model( - &self, - provider: &EmbeddingProviderType, - model: &str, - ) -> Result<()> { - let provider_impl = - crate::embedding::provider::create_embedding_provider_from_parts(provider, model) - .await?; - if !provider_impl.is_model_supported() { - return Err(anyhow::anyhow!( - "Model {} is not supported by provider {:?}", - model, - provider - )); - } - Ok(()) - } -} diff --git a/src/indexer/search.rs b/src/indexer/search.rs index d27a171..dced954 100644 --- a/src/indexer/search.rs +++ b/src/indexer/search.rs @@ -1187,7 +1187,7 @@ pub async fn search_codebase_with_details_multi_query_text( if queries.is_empty() { return Err(anyhow::anyhow!("At least one query is required")); } - if queries.len() > crate::constants::MAX_QUERIES { + if queries.len() > octolib::embedding::constants::MAX_QUERIES { return Err(anyhow::anyhow!( "Maximum {} queries allowed, got {}. Use fewer, more specific terms.", crate::constants::MAX_QUERIES, @@ -1557,7 +1557,7 @@ pub async fn search_codebase_with_details_multi_query( if queries.is_empty() { return Err(anyhow::anyhow!("At least one query is required")); } - if queries.len() > crate::constants::MAX_QUERIES { + if queries.len() > octolib::embedding::constants::MAX_QUERIES { return Err(anyhow::anyhow!( "Maximum {} queries allowed, got {}. Use fewer, more specific terms.", crate::constants::MAX_QUERIES, diff --git a/src/mcp/memory.rs b/src/mcp/memory.rs index be7dfcd..fa98988 100644 --- a/src/mcp/memory.rs +++ b/src/mcp/memory.rs @@ -21,11 +21,11 @@ use tokio::sync::Mutex; use tracing::{debug, warn}; use crate::config::Config; -use crate::constants::MAX_QUERIES; use crate::embedding::truncate_output; use crate::mcp::logging::log_critical_anyhow_error; use crate::mcp::types::{McpError, McpTool}; use crate::memory::{MemoryManager, MemoryQuery, MemoryType}; +use octolib::embedding::constants::MAX_QUERIES; /// Memory tools provider #[derive(Clone)] diff --git a/src/mcp/semantic_code.rs b/src/mcp/semantic_code.rs index c6c5e87..2b69ab6 100644 --- a/src/mcp/semantic_code.rs +++ b/src/mcp/semantic_code.rs @@ -17,13 +17,13 @@ use serde_json::{json, Value}; use tracing::debug; use crate::config::Config; -use crate::constants::MAX_QUERIES; use crate::embedding::truncate_output; use crate::indexer::search::{ search_codebase_with_details_multi_query_text, search_codebase_with_details_text, }; use crate::indexer::{extract_file_signatures, render_signatures_text, NoindexWalker, PathUtils}; use crate::mcp::types::{McpError, McpTool}; +use octolib::embedding::constants::MAX_QUERIES; /// Semantic code search tool provider #[derive(Clone)] From 676993db666e6d9346cd2464fd0b5fedcab504a5 Mon Sep 17 00:00:00 2001 From: Don Hardman Date: Tue, 16 Sep 2025 15:55:04 +0700 Subject: [PATCH 2/3] refactor(embedding): simplify embedding config and add octocode logic - Introduce EmbeddingGenerationConfig for octocode-specific settings - Parse provider:model strings for embedding generation calls - Implement generate_search_embeddings with mode-based logic - Add content hashing utilities including file path and line ranges - Replace broad re-export with selective imports from octolib embedding --- src/embedding/mod.rs | 169 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 159 insertions(+), 10 deletions(-) diff --git a/src/embedding/mod.rs b/src/embedding/mod.rs index 892ae95..7a5b908 100644 --- a/src/embedding/mod.rs +++ b/src/embedding/mod.rs @@ -12,15 +12,42 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Re-export embedding functionality from octolib +//! Re-export embedding functionality from octolib and add octocode-specific logic use crate::config::Config; use anyhow::Result; -// Re-export everything from octolib::embedding -pub use octolib::embedding::*; +// Re-export core functionality from octolib::embedding +pub use octolib::embedding::{ + count_tokens, create_embedding_provider_from_parts, split_texts_into_token_limited_batches, + truncate_output, EmbeddingProvider, InputType, +}; -/// Convert octocode Config to octolib EmbeddingGenerationConfig +/// Configuration for embedding generation (octocode-specific) +#[derive(Debug, Clone)] +pub struct EmbeddingGenerationConfig { + /// Code embedding model (format: "provider:model") + pub code_model: String, + /// Text embedding model (format: "provider:model") + pub text_model: String, + /// Batch size for embedding generation + pub batch_size: usize, + /// Maximum tokens per batch + pub max_tokens_per_batch: usize, +} + +impl Default for EmbeddingGenerationConfig { + fn default() -> Self { + Self { + code_model: "voyage:voyage-code-3".to_string(), + text_model: "voyage:voyage-3.5-lite".to_string(), + batch_size: 16, + max_tokens_per_batch: 100_000, + } + } +} + +/// Convert octocode Config to octocode EmbeddingGenerationConfig impl From<&Config> for EmbeddingGenerationConfig { fn from(config: &Config) -> Self { Self { @@ -40,7 +67,22 @@ pub async fn generate_embeddings( config: &Config, ) -> Result> { let embedding_config = EmbeddingGenerationConfig::from(config); - octolib::embedding::generate_embeddings(contents, is_code, &embedding_config).await + + // Get the model string from config + let model_string = if is_code { + &embedding_config.code_model + } else { + &embedding_config.text_model + }; + + // Parse provider and model from the string + let (provider, model) = if let Some((p, m)) = model_string.split_once(':') { + (p, m) + } else { + return Err(anyhow::anyhow!("Invalid model format: {}", model_string)); + }; + + octolib::embedding::generate_embeddings(contents, provider, model).await } /// Generate batch embeddings based on configured provider (supports provider:model format) @@ -52,17 +94,124 @@ pub async fn generate_embeddings_batch( input_type: InputType, ) -> Result>> { let embedding_config = EmbeddingGenerationConfig::from(config); - octolib::embedding::generate_embeddings_batch(texts, is_code, &embedding_config, input_type) - .await + + // Get the model string from config + let model_string = if is_code { + &embedding_config.code_model + } else { + &embedding_config.text_model + }; + + // Parse provider and model from the string + let (provider, model) = if let Some((p, m)) = model_string.split_once(':') { + (p, m) + } else { + return Err(anyhow::anyhow!("Invalid model format: {}", model_string)); + }; + + octolib::embedding::generate_embeddings_batch( + texts, + provider, + model, + input_type, + embedding_config.batch_size, + embedding_config.max_tokens_per_batch, + ) + .await +} + +/// Search mode embeddings result (octocode-specific) +#[derive(Debug, Clone)] +pub struct SearchModeEmbeddings { + pub code_embeddings: Option>, + pub text_embeddings: Option>, } /// Generate embeddings for search based on mode - centralized logic to avoid duplication -/// Compatibility wrapper for octocode Config +/// Compatibility wrapper for octocode Config (octocode-specific) pub async fn generate_search_embeddings( query: &str, mode: &str, config: &Config, ) -> Result { - let embedding_config = EmbeddingGenerationConfig::from(config); - octolib::embedding::generate_search_embeddings(query, mode, &embedding_config).await + match mode { + "code" => { + // Use code model for code searches only + let embeddings = generate_embeddings(query, true, config).await?; + Ok(SearchModeEmbeddings { + code_embeddings: Some(embeddings), + text_embeddings: None, + }) + } + "docs" | "text" => { + // Use text model for documents and text searches only + let embeddings = generate_embeddings(query, false, config).await?; + Ok(SearchModeEmbeddings { + code_embeddings: None, + text_embeddings: Some(embeddings), + }) + } + "all" => { + // For "all" mode, check if code and text models are different + // If different, generate separate embeddings; if same, use one set + let embedding_config = EmbeddingGenerationConfig::from(config); + let code_model = &embedding_config.code_model; + let text_model = &embedding_config.text_model; + + if code_model == text_model { + // Same model for both - generate once and reuse + let embeddings = generate_embeddings(query, true, config).await?; + Ok(SearchModeEmbeddings { + code_embeddings: Some(embeddings.clone()), + text_embeddings: Some(embeddings), + }) + } else { + // Different models - generate separate embeddings + let code_embeddings = generate_embeddings(query, true, config).await?; + let text_embeddings = generate_embeddings(query, false, config).await?; + Ok(SearchModeEmbeddings { + code_embeddings: Some(code_embeddings), + text_embeddings: Some(text_embeddings), + }) + } + } + _ => Err(anyhow::anyhow!( + "Invalid search mode '{}'. Use 'all', 'code', 'docs', or 'text'.", + mode + )), + } +} + +/// Calculate a unique hash for content including file path (octocode-specific) +pub fn calculate_unique_content_hash(contents: &str, file_path: &str) -> String { + use sha2::{Digest, Sha256}; + let mut hasher = Sha256::new(); + hasher.update(contents.as_bytes()); + hasher.update(file_path.as_bytes()); + format!("{:x}", hasher.finalize()) +} + +/// Calculate a unique hash for content including file path and line ranges (octocode-specific) +/// This ensures blocks are reindexed when their position changes in the file +pub fn calculate_content_hash_with_lines( + contents: &str, + file_path: &str, + start_line: usize, + end_line: usize, +) -> String { + use sha2::{Digest, Sha256}; + let mut hasher = Sha256::new(); + hasher.update(contents.as_bytes()); + hasher.update(file_path.as_bytes()); + hasher.update(start_line.to_string().as_bytes()); + hasher.update(end_line.to_string().as_bytes()); + format!("{:x}", hasher.finalize()) +} + +/// Calculate content hash without file path (octocode-specific) +pub fn calculate_content_hash(contents: &str) -> String { + use sha2::{Digest, Sha256}; + let mut hasher = Sha256::new(); + hasher.update(contents.as_bytes()); + format!("{:x}", hasher.finalize()) } From 3a6d34a4ae591cb70f29107ab9dcf984dcb4e858 Mon Sep 17 00:00:00 2001 From: Don Hardman Date: Tue, 16 Sep 2025 18:16:47 +0700 Subject: [PATCH 3/3] chore(embedding): update deps and re-export octolib types - Add fastembed and huggingface features - Bump indexmap to 2.11.3 for consistency - Update multiple dependencies to newer versions - Re-export octolib embedding types for compatibility --- Cargo.lock | 166 ++++++++++++++++++++++++++----------------- Cargo.toml | 2 + src/embedding/mod.rs | 13 ++++ 3 files changed, 116 insertions(+), 65 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cb2cdb3..819e9da 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -319,7 +319,7 @@ dependencies = [ "arrow-schema", "chrono", "half", - "indexmap 2.11.1", + "indexmap 2.11.3", "lexical-core", "memchr", "num", @@ -1370,9 +1370,9 @@ dependencies = [ [[package]] name = "clap_complete" -version = "4.5.57" +version = "4.5.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d9501bd3f5f09f7bbee01da9a511073ed30a80cd7a509f1214bb74eadea71ad" +checksum = "75bf0b32ad2e152de789bb635ea4d3078f6b838ad7974143e99b99f45a04af4a" dependencies = [ "clap", ] @@ -1838,7 +1838,7 @@ dependencies = [ "base64 0.22.1", "half", "hashbrown 0.14.5", - "indexmap 2.11.1", + "indexmap 2.11.3", "libc", "log", "object_store", @@ -1976,7 +1976,7 @@ dependencies = [ "datafusion-functions-aggregate-common", "datafusion-functions-window-common", "datafusion-physical-expr-common", - "indexmap 2.11.1", + "indexmap 2.11.3", "paste", "serde_json", "sqlparser", @@ -1990,7 +1990,7 @@ checksum = "70fafb3a045ed6c49cfca0cd090f62cf871ca6326cc3355cb0aaf1260fa760b6" dependencies = [ "arrow", "datafusion-common", - "indexmap 2.11.1", + "indexmap 2.11.3", "itertools 0.14.0", "paste", ] @@ -2145,7 +2145,7 @@ dependencies = [ "datafusion-common", "datafusion-expr", "datafusion-physical-expr", - "indexmap 2.11.1", + "indexmap 2.11.3", "itertools 0.14.0", "log", "regex", @@ -2167,7 +2167,7 @@ dependencies = [ "datafusion-physical-expr-common", "half", "hashbrown 0.14.5", - "indexmap 2.11.1", + "indexmap 2.11.3", "itertools 0.14.0", "log", "paste", @@ -2228,7 +2228,7 @@ dependencies = [ "futures", "half", "hashbrown 0.14.5", - "indexmap 2.11.1", + "indexmap 2.11.3", "itertools 0.14.0", "log", "parking_lot", @@ -2270,7 +2270,7 @@ dependencies = [ "bigdecimal", "datafusion-common", "datafusion-expr", - "indexmap 2.11.1", + "indexmap 2.11.3", "log", "regex", "sqlparser", @@ -2595,9 +2595,9 @@ checksum = "9afc2bd4d5a73106dd53d10d73d3401c2f32730ba2c0b93ddb888a8983680471" [[package]] name = "fastembed" -version = "5.1.0" +version = "5.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b6ea4eee3e41add51440698fdde1b4afe83241e73fa90c29462a0460d69bcec" +checksum = "d2e9bf3ea201e5d338450555088e02cff23b00be92bead3eff7ed341c68f5ac6" dependencies = [ "anyhow", "hf-hub", @@ -2605,7 +2605,7 @@ dependencies = [ "ndarray", "ort", "serde_json", - "tokenizers", + "tokenizers 0.22.0", ] [[package]] @@ -3155,7 +3155,7 @@ dependencies = [ "js-sys", "libc", "r-efi", - "wasi 0.14.5+wasi-0.2.4", + "wasi 0.14.7+wasi-0.2.4", "wasm-bindgen", ] @@ -3218,7 +3218,7 @@ dependencies = [ "futures-sink", "futures-util", "http 0.2.12", - "indexmap 2.11.1", + "indexmap 2.11.3", "slab", "tokio", "tokio-util", @@ -3237,7 +3237,7 @@ dependencies = [ "futures-core", "futures-sink", "http 1.3.1", - "indexmap 2.11.1", + "indexmap 2.11.3", "slab", "tokio", "tokio-util", @@ -3523,9 +3523,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.16" +version = "0.1.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d9b05277c7e8da2c93a568989bb6207bef0112e8d17df7a6eda4a3cf143bc5e" +checksum = "3c6995591a8f1380fcb4ba966a252a4b29188d51d2b89e3a252f5305be65aea8" dependencies = [ "base64 0.22.1", "bytes", @@ -3762,13 +3762,14 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.11.1" +version = "2.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "206a8042aec68fa4a62e8d3f7aa4ceb508177d9324faf261e1959e495b7a1921" +checksum = "92119844f513ffa41556430369ab02c295a3578af21cf945caa3e9e0c2481ac3" dependencies = [ "equivalent", "hashbrown 0.15.5", "serde", + "serde_core", ] [[package]] @@ -4506,9 +4507,9 @@ checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" [[package]] name = "libredox" -version = "0.1.9" +version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "391290121bad3d37fbddad76d8f5d1c1c314cfc646d143d7e07a3086ddff0ce3" +checksum = "416f7e718bdb06000964960ffa43b4335ad4012ae8b99060261aa4a8088d5ccb" dependencies = [ "bitflags 2.9.4", "libc", @@ -4774,19 +4775,20 @@ dependencies = [ [[package]] name = "monostate" -version = "0.1.15" +version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de2d6d85e10b411dc84861249f08dda16accf6147d2181725543f7855eb1ce4d" +checksum = "5f766eeb5719df144c29802e56f3f7e0b3f29bd3ec8ab6c819aa1eaddec3f80c" dependencies = [ "monostate-impl", + "serde", "serde_core", ] [[package]] name = "monostate-impl" -version = "0.1.15" +version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d95324d045d5ec8db75d6ece8c226cf06a47a78ed9ffa4b06e7cabc0992d9614" +checksum = "88d4b3dcd6ce9277a52fb05f00bf3e45d475af8cce7376de2f8d6bd065fa4adb" dependencies = [ "proc-macro2", "quote", @@ -5138,7 +5140,7 @@ dependencies = [ "serde_json", "sha2", "tokio", - "toml 0.9.5", + "toml 0.9.6", "tracing", "tracing-appender", "tracing-subscriber", @@ -5181,7 +5183,7 @@ dependencies = [ "sha2", "thiserror 2.0.16", "tiktoken-rs", - "tokenizers", + "tokenizers 0.21.4", "tokio", "tracing", "url", @@ -5419,7 +5421,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" dependencies = [ "fixedbitset", - "indexmap 2.11.1", + "indexmap 2.11.3", ] [[package]] @@ -5430,7 +5432,7 @@ checksum = "54acf3a685220b533e437e264e4d932cfbdc4cc7ec0cd232ed73c08d03b8a7ca" dependencies = [ "fixedbitset", "hashbrown 0.15.5", - "indexmap 2.11.1", + "indexmap 2.11.3", "serde", ] @@ -6301,7 +6303,7 @@ dependencies = [ "once_cell", "ring", "rustls-pki-types", - "rustls-webpki 0.103.5", + "rustls-webpki 0.103.6", "subtle", "zeroize", ] @@ -6370,9 +6372,9 @@ dependencies = [ [[package]] name = "rustls-webpki" -version = "0.103.5" +version = "0.103.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5a37813727b78798e53c2bec3f5e8fe12a6d6f8389bf9ca7802add4c9905ad8" +checksum = "8572f3c2cb9934231157b45499fc41e1f58c589fdfb81a844ba873265e80f8eb" dependencies = [ "aws-lc-rs", "ring", @@ -6504,9 +6506,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.26" +version = "1.0.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" [[package]] name = "seq-macro" @@ -6516,9 +6518,9 @@ checksum = "1bc711410fbe7399f390ca1c3b60ad0f53f80e95c5eb935e52268a0e2cd49acc" [[package]] name = "serde" -version = "1.0.221" +version = "1.0.225" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "341877e04a22458705eb4e131a1508483c877dca2792b3781d4e5d8a6019ec43" +checksum = "fd6c24dee235d0da097043389623fb913daddf92c76e9f5a1db88607a0bcbd1d" dependencies = [ "serde_core", "serde_derive", @@ -6526,18 +6528,18 @@ dependencies = [ [[package]] name = "serde_core" -version = "1.0.221" +version = "1.0.225" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c459bc0a14c840cb403fc14b148620de1e0778c96ecd6e0c8c3cacb6d8d00fe" +checksum = "659356f9a0cb1e529b24c01e43ad2bdf520ec4ceaf83047b83ddcc2251f96383" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.221" +version = "1.0.225" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6185cf75117e20e62b1ff867b9518577271e58abe0037c40bb4794969355ab0" +checksum = "0ea936adf78b1f766949a4977b91d2f5595825bd6ec079aa9543ad2685fc4516" dependencies = [ "proc-macro2", "quote", @@ -6546,14 +6548,15 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.144" +version = "1.0.145" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56177480b00303e689183f110b4e727bb4211d692c62d4fcd16d02be93077d40" +checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c" dependencies = [ - "indexmap 2.11.1", + "indexmap 2.11.3", "itoa", "memchr", "ryu", + "serde", "serde_core", ] @@ -6588,11 +6591,11 @@ dependencies = [ [[package]] name = "serde_spanned" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40734c41988f7306bb04f0ecf60ec0f3f1caa34290e4e8ea471dcd3346483b83" +checksum = "2789234a13a53fc4be1b51ea1bab45a3c338bdb884862a257d10e5a74ae009e6" dependencies = [ - "serde", + "serde_core", ] [[package]] @@ -6617,7 +6620,7 @@ dependencies = [ "chrono", "hex", "indexmap 1.9.3", - "indexmap 2.11.1", + "indexmap 2.11.3", "schemars 0.9.0", "schemars 1.0.4", "serde", @@ -7357,6 +7360,39 @@ dependencies = [ "unicode_categories", ] +[[package]] +name = "tokenizers" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af10f51be57162b69d90a15cb226eef12c9e4faecbd5e3ea98a86bfb920b3d71" +dependencies = [ + "ahash", + "aho-corasick", + "compact_str", + "dary_heap", + "derive_builder", + "esaxx-rs", + "getrandom 0.3.3", + "itertools 0.14.0", + "log", + "macro_rules_attribute", + "monostate", + "onig", + "paste", + "rand 0.9.2", + "rayon", + "rayon-cond", + "regex", + "regex-syntax", + "serde", + "serde_json", + "spm_precompiled", + "thiserror 2.0.16", + "unicode-normalization-alignments", + "unicode-segmentation", + "unicode_categories", +] + [[package]] name = "tokio" version = "1.47.1" @@ -7455,14 +7491,14 @@ dependencies = [ [[package]] name = "toml" -version = "0.9.5" +version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75129e1dc5000bfbaa9fee9d1b21f974f9fbad9daec557a521ee6e080825f6e8" +checksum = "ae2a4cf385da23d1d53bc15cdfa5c2109e93d8d362393c801e87da2f72f0e201" dependencies = [ - "indexmap 2.11.1", - "serde", - "serde_spanned 1.0.0", - "toml_datetime 0.7.0", + "indexmap 2.11.3", + "serde_core", + "serde_spanned 1.0.1", + "toml_datetime 0.7.1", "toml_parser", "toml_writer", "winnow", @@ -7479,11 +7515,11 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.7.0" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bade1c3e902f58d73d3f294cd7f20391c1cb2fbcb643b73566bc773971df91e3" +checksum = "a197c0ec7d131bfc6f7e82c8442ba1595aeab35da7adbf05b6b73cd06a16b6be" dependencies = [ - "serde", + "serde_core", ] [[package]] @@ -7492,7 +7528,7 @@ version = "0.22.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a" dependencies = [ - "indexmap 2.11.1", + "indexmap 2.11.3", "serde", "serde_spanned 0.6.9", "toml_datetime 0.6.11", @@ -8062,18 +8098,18 @@ checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" [[package]] name = "wasi" -version = "0.14.5+wasi-0.2.4" +version = "0.14.7+wasi-0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4494f6290a82f5fe584817a676a34b9d6763e8d9d18204009fb31dceca98fd4" +checksum = "883478de20367e224c0090af9cf5f9fa85bed63a95c1abf3afc5c083ebc06e8c" dependencies = [ "wasip2", ] [[package]] name = "wasip2" -version = "1.0.0+wasi-0.2.4" +version = "1.0.1+wasi-0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03fa2761397e5bd52002cd7e73110c71af2109aca4e521a9f40473fe685b0a24" +checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7" dependencies = [ "wit-bindgen", ] @@ -8582,9 +8618,9 @@ dependencies = [ [[package]] name = "wit-bindgen" -version = "0.45.1" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c573471f125075647d03df72e026074b7203790d41351cd6edc96f46bcccd36" +checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" [[package]] name = "writeable" @@ -8761,7 +8797,7 @@ dependencies = [ "crc32fast", "crossbeam-utils", "displaydoc", - "indexmap 2.11.1", + "indexmap 2.11.3", "num_enum", "thiserror 1.0.69", ] diff --git a/Cargo.toml b/Cargo.toml index ff6652f..084f8fd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,8 @@ exclude = [ [features] default = [] +fastembed = ["octolib/fastembed"] +huggingface = ["octolib/huggingface"] # Optimized release profile for static linking [profile.release] diff --git a/src/embedding/mod.rs b/src/embedding/mod.rs index 7a5b908..eaebcf8 100644 --- a/src/embedding/mod.rs +++ b/src/embedding/mod.rs @@ -23,6 +23,19 @@ pub use octolib::embedding::{ truncate_output, EmbeddingProvider, InputType, }; +// Re-export types for backward compatibility +pub use octolib::embedding::types::{parse_provider_model, EmbeddingProviderType}; + +// Create a types module for backward compatibility +pub mod types { + pub use octolib::embedding::types::*; +} + +// Create a provider module for backward compatibility +pub mod provider { + pub use octolib::embedding::provider::*; +} + /// Configuration for embedding generation (octocode-specific) #[derive(Debug, Clone)] pub struct EmbeddingGenerationConfig {