From e2d763e98f1f564e14da1be96a11512e964fd627 Mon Sep 17 00:00:00 2001 From: Even Date: Tue, 27 Jan 2026 11:51:16 +0800 Subject: [PATCH 01/23] feat: Add native hybird search (#197) * add sparse vector embedding * hybrid search add sparse vector search * add checking version logic * add qwen sparse vector * adjust weight * update sparse vector function * update sparse vector function * fix bug * fix bug * optimise function * optimise function * optimise function * optimise function * optimise function * optimise function * optimise function * fix bug * add migrate function * update alembic function * update alembic function * update alembic function * adjust file struct * update alembic * update version * optimise * fix bug * update * update schema update method * update schema update method * update schema update method * update schema update method * update schema update method * update migrate method * update migrate method * update env.example * update env.example * update migrate sparse vector * update migrate sparse vector * adjust threshold score logic * update remark * add guides and examples * add benchmark param * fix bug * fulltext parsers support * adjust enable sparse vector setting * adjust env.example * adjust docs * update version * fix bug * optimise check * adjust file construct * adjust file construct * add native search * add file * remove log * remove log * fix bug * update pyobvector * add rerank * adjust * add limit * adjust config --- .github/workflows/test.yml | 4 +- benchmark/README.md | 2 +- benchmark/{lomoco => locomo}/.env.example | 0 benchmark/{lomoco => locomo}/LICENSE | 0 benchmark/{lomoco => locomo}/README.md | 0 .../{lomoco => locomo}/dataset/locomo10.json | 0 benchmark/{lomoco => locomo}/evals.py | 0 .../{lomoco => locomo}/generate_scores.py | 0 benchmark/{lomoco => locomo}/methods/add.py | 0 .../{lomoco => locomo}/methods/search.py | 0 .../{lomoco => locomo}/metrics/llm_judge.py | 0 benchmark/{lomoco => locomo}/metrics/utils.py | 0 benchmark/{lomoco => locomo}/prompts.py | 0 benchmark/{lomoco => locomo}/requirements.txt | 0 benchmark/{lomoco => locomo}/run.sh | 0 .../{lomoco => locomo}/run_experiments.py | 0 benchmark/server/main.py | 3 +- pyproject.toml | 2 +- src/powermem/config_loader.py | 13 +- src/powermem/core/memory.py | 19 +- src/powermem/storage/adapter.py | 40 +- src/powermem/storage/oceanbase/oceanbase.py | 445 +++++++++++++----- src/powermem/utils/oceanbase_util.py | 358 +++++++++++++- 23 files changed, 729 insertions(+), 157 deletions(-) rename benchmark/{lomoco => locomo}/.env.example (100%) rename benchmark/{lomoco => locomo}/LICENSE (100%) rename benchmark/{lomoco => locomo}/README.md (100%) rename benchmark/{lomoco => locomo}/dataset/locomo10.json (100%) rename benchmark/{lomoco => locomo}/evals.py (100%) rename benchmark/{lomoco => locomo}/generate_scores.py (100%) rename benchmark/{lomoco => locomo}/methods/add.py (100%) rename benchmark/{lomoco => locomo}/methods/search.py (100%) rename benchmark/{lomoco => locomo}/metrics/llm_judge.py (100%) rename benchmark/{lomoco => locomo}/metrics/utils.py (100%) rename benchmark/{lomoco => locomo}/prompts.py (100%) rename benchmark/{lomoco => locomo}/requirements.txt (100%) rename benchmark/{lomoco => locomo}/run.sh (100%) rename benchmark/{lomoco => locomo}/run_experiments.py (100%) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6271775..493a8a5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -9,7 +9,7 @@ on: - 'pyproject.toml' - 'Makefile' - '.github/workflows/test.yml' - - 'benchmark/lomoco/requirements.txt' + - '../../benchmark/locomo/requirements.txt' pull_request: branches: [main, develop] paths: @@ -18,7 +18,7 @@ on: - 'pyproject.toml' - 'Makefile' - '.github/workflows/test.yml' - - 'benchmark/lomoco/requirements.txt' + - '../../benchmark/locomo/requirements.txt' workflow_dispatch: jobs: diff --git a/benchmark/README.md b/benchmark/README.md index 27bee19..750ae7a 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -5,7 +5,7 @@ This directory contains the benchmarking suite for PowerMem, including a REST AP ## Documentation - **[Benchmark Overview](../docs/benchmark/overview.md)**: Complete documentation covering setup, configuration, and usage of the benchmark suite -- **[LOCOMO Benchmark](./lomoco/README.md)**: Details about the LOCOMO load testing tool and evaluation metrics +- **[LOCOMO Benchmark](locomo/README.md)**: Details about the LOCOMO load testing tool and evaluation metrics ## Quick Links diff --git a/benchmark/lomoco/.env.example b/benchmark/locomo/.env.example similarity index 100% rename from benchmark/lomoco/.env.example rename to benchmark/locomo/.env.example diff --git a/benchmark/lomoco/LICENSE b/benchmark/locomo/LICENSE similarity index 100% rename from benchmark/lomoco/LICENSE rename to benchmark/locomo/LICENSE diff --git a/benchmark/lomoco/README.md b/benchmark/locomo/README.md similarity index 100% rename from benchmark/lomoco/README.md rename to benchmark/locomo/README.md diff --git a/benchmark/lomoco/dataset/locomo10.json b/benchmark/locomo/dataset/locomo10.json similarity index 100% rename from benchmark/lomoco/dataset/locomo10.json rename to benchmark/locomo/dataset/locomo10.json diff --git a/benchmark/lomoco/evals.py b/benchmark/locomo/evals.py similarity index 100% rename from benchmark/lomoco/evals.py rename to benchmark/locomo/evals.py diff --git a/benchmark/lomoco/generate_scores.py b/benchmark/locomo/generate_scores.py similarity index 100% rename from benchmark/lomoco/generate_scores.py rename to benchmark/locomo/generate_scores.py diff --git a/benchmark/lomoco/methods/add.py b/benchmark/locomo/methods/add.py similarity index 100% rename from benchmark/lomoco/methods/add.py rename to benchmark/locomo/methods/add.py diff --git a/benchmark/lomoco/methods/search.py b/benchmark/locomo/methods/search.py similarity index 100% rename from benchmark/lomoco/methods/search.py rename to benchmark/locomo/methods/search.py diff --git a/benchmark/lomoco/metrics/llm_judge.py b/benchmark/locomo/metrics/llm_judge.py similarity index 100% rename from benchmark/lomoco/metrics/llm_judge.py rename to benchmark/locomo/metrics/llm_judge.py diff --git a/benchmark/lomoco/metrics/utils.py b/benchmark/locomo/metrics/utils.py similarity index 100% rename from benchmark/lomoco/metrics/utils.py rename to benchmark/locomo/metrics/utils.py diff --git a/benchmark/lomoco/prompts.py b/benchmark/locomo/prompts.py similarity index 100% rename from benchmark/lomoco/prompts.py rename to benchmark/locomo/prompts.py diff --git a/benchmark/lomoco/requirements.txt b/benchmark/locomo/requirements.txt similarity index 100% rename from benchmark/lomoco/requirements.txt rename to benchmark/locomo/requirements.txt diff --git a/benchmark/lomoco/run.sh b/benchmark/locomo/run.sh similarity index 100% rename from benchmark/lomoco/run.sh rename to benchmark/locomo/run.sh diff --git a/benchmark/lomoco/run_experiments.py b/benchmark/locomo/run_experiments.py similarity index 100% rename from benchmark/lomoco/run_experiments.py rename to benchmark/locomo/run_experiments.py diff --git a/benchmark/server/main.py b/benchmark/server/main.py index 7873145..57bbd18 100644 --- a/benchmark/server/main.py +++ b/benchmark/server/main.py @@ -92,7 +92,8 @@ def load_config() -> Dict[str, Any]: "vidx_metric_type": os.getenv("OCEANBASE_VIDX_METRIC_TYPE", "l2"), "vector_weight": VECTOR_WEIGHT, "fts_weight": FTS_WEIGHT, - 'include_sparse': os.getenv('SPARSE_VECTOR_ENABLE', 'false').lower() == 'true' + 'include_sparse': os.getenv('SPARSE_VECTOR_ENABLE', 'false').lower() == 'true', + 'enable_native_hybrid': os.getenv('OCEANBASE_ENABLE_NATIVE_HYBRID', 'false').lower() == 'true' }, } elif DB_TYPE == "postgres": diff --git a/pyproject.toml b/pyproject.toml index 160a582..5a4823c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ "slowapi>=0.1.9", "click>=8.0.0", "rank-bm25>=0.2.2", - "pyobvector>=0.2.21,<0.3.0", + "pyobvector>=0.2.22,<0.3.0", "jieba>=0.42.1", "azure-identity>=1.24.0", "psycopg2-binary>=2.9.0", diff --git a/src/powermem/config_loader.py b/src/powermem/config_loader.py index 2796140..51c006a 100644 --- a/src/powermem/config_loader.py +++ b/src/powermem/config_loader.py @@ -200,6 +200,10 @@ class DatabaseSettings(_BasePowermemSettings): default=False, validation_alias=AliasChoices("SPARSE_VECTOR_ENABLE"), ) + enable_native_hybrid: bool = Field( + default=False, + validation_alias=AliasChoices("OCEANBASE_ENABLE_NATIVE_HYBRID"), + ) postgres_collection: str = Field( default="memories", validation_alias=AliasChoices("POSTGRES_COLLECTION"), @@ -257,6 +261,7 @@ def _build_oceanbase_config(self) -> Dict[str, Any]: "metadata_field": self.oceanbase_metadata_field, "vidx_name": self.oceanbase_vidx_name, "include_sparse": self.oceanbase_include_sparse, + "enable_native_hybrid": self.enable_native_hybrid, } def _build_postgres_config(self) -> Dict[str, Any]: @@ -839,7 +844,7 @@ def load_config_from_env() -> Dict[str, Any]: Load configuration from environment variables. Deprecated for direct use: prefer `auto_config()` or `create_memory()`. - + This function reads configuration from environment variables and builds a config dictionary. You can use this when you have .env file set up to avoid manually building config dict. @@ -913,7 +918,7 @@ def create_config( Deprecated: prefer `auto_config()` or `create_memory()` unless you need a minimal manual config. - + Args: database_provider: Database provider ('sqlite', 'oceanbase', 'postgres') llm_provider: LLM provider ('qwen', 'openai', etc.) @@ -1007,7 +1012,7 @@ def validate_config(config: Dict[str, Any]) -> bool: Validate a configuration dictionary. Deprecated for new code paths: prefer `create_memory()` or `auto_config()`. - + Args: config: Configuration dictionary to validate @@ -1046,7 +1051,7 @@ def auto_config() -> Dict[str, Any]: It automatically loads .env file and returns the config. Preferred entrypoint for configuration loading. - + Returns: Configuration dictionary from environment variables diff --git a/src/powermem/core/memory.py b/src/powermem/core/memory.py index 28f0dbd..84dc5bc 100644 --- a/src/powermem/core/memory.py +++ b/src/powermem/core/memory.py @@ -239,12 +239,12 @@ def __init__( if self.storage_type.lower() == 'oceanbase' and include_sparse: sparse_config_obj = None sparse_embedder_provider = None - + if self.memory_config and hasattr(self.memory_config, 'sparse_embedder') and self.memory_config.sparse_embedder: sparse_config_obj = self.memory_config.sparse_embedder elif self.config.get('sparse_embedder'): sparse_config_obj = self.config.get('sparse_embedder') - + if sparse_config_obj: try: # Handle SparseEmbedderConfig (BaseModel with provider and config) or dict format @@ -260,13 +260,13 @@ def __init__( logger.warning(f"Unknown sparse_embedder config format: {type(sparse_config_obj)}. Expected SparseEmbedderConfig or dict with 'provider' and 'config' keys.") sparse_embedder_provider = None config_dict = {} - + if sparse_embedder_provider: self.sparse_embedder = SparseEmbedderFactory.create(sparse_embedder_provider, config_dict) logger.info(f"Sparse embedder initialized: {sparse_embedder_provider}") except Exception as e: logger.warning(f"Failed to initialize sparse embedder: {e}") - + # Initialize storage adapter with embedding service and sparse embedder service # Automatically select adapter based on sub_stores configuration sub_stores_list = self.config.get('sub_stores', []) @@ -736,7 +736,7 @@ def _intelligent_add( # Get intelligent memory config to check fallback setting intelligent_config = self._get_intelligent_memory_config() fallback_to_simple = intelligent_config.get("fallback_to_simple_add", False) - + # Step 1: Extract facts from messages logger.info("Extracting facts from messages...") facts = self._extract_facts(messages) @@ -1126,7 +1126,8 @@ def search( run_id=run_id, filters=filters, limit=limit, - query=query # Pass query text for hybrid search (vector + full-text + sparse vector) + query=query, # Pass query text for hybrid search (vector + full-text + sparse vector) + threshold=threshold, # Pass threshold to storage for native hybrid search condition check ) # Process results with intelligence manager (only if enabled to avoid unnecessary calls) @@ -1154,18 +1155,18 @@ def search( transformed_results = [] for result in processed_results: score = result.get("score", 0.0) - + # Get quality score for threshold filtering # Quality score represents absolute similarity quality (0-1 range) # It's calculated from weighted average of all search paths' similarity scores metadata = result.get("metadata", {}) quality_score = metadata.get("_quality_score") - + # If quality_score is not available (e.g., from older data or non-hybrid search), # fall back to using the ranking score if quality_score is None: quality_score = score - + # Apply threshold filtering using quality score # Only include results if threshold is None or quality_score >= threshold if threshold is not None and quality_score < threshold: diff --git a/src/powermem/storage/adapter.py b/src/powermem/storage/adapter.py index 581cb22..edd1f38 100644 --- a/src/powermem/storage/adapter.py +++ b/src/powermem/storage/adapter.py @@ -37,23 +37,23 @@ def __init__(self, vector_store: VectorStoreBase, embedding_service=None, sparse def _generate_sparse_embedding(self, content: str, memory_action: str) -> Optional[Any]: """ Generate sparse embedding for given content. - + Args: content: The text content to generate embedding for memory_action: The action context ("add", "search", "update") - + Returns: Sparse embedding if successful, None otherwise """ if not self.sparse_embedder_service or not content: return None - + try: return self.sparse_embedder_service.embed_sparse(content, memory_action=memory_action) except Exception as e: logger.warning(f"Failed to generate sparse embedding ({memory_action}): {e}") return None - + def add_memory(self, memory_data: Dict[str, Any]) -> int: """Add a memory to the store.""" # ID will be generated using Snowflake algorithm before insertion @@ -108,7 +108,7 @@ def add_memory(self, memory_data: Dict[str, Any]) -> int: # Add sparse embedding to payload if available if sparse_embedding is not None: payload["sparse_embedding"] = sparse_embedding - + # Add only user-defined metadata (not system fields) user_metadata = memory_data.get("metadata", {}) payload["metadata"] = serialize_datetime(user_metadata) if user_metadata else {} @@ -136,6 +136,7 @@ def search_memories( filters: Optional[Dict[str, Any]] = None, limit: int = 30, query: Optional[str] = None, + threshold: Optional[float] = None, ) -> List[Dict[str, Any]]: """Search for memories.""" # Use the provided query embedding or generate one @@ -148,7 +149,7 @@ def search_memories( # Generate sparse embedding if sparse embedder service is available and query is provided sparse_embedding = self._generate_sparse_embedding(query, "search") if query else None - + # Merge user_id/agent_id/run_id into filters to ensure consistency # This ensures filters are applied at the database level, avoiding redundant filtering effective_filters = filters.copy() if filters else {} @@ -167,13 +168,24 @@ def search_memories( try: # Try OceanBase format first - pass query text for hybrid search search_query = query if query else "" - # Check if target_store.search supports sparse_embedding parameter + # Check if target_store.search supports sparse_embedding and threshold parameters import inspect search_sig = inspect.signature(target_store.search) - if 'sparse_embedding' in search_sig.parameters: - results = target_store.search(search_query, vectors=query_vector, limit=limit, filters=effective_filters, sparse_embedding=sparse_embedding) - else: - results = target_store.search(search_query, vectors=query_vector, limit=limit, filters=effective_filters) + search_params = search_sig.parameters + + # Build search kwargs based on supported parameters + search_kwargs = { + "query": search_query, + "vectors": query_vector, + "limit": limit, + "filters": effective_filters, + } + if 'sparse_embedding' in search_params: + search_kwargs["sparse_embedding"] = sparse_embedding + if 'threshold' in search_params: + search_kwargs["threshold"] = threshold + + results = target_store.search(**search_kwargs) except TypeError: # Fallback to SQLite format (doesn't support query text parameter) # Pass filters to ensure filtering works correctly @@ -374,12 +386,12 @@ def update_memory( if "content" in update_data: updated_payload["data"] = update_data["content"] updated_payload["fulltext_content"] = update_data["content"] - + # Generate sparse embedding if sparse embedder service is available and content is updated sparse_embedding = self._generate_sparse_embedding(update_data["content"], "update") if sparse_embedding is not None: updated_payload["sparse_embedding"] = sparse_embedding - + # Remove content from update_data to avoid confusion update_data = update_data.copy() del update_data["content"] @@ -395,7 +407,7 @@ def update_memory( merged_metadata = {**existing_metadata, **new_metadata} serialized_update_data = serialized_update_data.copy() serialized_update_data["metadata"] = merged_metadata - + # Update other fields updated_payload.update(serialized_update_data) diff --git a/src/powermem/storage/oceanbase/oceanbase.py b/src/powermem/storage/oceanbase/oceanbase.py index b326bd4..8873f58 100644 --- a/src/powermem/storage/oceanbase/oceanbase.py +++ b/src/powermem/storage/oceanbase/oceanbase.py @@ -71,6 +71,7 @@ def __init__( fts_weight: float = 0.5, sparse_weight: float = 0.25, reranker: Optional[Any] = None, + enable_native_hybrid: bool = False, **kwargs, ): """ @@ -101,6 +102,7 @@ def __init__( fts_weight (float): Weight for full-text search in hybrid search (default: 0.5). sparse_weight (Optional[float]): Weight for sparse vector search in hybrid search. reranker (Optional[Any]): Reranker model for fine ranking. + enable_native_hybrid (bool): Whether to enable OceanBase native hybrid search (DBMS_HYBRID_SEARCH.SEARCH). """ self.normalize = normalize self.include_sparse = include_sparse @@ -111,6 +113,7 @@ def __init__( self.fts_weight = fts_weight self.sparse_weight = sparse_weight self.reranker = reranker + self.enable_native_hybrid = enable_native_hybrid # Validate fulltext parser if self.fulltext_parser not in constants.OCEANBASE_SUPPORTED_FULLTEXT_PARSERS: @@ -171,6 +174,12 @@ def __init__( self._create_client(**kwargs) assert self.obvector is not None + # Check if native hybrid search is supported by version and table type + if self.enable_native_hybrid: + if not OceanBaseUtil.check_native_hybrid_version_support(self.obvector, self.collection_name): + logger.warning("Falling back to application-level hybrid search.") + self.enable_native_hybrid = False + # Autoconfigure vector index settings if enabled if self.auto_configure_vector_index: self._configure_vector_index_settings() @@ -238,9 +247,12 @@ def _configure_vector_index_settings(self): def _create_table_with_index_by_embedding_model_dims(self) -> None: """Create table with vector index based on embedding dimension. - + If include_sparse is True and database supports sparse vector, the sparse_embedding column will be included in the table schema. + + If enable_native_hybrid is True, creates heap table (ORGANIZATION HEAP) + for native hybrid search support. """ cols = [ # Primary key - Snowflake ID (BIGINT without AUTO_INCREMENT) @@ -264,7 +276,7 @@ def _create_table_with_index_by_embedding_model_dims(self) -> None: # Create vector index parameters vidx_params = self.obvector.prepare_index_params() - + # Add dense vector index vidx_params.add_index( field_name=self.vector_field, @@ -303,8 +315,13 @@ def _create_table_with_index_by_embedding_model_dims(self) -> None: "Creating table without sparse vector support. " "Upgrade to seekdb or OceanBase >= 4.5.0 for sparse vector." ) - - # Create table with vector indexes (both dense and sparse if configured) + + # Determine table options based on native hybrid search setting + table_kwargs = {} + if self.enable_native_hybrid: + table_kwargs['mysql_organization'] = 'heap' + logger.info(f"Creating heap table '{self.collection_name}' for native hybrid search support") + self.obvector.create_table_with_index_params( table_name=self.collection_name, columns=cols, @@ -312,6 +329,7 @@ def _create_table_with_index_by_embedding_model_dims(self) -> None: vidxs=vidx_params, fts_idxs=[fts_index_param] if fts_index_param is not None else None, partitions=None, + **table_kwargs, ) logger.debug(f"Table '{self.collection_name}' created successfully") @@ -358,7 +376,7 @@ def create_col(self, name: str, vector_size: Optional[int] = None, distance: str def _create_col(self): """Create a new collection.""" - + if self.embedding_model_dims is None: raise ValueError( "embedding_model_dims is required for OceanBase vector operations. " @@ -379,7 +397,7 @@ def _create_col(self): else: # Existing table: validate schema logger.info(f"Table {self.collection_name} already exists, preserving existing data") - + # Check if the existing table's vector dimension matches the requested dimension existing_dim = self._get_existing_vector_dimension() if existing_dim is not None and existing_dim != self.embedding_model_dims: @@ -391,12 +409,12 @@ def _create_col(self): if self.hybrid_search: self._check_and_create_fulltext_index() - + # Validate sparse vector support if enabled if self.include_sparse: if not OceanBaseUtil.check_sparse_vector_ready(self.obvector, self.collection_name, self.sparse_vector_field): self.include_sparse = False - + self.model_class = create_memory_model( table_name=self.collection_name, embedding_dims=self.embedding_model_dims, @@ -408,7 +426,7 @@ def _create_col(self): fulltext_field=self.fulltext_field, sparse_vector_field=self.sparse_vector_field ) - + # Use model_class.__table__ as table reference self.table = self.model_class.__table__ @@ -569,7 +587,7 @@ def process_condition(cond): def _row_to_model(self, row): """ Convert SQLAlchemy Row object to ORM Model instance. - + Args: row: SQLAlchemy Row object (query result) @@ -578,14 +596,14 @@ def _row_to_model(self, row): """ # Create a new Model instance (not bound to Session) record = self.model_class() - + # Iterate through all columns in the table, map values from Row to Model instance for col_name in self.model_class.__table__.c.keys(): # Check if Row contains this column (queries may not include all columns) if col_name in row._mapping.keys(): attr_name = 'metadata_' if col_name == 'metadata' else col_name setattr(record, attr_name, row._mapping[col_name]) - + return record def _get_standard_select_columns(self) -> List: @@ -605,11 +623,11 @@ def _get_standard_select_columns(self) -> List: self.table.c["updated_at"], self.table.c["category"], ] - + # Only include sparse_embedding if sparse search is enabled if self.include_sparse: columns.append(self.table.c[self.sparse_vector_field]) - + return columns def _get_standard_column_names(self, include_vector_field: bool = False) -> List[str]: @@ -619,11 +637,11 @@ def _get_standard_column_names(self, include_vector_field: bool = False) -> List column_names = [ self.text_field, ] - + # Include vector_field if requested if include_vector_field: column_names.append(self.vector_field) - + column_names.extend([ self.metadata_field, self.primary_field, @@ -636,23 +654,23 @@ def _get_standard_column_names(self, include_vector_field: bool = False) -> List "updated_at", "category", ]) - + # Only include sparse_embedding if sparse search is enabled if self.include_sparse: column_names.append(self.sparse_vector_field) - + return column_names def _parse_row_to_dict(self, row, include_vector: bool = False, extract_score: bool = True) -> Dict: """ Parse a database row and return all fields as a dictionary. Now uses ORM Model instance internally for cleaner field access. - + Args: row: Database row result include_vector: Whether the row includes vector field (for get/list methods) extract_score: Whether to extract score/distance field (for search methods) - + Returns: Dict containing all parsed fields: - text_content: Text content from the row @@ -665,7 +683,7 @@ def _parse_row_to_dict(self, row, include_vector: bool = False, extract_score: b - score_or_distance: Score or distance value (only if extract_score=True) """ record = self._row_to_model(row) - + text_content = record.document metadata_json = record.metadata_ vector_id = record.id @@ -677,12 +695,12 @@ def _parse_row_to_dict(self, row, include_vector: bool = False, extract_score: b created_at = record.created_at updated_at = record.updated_at category = record.category - + # Handle optional fields vector = None sparse_embedding = None score_or_distance = None - + if include_vector: # get/list scenario: includes vector field vector = record.embedding @@ -692,7 +710,7 @@ def _parse_row_to_dict(self, row, include_vector: bool = False, extract_score: b # Search scenario: does not include vector, but may include sparse_embedding if self.include_sparse and hasattr(record, 'sparse_embedding') and record.sparse_embedding is not None: sparse_embedding = record.sparse_embedding - + # Extract additional score/distance fields (these fields are not in Model, need to get from original row) if extract_score: if 'score' in row._mapping.keys(): @@ -701,7 +719,7 @@ def _parse_row_to_dict(self, row, include_vector: bool = False, extract_score: b score_or_distance = row._mapping['distance'] elif 'anon_1' in row._mapping.keys(): score_or_distance = row._mapping['anon_1'] - + # Build standard metadata metadata = { "user_id": user_id, @@ -715,7 +733,7 @@ def _parse_row_to_dict(self, row, include_vector: bool = False, extract_score: b # Store user metadata as nested structure to preserve it "metadata": OceanBaseUtil.parse_metadata(metadata_json) } - + # Build result dictionary result = { "text_content": text_content, @@ -731,14 +749,14 @@ def _parse_row_to_dict(self, row, include_vector: bool = False, extract_score: b "category": category, "metadata": metadata, } - + # Add optional fields if include_vector: result["vector"] = vector - + if extract_score: result["score_or_distance"] = score_or_distance - + return result def _create_output_data(self, vector_id: int, text_content: str, score: float, @@ -799,11 +817,12 @@ def search(self, vectors: List[List[float]], limit: int = 5, filters: Optional[Dict] = None, - sparse_embedding: Optional[Dict[int, float]] = None) -> list[OutputData]: + sparse_embedding: Optional[Dict[int, float]] = None, + threshold: Optional[float] = None) -> list[OutputData]: # Check if hybrid search is enabled, and we have query text # Full-text search is always enabled by default if self.hybrid_search and query: - return self._hybrid_search(query, vectors, limit, filters, sparse_embedding) + return self._hybrid_search(query, vectors, limit, filters, sparse_embedding, threshold=threshold) else: return self._vector_search(query, vectors, limit, filters) @@ -826,13 +845,13 @@ def _vector_search(self, return [] table = Table(self.collection_name, self.obvector.metadata_obj, autoload_with=self.obvector.engine) - + # Build where clause from filters using the same table object where_clause = self._generate_where_clause(filters, table=table) # Build output column names list output_columns = self._get_standard_column_names() - + # Perform vector search - pyobvector expects a single vector, not a list of vectors results = self.obvector.ann_search( table_name=self.collection_name, @@ -849,14 +868,14 @@ def _vector_search(self, search_results = [] for row in results.fetchall(): parsed = self._parse_row_to_dict(row, include_vector=False, extract_score=True) - + # Convert distance to similarity score (0-1 range, higher is better) # Handle None distance (shouldn't happen but be defensive) distance = parsed["score_or_distance"] vector_id = parsed["vector_id"] text_content = parsed["text_content"] metadata = parsed["metadata"] - + if distance is None: logger.warning(f"Distance is None for vector_id {vector_id}, using default similarity 0.0") similarity = 0.0 @@ -877,10 +896,10 @@ def _vector_search(self, else: # Unknown metric, use default similarity = 0.0 - + # Store original similarity in metadata metadata['_vector_similarity'] = similarity - + # For pure vector search (no fusion), quality score equals vector similarity metadata['_quality_score'] = similarity @@ -986,18 +1005,18 @@ def _fulltext_search(self, query: str, limit: int = 5, filters: Optional[Dict] = fts_results = [] for row in rows: parsed = self._parse_row_to_dict(row, include_vector=False, extract_score=True) - + # FTS score is already in 0-1 range (higher is better) fts_score = float(parsed["score_or_distance"]) - + # Store original similarity in metadata metadata = parsed["metadata"] metadata['_fts_score'] = fts_score - + fts_results.append(self._create_output_data( - parsed["vector_id"], - parsed["text_content"], - fts_score, + parsed["vector_id"], + parsed["text_content"], + fts_score, metadata )) @@ -1007,12 +1026,12 @@ def _fulltext_search(self, query: str, limit: int = 5, filters: Optional[Dict] = def _sparse_search(self, sparse_embedding: Dict[int, float], limit: int = 5, filters: Optional[Dict] = None) -> list[OutputData]: """ Perform sparse vector search using OceanBase SPARSEVECTOR. - + Args: sparse_embedding: Sparse embedding dictionary (token_id -> weight) limit: Maximum number of results to return filters: Optional filter conditions - + Returns: List of OutputData objects with search results """ @@ -1020,20 +1039,20 @@ def _sparse_search(self, sparse_embedding: Dict[int, float], limit: int = 5, fil if not self.include_sparse: logger.debug("Sparse vector search is not enabled") return [] - + # Check if sparse embedding is provided if not sparse_embedding or not isinstance(sparse_embedding, dict): logger.debug("Sparse embedding not provided, skipping sparse search") return [] - + try: - + # Format sparse vector for SQL query sparse_vector_str = OceanBaseUtil.format_sparse_vector(sparse_embedding) - + # Generate where clause from filters filter_where_clause = self._generate_where_clause(filters) - + # Build the sparse vector search query # Use negative_inner_product for ordering (lower is better, so we negate) columns = self._get_standard_select_columns() + [ @@ -1041,21 +1060,21 @@ def _sparse_search(self, sparse_embedding: Dict[int, float], limit: int = 5, fil # Directly embed sparse_vector_str in SQL as per OceanBase syntax text(f"negative_inner_product({self.sparse_vector_field}, '{sparse_vector_str}') as score") ] - + stmt = select(*columns) - + # Add where conditions if filter_where_clause: for condition in filter_where_clause: stmt = stmt.where(condition) - + # Order by score ASC (lower negative_inner_product means higher similarity) stmt = stmt.order_by(text('score ASC')) - + # Add APPROXIMATE LIMIT (using regular LIMIT as APPROXIMATE may not be supported in all versions) if limit: stmt = stmt.limit(limit) - + # Execute the query with self.obvector.engine.connect() as conn: with conn.begin(): @@ -1063,12 +1082,12 @@ def _sparse_search(self, sparse_embedding: Dict[int, float], limit: int = 5, fil # Execute the query results = conn.execute(stmt) rows = results.fetchall() - + # Convert results to OutputData objects sparse_results = [] for row in rows: parsed = self._parse_row_to_dict(row, include_vector=False, extract_score=True) - + # Convert negative_inner_product to similarity (0-1 range, higher is better) # negative_inner_product returns negative values, negate to get inner product sparse_score = parsed["score_or_distance"] @@ -1084,57 +1103,235 @@ def _sparse_search(self, sparse_embedding: Dict[int, float], limit: int = 5, fil similarity = max(0.0, 1.0 / (1.0 - inner_prod)) else: similarity = 0.0 - + # Store original similarity in metadata metadata = parsed["metadata"] metadata['_sparse_similarity'] = similarity - + sparse_results.append(self._create_output_data( - parsed["vector_id"], - parsed["text_content"], - similarity, + parsed["vector_id"], + parsed["text_content"], + similarity, metadata )) - + logger.debug(f"_sparse_search results, len : {len(sparse_results)}") return sparse_results - + except Exception as e: logger.error(f"Sparse vector search failed: {e}", exc_info=True) # Return empty results on error rather than raising return [] + def _native_hybrid_search( + self, + query: str, + vectors: List[List[float]], + limit: int, + filters: Optional[Dict], + sparse_embedding: Optional[Dict[int, float]] = None, + k: int = 60 + ) -> List[OutputData]: + """ + Perform hybrid search using OceanBase native DBMS_HYBRID_SEARCH.SEARCH. + + This method leverages the database's native hybrid search capabilities to combine + full-text search and vector search (and optionally sparse vector search) with + RRF (Reciprocal Rank Fusion) ranking. + + Note: This method does NOT perform normalization or weight adjustments. + It uses the database's native RRF fusion. + + Args: + query: Text query for full-text search + vectors: Query vector(s) for vector search + limit: Maximum number of results to return + filters: Filter conditions in mem0 format + sparse_embedding: Optional sparse vector for sparse search + k: RRF rank_constant parameter (default: 60) + + Returns: + List[OutputData]: Search results sorted by relevance score + """ + try: + # 1. Extract query vector (no normalization for native hybrid search) + if isinstance(vectors[0], (int, float)): + query_vector = vectors + else: + query_vector = vectors[0] + + # 2. Convert filters to native format + native_filters = OceanBaseUtil.convert_filters_to_native_format( + filters, self.model_class, self.metadata_field + ) + + # 3. Build search parameters JSON + search_params = { + "query": { + "bool": { + "must": [ + { + "query_string": { + "fields": [self.fulltext_field], + "query": query + } + } + ] + } + }, + "rank": { + "rrf": { + "rank_window_size": limit, + "rank_constant": k + } + }, + "size":limit + } + + # Add filters if present + if native_filters: + search_params["query"]["bool"]["filter"] = native_filters + + # 4. Build knn array (supports both dense and sparse vectors) + knn_list = [ + { + "field": self.vector_field, + "k": limit, + "query_vector": query_vector + } + ] + + # Add filters to dense vector search if present + if native_filters: + knn_list[0]["filter"] = native_filters + + # Add sparse vector search if sparse_embedding is provided + if sparse_embedding is not None and self.include_sparse: + sparse_vector_str = OceanBaseUtil.format_sparse_vector(sparse_embedding) + knn_list.append({ + "field": self.sparse_vector_field, + "k": limit, + "query_vector": sparse_vector_str + }) + + search_params["knn"] = knn_list + + # 5. Execute native hybrid search with simplified SQL + body_str = json.dumps(search_params) + sql = text("SELECT DBMS_HYBRID_SEARCH.SEARCH(:index, :body_str)") + + with self.obvector.engine.connect() as conn: + with conn.begin(): + res = conn.execute(sql, {"index": self.collection_name, "body_str": body_str}).fetchone() + result_json_str = res[0] if res else None + + # 6. Parse and return results + if not result_json_str: + logger.warning("Native hybrid search returned empty result") + return [] + + parsed_results = OceanBaseUtil.parse_native_hybrid_results( + result_json_str, + self.primary_field, + self.text_field, + self.metadata_field + ) + + # 7. Convert to OutputData objects + output_list = [] + for doc in parsed_results: + metadata = { + "user_id": doc["user_id"], + "agent_id": doc["agent_id"], + "run_id": doc["run_id"], + "actor_id": doc["actor_id"], + "hash": doc["hash"], + "created_at": doc["created_at"], + "updated_at": doc["updated_at"], + "category": doc["category"], + "metadata": OceanBaseUtil.parse_metadata(doc["metadata_json"]) + } + + output_list.append( + self._create_output_data(doc["vector_id"], doc["text_content"], doc["score"], metadata) + ) + + logger.debug(f"Native hybrid search returned {len(output_list)} results") + return output_list + + except Exception as e: + logger.error(f"Native hybrid search failed: {e}") + raise # Re-raise to trigger fallback in _hybrid_search + def _hybrid_search(self, query: str, vectors: List[List[float]], limit: int = 5, filters: Optional[Dict] = None, sparse_embedding: Optional[Dict[int, float]] = None, - fusion_method: str = "rrf", k: int = 60): - """Perform hybrid search combining vector, full-text, and sparse vector search with optional reranking.""" + fusion_method: str = "rrf", k: int = 60, + threshold: Optional[float] = None): + """Perform hybrid search combining vector, full-text, and sparse vector search with optional reranking. + + When enable_native_hybrid is True and conditions are met, uses OceanBase native + DBMS_HYBRID_SEARCH.SEARCH for better performance. + """ + # Check if native hybrid search can be used: + # 1. enable_native_hybrid must be True + # 2. threshold must be None (native search doesn't support threshold filtering) + # 3. All filter fields must be in table columns + use_native = ( + self.enable_native_hybrid + and threshold is None + and OceanBaseUtil.check_filters_all_in_columns(filters, self.model_class) + ) + + if use_native: + try: + logger.debug("Using OceanBase native hybrid search (DBMS_HYBRID_SEARCH.SEARCH)") + native_candidate_limit = limit * 2 + native_results = self._native_hybrid_search( + query, vectors, native_candidate_limit, filters, sparse_embedding, k + ) + + # Fine ranking (optional): use reranker for precision sorting + if self.reranker and query and native_results: + try: + final_results = self._apply_rerank(query, native_results, limit) + logger.debug(f"Native results reranked, final results: {len(final_results)}") + return final_results + except Exception as e: + logger.warning(f"Native rerank failed, falling back to native coarse ranking: {e}") + return native_results[:limit] + + return native_results[:limit] + except Exception as e: + logger.warning(f"Native hybrid search failed: {e}, falling back to application-level hybrid search") + + # Application-level hybrid search # Determine candidate limit for reranking candidate_limit = limit * 3 if self.reranker else limit # Determine which searches to perform perform_sparse = self.include_sparse and sparse_embedding is not None - + # Perform searches in parallel for better performance search_tasks = [] with ThreadPoolExecutor(max_workers=3 if perform_sparse else 2) as executor: # Submit vector search vector_future = executor.submit(self._vector_search, query, vectors, candidate_limit, filters) search_tasks.append(('vector', vector_future)) - + # Submit full-text search fts_future = executor.submit(self._fulltext_search, query, candidate_limit, filters) search_tasks.append(('fts', fts_future)) - + # Submit sparse vector search if enabled if perform_sparse: sparse_future = executor.submit(self._sparse_search, sparse_embedding, candidate_limit, filters) search_tasks.append(('sparse', sparse_future)) - + # Wait for all searches to complete and get results vector_results = None fts_results = None sparse_results = None - + for search_type, future in search_tasks: try: results = future.result() @@ -1245,11 +1442,11 @@ def _calculate_quality_score( ) -> float: """ Calculate quality score from multiple search paths. - + Quality score represents the absolute similarity quality (0-1 range), used for threshold filtering. Unlike fusion scores used for ranking, quality scores maintain semantic meaning across different search scenarios. - + Args: vector_similarity: Vector search similarity (0-1, higher is better) fts_score: Full-text search score (0-1, higher is better) @@ -1257,10 +1454,10 @@ def _calculate_quality_score( vector_weight: Weight for vector search (default: 0.5) fts_weight: Weight for full-text search (default: 0.3) sparse_weight: Weight for sparse vector search (default: 0.2) - + Returns: Quality score in range [0, 1], where higher means better quality - + Algorithm: 1. Identify which search paths participated (have non-None scores) 2. Sum the weights of active paths @@ -1269,33 +1466,33 @@ def _calculate_quality_score( """ # Collect active search paths and their scores active_paths = [] - + if vector_similarity is not None: active_paths.append((vector_weight, vector_similarity)) - + if fts_score is not None: active_paths.append((fts_weight, fts_score)) - + if sparse_similarity is not None: active_paths.append((sparse_weight, sparse_similarity)) - + # If no active paths, return 0 if not active_paths: return 0.0 - + # Calculate total weight of active paths total_weight = sum(weight for weight, _ in active_paths) - + # Handle edge case where total weight is 0 if total_weight == 0: return 0.0 - + # Calculate weighted average quality score quality_score = sum( - (weight / total_weight) * score + (weight / total_weight) * score for weight, score in active_paths ) - + # Ensure result is in [0, 1] range return max(0.0, min(1.0, quality_score)) @@ -1305,7 +1502,7 @@ def _combine_search_results(self, vector_results: List[OutputData], fts_results: """Combine and rerank vector, full-text, and sparse vector search results using RRF or weighted fusion.""" if sparse_results is None: sparse_results = [] - + if fusion_method == "rrf": return self._rrf_fusion(vector_results, fts_results, sparse_results, limit, k, sparse_embedding) else: @@ -1321,18 +1518,18 @@ def _normalize_weights_adaptively( ) -> Dict: """ Adaptively normalize weights for each document. - + Principle: Dynamically adjust the total weight to 1.0 based on how many retrieval paths the document was actually retrieved from, solving the unfairness issue in mixed states (some data has sparse vectors, some don't). - + Args: all_docs: Document dictionary {doc_id: {'result': ..., 'vector_rank': ..., 'fts_rank': ..., 'sparse_rank': ..., 'rrf_score': ...}} vector_w: Vector search weight fts_w: Full-text search weight sparse_w: Sparse vector search weight k: RRF constant (default: 60) - + Returns: Normalized all_docs (rrf_score modified) """ @@ -1345,22 +1542,22 @@ def _normalize_weights_adaptively( active_weights.append(('fts', fts_w, doc_data['fts_rank'])) if doc_data['sparse_rank'] is not None: active_weights.append(('sparse', sparse_w, doc_data['sparse_rank'])) - + # Calculate total effective weight total_weight = sum(w for _, w, _ in active_weights) - + if total_weight == 0: continue - + # Normalize and recalculate rrf_score normalized_score = 0.0 - + for path, weight, rank in active_weights: normalized_weight = weight / total_weight normalized_score += normalized_weight * (1.0 / (k + rank)) - + doc_data['rrf_score'] = normalized_score - + return all_docs def _rrf_fusion(self, vector_results: List[OutputData], fts_results: List[OutputData], @@ -1373,14 +1570,14 @@ def _rrf_fusion(self, vector_results: List[OutputData], fts_results: List[Output """ if sparse_results is None: sparse_results = [] - + vector_w = self.vector_weight if self.vector_weight is not None else 0 fts_w = self.fts_weight if self.fts_weight is not None else 0 sparse_w = 0 if self.include_sparse and sparse_results and sparse_embedding: sparse_w = self.sparse_weight if self.sparse_weight is not None else 0 - + # Create mapping of document ID to result data all_docs = {} @@ -1450,12 +1647,12 @@ def _rrf_fusion(self, vector_results: List[OutputData], fts_results: List[Output final_results = [] for score, _, doc_data in sorted(heap, key=lambda x: x[0], reverse=True): result = doc_data['result'] - + # Extract original similarity scores from metadata vector_similarity = result.payload.get('_vector_similarity') fts_score = result.payload.get('_fts_score') sparse_similarity = result.payload.get('_sparse_similarity') - + # Calculate quality score for threshold filtering quality_score = self._calculate_quality_score( vector_similarity=vector_similarity, @@ -1465,13 +1662,13 @@ def _rrf_fusion(self, vector_results: List[OutputData], fts_results: List[Output fts_weight=fts_w, sparse_weight=sparse_w ) - + # Store quality score in payload result.payload['_quality_score'] = quality_score - + # Store fusion score (RRF score) in payload for debugging result.payload['_fusion_score'] = score - + # Set result.score to fusion score (used for ranking) result.score = score # Add ranking information to metadata for debugging @@ -1494,20 +1691,20 @@ def _weighted_fusion(self, vector_results: List[OutputData], fts_results: List[O limit: int, vector_weight: float = 0.7, text_weight: float = 0.3, sparse_weight: float = 0.0): """ Traditional weighted score fusion (fallback method). - + Note: All input scores are already in 0-1 similarity range (higher is better), so no normalization is needed. """ if sparse_results is None: sparse_results = [] - + # Use instance weights if available vector_w = self.vector_weight if self.vector_weight is not None else vector_weight fts_w = self.fts_weight if self.fts_weight is not None else text_weight sparse_w = 0.0 if self.include_sparse and sparse_results: sparse_w = self.sparse_weight if self.sparse_weight is not None else sparse_weight - + # Create a mapping of id to results for deduplication combined_results = {} @@ -1533,7 +1730,7 @@ def _weighted_fusion(self, vector_results: List[OutputData], fts_results: List[O 'fts_score': result.score, 'sparse_score': 0.0 } - + # Add sparse vector search results (scores are already 0-1 similarity) for result in sparse_results: if result.id in combined_results: @@ -1563,12 +1760,12 @@ def _weighted_fusion(self, vector_results: List[OutputData], fts_results: List[O final_results = [] for score, _, doc_data in sorted(heap, key=lambda x: x[0], reverse=True): result = doc_data['result'] - + # Extract original similarity scores from metadata vector_similarity = result.payload.get('_vector_similarity') fts_score = result.payload.get('_fts_score') sparse_similarity = result.payload.get('_sparse_similarity') - + # Calculate quality score for threshold filtering quality_score = self._calculate_quality_score( vector_similarity=vector_similarity, @@ -1578,13 +1775,13 @@ def _weighted_fusion(self, vector_results: List[OutputData], fts_results: List[O fts_weight=fts_w, sparse_weight=sparse_w ) - + # Store quality score in payload result.payload['_quality_score'] = quality_score - + # Store fusion score in payload for debugging result.payload['_fusion_score'] = score - + # Set result.score to fusion score (used for ranking) result.score = score # Add fusion info for debugging @@ -1625,7 +1822,7 @@ def update(self, vector_id: int, vector: Optional[List[float]] = None, payload: has_sparse_column = OceanBaseUtil.check_column_exists(self.obvector, self.collection_name, self.sparse_vector_field) if has_sparse_column: output_columns.append(self.sparse_vector_field) - + existing_result = self.obvector.get( table_name=self.collection_name, ids=[vector_id], @@ -1689,7 +1886,7 @@ def get(self, vector_id: int): try: # Build output column name list output_columns = self._get_standard_column_names(include_vector_field=True) - + results = self.obvector.get( table_name=self.collection_name, ids=[vector_id], @@ -1705,9 +1902,9 @@ def get(self, vector_id: int): logger.debug(f"Successfully retrieved vector with ID: {vector_id} from collection '{self.collection_name}'") return self._create_output_data( - parsed["vector_id"], - parsed["text_content"], - 0.0, + parsed["vector_id"], + parsed["text_content"], + 0.0, parsed["metadata"] ) @@ -1792,7 +1989,7 @@ def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None): """List all memories.""" try: table = Table(self.collection_name, self.obvector.metadata_obj, autoload_with=self.obvector.engine) - + # Build where clause from filters using the same table object where_clause = self._generate_where_clause(filters, table=table) @@ -1812,9 +2009,9 @@ def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None): parsed = self._parse_row_to_dict(row, include_vector=True, extract_score=False) memories.append(self._create_output_data( - parsed["vector_id"], - parsed["text_content"], - 0.0, + parsed["vector_id"], + parsed["text_content"], + 0.0, parsed["metadata"] )) @@ -1831,11 +2028,11 @@ def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None): def reset(self): """ Reset collection by deleting and recreating it. - + Note: After reset, the table will be recreated with the current configuration. If include_sparse=True and database supports it, sparse vector will be included. For existing tables that need sparse vector support, use the upgrade script: - + from script import ScriptManager from powermem import auto_config config = auto_config() @@ -1844,21 +2041,21 @@ def reset(self): try: logger.info(f"Resetting collection '{self.collection_name}'") self.delete_col() - + if self.embedding_model_dims is not None: # Create baseline table (020: dense vector + fulltext only) self._create_table_with_index_by_embedding_model_dims() if self.hybrid_search: self._check_and_create_fulltext_index() - + # Note: Sparse vector support is NOT created in reset() # Users should reinitialize OceanBaseVectorStore to get upgraded features logger.info( f"Successfully reset collection '{self.collection_name}' to baseline schema (020). " ) - + except Exception as e: logger.error(f"Failed to reset collection '{self.collection_name}': {e}", exc_info=True) raise diff --git a/src/powermem/utils/oceanbase_util.py b/src/powermem/utils/oceanbase_util.py index c65377c..03e654b 100644 --- a/src/powermem/utils/oceanbase_util.py +++ b/src/powermem/utils/oceanbase_util.py @@ -10,7 +10,12 @@ try: from sqlalchemy import text + from sqlalchemy.schema import CreateTable from pyobvector import FtsParser + from pyobvector.schema import ObTable, VectorIndex, FtsIndex + from pyobvector.client.index_param import IndexParams + from pyobvector.client.fts_index_param import FtsIndexParam + from pyobvector.client.partitions import ObPartition except ImportError as e: raise ImportError( f"Required dependencies not found: {e}. Please install sqlalchemy and pyobvector." @@ -285,6 +290,108 @@ def check_sparse_vector_version_support(obvector) -> bool: ) return False + @staticmethod + def check_native_hybrid_version_support(obvector, table_name: str) -> bool: + """ + Check if the database version and table type support native hybrid search (DBMS_HYBRID_SEARCH.SEARCH). + + Args: + obvector: The ObVecClient instance. + table_name: The name of the table. + + Returns: + True if version is seekdb or OceanBase >= 4.4.1, and table is heap table or doesn't exist. + False otherwise. + """ + # Check if it's seekdb + if OceanBaseUtil.is_seekdb(obvector): + logger.info("Detected seekdb, native hybrid search is supported") + # Also check if table is heap table (or doesn't exist) + if not OceanBaseUtil.check_table_is_heap_or_not_exists(obvector, table_name): + logger.warning( + f"Table '{table_name}' is not a heap table (ORGANIZATION HEAP). " + "Native hybrid search requires heap table." + ) + return False + return True + + # Check if it's OceanBase and version >= 4.4.1 + version_dict = OceanBaseUtil.get_version_number(obvector) + if version_dict is None: + logger.warning("Could not determine database version, assuming native hybrid search not supported") + return False + + major = version_dict["major"] + minor = version_dict["minor"] + patch = version_dict["patch"] + + # Check version >= 4.4.1 + if major > 4 or (major == 4 and (minor > 4 or (minor == 4 and patch >= 1))): + logger.info(f"Detected OceanBase version {major}.{minor}.{patch}, native hybrid search is supported") + # Also check if table is heap table (or doesn't exist) + if not OceanBaseUtil.check_table_is_heap_or_not_exists(obvector, table_name): + logger.warning( + f"Table '{table_name}' is not a heap table (ORGANIZATION HEAP). " + "Native hybrid search requires heap table." + ) + return False + return True + else: + logger.warning( + f"Detected OceanBase version {major}.{minor}.{patch}, " + "native hybrid search requires OceanBase >= 4.4.1" + ) + return False + + @staticmethod + def check_table_is_heap_or_not_exists(obvector, table_name: str) -> bool: + """ + Check if the table is a heap table (ORGANIZATION HEAP) or doesn't exist. + + DBMS_HYBRID_SEARCH.SEARCH requires heap table, not index-organized table. + + Args: + obvector: The ObVecClient instance. + table_name: The name of the table. + + Returns: + True if table is heap table or doesn't exist, False if it's index-organized table. + """ + try: + with obvector.engine.connect() as conn: + # Check if table exists first + result = conn.execute(text( + f"SELECT COUNT(*) FROM information_schema.TABLES " + f"WHERE TABLE_SCHEMA = DATABASE() " + f"AND TABLE_NAME = '{table_name}'" + )) + if result.scalar() == 0: + # Table doesn't exist, will be created as heap table + logger.debug(f"Table '{table_name}' doesn't exist, will be created as heap table") + return True + + # Check table organization type using SHOW CREATE TABLE + result = conn.execute(text(f"SHOW CREATE TABLE `{table_name}`")) + row = result.fetchone() + if row and len(row) >= 2: + create_statement = row[1] + # Check for ORGANIZATION keyword + if "ORGANIZATION INDEX" in create_statement.upper(): + logger.debug(f"Table '{table_name}' is an index-organized table") + return False + elif "ORGANIZATION HEAP" in create_statement.upper(): + logger.debug(f"Table '{table_name}' is a heap table") + return True + else: + # No ORGANIZATION keyword, default is index-organized in OceanBase + # when table has primary key + logger.debug(f"Table '{table_name}' has no explicit ORGANIZATION, assuming index-organized") + return False + return False + except Exception as e: + logger.error(f"An error occurred while checking table organization type: {e}") + return False + @staticmethod def format_sparse_vector(sparse_dict: Dict[int, float]) -> str: """ @@ -427,4 +534,253 @@ def check_sparse_vector_ready(obvector, collection_name: str, sparse_vector_fiel return False logger.info(f"Sparse vector support validated successfully for table '{collection_name}'") - return True \ No newline at end of file + return True + + @staticmethod + def check_filters_all_in_columns(filters: Optional[Dict], model_class) -> bool: + """ + Check if all filter fields are in standard table columns. + + This is used to determine if native hybrid search can be used. + Native hybrid search doesn't support JSON path filtering well, + so we only use it when all filters are on actual table columns. + + Args: + filters: The filter conditions in mem0 format. + model_class: SQLAlchemy ORM model class with __table__ attribute. + + Returns: + True if all filter fields are in table columns, False otherwise. + """ + if not filters: + return True + + # Get column names from model_class + try: + table_columns = set(model_class.__table__.c.keys()) + except AttributeError: + logger.warning("model_class does not have __table__ attribute, native hybrid search disabled") + return False + + def check_filter_keys(filter_dict: Dict) -> bool: + """Recursively check if all keys are in table columns.""" + for key, value in filter_dict.items(): + # Handle AND/OR logic + if key in ["AND", "OR"]: + if isinstance(value, list): + for sub_filter in value: + if not check_filter_keys(sub_filter): + return False + else: + # Check if this key is a table column + if key not in table_columns: + logger.debug(f"Filter field '{key}' not in table columns, native hybrid search disabled") + return False + return True + + return check_filter_keys(filters) + + @staticmethod + def convert_filters_to_native_format( + filters: Optional[Dict], + model_class, + metadata_field: str = "metadata" + ) -> List[Dict]: + """ + Convert filter format to OceanBase native DBMS_HYBRID_SEARCH.SEARCH filter format. + + Follows the SEARCH API specification: + - term: Exact match (strings, numbers, booleans) + - range: Range queries (gte, gt, lte, lt) + - match: Fuzzy match (full-text) + - bool: Complex logic (must, should, must_not, filter) + + Args: + filters: The filter conditions in mem0 format. + model_class: SQLAlchemy ORM model class with __table__ attribute. + metadata_field: Name of the metadata field (default: "metadata"). + + Returns: + List[Dict]: Filter conditions in OceanBase SEARCH native format. + """ + if not filters: + return [] + + # Get column names from model_class + try: + table_columns = set(model_class.__table__.c.keys()) + except AttributeError: + logger.warning("model_class does not have __table__ attribute") + table_columns = set() + + def get_field_name(key: str) -> Optional[str]: + """Get the field name for filter.""" + if key in table_columns: + return key + else: + # Skip non-table fields for native hybrid search + logger.debug(f"Skipping non-table field in native hybrid search: {key}") + return None + + def process_single_filter(key: str, value) -> Optional[Dict]: + """Process a single filter condition.""" + field_name = get_field_name(key) + if field_name is None: + return None + + # List values -> IN query -> bool.should with multiple term queries + if isinstance(value, list): + if not value: + return None + return { + "bool": { + "should": [{"term": {field_name: v}} for v in value] + } + } + + # Dict values -> May be range query or other operators + if isinstance(value, dict): + range_ops = {"gte", "gt", "lte", "lt"} + if any(op in value for op in range_ops): + range_params = {op: value[op] for op in range_ops if op in value} + return {"range": {field_name: range_params}} + + if "eq" in value: + return {"term": {field_name: value["eq"]}} + + if "ne" in value: + return {"bool": {"must_not": [{"term": {field_name: value["ne"]}}]}} + + if "in" in value: + if not isinstance(value["in"], list) or not value["in"]: + return None + return { + "bool": { + "should": [{"term": {field_name: v}} for v in value["in"]] + } + } + + if "nin" in value: + if not isinstance(value["nin"], list) or not value["nin"]: + return None + return { + "bool": { + "must_not": [ + {"bool": {"should": [{"term": {field_name: v}} for v in value["nin"]]}} + ] + } + } + + if "like" in value or "ilike" in value: + query_str = value.get("like") or value.get("ilike") + query_str = str(query_str).replace("%", "").replace("_", " ").strip() + if query_str: + return {"match": {field_name: {"query": query_str}}} + return None + + # None values -> Not supported, skip + if value is None: + logger.warning(f"NULL filter not supported in native search, skipping: {key}") + return None + + # Simple values -> term query + return {"term": {field_name: value}} + + def process_complex_filter(filters_dict: Dict) -> List[Dict]: + """Process complex filters with AND/OR logic.""" + if "AND" in filters_dict: + conditions = [] + for sub_filter in filters_dict["AND"]: + result = process_complex_filter(sub_filter) + if result: + conditions.extend(result) + if conditions: + return [{"bool": {"filter": conditions}}] + return [] + + if "OR" in filters_dict: + conditions = [] + for sub_filter in filters_dict["OR"]: + result = process_complex_filter(sub_filter) + if result: + conditions.extend(result) + if conditions: + return [{"bool": {"should": conditions}}] + return [] + + results = [] + for k, v in filters_dict.items(): + if k not in ["AND", "OR"]: + result = process_single_filter(k, v) + if result: + results.append(result) + return results + + return process_complex_filter(filters) + + @staticmethod + def parse_native_hybrid_results( + result_json_str: str, + primary_field: str = "id", + text_field: str = "document", + metadata_field: str = "metadata" + ) -> List[Dict]: + """ + Parse the JSON results from OceanBase native DBMS_HYBRID_SEARCH.SEARCH. + + Args: + result_json_str: JSON string returned from DBMS_HYBRID_SEARCH.SEARCH + primary_field: Name of the primary key field + text_field: Name of the text content field + metadata_field: Name of the metadata field + + Returns: + List[Dict]: List of parsed result dictionaries, each containing: + - vector_id: Primary key value + - text_content: Text content + - score: Relevance score from _score field + - user_id, agent_id, run_id, actor_id: Standard ID fields + - hash, created_at, updated_at, category: Standard fields + - metadata_json: Raw metadata JSON + """ + try: + result_data = json.loads(result_json_str) + + if not isinstance(result_data, list): + logger.warning(f"Unexpected result format: {type(result_data)}, expected list") + return [] + + output_list = [] + + for doc in result_data: + vector_id = doc.get(primary_field) + score = doc.get("_score", 0.0) + + if not vector_id: + logger.warning(f"Document missing primary key '{primary_field}', skipping") + continue + + output_list.append({ + "vector_id": vector_id, + "text_content": doc.get(text_field, ""), + "score": score, + "user_id": doc.get("user_id", ""), + "agent_id": doc.get("agent_id", ""), + "run_id": doc.get("run_id", ""), + "actor_id": doc.get("actor_id", ""), + "hash": doc.get("hash", ""), + "created_at": doc.get("created_at", ""), + "updated_at": doc.get("updated_at", ""), + "category": doc.get("category", ""), + "metadata_json": doc.get(metadata_field, {}), + }) + + logger.debug(f"Parsed {len(output_list)} results from native hybrid search") + return output_list + + except json.JSONDecodeError as e: + logger.error(f"Failed to parse native hybrid search JSON results: {e}") + return [] + except Exception as e: + logger.error(f"Error parsing native hybrid search results: {e}") + return [] \ No newline at end of file From 4c0eae71078d9612edc058cca0d44582b91b7edf Mon Sep 17 00:00:00 2001 From: Even Date: Tue, 27 Jan 2026 17:48:17 +0800 Subject: [PATCH 02/23] feat: User Profile Support native language output (#198) * support native language * support native language * add docs --- docs/examples/scenario_9_user_memory.ipynb | 81 +++++++++++++++++++- docs/examples/scenario_9_user_memory.md | 75 +++++++++++++++++- docs/guides/0010-user_memory.md | 26 +++++++ src/powermem/prompts/user_profile_prompts.py | 56 +++++++++++++- src/powermem/user_memory/user_memory.py | 17 +++- src/server/api/v1/users.py | 1 + src/server/models/request.py | 1 + src/server/services/user_service.py | 3 + 8 files changed, 254 insertions(+), 6 deletions(-) diff --git a/docs/examples/scenario_9_user_memory.ipynb b/docs/examples/scenario_9_user_memory.ipynb index d974c68..7d08cf1 100644 --- a/docs/examples/scenario_9_user_memory.ipynb +++ b/docs/examples/scenario_9_user_memory.ipynb @@ -393,12 +393,91 @@ " print(f\" {i}. {result.get('memory', '')} (score: {result.get('score', 0):.2f})\")" ] }, + { + "cell_type": "markdown", + "id": "native_language_step", + "metadata": {}, + "source": [ + "## Step 7: Extract Profile in Native Language\n", + "\n", + "You can specify a native language for profile extraction, ensuring the profile is written in the user's preferred language regardless of the conversation language:\n", + "\n", + "**Supported Language Codes (ISO 639-1):**\n", + "\n", + "| Code | Language | Code | Language |\n", + "|------|----------|------|----------|\n", + "| zh | Chinese | en | English |\n", + "| ja | Japanese | ko | Korean |\n", + "| fr | French | de | German |\n", + "| es | Spanish | it | Italian |\n", + "| pt | Portuguese | ru | Russian |\n", + "| ar | Arabic | hi | Hindi |\n", + "| th | Thai | vi | Vietnamese |" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "native_language_code", + "metadata": {}, + "outputs": [], + "source": [ + "from powermem import UserMemory, auto_config\n", + "\n", + "config = auto_config()\n", + "user_memory = UserMemory(config=config)\n", + "\n", + "# Example 1: English conversation, Chinese profile\n", + "conversation_en = [\n", + " {\"role\": \"user\", \"content\": \"I am a software engineer working in Beijing. I love drinking tea and reading books.\"},\n", + " {\"role\": \"assistant\", \"content\": \"That sounds great!\"}\n", + "]\n", + "\n", + "result_zh = user_memory.add(\n", + " messages=conversation_en,\n", + " user_id=\"user_bilingual_001\",\n", + " native_language=\"zh\" # Extract profile in Chinese\n", + ")\n", + "\n", + "print(\"✓ English conversation processed\")\n", + "if result_zh.get('profile_content'):\n", + " print(f\" - Profile (Chinese): {result_zh['profile_content']}\")\n", + "\n", + "# Example 2: Chinese conversation, English profile\n", + "conversation_zh = [\n", + " {\"role\": \"user\", \"content\": \"I'm 25 years old, working at Microsoft in Seattle.\"},\n", + " {\"role\": \"assistant\", \"content\": \"Nice to meet you\"}\n", + "]\n", + "\n", + "result_en = user_memory.add(\n", + " messages=conversation_zh,\n", + " user_id=\"user_bilingual_002\",\n", + " native_language=\"en\" # Extract profile in English\n", + ")\n", + "\n", + "print(\"\\n✓ Chinese conversation processed\")\n", + "if result_en.get('profile_content'):\n", + " print(f\" - Profile (English): {result_en['profile_content']}\")\n", + "\n", + "# Example 3: Structured topics with native language\n", + "result_topics = user_memory.add(\n", + " messages=\"I'm 25 years old, working at Microsoft in Seattle.\",\n", + " user_id=\"user_bilingual_003\",\n", + " profile_type=\"topics\",\n", + " native_language=\"zh\" # Topic values in Chinese, keys remain English\n", + ")\n", + "\n", + "print(\"\\n✓ Structured topics extracted\")\n", + "if result_topics.get('topics'):\n", + " print(f\" - Topics: {result_topics['topics']}\")" + ] + }, { "cell_type": "markdown", "id": "c1757147ee07603e", "metadata": {}, "source": [ - "## Step 7: Delete User Profile\n", + "## Step 8: Delete User Profile\n", "\n", "Delete a user profile:" ] diff --git a/docs/examples/scenario_9_user_memory.md b/docs/examples/scenario_9_user_memory.md index 5fcef5a..5859462 100644 --- a/docs/examples/scenario_9_user_memory.md +++ b/docs/examples/scenario_9_user_memory.md @@ -421,7 +421,80 @@ python user_profile_example.py 1. Works as a UX designer, loves creating beautiful interfaces (score: 0.92) ``` -## Step 7: Delete User Profile +## Step 7: Extract Profile in Native Language + +You can specify a native language for profile extraction, ensuring the profile is written in the user's preferred language regardless of the conversation language: + +```python +# user_profile_example.py +from powermem import UserMemory, auto_config + +config = auto_config() +user_memory = UserMemory(config=config) + +# Example 1: English conversation, Chinese profile +conversation_en = [ + {"role": "user", "content": "I am a software engineer working in Beijing. I love drinking tea and reading books."}, + {"role": "assistant", "content": "That sounds great!"} +] + +result_zh = user_memory.add( + messages=conversation_en, + user_id="user_bilingual_001", + native_language="zh" # Extract profile in Chinese +) + +print("✓ English conversation processed") +if result_zh.get('profile_content'): + print(f" - Profile (Chinese): {result_zh['profile_content']}") + +# Example 2: Chinese conversation, English profile +conversation_zh = [ + {"role": "user", "content": "I'm 25 years old, working at Microsoft in Seattle."}, + {"role": "assistant", "content": "Nice to meet you"} +] + +result_en = user_memory.add( + messages=conversation_zh, + user_id="user_bilingual_002", + native_language="en" # Extract profile in English +) + +print("\n✓ Chinese conversation processed") +if result_en.get('profile_content'): + print(f" - Profile (English): {result_en['profile_content']}") + +# Example 3: Structured topics with native language +result_topics = user_memory.add( + messages="I'm 25 years old, working at Microsoft in Seattle.", + user_id="user_bilingual_003", + profile_type="topics", + native_language="zh" # Topic values in Chinese, keys remain English +) + +print("\n✓ Structured topics extracted") +if result_topics.get('topics'): + print(f" - Topics: {result_topics['topics']}") +``` + +**Run this code:** +```bash +python user_profile_example.py +``` + +**Supported Language Codes:** + +| Code | Language | Code | Language | +|------|----------|------|----------| +| zh | Chinese | en | English | +| ja | Japanese | ko | Korean | +| fr | French | de | German | +| es | Spanish | it | Italian | +| pt | Portuguese | ru | Russian | +| ar | Arabic | hi | Hindi | +| th | Thai | vi | Vietnamese | + +## Step 8: Delete User Profile Delete a user profile: diff --git a/docs/guides/0010-user_memory.md b/docs/guides/0010-user_memory.md index d8ae2e7..091164d 100644 --- a/docs/guides/0010-user_memory.md +++ b/docs/guides/0010-user_memory.md @@ -104,6 +104,7 @@ def add( strict_mode: bool = False, include_roles: Optional[List[str]] = ["user"], exclude_roles: Optional[List[str]] = ["assistant"], + native_language: Optional[str] = None, ) -> Dict[str, Any] ``` @@ -137,6 +138,7 @@ Same as `Memory.add()` - `strict_mode` (bool): If True, only output topics from the provided list. Only used when profile_type="topics". Default: False - `include_roles` (List[str], optional): List of roles to include when filtering messages for profile extraction. Default: `["user"]`. If explicitly set to `None` or `[]`, no include filter is applied. - `exclude_roles` (List[str], optional): List of roles to exclude when filtering messages for profile extraction. Default: `["assistant"]`. If explicitly set to `None` or `[]`, no exclude filter is applied. +- `native_language` (str, optional): ISO 639-1 language code (e.g., "zh", "en", "ja") to specify the target language for profile extraction. If specified, the extracted profile will be written in this language regardless of the languages used in the conversation. If not specified, the profile language will follow the conversation language. Default: None #### Return value @@ -251,6 +253,30 @@ result = user_memory.add( include_roles=["user", "system"], exclude_roles=["tool"] ) + +# Example 6: specify native language for profile extraction +# User speaks English, but wants profile in Chinese +result = user_memory.add( + messages="I am a software engineer working in Beijing. I love drinking tea.", + user_id="user_002", + native_language="zh" # Extract profile in Chinese +) + +if result.get('profile_content'): + print(f"Profile (in Chinese): {result['profile_content']}") + # Output: "用户是一名在北京工作的软件工程师。喜欢喝茶。" + +# Extract structured topics with native language +result = user_memory.add( + messages="I'm 28 years old, working at Google in California.", + user_id="user_003", + profile_type="topics", + native_language="zh" # Topic values in Chinese, keys remain in English +) + +if result.get('topics'): + print(f"Topics: {result['topics']}") + # Output: {"basic_information": {"user_age": "28"}, "employment": {"company": "谷歌", "location": "加利福尼亚"}} ``` ### 2. `search()` — Search memories (optionally include profile) diff --git a/src/powermem/prompts/user_profile_prompts.py b/src/powermem/prompts/user_profile_prompts.py index 041133a..d36672d 100644 --- a/src/powermem/prompts/user_profile_prompts.py +++ b/src/powermem/prompts/user_profile_prompts.py @@ -11,6 +11,27 @@ logger = logging.getLogger(__name__) +# Language code to language name mapping +# Supports ISO 639-1 two-letter language codes +# If a language code is not found in this mapping, it will be passed directly to the prompt +LANGUAGE_CODE_MAPPING = { + "zh": "Chinese", + "en": "English", + "ja": "Japanese", + "ko": "Korean", + "fr": "French", + "de": "German", + "es": "Spanish", + "it": "Italian", + "pt": "Portuguese", + "ru": "Russian", + "ar": "Arabic", + "hi": "Hindi", + "th": "Thai", + "vi": "Vietnamese", +} + + # User profile topics for reference in prompt USER_PROFILE_TOPICS = """ - Basic Information @@ -98,13 +119,20 @@ """ -def get_user_profile_extraction_prompt(conversation: str, existing_profile: Optional[str] = None) -> Tuple[str, str]: +def get_user_profile_extraction_prompt( + conversation: str, + existing_profile: Optional[str] = None, + native_language: Optional[str] = None, +) -> Tuple[str, str]: """ Generate the system prompt and user message for user profile extraction. Args: conversation: The conversation text to analyze existing_profile: Optional existing user profile content to update + native_language: Optional ISO 639-1 language code (e.g., "zh", "en") to specify the target language + for profile extraction. If specified, the extracted profile will be written in this language + regardless of the languages used in the conversation. Returns: Tuple of (system_prompt, user_message): @@ -120,8 +148,17 @@ def get_user_profile_extraction_prompt(conversation: str, existing_profile: Opti ``` {existing_profile} ```""" + + # Build language instruction section + language_instruction = "" + if native_language: + target_language = LANGUAGE_CODE_MAPPING.get(native_language, native_language) + language_instruction = f""" + +[Language Requirement]: +You MUST extract and write the profile content in {target_language}, regardless of what languages are used in the conversation.""" - system_prompt = f"""{USER_PROFILE_EXTRACTION_PROMPT}{current_profile_section} + system_prompt = f"""{USER_PROFILE_EXTRACTION_PROMPT}{current_profile_section}{language_instruction} [Target]: Extract and return the user profile information as a text description:""" @@ -136,6 +173,7 @@ def get_user_profile_topics_extraction_prompt( existing_topics: Optional[Dict[str, Any]] = None, custom_topics: Optional[str] = None, strict_mode: bool = False, + native_language: Optional[str] = None, ) -> Tuple[str, str]: """ Generate the system prompt and user message for structured topic extraction. @@ -153,6 +191,9 @@ def get_user_profile_topics_extraction_prompt( - All keys must be in snake_case (lowercase, underscores, no spaces) - Descriptions are for reference only and should NOT be used as keys in the output strict_mode: If True, only output topics from the provided list; if False, can extend + native_language: Optional ISO 639-1 language code (e.g., "zh", "en") to specify the target language + for topic value extraction. If specified, the extracted topic values will be written in this + language regardless of the languages used in the conversation. Returns: Tuple of (system_prompt, user_message): @@ -225,6 +266,15 @@ def get_user_profile_topics_extraction_prompt( {formatted_topics} """ + # Build language instruction section + language_instruction = "" + if native_language: + target_language = LANGUAGE_CODE_MAPPING.get(native_language, native_language) + language_instruction = f""" + +[Language Requirement]: +You MUST extract and write all topic values in {target_language}, regardless of what languages are used in the conversation. Keep the topic keys in snake_case English format, but write the values in {target_language}.""" + system_prompt = f"""You are a user profile topic extraction specialist. Your task is to analyze conversations and extract user profile information as structured topics. {topics_section}{description_warning} @@ -242,7 +292,7 @@ def get_user_profile_topics_extraction_prompt( 7. If no relevant profile information is found in the conversation, return the current topics as-is 8. If no user profile information can be extracted from the conversation at all, return an empty JSON object {{}} 9. Focus on current state and characteristics of the user -{strict_instruction}{existing_topics_section} +{strict_instruction}{existing_topics_section}{language_instruction} [Output Format]: Return a valid JSON object with the following structure: diff --git a/src/powermem/user_memory/user_memory.py b/src/powermem/user_memory/user_memory.py index c3be9d7..9e8307b 100644 --- a/src/powermem/user_memory/user_memory.py +++ b/src/powermem/user_memory/user_memory.py @@ -153,6 +153,7 @@ def add( strict_mode: bool = False, include_roles: Optional[List[str]] = ["user"], exclude_roles: Optional[List[str]] = ["assistant"], + native_language: Optional[str] = None, ) -> Dict[str, Any]: """ Add messages and extract user profile information. @@ -188,6 +189,10 @@ def add( Defaults to ["user"]. If explicitly set to None or [], no include filter is applied. exclude_roles: List of roles to exclude when filtering messages for profile extraction. Defaults to ["assistant"]. If explicitly set to None or [], no exclude filter is applied. + native_language: Optional ISO 639-1 language code (e.g., "zh", "en") to specify the target language + for profile extraction. If specified, the extracted profile will be written in this language + regardless of the languages used in the conversation. If not specified, the profile language + will follow the conversation language. Default: None Returns: Dict[str, Any]: A dictionary containing the add operation results with the following structure: @@ -229,6 +234,7 @@ def add( user_id=user_id, custom_topics=custom_topics, strict_mode=strict_mode, + native_language=native_language, ) result_key = "topics" else: @@ -236,6 +242,7 @@ def add( extracted_data = self._extract_profile( messages=filtered_messages, user_id=user_id, + native_language=native_language, ) result_key = "profile_content" @@ -346,6 +353,7 @@ def _extract_profile( self, messages: Any, user_id: str, + native_language: Optional[str] = None, ) -> str: """ Extract user profile information from conversation using LLM. @@ -354,6 +362,8 @@ def _extract_profile( Args: messages: Conversation messages (str, dict, or list[dict]) user_id: User identifier + native_language: Optional ISO 639-1 language code (e.g., "zh", "en") to specify the target language + for profile extraction. If specified, the extracted profile will be written in this language. Returns: Extracted profile content as text string, or empty string if no profile found @@ -373,7 +383,8 @@ def _extract_profile( # Generate system prompt and user message system_prompt, user_message = get_user_profile_extraction_prompt( conversation_text, - existing_profile=existing_profile + existing_profile=existing_profile, + native_language=native_language, ) # Call LLM to extract profile @@ -396,6 +407,7 @@ def _extract_topics( user_id: str, custom_topics: Optional[str] = None, strict_mode: bool = False, + native_language: Optional[str] = None, ) -> Optional[Dict[str, Any]]: """ Extract structured user profile topics from conversation using LLM. @@ -406,6 +418,8 @@ def _extract_topics( user_id: User identifier custom_topics: Optional custom topics JSON string. Format: {"main_topic": {"sub_topic": "description", ...}} strict_mode: If True, only output topics from the provided list + native_language: Optional ISO 639-1 language code (e.g., "zh", "en") to specify the target language + for topic value extraction. If specified, the topic values will be written in this language. Returns: Extracted topics as dictionary, or None if no topics found @@ -428,6 +442,7 @@ def _extract_topics( existing_topics=existing_topics, custom_topics=custom_topics, strict_mode=strict_mode, + native_language=native_language, ) # Call LLM to extract topics diff --git a/src/server/api/v1/users.py b/src/server/api/v1/users.py index 256662e..201fc57 100644 --- a/src/server/api/v1/users.py +++ b/src/server/api/v1/users.py @@ -77,6 +77,7 @@ async def add_user_profile( strict_mode=body.strict_mode, include_roles=body.include_roles, exclude_roles=body.exclude_roles, + native_language=body.native_language, ) return APIResponse( diff --git a/src/server/models/request.py b/src/server/models/request.py index f163405..3baa5e2 100644 --- a/src/server/models/request.py +++ b/src/server/models/request.py @@ -91,6 +91,7 @@ class UserProfileAddRequest(BaseModel): strict_mode: bool = Field(False, description="Only output topics from provided list (only used when profile_type='topics')") include_roles: Optional[List[str]] = Field(["user"], description="Roles to include when filtering messages. Default: ['user']. Set to None or [] to disable.") exclude_roles: Optional[List[str]] = Field(["assistant"], description="Roles to exclude when filtering messages. Default: ['assistant']. Set to None or [] to disable.") + native_language: Optional[str] = Field(None, description="ISO 639-1 language code (e.g., 'zh', 'en') for profile extraction. If specified, profile will be extracted in this language.") class UserProfileUpdateRequest(BaseModel): diff --git a/src/server/services/user_service.py b/src/server/services/user_service.py index 06c7098..1ea651b 100644 --- a/src/server/services/user_service.py +++ b/src/server/services/user_service.py @@ -85,6 +85,7 @@ def add_user_profile( strict_mode: bool = False, include_roles: Optional[List[str]] = ["user"], exclude_roles: Optional[List[str]] = ["assistant"], + native_language: Optional[str] = None, ) -> Dict[str, Any]: """ Add messages and extract user profile. @@ -105,6 +106,7 @@ def add_user_profile( strict_mode: Only output topics from provided list include_roles: Roles to include when filtering messages exclude_roles: Roles to exclude when filtering messages + native_language: ISO 639-1 language code (e.g., 'zh', 'en') for profile extraction Returns: Result dict with memory and profile extraction results @@ -136,6 +138,7 @@ def add_user_profile( strict_mode=strict_mode, include_roles=include_roles, exclude_roles=exclude_roles, + native_language=native_language, ) logger.info(f"User profile added: {user_id}") From 2be7d4f220e692147f4556bfa813ef991b286823 Mon Sep 17 00:00:00 2001 From: Even Date: Wed, 28 Jan 2026 14:34:17 +0800 Subject: [PATCH 03/23] Reconstruct LLM setting (#200) * feat(llm): enhance configuration management with pydantic-settings - Introduced a unified configuration system for LLM providers using pydantic-settings. - Added provider-specific settings for Anthropic, Azure, DeepSeek, Ollama, OpenAI, Qwen, Vllm, and Zai. - Improved environment variable handling and validation through Field and AliasChoices. - Removed legacy initialization methods in favor of a cleaner, more maintainable structure. - Updated LLMFactory to utilize the new provider registration mechanism. * chore: Update LLM configuration management and improve environment variable handling - Refactor LLM configuration imports to use BaseLLMConfig. - Replace direct attribute access with getattr for safer environment variable retrieval. - Remove deprecated LLMConfig and streamline related code for better maintainability. --- src/powermem/config_loader.py | 122 ++++++------ src/powermem/configs.py | 9 +- src/powermem/integrations/llm/__init__.py | 6 +- src/powermem/integrations/llm/azure.py | 14 +- .../integrations/llm/config/anthropic.py | 73 +++---- src/powermem/integrations/llm/config/azure.py | 113 +++++------ src/powermem/integrations/llm/config/base.py | 178 +++++++++++++----- .../integrations/llm/config/deepseek.py | 73 +++---- .../integrations/llm/config/gemini.py | 30 +++ .../integrations/llm/config/langchain.py | 24 +++ .../integrations/llm/config/ollama.py | 58 ++---- .../integrations/llm/config/openai.py | 126 ++++++------- .../llm/config/openai_structured.py | 14 ++ src/powermem/integrations/llm/config/qwen.py | 96 +++++----- .../integrations/llm/config/qwen_asr.py | 72 +++---- .../integrations/llm/config/siliconflow.py | 42 +++++ src/powermem/integrations/llm/config/vllm.py | 58 ++---- src/powermem/integrations/llm/config/zai.py | 82 ++++---- src/powermem/integrations/llm/configs.py | 32 ---- src/powermem/integrations/llm/deepseek.py | 2 +- src/powermem/integrations/llm/factory.py | 85 ++++----- src/powermem/integrations/llm/ollama.py | 2 +- src/powermem/integrations/llm/openai.py | 24 ++- .../integrations/llm/openai_structured.py | 2 +- src/powermem/integrations/llm/qwen.py | 14 +- src/powermem/integrations/llm/qwen_asr.py | 9 +- src/powermem/integrations/llm/siliconflow.py | 7 +- src/powermem/integrations/llm/vllm.py | 2 +- src/powermem/integrations/llm/zai.py | 5 +- src/powermem/storage/configs.py | 4 +- .../regression/test_scenario_7_multimodal.py | 28 ++- 31 files changed, 721 insertions(+), 685 deletions(-) create mode 100644 src/powermem/integrations/llm/config/gemini.py create mode 100644 src/powermem/integrations/llm/config/langchain.py create mode 100644 src/powermem/integrations/llm/config/openai_structured.py create mode 100644 src/powermem/integrations/llm/config/siliconflow.py delete mode 100644 src/powermem/integrations/llm/configs.py diff --git a/src/powermem/config_loader.py b/src/powermem/config_loader.py index 51c006a..894be2e 100644 --- a/src/powermem/config_loader.py +++ b/src/powermem/config_loader.py @@ -13,6 +13,7 @@ from powermem.integrations.embeddings.config.base import BaseEmbedderConfig from powermem.integrations.embeddings.config.providers import CustomEmbeddingConfig +from powermem.integrations.llm.config.base import BaseLLMConfig from powermem.settings import _DEFAULT_ENV_FILE, settings_config @@ -294,6 +295,18 @@ def to_config(self) -> Dict[str, Any]: class LLMSettings(_BasePowermemSettings): + """ + Unified LLM configuration settings. + + This class provides a common interface for configuring LLM providers. + It only contains fields that are common across all providers. + Provider-specific fields (e.g., dashscope_base_url for Qwen) should be + set via environment variables and will be loaded by the respective provider config classes. + + Design rationale: This follows the same pattern as EmbeddingSettings, + keeping the unified settings simple and delegating provider-specific + configuration to the provider config classes. + """ model_config = settings_config("LLM_") provider: str = Field(default="qwen") @@ -310,83 +323,54 @@ class LLMSettings(_BasePowermemSettings): max_tokens: int = Field(default=1000) top_p: float = Field(default=0.8) top_k: int = Field(default=50) - enable_search: bool = Field(default=False) - qwen_base_url: str = Field( - default="https://dashscope.aliyuncs.com/api/v1", - validation_alias=AliasChoices("QWEN_LLM_BASE_URL"), - ) - openai_base_url: str = Field( - default="https://api.openai.com/v1", - validation_alias=AliasChoices("OPENAI_LLM_BASE_URL"), - ) - siliconflow_base_url: str = Field( - default="https://api.siliconflow.cn/v1", - validation_alias=AliasChoices("SILICONFLOW_LLM_BASE_URL"), - ) - ollama_base_url: Optional[str] = Field( - default=None, - validation_alias=AliasChoices("OLLAMA_LLM_BASE_URL"), - ) - vllm_base_url: Optional[str] = Field( - default=None, - validation_alias=AliasChoices("VLLM_LLM_BASE_URL"), - ) - anthropic_base_url: str = Field( - default="https://api.anthropic.com", - validation_alias=AliasChoices("ANTHROPIC_LLM_BASE_URL"), - ) - deepseek_base_url: str = Field( - default="https://api.deepseek.com", - validation_alias=AliasChoices("DEEPSEEK_LLM_BASE_URL"), - ) - - def _apply_provider_config( - self, provider: str, config: Dict[str, Any] - ) -> None: - configurer = getattr(self, f"_configure_{provider}", None) - if callable(configurer): - configurer(config) - - def _configure_qwen(self, config: Dict[str, Any]) -> None: - config["dashscope_base_url"] = self.qwen_base_url - config["enable_search"] = self.enable_search - - def _configure_openai(self, config: Dict[str, Any]) -> None: - config["openai_base_url"] = self.openai_base_url - - def _configure_siliconflow(self, config: Dict[str, Any]) -> None: - config["openai_base_url"] = self.siliconflow_base_url - - def _configure_ollama(self, config: Dict[str, Any]) -> None: - if self.ollama_base_url is not None: - config["ollama_base_url"] = self.ollama_base_url - - def _configure_vllm(self, config: Dict[str, Any]) -> None: - if self.vllm_base_url is not None: - config["vllm_base_url"] = self.vllm_base_url - - def _configure_anthropic(self, config: Dict[str, Any]) -> None: - config["anthropic_base_url"] = self.anthropic_base_url - - def _configure_deepseek(self, config: Dict[str, Any]) -> None: - config["deepseek_base_url"] = self.deepseek_base_url def to_config(self) -> Dict[str, Any]: + """ + Convert settings to LLM configuration dictionary. + + This method: + 1. Gets the appropriate provider config class + 2. Creates an instance (loading provider-specific fields from environment) + 3. Overrides with explicitly set common fields from this settings object + 4. Returns the final configuration + + Provider-specific fields (e.g., dashscope_base_url, enable_search) are + automatically loaded from environment variables by the provider config class. + """ llm_provider = self.provider.lower() + + # Determine model name llm_model = self.model if llm_model is None: llm_model = "qwen-plus" if llm_provider == "qwen" else "gpt-4o-mini" - llm_config = { - "api_key": self.api_key, - "model": llm_model, - "temperature": self.temperature, - "max_tokens": self.max_tokens, - "top_p": self.top_p, - "top_k": self.top_k, - } + # 1. Get provider config class from registry + config_cls = ( + BaseLLMConfig.get_provider_config_cls(llm_provider) + or BaseLLMConfig # fallback to base config + ) + + # 2. Create provider settings from environment variables + # Provider-specific fields are automatically loaded here + provider_settings = config_cls() + + # 3. Collect common fields to override + overrides = {} + for field in ("api_key", "temperature", "max_tokens", "top_p", "top_k"): + if field in self.model_fields_set: + value = getattr(self, field) + if value is not None: + overrides[field] = value + + # Always set model + overrides["model"] = llm_model + + # 4. Update configuration with overrides + if overrides: + provider_settings = provider_settings.model_copy(update=overrides) - self._apply_provider_config(llm_provider, llm_config) + # 5. Export to dict + llm_config = provider_settings.model_dump(exclude_none=True) return {"provider": llm_provider, "config": llm_config} diff --git a/src/powermem/configs.py b/src/powermem/configs.py index 32182be..1c39702 100644 --- a/src/powermem/configs.py +++ b/src/powermem/configs.py @@ -11,7 +11,8 @@ from powermem.integrations.embeddings.config.base import BaseEmbedderConfig from powermem.integrations.embeddings.config.providers import OpenAIEmbeddingConfig from powermem.integrations.embeddings.config.sparse_base import SparseEmbedderConfig -from powermem.integrations.llm import LlmConfig +from powermem.integrations.llm.config.base import BaseLLMConfig +from powermem.integrations.llm.config.qwen import QwenConfig from powermem.storage.configs import VectorStoreConfig, GraphStoreConfig from powermem.integrations.rerank.configs import RerankConfig @@ -198,9 +199,9 @@ class MemoryConfig(BaseModel): description="Configuration for the vector store", default_factory=VectorStoreConfig, ) - llm: LlmConfig = Field( + llm: BaseLLMConfig = Field( description="Configuration for the language model", - default_factory=LlmConfig, + default_factory=QwenConfig, ) embedder: BaseEmbedderConfig = Field( description="Configuration for the embedding model", @@ -254,7 +255,7 @@ class MemoryConfig(BaseModel): description="Configuration for application logging", default=None, ) - audio_llm: Optional[LlmConfig] = Field( + audio_llm: Optional[BaseLLMConfig] = Field( description="Configuration for audio language model", default=None, ) diff --git a/src/powermem/integrations/llm/__init__.py b/src/powermem/integrations/llm/__init__.py index f874210..753305a 100644 --- a/src/powermem/integrations/llm/__init__.py +++ b/src/powermem/integrations/llm/__init__.py @@ -4,15 +4,15 @@ This module provides LLM integrations and factory. """ from .base import LLMBase -from .configs import LLMConfig from .factory import LLMFactory +from .config.base import BaseLLMConfig # provider alias name LlmFactory = LLMFactory -LlmConfig = LLMConfig __all__ = [ "LLMBase", "LlmFactory", - "LlmConfig" + "LLMFactory", + "BaseLLMConfig", ] diff --git a/src/powermem/integrations/llm/azure.py b/src/powermem/integrations/llm/azure.py index b7147f5..32623ac 100644 --- a/src/powermem/integrations/llm/azure.py +++ b/src/powermem/integrations/llm/azure.py @@ -46,7 +46,7 @@ def __init__(self, config: Optional[Union[BaseLLMConfig, AzureOpenAIConfig, Dict # Get Azure endpoint from config or environment azure_endpoint = ( - self.config.azure_endpoint + getattr(self.config, "azure_endpoint", None) or os.getenv("AZURE_OPENAI_ENDPOINT") or os.getenv("ENDPOINT_URL") ) @@ -59,18 +59,19 @@ def __init__(self, config: Optional[Union[BaseLLMConfig, AzureOpenAIConfig, Dict # Get API version from config or environment api_version = ( - self.config.api_version + getattr(self.config, "api_version", None) or os.getenv("AZURE_OPENAI_API_VERSION") or "2025-01-01-preview" ) # Initialize Azure OpenAI client # Support both API key and Azure AD token authentication - if self.config.azure_ad_token_provider: + azure_ad_token_provider = getattr(self.config, "azure_ad_token_provider", None) + if azure_ad_token_provider: # Use Azure AD token provider (Entra ID authentication) self.client = AzureOpenAI( azure_endpoint=azure_endpoint, - azure_ad_token_provider=self.config.azure_ad_token_provider, + azure_ad_token_provider=azure_ad_token_provider, api_version=api_version, ) else: @@ -190,9 +191,10 @@ def generate_response( response = self.client.chat.completions.create(**params) parsed_response = self._parse_response(response, tools) - if hasattr(self.config, "response_callback") and self.config.response_callback: + response_callback = getattr(self.config, "response_callback", None) + if response_callback: try: - self.config.response_callback(self, response, params) + response_callback(self, response, params) except Exception as e: # Log error but don't propagate logging.error(f"Error due to callback: {e}") diff --git a/src/powermem/integrations/llm/config/anthropic.py b/src/powermem/integrations/llm/config/anthropic.py index 26c4d5f..b0139aa 100644 --- a/src/powermem/integrations/llm/config/anthropic.py +++ b/src/powermem/integrations/llm/config/anthropic.py @@ -1,6 +1,9 @@ from typing import Optional +from pydantic import AliasChoices, Field + from powermem.integrations.llm.config.base import BaseLLMConfig +from powermem.settings import settings_config class AnthropicConfig(BaseLLMConfig): @@ -9,48 +12,28 @@ class AnthropicConfig(BaseLLMConfig): Inherits from BaseLLMConfig and adds Anthropic-specific settings. """ - def __init__( - self, - # Base parameters - model: Optional[str] = None, - temperature: float = 0.1, - api_key: Optional[str] = None, - max_tokens: int = 2000, - top_p: float = 0.1, - top_k: int = 1, - enable_vision: bool = False, - vision_details: Optional[str] = "auto", - http_client_proxies: Optional[dict] = None, - # Anthropic-specific parameters - anthropic_base_url: Optional[str] = None, - ): - """ - Initialize Anthropic configuration. - - Args: - model: Anthropic model to use, defaults to None - temperature: Controls randomness, defaults to 0.1 - api_key: Anthropic API key, defaults to None - max_tokens: Maximum tokens to generate, defaults to 2000 - top_p: Nucleus sampling parameter, defaults to 0.1 - top_k: Top-k sampling parameter, defaults to 1 - enable_vision: Enable vision capabilities, defaults to False - vision_details: Vision detail level, defaults to "auto" - http_client_proxies: HTTP client proxy settings, defaults to None - anthropic_base_url: Anthropic API base URL, defaults to None - """ - # Initialize base parameters - super().__init__( - model=model, - temperature=temperature, - api_key=api_key, - max_tokens=max_tokens, - top_p=top_p, - top_k=top_k, - enable_vision=enable_vision, - vision_details=vision_details, - http_client_proxies=http_client_proxies, - ) - - # Anthropic-specific parameters - self.anthropic_base_url = anthropic_base_url + _provider_name = "anthropic" + _class_path = "powermem.integrations.llm.anthropic.AnthropicLLM" + + model_config = settings_config("LLM_", extra="forbid", env_file=None) + + # Override base fields with Anthropic-specific validation_alias + api_key: Optional[str] = Field( + default=None, + validation_alias=AliasChoices( + "api_key", + "LLM_API_KEY", + "ANTHROPIC_API_KEY", + ), + description="Anthropic API key" + ) + + # Anthropic-specific fields + anthropic_base_url: Optional[str] = Field( + default=None, + validation_alias=AliasChoices( + "anthropic_base_url", + "ANTHROPIC_LLM_BASE_URL", + ), + description="Anthropic API base URL" + ) diff --git a/src/powermem/integrations/llm/config/azure.py b/src/powermem/integrations/llm/config/azure.py index 1b24b66..21e8fe0 100644 --- a/src/powermem/integrations/llm/config/azure.py +++ b/src/powermem/integrations/llm/config/azure.py @@ -1,6 +1,9 @@ -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Optional + +from pydantic import AliasChoices, Field from powermem.integrations.llm.config.base import BaseLLMConfig +from powermem.settings import settings_config class AzureOpenAIConfig(BaseLLMConfig): @@ -9,59 +12,61 @@ class AzureOpenAIConfig(BaseLLMConfig): Inherits from BaseLLMConfig and adds Azure OpenAI-specific settings. """ - def __init__( - self, - # Base parameters - model: Optional[str] = None, - temperature: float = 0.1, - api_key: Optional[str] = None, - max_tokens: int = 2000, - top_p: float = 0.1, - top_k: int = 1, - enable_vision: bool = False, - vision_details: Optional[str] = "auto", - http_client_proxies: Optional[dict] = None, - # Azure OpenAI-specific parameters - azure_endpoint: Optional[str] = None, - api_version: Optional[str] = "2025-01-01-preview", - azure_ad_token_provider: Optional[Callable[[], str]] = None, - deployment_name: Optional[str] = None, - ): - """ - Initialize Azure OpenAI configuration. + _provider_name = "azure" + _class_path = "powermem.integrations.llm.azure.AzureLLM" + + model_config = settings_config("LLM_", extra="forbid", env_file=None) + + # Override base fields with Azure-specific validation_alias + api_key: Optional[str] = Field( + default=None, + validation_alias=AliasChoices( + "api_key", + "LLM_API_KEY", + "AZURE_OPENAI_API_KEY", + "AZURE_API_KEY", + ), + description="Azure OpenAI API key" + ) - Args: - model: Azure OpenAI deployment name to use, defaults to None - temperature: Controls randomness, defaults to 0.1 - api_key: Azure OpenAI API key, defaults to None - max_tokens: Maximum tokens to generate, defaults to 2000 - top_p: Nucleus sampling parameter, defaults to 0.1 - top_k: Top-k sampling parameter, defaults to 1 - enable_vision: Enable vision capabilities, defaults to False - vision_details: Vision detail level, defaults to "auto" - http_client_proxies: HTTP client proxy settings, defaults to None - azure_endpoint: Azure OpenAI endpoint URL, defaults to None - api_version: Azure OpenAI API version, defaults to "2025-01-01-preview" - azure_ad_token_provider: Callable that returns an Azure AD token, defaults to None - deployment_name: Azure OpenAI deployment name (alias for model), defaults to None - """ - # Initialize base parameters - super().__init__( - model=model, - temperature=temperature, - api_key=api_key, - max_tokens=max_tokens, - top_p=top_p, - top_k=top_k, - enable_vision=enable_vision, - vision_details=vision_details, - http_client_proxies=http_client_proxies, - ) + # Azure OpenAI-specific fields + azure_endpoint: Optional[str] = Field( + default=None, + validation_alias=AliasChoices( + "azure_endpoint", + "AZURE_ENDPOINT", + "AZURE_OPENAI_ENDPOINT", + ), + description="Azure OpenAI endpoint URL" + ) + + api_version: Optional[str] = Field( + default="2025-01-01-preview", + validation_alias=AliasChoices( + "api_version", + "AZURE_API_VERSION", + ), + description="Azure OpenAI API version" + ) + + azure_ad_token_provider: Optional[Callable[[], str]] = Field( + default=None, + exclude=True, + description="Callable that returns an Azure AD token" + ) + + deployment_name: Optional[str] = Field( + default=None, + validation_alias=AliasChoices( + "deployment_name", + "AZURE_DEPLOYMENT", + ), + description="Azure OpenAI deployment name (alias for model)" + ) - # Azure OpenAI-specific parameters - self.azure_endpoint = azure_endpoint - self.api_version = api_version - self.azure_ad_token_provider = azure_ad_token_provider + def model_post_init(self, __context: Any) -> None: + """Initialize fields after model creation.""" + super().model_post_init(__context) # Use deployment_name if provided, otherwise use model - if deployment_name: - self.model = deployment_name + if self.deployment_name: + self.model = self.deployment_name diff --git a/src/powermem/integrations/llm/config/base.py b/src/powermem/integrations/llm/config/base.py index 9a9c129..2a5ac4e 100644 --- a/src/powermem/integrations/llm/config/base.py +++ b/src/powermem/integrations/llm/config/base.py @@ -1,62 +1,144 @@ -from abc import ABC -from typing import Dict, Optional, Union +from typing import Any, ClassVar, Dict, Optional, Union import httpx +from pydantic import Field +from pydantic_settings import BaseSettings +from powermem.settings import settings_config -class BaseLLMConfig(ABC): + +class BaseLLMConfig(BaseSettings): """ Base configuration for LLMs with only common parameters. Provider-specific configurations should be handled by separate config classes. This class contains only the parameters that are common across all LLM providers. For provider-specific parameters, use the appropriate provider config class. + + Now uses pydantic-settings for automatic environment variable loading. """ - def __init__( - self, - model: Optional[Union[str, Dict]] = None, - temperature: float = 0.1, - api_key: Optional[str] = None, - max_tokens: int = 2000, - top_p: float = 0.1, - top_k: int = 1, - enable_vision: bool = False, - vision_details: Optional[str] = "auto", - http_client_proxies: Optional[Union[Dict, str]] = None, - ): + model_config = settings_config("LLM_", extra="allow", env_file=None) + + # Registry for provider configurations + _provider_name: ClassVar[Optional[str]] = None + _class_path: ClassVar[Optional[str]] = None + _registry: ClassVar[dict[str, type["BaseLLMConfig"]]] = {} + _class_paths: ClassVar[dict[str, str]] = {} + + # Field definitions + model: Optional[Union[str, Dict]] = Field( + default=None, + description="The model identifier to use (e.g., 'gpt-4o-mini', 'claude-3-5-sonnet-20240620'). " + "Defaults to None (will be set by provider-specific configs)" + ) + + temperature: float = Field( + default=0.1, + description="Controls the randomness of the model's output. " + "Higher values (closer to 1) make output more random, lower values make it more deterministic. " + "Range: 0.0 to 2.0" + ) + + api_key: Optional[str] = Field( + default=None, + description="API key for the LLM provider. If None, will try to get from environment variables" + ) + + max_tokens: int = Field( + default=2000, + description="Maximum number of tokens to generate in the response. " + "Range: 1 to 4096 (varies by model)" + ) + + top_p: float = Field( + default=0.1, + description="Nucleus sampling parameter. Controls diversity via nucleus sampling. " + "Higher values (closer to 1) make word selection more diverse. " + "Range: 0.0 to 1.0" + ) + + top_k: int = Field( + default=1, + description="Top-k sampling parameter. Limits the number of tokens considered for each step. " + "Higher values make word selection more diverse. " + "Range: 1 to 40" + ) + + enable_vision: bool = Field( + default=False, + description="Whether to enable vision capabilities for the model. " + "Only applicable to vision-enabled models" + ) + + vision_details: Optional[str] = Field( + default="auto", + description="Level of detail for vision processing. Options: 'low', 'high', 'auto'" + ) + + http_client_proxies: Optional[Union[Dict, str]] = Field( + default=None, + description="Proxy settings for HTTP client. Can be a dict or string" + ) + + http_client: Optional[httpx.Client] = Field( + default=None, + exclude=True, + description="HTTP client instance (automatically initialized from http_client_proxies)" + ) + + @classmethod + def _register_provider(cls) -> None: + """Register provider in the global registry.""" + provider = getattr(cls, "_provider_name", None) + class_path = getattr(cls, "_class_path", None) + if provider: + BaseLLMConfig._registry[provider] = cls + if class_path: + BaseLLMConfig._class_paths[provider] = class_path + + def __init_subclass__(cls, **kwargs) -> None: + """Called when a class inherits from BaseLLMConfig.""" + super().__init_subclass__(**kwargs) + cls._register_provider() + + @classmethod + def __pydantic_init_subclass__(cls, **kwargs) -> None: + """Called by Pydantic when a class inherits from BaseLLMConfig.""" + super().__pydantic_init_subclass__(**kwargs) + cls._register_provider() + + @classmethod + def get_provider_config_cls(cls, provider: str) -> Optional[type["BaseLLMConfig"]]: + """Get the config class for a specific provider.""" + return cls._registry.get(provider) + + @classmethod + def get_provider_class_path(cls, provider: str) -> Optional[str]: + """Get the class path for a specific provider.""" + return cls._class_paths.get(provider) + + @classmethod + def has_provider(cls, provider: str) -> bool: + """Check if a provider is registered.""" + return provider in cls._registry + + def model_post_init(self, __context: Any) -> None: + """Initialize http_client after model creation.""" + if self.http_client_proxies and not self.http_client: + self.http_client = httpx.Client(proxies=self.http_client_proxies) + + def to_component_dict(self) -> Dict[str, Any]: """ - Initialize a base configuration class instance for the LLM. - - Args: - model: The model identifier to use (e.g., "gpt-4o-mini", "claude-3-5-sonnet-20240620") - Defaults to None (will be set by provider-specific configs) - temperature: Controls the randomness of the model's output. - Higher values (closer to 1) make output more random, lower values make it more deterministic. - Range: 0.0 to 2.0. Defaults to 0.1 - api_key: API key for the LLM provider. If None, will try to get from environment variables. - Defaults to None - max_tokens: Maximum number of tokens to generate in the response. - Range: 1 to 4096 (varies by model). Defaults to 2000 - top_p: Nucleus sampling parameter. Controls diversity via nucleus sampling. - Higher values (closer to 1) make word selection more diverse. - Range: 0.0 to 1.0. Defaults to 0.1 - top_k: Top-k sampling parameter. Limits the number of tokens considered for each step. - Higher values make word selection more diverse. - Range: 1 to 40. Defaults to 1 - enable_vision: Whether to enable vision capabilities for the model. - Only applicable to vision-enabled models. Defaults to False - vision_details: Level of detail for vision processing. - Options: "low", "high", "auto". Defaults to "auto" - http_client_proxies: Proxy settings for HTTP client. - Can be a dict or string. Defaults to None + Convert config to component dictionary format. + + This method is used by MemoryConfig.to_dict() to serialize + LLM configuration in a consistent format. + + Returns: + Dict with 'provider' and 'config' keys """ - self.model = model - self.temperature = temperature - self.api_key = api_key - self.max_tokens = max_tokens - self.top_p = top_p - self.top_k = top_k - self.enable_vision = enable_vision - self.vision_details = vision_details - self.http_client = httpx.Client(proxies=http_client_proxies) if http_client_proxies else None + return { + "provider": self._provider_name, + "config": self.model_dump(exclude_none=True) + } diff --git a/src/powermem/integrations/llm/config/deepseek.py b/src/powermem/integrations/llm/config/deepseek.py index 3fb6076..c42085c 100644 --- a/src/powermem/integrations/llm/config/deepseek.py +++ b/src/powermem/integrations/llm/config/deepseek.py @@ -1,6 +1,9 @@ from typing import Optional +from pydantic import AliasChoices, Field + from powermem.integrations.llm.config.base import BaseLLMConfig +from powermem.settings import settings_config class DeepSeekConfig(BaseLLMConfig): @@ -9,48 +12,28 @@ class DeepSeekConfig(BaseLLMConfig): Inherits from BaseLLMConfig and adds DeepSeek-specific settings. """ - def __init__( - self, - # Base parameters - model: Optional[str] = None, - temperature: float = 0.1, - api_key: Optional[str] = None, - max_tokens: int = 2000, - top_p: float = 0.1, - top_k: int = 1, - enable_vision: bool = False, - vision_details: Optional[str] = "auto", - http_client_proxies: Optional[dict] = None, - # DeepSeek-specific parameters - deepseek_base_url: Optional[str] = None, - ): - """ - Initialize DeepSeek configuration. - - Args: - model: DeepSeek model to use, defaults to None - temperature: Controls randomness, defaults to 0.1 - api_key: DeepSeek API key, defaults to None - max_tokens: Maximum tokens to generate, defaults to 2000 - top_p: Nucleus sampling parameter, defaults to 0.1 - top_k: Top-k sampling parameter, defaults to 1 - enable_vision: Enable vision capabilities, defaults to False - vision_details: Vision detail level, defaults to "auto" - http_client_proxies: HTTP client proxy settings, defaults to None - deepseek_base_url: DeepSeek API base URL, defaults to None - """ - # Initialize base parameters - super().__init__( - model=model, - temperature=temperature, - api_key=api_key, - max_tokens=max_tokens, - top_p=top_p, - top_k=top_k, - enable_vision=enable_vision, - vision_details=vision_details, - http_client_proxies=http_client_proxies, - ) - - # DeepSeek-specific parameters - self.deepseek_base_url = deepseek_base_url + _provider_name = "deepseek" + _class_path = "powermem.integrations.llm.deepseek.DeepSeekLLM" + + model_config = settings_config("LLM_", extra="forbid", env_file=None) + + # Override base fields with DeepSeek-specific validation_alias + api_key: Optional[str] = Field( + default=None, + validation_alias=AliasChoices( + "api_key", + "LLM_API_KEY", + "DEEPSEEK_API_KEY", + ), + description="DeepSeek API key" + ) + + # DeepSeek-specific fields + deepseek_base_url: Optional[str] = Field( + default=None, + validation_alias=AliasChoices( + "deepseek_base_url", + "DEEPSEEK_LLM_BASE_URL", + ), + description="DeepSeek API base URL" + ) diff --git a/src/powermem/integrations/llm/config/gemini.py b/src/powermem/integrations/llm/config/gemini.py new file mode 100644 index 0000000..59e0d35 --- /dev/null +++ b/src/powermem/integrations/llm/config/gemini.py @@ -0,0 +1,30 @@ +from typing import Optional + +from pydantic import AliasChoices, Field + +from powermem.integrations.llm.config.base import BaseLLMConfig +from powermem.settings import settings_config + + +class GeminiConfig(BaseLLMConfig): + """ + Configuration class for Google Gemini-specific parameters. + Inherits from BaseLLMConfig. + """ + + _provider_name = "gemini" + _class_path = "powermem.integrations.llm.gemini.GeminiLLM" + + model_config = settings_config("LLM_", extra="forbid", env_file=None) + + # Override base fields with Gemini-specific validation_alias + api_key: Optional[str] = Field( + default=None, + validation_alias=AliasChoices( + "api_key", + "LLM_API_KEY", + "GEMINI_API_KEY", + "GOOGLE_API_KEY", + ), + description="Google Gemini API key" + ) diff --git a/src/powermem/integrations/llm/config/langchain.py b/src/powermem/integrations/llm/config/langchain.py new file mode 100644 index 0000000..be1a013 --- /dev/null +++ b/src/powermem/integrations/llm/config/langchain.py @@ -0,0 +1,24 @@ +from typing import Any, Optional + +from pydantic import Field + +from powermem.integrations.llm.config.base import BaseLLMConfig +from powermem.settings import settings_config + + +class LangchainConfig(BaseLLMConfig): + """ + Configuration class for Langchain LLM wrapper. + Inherits from BaseLLMConfig. + """ + + _provider_name = "langchain" + _class_path = "powermem.integrations.llm.langchain.LangchainLLM" + + model_config = settings_config("LLM_", extra="forbid", env_file=None) + + # Langchain uses a model object instead of string + model: Optional[Any] = Field( + default=None, + description="Langchain LLM model object" + ) diff --git a/src/powermem/integrations/llm/config/ollama.py b/src/powermem/integrations/llm/config/ollama.py index 6437dad..f177993 100644 --- a/src/powermem/integrations/llm/config/ollama.py +++ b/src/powermem/integrations/llm/config/ollama.py @@ -1,6 +1,9 @@ from typing import Optional +from pydantic import AliasChoices, Field + from powermem.integrations.llm.config.base import BaseLLMConfig +from powermem.settings import settings_config class OllamaConfig(BaseLLMConfig): @@ -9,48 +12,17 @@ class OllamaConfig(BaseLLMConfig): Inherits from BaseLLMConfig and adds Ollama-specific settings. """ - def __init__( - self, - # Base parameters - model: Optional[str] = None, - temperature: float = 0.1, - api_key: Optional[str] = None, - max_tokens: int = 2000, - top_p: float = 0.1, - top_k: int = 1, - enable_vision: bool = False, - vision_details: Optional[str] = "auto", - http_client_proxies: Optional[dict] = None, - # Ollama-specific parameters - ollama_base_url: Optional[str] = None, - ): - """ - Initialize Ollama configuration. + _provider_name = "ollama" + _class_path = "powermem.integrations.llm.ollama.OllamaLLM" - Args: - model: Ollama model to use, defaults to None - temperature: Controls randomness, defaults to 0.1 - api_key: Ollama API key, defaults to None - max_tokens: Maximum tokens to generate, defaults to 2000 - top_p: Nucleus sampling parameter, defaults to 0.1 - top_k: Top-k sampling parameter, defaults to 1 - enable_vision: Enable vision capabilities, defaults to False - vision_details: Vision detail level, defaults to "auto" - http_client_proxies: HTTP client proxy settings, defaults to None - ollama_base_url: Ollama base URL, defaults to None - """ - # Initialize base parameters - super().__init__( - model=model, - temperature=temperature, - api_key=api_key, - max_tokens=max_tokens, - top_p=top_p, - top_k=top_k, - enable_vision=enable_vision, - vision_details=vision_details, - http_client_proxies=http_client_proxies, - ) + model_config = settings_config("LLM_", extra="forbid", env_file=None) - # Ollama-specific parameters - self.ollama_base_url = ollama_base_url + # Ollama-specific fields + ollama_base_url: Optional[str] = Field( + default=None, + validation_alias=AliasChoices( + "ollama_base_url", + "OLLAMA_LLM_BASE_URL", + ), + description="Ollama base URL" + ) diff --git a/src/powermem/integrations/llm/config/openai.py b/src/powermem/integrations/llm/config/openai.py index e0fa8a1..fdf5f6b 100644 --- a/src/powermem/integrations/llm/config/openai.py +++ b/src/powermem/integrations/llm/config/openai.py @@ -1,6 +1,9 @@ from typing import Any, Callable, List, Optional +from pydantic import AliasChoices, Field + from powermem.integrations.llm.config.base import BaseLLMConfig +from powermem.settings import settings_config class OpenAIConfig(BaseLLMConfig): @@ -9,71 +12,64 @@ class OpenAIConfig(BaseLLMConfig): Inherits from BaseLLMConfig and adds OpenAI-specific settings. """ - def __init__( - self, - # Base parameters - model: Optional[str] = None, - temperature: float = 0.1, - api_key: Optional[str] = None, - max_tokens: int = 2000, - top_p: float = 0.1, - top_k: int = 1, - enable_vision: bool = False, - vision_details: Optional[str] = "auto", - http_client_proxies: Optional[dict] = None, - # OpenAI-specific parameters - openai_base_url: Optional[str] = None, - models: Optional[List[str]] = None, - route: Optional[str] = "fallback", - openrouter_base_url: Optional[str] = None, - site_url: Optional[str] = None, - app_name: Optional[str] = None, - store: bool = False, - # Response monitoring callback - response_callback: Optional[Callable[[Any, dict, dict], None]] = None, - ): - """ - Initialize OpenAI configuration. + _provider_name = "openai" + _class_path = "powermem.integrations.llm.openai.OpenAILLM" - Args: - model: OpenAI model to use, defaults to None - temperature: Controls randomness, defaults to 0.1 - api_key: OpenAI API key, defaults to None - max_tokens: Maximum tokens to generate, defaults to 2000 - top_p: Nucleus sampling parameter, defaults to 0.1 - top_k: Top-k sampling parameter, defaults to 1 - enable_vision: Enable vision capabilities, defaults to False - vision_details: Vision detail level, defaults to "auto" - http_client_proxies: HTTP client proxy settings, defaults to None - openai_base_url: OpenAI API base URL, defaults to None - models: List of models for OpenRouter, defaults to None - route: OpenRouter route strategy, defaults to "fallback" - openrouter_base_url: OpenRouter base URL, defaults to None - site_url: Site URL for OpenRouter, defaults to None - app_name: Application name for OpenRouter, defaults to None - response_callback: Optional callback for monitoring LLM responses. - """ - # Initialize base parameters - super().__init__( - model=model, - temperature=temperature, - api_key=api_key, - max_tokens=max_tokens, - top_p=top_p, - top_k=top_k, - enable_vision=enable_vision, - vision_details=vision_details, - http_client_proxies=http_client_proxies, - ) + model_config = settings_config("LLM_", extra="forbid", env_file=None) - # OpenAI-specific parameters - self.openai_base_url = openai_base_url - self.models = models - self.route = route - self.openrouter_base_url = openrouter_base_url - self.site_url = site_url - self.app_name = app_name - self.store = store + # Override base fields with OpenAI-specific validation_alias + api_key: Optional[str] = Field( + default=None, + validation_alias=AliasChoices( + "api_key", + "LLM_API_KEY", + "OPENAI_API_KEY", + ), + description="OpenAI API key" + ) - # Response monitoring - self.response_callback = response_callback + # OpenAI-specific fields + openai_base_url: Optional[str] = Field( + default=None, + validation_alias=AliasChoices( + "openai_base_url", + "OPENAI_LLM_BASE_URL", + ), + description="OpenAI API base URL" + ) + + models: Optional[List[str]] = Field( + default=None, + description="List of models for OpenRouter" + ) + + route: Optional[str] = Field( + default="fallback", + description="OpenRouter route strategy" + ) + + openrouter_base_url: Optional[str] = Field( + default=None, + description="OpenRouter base URL" + ) + + site_url: Optional[str] = Field( + default=None, + description="Site URL for OpenRouter" + ) + + app_name: Optional[str] = Field( + default=None, + description="Application name for OpenRouter" + ) + + store: bool = Field( + default=False, + description="Whether to store conversations" + ) + + response_callback: Optional[Callable[[Any, dict, dict], None]] = Field( + default=None, + exclude=True, + description="Optional callback for monitoring LLM responses" + ) diff --git a/src/powermem/integrations/llm/config/openai_structured.py b/src/powermem/integrations/llm/config/openai_structured.py new file mode 100644 index 0000000..1702a57 --- /dev/null +++ b/src/powermem/integrations/llm/config/openai_structured.py @@ -0,0 +1,14 @@ +from powermem.integrations.llm.config.openai import OpenAIConfig +from powermem.settings import settings_config + + +class OpenAIStructuredConfig(OpenAIConfig): + """ + Configuration class for OpenAI Structured Output. + Inherits all configuration from OpenAIConfig, only overrides metadata. + """ + + _provider_name = "openai_structured" + _class_path = "powermem.integrations.llm.openai_structured.OpenAIStructuredLLM" + + model_config = settings_config("LLM_", extra="forbid", env_file=None) diff --git a/src/powermem/integrations/llm/config/qwen.py b/src/powermem/integrations/llm/config/qwen.py index 8ed87ea..a919349 100644 --- a/src/powermem/integrations/llm/config/qwen.py +++ b/src/powermem/integrations/llm/config/qwen.py @@ -1,6 +1,9 @@ from typing import Any, Callable, Optional +from pydantic import AliasChoices, Field + from powermem.integrations.llm.config.base import BaseLLMConfig +from powermem.settings import settings_config class QwenConfig(BaseLLMConfig): @@ -9,60 +12,45 @@ class QwenConfig(BaseLLMConfig): Inherits from BaseLLMConfig and adds Qwen-specific settings. """ - def __init__( - self, - # Base parameters - model: Optional[str] = None, - temperature: float = 0.1, - api_key: Optional[str] = None, - max_tokens: int = 2000, - top_p: float = 0.1, - top_k: int = 1, - enable_vision: bool = False, - vision_details: Optional[str] = "auto", - http_client_proxies: Optional[dict] = None, - # Qwen-specific parameters - dashscope_base_url: Optional[str] = None, - enable_search: bool = False, - search_params: Optional[dict] = None, - # Response monitoring callback - response_callback: Optional[Callable[[Any, dict, dict], None]] = None, - ): - """ - Initialize Qwen configuration. + _provider_name = "qwen" + _class_path = "powermem.integrations.llm.qwen.QwenLLM" - Args: - model: Qwen model to use, defaults to None - temperature: Controls randomness, defaults to 0.1 - api_key: DashScope API key, defaults to None - max_tokens: Maximum tokens to generate, defaults to 2000 - top_p: Nucleus sampling parameter, defaults to 0.1 - top_k: Top-k sampling parameter, defaults to 1 - enable_vision: Enable vision capabilities, defaults to False - vision_details: Vision detail level, defaults to "auto" - http_client_proxies: HTTP client proxy settings, defaults to None - dashscope_base_url: DashScope API base URL, defaults to None - enable_search: Enable web search capability, defaults to False - search_params: Parameters for web search, defaults to None - response_callback: Optional callback for monitoring LLM responses. - """ - # Initialize base parameters - super().__init__( - model=model, - temperature=temperature, - api_key=api_key, - max_tokens=max_tokens, - top_p=top_p, - top_k=top_k, - enable_vision=enable_vision, - vision_details=vision_details, - http_client_proxies=http_client_proxies, - ) + model_config = settings_config("LLM_", extra="forbid", env_file=None) - # Qwen-specific parameters - self.dashscope_base_url = dashscope_base_url - self.enable_search = enable_search - self.search_params = search_params or {} + # Override base fields with Qwen-specific validation_alias + api_key: Optional[str] = Field( + default=None, + validation_alias=AliasChoices( + "api_key", + "LLM_API_KEY", + "QWEN_API_KEY", + "DASHSCOPE_API_KEY", + ), + description="DashScope API key for Qwen models" + ) - # Response monitoring - self.response_callback = response_callback + # Qwen-specific fields + dashscope_base_url: Optional[str] = Field( + default=None, + validation_alias=AliasChoices( + "dashscope_base_url", + "QWEN_LLM_BASE_URL", + ), + description="DashScope API base URL" + ) + + enable_search: bool = Field( + default=False, + description="Enable web search capability for Qwen models" + ) + + search_params: Optional[dict] = Field( + default_factory=dict, + description="Parameters for web search functionality" + ) + + response_callback: Optional[Callable[[Any, dict, dict], None]] = Field( + default=None, + exclude=True, + description="Optional callback for monitoring LLM responses" + ) diff --git a/src/powermem/integrations/llm/config/qwen_asr.py b/src/powermem/integrations/llm/config/qwen_asr.py index 5ed09b1..7bd2ee1 100644 --- a/src/powermem/integrations/llm/config/qwen_asr.py +++ b/src/powermem/integrations/llm/config/qwen_asr.py @@ -1,6 +1,9 @@ from typing import Optional +from pydantic import AliasChoices, Field + from powermem.integrations.llm.config.base import BaseLLMConfig +from powermem.settings import settings_config class QwenASRConfig(BaseLLMConfig): @@ -9,38 +12,39 @@ class QwenASRConfig(BaseLLMConfig): Inherits from BaseLLMConfig and adds ASR-specific settings. """ - def __init__( - self, - # Base parameters (only model and api_key are used for ASR) - model: Optional[str] = None, - api_key: Optional[str] = None, - # ASR-specific parameters - dashscope_base_url: Optional[str] = None, - asr_options: Optional[dict] = None, - result_format: str = "message", - ): - """ - Initialize Qwen ASR configuration. - - Args: - model: Qwen ASR model to use, defaults to "qwen3-asr-flash" - api_key: DashScope API key, defaults to None - dashscope_base_url: DashScope API base URL, defaults to None - asr_options: ASR-specific options (e.g., language, enable_itn), defaults to {"enable_itn": True} - result_format: Result format for ASR response, defaults to "message" - """ - # Initialize base parameters with defaults (ASR doesn't use these parameters) - super().__init__( - model=model, - api_key=api_key, - ) - - # ASR-specific parameters - self.dashscope_base_url = dashscope_base_url - # Default asr_options with enable_itn enabled - if asr_options is None: - self.asr_options = {"enable_itn": True} - else: - self.asr_options = asr_options - self.result_format = result_format + _provider_name = "qwen_asr" + _class_path = "powermem.integrations.llm.qwen_asr.QwenASR" + + model_config = settings_config("LLM_", extra="forbid", env_file=None) + + # Override base fields with ASR-specific validation_alias + api_key: Optional[str] = Field( + default=None, + validation_alias=AliasChoices( + "api_key", + "LLM_API_KEY", + "QWEN_API_KEY", + "DASHSCOPE_API_KEY", + ), + description="DashScope API key for Qwen ASR" + ) + # ASR-specific fields + dashscope_base_url: Optional[str] = Field( + default=None, + validation_alias=AliasChoices( + "dashscope_base_url", + "QWEN_LLM_BASE_URL", + ), + description="DashScope API base URL" + ) + + asr_options: Optional[dict] = Field( + default_factory=lambda: {"enable_itn": True}, + description="ASR-specific options (e.g., language, enable_itn)" + ) + + result_format: str = Field( + default="message", + description="Result format for ASR response" + ) diff --git a/src/powermem/integrations/llm/config/siliconflow.py b/src/powermem/integrations/llm/config/siliconflow.py new file mode 100644 index 0000000..dd6db61 --- /dev/null +++ b/src/powermem/integrations/llm/config/siliconflow.py @@ -0,0 +1,42 @@ +from typing import Optional + +from pydantic import AliasChoices, Field + +from powermem.integrations.llm.config.openai import OpenAIConfig +from powermem.settings import settings_config + + +class SiliconFlowConfig(OpenAIConfig): + """ + Configuration class for SiliconFlow-specific parameters. + SiliconFlow is OpenAI-compatible, so it inherits from OpenAIConfig. + Only overrides provider-specific metadata and fields. + """ + + _provider_name = "siliconflow" + _class_path = "powermem.integrations.llm.siliconflow.SiliconFlowLLM" + + model_config = settings_config("LLM_", extra="forbid", env_file=None) + + # Override api_key to add SiliconFlow-specific validation_alias + api_key: Optional[str] = Field( + default=None, + validation_alias=AliasChoices( + "api_key", + "LLM_API_KEY", + "OPENAI_API_KEY", + "SILICONFLOW_API_KEY", + ), + description="SiliconFlow API key" + ) + + # Override openai_base_url with SiliconFlow default + openai_base_url: Optional[str] = Field( + default="https://api.siliconflow.cn/v1", + validation_alias=AliasChoices( + "openai_base_url", + "OPENAI_LLM_BASE_URL", + "SILICONFLOW_LLM_BASE_URL", + ), + description="SiliconFlow API base URL (OpenAI-compatible)" + ) \ No newline at end of file diff --git a/src/powermem/integrations/llm/config/vllm.py b/src/powermem/integrations/llm/config/vllm.py index 91dfa87..e564372 100644 --- a/src/powermem/integrations/llm/config/vllm.py +++ b/src/powermem/integrations/llm/config/vllm.py @@ -1,6 +1,9 @@ from typing import Optional +from pydantic import AliasChoices, Field + from powermem.integrations.llm.config.base import BaseLLMConfig +from powermem.settings import settings_config class VllmConfig(BaseLLMConfig): @@ -9,48 +12,17 @@ class VllmConfig(BaseLLMConfig): Inherits from BaseLLMConfig and adds vLLM-specific settings. """ - def __init__( - self, - # Base parameters - model: Optional[str] = None, - temperature: float = 0.1, - api_key: Optional[str] = None, - max_tokens: int = 2000, - top_p: float = 0.1, - top_k: int = 1, - enable_vision: bool = False, - vision_details: Optional[str] = "auto", - http_client_proxies: Optional[dict] = None, - # vLLM-specific parameters - vllm_base_url: Optional[str] = None, - ): - """ - Initialize vLLM configuration. + _provider_name = "vllm" + _class_path = "powermem.integrations.llm.vllm.VllmLLM" - Args: - model: vLLM model to use, defaults to None - temperature: Controls randomness, defaults to 0.1 - api_key: vLLM API key, defaults to None - max_tokens: Maximum tokens to generate, defaults to 2000 - top_p: Nucleus sampling parameter, defaults to 0.1 - top_k: Top-k sampling parameter, defaults to 1 - enable_vision: Enable vision capabilities, defaults to False - vision_details: Vision detail level, defaults to "auto" - http_client_proxies: HTTP client proxy settings, defaults to None - vllm_base_url: vLLM base URL, defaults to None - """ - # Initialize base parameters - super().__init__( - model=model, - temperature=temperature, - api_key=api_key, - max_tokens=max_tokens, - top_p=top_p, - top_k=top_k, - enable_vision=enable_vision, - vision_details=vision_details, - http_client_proxies=http_client_proxies, - ) + model_config = settings_config("LLM_", extra="forbid", env_file=None) - # vLLM-specific parameters - self.vllm_base_url = vllm_base_url or "http://localhost:8000/v1" + # vLLM-specific fields + vllm_base_url: Optional[str] = Field( + default="http://localhost:8000/v1", + validation_alias=AliasChoices( + "vllm_base_url", + "VLLM_LLM_BASE_URL", + ), + description="vLLM base URL" + ) diff --git a/src/powermem/integrations/llm/config/zai.py b/src/powermem/integrations/llm/config/zai.py index ce12eb8..335d486 100644 --- a/src/powermem/integrations/llm/config/zai.py +++ b/src/powermem/integrations/llm/config/zai.py @@ -1,6 +1,9 @@ -from typing import Any, Callable, List, Optional +from typing import Any, Callable, Optional + +from pydantic import AliasChoices, Field from powermem.integrations.llm.config.base import BaseLLMConfig +from powermem.settings import settings_config class ZaiConfig(BaseLLMConfig): @@ -11,54 +14,35 @@ class ZaiConfig(BaseLLMConfig): Reference: https://docs.bigmodel.cn/cn/guide/develop/python/introduction """ - def __init__( - self, - # Base parameters - model: Optional[str] = None, - temperature: float = 0.1, - api_key: Optional[str] = None, - max_tokens: int = 2000, - top_p: float = 0.1, - top_k: int = 1, - enable_vision: bool = False, - vision_details: Optional[str] = "auto", - http_client_proxies: Optional[dict] = None, - # Zhipu AI-specific parameters - zai_base_url: Optional[str] = None, - # Response monitoring callback - response_callback: Optional[Callable[[Any, dict, dict], None]] = None, - ): - """ - Initialize Zhipu AI configuration. + _provider_name = "zai" + _class_path = "powermem.integrations.llm.zai.ZaiLLM" - Args: - model: Zhipu AI model to use (e.g., 'glm-4.7', 'glm-4.6v'), defaults to None - temperature: Controls randomness, defaults to 0.1 - api_key: Zhipu AI API key, defaults to None - max_tokens: Maximum tokens to generate, defaults to 2000 - top_p: Nucleus sampling parameter, defaults to 0.1 - top_k: Top-k sampling parameter, defaults to 1 - enable_vision: Enable vision capabilities (use glm-4.6v model), defaults to False - vision_details: Vision detail level, defaults to "auto" - http_client_proxies: HTTP client proxy settings, defaults to None - zai_base_url: Zhipu AI API base URL, defaults to None - response_callback: Optional callback for monitoring LLM responses. - """ - # Initialize base parameters - super().__init__( - model=model, - temperature=temperature, - api_key=api_key, - max_tokens=max_tokens, - top_p=top_p, - top_k=top_k, - enable_vision=enable_vision, - vision_details=vision_details, - http_client_proxies=http_client_proxies, - ) + model_config = settings_config("LLM_", extra="forbid", env_file=None) - # Zhipu AI-specific parameters - self.zai_base_url = zai_base_url or "https://open.bigmodel.cn/api/paas/v4/" + # Override base fields with Zhipu AI-specific validation_alias + api_key: Optional[str] = Field( + default=None, + validation_alias=AliasChoices( + "api_key", + "LLM_API_KEY", + "ZAI_API_KEY", + "ZHIPU_API_KEY", + ), + description="Zhipu AI API key" + ) - # Response monitoring - self.response_callback = response_callback + # Zhipu AI-specific fields + zai_base_url: Optional[str] = Field( + default="https://open.bigmodel.cn/api/paas/v4/", + validation_alias=AliasChoices( + "zai_base_url", + "ZAI_BASE_URL", + ), + description="Zhipu AI API base URL" + ) + + response_callback: Optional[Callable[[Any, dict, dict], None]] = Field( + default=None, + exclude=True, + description="Optional callback for monitoring LLM responses" + ) diff --git a/src/powermem/integrations/llm/configs.py b/src/powermem/integrations/llm/configs.py deleted file mode 100644 index 5a1c0b5..0000000 --- a/src/powermem/integrations/llm/configs.py +++ /dev/null @@ -1,32 +0,0 @@ -from typing import Optional - -from pydantic import BaseModel, Field, field_validator - -from powermem.integrations.llm.factory import LLMFactory - - -class LLMConfig(BaseModel): - provider: str = Field(description="Provider of the LLM (e.g., 'ollama', 'openai')", default="openai") - config: Optional[dict] = Field(description="Configuration for the specific LLM", default={}) - - @field_validator("config") - def validate_config(cls, v, info): - provider = info.data.get("provider") - initialized_providers = ( - "openai", - "ollama", - "anthropic", - "openai_structured", - "azure", - "gemini", - "deepseek", - "vllm", - "langchain", - "qwen", - "siliconflow", - "zai", - ) - if provider in initialized_providers or provider in LLMFactory.provider_to_class: - return v - else: - raise ValueError(f"Unsupported LLM provider: {provider}") diff --git a/src/powermem/integrations/llm/deepseek.py b/src/powermem/integrations/llm/deepseek.py index 2c717f6..b6b9997 100644 --- a/src/powermem/integrations/llm/deepseek.py +++ b/src/powermem/integrations/llm/deepseek.py @@ -36,7 +36,7 @@ def __init__(self, config: Optional[Union[BaseLLMConfig, DeepSeekConfig, Dict]] self.config.model = "deepseek-chat" api_key = self.config.api_key or os.getenv("DEEPSEEK_API_KEY") - base_url = self.config.deepseek_base_url or os.getenv("DEEPSEEK_API_BASE") or "https://api.deepseek.com" + base_url = getattr(self.config, "deepseek_base_url", None) or os.getenv("DEEPSEEK_API_BASE") or "https://api.deepseek.com" self.client = OpenAI(api_key=api_key, base_url=base_url) def _parse_response(self, response, tools): diff --git a/src/powermem/integrations/llm/factory.py b/src/powermem/integrations/llm/factory.py index 635a939..45972e4 100644 --- a/src/powermem/integrations/llm/factory.py +++ b/src/powermem/integrations/llm/factory.py @@ -5,10 +5,14 @@ from powermem.integrations.llm.config.azure import AzureOpenAIConfig from powermem.integrations.llm.config.base import BaseLLMConfig from powermem.integrations.llm.config.deepseek import DeepSeekConfig +from powermem.integrations.llm.config.gemini import GeminiConfig +from powermem.integrations.llm.config.langchain import LangchainConfig from powermem.integrations.llm.config.ollama import OllamaConfig from powermem.integrations.llm.config.openai import OpenAIConfig +from powermem.integrations.llm.config.openai_structured import OpenAIStructuredConfig from powermem.integrations.llm.config.qwen import QwenConfig from powermem.integrations.llm.config.qwen_asr import QwenASRConfig +from powermem.integrations.llm.config.siliconflow import SiliconFlowConfig from powermem.integrations.llm.config.vllm import VllmConfig from powermem.integrations.llm.config.zai import ZaiConfig @@ -22,26 +26,9 @@ def load_class(class_type): class LLMFactory: """ Factory for creating LLM instances with appropriate configurations. - Supports both old-style BaseLLMConfig and new provider-specific configs. + Uses provider registration mechanism from BaseLLMConfig. """ - # Provider mappings with their config classes - provider_to_class = { - "ollama": ("powermem.integrations.llm.ollama.OllamaLLM", OllamaConfig), - "openai": ("powermem.integrations.llm.openai.OpenAILLM", OpenAIConfig), - "openai_structured": ("powermem.integrations.llm.openai_structured.OpenAIStructuredLLM", OpenAIConfig), - "anthropic": ("powermem.integrations.llm.anthropic.AnthropicLLM", AnthropicConfig), - "azure": ("powermem.integrations.llm.azure.AzureLLM", AzureOpenAIConfig), - "gemini": ("powermem.integrations.llm.gemini.GeminiLLM", BaseLLMConfig), - "deepseek": ("powermem.integrations.llm.deepseek.DeepSeekLLM", DeepSeekConfig), - "vllm": ("powermem.integrations.llm.vllm.VllmLLM", VllmConfig), - "langchain": ("powermem.integrations.llm.langchain.LangchainLLM", BaseLLMConfig), - "qwen": ("powermem.integrations.llm.qwen.QwenLLM", QwenConfig), - "qwen_asr": ("powermem.integrations.llm.qwen_asr.QwenASR", QwenASRConfig), - "siliconflow": ("powermem.integrations.llm.siliconflow.SiliconFlowLLM", OpenAIConfig), - "zai": ("powermem.integrations.llm.zai.ZaiLLM", ZaiConfig), - } - @classmethod def create(cls, provider_name: str, config: Optional[Union[BaseLLMConfig, Dict]] = None, **kwargs): """ @@ -49,8 +36,8 @@ def create(cls, provider_name: str, config: Optional[Union[BaseLLMConfig, Dict]] Args: provider_name (str): The provider name (e.g., 'openai', 'anthropic') - config: Configuration object or dict. If None, will create default config - **kwargs: Additional configuration parameters + config: Configuration object or dict. If None, will create default config from environment + **kwargs: Additional configuration parameters (overrides) Returns: Configured LLM instance @@ -58,45 +45,34 @@ def create(cls, provider_name: str, config: Optional[Union[BaseLLMConfig, Dict]] Raises: ValueError: If provider is not supported """ - if provider_name not in cls.provider_to_class: + # 1. Get class_path from registry + class_path = BaseLLMConfig.get_provider_class_path(provider_name) + if not class_path: raise ValueError(f"Unsupported Llm provider: {provider_name}") - class_type, config_class = cls.provider_to_class[provider_name] - llm_class = load_class(class_type) + # 2. Get config_cls from registry + config_cls = BaseLLMConfig.get_provider_config_cls(provider_name) or BaseLLMConfig - # Handle configuration + # 3. Handle configuration if config is None: - # Create default config with kwargs - config = config_class(**kwargs) + # Create default config from environment variables + provider_settings = config_cls() elif isinstance(config, dict): - # Merge dict config with kwargs - config.update(kwargs) - config = config_class(**config) + # Create config from dict + provider_settings = config_cls(**config) elif isinstance(config, BaseLLMConfig): - # Convert base config to provider-specific config if needed - if config_class != BaseLLMConfig: - # Convert to provider-specific config - config_dict = { - "model": config.model, - "temperature": config.temperature, - "api_key": config.api_key, - "max_tokens": config.max_tokens, - "top_p": config.top_p, - "top_k": config.top_k, - "enable_vision": config.enable_vision, - "vision_details": config.vision_details, - "http_client_proxies": config.http_client, - } - config_dict.update(kwargs) - config = config_class(**config_dict) - else: - # Use base config as-is - pass + # Use existing config as-is + provider_settings = config else: - # Assume it's already the correct config type - pass + raise TypeError(f"config must be BaseLLMConfig, dict, or None, got {type(config)}") + + # 4. Apply overrides (kwargs) + if kwargs: + provider_settings = provider_settings.model_copy(update=kwargs) - return llm_class(config) + # 5. Create LLM instance + llm_class = load_class(class_path) + return llm_class(provider_settings) @classmethod def register_provider(cls, name: str, class_path: str, config_class=None): @@ -110,7 +86,10 @@ def register_provider(cls, name: str, class_path: str, config_class=None): """ if config_class is None: config_class = BaseLLMConfig - cls.provider_to_class[name] = (class_path, config_class) + + # Register directly in BaseLLMConfig registry + BaseLLMConfig._registry[name] = config_class + BaseLLMConfig._class_paths[name] = class_path @classmethod def get_supported_providers(cls) -> list: @@ -120,4 +99,4 @@ def get_supported_providers(cls) -> list: Returns: list: List of supported provider names """ - return list(cls.provider_to_class.keys()) + return list(BaseLLMConfig._registry.keys()) diff --git a/src/powermem/integrations/llm/ollama.py b/src/powermem/integrations/llm/ollama.py index 153a597..92569ac 100644 --- a/src/powermem/integrations/llm/ollama.py +++ b/src/powermem/integrations/llm/ollama.py @@ -36,7 +36,7 @@ def __init__(self, config: Optional[Union[BaseLLMConfig, OllamaConfig, Dict]] = if not self.config.model: self.config.model = "llama3.1:70b" - self.client = Client(host=self.config.ollama_base_url) + self.client = Client(host=getattr(self.config, "ollama_base_url", "http://localhost:11434")) def _parse_response(self, response, tools): """ diff --git a/src/powermem/integrations/llm/openai.py b/src/powermem/integrations/llm/openai.py index c4b2d76..99d5db0 100644 --- a/src/powermem/integrations/llm/openai.py +++ b/src/powermem/integrations/llm/openai.py @@ -39,13 +39,13 @@ def __init__(self, config: Optional[Union[BaseLLMConfig, OpenAIConfig, Dict]] = if os.environ.get("OPENROUTER_API_KEY"): # Use OpenRouter self.client = OpenAI( api_key=os.environ.get("OPENROUTER_API_KEY"), - base_url=self.config.openrouter_base_url + base_url=getattr(self.config, "openrouter_base_url", None) or os.getenv("OPENROUTER_API_BASE") or "https://openrouter.ai/api/v1", ) else: api_key = self.config.api_key or os.getenv("OPENAI_API_KEY") - base_url = self.config.openai_base_url or os.getenv("OPENAI_BASE_URL") or "https://api.openai.com/v1" + base_url = getattr(self.config, "openai_base_url", None) or os.getenv("OPENAI_BASE_URL") or "https://api.openai.com/v1" self.client = OpenAI(api_key=api_key, base_url=base_url) @@ -129,15 +129,18 @@ def generate_response( if os.getenv("OPENROUTER_API_KEY"): openrouter_params = {} - if self.config.models: - openrouter_params["models"] = self.config.models - openrouter_params["route"] = self.config.route + models = getattr(self.config, "models", None) + if models: + openrouter_params["models"] = models + openrouter_params["route"] = getattr(self.config, "route", "fallback") params.pop("model") - if self.config.site_url and self.config.app_name: + site_url = getattr(self.config, "site_url", None) + app_name = getattr(self.config, "app_name", None) + if site_url and app_name: extra_headers = { - "HTTP-Referer": self.config.site_url, - "X-Title": self.config.app_name, + "HTTP-Referer": site_url, + "X-Title": app_name, } openrouter_params["extra_headers"] = extra_headers @@ -156,9 +159,10 @@ def generate_response( params["tool_choice"] = tool_choice response = self.client.chat.completions.create(**params) parsed_response = self._parse_response(response, tools) - if self.config.response_callback: + response_callback = getattr(self.config, "response_callback", None) + if response_callback: try: - self.config.response_callback(self, response, params) + response_callback(self, response, params) except Exception as e: # Log error but don't propagate logging.error(f"Error due to callback: {e}") diff --git a/src/powermem/integrations/llm/openai_structured.py b/src/powermem/integrations/llm/openai_structured.py index ae22a31..c26860f 100644 --- a/src/powermem/integrations/llm/openai_structured.py +++ b/src/powermem/integrations/llm/openai_structured.py @@ -16,7 +16,7 @@ def __init__(self, config: Optional[BaseLLMConfig] = None): self.config.model = "gpt-5" api_key = self.config.api_key or os.getenv("OPENAI_API_KEY") - base_url = self.config.openai_base_url or os.getenv("OPENAI_API_BASE") or "https://api.openai.com/v1" + base_url = getattr(self.config, "openai_base_url", None) or os.getenv("OPENAI_API_BASE") or "https://api.openai.com/v1" self.client = OpenAI(api_key=api_key, base_url=base_url) def generate_response( diff --git a/src/powermem/integrations/llm/qwen.py b/src/powermem/integrations/llm/qwen.py index 57969e3..aaf76b7 100644 --- a/src/powermem/integrations/llm/qwen.py +++ b/src/powermem/integrations/llm/qwen.py @@ -62,7 +62,7 @@ def __init__(self, config: Optional[Union[BaseLLMConfig, QwenConfig, Dict]] = No dashscope.api_key = api_key # Set base URL - base_url = self.config.dashscope_base_url or os.getenv( + base_url = getattr(self.config, "dashscope_base_url", None) or os.getenv( "DASHSCOPE_BASE_URL") or "https://dashscope.aliyuncs.com/api/v1" if base_url: @@ -175,10 +175,11 @@ def generate_response( } # Add Qwen-specific parameters - if self.config.enable_search: + if getattr(self.config, "enable_search", False): generation_params["enable_search"] = True - if self.config.search_params: - generation_params.update(self.config.search_params) + search_params = getattr(self.config, "search_params", None) + if search_params: + generation_params.update(search_params) # Add tools if provided if tools: @@ -193,9 +194,10 @@ def generate_response( response = Generation.call(**generation_params) parsed_response = self._parse_response(response, tools) - if self.config.response_callback: + response_callback = getattr(self.config, "response_callback", None) + if response_callback: try: - self.config.response_callback(self, response, generation_params) + response_callback(self, response, generation_params) except Exception as e: # Log error but don't propagate logging.error(f"Error due to callback: {e}") diff --git a/src/powermem/integrations/llm/qwen_asr.py b/src/powermem/integrations/llm/qwen_asr.py index b917fda..ca186b0 100644 --- a/src/powermem/integrations/llm/qwen_asr.py +++ b/src/powermem/integrations/llm/qwen_asr.py @@ -55,7 +55,7 @@ def __init__(self, config: Optional[Union[BaseLLMConfig, QwenASRConfig, Dict]] = dashscope.api_key = api_key # Set base URL - base_url = self.config.dashscope_base_url or os.getenv( + base_url = getattr(self.config, "dashscope_base_url", None) or os.getenv( "DASHSCOPE_BASE_URL") or "https://dashscope.aliyuncs.com/api/v1" if base_url: @@ -119,11 +119,12 @@ def generate_response( "api_key": self.config.api_key or os.getenv("DASHSCOPE_API_KEY"), "model": self.config.model, "messages": messages, - "result_format": self.config.result_format, + "result_format": getattr(self.config, "result_format", "message"), } # Add ASR options - asr_options = kwargs.get("asr_options", self.config.asr_options) + config_asr_options = getattr(self.config, "asr_options", None) + asr_options = kwargs.get("asr_options", config_asr_options) if asr_options: asr_params["asr_options"] = asr_options @@ -167,5 +168,5 @@ def transcribe( } ] - return self.generate_response(messages, asr_options=asr_options or self.config.asr_options) + return self.generate_response(messages, asr_options=asr_options or getattr(self.config, "asr_options", None)) diff --git a/src/powermem/integrations/llm/siliconflow.py b/src/powermem/integrations/llm/siliconflow.py index 0c48fcb..4bf534e 100644 --- a/src/powermem/integrations/llm/siliconflow.py +++ b/src/powermem/integrations/llm/siliconflow.py @@ -47,7 +47,7 @@ def __init__(self, config: Optional[Union[BaseLLMConfig, OpenAIConfig, Dict]] = api_key = self.config.api_key or os.getenv("SILICONFLOW_API_KEY") or os.getenv("LLM_API_KEY") # Default base URL for SiliconFlow base_url = ( - self.config.openai_base_url + getattr(self.config, "openai_base_url", None) or os.getenv("SILICONFLOW_LLM_BASE_URL") or "https://api.siliconflow.cn/v1" ) @@ -141,9 +141,10 @@ def generate_response( response = self.client.chat.completions.create(**params) parsed_response = self._parse_response(response, tools) - if self.config.response_callback: + response_callback = getattr(self.config, "response_callback", None) + if response_callback: try: - self.config.response_callback(self, response, params) + response_callback(self, response, params) except Exception as e: # Log error but don't propagate logging.error(f"Error due to callback: {e}") diff --git a/src/powermem/integrations/llm/vllm.py b/src/powermem/integrations/llm/vllm.py index 1787806..53642c5 100644 --- a/src/powermem/integrations/llm/vllm.py +++ b/src/powermem/integrations/llm/vllm.py @@ -36,7 +36,7 @@ def __init__(self, config: Optional[Union[BaseLLMConfig, VllmConfig, Dict]] = No self.config.model = "Qwen/Qwen2.5-32B-Instruct" self.config.api_key = self.config.api_key or os.getenv("VLLM_API_KEY") or "vllm-api-key" - base_url = self.config.vllm_base_url or os.getenv("VLLM_LLM_BASE_URL") + base_url = getattr(self.config, "vllm_base_url", None) or os.getenv("VLLM_LLM_BASE_URL") self.client = OpenAI(api_key=self.config.api_key, base_url=base_url) def _parse_response(self, response, tools): diff --git a/src/powermem/integrations/llm/zai.py b/src/powermem/integrations/llm/zai.py index 676b2ac..ab1e0b7 100644 --- a/src/powermem/integrations/llm/zai.py +++ b/src/powermem/integrations/llm/zai.py @@ -142,9 +142,10 @@ def generate_response( response = self.client.chat.completions.create(**params) parsed_response = self._parse_response(response, tools) - if self.config.response_callback: + response_callback = getattr(self.config, "response_callback", None) + if response_callback: try: - self.config.response_callback(self, response, params) + response_callback(self, response, params) except Exception as e: # Log error but don't propagate logging.error(f"Error due to callback: {e}") diff --git a/src/powermem/storage/configs.py b/src/powermem/storage/configs.py index add7e42..f5c3e8f 100644 --- a/src/powermem/storage/configs.py +++ b/src/powermem/storage/configs.py @@ -8,7 +8,7 @@ from pydantic import BaseModel, Field, model_validator -from powermem.integrations.llm.configs import LLMConfig +from powermem.integrations.llm.config.base import BaseLLMConfig from powermem.storage.config.oceanbase import OceanBaseGraphConfig from powermem.storage.factory import VectorStoreFactory @@ -109,7 +109,7 @@ class GraphStoreConfig(BaseModel): description="Configuration for the specific data store", default=None ) - llm: Optional[LLMConfig] = Field( + llm: Optional[BaseLLMConfig] = Field( description="LLM configuration for querying the graph store", default=None ) diff --git a/tests/regression/test_scenario_7_multimodal.py b/tests/regression/test_scenario_7_multimodal.py index bc109be..26e45a6 100644 --- a/tests/regression/test_scenario_7_multimodal.py +++ b/tests/regression/test_scenario_7_multimodal.py @@ -36,15 +36,29 @@ print(f" 2. Edit {env_path} and add your API keys") print("\n create_memory will fall back to mock providers if keys are missing.") -# Get API key from environment variable (GitHub Secrets) or .env file -# Priority: DASHSCOPE_API_KEY (GitHub Secrets) > LLM_API_KEY > .env file > default fallback -dashscope_api_key = os.getenv("QWEN_API_KEY") +# Get API key from environment variable with multiple fallback options +# Priority: QWEN_API_KEY > DASHSCOPE_API_KEY > LLM_API_KEY > EMBEDDING_API_KEY +# This follows the same pattern as config_loader.py LLMSettings +dashscope_api_key = ( + os.getenv("QWEN_API_KEY") or + os.getenv("DASHSCOPE_API_KEY") or + os.getenv("LLM_API_KEY") or + os.getenv("EMBEDDING_API_KEY") +) + # Handle empty string from GitHub Secrets (if secret is not set, it returns empty string) -if not dashscope_api_key or dashscope_api_key.strip() == "": - # Fallback to default for local development (not recommended for production) - print("⚠ Warning: Using default API key. For production, set QWEN environment variable or GitHub Secret.") +if dashscope_api_key: + dashscope_api_key = dashscope_api_key.strip() + +if not dashscope_api_key: + # Skip all tests in this module if no API key is found + pytest.skip( + "No API key found. Please set one of: QWEN_API_KEY, DASHSCOPE_API_KEY, LLM_API_KEY, or EMBEDDING_API_KEY\n" + f"You can also create a .env file at: {env_path}", + allow_module_level=True + ) else: - print("✓ API key loaded from environment variable or GitHub Secrets") + print(f"✓ API key loaded successfully (length: {len(dashscope_api_key)})") custom_config = { "llm": { From fafcf640a396363c2ac9df0bfc7d08b439e7f20f Mon Sep 17 00:00:00 2001 From: "jingshun.tq" <35712518+Teingi@users.noreply.github.com> Date: Mon, 2 Feb 2026 14:43:44 +0800 Subject: [PATCH 04/23] refactor: unify configuration governance for agent, core, and server modules (#199) --- src/powermem/agent/agent.py | 82 +++++++++++++++++++++++------------ src/powermem/config_loader.py | 45 +++++++++---------- src/powermem/core/memory.py | 31 ++++++++++++- src/powermem/utils/utils.py | 52 ++++++++++++++++------ src/server/config.py | 9 ++-- 5 files changed, 149 insertions(+), 70 deletions(-) diff --git a/src/powermem/agent/agent.py b/src/powermem/agent/agent.py index 275d94d..b209a91 100644 --- a/src/powermem/agent/agent.py +++ b/src/powermem/agent/agent.py @@ -221,15 +221,24 @@ def _initialize_auto(self) -> None: self._initialize_multi_agent() def _get_default_multi_agent_config(self) -> Dict[str, Any]: - """Get default multi-agent configuration with environment variable support.""" - import os - - # Get environment variables with defaults - enabled = os.getenv('AGENT_ENABLED', 'true').lower() == 'true' - default_scope = os.getenv('AGENT_DEFAULT_SCOPE', 'AGENT') - default_privacy_level = os.getenv('AGENT_DEFAULT_PRIVACY_LEVEL', 'PRIVATE') - default_collaboration_level = os.getenv('AGENT_DEFAULT_COLLABORATION_LEVEL', 'READ_ONLY') - default_access_permission = os.getenv('AGENT_DEFAULT_ACCESS_PERMISSION', 'OWNER_ONLY') + """Get default multi-agent configuration from config or use AgentMemorySettings defaults.""" + # Try to get from config first + agent_config = None + if 'agent_memory' in self.config: + agent_config = self.config['agent_memory'] + if isinstance(agent_config, ConfigObject): + agent_config = agent_config._data if hasattr(agent_config, '_data') else agent_config.to_dict() + + # Use AgentMemorySettings defaults (aligned with config_loader) + from powermem.config_loader import AgentMemorySettings + settings = AgentMemorySettings() + + # Extract values from config if available, otherwise use Settings defaults + enabled = agent_config.get('enabled', settings.enabled) if agent_config else settings.enabled + default_scope = agent_config.get('default_scope', settings.default_scope) if agent_config else settings.default_scope + default_privacy_level = agent_config.get('default_privacy_level', settings.default_privacy_level) if agent_config else settings.default_privacy_level + default_collaboration_level = agent_config.get('default_collaboration_level', settings.default_collaboration_level) if agent_config else settings.default_collaboration_level + default_access_permission = agent_config.get('default_access_permission', settings.default_access_permission) if agent_config else settings.default_access_permission return { 'enabled': enabled, @@ -267,15 +276,25 @@ def _get_default_multi_agent_config(self) -> Dict[str, Any]: } def _get_default_multi_user_config(self) -> Dict[str, Any]: - """Get default multi-user configuration with environment variable support.""" - import os - - # Get environment variables with defaults - enabled = os.getenv('AGENT_ENABLED', 'true').lower() == 'true' - default_scope = os.getenv('AGENT_DEFAULT_SCOPE', 'USER_GROUP') - default_privacy_level = os.getenv('AGENT_DEFAULT_PRIVACY_LEVEL', 'PRIVATE') - default_collaboration_level = os.getenv('AGENT_DEFAULT_COLLABORATION_LEVEL', 'READ_ONLY') - default_access_permission = os.getenv('AGENT_DEFAULT_ACCESS_PERMISSION', 'OWNER_ONLY') + """Get default multi-user configuration from config or use AgentMemorySettings defaults.""" + # Try to get from config first + agent_config = None + if 'agent_memory' in self.config: + agent_config = self.config['agent_memory'] + if isinstance(agent_config, ConfigObject): + agent_config = agent_config._data if hasattr(agent_config, '_data') else agent_config.to_dict() + + # Use AgentMemorySettings defaults (aligned with config_loader) + from powermem.config_loader import AgentMemorySettings + settings = AgentMemorySettings() + + # Extract values from config if available, otherwise use Settings defaults + # For multi_user mode, default_scope should be 'USER_GROUP' if not specified + enabled = agent_config.get('enabled', settings.enabled) if agent_config else settings.enabled + default_scope = agent_config.get('default_scope', 'USER_GROUP') if agent_config else 'USER_GROUP' + default_privacy_level = agent_config.get('default_privacy_level', settings.default_privacy_level) if agent_config else settings.default_privacy_level + default_collaboration_level = agent_config.get('default_collaboration_level', settings.default_collaboration_level) if agent_config else settings.default_collaboration_level + default_access_permission = agent_config.get('default_access_permission', settings.default_access_permission) if agent_config else settings.default_access_permission return { 'enabled': enabled, @@ -300,15 +319,24 @@ def _get_default_multi_user_config(self) -> Dict[str, Any]: } def _get_default_hybrid_config(self) -> Dict[str, Any]: - """Get default hybrid configuration with environment variable support.""" - import os - - # Get environment variables with defaults - enabled = os.getenv('AGENT_ENABLED', 'true').lower() == 'true' - default_scope = os.getenv('AGENT_DEFAULT_SCOPE', 'AGENT') - default_privacy_level = os.getenv('AGENT_DEFAULT_PRIVACY_LEVEL', 'PRIVATE') - default_collaboration_level = os.getenv('AGENT_DEFAULT_COLLABORATION_LEVEL', 'READ_ONLY') - default_access_permission = os.getenv('AGENT_DEFAULT_ACCESS_PERMISSION', 'OWNER_ONLY') + """Get default hybrid configuration from config or use AgentMemorySettings defaults.""" + # Try to get from config first + agent_config = None + if 'agent_memory' in self.config: + agent_config = self.config['agent_memory'] + if isinstance(agent_config, ConfigObject): + agent_config = agent_config._data if hasattr(agent_config, '_data') else agent_config.to_dict() + + # Use AgentMemorySettings defaults (aligned with config_loader) + from powermem.config_loader import AgentMemorySettings + settings = AgentMemorySettings() + + # Extract values from config if available, otherwise use Settings defaults + enabled = agent_config.get('enabled', settings.enabled) if agent_config else settings.enabled + default_scope = agent_config.get('default_scope', settings.default_scope) if agent_config else settings.default_scope + default_privacy_level = agent_config.get('default_privacy_level', settings.default_privacy_level) if agent_config else settings.default_privacy_level + default_collaboration_level = agent_config.get('default_collaboration_level', settings.default_collaboration_level) if agent_config else settings.default_collaboration_level + default_access_permission = agent_config.get('default_access_permission', settings.default_access_permission) if agent_config else settings.default_access_permission return { 'enabled': enabled, diff --git a/src/powermem/config_loader.py b/src/powermem/config_loader.py index 894be2e..b9e8233 100644 --- a/src/powermem/config_loader.py +++ b/src/powermem/config_loader.py @@ -422,28 +422,17 @@ def to_config(self) -> Dict[str, Any]: class MemoryDecaySettings(_BasePowermemSettings): - model_config = settings_config() + model_config = settings_config("MEMORY_DECAY_") - enabled: bool = Field( - default=True, - validation_alias=AliasChoices("MEMORY_DECAY_ENABLED"), - ) - algorithm: str = Field( - default="ebbinghaus", - validation_alias=AliasChoices("MEMORY_DECAY_ALGORITHM"), - ) - base_retention: float = Field( - default=1.0, - validation_alias=AliasChoices("MEMORY_DECAY_BASE_RETENTION"), - ) - forgetting_rate: float = Field( - default=0.1, - validation_alias=AliasChoices("MEMORY_DECAY_FORGETTING_RATE"), - ) - reinforcement_factor: float = Field( - default=0.3, - validation_alias=AliasChoices("MEMORY_DECAY_REINFORCEMENT_FACTOR"), - ) + enabled: bool = Field(default=True) + algorithm: str = Field(default="ebbinghaus") + base_retention: float = Field(default=1.0) + forgetting_rate: float = Field(default=0.1) + reinforcement_factor: float = Field(default=0.3) + + def to_config(self) -> Dict[str, Any]: + """Convert MemoryDecaySettings to config dict.""" + return self.model_dump() class AgentMemorySettings(_BasePowermemSettings): @@ -576,6 +565,10 @@ class PerformanceSettings(_BasePowermemSettings): validation_alias=AliasChoices("VECTOR_STORE_INDEX_REBUILD_INTERVAL"), ) + def to_config(self) -> Dict[str, Any]: + """Convert PerformanceSettings to config dict.""" + return self.model_dump() + class SecuritySettings(_BasePowermemSettings): model_config = settings_config() @@ -605,6 +598,10 @@ class SecuritySettings(_BasePowermemSettings): validation_alias=AliasChoices("ACCESS_CONTROL_ADMIN_USERS"), ) + def to_config(self) -> Dict[str, Any]: + """Convert SecuritySettings to config dict.""" + return self.model_dump() + class GraphStoreSettings(_BasePowermemSettings): model_config = settings_config("GRAPH_STORE_") @@ -794,15 +791,15 @@ class PowermemSettings: "telemetry": ("telemetry", TelemetrySettings), "audit": ("audit", AuditSettings), "logging": ("logging", LoggingSettings), + "performance": ("performance", PerformanceSettings), + "security": ("security", SecuritySettings), + "memory_decay": ("memory_decay", MemoryDecaySettings), } def __init__(self) -> None: for _, (attr_name, component_cls) in self._COMPONENTS.items(): setattr(self, attr_name, component_cls()) self.graph_store = GraphStoreSettings() - self.memory_decay = MemoryDecaySettings() - self.performance = PerformanceSettings() - self.security = SecuritySettings() self.sparse_embedder = SparseEmbedderSettings() def to_config(self) -> Dict[str, Any]: diff --git a/src/powermem/core/memory.py b/src/powermem/core/memory.py index 84dc5bc..d36cb69 100644 --- a/src/powermem/core/memory.py +++ b/src/powermem/core/memory.py @@ -26,7 +26,7 @@ from .telemetry import TelemetryManager from .audit import AuditLogger from ..intelligence.plugin import IntelligentMemoryPlugin, EbbinghausIntelligencePlugin -from ..utils.utils import remove_code_blocks, convert_config_object_to_dict, parse_vision_messages +from ..utils.utils import remove_code_blocks, convert_config_object_to_dict, parse_vision_messages, set_timezone from ..prompts.intelligent_memory_prompts import ( FACT_RETRIEVAL_PROMPT, FACT_EXTRACTION_PROMPT, @@ -160,6 +160,12 @@ def __init__( self.agent_id = agent_id + # Set timezone from config if provided (priority: config > env) + timezone_config = self.config.get('timezone') + if timezone_config: + set_timezone(timezone_config) + logger.debug(f"Timezone set from config: {timezone_config}") + # Extract providers from config with fallbacks self.storage_type = storage_type or self._get_provider('vector_store', 'oceanbase') self.llm_provider = llm_provider or self._get_provider('llm', 'mock') @@ -373,6 +379,7 @@ def _get_intelligent_memory_config(self) -> Dict[str, Any]: """ Helper method to get intelligent memory configuration. Supports both "intelligence" and "intelligent_memory" config keys for backward compatibility. + Also merges "memory_decay" config into intelligent_memory config for Ebbinghaus algorithm. Returns: Merged intelligent memory configuration dictionary @@ -383,6 +390,17 @@ def _get_intelligent_memory_config(self) -> Dict[str, Any]: # Merge custom_importance_evaluation_prompt from top level if present if self.memory_config.custom_importance_evaluation_prompt: cfg["custom_importance_evaluation_prompt"] = self.memory_config.custom_importance_evaluation_prompt + # Merge memory_decay config if present (for Ebbinghaus algorithm parameters) + memory_decay_cfg = self.config.get("memory_decay", {}) + if memory_decay_cfg: + # Merge memory_decay fields into intelligent_memory config + # These fields are used by EbbinghausAlgorithm + if memory_decay_cfg.get("base_retention") is not None: + cfg["initial_retention"] = memory_decay_cfg["base_retention"] + if memory_decay_cfg.get("forgetting_rate") is not None: + cfg["decay_rate"] = memory_decay_cfg["forgetting_rate"] + if memory_decay_cfg.get("reinforcement_factor") is not None: + cfg["reinforcement_factor"] = memory_decay_cfg["reinforcement_factor"] return cfg else: # Fallback to dict access @@ -392,6 +410,17 @@ def _get_intelligent_memory_config(self) -> Dict[str, Any]: # Merge custom_importance_evaluation_prompt from top level if present if "custom_importance_evaluation_prompt" in self.config: merged_cfg["custom_importance_evaluation_prompt"] = self.config["custom_importance_evaluation_prompt"] + # Merge memory_decay config if present (for Ebbinghaus algorithm parameters) + memory_decay_cfg = (self.config or {}).get("memory_decay", {}) + if memory_decay_cfg: + # Merge memory_decay fields into intelligent_memory config + # These fields are used by EbbinghausAlgorithm + if memory_decay_cfg.get("base_retention") is not None: + merged_cfg["initial_retention"] = memory_decay_cfg["base_retention"] + if memory_decay_cfg.get("forgetting_rate") is not None: + merged_cfg["decay_rate"] = memory_decay_cfg["forgetting_rate"] + if memory_decay_cfg.get("reinforcement_factor") is not None: + merged_cfg["reinforcement_factor"] = memory_decay_cfg["reinforcement_factor"] return merged_cfg def _extract_facts(self, messages: Any) -> List[str]: diff --git a/src/powermem/utils/utils.py b/src/powermem/utils/utils.py index 50d4441..6ee5335 100644 --- a/src/powermem/utils/utils.py +++ b/src/powermem/utils/utils.py @@ -30,23 +30,44 @@ # Cache for timezone to avoid repeated lookups _timezone_cache: Optional[Any] = None +_timezone_str: Optional[str] = None # Store timezone string from config _timezone_lock = threading.Lock() +def set_timezone(timezone_str: str) -> None: + """ + Set the timezone from configuration. + + This function should be called during Memory initialization if timezone + is specified in the config. It takes precedence over environment variables. + + Args: + timezone_str: Timezone string (e.g., 'Asia/Shanghai', 'UTC') + """ + global _timezone_cache, _timezone_str + + with _timezone_lock: + _timezone_str = timezone_str + _timezone_cache = None # Reset cache to force re-initialization + + def get_timezone() -> Any: """ - Get the configured timezone from environment variable. + Get the configured timezone from config or environment variable. - This function reads the TIMEZONE environment variable to determine the timezone - to use for all datetime operations in powermem. The timezone is cached after first + This function first checks if timezone was set via set_timezone() (from config), + then falls back to TIMEZONE environment variable. The timezone is cached after first access for performance. Configuration: - Set TIMEZONE in your .env file or environment variables: - - TIMEZONE=Asia/Shanghai (for China Standard Time) - - TIMEZONE=America/New_York (for US Eastern Time) - - TIMEZONE=Europe/London (for UK Time) - - TIMEZONE=UTC (default, if not specified) + Timezone can be configured in two ways: + 1. Via config dict/JSON: Set 'timezone' in your config, which will be + automatically applied during Memory initialization. + 2. Via environment variable: Set TIMEZONE in your .env file or environment: + - TIMEZONE=Asia/Shanghai (for China Standard Time) + - TIMEZONE=America/New_York (for US Eastern Time) + - TIMEZONE=Europe/London (for UK Time) + - TIMEZONE=UTC (default, if not specified) Common timezone names: - Asia/Shanghai, Asia/Tokyo, Asia/Hong_Kong @@ -60,9 +81,9 @@ def get_timezone() -> Any: Note: The timezone is cached globally. To reset the cache (e.g., after changing - the TIMEZONE environment variable), call reset_timezone_cache(). + the timezone), call reset_timezone_cache(). """ - global _timezone_cache + global _timezone_cache, _timezone_str if _timezone_cache is not None: return _timezone_cache @@ -71,8 +92,12 @@ def get_timezone() -> Any: if _timezone_cache is not None: return _timezone_cache - # Try to get timezone from environment variable - timezone_str = os.getenv('TIMEZONE', 'UTC') + # Priority: config > environment variable > default + if _timezone_str is not None: + timezone_str = _timezone_str + else: + # Fallback to environment variable (for backward compatibility) + timezone_str = os.getenv('TIMEZONE', 'UTC') try: if _HAS_ZONEINFO: @@ -134,9 +159,10 @@ def reset_timezone_cache(): """ Reset the timezone cache. Useful for testing or when timezone changes. """ - global _timezone_cache + global _timezone_cache, _timezone_str with _timezone_lock: _timezone_cache = None + _timezone_str = None def generate_memory_id(content: str, user_id: Optional[str] = None) -> str: diff --git a/src/server/config.py b/src/server/config.py index 2e815e2..2f645ca 100644 --- a/src/server/config.py +++ b/src/server/config.py @@ -6,7 +6,9 @@ from typing import List, Optional from pydantic import Field, field_validator -from pydantic_settings import BaseSettings, SettingsConfigDict +from pydantic_settings import BaseSettings + +from powermem.settings import settings_config def _parse_boolish(value: object) -> object: @@ -32,11 +34,8 @@ def _parse_boolish(value: object) -> object: class ServerSettings(BaseSettings): - model_config = SettingsConfigDict( - env_file=".env", - env_file_encoding="utf-8", + model_config = settings_config( env_prefix="POWERMEM_SERVER_", - case_sensitive=False, extra="ignore", ) From 282f5276ab31e62e1144284743dd2f8d5d79a4a6 Mon Sep 17 00:00:00 2001 From: Even Date: Mon, 2 Feb 2026 16:42:56 +0800 Subject: [PATCH 05/23] Reconstruct setting in Rerank,Vector,Graph (#202) * feat(llm): enhance configuration management with pydantic-settings - Introduced a unified configuration system for LLM providers using pydantic-settings. - Added provider-specific settings for Anthropic, Azure, DeepSeek, Ollama, OpenAI, Qwen, Vllm, and Zai. - Improved environment variable handling and validation through Field and AliasChoices. - Removed legacy initialization methods in favor of a cleaner, more maintainable structure. - Updated LLMFactory to utilize the new provider registration mechanism. * chore: Update LLM configuration management and improve environment variable handling - Refactor LLM configuration imports to use BaseLLMConfig. - Replace direct attribute access with getattr for safer environment variable retrieval. - Remove deprecated LLMConfig and streamline related code for better maintainability. * feat: Enhance rerank configuration and integration - Introduced BaseRerankConfig for improved configuration management across rerank providers. - Updated rerank integration files to utilize the new base configuration structure. - Added support for additional configuration fields such as api_base_url and top_n. - Refactored rerank factory to accommodate new configuration handling and provider registration. - Removed deprecated RerankConfig and streamlined related code for better maintainability. - Updated API request handling in rerank classes to support custom HTTP clients. * * refactor(powermem): remove unused storage configuration management module - Removed `VectorStoreConfig` and `GraphStoreConfig` classes - Deleted associated validation logic and import statements - Streamlined codebase by eliminating unused components * feat(powermem): enhance sparse embedder configuration management - Introduced BaseSparseEmbedderConfig for unified sparse embedding configuration. - Updated MemoryConfig to utilize BaseSparseEmbedderConfig. - Refactored SparseEmbedderFactory to support new configuration handling. - Improved handling of sparse embedder settings across various components. * feat(powermem): enhance user profile storage with provider registration - Added a registry mechanism to UserProfileStoreBase for automatic provider registration. - Implemented class paths for OceanBase and SQLite user profile storage implementations. - Updated UserProfileStoreFactory to utilize the new registry for provider class retrieval. - Refactored imports to trigger auto-registration of user profile storage classes. - Improved handling of provider names in the factory for better compatibility. * feat(powermem): synchronize embedding model dimensions across configurations - Added logic to sync `embedding_model_dims` from the embedder to both `vector_store` and `graph_store` if not already set. - Updated `config_loader.py` and `configs.py` to ensure consistent embedding dimensions across components. * feat(powermem): enhance OceanBase configuration and query handling - Added `enable_native_hybrid` field to `OceanBaseConfig` for native hybrid search support. - Updated query handling in `OceanBaseVectorStore` to use a safe query format, preventing SQL injection risks. --- src/powermem/config_loader.py | 467 +++++------------- src/powermem/configs.py | 37 +- src/powermem/core/memory.py | 26 +- src/powermem/integrations/__init__.py | 4 +- .../embeddings/config/__init__.py | 4 + .../embeddings/config/sparse_base.py | 110 +++-- .../embeddings/config/sparse_providers.py | 32 ++ .../integrations/embeddings/qwen_sparse.py | 4 +- .../integrations/embeddings/sparse_base.py | 2 +- .../integrations/embeddings/sparse_factory.py | 62 +-- src/powermem/integrations/rerank/__init__.py | 14 +- .../integrations/rerank/config/__init__.py | 14 +- .../integrations/rerank/config/base.py | 127 ++++- .../integrations/rerank/config/providers.py | 127 +++++ src/powermem/integrations/rerank/configs.py | 23 - src/powermem/integrations/rerank/factory.py | 81 +-- src/powermem/integrations/rerank/generic.py | 40 +- src/powermem/integrations/rerank/jina.py | 44 +- src/powermem/integrations/rerank/qwen.py | 19 +- src/powermem/integrations/rerank/zai.py | 44 +- src/powermem/storage/__init__.py | 5 +- src/powermem/storage/config/__init__.py | 19 + src/powermem/storage/config/base.py | 266 +++++++++- src/powermem/storage/config/oceanbase.py | 241 +++++++-- src/powermem/storage/config/pgvector.py | 148 ++++-- src/powermem/storage/config/sqlite.py | 27 +- src/powermem/storage/configs.py | 162 ------ src/powermem/storage/factory.py | 162 +++++- src/powermem/storage/oceanbase/oceanbase.py | 3 +- .../storage/oceanbase/oceanbase_graph.py | 13 +- src/powermem/user_memory/storage/base.py | 31 +- src/powermem/user_memory/storage/factory.py | 24 +- .../user_memory/storage/user_profile.py | 3 + .../storage/user_profile_sqlite.py | 3 + .../test_scenario_5_custom_integration.py | 16 +- 35 files changed, 1561 insertions(+), 843 deletions(-) create mode 100644 src/powermem/integrations/embeddings/config/sparse_providers.py create mode 100644 src/powermem/integrations/rerank/config/providers.py delete mode 100644 src/powermem/integrations/rerank/configs.py create mode 100644 src/powermem/storage/config/__init__.py delete mode 100644 src/powermem/storage/configs.py diff --git a/src/powermem/config_loader.py b/src/powermem/config_loader.py index b9e8233..f8ccfc4 100644 --- a/src/powermem/config_loader.py +++ b/src/powermem/config_loader.py @@ -13,6 +13,7 @@ from powermem.integrations.embeddings.config.base import BaseEmbedderConfig from powermem.integrations.embeddings.config.providers import CustomEmbeddingConfig +from powermem.integrations.embeddings.config.sparse_base import BaseSparseEmbedderConfig from powermem.integrations.llm.config.base import BaseLLMConfig from powermem.settings import _DEFAULT_ENV_FILE, settings_config @@ -110,188 +111,39 @@ class DatabaseSettings(_BasePowermemSettings): model_config = settings_config() provider: str = Field( - default="oceanbase", + default="sqlite", validation_alias=AliasChoices("DATABASE_PROVIDER"), ) - database_sslmode: Optional[str] = Field( - default=None, - validation_alias=AliasChoices("DATABASE_SSLMODE"), - ) - database_pool_size: Optional[int] = Field( - default=None, - validation_alias=AliasChoices("DATABASE_POOL_SIZE"), - ) - database_max_overflow: Optional[int] = Field( - default=None, - validation_alias=AliasChoices("DATABASE_MAX_OVERFLOW"), - ) - sqlite_path: str = Field( - default="./data/powermem_dev.db", - validation_alias=AliasChoices("SQLITE_PATH"), - ) - sqlite_collection: str = Field( - default="memories", - validation_alias=AliasChoices("SQLITE_COLLECTION"), - ) - sqlite_enable_wal: bool = Field( - default=True, - validation_alias=AliasChoices("SQLITE_ENABLE_WAL"), - ) - sqlite_timeout: int = Field( - default=30, - validation_alias=AliasChoices("SQLITE_TIMEOUT"), - ) - oceanbase_host: str = Field( - default="127.0.0.1", - validation_alias=AliasChoices("OCEANBASE_HOST"), - ) - oceanbase_port: int = Field( - default=2881, - validation_alias=AliasChoices("OCEANBASE_PORT"), - ) - oceanbase_user: str = Field( - default="root@sys", - validation_alias=AliasChoices("OCEANBASE_USER"), - ) - oceanbase_password: str = Field( - default="password", - validation_alias=AliasChoices("OCEANBASE_PASSWORD"), - ) - oceanbase_database: str = Field( - default="powermem", - validation_alias=AliasChoices("OCEANBASE_DATABASE"), - ) - oceanbase_collection: str = Field( - default="memories", - validation_alias=AliasChoices("OCEANBASE_COLLECTION"), - ) - oceanbase_vector_metric_type: str = Field( - default="cosine", - validation_alias=AliasChoices("OCEANBASE_VECTOR_METRIC_TYPE"), - ) - oceanbase_index_type: str = Field( - default="IVF_FLAT", - validation_alias=AliasChoices("OCEANBASE_INDEX_TYPE"), - ) - oceanbase_embedding_model_dims: int = Field( - default=1536, - validation_alias=AliasChoices("OCEANBASE_EMBEDDING_MODEL_DIMS"), - ) - oceanbase_primary_field: str = Field( - default="id", - validation_alias=AliasChoices("OCEANBASE_PRIMARY_FIELD"), - ) - oceanbase_vector_field: str = Field( - default="embedding", - validation_alias=AliasChoices("OCEANBASE_VECTOR_FIELD"), - ) - oceanbase_text_field: str = Field( - default="document", - validation_alias=AliasChoices("OCEANBASE_TEXT_FIELD"), - ) - oceanbase_metadata_field: str = Field( - default="metadata", - validation_alias=AliasChoices("OCEANBASE_METADATA_FIELD"), - ) - oceanbase_vidx_name: str = Field( - default="memories_vidx", - validation_alias=AliasChoices("OCEANBASE_VIDX_NAME"), - ) - oceanbase_include_sparse: bool = Field( - default=False, - validation_alias=AliasChoices("SPARSE_VECTOR_ENABLE"), - ) - enable_native_hybrid: bool = Field( - default=False, - validation_alias=AliasChoices("OCEANBASE_ENABLE_NATIVE_HYBRID"), - ) - postgres_collection: str = Field( - default="memories", - validation_alias=AliasChoices("POSTGRES_COLLECTION"), - ) - postgres_database: str = Field( - default="powermem", - validation_alias=AliasChoices("POSTGRES_DATABASE"), - ) - postgres_host: str = Field( - default="127.0.0.1", - validation_alias=AliasChoices("POSTGRES_HOST"), - ) - postgres_port: int = Field( - default=5432, - validation_alias=AliasChoices("POSTGRES_PORT"), - ) - postgres_user: str = Field( - default="postgres", - validation_alias=AliasChoices("POSTGRES_USER"), - ) - postgres_password: str = Field( - default="password", - validation_alias=AliasChoices("POSTGRES_PASSWORD"), - ) - postgres_embedding_model_dims: int = Field( - default=1536, - validation_alias=AliasChoices("POSTGRES_EMBEDDING_MODEL_DIMS"), - ) - postgres_diskann: bool = Field( - default=True, - validation_alias=AliasChoices("POSTGRES_DISKANN"), - ) - postgres_hnsw: bool = Field( - default=True, - validation_alias=AliasChoices("POSTGRES_HNSW"), - ) - - def _build_oceanbase_config(self) -> Dict[str, Any]: - connection_args = { - "host": self.oceanbase_host, - "port": self.oceanbase_port, - "user": self.oceanbase_user, - "password": self.oceanbase_password, - "db_name": self.oceanbase_database, - } - return { - "collection_name": self.oceanbase_collection, - "connection_args": connection_args, - "vidx_metric_type": self.oceanbase_vector_metric_type, - "index_type": self.oceanbase_index_type, - "embedding_model_dims": self.oceanbase_embedding_model_dims, - "primary_field": self.oceanbase_primary_field, - "vector_field": self.oceanbase_vector_field, - "text_field": self.oceanbase_text_field, - "metadata_field": self.oceanbase_metadata_field, - "vidx_name": self.oceanbase_vidx_name, - "include_sparse": self.oceanbase_include_sparse, - "enable_native_hybrid": self.enable_native_hybrid, - } - - def _build_postgres_config(self) -> Dict[str, Any]: - return { - "collection_name": self.postgres_collection, - "dbname": self.postgres_database, - "host": self.postgres_host, - "port": self.postgres_port, - "user": self.postgres_user, - "password": self.postgres_password, - "embedding_model_dims": self.postgres_embedding_model_dims, - "diskann": self.postgres_diskann, - "hnsw": self.postgres_hnsw, - } - - def _build_sqlite_config(self) -> Dict[str, Any]: - return { - "database_path": self.sqlite_path, - "collection_name": self.sqlite_collection, - "enable_wal": self.sqlite_enable_wal, - "timeout": self.sqlite_timeout, - } def to_config(self) -> Dict[str, Any]: + """ + Convert settings to VectorStore configuration dictionary. + + Provider-specific fields are automatically loaded from environment + variables by the provider config class. + """ + from powermem.storage.config.base import BaseVectorStoreConfig + db_provider = self.provider.lower() - builder = getattr(self, f"_build_{db_provider}_config", None) - if not callable(builder): - builder = self._build_sqlite_config - return {"provider": db_provider, "config": builder()} + + # Handle postgres alias + if db_provider == "postgres": + db_provider = "pgvector" + + # 1. Get provider config class from registry + config_cls = ( + BaseVectorStoreConfig.get_provider_config_cls(db_provider) + or BaseVectorStoreConfig + ) + + # 2. Create provider settings from environment variables + # All provider-specific fields are loaded here automatically + provider_settings = config_cls() + + # 3. Export to dict + vector_store_config = provider_settings.model_dump(exclude_none=True) + + return {"provider": db_provider, "config": vector_store_config} class LLMSettings(_BasePowermemSettings): @@ -475,16 +327,50 @@ class RerankerSettings(_BasePowermemSettings): provider: str = Field(default="qwen") model: Optional[str] = Field(default=None) api_key: Optional[str] = Field(default=None) + api_base_url: Optional[str] = Field(default=None) + top_n: Optional[int] = Field(default=None) def to_config(self) -> Dict[str, Any]: - return { - "enabled": self.enabled, - "provider": self.provider, - "config": { - "model": self.model, - "api_key": self.api_key, - }, - } + """ + Convert settings to Rerank configuration dictionary. + + This method: + 1. Gets the appropriate provider config class + 2. Creates an instance (loading provider-specific fields from environment) + 3. Overrides with explicitly set fields from this settings object + 4. Returns the final configuration + + Provider-specific fields (e.g., api_base_url) are automatically loaded + from environment variables by the provider config class. + """ + from powermem.integrations.rerank.config.base import BaseRerankConfig + + rerank_provider = self.provider.lower() + + # 1. Get provider config class from registry + config_cls = ( + BaseRerankConfig.get_provider_config_cls(rerank_provider) + or BaseRerankConfig # fallback to base config + ) + + # 2. Create provider settings from environment variables + # Provider-specific fields are automatically loaded here + provider_settings = config_cls() + + # 3. Collect fields to override + overrides = {} + for field in ("enabled", "model", "api_key", "api_base_url", "top_n"): + if field in self.model_fields_set: + value = getattr(self, field) + if value is not None: + overrides[field] = value + + # 4. Update configuration with overrides + if overrides: + provider_settings = provider_settings.model_copy(update=overrides) + + # 5. Export using to_component_dict() to match RerankConfig structure + return provider_settings.to_component_dict() class QueryRewriteSettings(_BasePowermemSettings): @@ -519,14 +405,21 @@ class SparseEmbedderSettings(_BasePowermemSettings): def to_config(self) -> Optional[Dict[str, Any]]: if not self.provider: return None - config = { - "api_key": self.api_key, - "model": self.model, - "base_url": self.base_url, - "embedding_dims": self.embedding_dims, - } - config = {key: value for key, value in config.items() if value is not None} - return {"provider": self.provider.lower(), "config": config} + provider = self.provider.lower() + config_cls = ( + BaseSparseEmbedderConfig.get_provider_config_cls(provider) + or BaseSparseEmbedderConfig + ) + provider_settings = config_cls() + overrides = {} + for field in ("api_key", "model", "base_url", "embedding_dims"): + if field in self.model_fields_set: + value = getattr(self, field) + if value is not None: + overrides[field] = value + if overrides: + provider_settings = provider_settings.model_copy(update=overrides) + return provider_settings.to_component_dict() class PerformanceSettings(_BasePowermemSettings): @@ -608,122 +501,47 @@ class GraphStoreSettings(_BasePowermemSettings): enabled: bool = Field(default=False) provider: str = Field(default="oceanbase") - host: Optional[str] = Field(default=None) - port: Optional[int] = Field(default=None) - user: Optional[str] = Field(default=None) - password: Optional[str] = Field(default=None) - db_name: Optional[str] = Field(default=None) - vector_metric_type: Optional[str] = Field(default=None) - index_type: Optional[str] = Field(default=None) - embedding_model_dims: Optional[int] = Field(default=None) - max_hops: Optional[int] = Field(default=None) custom_prompt: Optional[str] = Field(default=None) custom_extract_relations_prompt: Optional[str] = Field(default=None) custom_update_graph_prompt: Optional[str] = Field(default=None) custom_delete_relations_prompt: Optional[str] = Field(default=None) - def _build_oceanbase_config( - self, database_settings: "DatabaseSettings" - ) -> Dict[str, Any]: - graph_connection_args = { - "host": _get_graph_value( - self, - "host", - _get_db_value( - database_settings, - "oceanbase_host", - ), - "127.0.0.1", - ), - "port": _get_graph_value( - self, - "port", - _get_db_value( - database_settings, - "oceanbase_port", - ), - 2881, - ), - "user": _get_graph_value( - self, - "user", - _get_db_value( - database_settings, - "oceanbase_user", - ), - "root@sys", - ), - "password": _get_graph_value( - self, - "password", - _get_db_value( - database_settings, - "oceanbase_password", - ), - "password", - ), - "db_name": _get_graph_value( - self, - "db_name", - _get_db_value( - database_settings, - "oceanbase_database", - ), - "powermem", - ), - } - return { - "host": graph_connection_args["host"], - "port": graph_connection_args["port"], - "user": graph_connection_args["user"], - "password": graph_connection_args["password"], - "db_name": graph_connection_args["db_name"], - "vidx_metric_type": _get_graph_value_with_database( - self, - "vector_metric_type", - database_settings, - "oceanbase_vector_metric_type", - "l2", - ), - "index_type": _get_graph_value_with_database( - self, - "index_type", - database_settings, - "oceanbase_index_type", - "HNSW", - ), - "embedding_model_dims": _get_graph_value_with_database( - self, - "embedding_model_dims", - database_settings, - "oceanbase_embedding_model_dims", - 1536, - ), - "max_hops": _get_graph_value( - self, - "max_hops", - None, - 3, - ), - } - def to_config( self, - database_settings: "DatabaseSettings", ) -> Optional[Dict[str, Any]]: + """ + Convert settings to GraphStore configuration dictionary. + + Provider-specific fields are automatically loaded from environment + variables by the provider config class (with fallback to VectorStore env vars). + """ if not self.enabled: return None - - graph_store_provider = self.provider.lower() - builder = getattr(self, f"_build_{graph_store_provider}_config", None) - graph_config = builder(database_settings) if callable(builder) else {} - + + from powermem.storage.config.base import BaseGraphStoreConfig + + graph_provider = self.provider.lower() + + # 1. Get provider config class from registry + config_cls = ( + BaseGraphStoreConfig.get_provider_config_cls(graph_provider) + or BaseGraphStoreConfig + ) + + # 2. Create provider settings from environment variables + provider_settings = config_cls() + + # 3. Export to dict + graph_config = provider_settings.model_dump(exclude_none=True) + + # 4. Build final config graph_store_config = { "enabled": True, - "provider": graph_store_provider, + "provider": graph_provider, "config": graph_config, } - + + # 5. Add custom prompts if configured if self.custom_prompt: graph_store_config["custom_prompt"] = self.custom_prompt if self.custom_extract_relations_prompt: @@ -742,42 +560,6 @@ def to_config( return graph_store_config -def _get_graph_value( - settings: GraphStoreSettings, - field: str, - fallback: Optional[Any], - default: Any, -) -> Any: - if field in settings.model_fields_set: - return getattr(settings, field) - if fallback is not None: - return fallback - return default - - -def _get_db_value( - settings: DatabaseSettings, - field: str, -) -> Optional[Any]: - if field in settings.model_fields_set: - return getattr(settings, field) - return None - - -def _get_graph_value_with_database( - settings: GraphStoreSettings, - field: str, - database_settings: DatabaseSettings, - database_field: str, - default: Any, -) -> Any: - if field in settings.model_fields_set: - return getattr(settings, field) - if database_field in database_settings.model_fields_set: - return getattr(database_settings, database_field) - return default - - class PowermemSettings: _COMPONENTS = { "vector_store": ("database", DatabaseSettings), @@ -809,7 +591,7 @@ def to_config(self) -> Dict[str, Any]: if component_config is not None: config[output_key] = component_config - graph_store_config = self.graph_store.to_config(self.database) + graph_store_config = self.graph_store.to_config() if graph_store_config: config["graph_store"] = graph_store_config @@ -817,6 +599,23 @@ def to_config(self) -> Dict[str, Any]: if sparse_embedder_config: config["sparse_embedder"] = sparse_embedder_config + # Sync embedding_model_dims from embedder to vector_store and graph_store + embedder_config = config.get("embedder", {}) + embedder_dims = embedder_config.get("config", {}).get("embedding_dims") + + if embedder_dims is not None: + # Sync to vector_store if not set + vector_store_config = config.get("vector_store", {}) + vector_store_inner_config = vector_store_config.get("config", {}) + if vector_store_inner_config.get("embedding_model_dims") is None: + vector_store_inner_config["embedding_model_dims"] = embedder_dims + + # Sync to graph_store if not set + if graph_store_config: + graph_store_inner_config = graph_store_config.get("config", {}) + if graph_store_inner_config.get("embedding_model_dims") is None: + graph_store_inner_config["embedding_model_dims"] = embedder_dims + return config @@ -985,6 +784,10 @@ def create_config( }, } + # Sync embedding_model_dims from embedder to vector_store if not set + if config["vector_store"]["config"].get("embedding_model_dims") is None: + config["vector_store"]["config"]["embedding_model_dims"] = options.embedding_dims + return config diff --git a/src/powermem/configs.py b/src/powermem/configs.py index 1c39702..dcc4641 100644 --- a/src/powermem/configs.py +++ b/src/powermem/configs.py @@ -10,11 +10,13 @@ from powermem.integrations.embeddings.config.base import BaseEmbedderConfig from powermem.integrations.embeddings.config.providers import OpenAIEmbeddingConfig -from powermem.integrations.embeddings.config.sparse_base import SparseEmbedderConfig +from powermem.integrations.embeddings.config.sparse_base import BaseSparseEmbedderConfig from powermem.integrations.llm.config.base import BaseLLMConfig from powermem.integrations.llm.config.qwen import QwenConfig -from powermem.storage.configs import VectorStoreConfig, GraphStoreConfig -from powermem.integrations.rerank.configs import RerankConfig +from powermem.storage.config.base import BaseVectorStoreConfig, BaseGraphStoreConfig +from powermem.storage.config.sqlite import SQLiteConfig +from powermem.storage.config.oceanbase import OceanBaseGraphConfig +from powermem.integrations.rerank.config.base import BaseRerankConfig class IntelligentMemoryConfig(BaseModel): @@ -195,9 +197,9 @@ class QueryRewriteConfig(BaseModel): class MemoryConfig(BaseModel): """Main memory configuration class.""" - vector_store: VectorStoreConfig = Field( + vector_store: BaseVectorStoreConfig = Field( description="Configuration for the vector store", - default_factory=VectorStoreConfig, + default_factory=SQLiteConfig, ) llm: BaseLLMConfig = Field( description="Configuration for the language model", @@ -207,15 +209,15 @@ class MemoryConfig(BaseModel): description="Configuration for the embedding model", default_factory=OpenAIEmbeddingConfig, ) - graph_store: GraphStoreConfig = Field( - description="Configuration for the graph", - default_factory=GraphStoreConfig, + graph_store: Optional[BaseGraphStoreConfig] = Field( + description="Configuration for the graph store (None means disabled)", + default=None, ) - reranker: Optional[RerankConfig] = Field( + reranker: Optional[BaseRerankConfig] = Field( description="Configuration for the reranker", default=None, ) - sparse_embedder: Optional[SparseEmbedderConfig] = Field( + sparse_embedder: Optional[BaseSparseEmbedderConfig] = Field( description="Configuration for the sparse embedder (only supported for OceanBase)", default=None, ) @@ -277,14 +279,25 @@ def __init__(self, **data): if self.logging is None: self.logging = LoggingConfig() if self.reranker is None: - self.reranker = RerankConfig() + self.reranker = BaseRerankConfig() if self.query_rewrite is None: self.query_rewrite = QueryRewriteConfig() + + # Sync embedding_model_dims from embedder if not set in vector_store/graph_store + embedder_dims = getattr(self.embedder, 'embedding_dims', None) + if embedder_dims is not None: + # Sync to vector_store if not set + if hasattr(self.vector_store, 'embedding_model_dims') and self.vector_store.embedding_model_dims is None: + self.vector_store.embedding_model_dims = embedder_dims + # Sync to graph_store if not set + if self.graph_store is not None: + if hasattr(self.graph_store, 'embedding_model_dims') and self.graph_store.embedding_model_dims is None: + self.graph_store.embedding_model_dims = embedder_dims def to_dict(self) -> Dict[str, Any]: result = self.model_dump(exclude_none=True) - for field in ['embedder', 'llm', 'vector_store']: + for field in ['embedder', 'llm', 'vector_store', 'reranker', 'graph_store', 'sparse_embedder']: obj = getattr(self, field, None) if obj and hasattr(obj, 'to_component_dict'): result[field] = obj.to_component_dict() diff --git a/src/powermem/core/memory.py b/src/powermem/core/memory.py index d36cb69..9ccc55b 100644 --- a/src/powermem/core/memory.py +++ b/src/powermem/core/memory.py @@ -15,7 +15,7 @@ from .base import MemoryBase from ..configs import MemoryConfig -from ..integrations.embeddings.config.sparse_base import BaseSparseEmbedderConfig, SparseEmbedderConfig +from ..integrations.embeddings.config.sparse_base import BaseSparseEmbedderConfig from ..storage.factory import VectorStoreFactory, GraphStoreFactory from ..storage.adapter import StorageAdapter, SubStorageAdapter from ..intelligence.manager import IntelligenceManager @@ -253,9 +253,12 @@ def __init__( if sparse_config_obj: try: - # Handle SparseEmbedderConfig (BaseModel with provider and config) or dict format - if hasattr(sparse_config_obj, 'provider') and hasattr(sparse_config_obj, 'config'): - # It's a SparseEmbedderConfig (BaseModel) object + # Handle BaseSparseEmbedderConfig, legacy wrapper, or dict format + if isinstance(sparse_config_obj, BaseSparseEmbedderConfig): + sparse_embedder_provider = sparse_config_obj._provider_name + config_dict = sparse_config_obj.model_dump(exclude_none=True) + elif hasattr(sparse_config_obj, 'provider') and hasattr(sparse_config_obj, 'config'): + # Legacy wrapper with provider + config fields sparse_embedder_provider = sparse_config_obj.provider config_dict = sparse_config_obj.config or {} elif isinstance(sparse_config_obj, dict): @@ -263,7 +266,10 @@ def __init__( sparse_embedder_provider = sparse_config_obj.get('provider') config_dict = sparse_config_obj.get('config', {}) else: - logger.warning(f"Unknown sparse_embedder config format: {type(sparse_config_obj)}. Expected SparseEmbedderConfig or dict with 'provider' and 'config' keys.") + logger.warning( + "Unknown sparse_embedder config format: %s. Expected BaseSparseEmbedderConfig or dict with 'provider' and 'config' keys.", + type(sparse_config_obj), + ) sparse_embedder_provider = None config_dict = {} @@ -370,10 +376,16 @@ def _get_graph_enabled(self) -> bool: Boolean indicating whether graph store is enabled """ if self.memory_config: - return self.memory_config.graph_store.enabled if self.memory_config.graph_store else False + # graph_store is None means disabled, otherwise enabled + return self.memory_config.graph_store is not None else: graph_store_config = self.config.get('graph_store', {}) - return graph_store_config.get('enabled', False) if graph_store_config else False + # Support both old format (dict with 'enabled') and new format (config object) + if isinstance(graph_store_config, dict): + return graph_store_config.get('enabled', False) if graph_store_config else False + else: + # New format: config object means enabled + return graph_store_config is not None def _get_intelligent_memory_config(self) -> Dict[str, Any]: """ diff --git a/src/powermem/integrations/__init__.py b/src/powermem/integrations/__init__.py index 28140a2..15a94de 100644 --- a/src/powermem/integrations/__init__.py +++ b/src/powermem/integrations/__init__.py @@ -7,11 +7,11 @@ from .llm.factory import LLMFactory from .embeddings.factory import EmbedderFactory from .rerank.factory import RerankFactory -from .rerank.configs import RerankConfig +from .rerank.config.base import BaseRerankConfig __all__ = [ "LLMFactory", "EmbedderFactory", "RerankFactory", - "RerankConfig", + "BaseRerankConfig", ] diff --git a/src/powermem/integrations/embeddings/config/__init__.py b/src/powermem/integrations/embeddings/config/__init__.py index 4d7d5a3..24d2625 100644 --- a/src/powermem/integrations/embeddings/config/__init__.py +++ b/src/powermem/integrations/embeddings/config/__init__.py @@ -16,6 +16,9 @@ VertexAIEmbeddingConfig, ZaiEmbeddingConfig, ) +from powermem.integrations.embeddings.config.sparse_providers import ( + QwenSparseEmbeddingConfig, +) __all__ = [ "AWSBedrockEmbeddingConfig", @@ -29,6 +32,7 @@ "MockEmbeddingConfig", "OllamaEmbeddingConfig", "OpenAIEmbeddingConfig", + "QwenSparseEmbeddingConfig", "QwenEmbeddingConfig", "SiliconFlowEmbeddingConfig", "TogetherEmbeddingConfig", diff --git a/src/powermem/integrations/embeddings/config/sparse_base.py b/src/powermem/integrations/embeddings/config/sparse_base.py index a00de0e..2f8ae33 100644 --- a/src/powermem/integrations/embeddings/config/sparse_base.py +++ b/src/powermem/integrations/embeddings/config/sparse_base.py @@ -1,65 +1,73 @@ -from abc import ABC -from typing import Optional +from typing import Any, ClassVar, Dict, Optional -from pydantic import BaseModel, Field, field_validator +from pydantic import Field +from pydantic_settings import BaseSettings +from powermem.settings import settings_config -class BaseSparseEmbedderConfig(ABC): - """ - Base config for Sparse Embeddings. - This is an abstract base class used by specific sparse embedding implementations. - """ - def __init__( - self, - model: Optional[str] = None, - api_key: Optional[str] = None, - embedding_dims: Optional[int] = None, - base_url: Optional[str] = None, - ): - """ - Initializes a configuration class instance for the Sparse Embeddings. +class BaseSparseEmbedderConfig(BaseSettings): + """Common sparse embedding configuration shared by all providers.""" - :param model: Embedding model to use, defaults to None - :type model: Optional[str], optional - :param api_key: API key to use, defaults to None - :type api_key: Optional[str], optional - :param embedding_dims: The number of dimensions in the embedding, defaults to None - :type embedding_dims: Optional[int], optional - :param base_url: Base URL for the API, defaults to None - :type base_url: Optional[str], optional - """ + model_config = settings_config("SPARSE_EMBEDDER_", extra="allow", env_file=None) - self.model = model - self.api_key = api_key - self.embedding_dims = embedding_dims - self.base_url = base_url + _provider_name: ClassVar[Optional[str]] = None + _class_path: ClassVar[Optional[str]] = None + _registry: ClassVar[dict[str, type["BaseSparseEmbedderConfig"]]] = {} + _class_paths: ClassVar[dict[str, str]] = {} + @classmethod + def _register_provider(cls) -> None: + provider = getattr(cls, "_provider_name", None) + class_path = getattr(cls, "_class_path", None) + if provider: + BaseSparseEmbedderConfig._registry[provider] = cls + if class_path: + BaseSparseEmbedderConfig._class_paths[provider] = class_path -class SparseEmbedderConfig(BaseModel): - """ - Configuration for sparse embedder in MemoryConfig. - This is a Pydantic model used in MemoryConfig, similar to EmbedderConfig. - """ + def __init_subclass__(cls, **kwargs) -> None: + super().__init_subclass__(**kwargs) + cls._register_provider() - provider: str = Field( - description="Provider of the sparse embedding model (e.g., 'qwen')", + @classmethod + def __pydantic_init_subclass__(cls, **kwargs) -> None: + super().__pydantic_init_subclass__(**kwargs) + cls._register_provider() + + @classmethod + def get_provider_config_cls( + cls, provider: str + ) -> Optional[type["BaseSparseEmbedderConfig"]]: + return cls._registry.get(provider) + + @classmethod + def get_provider_class_path(cls, provider: str) -> Optional[str]: + return cls._class_paths.get(provider) + + @classmethod + def has_provider(cls, provider: str) -> bool: + return provider in cls._registry + + model: Optional[str] = Field( default=None, + description="Sparse embedding model identifier.", ) - config: Optional[dict] = Field( - description="Configuration for the specific sparse embedding model", - default={} + api_key: Optional[str] = Field( + default=None, + description="API key used for provider authentication.", + ) + embedding_dims: Optional[int] = Field( + default=None, + description="Sparse embedding vector dimensions, when configurable.", + ) + base_url: Optional[str] = Field( + default=None, + description="Base URL for the sparse embedding provider.", ) - @field_validator("config") - def validate_config(cls, v, values): - provider = values.data.get("provider") - - # Import here to avoid circular import - from powermem.integrations.embeddings.sparse_factory import SparseEmbedderFactory - - if provider in SparseEmbedderFactory.provider_to_class: - return v - else: - raise ValueError(f"Unsupported sparse embedding provider: {provider}") + def to_component_dict(self) -> Dict[str, Any]: + return { + "provider": self._provider_name, + "config": self.model_dump(exclude_none=True), + } diff --git a/src/powermem/integrations/embeddings/config/sparse_providers.py b/src/powermem/integrations/embeddings/config/sparse_providers.py new file mode 100644 index 0000000..598282a --- /dev/null +++ b/src/powermem/integrations/embeddings/config/sparse_providers.py @@ -0,0 +1,32 @@ +from typing import Optional + +from pydantic import AliasChoices, Field + +from powermem.integrations.embeddings.config.sparse_base import BaseSparseEmbedderConfig +from powermem.settings import settings_config + + +class QwenSparseEmbeddingConfig(BaseSparseEmbedderConfig): + _provider_name = "qwen" + _class_path = "powermem.integrations.embeddings.qwen_sparse.QwenSparseEmbedding" + + model_config = settings_config("SPARSE_EMBEDDER_", extra="forbid", env_file=None) + + api_key: Optional[str] = Field( + default=None, + validation_alias=AliasChoices( + "api_key", + "SPARSE_EMBEDDER_API_KEY", + "DASHSCOPE_API_KEY", + ), + ) + base_url: Optional[str] = Field( + default=None, + validation_alias=AliasChoices( + "base_url", + "SPARSE_EMBEDDING_BASE_URL", + "DASHSCOPE_BASE_URL", + ), + ) + model: Optional[str] = Field(default=None) + embedding_dims: Optional[int] = Field(default=None) diff --git a/src/powermem/integrations/embeddings/qwen_sparse.py b/src/powermem/integrations/embeddings/qwen_sparse.py index 272ecbf..7a0b03e 100644 --- a/src/powermem/integrations/embeddings/qwen_sparse.py +++ b/src/powermem/integrations/embeddings/qwen_sparse.py @@ -1,8 +1,8 @@ import os from typing import Literal, Optional -from src.powermem.integrations.embeddings.config.sparse_base import BaseSparseEmbedderConfig -from src.powermem.integrations.embeddings.sparse_base import SparseEmbeddingBase +from powermem.integrations.embeddings.config.sparse_base import BaseSparseEmbedderConfig +from powermem.integrations.embeddings.sparse_base import SparseEmbeddingBase try: from dashscope import TextEmbedding diff --git a/src/powermem/integrations/embeddings/sparse_base.py b/src/powermem/integrations/embeddings/sparse_base.py index abb076e..9d33ff7 100644 --- a/src/powermem/integrations/embeddings/sparse_base.py +++ b/src/powermem/integrations/embeddings/sparse_base.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Optional -from src.powermem.integrations.embeddings.config.sparse_base import BaseSparseEmbedderConfig +from powermem.integrations.embeddings.config.sparse_base import BaseSparseEmbedderConfig class SparseEmbeddingBase(ABC): diff --git a/src/powermem/integrations/embeddings/sparse_factory.py b/src/powermem/integrations/embeddings/sparse_factory.py index 0ff85d9..2139db0 100644 --- a/src/powermem/integrations/embeddings/sparse_factory.py +++ b/src/powermem/integrations/embeddings/sparse_factory.py @@ -17,45 +17,47 @@ def load_class(class_type): class SparseEmbedderFactory: """Factory for creating sparse embedding instances.""" - - provider_to_class = { - "qwen": "powermem.integrations.embeddings.qwen_sparse.QwenSparseEmbedding", - } @classmethod def create(cls, provider_name: str, config): """ Create a sparse embedding instance. - + Args: provider_name: Name of the sparse embedding provider (e.g., 'qwen') - config: Configuration dictionary, BaseSparseEmbedderConfig object, or SparseEmbedderConfig object - + config: Configuration dictionary or BaseSparseEmbedderConfig object + Returns: Sparse embedding instance """ - class_type = cls.provider_to_class.get(provider_name) - if class_type: - # Handle different config types + provider = provider_name.lower() + class_type = BaseSparseEmbedderConfig.get_provider_class_path(provider) + if not class_type: + raise ValueError(f"Unsupported SparseEmbedder provider: {provider_name}") + + if isinstance(config, BaseSparseEmbedderConfig): + config_obj = config + else: if isinstance(config, dict): - # Filter out 'provider' if present in dict - config_dict = {k: v for k, v in config.items() if k != 'provider'} - config_obj = BaseSparseEmbedderConfig(**config_dict) - elif hasattr(config, 'provider') and hasattr(config, 'config'): - # It's a SparseEmbedderConfig object, extract the inner config - inner_config = config.config if isinstance(config.config, dict) else config.model_dump().get('config', {}) - config_obj = BaseSparseEmbedderConfig(**inner_config) - elif hasattr(config, 'model') or hasattr(config, 'api_key'): - # It's already a BaseSparseEmbedderConfig object, use it directly - config_obj = config + if isinstance(config.get("config"), dict): + config_dict = config.get("config", {}) + else: + config_dict = {k: v for k, v in config.items() if k != "provider"} + elif hasattr(config, "provider") and hasattr(config, "config"): + config_dict = ( + config.config + if isinstance(config.config, dict) + else config.model_dump().get("config", {}) + ) + elif hasattr(config, "model_dump"): + config_dict = config.model_dump() else: - # Try to convert to dict (e.g., Pydantic model) - config_dict = config.model_dump() if hasattr(config, 'model_dump') else {} - # Filter out 'provider' if present - config_dict = {k: v for k, v in config_dict.items() if k != 'provider'} - config_obj = BaseSparseEmbedderConfig(**config_dict) - - sparse_embedder_class = load_class(class_type) - return sparse_embedder_class(config_obj) - else: - raise ValueError(f"Unsupported SparseEmbedder provider: {provider_name}") + config_dict = {} + config_cls = ( + BaseSparseEmbedderConfig.get_provider_config_cls(provider) + or BaseSparseEmbedderConfig + ) + config_obj = config_cls(**config_dict) + + sparse_embedder_class = load_class(class_type) + return sparse_embedder_class(config_obj) diff --git a/src/powermem/integrations/rerank/__init__.py b/src/powermem/integrations/rerank/__init__.py index 9faee1a..16717a0 100644 --- a/src/powermem/integrations/rerank/__init__.py +++ b/src/powermem/integrations/rerank/__init__.py @@ -9,8 +9,14 @@ from .qwen import QwenRerank from .jina import JinaRerank from .generic import GenericRerank +from .zai import ZaiRerank from .config.base import BaseRerankConfig -from .configs import RerankConfig +from .config.providers import ( + QwenRerankConfig, + JinaRerankConfig, + ZaiRerankConfig, + GenericRerankConfig, +) __all__ = [ "RerankBase", @@ -18,7 +24,11 @@ "QwenRerank", "JinaRerank", "GenericRerank", + "ZaiRerank", "BaseRerankConfig", - "RerankConfig", + "QwenRerankConfig", + "JinaRerankConfig", + "ZaiRerankConfig", + "GenericRerankConfig", ] diff --git a/src/powermem/integrations/rerank/config/__init__.py b/src/powermem/integrations/rerank/config/__init__.py index 89d31f6..f65f253 100644 --- a/src/powermem/integrations/rerank/config/__init__.py +++ b/src/powermem/integrations/rerank/config/__init__.py @@ -2,6 +2,18 @@ Rerank configuration module """ from .base import BaseRerankConfig +from .providers import ( + QwenRerankConfig, + JinaRerankConfig, + ZaiRerankConfig, + GenericRerankConfig, +) -__all__ = ["BaseRerankConfig"] +__all__ = [ + "BaseRerankConfig", + "QwenRerankConfig", + "JinaRerankConfig", + "ZaiRerankConfig", + "GenericRerankConfig", +] diff --git a/src/powermem/integrations/rerank/config/base.py b/src/powermem/integrations/rerank/config/base.py index bf4c245..3f25dda 100644 --- a/src/powermem/integrations/rerank/config/base.py +++ b/src/powermem/integrations/rerank/config/base.py @@ -1,27 +1,122 @@ """ Base configuration for rerank models """ -from typing import Optional +from typing import Any, ClassVar, Dict, Optional, Union +try: + import httpx +except ImportError: + httpx = None -class BaseRerankConfig: +from pydantic import Field +from pydantic_settings import BaseSettings + +from powermem.settings import settings_config + + +class BaseRerankConfig(BaseSettings): """Base configuration for rerank models - Args: - model (str): The rerank model to use - api_key (Optional[str]): API key for the rerank service + This class uses pydantic-settings to support automatic loading from environment variables. + All rerank provider configurations should inherit from this base class. + + Environment Variables: + RERANK_ENABLED: Whether to enable reranker (default: False) + RERANK_MODEL: The rerank model to use + RERANK_API_KEY: API key for the rerank service + RERANK_API_BASE_URL: Base URL for the rerank API endpoint + RERANK_TOP_N: Default number of top results to return """ - def __init__( - self, - model: Optional[str] = None, - api_key: Optional[str] = None, - **kwargs, - ): - self.model = model - self.api_key = api_key + model_config = settings_config("RERANK_", extra="allow", env_file=None) + + # Class variables for provider registration + _provider_name: ClassVar[Optional[str]] = None + _class_path: ClassVar[Optional[str]] = None + _registry: ClassVar[dict[str, type["BaseRerankConfig"]]] = {} + _class_paths: ClassVar[dict[str, str]] = {} + + # Configuration fields + enabled: bool = Field( + default=False, + description="Whether to enable reranker" + ) + model: Optional[str] = Field( + default=None, + description="The rerank model identifier to use" + ) + api_key: Optional[str] = Field( + default=None, + description="API key for the rerank provider" + ) + api_base_url: Optional[str] = Field( + default=None, + description="Base URL for the rerank API endpoint" + ) + top_n: Optional[int] = Field( + default=None, + description="Default number of top results to return (can be overridden at runtime)" + ) + http_client_proxies: Optional[Union[Dict, str]] = Field( + default=None, + description="Proxy settings for HTTP client" + ) + http_client: Optional[Any] = Field( # httpx.Client type + default=None, + exclude=True, + description="HTTP client instance" + ) + + @classmethod + def _register_provider(cls) -> None: + """Register provider in the global registry.""" + provider = getattr(cls, "_provider_name", None) + class_path = getattr(cls, "_class_path", None) + if provider: + BaseRerankConfig._registry[provider] = cls + if class_path: + BaseRerankConfig._class_paths[provider] = class_path + + def __init_subclass__(cls, **kwargs) -> None: + """Called when a class inherits from BaseRerankConfig.""" + super().__init_subclass__(**kwargs) + cls._register_provider() + + @classmethod + def __pydantic_init_subclass__(cls, **kwargs) -> None: + """Called by Pydantic when a class inherits from BaseRerankConfig.""" + super().__pydantic_init_subclass__(**kwargs) + cls._register_provider() + + @classmethod + def get_provider_config_cls(cls, provider: str) -> Optional[type["BaseRerankConfig"]]: + """Get the config class for a specific provider.""" + return cls._registry.get(provider) + + @classmethod + def get_provider_class_path(cls, provider: str) -> Optional[str]: + """Get the class path for a specific provider.""" + return cls._class_paths.get(provider) + + @classmethod + def has_provider(cls, provider: str) -> bool: + """Check if a provider is registered.""" + return provider in cls._registry + + def model_post_init(self, __context: Any) -> None: + """Initialize http_client after model creation.""" + if self.http_client_proxies and not self.http_client and httpx: + self.http_client = httpx.Client(proxies=self.http_client_proxies) + + def to_component_dict(self) -> Dict[str, Any]: + """Convert config to component dict format matching RerankConfig structure. - # Store any additional kwargs - for key, value in kwargs.items(): - setattr(self, key, value) + Returns: + Dict matching RerankConfig schema with 'enabled', 'provider', 'config' fields + """ + return { + "enabled": self.enabled, + "provider": self._provider_name, + "config": self.model_dump(exclude_none=True) + } diff --git a/src/powermem/integrations/rerank/config/providers.py b/src/powermem/integrations/rerank/config/providers.py new file mode 100644 index 0000000..7991841 --- /dev/null +++ b/src/powermem/integrations/rerank/config/providers.py @@ -0,0 +1,127 @@ +""" +Provider-specific rerank configurations +""" +from typing import Optional + +from pydantic import AliasChoices, Field + +from powermem.integrations.rerank.config.base import BaseRerankConfig +from powermem.settings import settings_config + + +class QwenRerankConfig(BaseRerankConfig): + """Configuration for Qwen (DashScope) rerank service""" + + _provider_name = "qwen" + _class_path = "powermem.integrations.rerank.qwen.QwenRerank" + + model_config = settings_config("RERANK_", extra="forbid", env_file=None) + + # Override base fields with Qwen-specific aliases + api_key: Optional[str] = Field( + default=None, + validation_alias=AliasChoices( + "api_key", # Must include field name itself! + "RERANK_API_KEY", + "QWEN_API_KEY", + "DASHSCOPE_API_KEY", + ), + description="API key for Qwen rerank service" + ) + + model: Optional[str] = Field( + default="qwen3-rerank", + description="Qwen rerank model name" + ) + + api_base_url: Optional[str] = Field( + default=None, + validation_alias=AliasChoices( + "api_base_url", + "RERANK_API_BASE_URL", + "QWEN_RERANK_BASE_URL", + "DASHSCOPE_BASE_URL", + ), + description="Base URL for Qwen/DashScope API" + ) + + +class JinaRerankConfig(BaseRerankConfig): + """Configuration for Jina AI rerank service""" + + _provider_name = "jina" + _class_path = "powermem.integrations.rerank.jina.JinaRerank" + + model_config = settings_config("RERANK_", extra="forbid", env_file=None) + + # Override base fields with Jina-specific aliases + api_key: Optional[str] = Field( + default=None, + validation_alias=AliasChoices( + "api_key", + "RERANK_API_KEY", + "JINA_API_KEY", + ), + description="API key for Jina AI" + ) + + model: Optional[str] = Field( + default="jina-reranker-v3", + description="Jina rerank model name" + ) + + api_base_url: Optional[str] = Field( + default="https://api.jina.ai/v1/rerank", + validation_alias=AliasChoices( + "api_base_url", + "RERANK_API_BASE_URL", + "JINA_API_BASE_URL", + ), + description="Base URL for Jina AI rerank API" + ) + + +class ZaiRerankConfig(BaseRerankConfig): + """Configuration for Zhipu AI rerank service""" + + _provider_name = "zai" + _class_path = "powermem.integrations.rerank.zai.ZaiRerank" + + model_config = settings_config("RERANK_", extra="forbid", env_file=None) + + # Override base fields with Zhipu AI-specific aliases + api_key: Optional[str] = Field( + default=None, + validation_alias=AliasChoices( + "api_key", + "RERANK_API_KEY", + "ZAI_API_KEY", + ), + description="API key for Zhipu AI" + ) + + model: Optional[str] = Field( + default="rerank", + description="Zhipu AI rerank model name" + ) + + api_base_url: Optional[str] = Field( + default="https://open.bigmodel.cn/api/paas/v4/rerank", + validation_alias=AliasChoices( + "api_base_url", + "RERANK_API_BASE_URL", + "ZAI_API_BASE_URL", + ), + description="Base URL for Zhipu AI rerank API" + ) + + +class GenericRerankConfig(BaseRerankConfig): + """Configuration for generic rerank service""" + + _provider_name = "generic" + _class_path = "powermem.integrations.rerank.generic.GenericRerank" + + model_config = settings_config("RERANK_", extra="forbid", env_file=None) + + # Generic uses base class default configuration diff --git a/src/powermem/integrations/rerank/configs.py b/src/powermem/integrations/rerank/configs.py deleted file mode 100644 index ea63166..0000000 --- a/src/powermem/integrations/rerank/configs.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -Configuration for rerank models -""" -from typing import Optional, Dict, Any -from pydantic import BaseModel, Field - - -class RerankConfig(BaseModel): - """Configuration for rerank functionality.""" - - enabled: bool = Field( - description="Whether to enable reranker", - default=False, - ) - provider: str = Field( - description="Reranker provider (e.g., 'qwen', 'zai', 'jina')", - default="qwen", - ) - config: Optional[Dict[str, Any]] = Field( - description="Configuration for the specific reranker provider", - default=None - ) - diff --git a/src/powermem/integrations/rerank/factory.py b/src/powermem/integrations/rerank/factory.py index a8cffb6..a376e7c 100644 --- a/src/powermem/integrations/rerank/factory.py +++ b/src/powermem/integrations/rerank/factory.py @@ -5,60 +5,79 @@ """ import importlib -from typing import Optional +from typing import Optional, Union from powermem.integrations.rerank.config.base import BaseRerankConfig - - -def load_class(class_type): - """Dynamically load a class from a string path""" - module_path, class_name = class_type.rsplit(".", 1) - module = importlib.import_module(module_path) - return getattr(module, class_name) +# Trigger automatic registration by importing provider configs +from powermem.integrations.rerank.config.providers import ( + QwenRerankConfig, + JinaRerankConfig, + ZaiRerankConfig, + GenericRerankConfig, +) class RerankFactory: - """Factory class for creating rerank model instances - """ - - provider_to_class = { - "qwen": "powermem.integrations.rerank.qwen.QwenRerank", - "jina": "powermem.integrations.rerank.jina.JinaRerank", - "generic": "powermem.integrations.rerank.generic.GenericRerank", - "zai": "powermem.integrations.rerank.zai.ZaiRerank", - } + """Factory class for creating rerank model instances""" @classmethod - def create(cls, provider_name: str = "qwen", config: Optional[dict] = None): + def create( + cls, + provider_name: str = "qwen", + config: Optional[Union[dict, BaseRerankConfig]] = None + ): """ Create a rerank instance based on provider name. Args: provider_name (str): Name of the rerank provider. Defaults to "qwen" - config (Optional[dict]): Configuration dictionary for the rerank model + config (Optional[Union[dict, BaseRerankConfig]]): + Configuration dictionary or BaseRerankConfig instance for the rerank model Returns: RerankBase: An instance of the requested rerank provider Raises: ValueError: If the provider is not supported + TypeError: If config type is invalid """ - class_type = cls.provider_to_class.get(provider_name) - if class_type: - reranker_class = load_class(class_type) - # Create config if provided - if config: - base_config = BaseRerankConfig(**config) - return reranker_class(base_config) - else: - return reranker_class() - else: - supported = ", ".join(cls.provider_to_class.keys()) + # Get config class from registry + config_cls = BaseRerankConfig.get_provider_config_cls(provider_name) + if not config_cls: + supported = ", ".join(BaseRerankConfig._registry.keys()) raise ValueError( f"Unsupported rerank provider: {provider_name}. " f"Supported providers: {supported}" ) + # Get class path from registry + class_path = BaseRerankConfig.get_provider_class_path(provider_name) + if not class_path: + raise ValueError(f"No class path registered for provider: {provider_name}") + + # Load reranker class + module_path, class_name = class_path.rsplit(".", 1) + module = importlib.import_module(module_path) + reranker_class = getattr(module, class_name) + + # Create config instance + if config is None: + # Use default config + config_instance = config_cls() + elif isinstance(config, dict): + # Create config from dict + config_instance = config_cls(**config) + elif isinstance(config, BaseRerankConfig): + # Use provided config instance directly + config_instance = config + else: + raise TypeError( + f"config must be dict or BaseRerankConfig, got {type(config)}" + ) + + # Create and return reranker instance + return reranker_class(config_instance) + @classmethod def list_providers(cls) -> list: """ @@ -67,5 +86,5 @@ def list_providers(cls) -> list: Returns: list: List of supported provider names """ - return list(cls.provider_to_class.keys()) + return list(BaseRerankConfig._registry.keys()) diff --git a/src/powermem/integrations/rerank/generic.py b/src/powermem/integrations/rerank/generic.py index a946921..e93e324 100644 --- a/src/powermem/integrations/rerank/generic.py +++ b/src/powermem/integrations/rerank/generic.py @@ -55,24 +55,26 @@ def __init__(self, config: Optional[BaseRerankConfig] = None): "httpx is not installed. Please install it with: pip install httpx" ) - # Set API base URL (required) - self.api_base_url = getattr(self.config, 'api_base_url', None) or os.getenv( - "RERANK_API_BASE_URL" - ) - if not self.api_base_url: + # Validate API base URL (required, config handles env var loading) + if not self.config.api_base_url: raise ValueError( "api_base_url is required. Set RERANK_API_BASE_URL environment variable " "or pass api_base_url in config." ) - # Set model (required) + # Validate model (required) if not self.config.model: raise ValueError( - "model is required. Pass model name or UID in config." + "model is required. Set RERANK_MODEL environment variable " + "or pass model name/UID in config." ) - # API key is optional - self.api_key = self.config.api_key or os.getenv("RERANK_API_KEY") + # Set API base URL and optional API key + self.api_base_url = self.config.api_base_url + self.api_key = self.config.api_key + + # Use http_client from config if available + self.http_client = self.config.http_client def rerank( self, @@ -127,17 +129,27 @@ def rerank( if self.api_key: headers["Authorization"] = f"Bearer {self.api_key}" - # Make API request - with httpx.Client(timeout=30.0) as client: - response = client.post( + # Use configured http_client or create temporary one + if self.http_client: + response = self.http_client.post( self.api_base_url, json=payload, headers=headers, ) - - # Check response status response.raise_for_status() result = response.json() + else: + # Make API request with temporary client + with httpx.Client(timeout=30.0) as client: + response = client.post( + self.api_base_url, + json=payload, + headers=headers, + ) + + # Check response status + response.raise_for_status() + result = response.json() # Parse results results = [] diff --git a/src/powermem/integrations/rerank/jina.py b/src/powermem/integrations/rerank/jina.py index 9b22840..216b4ab 100644 --- a/src/powermem/integrations/rerank/jina.py +++ b/src/powermem/integrations/rerank/jina.py @@ -26,7 +26,8 @@ def __init__(self, config: Optional[BaseRerankConfig] = None): super().__init__(config) # Set default model - self.config.model = self.config.model or "jina-reranker-v3" + if not self.config.model: + self.config.model = "jina-reranker-v3" # Check if httpx is available if httpx is None: @@ -34,17 +35,19 @@ def __init__(self, config: Optional[BaseRerankConfig] = None): "httpx is not installed. Please install it with: pip install httpx" ) - # Set API key - api_key = self.config.api_key or os.getenv("JINA_API_KEY") - if not api_key: + # Validate API key (config already handles env var loading) + if not self.config.api_key: raise ValueError( - "API key is required. Set JINA_API_KEY environment variable or pass api_key in config." + "API key is required. Set JINA_API_KEY or RERANK_API_KEY environment variable, " + "or pass api_key in config." ) - self.api_key = api_key - self.api_base_url = getattr(self.config, 'api_base_url', None) or os.getenv( - "JINA_API_BASE_URL", "https://api.jina.ai/v1/rerank" - ) + # Set API key and base URL + self.api_key = self.config.api_key + self.api_base_url = self.config.api_base_url or "https://api.jina.ai/v1/rerank" + + # Use http_client from config if available + self.http_client = self.config.http_client def rerank( self, @@ -92,9 +95,9 @@ def rerank( "top_n": effective_top_n, } - # Make API request - with httpx.Client(timeout=30.0) as client: - response = client.post( + # Use configured http_client or create temporary one + if self.http_client: + response = self.http_client.post( self.api_base_url, json=payload, headers={ @@ -102,10 +105,23 @@ def rerank( "Authorization": f"Bearer {self.api_key}", }, ) - - # Check response status response.raise_for_status() result = response.json() + else: + # Make API request with temporary client + with httpx.Client(timeout=30.0) as client: + response = client.post( + self.api_base_url, + json=payload, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}", + }, + ) + + # Check response status + response.raise_for_status() + result = response.json() # Parse results results = [] diff --git a/src/powermem/integrations/rerank/qwen.py b/src/powermem/integrations/rerank/qwen.py index f80827f..f5367e3 100644 --- a/src/powermem/integrations/rerank/qwen.py +++ b/src/powermem/integrations/rerank/qwen.py @@ -27,8 +27,9 @@ class QwenRerank(RerankBase): def __init__(self, config: Optional[BaseRerankConfig] = None): super().__init__(config) - # Set default model - self.config.model = self.config.model or "qwen3-rerank" + # Set default model (if not already set in config) + if not self.config.model: + self.config.model = "qwen3-rerank" # Check if dashscope is available if TextReRank is None or dashscope is None: @@ -36,15 +37,19 @@ def __init__(self, config: Optional[BaseRerankConfig] = None): "DashScope SDK is not installed. Please install it with: pip install dashscope" ) - # Set API key - api_key = self.config.api_key or os.getenv("DASHSCOPE_API_KEY") - if not api_key: + # Validate API key (config already handles env var loading) + if not self.config.api_key: raise ValueError( - "API key is required. Set DASHSCOPE_API_KEY environment variable or pass api_key in config." + "API key is required. Set DASHSCOPE_API_KEY, QWEN_API_KEY, or RERANK_API_KEY environment variable, " + "or pass api_key in config." ) # Set API key for DashScope SDK - dashscope.api_key = api_key + dashscope.api_key = self.config.api_key + + # Set base URL if provided + if self.config.api_base_url: + dashscope.base_http_api_url = self.config.api_base_url def rerank( self, diff --git a/src/powermem/integrations/rerank/zai.py b/src/powermem/integrations/rerank/zai.py index 79d5ece..43f5317 100644 --- a/src/powermem/integrations/rerank/zai.py +++ b/src/powermem/integrations/rerank/zai.py @@ -28,7 +28,8 @@ def __init__(self, config: Optional[BaseRerankConfig] = None): super().__init__(config) # Set default model - self.config.model = self.config.model or "rerank" + if not self.config.model: + self.config.model = "rerank" # Check if httpx is available if httpx is None: @@ -36,17 +37,19 @@ def __init__(self, config: Optional[BaseRerankConfig] = None): "httpx is not installed. Please install it with: pip install httpx" ) - # Set API key - api_key = self.config.api_key or os.getenv("ZAI_API_KEY") - if not api_key: + # Validate API key (config already handles env var loading) + if not self.config.api_key: raise ValueError( - "API key is required. Set ZAI_API_KEY environment variable or pass api_key in config." + "API key is required. Set ZAI_API_KEY or RERANK_API_KEY environment variable, " + "or pass api_key in config." ) - self.api_key = api_key - self.api_base_url = getattr(self.config, 'api_base_url', None) or os.getenv( - "ZAI_API_BASE_URL", "https://open.bigmodel.cn/api/paas/v4/rerank" - ) + # Set API key and base URL + self.api_key = self.config.api_key + self.api_base_url = self.config.api_base_url or "https://open.bigmodel.cn/api/paas/v4/rerank" + + # Use http_client from config if available + self.http_client = self.config.http_client def rerank( self, @@ -101,9 +104,9 @@ def rerank( payload["top_n"] = top_n try: - # Make API request - with httpx.Client(timeout=60.0) as client: - response = client.post( + # Use configured http_client or create temporary one + if self.http_client: + response = self.http_client.post( self.api_base_url, json=payload, headers={ @@ -111,10 +114,23 @@ def rerank( "Authorization": f"Bearer {self.api_key}", }, ) - - # Check response status response.raise_for_status() result = response.json() + else: + # Make API request with temporary client + with httpx.Client(timeout=60.0) as client: + response = client.post( + self.api_base_url, + json=payload, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}", + }, + ) + + # Check response status + response.raise_for_status() + result = response.json() # Parse results results = [] diff --git a/src/powermem/storage/__init__.py b/src/powermem/storage/__init__.py index f903c87..b232773 100644 --- a/src/powermem/storage/__init__.py +++ b/src/powermem/storage/__init__.py @@ -6,9 +6,12 @@ from .base import VectorStoreBase from .factory import VectorStoreFactory, GraphStoreFactory +from .config.base import BaseVectorStoreConfig, BaseGraphStoreConfig __all__ = [ "VectorStoreBase", "VectorStoreFactory", - "GraphStoreFactory", + "GraphStoreFactory", + "BaseVectorStoreConfig", + "BaseGraphStoreConfig", ] diff --git a/src/powermem/storage/config/__init__.py b/src/powermem/storage/config/__init__.py new file mode 100644 index 0000000..85c9f4b --- /dev/null +++ b/src/powermem/storage/config/__init__.py @@ -0,0 +1,19 @@ +""" +Storage configuration module + +This module provides configuration classes for different storage providers. +""" + +from .base import BaseVectorStoreConfig, BaseGraphStoreConfig +from .oceanbase import OceanBaseConfig, OceanBaseGraphConfig +from .pgvector import PGVectorConfig +from .sqlite import SQLiteConfig + +__all__ = [ + "BaseVectorStoreConfig", + "BaseGraphStoreConfig", + "OceanBaseConfig", + "OceanBaseGraphConfig", + "PGVectorConfig", + "SQLiteConfig", +] diff --git a/src/powermem/storage/config/base.py b/src/powermem/storage/config/base.py index 392f827..0e6dfde 100644 --- a/src/powermem/storage/config/base.py +++ b/src/powermem/storage/config/base.py @@ -1,13 +1,265 @@ -from abc import ABC -from typing import Any, Dict +from typing import Any, ClassVar, Dict, Optional, Union +from pydantic import AliasChoices, Field +from pydantic_settings import BaseSettings +from powermem.settings import settings_config -from pydantic import BaseModel, model_validator - -class BaseVectorStoreConfig(BaseModel, ABC): +class BaseVectorStoreConfig(BaseSettings): """ Base configuration class for all vector store providers. - This class provides common validation logic that is shared + This class provides common fields and validation logic shared across all vector store implementations. - """ \ No newline at end of file + """ + + # Model config + model_config = settings_config("VECTOR_STORE_", extra="allow", env_file=None) + + # Registry mechanism (same as LLM/Rerank) + _provider_name: ClassVar[Optional[str]] = None + _class_path: ClassVar[Optional[str]] = None + _registry: ClassVar[dict[str, type["BaseVectorStoreConfig"]]] = {} + _class_paths: ClassVar[dict[str, str]] = {} + + # Common fields across all providers + collection_name: str = Field( + default="memories", + description="Name of the collection/table" + ) + + embedding_model_dims: Optional[int] = Field( + default=None, + description="Dimension of embedding vectors" + ) + + @classmethod + def _register_provider(cls) -> None: + """Register provider in the global registry.""" + provider = getattr(cls, "_provider_name", None) + class_path = getattr(cls, "_class_path", None) + if provider: + BaseVectorStoreConfig._registry[provider] = cls + if class_path: + BaseVectorStoreConfig._class_paths[provider] = class_path + + def __init_subclass__(cls, **kwargs) -> None: + """Called when a class inherits from BaseVectorStoreConfig.""" + super().__init_subclass__(**kwargs) + cls._register_provider() + + @classmethod + def __pydantic_init_subclass__(cls, **kwargs) -> None: + """Called by Pydantic when a class inherits from BaseVectorStoreConfig.""" + super().__pydantic_init_subclass__(**kwargs) + cls._register_provider() + + @classmethod + def get_provider_config_cls(cls, provider: str) -> Optional[type["BaseVectorStoreConfig"]]: + """Get the config class for a specific provider.""" + return cls._registry.get(provider) + + @classmethod + def get_provider_class_path(cls, provider: str) -> Optional[str]: + """Get the class path for a specific provider.""" + return cls._class_paths.get(provider) + + @classmethod + def has_provider(cls, provider: str) -> bool: + """Check if a provider is registered.""" + return provider in cls._registry + + def to_component_dict(self) -> Dict[str, Any]: + """ + Convert config to component dictionary format. + + Returns: + Dict with 'provider' and 'config' keys + """ + return { + "provider": self._provider_name, + "config": self.model_dump(exclude_none=True) + } + + +class BaseGraphStoreConfig(BaseVectorStoreConfig): + """ + Base configuration class for all graph store providers. + + Inherits from BaseVectorStoreConfig to reuse connection and vector parameters. + Adds graph-specific fields like max_hops. + + Environment variable priority (via AliasChoices): + 1. GRAPH_STORE_* (highest priority) + 2. OCEANBASE_* (fallback to VectorStore config) + 3. Default values + """ + + model_config = settings_config("GRAPH_STORE_", extra="allow", env_file=None) + + _provider_name: ClassVar[Optional[str]] = None + _class_path: ClassVar[Optional[str]] = None + _registry: ClassVar[dict[str, type["BaseGraphStoreConfig"]]] = {} + _class_paths: ClassVar[dict[str, str]] = {} + + # Override connection fields with GRAPH_STORE_ fallback aliases + host: str = Field( + default="127.0.0.1", + validation_alias=AliasChoices( + "host", + "GRAPH_STORE_HOST", # Priority 1 + "OCEANBASE_HOST", # Priority 2 (fallback) + ), + description="Database server host" + ) + + port: str = Field( + default="2881", + validation_alias=AliasChoices( + "port", + "GRAPH_STORE_PORT", + "OCEANBASE_PORT", + ), + description="Database server port" + ) + + user: str = Field( + default="root@test", + validation_alias=AliasChoices( + "GRAPH_STORE_USER", + "OCEANBASE_USER", + "user", # avoid using system USER environment variable first + ), + description="Database username" + ) + + password: str = Field( + default="", + validation_alias=AliasChoices( + "password", + "GRAPH_STORE_PASSWORD", + "OCEANBASE_PASSWORD", + ), + description="Database password" + ) + + db_name: str = Field( + default="test", + validation_alias=AliasChoices( + "db_name", + "GRAPH_STORE_DB_NAME", + "OCEANBASE_DATABASE", + ), + description="Database name" + ) + + # Override vector configuration fields + vidx_metric_type: str = Field( + default="l2", + validation_alias=AliasChoices( + "vidx_metric_type", + "GRAPH_STORE_VECTOR_METRIC_TYPE", + "OCEANBASE_VECTOR_METRIC_TYPE", + ), + description="Distance metric (l2, inner_product, cosine)" + ) + + index_type: str = Field( + default="HNSW", + validation_alias=AliasChoices( + "index_type", + "GRAPH_STORE_INDEX_TYPE", + "OCEANBASE_INDEX_TYPE", + ), + description="Type of vector index (HNSW, IVF, FLAT, etc.)" + ) + + embedding_model_dims: Optional[int] = Field( + default=None, + validation_alias=AliasChoices( + "embedding_model_dims", + "GRAPH_STORE_EMBEDDING_MODEL_DIMS", + "OCEANBASE_EMBEDDING_MODEL_DIMS", + ), + description="Dimension of embedding vectors" + ) + + # Graph-specific fields + max_hops: int = Field( + default=3, + validation_alias=AliasChoices( + "max_hops", + "GRAPH_STORE_MAX_HOPS", + ), + description="Maximum number of hops for multi-hop graph search" + ) + + # GraphStore metadata fields (from GraphStoreConfig) + # Note: BaseLLMConfig is imported lazily to avoid circular imports + llm: Optional[Any] = Field( + default=None, + description="LLM configuration for querying the graph store (overrides global LLM)" + ) + custom_prompt: Optional[str] = Field( + default=None, + description="Custom prompt to fetch entities from the given text" + ) + custom_extract_relations_prompt: Optional[str] = Field( + default=None, + description="Custom prompt for extracting relations from text" + ) + custom_update_graph_prompt: Optional[str] = Field( + default=None, + description="Custom prompt for updating graph memories" + ) + custom_delete_relations_prompt: Optional[str] = Field( + default=None, + description="Custom prompt for deleting relations" + ) + + @classmethod + def _register_provider(cls) -> None: + """Register provider in the global registry.""" + provider = getattr(cls, "_provider_name", None) + class_path = getattr(cls, "_class_path", None) + if provider: + BaseGraphStoreConfig._registry[provider] = cls + if class_path: + BaseGraphStoreConfig._class_paths[provider] = class_path + + def __init_subclass__(cls, **kwargs) -> None: + """Called when a class inherits from BaseGraphStoreConfig.""" + super().__init_subclass__(**kwargs) + cls._register_provider() + + @classmethod + def __pydantic_init_subclass__(cls, **kwargs) -> None: + """Called by Pydantic when a class inherits from BaseGraphStoreConfig.""" + super().__pydantic_init_subclass__(**kwargs) + cls._register_provider() + + @classmethod + def get_provider_config_cls(cls, provider: str) -> Optional[type["BaseGraphStoreConfig"]]: + """Get the config class for a specific provider.""" + return cls._registry.get(provider) + + @classmethod + def get_provider_class_path(cls, provider: str) -> Optional[str]: + """Get the class path for a specific provider.""" + return cls._class_paths.get(provider) + + @classmethod + def has_provider(cls, provider: str) -> bool: + """Check if a provider is registered.""" + return provider in cls._registry + + def to_component_dict(self) -> Dict[str, Any]: + """ + Convert config to component dictionary format. + + Returns: + Dict with 'provider' and 'config' keys + """ + return { + "provider": self._provider_name, + "config": self.model_dump(exclude_none=True) + } \ No newline at end of file diff --git a/src/powermem/storage/config/oceanbase.py b/src/powermem/storage/config/oceanbase.py index b9961c0..22c0adc 100644 --- a/src/powermem/storage/config/oceanbase.py +++ b/src/powermem/storage/config/oceanbase.py @@ -1,59 +1,234 @@ from typing import Any, ClassVar, Dict, Optional -from pydantic import Field, model_validator +from pydantic import AliasChoices, Field, model_validator +from powermem.settings import settings_config -from powermem.storage.config.base import BaseVectorStoreConfig +from powermem.storage.config.base import BaseVectorStoreConfig, BaseGraphStoreConfig class OceanBaseConfig(BaseVectorStoreConfig): + _provider_name = "oceanbase" + _class_path = "powermem.storage.oceanbase.oceanbase.OceanBaseVectorStore" + try: from pyobvector import ObVecClient except ImportError: raise ImportError("The 'pyobvector' library is required. Please install it using 'pip install pyobvector'.") ObVecClient: ClassVar[type] = ObVecClient - collection_name: str = Field("power_mem", description="Default name for the collection") + model_config = settings_config("VECTOR_STORE_", extra="forbid", env_file=None) + + collection_name: str = Field( + default="power_mem", + validation_alias=AliasChoices( + "collection_name", + "VECTOR_STORE_COLLECTION_NAME", + "OCEANBASE_COLLECTION", + ), + description="Default name for the collection" + ) # Connection parameters - host: str = Field("127.0.0.1", description="OceanBase server host") - port: str = Field("2881", description="OceanBase server port") - user: str = Field("root@test", description="OceanBase username") - password: str = Field("", description="OceanBase password") - db_name: str = Field("test", description="OceanBase database name") + host: str = Field( + default="127.0.0.1", + validation_alias=AliasChoices( + "host", + "OCEANBASE_HOST", + ), + description="OceanBase server host" + ) + + port: str = Field( + default="2881", + validation_alias=AliasChoices( + "port", + "OCEANBASE_PORT", + ), + description="OceanBase server port" + ) + + user: str = Field( + default="root@test", + validation_alias=AliasChoices( + "OCEANBASE_USER", + "user", # avoid using system USER environment variable first + ), + description="OceanBase username" + ) + + password: str = Field( + default="", + validation_alias=AliasChoices( + "password", + "OCEANBASE_PASSWORD", + ), + description="OceanBase password" + ) + + db_name: str = Field( + default="test", + validation_alias=AliasChoices( + "db_name", + "OCEANBASE_DATABASE", + ), + description="OceanBase database name" + ) + + connection_args: Optional[dict] = Field( + default=None, + validation_alias=AliasChoices( + "connection_args", + ), + description="OceanBase connection args" + ) # Vector index parameters - index_type: str = Field("HNSW", description="Type of vector index (HNSW, IVF, FLAT, etc.)") - vidx_metric_type: str = Field("l2", description="Distance metric (l2, inner_product, cosine)") - embedding_model_dims: Optional[int] = Field(None, description="Dimension of vectors") + index_type: str = Field( + default="HNSW", + validation_alias=AliasChoices( + "index_type", + "OCEANBASE_INDEX_TYPE", + ), + description="Type of vector index (HNSW, IVF, FLAT, etc.)" + ) + + vidx_metric_type: str = Field( + default="l2", + validation_alias=AliasChoices( + "vidx_metric_type", + "OCEANBASE_VECTOR_METRIC_TYPE", + ), + description="Distance metric (l2, inner_product, cosine)" + ) + + embedding_model_dims: Optional[int] = Field( + default=None, + validation_alias=AliasChoices( + "embedding_model_dims", + "OCEANBASE_EMBEDDING_MODEL_DIMS", + ), + description="Dimension of vectors" + ) # Advanced parameters - vidx_algo_params: Optional[Dict[str, Any]] = Field(None, description="Index algorithm parameters") - normalize: bool = Field(False, description="Whether to normalize vectors") - include_sparse: bool = Field(False, description="Whether to include sparse vector support") - hybrid_search: bool = Field(True, description="Whether to enable hybrid search") - auto_configure_vector_index: bool = Field(True, - description="Whether to automatically configure vector index settings") + vidx_algo_params: Optional[Dict[str, Any]] = Field( + default=None, + description="Index algorithm parameters" + ) + + normalize: bool = Field( + default=False, + description="Whether to normalize vectors" + ) + + include_sparse: bool = Field( + default=False, + validation_alias=AliasChoices( + "include_sparse", + "OCEANBASE_INCLUDE_SPARSE", + "SPARSE_VECTOR_ENABLE", + ), + description="Whether to include sparse vector support" + ) + + hybrid_search: bool = Field( + default=True, + validation_alias=AliasChoices( + "hybrid_search", + ), + description="Whether to enable hybrid search" + ) + + enable_native_hybrid: bool = Field( + default=False, + validation_alias=AliasChoices( + "enable_native_hybrid", + "OCEANBASE_ENABLE_NATIVE_HYBRID", + ), + description="Whether to enable OceanBase native hybrid search" + ) + + auto_configure_vector_index: bool = Field( + default=True, + description="Whether to automatically configure vector index settings" + ) # Fulltext search parameters - fulltext_parser: str = Field("ik", description="Fulltext parser type (ik, ngram, ngram2, beng, space)") + fulltext_parser: str = Field( + default="ik", + description="Fulltext parser type (ik, ngram, ngram2, beng, space)" + ) # Field names - primary_field: str = Field("id", description="Primary key field name") - vector_field: str = Field("embedding", description="Vector field name") - text_field: str = Field("document", description="Text field name") - metadata_field: str = Field("metadata", description="Metadata field name") - vidx_name: str = Field("vidx", description="Vector index name") - - vector_weight: float = Field(0.5, description="Weight for vector search") - fts_weight: float = Field(0.5, description="Weight for fulltext search") - sparse_weight: Optional[float] = Field(None, description="Weight for sparse vector search") + primary_field: str = Field( + default="id", + validation_alias=AliasChoices( + "primary_field", + "OCEANBASE_PRIMARY_FIELD", + ), + description="Primary key field name" + ) + + vector_field: str = Field( + default="embedding", + validation_alias=AliasChoices( + "vector_field", + "OCEANBASE_VECTOR_FIELD", + ), + description="Vector field name" + ) + + text_field: str = Field( + default="document", + validation_alias=AliasChoices( + "text_field", + "OCEANBASE_TEXT_FIELD", + ), + description="Text field name" + ) + + metadata_field: str = Field( + default="metadata", + validation_alias=AliasChoices( + "metadata_field", + "OCEANBASE_METADATA_FIELD", + ), + description="Metadata field name" + ) + + vidx_name: str = Field( + default="vidx", + validation_alias=AliasChoices( + "vidx_name", + "OCEANBASE_VIDX_NAME", + ), + description="Vector index name" + ) - model_config = { - "arbitrary_types_allowed": True, - } + vector_weight: float = Field( + default=0.5, + description="Weight for vector search" + ) + + fts_weight: float = Field( + default=0.5, + description="Weight for fulltext search" + ) + + sparse_weight: Optional[float] = Field( + default=None, + description="Weight for sparse vector search" + ) -class OceanBaseGraphConfig(OceanBaseConfig): - # Graph search parameters - max_hops: int = Field(3, description="Maximum number of hops for multi-hop graph search") +class OceanBaseGraphConfig(BaseGraphStoreConfig): + """Configuration for OceanBase graph store.""" + + _provider_name = "oceanbase" + _class_path = "powermem.storage.oceanbase.oceanbase_graph.MemoryGraph" + + model_config = settings_config("GRAPH_STORE_", extra="forbid", env_file=None) + + # All fields (connection, vector, max_hops) are inherited from BaseGraphStoreConfig + # No additional fields needed for OceanBase GraphStore at this time diff --git a/src/powermem/storage/config/pgvector.py b/src/powermem/storage/config/pgvector.py index 772b322..f777c0c 100644 --- a/src/powermem/storage/config/pgvector.py +++ b/src/powermem/storage/config/pgvector.py @@ -1,52 +1,140 @@ from typing import Any, Optional -from pydantic import Field, model_validator +from pydantic import AliasChoices, Field, model_validator +from powermem.settings import settings_config from powermem.storage.config.base import BaseVectorStoreConfig class PGVectorConfig(BaseVectorStoreConfig): - dbname: str = Field("postgres", description="Default name for the database") - collection_name: str = Field("power_mem", description="Default name for the collection") - embedding_model_dims: Optional[int] = Field(1536, description="Dimensions of the embedding model") - user: Optional[str] = Field(None, description="Database user") - password: Optional[str] = Field(None, description="Database password") - host: Optional[str] = Field(None, description="Database host. Default is 127.0.0.1") - port: Optional[int] = Field(None, description="Database port. Default is 1536") - diskann: Optional[bool] = Field(False, description="Use diskann for approximate nearest neighbors search") - hnsw: Optional[bool] = Field(True, description="Use hnsw for faster search") - minconn: Optional[int] = Field(1, description="Minimum number of connections in the pool") - maxconn: Optional[int] = Field(5, description="Maximum number of connections in the pool") - # New SSL and connection options - sslmode: Optional[str] = Field(None, - description="SSL mode for PostgreSQL connection (e.g., 'require', 'prefer', 'disable')") - connection_string: Optional[str] = Field(None, - description="PostgreSQL connection string (overrides individual connection parameters)") - connection_pool: Optional[Any] = Field(None, - description="psycopg connection pool object (overrides connection string and individual parameters)") + _provider_name = "pgvector" + _class_path = "powermem.storage.pgvector.pgvector.PGVectorStore" + + model_config = settings_config("VECTOR_STORE_", extra="forbid", env_file=None) + + dbname: str = Field( + default="postgres", + validation_alias=AliasChoices( + "dbname", + "POSTGRES_DATABASE", + ), + description="Default name for the database" + ) + + collection_name: str = Field( + default="power_mem", + validation_alias=AliasChoices( + "collection_name", + "POSTGRES_COLLECTION", + ), + description="Default name for the collection" + ) + + embedding_model_dims: Optional[int] = Field( + default=1536, + validation_alias=AliasChoices( + "embedding_model_dims", + "POSTGRES_EMBEDDING_MODEL_DIMS", + ), + description="Dimensions of the embedding model" + ) + + user: Optional[str] = Field( + default=None, + validation_alias=AliasChoices( + "POSTGRES_USER", + "user", # avoid using system USER environment variable first + ), + description="Database user" + ) + + password: Optional[str] = Field( + default=None, + validation_alias=AliasChoices( + "password", + "POSTGRES_PASSWORD", + ), + description="Database password" + ) + + host: Optional[str] = Field( + default=None, + validation_alias=AliasChoices( + "host", + "POSTGRES_HOST", + ), + description="Database host. Default is 127.0.0.1" + ) + + port: Optional[int] = Field( + default=None, + validation_alias=AliasChoices( + "port", + "POSTGRES_PORT", + ), + description="Database port. Default is 5432" + ) + + diskann: Optional[bool] = Field( + default=False, + validation_alias=AliasChoices( + "diskann", + "POSTGRES_DISKANN", + ), + description="Use diskann for approximate nearest neighbors search" + ) + + hnsw: Optional[bool] = Field( + default=True, + validation_alias=AliasChoices( + "hnsw", + "POSTGRES_HNSW", + ), + description="Use hnsw for faster search" + ) + + minconn: Optional[int] = Field( + default=1, + description="Minimum number of connections in the pool" + ) + + maxconn: Optional[int] = Field( + default=5, + description="Maximum number of connections in the pool" + ) + + sslmode: Optional[str] = Field( + default=None, + validation_alias=AliasChoices( + "sslmode", + "DATABASE_SSLMODE", + ), + description="SSL mode for PostgreSQL connection" + ) + + connection_string: Optional[str] = Field( + default=None, + description="PostgreSQL connection string" + ) + + connection_pool: Optional[Any] = Field( + default=None, + description="psycopg connection pool object" + ) @model_validator(mode="before") @classmethod def check_auth_and_connection(cls, values): - # If connection_pool is provided, skip validation of individual connection parameters if values.get("connection_pool") is not None: return values - - # If connection_string is provided, skip validation of individual connection parameters if values.get("connection_string") is not None: return values - - # Otherwise, validate individual connection parameters user, password = values.get("user"), values.get("password") host, port = values.get("host"), values.get("port") - - # Only validate if user explicitly provided values (not using defaults) if user is not None or password is not None: if not user or not password: - raise ValueError("Both 'user' and 'password' must be provided when not using connection_string.") - + raise ValueError("Both 'user' and 'password' must be provided.") if host is not None or port is not None: if not host or not port: - raise ValueError("Both 'host' and 'port' must be provided when not using connection_string.") - + raise ValueError("Both 'host' and 'port' must be provided.") return values diff --git a/src/powermem/storage/config/sqlite.py b/src/powermem/storage/config/sqlite.py index 1c988a5..a8c7aba 100644 --- a/src/powermem/storage/config/sqlite.py +++ b/src/powermem/storage/config/sqlite.py @@ -1,6 +1,7 @@ from typing import Optional -from pydantic import Field +from pydantic import AliasChoices, Field +from powermem.settings import settings_config from powermem.storage.config.base import BaseVectorStoreConfig @@ -8,20 +9,44 @@ class SQLiteConfig(BaseVectorStoreConfig): """Configuration for SQLite vector store.""" + _provider_name = "sqlite" + _class_path = "powermem.storage.sqlite.sqlite_vector_store.SQLiteVectorStore" + + model_config = settings_config("VECTOR_STORE_", extra="forbid", env_file=None) + database_path: str = Field( default="./data/powermem_dev.db", + validation_alias=AliasChoices( + "database_path", + "SQLITE_PATH", + ), description="Path to SQLite database file" ) + collection_name: str = Field( default="memories", + validation_alias=AliasChoices( + "collection_name", + "SQLITE_COLLECTION", + ), description="Name of the collection/table" ) + enable_wal: bool = Field( default=True, + validation_alias=AliasChoices( + "enable_wal", + "SQLITE_ENABLE_WAL", + ), description="Enable Write-Ahead Logging for better concurrency" ) + timeout: int = Field( default=30, + validation_alias=AliasChoices( + "timeout", + "SQLITE_TIMEOUT", + ), description="Connection timeout in seconds" ) diff --git a/src/powermem/storage/configs.py b/src/powermem/storage/configs.py deleted file mode 100644 index f5c3e8f..0000000 --- a/src/powermem/storage/configs.py +++ /dev/null @@ -1,162 +0,0 @@ -""" -Storage configuration management - -This module handles storage configuration and validation. -""" - -from typing import Dict, Optional, Union - -from pydantic import BaseModel, Field, model_validator - -from powermem.integrations.llm.config.base import BaseLLMConfig -from powermem.storage.config.oceanbase import OceanBaseGraphConfig -from powermem.storage.factory import VectorStoreFactory - - -class VectorStoreConfig(BaseModel): - provider: str = Field( - description="Provider of the vector store (e.g., 'oceanbase', 'pgvector')", - default="oceanbase", - ) - config: Optional[Dict] = Field( - description="Configuration for the specific vector store", - default=None - ) - - _provider_configs: Dict[str, str] = { - "oceanbase": "OceanBaseConfig", - "pgvector": "PGVectorConfig", - "sqlite": "SQLiteConfig", - } - - @model_validator(mode="after") - def validate_config(self) -> "VectorStoreConfig": - """ - Validate the configuration without converting to provider-specific config class. - The conversion is handled by VectorStoreFactory.create() when needed. - """ - provider = self.provider - config = self.config - - if provider is not None and provider == "postgres": - provider = "pgvector" - - # Check both initialized providers and Factory-registered providers - if provider not in self._provider_configs and provider not in VectorStoreFactory.provider_to_class: - raise ValueError(f"Unsupported vector store provider: {provider}") - - if config is None: - self.config = {} - return self - - if not isinstance(config, dict): - raise ValueError(f"Config must be a dictionary, got {type(config)}") - - # Handle connection_args for backward compatibility - # If connection_args exists, flatten it into the main config - if "connection_args" in config: - connection_args = config.pop("connection_args") - if isinstance(connection_args, dict): - # Merge connection_args into config (connection_args values take precedence) - for key, value in connection_args.items(): - if key not in config: - # Convert port to string if it's an int (for OceanBase compatibility) - if key == "port" and isinstance(value, int): - config[key] = str(value) - else: - config[key] = value - self.config = config - - # Convert port to string if it's an int (for OceanBase compatibility) - # This handles both direct port field and port from connection_args - if "port" in config and isinstance(config["port"], int): - config["port"] = str(config["port"]) - self.config = config - - # Validate config by attempting to create provider-specific config instance - # This ensures the config has valid fields, but we don't store the converted object - if provider in self._provider_configs: - module = __import__( - f"powermem.storage.config.{provider}", - fromlist=[self._provider_configs[provider]], - ) - config_class = getattr(module, self._provider_configs[provider]) - - # Add default path if needed - if "path" not in config and "path" in config_class.__annotations__: - config["path"] = f"/tmp/{provider}" - self.config = config - - # Validate by creating instance (throws error if invalid) - try: - config_class(**config) - except Exception as e: - raise ValueError(f"Invalid configuration for {provider}: {e}") - - # Keep config as dict, don't convert to config_class instance - return self - -class GraphStoreConfig(BaseModel): - enabled: bool = Field( - description="Whether to enable graph store", - default=False, - ) - provider: str = Field( - description="Provider of the data store (e.g., 'oceanbase')", - default="oceanbase", - ) - config: Optional[Union[Dict, OceanBaseGraphConfig]] = Field( - description="Configuration for the specific data store", - default=None - ) - llm: Optional[BaseLLMConfig] = Field( - description="LLM configuration for querying the graph store", - default=None - ) - custom_prompt: Optional[str] = Field( - description="Custom prompt to fetch entities from the given text", - default=None - ) - custom_extract_relations_prompt: Optional[str] = Field( - description="Custom prompt for extracting relations from text", - default=None - ) - custom_update_graph_prompt: Optional[str] = Field( - description="Custom prompt for updating graph memories", - default=None - ) - custom_delete_relations_prompt: Optional[str] = Field( - description="Custom prompt for deleting relations", - default=None - ) - - @model_validator(mode="after") - def validate_config(self) -> "GraphStoreConfig": - """ - Validate the configuration without converting to provider-specific config class. - Keep config as dict for consistency. - """ - if self.config is None: - self.config = {} - return self - - # If config is a Pydantic BaseModel instance, convert it to dict - if isinstance(self.config, BaseModel): - self.config = self.config.model_dump() - - if not isinstance(self.config, dict): - raise ValueError(f"Config must be a dictionary or BaseModel instance, got {type(self.config)}") - - - # Validate config based on provider - provider = self.provider - if provider == "oceanbase": - try: - OceanBaseGraphConfig(**self.config) - except Exception as e: - raise ValueError(f"Invalid configuration for {provider}: {e}") - else: - raise ValueError(f"Unsupported graph store provider: {provider}") - - # Keep config as dict, don't convert - return self \ No newline at end of file diff --git a/src/powermem/storage/factory.py b/src/powermem/storage/factory.py index f0e95e3..5e15252 100644 --- a/src/powermem/storage/factory.py +++ b/src/powermem/storage/factory.py @@ -6,6 +6,12 @@ import importlib +# Import all provider configs to trigger auto-registration +from powermem.storage.config.base import BaseVectorStoreConfig, BaseGraphStoreConfig +from powermem.storage.config.oceanbase import OceanBaseConfig, OceanBaseGraphConfig +from powermem.storage.config.pgvector import PGVectorConfig +from powermem.storage.config.sqlite import SQLiteConfig + def load_class(class_type): module_path, class_name = class_type.rsplit(".", 1) @@ -13,23 +19,76 @@ def load_class(class_type): return getattr(module, class_name) class VectorStoreFactory: - provider_to_class = { - "oceanbase": "powermem.storage.oceanbase.oceanbase.OceanBaseVectorStore", - "sqlite": "powermem.storage.sqlite.sqlite_vector_store.SQLiteVectorStore", - "pgvector": "powermem.storage.pgvector.pgvector.PGVectorStore", - "postgres": "powermem.storage.pgvector.pgvector.PGVectorStore", # Alias for pgvector - } - @classmethod def create(cls, provider_name, config): - class_type = cls.provider_to_class.get(provider_name) - if class_type: - if not isinstance(config, dict): - config = config.model_dump() - vector_store_instance = load_class(class_type) - return vector_store_instance(**config) - else: + """ + Create a VectorStore instance with the appropriate configuration. + + Args: + provider_name (str): The provider name (e.g., 'oceanbase', 'pgvector', 'sqlite') + config: Configuration object or dict. If dict, will convert to provider config + + Returns: + Configured VectorStore instance + + Raises: + ValueError: If provider is not supported + """ + # Handle postgres alias + if provider_name == "postgres": + provider_name = "pgvector" + + # 1. Get class_path from registry + class_path = BaseVectorStoreConfig.get_provider_class_path(provider_name) + if not class_path: raise ValueError(f"Unsupported VectorStore provider: {provider_name}") + + # 2. Get config_cls from registry + config_cls = BaseVectorStoreConfig.get_provider_config_cls(provider_name) or BaseVectorStoreConfig + + # 3. Handle config parameter + if isinstance(config, dict): + # Convert dict to provider config instance + provider_config = config_cls(**config) + elif isinstance(config, BaseVectorStoreConfig): + # Use config instance directly + provider_config = config + else: + raise TypeError(f"config must be BaseVectorStoreConfig or dict, got {type(config)}") + + # 4. Export to dict for VectorStore constructor + config_dict = provider_config.model_dump(exclude_none=True) + + # 5. Create VectorStore instance + vector_store_class = load_class(class_path) + return vector_store_class(**config_dict) + + @classmethod + def register_provider(cls, name: str, class_path: str, config_class=None): + """ + Register a new vector store provider. + + Args: + name (str): Provider name + class_path (str): Full path to VectorStore class + config_class: Configuration class for the provider (defaults to BaseVectorStoreConfig) + """ + if config_class is None: + config_class = BaseVectorStoreConfig + + # Register directly in BaseVectorStoreConfig registry + BaseVectorStoreConfig._registry[name] = config_class + BaseVectorStoreConfig._class_paths[name] = class_path + + @classmethod + def get_supported_providers(cls) -> list: + """ + Get list of supported providers. + + Returns: + list: List of supported provider names + """ + return list(BaseVectorStoreConfig._registry.keys()) @classmethod def reset(cls, instance): @@ -43,17 +102,70 @@ class GraphStoreFactory: Usage: GraphStoreFactory.create(provider_name, config) """ - provider_to_class = { - "oceanbase": "powermem.storage.oceanbase.oceanbase_graph.MemoryGraph", - "default": "powermem.storage.oceanbase.oceanbase_graph.MemoryGraph", - } - @classmethod def create(cls, provider_name, config): - class_type = cls.provider_to_class.get(provider_name, cls.provider_to_class["default"]) - try: - GraphClass = load_class(class_type) - except (ImportError, AttributeError) as e: - raise ImportError(f"Could not import MemoryGraph for provider '{provider_name}': {e}") - return GraphClass(config) + """ + Create a GraphStore instance with the appropriate configuration. + + Args: + provider_name (str): The provider name (e.g., 'oceanbase') + config: Configuration object or dict. If dict, will convert to provider config + + Returns: + Configured GraphStore instance + + Raises: + ValueError: If provider is not supported + """ + # 1. Get class_path from registry + class_path = BaseGraphStoreConfig.get_provider_class_path(provider_name) + if not class_path: + raise ValueError(f"Unsupported GraphStore provider: {provider_name}") + + # 2. Get config_cls from registry + config_cls = BaseGraphStoreConfig.get_provider_config_cls(provider_name) or BaseGraphStoreConfig + + # 3. Handle config parameter + if isinstance(config, dict): + # Convert dict to provider config instance + provider_config = config_cls(**config) + elif isinstance(config, BaseGraphStoreConfig): + # Use config instance directly + provider_config = config + else: + raise TypeError(f"config must be BaseGraphStoreConfig or dict, got {type(config)}") + + # 4. Export to dict for GraphStore constructor + config_dict = provider_config.model_dump(exclude_none=True) + + # 5. Create GraphStore instance + graph_store_class = load_class(class_path) + return graph_store_class(config_dict) + + @classmethod + def register_provider(cls, name: str, class_path: str, config_class=None): + """ + Register a new graph store provider. + + Args: + name (str): Provider name + class_path (str): Full path to GraphStore class + config_class: Configuration class for the provider (defaults to BaseGraphStoreConfig) + """ + if config_class is None: + config_class = BaseGraphStoreConfig + + # Register directly in BaseGraphStoreConfig registry + BaseGraphStoreConfig._registry[name] = config_class + BaseGraphStoreConfig._class_paths[name] = class_path + + @classmethod + def get_supported_providers(cls) -> list: + """ + Get list of supported providers. + + Returns: + list: List of supported provider names + """ + return list(BaseGraphStoreConfig._registry.keys()) diff --git a/src/powermem/storage/oceanbase/oceanbase.py b/src/powermem/storage/oceanbase/oceanbase.py index 8873f58..931b860 100644 --- a/src/powermem/storage/oceanbase/oceanbase.py +++ b/src/powermem/storage/oceanbase/oceanbase.py @@ -1166,6 +1166,7 @@ def _native_hybrid_search( ) # 3. Build search parameters JSON + safe_query = query.replace("'", "''") if query else query search_params = { "query": { "bool": { @@ -1173,7 +1174,7 @@ def _native_hybrid_search( { "query_string": { "fields": [self.fulltext_field], - "query": query + "query": safe_query } } ] diff --git a/src/powermem/storage/oceanbase/oceanbase_graph.py b/src/powermem/storage/oceanbase/oceanbase_graph.py index 03ab8f4..afa7ce2 100644 --- a/src/powermem/storage/oceanbase/oceanbase_graph.py +++ b/src/powermem/storage/oceanbase/oceanbase_graph.py @@ -73,7 +73,16 @@ def __init__(self, config: Any) -> None: self.config = config # Get OceanBase config - ob_config = self.config.graph_store.config + # Support both old format (graph_store.config) and new format (graph_store is the config) + if self.config.graph_store: + if hasattr(self.config.graph_store, 'config'): + # Old format: GraphStoreConfig with .config field + ob_config = self.config.graph_store.config + else: + # New format: graph_store is BaseGraphStoreConfig itself + ob_config = self.config.graph_store + else: + ob_config = {} # Helper function to get config value (supports both dict and object) def get_config_value(key: str, default: Any = None) -> Any: @@ -141,7 +150,7 @@ def get_config_value(key: str, default: Any = None) -> Any: # Pass graph_store config or full config to prompts graph_config = {} if self.config.graph_store: - # Convert GraphStoreConfig to dict if needed + # Convert BaseGraphStoreConfig to dict if needed if hasattr(self.config.graph_store, 'model_dump'): graph_config = self.config.graph_store.model_dump() elif isinstance(self.config.graph_store, dict): diff --git a/src/powermem/user_memory/storage/base.py b/src/powermem/user_memory/storage/base.py index 2dc7e8e..e3d0e06 100644 --- a/src/powermem/user_memory/storage/base.py +++ b/src/powermem/user_memory/storage/base.py @@ -5,7 +5,7 @@ """ from abc import ABC, abstractmethod -from typing import Optional, Dict, Any, List +from typing import Optional, Dict, Any, List, ClassVar class UserProfileStoreBase(ABC): @@ -14,6 +14,35 @@ class UserProfileStoreBase(ABC): This class defines the interface that all user profile storage backends must implement. """ + + # Registry mechanism (same as VectorStore/GraphStore) + _provider_name: ClassVar[Optional[str]] = None + _class_path: ClassVar[Optional[str]] = None + _registry: ClassVar[Dict[str, type["UserProfileStoreBase"]]] = {} + _class_paths: ClassVar[Dict[str, str]] = {} + + def __init_subclass__(cls, **kwargs) -> None: + """Called when a class inherits from UserProfileStoreBase.""" + super().__init_subclass__(**kwargs) + cls._register_provider() + + @classmethod + def _register_provider(cls) -> None: + """Register provider in the global registry.""" + provider = getattr(cls, "_provider_name", None) + class_path = getattr(cls, "_class_path", None) + if provider: + UserProfileStoreBase._registry[provider] = cls + if class_path: + UserProfileStoreBase._class_paths[provider] = class_path + + @classmethod + def get_provider_class_path(cls, provider: str) -> Optional[str]: + """Get the class path for a specific provider.""" + provider = provider.lower() + if provider == "postgres": + provider = "pgvector" + return cls._class_paths.get(provider) @abstractmethod def save_profile( diff --git a/src/powermem/user_memory/storage/factory.py b/src/powermem/user_memory/storage/factory.py index ec54947..f0d5af8 100644 --- a/src/powermem/user_memory/storage/factory.py +++ b/src/powermem/user_memory/storage/factory.py @@ -6,7 +6,12 @@ import importlib import logging -from typing import Dict +from typing import Dict, Optional + +from .base import UserProfileStoreBase +# Import provider classes to trigger auto-registration via __init_subclass__ +from .user_profile import OceanBaseUserProfileStore # noqa: F401 +from .user_profile_sqlite import SQLiteUserProfileStore # noqa: F401 logger = logging.getLogger(__name__) @@ -23,11 +28,6 @@ class UserProfileStoreFactory: Factory for creating UserProfileStore instances for different storage providers. Usage: UserProfileStoreFactory.create(provider_name, config) """ - - provider_to_class = { - "oceanbase": "powermem.user_memory.storage.user_profile.OceanBaseUserProfileStore", - "sqlite": "powermem.user_memory.storage.user_profile_sqlite.SQLiteUserProfileStore", - } @classmethod def create(cls, provider_name: str, config: Dict): @@ -45,10 +45,14 @@ def create(cls, provider_name: str, config: Dict): ValueError: If the provider is not supported """ provider_name = provider_name.lower() - class_type = cls.provider_to_class.get(provider_name) + if provider_name == "postgres": + provider_name = "pgvector" + + # Get class path from registry (auto-registered via __init_subclass__) + class_path = UserProfileStoreBase.get_provider_class_path(provider_name) - if not class_type: - supported_providers = ", ".join(cls.provider_to_class.keys()) + if not class_path: + supported_providers = ", ".join(sorted(UserProfileStoreBase._class_paths.keys())) raise ValueError( f"Unsupported UserProfileStore provider: {provider_name}. " f"Currently supported providers are: {supported_providers}. " @@ -57,7 +61,7 @@ def create(cls, provider_name: str, config: Dict): ) try: - ProfileStoreClass = load_class(class_type) + ProfileStoreClass = load_class(class_path) return ProfileStoreClass(**config) except (ImportError, AttributeError) as e: raise ImportError( diff --git a/src/powermem/user_memory/storage/user_profile.py b/src/powermem/user_memory/storage/user_profile.py index d891b34..8e2ca98 100644 --- a/src/powermem/user_memory/storage/user_profile.py +++ b/src/powermem/user_memory/storage/user_profile.py @@ -28,6 +28,9 @@ class OceanBaseUserProfileStore(UserProfileStoreBase): """OceanBase-based user profile storage implementation""" + + _provider_name = "oceanbase" + _class_path = "powermem.user_memory.storage.user_profile.OceanBaseUserProfileStore" def __init__( self, diff --git a/src/powermem/user_memory/storage/user_profile_sqlite.py b/src/powermem/user_memory/storage/user_profile_sqlite.py index 7e0f36c..a7cbda2 100644 --- a/src/powermem/user_memory/storage/user_profile_sqlite.py +++ b/src/powermem/user_memory/storage/user_profile_sqlite.py @@ -20,6 +20,9 @@ class SQLiteUserProfileStore(UserProfileStoreBase): """SQLite-based user profile storage implementation""" + + _provider_name = "sqlite" + _class_path = "powermem.user_memory.storage.user_profile_sqlite.SQLiteUserProfileStore" def __init__( self, diff --git a/tests/regression/test_scenario_5_custom_integration.py b/tests/regression/test_scenario_5_custom_integration.py index 6833ff1..acfe53a 100644 --- a/tests/regression/test_scenario_5_custom_integration.py +++ b/tests/regression/test_scenario_5_custom_integration.py @@ -196,9 +196,7 @@ def test_step1_custom_llm_provider() -> None: LLMFactory.register_provider("custom", f"{__name__}.CustomLLM", CustomLLMConfig) # Also register custom vector store for testing - VectorStoreFactory.provider_to_class.update({ - "custom": f"{__name__}.CustomVectorStore" - }) + VectorStoreFactory.register_provider("custom", f"{__name__}.CustomVectorStore") print("✓ CustomLLM class defined") print("✓ Custom LLM provider registered successfully") @@ -495,9 +493,7 @@ def test_step3_custom_vector_store() -> None: from powermem.storage.factory import VectorStoreFactory - VectorStoreFactory.provider_to_class.update({ - "custom": f"{__name__}.CustomVectorStore" - }) + VectorStoreFactory.register_provider("custom", f"{__name__}.CustomVectorStore") print("✓ CustomVectorStore class defined") print("✓ Custom Vector Store provider registered successfully") @@ -925,14 +921,12 @@ def test_step5_fastapi_integration() -> None: from powermem.storage.factory import VectorStoreFactory # Ensure custom providers are registered - if 'custom' not in LLMFactory.provider_to_class: + if 'custom' not in LLMFactory.get_supported_providers(): LLMFactory.register_provider("custom", f"{__name__}.CustomLLM", CustomLLMConfig) if not BaseEmbedderConfig.has_provider("custom"): print("⚠ Custom embedder config is not registered") - if 'custom' not in VectorStoreFactory.provider_to_class: - VectorStoreFactory.provider_to_class.update({ - "custom": f"{__name__}.CustomVectorStore" - }) + if 'custom' not in VectorStoreFactory.get_supported_providers(): + VectorStoreFactory.register_provider("custom", f"{__name__}.CustomVectorStore") # Define Pydantic models for request/response class MemoryRequest(BaseModel): From 5711e22c6af66cd1974217441fc8a688473453a8 Mon Sep 17 00:00:00 2001 From: Chifang <40140008+Ripcord55@users.noreply.github.com> Date: Tue, 3 Feb 2026 10:15:24 +0800 Subject: [PATCH 06/23] oceanbase native language case (#220) --- tests/regression/test_native_language.py | 806 +++++++++++++++++++++++ 1 file changed, 806 insertions(+) create mode 100644 tests/regression/test_native_language.py diff --git a/tests/regression/test_native_language.py b/tests/regression/test_native_language.py new file mode 100644 index 0000000..16e4081 --- /dev/null +++ b/tests/regression/test_native_language.py @@ -0,0 +1,806 @@ +""" +User Profile Native Language Support - Test Script + +Test case coverage: +- Basic functionality tests (TC-001 ~ TC-006) +- Language coverage tests (TC-007 ~ TC-011) +- Boundary condition tests (TC-012 ~ TC-015) +- Compatibility tests (TC-016 ~ TC-018) +- API endpoint tests (TC-019 ~ TC-022) + +Usage: + pytest test_native_language.py -v + pytest test_native_language.py -v -k "TC001" # Run single test case + pytest test_native_language.py -v -m "api" # Run API tests only +""" + +import os +import sys +import json +import logging +import uuid +import requests +import pytest +from typing import Dict, Any, Optional, List + +# Add src to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "src")) + +from powermem import auto_config +from powermem.user_memory import UserMemory + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +# ==================== Fixtures ==================== + +@pytest.fixture(scope="module") +def config(): + """Provide shared configuration for all tests with qwen provider.""" + # Get base config from auto_config + base_config = auto_config() + + # Get QWEN_API_KEY from environment + qwen_api_key = os.getenv("QWEN_API_KEY") + + # Override LLM config with qwen provider + base_config["llm"] = { + "provider": "qwen", + "config": { + "api_key": qwen_api_key, + "model": "qwen3-max", + "temperature": 0.7, + "max_tokens": 1000, + } + } + + # Override embedder config with qwen provider + base_config["embedder"] = { + "provider": "qwen", + "config": { + "api_key": qwen_api_key, + "model": "text-embedding-v4", + "embedding_dims": 1536, + } + } + + return base_config + + +@pytest.fixture(scope="module") +def user_memory(config): + """Module-scoped fixture providing a shared UserMemory instance.""" + um = UserMemory(config=config, agent_id="test_native_language_agent") + yield um + + +@pytest.fixture(scope="module") +def api_client(): + """Provide API client for HTTP tests.""" + base_url = os.getenv("POWERMEM_API_URL", "http://localhost:8000") + api_key = os.getenv("POWERMEM_API_KEY", "key1") + return APIClient(base_url=base_url, api_key=api_key) + + +class APIClient: + """Simple API client for testing HTTP endpoints.""" + + def __init__(self, base_url: str = "http://localhost:8000", api_key: str = "key1"): + self.base_url = base_url.rstrip('/') + self.api_base = f"{self.base_url}/api/v1" + self.api_key = api_key + self.headers = { + "X-API-Key": api_key, + "Content-Type": "application/json" + } + + def post(self, endpoint: str, data: Dict[str, Any], timeout: int = 60) -> requests.Response: + """Send POST request.""" + url = f"{self.api_base}{endpoint}" + return requests.post(url, headers=self.headers, json=data, timeout=timeout) + + def get(self, endpoint: str, timeout: int = 30) -> requests.Response: + """Send GET request.""" + url = f"{self.api_base}{endpoint}" + return requests.get(url, headers=self.headers, timeout=timeout) + + def delete(self, endpoint: str, timeout: int = 30) -> requests.Response: + """Send DELETE request.""" + url = f"{self.api_base}{endpoint}" + return requests.delete(url, headers=self.headers, timeout=timeout) + + +# ==================== Helper Functions ==================== + +def print_test_result(test_id: str, messages: Any, params: Dict[str, Any], result: Dict[str, Any]): + """Print detailed test results""" + print(f"\n{'='*60}") + print(f"Test Case: {test_id}") + print(f"{'='*60}") + + # Input parameters + print(f"\n📥 Input Parameters:") + print(f" - native_language: {params.get('native_language', 'not specified')}") + print(f" - profile_type: {params.get('profile_type', 'content')}") + if params.get('include_roles'): + print(f" - include_roles: {params.get('include_roles')}") + if params.get('exclude_roles'): + print(f" - exclude_roles: {params.get('exclude_roles')}") + + # Input messages + print(f"\n📝 Input Messages:") + if isinstance(messages, list): + for msg in messages: + role = msg.get('role', 'unknown') + content = msg.get('content', '') + print(f" [{role}]: {content}") + else: + print(f" {messages}") + + # Output results + print(f"\n📤 Output Results:") + print(f" - profile_extracted: {result.get('profile_extracted', False)}") + + if result.get('profile_content'): + print(f" - profile_content: {result['profile_content']}") + + if result.get('topics'): + print(f" - topics: {json.dumps(result['topics'], ensure_ascii=False, indent=4)}") + + # Memory results + memory_results = result.get('results', []) + print(f"\n💾 Memory Storage Results (total {len(memory_results)} items):") + if memory_results: + for i, mem in enumerate(memory_results, 1): + print(f" [{i}] ID: {mem.get('id', 'N/A')}") + print(f" Memory: {mem.get('memory', 'N/A')}") + if mem.get('metadata'): + print(f" Metadata: {mem.get('metadata')}") + else: + print(" (No new memories)") + + print(f"\n{'='*60}") + print(f"✓ {test_id} Test Passed") + print(f"{'='*60}\n") + + +def has_chinese_chars(text: str) -> bool: + """Check if text contains Chinese characters.""" + return any('\u4e00' <= char <= '\u9fff' for char in text) + + +def has_japanese_chars(text: str) -> bool: + """Check if text contains Japanese characters (Hiragana, Katakana, or Kanji).""" + for char in text: + # Hiragana: U+3040 - U+309F + # Katakana: U+30A0 - U+30FF + # Kanji (CJK): U+4E00 - U+9FFF (shared with Chinese) + if '\u3040' <= char <= '\u309f' or '\u30a0' <= char <= '\u30ff': + return True + return False + + +def has_korean_chars(text: str) -> bool: + """Check if text contains Korean characters (Hangul).""" + return any('\uac00' <= char <= '\ud7a3' or '\u1100' <= char <= '\u11ff' for char in text) + + +def has_cyrillic_chars(text: str) -> bool: + """Check if text contains Cyrillic characters (Russian etc.).""" + return any('\u0400' <= char <= '\u04ff' for char in text) + + +def check_topics_keys_english(topics: Dict[str, Any]) -> bool: + """Check if all topic keys are in English (ASCII).""" + def _check_keys(d): + for key, value in d.items(): + if not key.replace('_', '').isascii(): + return False + if isinstance(value, dict): + if not _check_keys(value): + return False + return True + return _check_keys(topics) + + +def flatten_topics_values(topics: Dict[str, Any]) -> List[str]: + """Flatten all values from nested topics dict to a list.""" + values = [] + def _extract_values(d): + for value in d.values(): + if isinstance(value, dict): + _extract_values(value) + elif isinstance(value, str): + values.append(value) + elif isinstance(value, list): + for item in value: + if isinstance(item, str): + values.append(item) + _extract_values(topics) + return values + + +# ==================== Section 1: Basic Functionality Tests ==================== + +class TestBasicFunctionality: + """Basic functionality test cases TC-001 ~ TC-006""" + + def test_TC001_chinese_native_language_content(self, user_memory): + """TC-001: Extract unstructured profile with Chinese native language""" + user_id = "tc001_zh_content_user" + messages = [ + {"role": "user", "content": "I work in Beijing as a software engineer."}, + {"role": "assistant", "content": "That's great! What kind of projects do you work on?"} + ] + params = {"native_language": "zh", "profile_type": "content"} + + result = user_memory.add( + messages=messages, + user_id=user_id, + **params + ) + + assert result.get("profile_extracted") == True, "Profile should be extracted" + profile_content = result.get("profile_content", "") + assert profile_content, "profile_content should not be empty" + assert has_chinese_chars(profile_content), f"Profile should be in Chinese, actual: {profile_content}" + print_test_result("TC-001", messages, params, result) + + def test_TC002_chinese_native_language_topics(self, user_memory): + """TC-002: Extract structured profile (topics) with Chinese native language""" + user_id = "tc002_zh_topics_user" + messages = [ + {"role": "user", "content": "My name is John and I live in Shanghai."}, + {"role": "assistant", "content": "Nice to meet you, John!"} + ] + params = {"native_language": "zh", "profile_type": "topics"} + + result = user_memory.add( + messages=messages, + user_id=user_id, + **params + ) + + assert result.get("profile_extracted") == True, "Profile should be extracted" + topics = result.get("topics", {}) + assert topics, "topics should not be empty" + + # Check keys are English + assert check_topics_keys_english(topics), f"Topic keys should be in English: {topics}" + + # Check values contain Chinese + values = flatten_topics_values(topics) + has_chinese_value = any(has_chinese_chars(v) for v in values if v) + assert has_chinese_value, f"Topic values should contain Chinese: {topics}" + print_test_result("TC-002", messages, params, result) + + def test_TC003_japanese_native_language(self, user_memory): + """TC-003: Extract profile with Japanese native language""" + user_id = "tc003_ja_user" + messages = [ + {"role": "user", "content": "我叫测试007。"}, + {"role": "assistant", "content": "嘿,测试007,你好呀!"} + ] + params = {"native_language": "ja", "profile_type": "topics"} + + result = user_memory.add( + messages=messages, + user_id=user_id, + **params + ) + + assert result.get("profile_extracted") == True, "Profile should be extracted" + topics = result.get("topics", {}) + assert topics, "topics should not be empty" + + # Check values - may contain Japanese or transliterated content + values = flatten_topics_values(topics) + print_test_result("TC-003", messages, params, result) + + def test_TC004_english_native_language(self, user_memory): + """TC-004: Extract profile with English native language""" + user_id = "tc004_en_user" + messages = [ + {"role": "user", "content": "我是一名来自北京的程序员"}, + {"role": "assistant", "content": "很高兴认识你!"} + ] + params = {"native_language": "en", "profile_type": "content"} + + result = user_memory.add( + messages=messages, + user_id=user_id, + **params + ) + + assert result.get("profile_extracted") == True, "Profile should be extracted" + profile_content = result.get("profile_content", "") + assert profile_content, "profile_content should not be empty" + + # English content should be mostly ASCII + ascii_ratio = sum(1 for c in profile_content if c.isascii()) / max(len(profile_content), 1) + assert ascii_ratio > 0.7, f"Profile should be mostly in English, actual: {profile_content}" + print_test_result("TC-004", messages, params, result) + + def test_TC005_mixed_language_conversation(self, user_memory): + """TC-005: Mixed language conversation with specified native language""" + user_id = "tc005_mixed_lang_user" + messages = [ + {"role": "user", "content": "我在 Google 工作,做 machine learning"}, + {"role": "assistant", "content": "ML is a great field!"}, + {"role": "user", "content": "是的,我专注于 NLP 领域"} + ] + params = {"native_language": "zh", "profile_type": "content"} + + result = user_memory.add( + messages=messages, + user_id=user_id, + **params + ) + + assert result.get("profile_extracted") == True, "Profile should be extracted" + profile_content = result.get("profile_content", "") + assert profile_content, "profile_content should not be empty" + assert has_chinese_chars(profile_content), f"Profile should be unified in Chinese, actual: {profile_content}" + print_test_result("TC-005", messages, params, result) + + def test_TC006_multi_round_conversation(self, user_memory): + """TC-006: Multi-round conversation accumulating profile""" + user_id = "tc006_multi_round_user" + params = {"native_language": "zh", "profile_type": "content"} + + # Round 1 + messages_1 = [{"role": "user", "content": "I'm a teacher"}] + result_1 = user_memory.add( + messages=messages_1, + user_id=user_id, + **params + ) + + assert result_1.get("profile_extracted") == True + profile_1 = result_1.get("profile_content", "") + assert has_chinese_chars(profile_1), f"Round 1 profile should be in Chinese: {profile_1}" + + print(f"\n{'='*60}") + print(f"Test Case: TC-006 (Multi-round Conversation)") + print(f"{'='*60}") + print(f"\n📥 Round 1 Input:") + print(f" [user]: {messages_1[0]['content']}") + print(f"\n📤 Round 1 Result:") + print(f" - profile_content: {profile_1}") + + # Round 2 + messages_2 = [{"role": "user", "content": "I live in Tokyo"}] + result_2 = user_memory.add( + messages=messages_2, + user_id=user_id, + **params + ) + + assert result_2.get("profile_extracted") == True + profile_2 = result_2.get("profile_content", "") + assert has_chinese_chars(profile_2), f"Round 2 profile should be in Chinese: {profile_2}" + + print(f"\n📥 Round 2 Input:") + print(f" [user]: {messages_2[0]['content']}") + print(f"\n📤 Round 2 Result:") + print(f" - profile_content: {profile_2}") + print(f"\n{'='*60}") + print(f"✓ TC-006 Test Passed") + print(f"{'='*60}\n") + + +# ==================== Section 2: Language Coverage Tests ==================== + +class TestLanguageCoverage: + """Language coverage test cases TC-007 ~ TC-011""" + + @pytest.mark.parametrize("lang_code,test_id,message,check_func", [ + ("ko", "TC-007", "I'm from Seoul and I love K-pop music. My favorite food is kimchi and bibimbap.", has_korean_chars), + ("ru", "TC-011", "I live in Moscow and I work as a ballet dancer at the Bolshoi Theatre. I love Russian literature.", has_cyrillic_chars), + ]) + def test_language_with_special_chars(self, user_memory, lang_code, test_id, message, check_func): + """Test languages with special characters (Korean, Russian)""" + user_id = f"{test_id.lower().replace('-', '_')}_user" + messages = [{"role": "user", "content": message}] + params = {"native_language": lang_code, "profile_type": "content"} + + result = user_memory.add( + messages=messages, + user_id=user_id, + **params + ) + + assert result.get("profile_extracted") == True, f"{test_id}: Profile should be extracted" + profile_content = result.get("profile_content", "") + assert profile_content, f"{test_id}: profile_content should not be empty" + # Note: LLM may not always output in target language, so we just log + print_test_result(test_id, messages, params, result) + + def test_TC008_french_native_language(self, user_memory): + """TC-008: French native language test""" + user_id = "tc008_french_user" + messages = [ + {"role": "user", "content": "I live in Paris and I love French cuisine. My favorite food is croissant and café au lait."}, + {"role": "assistant", "content": "That sounds wonderful! Paris is a beautiful city."}, + {"role": "user", "content": "Yes, I work as a chef at a restaurant near the Eiffel Tower."} + ] + params = {"native_language": "fr", "profile_type": "content"} + + result = user_memory.add( + messages=messages, + user_id=user_id, + **params + ) + + assert result.get("profile_extracted") == True, "Profile should be extracted" + profile_content = result.get("profile_content", "") + assert profile_content, "profile_content should not be empty" + print_test_result("TC-008", messages, params, result) + + def test_TC009_german_native_language(self, user_memory): + """TC-009: German native language test""" + user_id = "tc009_german_user" + messages = [ + {"role": "user", "content": "I work in Berlin as an engineer at Volkswagen. I love German beer and Oktoberfest."}, + {"role": "assistant", "content": "Das klingt toll!"}, + {"role": "user", "content": "Ja, I also enjoy hiking in the Alps on weekends."} + ] + params = {"native_language": "de", "profile_type": "content"} + + result = user_memory.add( + messages=messages, + user_id=user_id, + **params + ) + + assert result.get("profile_extracted") == True, "Profile should be extracted" + profile_content = result.get("profile_content", "") + assert profile_content, "profile_content should not be empty" + print_test_result("TC-009", messages, params, result) + + def test_TC010_spanish_native_language(self, user_memory): + """TC-010: Spanish native language test""" + user_id = "tc010_spanish_user" + messages = [ + {"role": "user", "content": "I'm a doctor from Madrid. I love flamenco dancing and tapas."}, + {"role": "assistant", "content": "¡Qué interesante!"}, + {"role": "user", "content": "Sí, I also enjoy watching Real Madrid football matches."} + ] + params = {"native_language": "es", "profile_type": "content"} + + result = user_memory.add( + messages=messages, + user_id=user_id, + **params + ) + + assert result.get("profile_extracted") == True, "Profile should be extracted" + profile_content = result.get("profile_content", "") + assert profile_content, "profile_content should not be empty" + print_test_result("TC-010", messages, params, result) + + +# ==================== Section 3: Boundary Condition Tests ==================== + +class TestBoundaryConditions: + """Boundary condition test cases TC-012 ~ TC-015""" + + def test_TC012_no_native_language_param(self, user_memory): + """TC-012: Without native_language parameter""" + user_id = "tc012_no_lang_user" + messages = [ + {"role": "user", "content": "My name is Bob and I'm a developer from San Francisco."}, + {"role": "assistant", "content": "Nice to meet you, Bob!"} + ] + params = {"profile_type": "content"} # Without native_language + + # Call without native_language parameter + result = user_memory.add( + messages=messages, + user_id=user_id, + **params + ) + + assert result.get("profile_extracted") == True, "Profile should be extracted" + profile_content = result.get("profile_content", "") + assert profile_content, "profile_content should not be empty" + print_test_result("TC-012 (without native_language)", messages, params, result) + + def test_TC013_native_language_empty_string(self, user_memory): + """TC-013: native_language as empty string""" + user_id = "tc013_empty_lang_user" + messages = [ + {"role": "user", "content": "My name is Alice and I'm a software engineer from New York."}, + {"role": "assistant", "content": "Nice to meet you, Alice!"} + ] + params = {"native_language": "", "profile_type": "content"} # Empty string + + result = user_memory.add( + messages=messages, + user_id=user_id, + **params + ) + + assert result.get("profile_extracted") == True, "Profile should be extracted" + profile_content = result.get("profile_content", "") + assert profile_content, "profile_content should not be empty" + print_test_result("TC-013 (native_language=empty string)", messages, params, result) + + def test_TC014_unmapped_language_code(self, user_memory): + """TC-014: Unmapped language code (Polish pl)""" + user_id = "tc014_polish_user" + messages = [ + {"role": "user", "content": "I'm from Warsaw and I work as a pianist. I love Chopin's music and Polish pierogi."}, + {"role": "assistant", "content": "That sounds wonderful!"}, + {"role": "user", "content": "Tak, I also enjoy visiting the historic Old Town."} + ] + params = {"native_language": "pl", "profile_type": "content"} # pl not in standard mapping + + result = user_memory.add( + messages=messages, + user_id=user_id, + **params + ) + + # Should not raise error + assert result.get("profile_extracted") == True, "Profile should be extracted (even if language code is unmapped)" + profile_content = result.get("profile_content", "") + assert profile_content, "profile_content should not be empty" + print_test_result("TC-014 (unmapped code pl)", messages, params, result) + + def test_TC015_non_standard_language_description(self, user_memory): + """TC-015: Non-standard language description (français)""" + user_id = "tc015_francais_user" + messages = [ + {"role": "user", "content": "Bonjour! I live in Lyon and I'm a sommelier. I love wine tasting and French gastronomy."}, + {"role": "assistant", "content": "Magnifique! Lyon is known for its cuisine."}, + {"role": "user", "content": "Oui, I work at a Michelin star restaurant. My specialty is pairing wine with French dishes like coq au vin and bouillabaisse."} + ] + params = {"native_language": "français", "profile_type": "content"} # Using French word instead of ISO code + + result = user_memory.add( + messages=messages, + user_id=user_id, + **params + ) + + # Should not raise error - LLM should understand + assert result.get("profile_extracted") == True, "Profile should be extracted (even with non-standard language description)" + profile_content = result.get("profile_content", "") + assert profile_content, "profile_content should not be empty" + print_test_result("TC-015 (non-standard description français)", messages, params, result) + + +# ==================== Section 4: Compatibility Tests ==================== + +class TestCompatibility: + """Compatibility test cases TC-016 ~ TC-018""" + + def test_TC016_backward_compatibility(self, user_memory): + """TC-016: Backward compatibility with old code""" + user_id = "tc016_backward_compat_user" + messages = "Hello, I'm a developer named Charlie from Boston" + params = {} # Old-style call without any extra parameters + + # Use old-style call without native_language + result = user_memory.add( + messages=messages, + user_id=user_id + ) + + assert "profile_extracted" in result, "Should return standard structure" + assert isinstance(result.get("profile_extracted"), bool), "profile_extracted should be bool" + print_test_result("TC-016 (backward compatibility)", messages, params, result) + + def test_TC017_with_role_filters(self, user_memory): + """TC-017: Combined with include_roles/exclude_roles""" + user_id = "tc017_role_filter_user" + messages = [ + {"role": "user", "content": "I'm a data scientist from California"}, + {"role": "assistant", "content": "Your mother is a design engineer working at Google"}, + {"role": "user", "content": "I work with machine learning models"} + ] + params = { + "include_roles": ["user"], + "exclude_roles": ["assistant"], + "native_language": "zh", + "profile_type": "content" + } + + result = user_memory.add( + messages=messages, + user_id=user_id, + **params + ) + + assert result.get("profile_extracted") == True, "Profile should be extracted" + profile_content = result.get("profile_content", "") + assert has_chinese_chars(profile_content), f"Profile should be in Chinese: {profile_content}" + print_test_result("TC-017 (role filter + native_language)", messages, params, result) + + def test_TC018_with_profile_type_content(self, user_memory): + """TC-018a: Combined with profile_type=content""" + user_id = "tc018a_content_user" + messages = [{"role": "user", "content": "I love hiking and photography. I often go to Yosemite National Park."}] + params = {"native_language": "zh", "profile_type": "content"} + + result = user_memory.add( + messages=messages, + user_id=user_id, + **params + ) + + assert result.get("profile_extracted") == True + profile_content = result.get("profile_content", "") + assert profile_content, "profile_content should not be empty" + assert has_chinese_chars(profile_content), f"Content type profile should be in Chinese: {profile_content}" + print_test_result("TC-018a (profile_type=content)", messages, params, result) + + def test_TC018_with_profile_type_topics(self, user_memory): + """TC-018b: Combined with profile_type=topics""" + user_id = "tc018b_topics_user" + messages = [{"role": "user", "content": "I love hiking and photography. My name is David and I live in Seattle."}] + params = {"native_language": "zh", "profile_type": "topics"} + + result = user_memory.add( + messages=messages, + user_id=user_id, + **params + ) + + assert result.get("profile_extracted") == True + topics = result.get("topics", {}) + assert topics, "topics should not be empty" + assert check_topics_keys_english(topics), f"Topic keys should be in English: {topics}" + print_test_result("TC-018b (profile_type=topics)", messages, params, result) + + +# ==================== Section 5: API Endpoint Tests ==================== + +@pytest.mark.api +class TestAPIEndpoints: + """API endpoint test cases TC-019 ~ TC-022""" + + def _print_api_result(self, test_id: str, endpoint: str, request_data: Dict, response_data: Dict): + """Print detailed API test results""" + print(f"\n{'='*60}") + print(f"Test Case: {test_id}") + print(f"{'='*60}") + print(f"\n🌐 API Request:") + print(f" - Endpoint: POST {endpoint}") + print(f" - Request Body:") + print(f" {json.dumps(request_data, ensure_ascii=False, indent=4)}") + print(f"\n📤 API Response:") + print(f" - Response:") + print(f" {json.dumps(response_data, ensure_ascii=False, indent=4)}") + print(f"\n{'='*60}") + print(f"✓ {test_id} Test Passed") + print(f"{'='*60}\n") + + def test_TC019_api_with_native_language(self, api_client): + """TC-019: HTTP API with native_language parameter""" + user_id = f"api_test_{uuid.uuid4().hex[:8]}" + data = { + "messages": [{"role": "user", "content": "I am a developer from Shanghai"}], + "native_language": "zh", + "profile_type": "content", + "agent_id": "test_native_lang_agent", + "infer": True + } + endpoint = f"/users/{user_id}/profile" + + try: + response = api_client.post(endpoint, data=data) + + # Print debug info + print(f"\n🌐 Request URL: {api_client.api_base}{endpoint}") + print(f"📤 Response Status: {response.status_code}") + if response.status_code != 200: + print(f"📄 Response Content: {response.text[:500]}") + + assert response.status_code == 200, f"Should return 200, actual: {response.status_code}, response: {response.text[:200]}" + result = response.json() + assert result.get("success") == True, f"Request should succeed: {result}" + + profile_data = result.get("data", {}) + profile_content = profile_data.get("profile_content", "") + assert has_chinese_chars(profile_content), f"API returned profile should be in Chinese: {profile_content}" + self._print_api_result("TC-019 (API with native_language)", endpoint, data, result) + except requests.exceptions.ConnectionError: + pytest.skip("API server not running") + + def test_TC020_api_without_native_language(self, api_client): + """TC-020: HTTP API without native_language parameter""" + user_id = f"api_test_{uuid.uuid4().hex[:8]}" + data = { + "messages": [{"role": "user", "content": "I am a developer from Beijing"}], + "profile_type": "content", + "agent_id": "test_native_lang_agent", + "infer": True + } + endpoint = f"/users/{user_id}/profile" + + try: + response = api_client.post(endpoint, data=data) + + # Print debug info + print(f"\n🌐 Request URL: {api_client.api_base}{endpoint}") + print(f"📤 Response Status: {response.status_code}") + if response.status_code != 200: + print(f"📄 Response Content: {response.text[:500]}") + + assert response.status_code == 200, f"Should return 200, actual: {response.status_code}, response: {response.text[:200]}" + result = response.json() + assert result.get("success") == True, f"Request should succeed (backward compatible): {result}" + self._print_api_result("TC-020 (API without native_language)", endpoint, data, result) + except requests.exceptions.ConnectionError: + pytest.skip("API server not running") + + def test_TC021_api_native_language_null(self, api_client): + """TC-021: HTTP API with native_language field as null""" + user_id = f"api_test_{uuid.uuid4().hex[:8]}" + data = { + "messages": [{"role": "user", "content": "I am a developer from Tokyo"}], + "native_language": None, + "profile_type": "content", + "agent_id": "test_native_lang_agent", + "infer": True + } + endpoint = f"/users/{user_id}/profile" + + try: + response = api_client.post(endpoint, data=data) + + # Print debug info + print(f"\n🌐 Request URL: {api_client.api_base}{endpoint}") + print(f"📤 Response Status: {response.status_code}") + if response.status_code != 200: + print(f"📄 Response Content: {response.text[:500]}") + + assert response.status_code == 200, f"Should return 200, actual: {response.status_code}, response: {response.text[:200]}" + result = response.json() + assert result.get("success") == True, f"Request should succeed (null equals not passing): {result}" + self._print_api_result("TC-021 (API native_language=null)", endpoint, data, result) + except requests.exceptions.ConnectionError: + pytest.skip("API server not running") + + def test_TC022_api_non_standard_language_description(self, api_client): + """TC-022: HTTP API with non-standard language description""" + user_id = f"api_test_{uuid.uuid4().hex[:8]}" + data = { + "messages": [{"role": "user", "content": "I live in Paris and work as a chef. I love French cuisine and wine."}], + "native_language": "français", # Non-standard: full language name instead of ISO code + "profile_type": "content", + "agent_id": "test_native_lang_agent", + "infer": True + } + endpoint = f"/users/{user_id}/profile" + + try: + response = api_client.post(endpoint, data=data) + + # Print debug info + print(f"\n🌐 Request URL: {api_client.api_base}{endpoint}") + print(f"📤 Response Status: {response.status_code}") + if response.status_code != 200: + print(f"📄 Response Content: {response.text[:500]}") + + assert response.status_code == 200, f"Should return 200 (should not error on non-standard language description), actual: {response.status_code}, response: {response.text[:200]}" + result = response.json() + assert result.get("success") == True, f"Request should succeed: {result}" + + profile_data = result.get("data", {}) + profile_content = profile_data.get("profile_content", "") + assert profile_content, "profile_content should not be empty" + self._print_api_result("TC-022 (API non-standard language français)", endpoint, data, result) + except requests.exceptions.ConnectionError: + pytest.skip("API server not running") + + +# ==================== Entry Point ==================== + +if __name__ == "__main__": + # Run all tests with verbose output + pytest.main([__file__, "-v", "--tb=short"]) + From 05fd24474d6f474680d13c812ce21520753aefc7 Mon Sep 17 00:00:00 2001 From: Chifang <40140008+Ripcord55@users.noreply.github.com> Date: Tue, 3 Feb 2026 16:38:35 +0800 Subject: [PATCH 07/23] Oceanbase Native Hybrid Search Cases (#223) * oceanbase native language case * Oceanbase Native Hybrid Search Cases --- tests/regression/test_native_hybrid_search.py | 710 ++++++++++++++++++ 1 file changed, 710 insertions(+) create mode 100644 tests/regression/test_native_hybrid_search.py diff --git a/tests/regression/test_native_hybrid_search.py b/tests/regression/test_native_hybrid_search.py new file mode 100644 index 0000000..a4c9b47 --- /dev/null +++ b/tests/regression/test_native_hybrid_search.py @@ -0,0 +1,710 @@ +""" +Comprehensive test cases for OceanBase Native Hybrid Search + +This test suite covers all test cases from the test document: +- TC-001: Enable native hybrid search +- TC-005: Hybrid search fusion effect +- TC-006: Table column field filtering +- TC-007: JSON field filtering (auto fallback) +- TC-008: Empty result handling +- TC-009: Large data search +- TC-010: Limit parameter test +- TC-012: Threshold parameter triggers fallback +- TC-013: API compatibility +- TC-014: Old table compatibility (InnoDB to HEAP migration) +- Performance comparison test +""" + +from ast import Tuple +import logging +import os +import sys +import time +import pytest +from typing import Dict, Any, List, Optional + +# Add project root to Python path +project_root = os.path.join(os.path.dirname(__file__), "..", "..") +project_root = os.path.abspath(project_root) +sys.path.insert(0, project_root) + +from powermem import auto_config, Memory + +# Configure logging +logging.basicConfig( + level=logging.DEBUG, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + force=True +) +logger = logging.getLogger(__name__) + +# Helper functions for logging +def log_info(msg): + """Log and print info message""" + logger.info(msg) + print(msg) + +def log_warning(msg): + """Log and print warning message""" + logger.warning(msg) + print(f"WARNING: {msg}") + +def log_error(msg): + """Log and print error message""" + logger.error(msg) + print(f"ERROR: {msg}") + + +class NativeHybridSearchTester: + """Native Hybrid Search test class""" + + def __init__(self, enable_native_hybrid: bool = True): + """Initialize tester""" + self.enable_native_hybrid = enable_native_hybrid + self.config = auto_config() + + # Enable native hybrid search if requested + if 'vector_store' not in self.config: + self.config['vector_store'] = {} + if 'config' not in self.config['vector_store']: + self.config['vector_store']['config'] = {} + self.config['vector_store']['config']['enable_native_hybrid'] = enable_native_hybrid + + self.memory = Memory(config=self.config) + log_info(f"Native Hybrid Search tester initialized (enable_native_hybrid={enable_native_hybrid})") + + def cleanup_all(self): + """Cleanup all test data""" + try: + # Note: This is a simplified cleanup, actual implementation may vary + log_info("Cleaning up test data...") + except Exception as e: + log_warning(f"Failed to cleanup: {e}") + + +@pytest.fixture(scope="class") +def native_hybrid_tester(request): + """Fixture to create NativeHybridSearchTester instance""" + enable_native = getattr(request.cls, 'enable_native_hybrid', True) + tester = NativeHybridSearchTester(enable_native_hybrid=enable_native) + yield tester + # tester.cleanup_all() + + +@pytest.mark.usefixtures("native_hybrid_tester") +class TestNativeHybridSearch: + """Test class for Native Hybrid Search functionality""" + + enable_native_hybrid = True + + @pytest.fixture(autouse=True) + def setup_tester(self, native_hybrid_tester): + """Setup tester instance for each test""" + self.tester = native_hybrid_tester + + def test_tc001_enable_native_hybrid_search(self): + """ + TC-001: Enable native hybrid search + + Test purpose: Verify native hybrid search can be enabled normally + """ + log_info("=" * 80) + log_info("TC-001: Enable Native Hybrid Search") + log_info("=" * 80) + + user_id = "tc001_user" + + # Step 1: Initialize Memory instance + log_info("\n[Step 1] Initializing Memory instance...") + memory = self.tester.memory + assert memory is not None, "Memory instance should be created" + log_info("✓ Memory instance created") + + # Step 2: Add test data + log_info("\n[Step 2] Adding test data...") + result = memory.add(messages="Zhang San lives in Hangzhou", user_id=user_id) + assert result is not None, "Add operation should succeed" + log_info("✓ Test data added") + + # Step 3: Execute search query + log_info("\n[Step 3] Executing search query...") + search_results = memory.search(query="Where does Zhang San live", user_id=user_id, limit=10) + assert search_results is not None, "Search should return results" + log_info(f"✓ Search completed, found {len(search_results.get('results', []))} results") + + # Step 4: Check log output (we can't directly check logs, but we verify functionality) + log_info("\n[Step 4] Verifying search results...") + memories = search_results.get('results', []) + assert len(memories) > 0, "Should return at least one result" + + # Verify result content + found_relevant = False + for mem in memories: + content = mem.get('memory', '') + if 'Zhang San' in content or 'Hangzhou' in content: + found_relevant = True + log_info(f"✓ Found relevant result: {content}...") + break + log_info(f"✓ Top Result: {memories[0].get('memory', '') if memories else 'No results'}") + assert found_relevant, "Should find relevant results" + log_info("\n✓ TC-001 passed: Native hybrid search enabled successfully") + + def test_tc005_hybrid_search_fusion_effect(self): + """ + TC-005: Hybrid search fusion effect + + Test purpose: Verify the fusion effect of vector search, full-text search, and sparse vector search + """ + log_info("=" * 80) + log_info("TC-005: Hybrid Search Fusion Effect") + log_info("=" * 80) + + user_id = "tc005_user" + memory = self.tester.memory + + # Step 1: Add diverse test data + log_info("\n[Step 1] Adding diverse test data...") + test_messages = [ + "Zhang San lives in Hangzhou and is a software engineer", + "Li Si is a product manager in Beijing", + "Wang Wu works in Shenzhen and likes running" + ] + + for msg in test_messages: + result = memory.add(messages=msg, user_id=user_id) + assert result is not None, f"Failed to add message: {msg}" + log_info(f"✓ Added: {msg}") + + # Step 2: Execute complex query + log_info("\n[Step 2] Executing complex query...") + query = "Zhang San's workplace and occupation" + results = memory.search(query=query, user_id=user_id, limit=10) + + assert results is not None, "Search should return results" + memories = results.get('results', []) + log_info(f"✓ Found {len(memories)} results") + + # Step 3: Verify fusion results accuracy + log_info("\n[Step 3] Verifying fusion results...") + assert len(memories) > 0, "Should return at least one result" + + # Check if most relevant result is at the top + if memories: + top_result = memories[0] + content = top_result.get('memory', '') + score = top_result.get('score', 0) + log_info(f"✓ Top result: {content}... (score: {score:.4f})") + + # Verify relevance + assert 'Zhang San' in content, "Top result should contain 'Zhang San'" + assert score > 0, "Score should be positive" + log_info(f"✓ Top Result: {top_result.get('memory', '') if top_result else 'No results'}") + log_info("\n✓ TC-005 passed: Hybrid search fusion effect verified") + + def test_tc006_table_column_field_filtering(self): + """ + TC-006: Table column field filtering + + Test purpose: Verify native hybrid search supports table column field filtering + """ + log_info("=" * 80) + log_info("TC-006: Table Column Field Filtering") + log_info("=" * 80) + + memory = self.tester.memory + + # Use specific user_ids for filter testing + user_id_1 = "tc006_filter_user1" + user_id_2 = "tc006_filter_user2" + + # Step 1: Add memories with different user_id + log_info("\n[Step 1] Adding memories with different user_id...") + memory.add(messages="Zhang San lives in Hangzhou", user_id=user_id_1) + memory.add(messages="Li Si is in Beijing", user_id=user_id_2) + log_info(f"✓ Added memories for {user_id_1} and {user_id_2}") + + # Step 2: Search with table column field filter + log_info("\n[Step 2] Searching with table column field filter...") + results = memory.search( + query="Where does Li Si live", + filters={"user_id": user_id_2}, + limit=10 + ) + + assert results is not None, "Search should return results" + memories = results.get('results', []) + log_info(f"✓ Found {len(memories)} results") + + # Step 3: Verify filtering effect + log_info("\n[Step 3] Verifying filtering effect...") + if memories: + for mem in memories: + # Check if all results belong to the filtered user + # Note: This depends on how metadata is stored + content = mem.get('memory', '') + log_info(f" Result: {content}...") + + # Verify that results are filtered (all should be user_id_2's data) + assert len(memories) >= 0, "Should return filtered results" + log_info("✓ Filtering applied successfully") + log_info(f"✓ Top Result: {memories[0].get('memory', '') if memories else 'No results'}") + log_info("\n✓ TC-006 passed: Table column field filtering verified") + + def test_tc007_json_field_filtering_auto_fallback(self): + """ + TC-007: JSON field filtering (auto fallback) + + Test purpose: Verify automatic fallback to application-level hybrid search when using JSON field filtering + """ + log_info("=" * 80) + log_info("TC-007: JSON Field Filtering (Auto Fallback)") + log_info("=" * 80) + + user_id = "tc007_user" + memory = self.tester.memory + + # Step 1: Add memory with JSON metadata + log_info("\n[Step 1] Adding memory with JSON metadata...") + memory.add( + messages="Zhang San lives in Hangzhou", + user_id=user_id, + metadata={"custom_field": "Hangzhou", "city": "Hangzhou", "province": "Zhejiang"} + ) + log_info("✓ Memory added with metadata: {'custom_field': 'Hangzhou', 'city': 'Hangzhou', 'province': 'Zhejiang'}") + + # Step 2: Search with JSON field filter (not supported) + log_info("\n[Step 2] Searching with JSON field filter (should trigger fallback)...") + results = memory.search( + query="Where does Zhang San live", + user_id=user_id, + filters={"custom_field": "Hangzhou"}, + limit=10 + ) + + # Step 3: Check if auto fallback occurred + log_info("\n[Step 3] Verifying auto fallback...") + assert results is not None, "Search should still return results (via fallback)" + memories = results.get('results', []) + log_info(f"✓ Found {len(memories)} results (via fallback)") + + # Verify search results are still correct + assert len(memories) >= 0, "Should return results even with fallback" + log_info("✓ Auto fallback mechanism verified") + log_info(f"✓ Top Result: {memories[0].get('memory', '') if memories else 'No results'}") + log_info("\n✓ TC-007 passed: JSON field filtering auto fallback verified") + + def test_tc008_empty_result_handling(self): + """ + TC-008: Empty result handling + + Test purpose: Verify handling when query returns no results + """ + log_info("=" * 80) + log_info("TC-008: Empty Result Handling") + log_info("=" * 80) + + user_id = "tc008_user" + memory = self.tester.memory + + # Step 1: Add small amount of test data + log_info("\n[Step 1] Adding test data...") + memory.add(messages="Zhang San lives in Hangzhou", user_id=user_id) + log_info("✓ Test data added") + + # Step 2: Query irrelevant content + log_info("\n[Step 2] Querying irrelevant content...") + query = "Completely irrelevant content xyz123" + results = memory.search(query=query, user_id=user_id, limit=10) + + assert results is not None, "Search should return results object" + memories = results.get('results', []) + log_info(f"✓ Search completed, found {len(memories)} results") + + # Step 3: Verify empty result handling + log_info("\n[Step 3] Verifying empty result handling...") + # Empty results or low relevance results are acceptable + assert isinstance(memories, list), "Results should be a list" + log_info("✓ Empty result handling verified (no exception thrown)") + log_info(f"✓ Top Result: {memories[0].get('memory', '') if memories else 'No results'}") + log_info("\n✓ TC-008 passed: Empty result handling verified") + + def test_tc009_large_data_search(self): + """ + TC-009: Large data search and performance comparison + + Test purpose: Verify search performance with large amount of data and compare + native hybrid search vs application-level hybrid search + """ + log_info("=" * 80) + log_info("TC-009: Large Data Search and Performance Comparison") + log_info("=" * 80) + + user_id = "tc009_user" + memory = self.tester.memory + + # Step 1: Add large amount of test data (100+ records) + log_info("\n[Step 1] Adding large amount of test data (100 records)...") + start_time = time.time() + + for i in range(10): + memory.add(messages=f"user{i + 43} is {i + 43} years old", user_id=user_id) + if (i + 1) % 20 == 0: + log_info(f" Added {i + 1} records...") + + add_time = time.time() - start_time + log_info(f"✓ Added 100 records in {add_time:.2f} seconds") + + # Step 2: Execute search query + log_info("\n[Step 2] Executing search query...") + start_time = time.time() + results = memory.search(query="user50", user_id=user_id, limit=10) + search_time = time.time() - start_time + + assert results is not None, "Search should return results" + memories = results.get('results', []) + log_info(f"✓ Search completed in {search_time:.3f} seconds, found {len(memories)} results") + + # Step 3: Verify performance and result accuracy + log_info("\n[Step 3] Verifying performance and accuracy...") + assert search_time < 1.0, f"Search should complete within 1 second, took {search_time:.3f}s" + assert len(memories) > 0, "Should return relevant results" + + # Verify result accuracy + if memories: + top_result = memories[0] + content = top_result.get('memory', '') + log_info(f"✓ Top result: {content}") + # Check if result contains user50 or related information + assert 'user50' in content or '50' in content, f"Top result should be relevant to 'user50', got: {content}" + + # Step 4: Test native hybrid search performance (enabled) + log_info("\n[Step 4] Testing native hybrid search performance (enabled)...") + start = time.time() + for i in range(50): + results_native = memory.search(query="user", user_id=user_id, limit=10) + native_time = time.time() - start + + assert results_native is not None, "Native search should return results" + native_count = len(results_native.get('results', [])) + log_info(f"✓ Native hybrid search: {native_time:.3f}s for 50 queries, {native_count} results") + + # Step 5: Test application-level hybrid search (disabled) + log_info("\n[Step 5] Testing application-level hybrid search (disabled)...") + config_app = auto_config() + if 'vector_store' not in config_app: + config_app['vector_store'] = {} + if 'config' not in config_app['vector_store']: + config_app['vector_store']['config'] = {} + config_app['vector_store']['config']['enable_native_hybrid'] = False + + memory_app = Memory(config=config_app) + + start = time.time() + for i in range(50): + results_app = memory_app.search(query="user", limit=10) + app_time = time.time() - start + + assert results_app is not None, "Application-level search should return results" + app_count = len(results_app.get('results', [])) + log_info(f"✓ Application-level hybrid search: {app_time:.3f}s for 50 queries, {app_count} results") + + # Step 6: Compare performance + log_info("\n[Step 6] Performance comparison:") + log_info(f" Native hybrid search: {native_time:.3f}s") + log_info(f" Application-level search: {app_time:.3f}s") + + if app_time > 0: + improvement = ((app_time - native_time) / app_time) * 100 + log_info(f" Performance improvement: {improvement:.1f}%") + + # Both should return results + assert native_count >= 0, "Native search should return results" + assert app_count >= 0, "Application-level search should return results" + + log_info(f"\n✓ TC-009 passed: Large data search and performance comparison verified") + + def test_tc010_limit_parameter(self): + """ + TC-010: Limit parameter test + + Test purpose: Verify the effect of different limit parameters + """ + log_info("=" * 80) + log_info("TC-010: Limit Parameter Test") + log_info("=" * 80) + + user_id = "tc010_user" + memory = self.tester.memory + + # Step 0: Add test data for this test case (30+ records to test limit properly) + log_info("\n[Step 0] Adding test data for limit testing...") + for i in range(30): + memory.add(messages=f"tc010_user{i} is {i} years old", user_id=user_id) + log_info("✓ Added 30 records for limit testing") + + # Step 1: Search with different limit values + log_info("\n[Step 1] Searching with different limit values...") + log_info("Note: Testing limit parameter with query 'tc010_user' (should match many records)") + + # Test limit=5 + log_info("\n[Test 1] Testing limit=5...") + results_5 = memory.search(query="tc010_user", user_id=user_id, limit=5) + memories_5 = results_5.get('results', []) + log_info(f"✓ limit=5: returned {len(memories_5)} results") + if len(memories_5) < 5: + log_warning(f" ⚠ Expected up to 5 results, but got {len(memories_5)}. This may indicate a limit issue.") + + # Test limit=10 + log_info("\n[Test 2] Testing limit=10...") + results_10 = memory.search(query="tc010_user", user_id=user_id, limit=10) + memories_10 = results_10.get('results', []) + log_info(f"✓ limit=10: returned {len(memories_10)} results") + if len(memories_10) < 10: + log_warning(f" ⚠ Expected up to 10 results, but got {len(memories_10)}. This may indicate a limit issue.") + + # Test limit=20 + log_info("\n[Test 3] Testing limit=20...") + results_20 = memory.search(query="tc010_user", user_id=user_id, limit=20) + memories_20 = results_20.get('results', []) + log_info(f"✓ limit=20: returned {len(memories_20)} results") + if len(memories_20) < 20: + log_warning(f" ⚠ Expected up to 20 results, but got {len(memories_20)}. This may indicate a limit issue.") + log_info(f" Database should have 30 user-related records, but only {len(memories_20)} were returned.") + + # Step 2: Verify result counts + log_info("\n[Step 2] Verifying result counts...") + # Note: We use <= instead of == because there might not be enough matching records + assert len(memories_5) <= 5, f"limit=5 should return at most 5 results, got {len(memories_5)}" + assert len(memories_10) <= 10, f"limit=10 should return at most 10 results, got {len(memories_10)}" + assert len(memories_20) <= 20, f"limit=20 should return at most 20 results, got {len(memories_20)}" + + # Additional check: if limit=20 returns only 10, there might be a bug + if len(memories_20) == 10 and len(memories_10) == 10: + log_warning(" ⚠ limit=20 and limit=10 both returned 10 results. This suggests rank_window_size might be capped at 10.") + log_warning(" This could be a bug in the native hybrid search implementation.") + + # Verify ordering (results should be sorted by relevance) + if len(memories_5) > 1: + scores = [m.get('score', 0) for m in memories_5] + assert scores == sorted(scores, reverse=True), "Results should be sorted by score (descending)" + + log_info(f"✓ memories_20: {memories_20}") + log_info("✓ Limit parameter verified") + log_info("\n✓ TC-010 passed: Limit parameter test verified") + + def test_tc012_threshold_parameter_triggers_fallback(self): + """ + TC-012: Threshold parameter triggers fallback + + Test purpose: Verify automatic fallback when threshold parameter is used + """ + log_info("=" * 80) + log_info("TC-012: Threshold Parameter Triggers Fallback") + log_info("=" * 80) + + user_id = "tc012_user" + memory = self.tester.memory + + # Step 0: Add test data for this test case + log_info("\n[Step 0] Adding test data for threshold testing...") + memory.add(messages="tc012 test user data for threshold testing", user_id=user_id) + log_info("✓ Test data added") + + # Step 1: Enable native hybrid search (already enabled in fixture) + log_info("\n[Step 1] Native hybrid search is enabled") + + # Step 2: Search with threshold parameter (should trigger fallback) + log_info("\n[Step 2] Searching with threshold parameter (should trigger fallback)...") + try: + results = memory.search( + query="tc012 test", + user_id=user_id, + limit=10, + threshold=0.8 # This should trigger fallback + ) + + # Step 3: Check if auto fallback occurred + log_info("\n[Step 3] Verifying auto fallback...") + assert results is not None, "Search should still return results (via fallback)" + memories = results.get('results', []) + log_info(f"✓ Found {len(memories)} results (via fallback)") + + log_info(f"✓ Top Result: {memories[0].get('memory', '') if memories else 'No results'}") + # Verify search results are still correct + assert len(memories) >= 0, "Should return results even with fallback" + log_info("✓ Auto fallback mechanism verified") + + except TypeError as e: + # If threshold parameter is not supported in the API + log_warning(f"Threshold parameter may not be supported in API: {e}") + pytest.skip("Threshold parameter not supported in current API") + + log_info("\n✓ TC-012 passed: Threshold parameter triggers fallback verified") + + + def test_tc014_old_table_compatibility(self): + """ + TC-014: Old table compatibility test + + Test purpose: Verify compatibility when switching from old logic to new logic (native hybrid search) + + Test steps: + 1. Drop existing table to ensure clean state + 2. Disable native hybrid search, initialize Memory (auto create table) + 3. Add some data and search with old logic + 4. Enable native hybrid search and search with new logic + 5. Output results (count and content) + 6. Output test results + + Note: This test does NOT use the fixture to avoid creating a HEAP table before the test starts. + """ + import pymysql + + log_info("=" * 80) + log_info("TC-014: Old Table Compatibility Test") + log_info("=" * 80) + + # Database connection info + db_host = "127.0.0.1" + db_port = 10001 + db_name = "powermem" + table_name = "memories_old_table_test" + + # Step 0: Drop existing table to ensure clean state + log_info("\n[Step 0] Dropping existing table to ensure clean state...") + try: + conn = pymysql.connect( + host=db_host, + port=db_port, + database=db_name, + user="root", + password="", + charset="utf8mb4" + ) + cursor = conn.cursor() + drop_sql = f"DROP TABLE IF EXISTS `{table_name}`" + cursor.execute(drop_sql) + conn.commit() + cursor.close() + conn.close() + log_info(f"✓ Table '{table_name}' dropped successfully (or did not exist)") + except Exception as e: + log_warning(f"Failed to drop table: {e}") + + # Step 1: Disable native hybrid search, initialize Memory (auto create table) + log_info("\n[Step 1] Disabling native hybrid search and initializing Memory (old mode)...") + config_old = auto_config() + if 'vector_store' not in config_old: + config_old['vector_store'] = {} + if 'config' not in config_old['vector_store']: + config_old['vector_store']['config'] = {} + config_old['vector_store']['config']['enable_native_hybrid'] = False + config_old['vector_store']['config']['collection_name'] = table_name + + memory_old = Memory(config=config_old) + log_info("✓ Memory initialized with native hybrid search DISABLED (old logic)") + + # Step 2: Add some data to the table + log_info("\n[Step 2] Adding data to table...") + user_id = "tc014_user" + test_messages = [ + "user1 is 1 years old", + "user2 is 2 years old", + "user3 is 3 years old" + ] + + for msg in test_messages: + result = memory_old.add(messages=msg, user_id=user_id) + assert result is not None, f"Failed to add message: {msg}" + log_info(f"✓ Added: {msg}") + + # Step 3: Search with old logic (native hybrid search disabled) + log_info("\n[Step 3] Searching with OLD logic (native hybrid search DISABLED)...") + old_results = memory_old.search(query="user", user_id=user_id, limit=10) + old_memories = old_results.get('results', []) + + log_info(f"\n--- OLD Logic Search Results ---") + log_info(f"Result count: {len(old_memories)}") + log_info(f"Results content:") + for i, mem in enumerate(old_memories): + memory_content = mem.get('memory', 'N/A') + score = mem.get('score', 'N/A') + log_info(f" [{i+1}] score={score}, memory={memory_content}") + + # Step 4: Enable native hybrid search and search with new logic + log_info("\n[Step 4] Enabling native hybrid search and searching with NEW logic...") + config_new = auto_config() + if 'vector_store' not in config_new: + config_new['vector_store'] = {} + if 'config' not in config_new['vector_store']: + config_new['vector_store']['config'] = {} + config_new['vector_store']['config']['enable_native_hybrid'] = True + config_new['vector_store']['config']['collection_name'] = table_name + + memory_new = Memory(config=config_new) + log_info("✓ Memory reinitialized with native hybrid search ENABLED (new logic)") + + new_results = memory_new.search(query="user", user_id=user_id, limit=10) + new_memories = new_results.get('results', []) + + log_info(f"\n--- NEW Logic Search Results ---") + log_info(f"Result count: {len(new_memories)}") + log_info(f"Results content:") + for i, mem in enumerate(new_memories): + memory_content = mem.get('memory', 'N/A') + score = mem.get('score', 'N/A') + log_info(f" [{i+1}] score={score}, memory={memory_content}") + + # Step 5: Output comparison summary + log_info("\n" + "=" * 80) + log_info("Test Results Summary") + log_info("=" * 80) + log_info(f"OLD logic (native hybrid DISABLED): {len(old_memories)} results") + log_info(f"NEW logic (native hybrid ENABLED): {len(new_memories)} results") + + # Verify both searches returned results + assert len(old_memories) > 0, "Old logic should return results" + assert len(new_memories) > 0, "New logic should return results" + + log_info("\n✓ TC-014 passed: Old table compatibility verified") + log_info(" - Old logic search works correctly") + log_info(" - New logic search works correctly after enabling native hybrid search") + + + +def run_all_tests(): + """Run tests with case selection in code""" + # Define available test cases + test_cases = { + "1": ("test_tc001_enable_native_hybrid_search", "TC-001: Enable native hybrid search"), + "2": ("test_tc005_hybrid_search_fusion_effect", "TC-005: Hybrid search fusion effect"), + "3": ("test_tc006_table_column_field_filtering", "TC-006: Table column field filtering"), + "4": ("test_tc007_json_field_filtering_auto_fallback", "TC-007: JSON field filtering (auto fallback)"), + "5": ("test_tc008_empty_result_handling", "TC-008: Empty result handling"), + "6": ("test_tc009_large_data_search", "TC-009: Large data search and performance comparison"), + "7": ("test_tc010_limit_parameter", "TC-010: Limit parameter test"), + "8": ("test_tc012_threshold_parameter_triggers_fallback", "TC-012: Threshold parameter triggers fallback"), + "9": ("test_tc014_old_table_compatibility", "TC-014: Old table compatibility test"), + } + + log_info("=" * 80) + log_info("Starting Native Hybrid Search Comprehensive Tests") + log_info("=" * 80) + + # Build pytest arguments - run all test cases defined in test_cases + pytest_args = ["-v", "-s"] + + log_info(f"Running {len(test_cases)} test cases from test_cases list:") + for test_key, (test_method, desc) in sorted(test_cases.items(), key=lambda x: int(x[0])): + # Each test case needs to be a complete path: file::Class::method + test_path = f"{__file__}::TestNativeHybridSearch::{test_method}" + pytest_args.append(test_path) + log_info(f" - {desc}") + + log_info("=" * 80) + pytest.main(pytest_args) + + +if __name__ == "__main__": + run_all_tests() + From d2b4d129cf2c930d12c6040fc80f01542577b7c3 Mon Sep 17 00:00:00 2001 From: Even Date: Tue, 3 Feb 2026 19:41:38 +0800 Subject: [PATCH 08/23] Optimise searching in Intelligent mode And fix SILICONFLOW_LLM_BASE_URL bug (#224) * Enhance memory operations with background threading support - Added a global background thread pool for asynchronous memory updates and deletions in the Memory class. - Updated the handling of memory updates and deletions to submit tasks to the background executor, improving performance and responsiveness. * format * Enhance SiliconFlowConfig API key handling - Updated `SiliconFlowConfig` to improve API key and base URL handling by adding new validation aliases for better compatibility. --- src/powermem/core/memory.py | 22 ++++++++++--------- .../integrations/llm/config/siliconflow.py | 4 ++-- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/powermem/core/memory.py b/src/powermem/core/memory.py index 9ccc55b..c98280b 100644 --- a/src/powermem/core/memory.py +++ b/src/powermem/core/memory.py @@ -8,6 +8,7 @@ import warnings import hashlib import json +from concurrent.futures import ThreadPoolExecutor from typing import Any, Dict, List, Optional, Union from datetime import datetime from powermem.utils.utils import get_current_datetime @@ -36,6 +37,9 @@ logger = logging.getLogger(__name__) +# Global background thread pool for async memory operations +_BACKGROUND_EXECUTOR = ThreadPoolExecutor(max_workers=3) + def _auto_convert_config(config: Dict[str, Any]) -> Dict[str, Any]: """ @@ -1180,16 +1184,14 @@ def search( # Intelligent plugin lifecycle management on search if self._intelligence_plugin and self._intelligence_plugin.enabled: updates, deletes = self._intelligence_plugin.on_search(processed_results) - for mem_id, upd in updates: - try: - self.storage.update_memory(mem_id, {**upd}, user_id, agent_id) - except Exception: - continue - for mem_id in deletes: - try: - self.storage.delete_memory(mem_id, user_id, agent_id) - except Exception: - continue + if updates: + for mem_id, upd in updates: + _BACKGROUND_EXECUTOR.submit(self.storage.update_memory,mem_id,{**upd},user_id,agent_id) + logger.info(f"Submitted {len(updates)} update operations to background executor") + if deletes: + for mem_id in deletes: + _BACKGROUND_EXECUTOR.submit(self.storage.delete_memory,mem_id,user_id,agent_id) + logger.info(f"Submitted {len(deletes)} delete operations to background executor") # Transform results to match benchmark expected format # Benchmark expects: {"results": [{"memory": ..., "metadata": {...}, "score": ...}], "relations": [...]} diff --git a/src/powermem/integrations/llm/config/siliconflow.py b/src/powermem/integrations/llm/config/siliconflow.py index dd6db61..5996399 100644 --- a/src/powermem/integrations/llm/config/siliconflow.py +++ b/src/powermem/integrations/llm/config/siliconflow.py @@ -22,10 +22,10 @@ class SiliconFlowConfig(OpenAIConfig): api_key: Optional[str] = Field( default=None, validation_alias=AliasChoices( + "SILICONFLOW_API_KEY", "api_key", "LLM_API_KEY", "OPENAI_API_KEY", - "SILICONFLOW_API_KEY", ), description="SiliconFlow API key" ) @@ -34,9 +34,9 @@ class SiliconFlowConfig(OpenAIConfig): openai_base_url: Optional[str] = Field( default="https://api.siliconflow.cn/v1", validation_alias=AliasChoices( + "SILICONFLOW_LLM_BASE_URL", "openai_base_url", "OPENAI_LLM_BASE_URL", - "SILICONFLOW_LLM_BASE_URL", ), description="SiliconFlow API base URL (OpenAI-compatible)" ) \ No newline at end of file From 408c967db1efadc4c2d7e905cc3d69610b01d625 Mon Sep 17 00:00:00 2001 From: Even Date: Wed, 4 Feb 2026 15:35:51 +0800 Subject: [PATCH 09/23] Fix unit test issues caused by setting changes (#228) * Enhance configuration management for OceanBase in config_loader.py - Added backward compatibility for OceanBase by constructing connection arguments from vector store configuration. - Updated unit tests to verify the inclusion of internal settings in the configuration. * disable env file --- src/powermem/config_loader.py | 11 +++++++++++ tests/unit/test_config_loader.py | 18 ++++++++++++++---- tests/unit/test_qwen.py | 23 +++++++++++++---------- 3 files changed, 38 insertions(+), 14 deletions(-) diff --git a/src/powermem/config_loader.py b/src/powermem/config_loader.py index f8ccfc4..479f3f7 100644 --- a/src/powermem/config_loader.py +++ b/src/powermem/config_loader.py @@ -143,6 +143,17 @@ def to_config(self) -> Dict[str, Any]: # 3. Export to dict vector_store_config = provider_settings.model_dump(exclude_none=True) + # 4. For OceanBase, build connection_args for backward compatibility + if db_provider == "oceanbase": + connection_args = {} + for key in ["host", "port", "user", "password", "db_name"]: + if key in vector_store_config: + connection_args[key] = vector_store_config[key] + + # Only add connection_args if we have connection parameters + if connection_args: + vector_store_config["connection_args"] = connection_args + return {"provider": db_provider, "config": vector_store_config} diff --git a/tests/unit/test_config_loader.py b/tests/unit/test_config_loader.py index 946226f..d1976fc 100644 --- a/tests/unit/test_config_loader.py +++ b/tests/unit/test_config_loader.py @@ -10,6 +10,9 @@ def _reset_env(monkeypatch, keys): def _disable_env_file(monkeypatch): monkeypatch.setattr(config_loader, "_DEFAULT_ENV_FILE", None, raising=False) monkeypatch.setattr(settings, "_DEFAULT_ENV_FILE", None, raising=False) + new_config = dict(config_loader.EmbeddingSettings.model_config) + new_config["env_file"] = None + monkeypatch.setattr(config_loader.EmbeddingSettings, "model_config", new_config) def test_load_config_from_env_builds_core_config(monkeypatch): @@ -90,7 +93,7 @@ def test_load_config_from_env_graph_store_fallback(monkeypatch): assert graph_store["config"]["max_hops"] == 3 -def test_load_config_from_env_does_not_expose_internal_settings(monkeypatch): +def test_load_config_from_env_loads_internal_settings(monkeypatch): _reset_env( monkeypatch, [ @@ -110,9 +113,16 @@ def test_load_config_from_env_does_not_expose_internal_settings(monkeypatch): config = config_loader.load_config_from_env() - assert "performance" not in config - assert "security" not in config - assert "memory_decay" not in config + # These settings should be included in the config + assert "performance" in config + assert config["performance"]["memory_batch_size"] == 200 + + assert "security" in config + assert config["security"]["encryption_enabled"] is True + assert config["security"]["access_control_enabled"] is False + + assert "memory_decay" in config + assert config["memory_decay"]["enabled"] is False def test_load_config_from_env_telemetry_aliases(monkeypatch): diff --git a/tests/unit/test_qwen.py b/tests/unit/test_qwen.py index fd36495..913f4bc 100644 --- a/tests/unit/test_qwen.py +++ b/tests/unit/test_qwen.py @@ -343,9 +343,10 @@ def test_api_key_from_environment(): config = QwenConfig(model="qwen-turbo") llm = QwenLLM(config) - # Verify API key is set in dashscope module (not in config) - # The config.api_key remains None, but dashscope.api_key is set - assert llm.config.api_key is None + # Verify API key is set from environment variable + # When using validation_alias with AliasChoices, pydantic will read the env var + # and set it to the api_key field + assert llm.config.api_key == "env_api_key" # Clean up del os.environ["DASHSCOPE_API_KEY"] @@ -353,13 +354,15 @@ def test_api_key_from_environment(): def test_dashscope_import_error(): # Test when dashscope is not installed - with patch("builtins.__import__", side_effect=ImportError("No module named 'dashscope'")): - config = QwenConfig(model="qwen-turbo", api_key="test_key") - - with pytest.raises(ImportError) as exc_info: - QwenLLM(config) - - assert "DashScope SDK is not installed" in str(exc_info.value) + config = QwenConfig(model="qwen-turbo", api_key="test_key") + + with patch("powermem.integrations.llm.qwen.Generation", None), \ + patch("powermem.integrations.llm.qwen.DashScopeAPIResponse", None), \ + patch.dict('sys.modules', {'dashscope': None, 'dashscope.api_entities.dashscope_response': None}), \ + pytest.raises(ImportError) as exc_info: + QwenLLM(config) + + assert "DashScope SDK is not installed" in str(exc_info.value) def test_model_default_value(): From 5fa2d292018b9d60f6f02cf5754e37d3019cf7fb Mon Sep 17 00:00:00 2001 From: Even Date: Wed, 4 Feb 2026 15:55:10 +0800 Subject: [PATCH 10/23] Fixed run failure caused by incorrect folder name (#229) * Enhance configuration management for OceanBase in config_loader.py - Added backward compatibility for OceanBase by constructing connection arguments from vector store configuration. - Updated unit tests to verify the inclusion of internal settings in the configuration. * disable env file * Fixed run failure caused by incorrect folder name --- .github/workflows/test.yml | 2 +- benchmark/README.md | 2 +- docs/benchmark/overview.md | 32 ++++++++++++++++---------------- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 493a8a5..0fad5cd 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -83,7 +83,7 @@ jobs: pip install --upgrade setuptools wheel - name: Install and verify dependencies - working-directory: benchmark/lomoco + working-directory: benchmark/locomo run: | pip install -r requirements.txt pip check diff --git a/benchmark/README.md b/benchmark/README.md index 750ae7a..958b437 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -10,4 +10,4 @@ This directory contains the benchmarking suite for PowerMem, including a REST AP ## Quick Links - **Benchmark Server**: FastAPI-based REST API server for memory operations (`server/`) -- **LOCOMO Load Testing**: Comprehensive benchmarking tool using the LOCOMO dataset (`lomoco/`) +- **LOCOMO Load Testing**: Comprehensive benchmarking tool using the LOCOMO dataset (`locomo/`) diff --git a/docs/benchmark/overview.md b/docs/benchmark/overview.md index e679872..b8f2654 100644 --- a/docs/benchmark/overview.md +++ b/docs/benchmark/overview.md @@ -13,7 +13,7 @@ The PowerMem Benchmark suite consists of two main components: - Token usage tracking - Support for multiple database backends (OceanBase, PostgreSQL) -2. **Load Testing Tool** (`benchmark/lomoco/`): A comprehensive benchmarking tool that: +2. **Load Testing Tool** (`benchmark/locomo/`): A comprehensive benchmarking tool that: - Tests memory addition and search performance - Evaluates response quality using multiple metrics - Measures latency and token consumption @@ -39,10 +39,10 @@ uvicorn benchmark.server.main:app --host 0.0.0.0 --port 8000 --reload ```bash # Install load testing dependencies -pip install -r benchmark/lomoco/requirements.txt +pip install -r benchmark/locomo/requirements.txt # Configure environment -cd benchmark/lomoco +cd benchmark/locomo cp .env.example .env # Edit .env with your API keys and server URL @@ -181,12 +181,12 @@ The LOCOMO benchmark tool performs comprehensive evaluations of memory systems u From the project root: ```bash - pip install -r benchmark/lomoco/requirements.txt + pip install -r benchmark/locomo/requirements.txt ``` - Or from the lomoco directory: + Or from the locomo directory: ```bash - cd benchmark/lomoco + cd benchmark/locomo pip install -r requirements.txt ``` @@ -195,13 +195,13 @@ The LOCOMO benchmark tool performs comprehensive evaluations of memory systems u 1. **Create environment configuration file** ```bash - cd benchmark/lomoco + cd benchmark/locomo cp .env.example .env ``` 2. **Edit the `.env` file** - Open `benchmark/lomoco/.env` and configure the following variables: + Open `benchmark/locomo/.env` and configure the following variables: ```bash # OpenAI API configuration @@ -234,15 +234,15 @@ The LOCOMO benchmark tool performs comprehensive evaluations of memory systems u 2. **Run the complete test script** - From the `benchmark/lomoco` directory: + From the `benchmark/locomo` directory: ```bash - cd benchmark/lomoco + cd benchmark/locomo bash run.sh [output_folder] ``` Or from the project root: ```bash - cd benchmark/lomoco && bash run.sh results + cd benchmark/locomo && bash run.sh results ``` The `output_folder` parameter is optional (defaults to `results`). @@ -264,13 +264,13 @@ You can also run individual test methods manually: **Memory Addition Test:** ```bash -cd benchmark/lomoco +cd benchmark/locomo python3 run_experiments.py --method add --output_folder results ``` **Memory Search Test:** ```bash -cd benchmark/lomoco +cd benchmark/locomo python3 run_experiments.py --method search --output_folder results --top_k 30 ``` @@ -380,7 +380,7 @@ The benchmark evaluates performance using multiple metrics: #### "api_base_url is not set" - **Solution**: - - Create `.env` file in `benchmark/lomoco/` directory + - Create `.env` file in `benchmark/locomo/` directory - Verify that `API_BASE_URL` is set correctly - Ensure the URL matches your running server address @@ -392,13 +392,13 @@ The benchmark evaluates performance using multiple metrics: #### "model is not set" or "openai_api_key is not set" - **Solution**: - - Check that `MODEL` and `OPENAI_API_KEY` are set in `benchmark/lomoco/.env` + - Check that `MODEL` and `OPENAI_API_KEY` are set in `benchmark/locomo/.env` - Verify the API key is valid - Ensure no extra quotes or spaces in the values #### Import errors - **Solution**: - - Install all dependencies: `pip install -r benchmark/lomoco/requirements.txt` + - Install all dependencies: `pip install -r benchmark/locomo/requirements.txt` - Ensure you're running from the correct directory - Check Python version (requires 3.10+) From 215cf8ecbf69f44c5a5318872673f2d5d188333e Mon Sep 17 00:00:00 2001 From: "jingshun.tq" <35712518+Teingi@users.noreply.github.com> Date: Thu, 5 Feb 2026 14:29:42 +0800 Subject: [PATCH 11/23] fixed: timezone config parsing in set_timezone & update version 0.5.0 (#233) * docs:Add Moltbot(clawdbot) memory Plugin * docs:Add Moltbot(clawdbot) memory Plugin (#204) * docs:Add Moltbot(clawdbot) memory Plugin * fixed timezone config parsing in set_timezone * fixed docs * update version to 0.5.0 --- examples/langchain/requirements.txt | 2 +- examples/langgraph/requirements.txt | 2 +- pyproject.toml | 2 +- src/powermem/core/audit.py | 2 +- src/powermem/core/telemetry.py | 4 ++-- src/powermem/utils/utils.py | 13 +++++++++++-- src/powermem/version.py | 2 +- 7 files changed, 18 insertions(+), 9 deletions(-) diff --git a/examples/langchain/requirements.txt b/examples/langchain/requirements.txt index fa16b5a..dceea6a 100644 --- a/examples/langchain/requirements.txt +++ b/examples/langchain/requirements.txt @@ -2,7 +2,7 @@ # Install with: pip install -r requirements.txt # Core dependencies -powermem>=0.3.0 +powermem>=0.5.0 python-dotenv>=1.0.0 openai>=1.109.1,<3.0.0 diff --git a/examples/langgraph/requirements.txt b/examples/langgraph/requirements.txt index e3a8422..21a4d0c 100644 --- a/examples/langgraph/requirements.txt +++ b/examples/langgraph/requirements.txt @@ -2,7 +2,7 @@ # Install with: pip install -r requirements.txt # Core dependencies -powermem>=0.3.0 +powermem>=0.5.0 python-dotenv>=1.0.0 # LangGraph and LangChain dependencies diff --git a/pyproject.toml b/pyproject.toml index 5a4823c..e9fdada 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "powermem" -version = "0.4.0" +version = "0.5.0" description = "Intelligent Memory System - Persistent memory layer for LLM applications" readme = "README.md" license = {text = "Apache-2.0"} diff --git a/src/powermem/core/audit.py b/src/powermem/core/audit.py index 7cdec99..22579c3 100644 --- a/src/powermem/core/audit.py +++ b/src/powermem/core/audit.py @@ -92,7 +92,7 @@ def log_event( "user_id": user_id, "agent_id": agent_id, "details": details, - "version": "0.4.0", + "version": "0.5.0", } # Log to file diff --git a/src/powermem/core/telemetry.py b/src/powermem/core/telemetry.py index 45a33c7..6971c5d 100644 --- a/src/powermem/core/telemetry.py +++ b/src/powermem/core/telemetry.py @@ -82,7 +82,7 @@ def capture_event( "user_id": user_id, "agent_id": agent_id, "timestamp": get_current_datetime().isoformat(), - "version": "0.4.0", + "version": "0.5.0", } self.events.append(event) @@ -182,7 +182,7 @@ def set_user_properties(self, user_id: str, properties: Dict[str, Any]) -> None: "properties": properties, "user_id": user_id, "timestamp": get_current_datetime().isoformat(), - "version": "0.4.0", + "version": "0.5.0", } self.events.append(event) diff --git a/src/powermem/utils/utils.py b/src/powermem/utils/utils.py index 6ee5335..b9bdee1 100644 --- a/src/powermem/utils/utils.py +++ b/src/powermem/utils/utils.py @@ -34,7 +34,7 @@ _timezone_lock = threading.Lock() -def set_timezone(timezone_str: str) -> None: +def set_timezone(timezone_str: Any) -> None: """ Set the timezone from configuration. @@ -47,7 +47,16 @@ def set_timezone(timezone_str: str) -> None: global _timezone_cache, _timezone_str with _timezone_lock: - _timezone_str = timezone_str + tz = timezone_str + if isinstance(tz, dict): + tz = tz.get("timezone") or tz.get("tz") + + # Only apply when we have a valid non-empty string + if not isinstance(tz, str) or not tz.strip(): + logger.warning("Invalid timezone config: %r", timezone_str) + return + + _timezone_str = tz _timezone_cache = None # Reset cache to force re-initialization diff --git a/src/powermem/version.py b/src/powermem/version.py index f3501e9..6311ad6 100644 --- a/src/powermem/version.py +++ b/src/powermem/version.py @@ -2,7 +2,7 @@ Version information management """ -__version__ = "0.4.0" +__version__ = "0.5.0" __version_info__ = tuple(map(int, __version__.split("."))) # Version history From 4afeda90bf90755be77e7cb3343591903b4205b6 Mon Sep 17 00:00:00 2001 From: Even Date: Thu, 5 Feb 2026 15:36:18 +0800 Subject: [PATCH 12/23] Optimize prompt content when extracting user profiles (#234) * Enhance configuration management for OceanBase in config_loader.py - Added backward compatibility for OceanBase by constructing connection arguments from vector store configuration. - Updated unit tests to verify the inclusion of internal settings in the configuration. * disable env file * Fixed run failure caused by incorrect folder name * upgrade pyobvector version for fixing create ob table's bug * Refactor user profile extraction prompts to return a single user prompt instead of a tuple. --- pyproject.toml | 2 +- src/powermem/prompts/user_profile_prompts.py | 35 ++++++++++---------- src/powermem/user_memory/user_memory.py | 21 +++++------- 3 files changed, 27 insertions(+), 31 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e9fdada..39f8273 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ "slowapi>=0.1.9", "click>=8.0.0", "rank-bm25>=0.2.2", - "pyobvector>=0.2.22,<0.3.0", + "pyobvector>=0.2.24,<0.3.0", "jieba>=0.42.1", "azure-identity>=1.24.0", "psycopg2-binary>=2.9.0", diff --git a/src/powermem/prompts/user_profile_prompts.py b/src/powermem/prompts/user_profile_prompts.py index d36672d..0e43cfb 100644 --- a/src/powermem/prompts/user_profile_prompts.py +++ b/src/powermem/prompts/user_profile_prompts.py @@ -123,9 +123,9 @@ def get_user_profile_extraction_prompt( conversation: str, existing_profile: Optional[str] = None, native_language: Optional[str] = None, -) -> Tuple[str, str]: +) -> str: """ - Generate the system prompt and user message for user profile extraction. + Generate the user prompt for user profile extraction. Args: conversation: The conversation text to analyze @@ -135,9 +135,7 @@ def get_user_profile_extraction_prompt( regardless of the languages used in the conversation. Returns: - Tuple of (system_prompt, user_message): - - system_prompt: Fixed instructions and context for the LLM - - user_message: The conversation text to analyze + str: The complete user prompt containing instructions and conversation text """ # Build the prompt with optional Current User Profile section current_profile_section = "" @@ -158,13 +156,15 @@ def get_user_profile_extraction_prompt( [Language Requirement]: You MUST extract and write the profile content in {target_language}, regardless of what languages are used in the conversation.""" - system_prompt = f"""{USER_PROFILE_EXTRACTION_PROMPT}{current_profile_section}{language_instruction} + user_prompt = f"""{USER_PROFILE_EXTRACTION_PROMPT}{current_profile_section}{language_instruction} [Target]: -Extract and return the user profile information as a text description:""" - user_message = conversation +Extract and return the user profile information as a text description: + +[Conversation]: +{conversation}""" - return system_prompt, user_message + return user_prompt @@ -174,9 +174,9 @@ def get_user_profile_topics_extraction_prompt( custom_topics: Optional[str] = None, strict_mode: bool = False, native_language: Optional[str] = None, -) -> Tuple[str, str]: +) -> str: """ - Generate the system prompt and user message for structured topic extraction. + Generate the user prompt for structured topic extraction. Args: conversation: The conversation text to analyze @@ -196,9 +196,7 @@ def get_user_profile_topics_extraction_prompt( language regardless of the languages used in the conversation. Returns: - Tuple of (system_prompt, user_message): - - system_prompt: Fixed instructions and context for the LLM - - user_message: The conversation text to analyze + str: The complete user prompt containing instructions and conversation text """ # Use custom topics if provided, otherwise use default if custom_topics: @@ -275,7 +273,7 @@ def get_user_profile_topics_extraction_prompt( [Language Requirement]: You MUST extract and write all topic values in {target_language}, regardless of what languages are used in the conversation. Keep the topic keys in snake_case English format, but write the values in {target_language}.""" - system_prompt = f"""You are a user profile topic extraction specialist. Your task is to analyze conversations and extract user profile information as structured topics. + user_prompt = f"""You are a user profile topic extraction specialist. Your task is to analyze conversations and extract user profile information as structured topics. {topics_section}{description_warning} @@ -307,9 +305,10 @@ def get_user_profile_topics_extraction_prompt( }} All keys must be in snake_case (lowercase with underscores). Values can be strings, numbers, or nested objects as needed. -Remember: Use only the topic names as keys, NOT the descriptions.""" +Remember: Use only the topic names as keys, NOT the descriptions. - user_message = conversation +[Conversation]: +{conversation}""" - return system_prompt, user_message + return user_prompt diff --git a/src/powermem/user_memory/user_memory.py b/src/powermem/user_memory/user_memory.py index 9e8307b..b9f383b 100644 --- a/src/powermem/user_memory/user_memory.py +++ b/src/powermem/user_memory/user_memory.py @@ -328,23 +328,20 @@ def _get_existing_profile_data( def _call_llm_for_extraction( self, - system_prompt: str, - user_message: str, + user_prompt: str, ) -> str: """ Call LLM to extract profile information. Args: - system_prompt: System prompt for LLM - user_message: User message for LLM + user_prompt: User prompt for LLM Returns: LLM response text """ response = self.memory.llm.generate_response( messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_message}, + {"role": "user", "content": user_prompt}, ], ) return remove_code_blocks(response).strip() @@ -380,8 +377,8 @@ def _extract_profile( data_key="profile_content", ) - # Generate system prompt and user message - system_prompt, user_message = get_user_profile_extraction_prompt( + # Generate user prompt + user_prompt = get_user_profile_extraction_prompt( conversation_text, existing_profile=existing_profile, native_language=native_language, @@ -389,7 +386,7 @@ def _extract_profile( # Call LLM to extract profile try: - profile_content = self._call_llm_for_extraction(system_prompt, user_message) + profile_content = self._call_llm_for_extraction(user_prompt) # Return empty string if response is empty or indicates no profile if not profile_content or profile_content.lower() in ["","\"\"", "none", "no profile information", "no relevant information"]: @@ -436,8 +433,8 @@ def _extract_topics( data_key="topics", ) - # Generate system prompt and user message - system_prompt, user_message = get_user_profile_topics_extraction_prompt( + # Generate user prompt + user_prompt = get_user_profile_topics_extraction_prompt( conversation_text, existing_topics=existing_topics, custom_topics=custom_topics, @@ -447,7 +444,7 @@ def _extract_topics( # Call LLM to extract topics try: - topics_text = self._call_llm_for_extraction(system_prompt, user_message) + topics_text = self._call_llm_for_extraction(user_prompt) # Return None if response is empty or indicates no topics if not topics_text or topics_text.lower() in ["", "none", "no profile information", "no relevant information", "{}"]: From 1809eb30e5ee78aa7e263a44666089ea4adb6cff Mon Sep 17 00:00:00 2001 From: Even Date: Thu, 5 Feb 2026 17:43:16 +0800 Subject: [PATCH 13/23] Enhance Memory class query handling (#236) * Enhance configuration management for OceanBase in config_loader.py - Added backward compatibility for OceanBase by constructing connection arguments from vector store configuration. - Updated unit tests to verify the inclusion of internal settings in the configuration. * disable env file * Fixed run failure caused by incorrect folder name * upgrade pyobvector version for fixing create ob table's bug * Refactor user profile extraction prompts to return a single user prompt instead of a tuple. * Enhance Memory class query handling * Enhance Memory class query handling --- src/powermem/core/async_memory.py | 6 ++++++ src/powermem/core/memory.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/src/powermem/core/async_memory.py b/src/powermem/core/async_memory.py index 503a04d..e091b6a 100644 --- a/src/powermem/core/async_memory.py +++ b/src/powermem/core/async_memory.py @@ -994,6 +994,12 @@ async def search( - "relations" (List, optional): Graph relations if graph store is enabled """ try: + if not query or not query.strip(): + return { + "results": [], + "relations": [] + } + # Select embedding service based on filters (for sub-store routing) embedding_service = self._get_embedding_service(filters) diff --git a/src/powermem/core/memory.py b/src/powermem/core/memory.py index c98280b..7c63849 100644 --- a/src/powermem/core/memory.py +++ b/src/powermem/core/memory.py @@ -1156,6 +1156,12 @@ def search( - "relations" (List, optional): Graph relations if graph store is enabled """ try: + if not query or not query.strip(): + return { + "results": [], + "relations": [] + } + # Select embedding service based on filters (for sub-store routing) embedding_service = self._get_embedding_service(filters) From 6bce38d2ebf26de1c78963bff7278408e67452b3 Mon Sep 17 00:00:00 2001 From: Chifang <40140008+Ripcord55@users.noreply.github.com> Date: Fri, 6 Feb 2026 10:12:30 +0800 Subject: [PATCH 14/23] tests: Action case adjust (#237) * oceanbase native language case * Oceanbase Native Hybrid Search Cases * Action case adjust --- .../test_scenario_5_custom_integration.py | 22 +++++++++---------- .../regression/test_scenario_7_multimodal.py | 1 + 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/tests/regression/test_scenario_5_custom_integration.py b/tests/regression/test_scenario_5_custom_integration.py index acfe53a..5476bca 100644 --- a/tests/regression/test_scenario_5_custom_integration.py +++ b/tests/regression/test_scenario_5_custom_integration.py @@ -85,7 +85,7 @@ class CustomEmbedderConfig(BaseEmbedderConfig): def __init__( self, - dims: int = 768, + dims: int = 1536, embedding_dims: Optional[int] = None, model: Optional[str] = None, api_key: Optional[str] = None, @@ -217,7 +217,7 @@ def test_step1_custom_llm_provider() -> None: 'config': { 'api_key': 'test_key', 'model': 'test_model', - 'embedding_dims': 768 + 'embedding_dims': 1536 } } } @@ -273,7 +273,7 @@ def __init__(self, config): # Access config attributes self.api_key = getattr(self.config, 'api_key', '') self.model = getattr(self.config, 'model', 'default') - self.dims = getattr(self.config, 'dims', 768) + self.dims = getattr(self.config, 'dims', 1536) def embed(self, text, memory_action=None) -> List[float]: """Generate embedding for text""" @@ -305,7 +305,7 @@ def test_step2_custom_embedder_provider() -> None: 'config': { 'api_key': 'your_key', 'model': 'your_model', - 'embedding_dims': 768 + 'embedding_dims': 1536 } } @@ -377,7 +377,7 @@ def insert(self, vectors, payloads=None, ids=None): """Insert vectors into a collection""" col_name = self.collection_name if col_name not in self._storage: - self.create_col(col_name, len(vectors[0]) if vectors else 768, "cosine") + self.create_col(col_name, len(vectors[0]) if vectors else 1536, "cosine") if ids is None: ids = [f"mem_{len(self._vectors[col_name]) + i}" for i in range(len(vectors))] @@ -483,7 +483,7 @@ def reset(self): """Reset by delete the collection and recreate it""" col_name = self.collection_name self.delete_col() - self.create_col(col_name, 768, "cosine") + self.create_col(col_name, 1536, "cosine") return True # Register custom Vector Store Provider @@ -513,7 +513,7 @@ def test_step3_custom_vector_store() -> None: 'config': { 'api_key': 'test_key', 'model': 'test_model', - 'embedding_dims': 768 + 'embedding_dims': 1536 } }, 'vector_store': { @@ -653,7 +653,7 @@ def get_context(self, query: str) -> str: 'config': { 'api_key': 'test_key', 'model': 'test_model', - 'embedding_dims': 768 + 'embedding_dims': 1536 } }, 'vector_store': { @@ -785,7 +785,7 @@ def _generate( 'config': { 'api_key': 'test_key', 'model': 'test_model', - 'embedding_dims': 768 + 'embedding_dims': 1536 } }, 'vector_store': { @@ -958,7 +958,7 @@ class SearchRequest(BaseModel): 'config': { 'api_key': 'test_key', 'model': 'test_model', - 'embedding_dims': 768 + 'embedding_dims': 1536 } }, 'vector_store': { @@ -1386,7 +1386,7 @@ def test_complete_example() -> None: 'config': { 'api_key': 'test_key', 'model': 'test_model', - 'embedding_dims': 768 + 'embedding_dims': 1536 } }, 'vector_store': { diff --git a/tests/regression/test_scenario_7_multimodal.py b/tests/regression/test_scenario_7_multimodal.py index 26e45a6..aaf5480 100644 --- a/tests/regression/test_scenario_7_multimodal.py +++ b/tests/regression/test_scenario_7_multimodal.py @@ -106,6 +106,7 @@ "config": { "model": "qwen3-asr-flash", # ASR model for speech-to-text "api_key": dashscope_api_key, + "dashscope_base_url": "https://dashscope.aliyuncs.com/api/v1", } }, } From 9a3600c9335a33daa740f60403331dd496398a4c Mon Sep 17 00:00:00 2001 From: "jingshun.tq" <35712518+Teingi@users.noreply.github.com> Date: Fri, 6 Feb 2026 11:07:00 +0800 Subject: [PATCH 15/23] release 0.5.0 (#238) --- README.md | 1 + README_CN.md | 3 ++- README_JP.md | 1 + src/powermem/version.py | 1 + 4 files changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 7855dc1..4f28102 100644 --- a/README.md +++ b/README.md @@ -209,6 +209,7 @@ The MCP server provides tools for memory management including adding, searching, | Version | Release Date | Function | |---------|--------------|---------| +| 0.5.0 | 2026.02.06 |
  • Unified configuration governance across SDK/API Server (pydantic-settings based)
  • Added OceanBase native hybrid search support
  • Enhanced Memory query handling and added sorting support for memory list operations
  • Added user profile support for custom native-language output
| | 0.4.0 | 2026.01.20 |
  • Sparse vector support for enhanced hybrid retrieval, combining dense vector, full-text, and sparse vector search
  • User memory query rewriting - automatically enhances search queries based on user profiles for improved recall
  • Schema upgrade and data migration tools for existing tables
| | 0.3.0 | 2026.01.09 |
  • Production-ready HTTP API Server with RESTful endpoints for all memory operations
  • Docker support for easy deployment and containerization
| | 0.2.0 | 2025.12.16 |
  • Advanced user profile management, supporting "personalized experience" for AI applications
  • Expanded multimodal support, including text, image, and audio memory
| diff --git a/README_CN.md b/README_CN.md index b54d8b0..67727b1 100644 --- a/README_CN.md +++ b/README_CN.md @@ -210,8 +210,9 @@ MCP Server提供记忆管理工具,包括添加、搜索、更新和删除记 | Version | Release Date | Function | |---------|-------|---------| +| 0.5.0 | 2026.02.06 |
  • 统一 SDK/API Server 配置治理(基于 pydantic-settings)
  • 新增 OceanBase native hybrid search 支持
  • 增强 Memory 查询处理并支持记忆列表排序
  • 新增用户画像支持自定义原生语言
| | 0.4.0 | 2026.01.20 |
  • 稀疏向量支持,增强混合检索能力,融合密集向量、全文检索和稀疏向量三种检索方式
  • 用户画像查询改写功能,基于用户画像自动改写查询以提升搜索召回率
  • 表结构升级和数据迁移工具,支持现有表的平滑升级
| -| 0.3.0 | 2026.01.09 |
  • 生产就绪的 HTTP API Server,提供所有记忆操作的 RESTful 端点
  • Docker 支持,便于部署和容器化
  • >
| +| 0.3.0 | 2026.01.09 |
  • 生产就绪的 HTTP API Server,提供所有记忆操作的 RESTful 接口
  • Docker 支持,便于部署和容器化
  • >
| | 0.2.0 | 2025.12.16 |
  • 高级用户画像管理,支持 AI 应用的"千人千面"
  • 扩展多模态支持,包括文本、图像和音频记忆
| | 0.1.0 | 2025.11.14 |
  • 核心记忆管理功能,支持持久化存储记忆
  • 支持向量、全文和图的混合检索
  • 基于 LLM 的事实提取智能记忆
  • 支持基于艾宾浩斯遗忘曲线的全生命周期记忆管理
  • 支持 Multi-Agent 记忆管理
  • 多存储后端支持(OceanBase、PostgreSQL、SQLite)
  • 支持通过多跳图检索的方式处理知识图谱的检索
| diff --git a/README_JP.md b/README_JP.md index ba25990..ea29269 100644 --- a/README_JP.md +++ b/README_JP.md @@ -209,6 +209,7 @@ MCP サーバーは、メモリの追加、検索、更新、削除を含むメ | Version | Release Date | Function | |---------|-------|---------| +| 0.5.0 | 2026.02.06 |
  • SDK/API Server の設定ガバナンスを統一(pydantic-settings ベース)
  • OceanBase の native hybrid search を追加
  • Memory のクエリ処理を強化し、メモリ一覧のソートに対応
  • ユーザープロフィールでカスタムのネイティブ言語出力をサポート
| | 0.4.0 | 2026.01.20 |
  • スパースベクトルサポート、高密度ベクトル、全文検索、スパースベクトルの3つの検索方式を融合したハイブリッド検索機能の強化
  • ユーザーメモリクエリ書き換え機能、ユーザープロフィールに基づいてクエリを自動的に書き換え、検索の再現率を向上
  • 既存テーブルのスキーマアップグレードとデータ移行ツール
| | 0.3.0 | 2026.01.09 |
  • 本番環境対応の HTTP API サーバー、すべてのメモリ操作の RESTful エンドポイントを提供
  • Docker サポート、簡単なデプロイとコンテナ化を実現
| | 0.2.0 | 2025.12.16 |
  • 高度なユーザープロフィール管理、AI アプリケーションの「千人千面」をサポート
  • テキスト、画像、音声メモリを含む拡張マルチモーダルサポート
| diff --git a/src/powermem/version.py b/src/powermem/version.py index 6311ad6..cfea505 100644 --- a/src/powermem/version.py +++ b/src/powermem/version.py @@ -7,6 +7,7 @@ # Version history VERSION_HISTORY = { + "0.5.0": "2026-02-06 - Version 0.5.0 release", "0.4.0": "2026-01-20 - Version 0.4.0 release", "0.3.1": "2026-01-13 - Version 0.3.1 release", "0.3.0": "2026-01-09 - Version 0.3.0 release", From 45b855bedf61f1526486e1bd64bf9f83665c475e Mon Sep 17 00:00:00 2001 From: Even Date: Fri, 6 Feb 2026 14:07:49 +0800 Subject: [PATCH 16/23] resolve conflicts (#240) --- src/powermem/user_memory/user_memory.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/powermem/user_memory/user_memory.py b/src/powermem/user_memory/user_memory.py index b9f383b..da4e080 100644 --- a/src/powermem/user_memory/user_memory.py +++ b/src/powermem/user_memory/user_memory.py @@ -151,8 +151,8 @@ def add( profile_type: str = "content", custom_topics: Optional[str] = None, strict_mode: bool = False, - include_roles: Optional[List[str]] = ["user"], - exclude_roles: Optional[List[str]] = ["assistant"], + include_roles: Optional[List[str]] = None, + exclude_roles: Optional[List[str]] = None, native_language: Optional[str] = None, ) -> Dict[str, Any]: """ @@ -186,9 +186,9 @@ def add( - Descriptions are for reference only and should NOT be used as keys in the output strict_mode: If True, only output topics from the provided list. Only used when profile_type="topics". Default: False include_roles: List of roles to include when filtering messages for profile extraction. - Defaults to ["user"]. If explicitly set to None or [], no include filter is applied. + Defaults to None. If explicitly set to None or [], no include filter is applied. exclude_roles: List of roles to exclude when filtering messages for profile extraction. - Defaults to ["assistant"]. If explicitly set to None or [], no exclude filter is applied. + Defaults to None. If explicitly set to None or [], no exclude filter is applied. native_language: Optional ISO 639-1 language code (e.g., "zh", "en") to specify the target language for profile extraction. If specified, the extracted profile will be written in this language regardless of the languages used in the conversation. If not specified, the profile language From 082e6175d8f76ab5333e9f253635dbd0ae11ad31 Mon Sep 17 00:00:00 2001 From: "jingshun.tq" <35712518+Teingi@users.noreply.github.com> Date: Fri, 6 Feb 2026 14:41:38 +0800 Subject: [PATCH 17/23] release 0.5.1 (#242) * release 0.5.0 * release 0.5.1 --- examples/langchain/requirements.txt | 2 +- examples/langgraph/requirements.txt | 2 +- pyproject.toml | 2 +- src/powermem/core/audit.py | 2 +- src/powermem/core/telemetry.py | 4 ++-- src/powermem/version.py | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/langchain/requirements.txt b/examples/langchain/requirements.txt index dceea6a..7eff799 100644 --- a/examples/langchain/requirements.txt +++ b/examples/langchain/requirements.txt @@ -2,7 +2,7 @@ # Install with: pip install -r requirements.txt # Core dependencies -powermem>=0.5.0 +powermem>=0.5.1 python-dotenv>=1.0.0 openai>=1.109.1,<3.0.0 diff --git a/examples/langgraph/requirements.txt b/examples/langgraph/requirements.txt index 21a4d0c..a8de602 100644 --- a/examples/langgraph/requirements.txt +++ b/examples/langgraph/requirements.txt @@ -2,7 +2,7 @@ # Install with: pip install -r requirements.txt # Core dependencies -powermem>=0.5.0 +powermem>=0.5.1 python-dotenv>=1.0.0 # LangGraph and LangChain dependencies diff --git a/pyproject.toml b/pyproject.toml index 39f8273..36f0c4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "powermem" -version = "0.5.0" +version = "0.5.1" description = "Intelligent Memory System - Persistent memory layer for LLM applications" readme = "README.md" license = {text = "Apache-2.0"} diff --git a/src/powermem/core/audit.py b/src/powermem/core/audit.py index 22579c3..a2061ae 100644 --- a/src/powermem/core/audit.py +++ b/src/powermem/core/audit.py @@ -92,7 +92,7 @@ def log_event( "user_id": user_id, "agent_id": agent_id, "details": details, - "version": "0.5.0", + "version": "0.5.1", } # Log to file diff --git a/src/powermem/core/telemetry.py b/src/powermem/core/telemetry.py index 6971c5d..938eaab 100644 --- a/src/powermem/core/telemetry.py +++ b/src/powermem/core/telemetry.py @@ -82,7 +82,7 @@ def capture_event( "user_id": user_id, "agent_id": agent_id, "timestamp": get_current_datetime().isoformat(), - "version": "0.5.0", + "version": "0.5.1", } self.events.append(event) @@ -182,7 +182,7 @@ def set_user_properties(self, user_id: str, properties: Dict[str, Any]) -> None: "properties": properties, "user_id": user_id, "timestamp": get_current_datetime().isoformat(), - "version": "0.5.0", + "version": "0.5.1", } self.events.append(event) diff --git a/src/powermem/version.py b/src/powermem/version.py index cfea505..c188be0 100644 --- a/src/powermem/version.py +++ b/src/powermem/version.py @@ -2,7 +2,7 @@ Version information management """ -__version__ = "0.5.0" +__version__ = "0.5.1" __version_info__ = tuple(map(int, __version__.split("."))) # Version history From 0c02460ee3272bdff79f31de0a19887ed569bc6e Mon Sep 17 00:00:00 2001 From: Chifang <40140008+Ripcord55@users.noreply.github.com> Date: Mon, 9 Feb 2026 17:37:06 +0800 Subject: [PATCH 18/23] test: Regression repair (#245) * regression update * Hybrid search time adjust --- .github/workflows/regression.yml | 9 ++------- tests/regression/test_native_hybrid_search.py | 2 +- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/.github/workflows/regression.yml b/.github/workflows/regression.yml index c499acb..b81742f 100644 --- a/.github/workflows/regression.yml +++ b/.github/workflows/regression.yml @@ -88,18 +88,12 @@ jobs: run: | pip install -e ".[dev,test]" - # - name: Install Docker - # run: | - # curl -fsSL https://get.docker.com -o get-docker.sh - # sudo sh get-docker.sh - - - name: Deploy SeekDB (OceanBase) run: | # Remove existing container if it exists sudo docker rm -f seekdb 2>/dev/null || true # Start SeekDB container - sudo docker run -d -p 10001:2881 --name seekdb oceanbase/seekdb + sudo docker run -d -p 10001:2881 -e MEMORY_LIMIT=4G -e LOG_DISK_SIZE=4G -e DATAFILE_SIZE=4G -e DATAFILE_NEXT=4G -e DATAFILE_MAXSIZE=100G --name seekdb oceanbase/seekdb # Wait for database to be ready echo "Waiting for SeekDB to be ready..." timeout=60 @@ -149,6 +143,7 @@ jobs: sed -i 's|^GRAPH_STORE_PASSWORD=.*|GRAPH_STORE_PASSWORD=|' .env sed -i "s|^LLM_API_KEY=.*|LLM_API_KEY=${SILICONFLOW_CN_API_KEY}|" .env sed -i "s|^EMBEDDING_API_KEY=.*|EMBEDDING_API_KEY=${QWEN_API_KEY}|" .env + sed -i "s|^POWERMEM_SERVER_API_KEYS=.*|POWERMEM_SERVER_API_KEYS=key1,key2,key3|" .env - name: Run regression tests env: diff --git a/tests/regression/test_native_hybrid_search.py b/tests/regression/test_native_hybrid_search.py index a4c9b47..8f20bbd 100644 --- a/tests/regression/test_native_hybrid_search.py +++ b/tests/regression/test_native_hybrid_search.py @@ -366,7 +366,7 @@ def test_tc009_large_data_search(self): # Step 3: Verify performance and result accuracy log_info("\n[Step 3] Verifying performance and accuracy...") - assert search_time < 1.0, f"Search should complete within 1 second, took {search_time:.3f}s" + assert search_time < 2.0, f"Search should complete within 2 second, took {search_time:.3f}s" assert len(memories) > 0, "Should return relevant results" # Verify result accuracy From fe7380ea13100d32cdc789b6469daadd4a957d5f Mon Sep 17 00:00:00 2001 From: "jingshun.tq" <35712518+Teingi@users.noreply.github.com> Date: Tue, 10 Feb 2026 11:22:13 +0800 Subject: [PATCH 19/23] prompts upadte:LANGUAGE DO NOT translate (#248) --- src/powermem/prompts/intelligent_memory_prompts.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/powermem/prompts/intelligent_memory_prompts.py b/src/powermem/prompts/intelligent_memory_prompts.py index f4362b8..9fa011e 100644 --- a/src/powermem/prompts/intelligent_memory_prompts.py +++ b/src/powermem/prompts/intelligent_memory_prompts.py @@ -25,6 +25,7 @@ 2. COMPLETE: Extract self-contained facts with who/what/when/where when available. 3. SEPARATE: Extract distinct facts separately, especially when they have different time periods. 4. INTENTIONS & NEEDS: ALWAYS extract user intentions, needs, and requests even without time information. Examples: "Want to book a doctor appointment", "Need to call someone", "Plan to visit a place". +5. LANGUAGE: DO NOT translate. Preserve the original language of the source text for each extracted fact. If the input is Chinese, output facts in Chinese; if English, output in English; if mixed-language, keep each fact in the language it appears in. Examples: Input: Hi. @@ -51,7 +52,7 @@ - Extract from user/assistant messages only - Extract intentions, needs, and requests even without time information - If no relevant facts, return empty list -- Preserve input language +- Output must preserve the input language (no translation) Extract facts from the conversation below:""" @@ -87,6 +88,7 @@ Delete: Only clear contradictions (e.g., "Loves pizza" vs "Dislikes pizza"). Prefer UPDATE for time conflicts. Important: Use existing IDs only. Keep same ID when updating. Always preserve temporal information. +LANGUAGE (CRITICAL): Do NOT translate memory text. Keep the same language as the incoming fact(s) and the original memory whenever possible. """ # Alias for compatibility From 7764d1702fe29bb574d1f272eddd4418f314f6d3 Mon Sep 17 00:00:00 2001 From: Even Date: Tue, 10 Feb 2026 11:24:07 +0800 Subject: [PATCH 20/23] Enhance memory listing functionality with pagination and sorting support (#246) - Updated the `list` method in various vector store classes to include optional parameters for filtering, pagination (offset and limit), and sorting (order by and order direction). - Refactored the `StorageAdapter` class to streamline memory retrieval with the new parameters. --- src/powermem/storage/adapter.py | 83 ++----------------- src/powermem/storage/base.py | 12 ++- src/powermem/storage/oceanbase/oceanbase.py | 44 +++++++--- src/powermem/storage/pgvector/pgvector.py | 46 ++++++++-- .../storage/sqlite/sqlite_vector_store.py | 18 +++- 5 files changed, 102 insertions(+), 101 deletions(-) diff --git a/src/powermem/storage/adapter.py b/src/powermem/storage/adapter.py index edd1f38..cfccf43 100644 --- a/src/powermem/storage/adapter.py +++ b/src/powermem/storage/adapter.py @@ -484,14 +484,13 @@ def get_all_memories( if run_id: filters["run_id"] = run_id - # Get memories from vector store with filters (if supported) - if filters and hasattr(self.vector_store, 'list'): - # Pass filters to vector store's list method for database-level filtering - # Request more records to support offset - results = self.vector_store.list(filters=filters, limit=limit + offset) - else: - # Fallback: get all and filter in memory - results = self.vector_store.list(limit=limit + offset) + results = self.vector_store.list( + filters=filters if filters else None, + limit=limit, + offset=offset, + order_by=sort_by, + order=order + ) # OceanBase returns [memories], SQLite/PGVector return memories directly if results and isinstance(results[0], list): @@ -549,73 +548,7 @@ def get_all_memories( memories.append(memory) - # Apply sorting if specified - if sort_by: - memories = self._sort_memories(memories, sort_by, order) - - # Apply offset and limit - return memories[offset:offset + limit] - - def _sort_memories( - self, - memories: List[Dict[str, Any]], - sort_by: str, - order: str = "desc" - ) -> List[Dict[str, Any]]: - """ - Sort memories by specified field. - - Args: - memories: List of memory dictionaries - sort_by: Field to sort by. Options: "created_at", "updated_at", "id" - order: Sort order. "desc" for descending (default), "asc" for ascending - - Returns: - Sorted list of memories - """ - if not memories or not sort_by: - return memories - - reverse = (order.lower() == "desc") - - def get_sort_key(memory: Dict[str, Any]) -> Any: - """Get the sort key value from memory.""" - if sort_by == "created_at": - created_at = memory.get("created_at") - if created_at is None: - return datetime.min if reverse else datetime.max - # Handle both datetime objects and ISO format strings - if isinstance(created_at, str): - try: - from datetime import datetime as dt - return dt.fromisoformat(created_at.replace('Z', '+00:00')) - except (ValueError, AttributeError): - return datetime.min if reverse else datetime.max - return created_at if isinstance(created_at, datetime) else datetime.min - elif sort_by == "updated_at": - updated_at = memory.get("updated_at") - if updated_at is None: - return datetime.min if reverse else datetime.max - # Handle both datetime objects and ISO format strings - if isinstance(updated_at, str): - try: - from datetime import datetime as dt - return dt.fromisoformat(updated_at.replace('Z', '+00:00')) - except (ValueError, AttributeError): - return datetime.min if reverse else datetime.max - return updated_at if isinstance(updated_at, datetime) else datetime.min - elif sort_by == "id": - return memory.get("id", 0) - else: - # Unknown sort field, return original order - return None - - try: - sorted_memories = sorted(memories, key=get_sort_key, reverse=reverse) - return sorted_memories - except Exception as e: - logger.warning(f"Failed to sort memories by {sort_by}: {e}, returning original order") - return memories + return memories def clear_memories( self, diff --git a/src/powermem/storage/base.py b/src/powermem/storage/base.py index acef81c..055f1a7 100644 --- a/src/powermem/storage/base.py +++ b/src/powermem/storage/base.py @@ -68,8 +68,16 @@ def col_info(self): pass @abstractmethod - def list(self, filters=None, limit=None): - """List all memories.""" + def list(self, filters=None, limit=None, offset=None, order_by=None, order="desc"): + """List all memories with optional filtering, pagination and sorting. + + Args: + filters: Optional filters to apply + limit: Maximum number of results to return + offset: Number of results to skip + order_by: Field to sort by (e.g., "created_at", "updated_at", "id") + order: Sort order, "desc" for descending or "asc" for ascending + """ pass @abstractmethod diff --git a/src/powermem/storage/oceanbase/oceanbase.py b/src/powermem/storage/oceanbase/oceanbase.py index 931b860..646aa95 100644 --- a/src/powermem/storage/oceanbase/oceanbase.py +++ b/src/powermem/storage/oceanbase/oceanbase.py @@ -1986,7 +1986,8 @@ def col_info(self): logger.error(f"Failed to get collection info for '{self.collection_name}': {e}", exc_info=True) raise - def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None): + def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None, + offset: Optional[int] = None, order_by: Optional[str] = None, order: str = "desc"): """List all memories.""" try: table = Table(self.collection_name, self.obvector.metadata_obj, autoload_with=self.obvector.engine) @@ -1995,18 +1996,38 @@ def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None): where_clause = self._generate_where_clause(filters, table=table) # Build output column name list - output_columns = self._get_standard_column_names(include_vector_field=True) + output_columns_names = self._get_standard_column_names(include_vector_field=True) + + # Build select statement with columns + output_columns = [table.c[col_name] for col_name in output_columns_names if col_name in table.c] + stmt = select(*output_columns) + + # Apply WHERE clause + if where_clause is not None: + stmt = stmt.where(where_clause) + + # Apply ORDER BY clause for sorting + if order_by: + if order_by in table.c: + order_column = table.c[order_by] + if order.lower() == "desc": + stmt = stmt.order_by(order_column.desc()) + else: + stmt = stmt.order_by(order_column.asc()) + + # Apply OFFSET and LIMIT for pagination + if offset is not None: + stmt = stmt.offset(offset) + if limit is not None: + stmt = stmt.limit(limit) - # Get all records - results = self.obvector.get( - table_name=self.collection_name, - ids=None, - output_column_name=output_columns, - where_clause=where_clause - ) + # Execute query + with self.obvector.engine.connect() as conn: + results = conn.execute(stmt) + rows = results.fetchall() memories = [] - for row in results.fetchall(): + for row in rows: parsed = self._parse_row_to_dict(row, include_vector=True, extract_score=False) memories.append(self._create_output_data( @@ -2016,9 +2037,6 @@ def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None): parsed["metadata"] )) - if limit: - memories = memories[:limit] - logger.debug(f"Successfully listed {len(memories)} memories from collection '{self.collection_name}'") return [memories] diff --git a/src/powermem/storage/pgvector/pgvector.py b/src/powermem/storage/pgvector/pgvector.py index 2a22539..25f95c4 100644 --- a/src/powermem/storage/pgvector/pgvector.py +++ b/src/powermem/storage/pgvector/pgvector.py @@ -365,7 +365,10 @@ def col_info(self) -> dict[str, Any]: def list( self, filters: Optional[dict] = None, - limit: Optional[int] = 100 + limit: Optional[int] = 100, + offset: Optional[int] = None, + order_by: Optional[str] = None, + order: str = "desc" ) -> List[OutputData]: """ List all vectors in a collection. @@ -373,6 +376,9 @@ def list( Args: filters (Dict, optional): Filters to apply to the list. limit (int, optional): Number of vectors to return. Defaults to 100. + offset (int, optional): Number of results to skip. + order_by (str, optional): Field to sort by (e.g., "created_at", "updated_at", "id"). + order (str, optional): Sort order, "desc" for descending or "asc" for ascending. Returns: List[OutputData]: List of vectors. @@ -386,16 +392,38 @@ def list( filter_params.extend([k, str(v)]) filter_clause = "WHERE " + " AND ".join(filter_conditions) if filter_conditions else "" - - query = f""" - SELECT id, vector, payload - FROM {self.collection_name} - {filter_clause} - LIMIT %s - """ + + # Build ORDER BY clause for sorting + order_clause = "" + if order_by: + order_upper = order.upper() + if order_by in ["created_at", "updated_at"]: + # Sort by JSON field in payload + order_clause = f"ORDER BY payload->>'{order_by}' {order_upper}" + elif order_by == "id": + # Sort by id column + order_clause = f"ORDER BY id {order_upper}" + + # Build query with all clauses + query_parts = [ + f"SELECT id, vector, payload", + f"FROM {self.collection_name}", + filter_clause, + order_clause, + ] + + # Add OFFSET and LIMIT + if offset is not None: + query_parts.append("OFFSET %s") + filter_params.append(offset) + + query_parts.append("LIMIT %s") + filter_params.append(limit) + + query = "\n".join(part for part in query_parts if part) with self._get_cursor() as cur: - cur.execute(query, (*filter_params, limit)) + cur.execute(query, tuple(filter_params)) results = cur.fetchall() return [OutputData(id=r[0], score=None, payload=r[2]) for r in results] diff --git a/src/powermem/storage/sqlite/sqlite_vector_store.py b/src/powermem/storage/sqlite/sqlite_vector_store.py index 7658409..45f4051 100644 --- a/src/powermem/storage/sqlite/sqlite_vector_store.py +++ b/src/powermem/storage/sqlite/sqlite_vector_store.py @@ -240,8 +240,8 @@ def col_info(self) -> Dict[str, Any]: "db_path": self.db_path } - def list(self, filters=None, limit=None) -> List[OutputData]: - """List all memories with optional filtering.""" + def list(self, filters=None, limit=None, offset=None, order_by=None, order="desc") -> List[OutputData]: + """List all memories with optional filtering, pagination and sorting.""" query = f"SELECT id, vector, payload FROM {self.collection_name}" query_params = [] @@ -256,8 +256,22 @@ def list(self, filters=None, limit=None) -> List[OutputData]: if conditions: query += " WHERE " + " AND ".join(conditions) + # Add ORDER BY clause for sorting + if order_by: + order_upper = order.upper() + if order_by in ["created_at", "updated_at"]: + # Sort by JSON field in payload + query += f" ORDER BY json_extract(payload, '$.{order_by}') {order_upper}" + elif order_by == "id": + # Sort by id column + query += f" ORDER BY id {order_upper}" + + # Add LIMIT and OFFSET for pagination + # Note: In SQLite, LIMIT must come after ORDER BY and before OFFSET if limit: query += f" LIMIT {limit}" + if offset: + query += f" OFFSET {offset}" results = [] with self._lock: From b1667d7e7ae738976dde80cc66fdd7cba97bbb16 Mon Sep 17 00:00:00 2001 From: Even Date: Wed, 11 Feb 2026 11:48:00 +0800 Subject: [PATCH 21/23] fix search bug (#253) * Enhance memory listing functionality with pagination and sorting support - Updated the `list` method in various vector store classes to include optional parameters for filtering, pagination (offset and limit), and sorting (order by and order direction). - Refactored the `StorageAdapter` class to streamline memory retrieval with the new parameters. * fix search bug --- src/powermem/storage/oceanbase/oceanbase.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/powermem/storage/oceanbase/oceanbase.py b/src/powermem/storage/oceanbase/oceanbase.py index 646aa95..f128805 100644 --- a/src/powermem/storage/oceanbase/oceanbase.py +++ b/src/powermem/storage/oceanbase/oceanbase.py @@ -477,7 +477,7 @@ def insert(self, logger.error(f"Failed to insert vectors into collection '{self.collection_name}': {e}", exc_info=True) raise - def _generate_where_clause(self, filters: Optional[Dict] = None, table = None) -> Optional[List]: + def _generate_where_clause(self, filters: Optional[Dict] = None, table = None): """ Generate a properly formatted where clause for OceanBase. @@ -497,7 +497,7 @@ def _generate_where_clause(self, filters: Optional[Dict] = None, table = None) - table: SQLAlchemy Table object to use for column references. If None, uses self.table. Returns: - Optional[List]: List of SQLAlchemy ColumnElement objects for where clause. + SQLAlchemy ColumnElement or None: A single SQLAlchemy expression for where clause, or None if no filters. """ # Use provided table or fall back to self.table if table is None: @@ -582,7 +582,7 @@ def process_condition(cond): # Handle complex filters with AND/OR result = process_condition(filters) - return [result] if result is not None else None + return result def _row_to_model(self, row): """ @@ -929,8 +929,8 @@ def _fulltext_search(self, query: str, limit: int = 5, filters: Optional[Dict] = # Combine FTS condition with filter conditions where_conditions = [fts_condition] - if filter_where_clause: - where_conditions.extend(filter_where_clause) + if filter_where_clause is not None: + where_conditions.append(filter_where_clause) # Build custom query to include score field try: @@ -1064,9 +1064,8 @@ def _sparse_search(self, sparse_embedding: Dict[int, float], limit: int = 5, fil stmt = select(*columns) # Add where conditions - if filter_where_clause: - for condition in filter_where_clause: - stmt = stmt.where(condition) + if filter_where_clause is not None: + stmt = stmt.where(filter_where_clause) # Order by score ASC (lower negative_inner_product means higher similarity) stmt = stmt.order_by(text('score ASC')) From 0421b1a227dc8aab3a7a736a895c589044e5e500 Mon Sep 17 00:00:00 2001 From: Even Date: Wed, 11 Feb 2026 14:40:32 +0800 Subject: [PATCH 22/23] Enhance PGVectorConfig for flexible database connection settings (#257) * Enhance memory listing functionality with pagination and sorting support - Updated the `list` method in various vector store classes to include optional parameters for filtering, pagination (offset and limit), and sorting (order by and order direction). - Refactored the `StorageAdapter` class to streamline memory retrieval with the new parameters. * fix search bug * Enhance PGVectorConfig for flexible database connection settings - Modified PGVectorConfig to allow for alternative environment variable aliases for database connection parameters (user, password, host, port). - Improved unit tests for memory listing to support sorting and pagination through a mock list function. --- src/powermem/storage/config/pgvector.py | 6 +- tests/unit/test_list_memories_sorting.py | 106 +++++++++++++++-------- 2 files changed, 74 insertions(+), 38 deletions(-) diff --git a/src/powermem/storage/config/pgvector.py b/src/powermem/storage/config/pgvector.py index f777c0c..bc1c8a0 100644 --- a/src/powermem/storage/config/pgvector.py +++ b/src/powermem/storage/config/pgvector.py @@ -129,8 +129,10 @@ def check_auth_and_connection(cls, values): return values if values.get("connection_string") is not None: return values - user, password = values.get("user"), values.get("password") - host, port = values.get("host"), values.get("port") + user = values.get("user") or values.get("POSTGRES_USER") + password = values.get("password") or values.get("POSTGRES_PASSWORD") + host = values.get("host") or values.get("POSTGRES_HOST") + port = values.get("port") or values.get("POSTGRES_PORT") if user is not None or password is not None: if not user or not password: raise ValueError("Both 'user' and 'password' must be provided.") diff --git a/tests/unit/test_list_memories_sorting.py b/tests/unit/test_list_memories_sorting.py index 5fddb93..e5529ea 100644 --- a/tests/unit/test_list_memories_sorting.py +++ b/tests/unit/test_list_memories_sorting.py @@ -43,6 +43,55 @@ def _create_output_data_list(self, memories_data, default_user_id="test_user"): output_data_list.append(output_data) return output_data_list + def _create_mock_list_with_sorting(self, output_data_list): + """Create a mock list function that supports sorting and pagination.""" + def list_side_effect(filters=None, limit=None, offset=None, order_by=None, order="desc"): + # Start with all data + result = output_data_list[:] + + # Apply sorting if order_by is specified + if order_by: + # Extract sort key from payload or object attributes + def get_sort_key(item): + # Special handling for 'id' field - it's on the object itself + if order_by == 'id': + value = item.id if hasattr(item, 'id') else item.get('id') + # For other fields, check payload first + elif hasattr(item, 'payload'): + value = item.payload.get(order_by) + else: + value = item.get(order_by) + + # Handle None values - put them at the end for both asc and desc + if value is None: + # Use a very large/small value to push None to the end + from datetime import datetime + if order == "desc": + return datetime.min # None goes to end (smallest) + else: + return datetime.max # None goes to end (largest) + + return value + + # Sort the results + reverse = (order == "desc") + try: + result = sorted(result, key=get_sort_key, reverse=reverse) + except Exception as e: + # If sorting fails, return unsorted + print(f"Sorting failed: {e}") + pass + + # Apply pagination (offset and limit) + if offset is not None: + result = result[offset:] + if limit is not None: + result = result[:limit] + + return result + + return list_side_effect + def test_get_all_with_sort_by_updated_at_desc(self, mock_memory): """Test get_all with sorting by updated_at in descending order.""" # Create test data with different update times @@ -71,12 +120,9 @@ def test_get_all_with_sort_by_updated_at_desc(self, mock_memory): # Mock vector_store.list to return OutputData objects # Need to handle both with filters and without filters calls output_data_list = self._create_output_data_list(test_memories_data) - - def list_side_effect(filters=None, limit=None): - # Return the mock data regardless of filters (filtering happens in get_all_memories) - return output_data_list - - mock_memory.storage.vector_store.list = MagicMock(side_effect=list_side_effect) + mock_memory.storage.vector_store.list = MagicMock( + side_effect=self._create_mock_list_with_sorting(output_data_list) + ) result = mock_memory.get_all( user_id="test_user", @@ -120,11 +166,9 @@ def test_get_all_with_sort_by_updated_at_asc(self, mock_memory): ] output_data_list = self._create_output_data_list(test_memories_data) - - def list_side_effect(filters=None, limit=None): - return output_data_list - - mock_memory.storage.vector_store.list = MagicMock(side_effect=list_side_effect) + mock_memory.storage.vector_store.list = MagicMock( + side_effect=self._create_mock_list_with_sorting(output_data_list) + ) result = mock_memory.get_all( user_id="test_user", @@ -168,11 +212,9 @@ def test_get_all_with_sort_by_created_at_desc(self, mock_memory): ] output_data_list = self._create_output_data_list(test_memories_data) - - def list_side_effect(filters=None, limit=None): - return output_data_list - - mock_memory.storage.vector_store.list = MagicMock(side_effect=list_side_effect) + mock_memory.storage.vector_store.list = MagicMock( + side_effect=self._create_mock_list_with_sorting(output_data_list) + ) result = mock_memory.get_all( user_id="test_user", @@ -199,11 +241,9 @@ def test_get_all_with_sort_by_id_desc(self, mock_memory): ] output_data_list = self._create_output_data_list(test_memories_data) - - def list_side_effect(filters=None, limit=None): - return output_data_list - - mock_memory.storage.vector_store.list = MagicMock(side_effect=list_side_effect) + mock_memory.storage.vector_store.list = MagicMock( + side_effect=self._create_mock_list_with_sorting(output_data_list) + ) result = mock_memory.get_all( user_id="test_user", @@ -230,11 +270,9 @@ def test_get_all_without_sorting(self, mock_memory): ] output_data_list = self._create_output_data_list(test_memories_data) - - def list_side_effect(filters=None, limit=None): - return output_data_list - - mock_memory.storage.vector_store.list = MagicMock(side_effect=list_side_effect) + mock_memory.storage.vector_store.list = MagicMock( + side_effect=self._create_mock_list_with_sorting(output_data_list) + ) result = mock_memory.get_all( user_id="test_user", @@ -276,11 +314,9 @@ def test_get_all_with_filtering_and_sorting(self, mock_memory): ] output_data_list = self._create_output_data_list(test_memories_data) - - def list_side_effect(filters=None, limit=None): - return output_data_list - - mock_memory.storage.vector_store.list = MagicMock(side_effect=list_side_effect) + mock_memory.storage.vector_store.list = MagicMock( + side_effect=self._create_mock_list_with_sorting(output_data_list) + ) result = mock_memory.get_all( user_id="test_user", @@ -316,11 +352,9 @@ def test_get_all_with_pagination_and_sorting(self, mock_memory): ] output_data_list = self._create_output_data_list(test_memories_data) - - def list_side_effect(filters=None, limit=None): - return output_data_list - - mock_memory.storage.vector_store.list = MagicMock(side_effect=list_side_effect) + mock_memory.storage.vector_store.list = MagicMock( + side_effect=self._create_mock_list_with_sorting(output_data_list) + ) # Get first page result1 = mock_memory.get_all( From e3a6986457cd552ff8f22247594a7741144971ba Mon Sep 17 00:00:00 2001 From: Teingi Date: Thu, 12 Feb 2026 22:06:58 +0800 Subject: [PATCH 23/23] release v0.5.2 --- examples/langchain/requirements.txt | 2 +- examples/langgraph/requirements.txt | 2 +- pyproject.toml | 2 +- src/powermem/core/audit.py | 2 +- src/powermem/core/telemetry.py | 4 ++-- src/powermem/version.py | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/langchain/requirements.txt b/examples/langchain/requirements.txt index 7eff799..b9c7e28 100644 --- a/examples/langchain/requirements.txt +++ b/examples/langchain/requirements.txt @@ -2,7 +2,7 @@ # Install with: pip install -r requirements.txt # Core dependencies -powermem>=0.5.1 +powermem>=0.5.2 python-dotenv>=1.0.0 openai>=1.109.1,<3.0.0 diff --git a/examples/langgraph/requirements.txt b/examples/langgraph/requirements.txt index a8de602..7c4759d 100644 --- a/examples/langgraph/requirements.txt +++ b/examples/langgraph/requirements.txt @@ -2,7 +2,7 @@ # Install with: pip install -r requirements.txt # Core dependencies -powermem>=0.5.1 +powermem>=0.5.2 python-dotenv>=1.0.0 # LangGraph and LangChain dependencies diff --git a/pyproject.toml b/pyproject.toml index 36f0c4c..eef0172 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "powermem" -version = "0.5.1" +version = "0.5.2" description = "Intelligent Memory System - Persistent memory layer for LLM applications" readme = "README.md" license = {text = "Apache-2.0"} diff --git a/src/powermem/core/audit.py b/src/powermem/core/audit.py index a2061ae..5aff697 100644 --- a/src/powermem/core/audit.py +++ b/src/powermem/core/audit.py @@ -92,7 +92,7 @@ def log_event( "user_id": user_id, "agent_id": agent_id, "details": details, - "version": "0.5.1", + "version": "0.5.2", } # Log to file diff --git a/src/powermem/core/telemetry.py b/src/powermem/core/telemetry.py index 938eaab..46619b5 100644 --- a/src/powermem/core/telemetry.py +++ b/src/powermem/core/telemetry.py @@ -82,7 +82,7 @@ def capture_event( "user_id": user_id, "agent_id": agent_id, "timestamp": get_current_datetime().isoformat(), - "version": "0.5.1", + "version": "0.5.2", } self.events.append(event) @@ -182,7 +182,7 @@ def set_user_properties(self, user_id: str, properties: Dict[str, Any]) -> None: "properties": properties, "user_id": user_id, "timestamp": get_current_datetime().isoformat(), - "version": "0.5.1", + "version": "0.5.2", } self.events.append(event) diff --git a/src/powermem/version.py b/src/powermem/version.py index c188be0..a4f8bd5 100644 --- a/src/powermem/version.py +++ b/src/powermem/version.py @@ -2,7 +2,7 @@ Version information management """ -__version__ = "0.5.1" +__version__ = "0.5.2" __version_info__ = tuple(map(int, __version__.split("."))) # Version history