From 8a0e0f1dfad9df61218422a0b7715d065c0d3344 Mon Sep 17 00:00:00 2001 From: "wei.liu" Date: Tue, 20 Jan 2026 19:28:49 +0800 Subject: [PATCH 1/3] feat(vector-store): add ClickZetta vector store support --- docs/components/vectordbs/dbs/clickzetta.mdx | 68 ++ mem0/configs/vector_stores/clickzetta.py | 36 + mem0/utils/factory.py | 1 + mem0/vector_stores/clickzetta.py | 494 +++++++++++++ mem0/vector_stores/configs.py | 1 + pyproject.toml | 1 + tests/vector_stores/test_clickzetta.py | 727 +++++++++++++++++++ 7 files changed, 1328 insertions(+) create mode 100644 docs/components/vectordbs/dbs/clickzetta.mdx create mode 100644 mem0/configs/vector_stores/clickzetta.py create mode 100644 mem0/vector_stores/clickzetta.py create mode 100644 tests/vector_stores/test_clickzetta.py diff --git a/docs/components/vectordbs/dbs/clickzetta.mdx b/docs/components/vectordbs/dbs/clickzetta.mdx new file mode 100644 index 0000000000..78a73c6d4d --- /dev/null +++ b/docs/components/vectordbs/dbs/clickzetta.mdx @@ -0,0 +1,68 @@ +[ClickZetta](https://www.yunqi.tech/) is a cloud-native data lakehouse platform developed by Yunqi Technology, supporting vector storage and search capabilities. + +### Usage + + +```python Python +import os +from mem0 import Memory + +config = { + "vector_store": { + "provider": "clickzetta", + "config": { + "collection_name": "mem0_memories", + "service": "your-service.clickzetta.com", + "instance": "your-instance", + "workspace": "your-workspace", + "schema": "public", + "username": "your-username", + "password": "your-password", + "vcluster": "default", + "protocol": "http", + } + } +} + +m = Memory.from_config(config) +messages = [ + {"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"}, + {"role": "assistant", "content": "How about thriller movies? They can be quite engaging."}, + {"role": "user", "content": "I'm not a big fan of thriller movies but I love sci-fi movies."}, + {"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."} +] +m.add(messages, user_id="alice", metadata={"category": "movies"}) +``` + + +### Config + +Let's see the available parameters for the `clickzetta` config: + +| Parameter | Description | Default Value | +| --- | --- | --- | +| `collection_name` | The name of the collection/table to store the vectors | `mem0` | +| `embedding_model_dims` | Dimensions of the embedding model | `1536` | +| `service` | ClickZetta service endpoint | Required | +| `instance` | ClickZetta instance name | Required | +| `workspace` | ClickZetta workspace name | Required | +| `schema` | Schema name for the table | Required | +| `username` | Username for authentication | Required | +| `password` | Password for authentication | Required | +| `vcluster` | Virtual cluster name | Required | +| `protocol` | Connection protocol (http/https) | `http` | +| `distance_metric` | Distance metric for similarity search (cosine, euclidean, dot_product) | `cosine` | + +### Installation + +To use ClickZetta as a vector store, you need to install the ClickZetta connector: + +```bash +pip install clickzetta-connector-python +``` + +Or install with all vector store dependencies: + +```bash +pip install mem0ai[vector_stores] +``` diff --git a/mem0/configs/vector_stores/clickzetta.py b/mem0/configs/vector_stores/clickzetta.py new file mode 100644 index 0000000000..cad3535538 --- /dev/null +++ b/mem0/configs/vector_stores/clickzetta.py @@ -0,0 +1,36 @@ +from typing import Any, Dict, Literal, Optional +import warnings + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +# Suppress schema field name warning +warnings.filterwarnings("ignore", message=".*Field name.*schema.*shadows.*") + + +class ClickzettaConfig(BaseModel): + """ClickZetta Vector Store Configuration.""" + + model_config = ConfigDict(populate_by_name=True, protected_namespaces=()) + + collection_name: str = Field("mem0", description="Collection/table name") + embedding_model_dims: Optional[int] = Field(1536, description="Embedding vector dimensions") + service: str = Field(..., description="ClickZetta service address") + instance: str = Field(..., description="Instance name") + workspace: str = Field(..., description="Workspace name") + schema: str = Field(..., description="Schema name") + username: str = Field(..., description="Username") + password: str = Field(..., description="Password") + vcluster: str = Field(..., description="Virtual cluster name") + protocol: str = Field("http", description="Connection protocol (http/https)") + distance_metric: Literal["cosine", "euclidean", "dot_product"] = Field( + "cosine", description="Distance metric" + ) + + @model_validator(mode="before") + @classmethod + def validate_required_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]: + required_fields = ["service", "instance", "workspace", "schema", "username", "password", "vcluster"] + missing = [f for f in required_fields if not values.get(f)] + if missing: + raise ValueError(f"Missing required fields: {', '.join(missing)}") + return values diff --git a/mem0/utils/factory.py b/mem0/utils/factory.py index ab3fc77a3c..9a47989854 100644 --- a/mem0/utils/factory.py +++ b/mem0/utils/factory.py @@ -186,6 +186,7 @@ class VectorStoreFactory: "baidu": "mem0.vector_stores.baidu.BaiduDB", "cassandra": "mem0.vector_stores.cassandra.CassandraDB", "neptune": "mem0.vector_stores.neptune_analytics.NeptuneAnalyticsVector", + "clickzetta": "mem0.vector_stores.clickzetta.ClickZetta", } @classmethod diff --git a/mem0/vector_stores/clickzetta.py b/mem0/vector_stores/clickzetta.py new file mode 100644 index 0000000000..5813e898d2 --- /dev/null +++ b/mem0/vector_stores/clickzetta.py @@ -0,0 +1,494 @@ +""" +ClickZetta Vector Store Implementation + +ClickZetta is a cloud-native data lakehouse platform that supports vector storage and search. + +Usage: + from mem0 import Memory + + config = { + "vector_store": { + "provider": "clickzetta", + "config": { + "collection_name": "mem0_memories", + "service": "your-service", + "instance": "your-instance", + "workspace": "your-workspace", + "schema": "your-schema", + "username": "your-username", + "password": "your-password", + "vcluster": "your-vcluster" + } + } + } + m = Memory.from_config(config) + +Dependencies: + pip install clickzetta-connector-python +""" + +import json +import logging +import uuid +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel + +from mem0.vector_stores.base import VectorStoreBase + +logger = logging.getLogger(__name__) + +try: + import clickzetta.dbapi as clickzetta_dbapi +except ImportError: + raise ImportError( + "The 'clickzetta-connector-python' library is required. " + "Please install it using 'pip install clickzetta-connector-python'." + ) + + +class OutputData(BaseModel): + """Search result output data model.""" + id: str + score: float + payload: Optional[Dict[str, Any]] = None + + +class ClickZetta(VectorStoreBase): + """ + ClickZetta Vector Store Implementation. + + Uses ClickZetta's SQL interface to store and query vector data. + Vectors are stored as VECTOR type, using cosine similarity for search by default. + """ + + def __init__( + self, + collection_name: str, + embedding_model_dims: int, + service: str, + instance: str, + workspace: str, + schema: str, + username: str, + password: str, + vcluster: str, + distance_metric: str = "cosine", + protocol: str = "http", + ): + """ + Initialize ClickZetta Vector Store. + + Args: + collection_name: Collection/table name. + embedding_model_dims: Embedding vector dimensions. + service: ClickZetta service name. + instance: Instance name. + workspace: Workspace name. + schema: Schema name. + username: Username. + password: Password. + vcluster: Virtual cluster name. + distance_metric: Distance metric (cosine, euclidean, dot_product). + protocol: Gateway protocol. + """ + self.collection_name = collection_name + self.embedding_model_dims = embedding_model_dims + self.distance_metric = distance_metric + + # Connection configuration + self.service = service + self.instance = instance + self.workspace = workspace + self.schema = schema + self.username = username + self.password = password + self.vcluster = vcluster + self.protocol = protocol + + # Create connection + self.connection = self._create_connection() + + # Create collection table + self.create_col(embedding_model_dims, distance_metric) + + def _create_connection(self): + """Create database connection.""" + try: + conn = clickzetta_dbapi.connect( + service=self.service, + instance=self.instance, + workspace=self.workspace, + schema=self.schema, + username=self.username, + password=self.password, + vcluster=self.vcluster, + protocol=self.protocol + ) + logger.info("Successfully connected to ClickZetta") + return conn + except Exception as e: + logger.error(f"Failed to connect to ClickZetta: {e}") + raise + + def _execute_query(self, query: str, params: dict = None) -> List[tuple]: + """Execute SQL query.""" + cursor = self.connection.cursor() + try: + if params: + cursor.execute(query, params) + else: + cursor.execute(query) + + # For SELECT queries, return results + if query.strip().upper().startswith("SELECT"): + return cursor.fetchall() + except Exception as e: + logger.error(f"Query execution failed: {e}, SQL: {query}") + raise + finally: + cursor.close() + + def create_col(self, vector_size: int, distance: str = "cosine"): + """ + Create vector storage table. + + Args: + vector_size: Vector dimensions. + distance: Distance metric. + """ + # Check if table already exists + check_query = f""" + SELECT COUNT(*) FROM information_schema.tables + WHERE table_schema = '{self.schema}' + AND table_name = '{self.collection_name}' + """ + + try: + result = self._execute_query(check_query) + if result and result[0][0] > 0: + logger.debug(f"Collection {self.collection_name} already exists. Skipping creation.") + return + except Exception: + pass # Table may not exist, continue to create + + # Create table + create_query = f""" + CREATE TABLE IF NOT EXISTS {self.schema}.{self.collection_name} ( + id VARCHAR(64) PRIMARY KEY, + vector vector({vector_size}), + payload JSON, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + + try: + self._execute_query(create_query) + logger.info(f"Created collection {self.collection_name}") + except Exception as e: + logger.error(f"Failed to create collection: {e}") + raise + + def insert(self, vectors: List[List[float]], payloads: List[Dict] = None, ids: List[str] = None): + """ + Insert vector data. + + Args: + vectors: List of vectors. + payloads: List of metadata. + ids: List of IDs. + """ + if ids is None: + ids = [str(uuid.uuid4()) for _ in range(len(vectors))] + + if payloads is None: + payloads = [{} for _ in range(len(vectors))] + + logger.info(f"Inserting {len(vectors)} vectors into collection {self.collection_name}") + + for i, (vec_id, vector, payload) in enumerate(zip(ids, vectors, payloads)): + # Convert vector to array string format + vector_str = "[" + ",".join(str(v) for v in vector) + "]" + payload_json = json.dumps(payload, ensure_ascii=False) + + param = {'hints': {'cz.sql.insert.duplicate.key.policy': 'update'}} + + query = f""" + INSERT INTO {self.schema}.{self.collection_name} (id, vector, payload) + VALUES ('{vec_id}', cast("{vector_str}" as vector(384)), json_parse('{payload_json}')) + """ + + try: + self._execute_query(query, param) + except Exception as e: + logger.error(f"Failed to insert vector {vec_id}: {e}, SQL: {query}") + raise + + def _build_distance_expression(self, query_vector: List[float]) -> str: + """ + Build distance calculation expression. + + Args: + query_vector: Query vector. + + Returns: + SQL distance calculation expression. + """ + vector_str = "[" + ",".join(str(v) for v in query_vector) + "]" + + if self.distance_metric == "cosine": + return f'cosine_distance(vector, cast("{vector_str}" as vector({self.embedding_model_dims})))' + elif self.distance_metric == "euclidean": + return f'L2_distance(vector, cast("{vector_str}" as vector({self.embedding_model_dims})))' + elif self.distance_metric == "dot_product": + return f'(-1 * dot_product(vector, cast("{vector_str}" as vector({self.embedding_model_dims}))))' + else: + return f'cosine_distance(vector, cast("{vector_str}" as vector({self.embedding_model_dims})))' + + def _build_filter_clause(self, filters: Dict) -> str: + """ + Build filter condition SQL. + + Args: + filters: Filter condition dictionary. + + Returns: + WHERE clause string. + """ + if not filters: + return "" + + conditions = [] + for key, value in filters.items(): + if isinstance(value, dict) and "gte" in value and "lte" in value: + # Range query + conditions.append( + f"json_extract_string(payload, '$.{key}') >= {value['gte']} " + f"AND json_extract_string(payload, '$.{key}') <= {value['lte']}" + ) + elif isinstance(value, str): + conditions.append(f"json_extract_string(payload, '$.{key}') = '{value}'") + else: + conditions.append(f"json_extract_string(payload, '$.{key}') = {value}") + + return " AND " + " AND ".join(conditions) if conditions else "" + + def search(self, query: str, vectors: List[float], limit: int = 5, filters: Dict = None) -> List[OutputData]: + """ + Search for similar vectors. + + Args: + query: Query text (unused, kept for interface compatibility). + vectors: Query vector. + limit: Number of results to return. + filters: Filter conditions. + + Returns: + List of search results. + """ + distance_expr = self._build_distance_expression(vectors) + filter_clause = self._build_filter_clause(filters) + + search_query = f""" + SELECT id, payload, {distance_expr} AS distance + FROM {self.schema}.{self.collection_name} + WHERE 1=1 {filter_clause} + ORDER BY distance ASC + LIMIT {limit} + """ + + try: + results = self._execute_query(search_query) + + output = [] + for row in results: + vec_id, payload_str, distance = row[0], row[1], row[2] + + try: + payload = json.loads(payload_str) if payload_str else {} + except json.JSONDecodeError: + payload = {} + + # Convert distance to similarity score + if self.distance_metric == "cosine": + # cosine_distance range [0, 2] -> score [1, 0] + score = 1 - float(distance) / 2 + elif self.distance_metric == "euclidean": + # Smaller distance means higher score, range (0, 1] + score = 1 / (1 + float(distance)) + else: # dot_product + # Restore to positive dot product value + score = -float(distance) + + output.append(OutputData(id=vec_id, score=score, payload=payload)) + + return output + except Exception as e: + logger.error(f"Search failed: {e}") + raise + + def delete(self, vector_id: str): + """ + Delete a vector. + + Args: + vector_id: Vector ID. + """ + query = f""" + DELETE FROM {self.schema}.{self.collection_name} + WHERE id = '{vector_id}' + """ + self._execute_query(query) + logger.debug(f"Deleted vector {vector_id}") + + def update(self, vector_id: str, vector: List[float] = None, payload: Dict = None): + """ + Update vector and metadata. + + Args: + vector_id: Vector ID. + vector: New vector. + payload: New metadata. + """ + updates = [] + + if vector is not None: + vector_str = "[" + ",".join(str(v) for v in vector) + "]" + updates.append(f"vector = ARRAY{vector_str}") + + if payload is not None: + payload_json = json.dumps(payload, ensure_ascii=False) + updates.append(f"payload = '{payload_json}'") + + if not updates: + return + + query = f""" + UPDATE {self.schema}.{self.collection_name} + SET {", ".join(updates)} + WHERE id = '{vector_id}' + """ + self._execute_query(query) + logger.debug(f"Updated vector {vector_id}") + + def get(self, vector_id: str) -> Optional[OutputData]: + """ + Retrieve a single vector. + + Args: + vector_id: Vector ID. + + Returns: + Vector data, or None if not found. + """ + query = f""" + SELECT id, vector, payload + FROM {self.schema}.{self.collection_name} + WHERE id = '{vector_id}' + """ + + results = self._execute_query(query) + + if not results: + return None + + row = results[0] + vec_id, vector, payload_str = row[0], row[1], row[2] + + try: + payload = json.loads(payload_str) if payload_str else {} + except json.JSONDecodeError: + payload = {} + + return OutputData(id=vec_id, score=1.0, payload=payload) + + def list_cols(self) -> List[str]: + """ + List all collections (tables). + + Returns: + List of collection names. + """ + query = f""" + SELECT table_name + FROM information_schema.tables + WHERE table_schema = '{self.schema}' + """ + + results = self._execute_query(query) + return [row[0] for row in results] + + def delete_col(self): + """Delete the current collection.""" + query = f"DROP TABLE IF EXISTS {self.schema}.{self.collection_name}" + self._execute_query(query) + logger.info(f"Deleted collection {self.collection_name}") + + def col_info(self) -> Dict: + """ + Get collection information. + + Returns: + Collection metadata. + """ + # Get row count + count_query = f"SELECT COUNT(*) FROM {self.schema}.{self.collection_name}" + count_result = self._execute_query(count_query) + row_count = count_result[0][0] if count_result else 0 + + return { + "name": self.collection_name, + "schema": self.schema, + "row_count": row_count, + "embedding_dims": self.embedding_model_dims, + "distance_metric": self.distance_metric, + } + + def list(self, filters: Dict = None, limit: int = 100) -> List[OutputData]: + """ + List all vectors in the collection. + + Args: + filters: Filter conditions. + limit: Maximum number of results. + + Returns: + List of vectors. + """ + filter_clause = self._build_filter_clause(filters) + + query = f""" + SELECT id, payload + FROM {self.schema}.{self.collection_name} + WHERE 1=1 {filter_clause} + LIMIT {limit} + """ + + results = self._execute_query(query) + + output = [] + for row in results: + vec_id, payload_str = row[0], row[1] + + try: + payload = json.loads(payload_str) if payload_str else {} + except json.JSONDecodeError: + payload = {} + + output.append(OutputData(id=vec_id, score=1.0, payload=payload)) + + return output + + def reset(self): + """Reset the collection (delete and recreate).""" + logger.warning(f"Resetting collection {self.collection_name}...") + self.delete_col() + self.create_col(self.embedding_model_dims, self.distance_metric) + + def __del__(self): + """Close database connection.""" + if hasattr(self, 'connection') and self.connection: + try: + self.connection.close() + except Exception: + pass diff --git a/mem0/vector_stores/configs.py b/mem0/vector_stores/configs.py index d08bae37ac..a240ae4e39 100644 --- a/mem0/vector_stores/configs.py +++ b/mem0/vector_stores/configs.py @@ -34,6 +34,7 @@ class VectorStoreConfig(BaseModel): "faiss": "FAISSConfig", "langchain": "LangchainConfig", "s3_vectors": "S3VectorsConfig", + "clickzetta": "ClickzettaConfig", } @model_validator(mode="after") diff --git a/pyproject.toml b/pyproject.toml index 294a7ff690..35c198af16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ vector_stores = [ "elasticsearch>=8.0.0,<9.0.0", "pymilvus>=2.4.0,<2.6.0", "langchain-aws>=0.2.23", + "clickzetta-connector-python>=0.8.109", ] llms = [ "groq>=0.3.0", diff --git a/tests/vector_stores/test_clickzetta.py b/tests/vector_stores/test_clickzetta.py new file mode 100644 index 0000000000..13f4abee48 --- /dev/null +++ b/tests/vector_stores/test_clickzetta.py @@ -0,0 +1,727 @@ +import json +import uuid +from unittest.mock import MagicMock, patch + +import pytest + +from mem0.vector_stores.clickzetta import ClickZetta, OutputData + + +# ---------------------- Fixtures ---------------------- # + + +@pytest.fixture +@patch("mem0.vector_stores.clickzetta.clickzetta_dbapi") +def clickzetta_fixture(mock_dbapi): + """Create a ClickZetta instance with mocked database connection.""" + mock_connection = MagicMock() + mock_cursor = MagicMock() + mock_connection.cursor.return_value = mock_cursor + mock_dbapi.connect.return_value = mock_connection + + # Mock table existence check to return empty (table doesn't exist) + mock_cursor.fetchall.return_value = [(0,)] + + clickzetta = ClickZetta( + collection_name="test_collection", + embedding_model_dims=384, + service="test-service", + instance="test-instance", + workspace="test-workspace", + schema="test_schema", + username="test-user", + password="test-password", + vcluster="test-vcluster", + protocol="http", + distance_metric="cosine", + ) + return clickzetta, mock_connection, mock_cursor + + +# ---------------------- Initialization Tests ---------------------- # + + +def test_initialization(clickzetta_fixture): + """Test that ClickZetta initializes correctly.""" + clickzetta, mock_connection, mock_cursor = clickzetta_fixture + + assert clickzetta.collection_name == "test_collection" + assert clickzetta.embedding_model_dims == 384 + assert clickzetta.distance_metric == "cosine" + assert clickzetta.schema == "test_schema" + assert clickzetta.protocol == "http" + + +def test_create_connection(clickzetta_fixture): + """Test that connection is created with correct parameters.""" + clickzetta, mock_connection, _ = clickzetta_fixture + assert clickzetta.connection is not None + + +def test_create_col(clickzetta_fixture): + """Test collection creation.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + mock_cursor.fetchall.return_value = [(0,)] + + clickzetta.create_col(vector_size=384, distance="cosine") + + # Verify CREATE TABLE was called + calls = [str(c) for c in mock_cursor.execute.call_args_list] + create_calls = [c for c in calls if "CREATE TABLE" in c] + assert len(create_calls) > 0 + + +def test_create_col_already_exists(clickzetta_fixture): + """Test that collection creation is skipped if table exists.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + mock_cursor.fetchall.return_value = [(1,)] # Table exists + + clickzetta.create_col(vector_size=384, distance="cosine") + + # Verify CREATE TABLE was NOT called + calls = [str(c) for c in mock_cursor.execute.call_args_list] + create_calls = [c for c in calls if "CREATE TABLE" in c] + assert len(create_calls) == 0 + + +# ---------------------- Insert Tests ---------------------- # + + +def test_insert(clickzetta_fixture): + """Test vector insertion.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + + vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + payloads = [{"user_id": "user1"}, {"user_id": "user2"}] + ids = [str(uuid.uuid4()), str(uuid.uuid4())] + + clickzetta.insert(vectors=vectors, payloads=payloads, ids=ids) + + # Verify INSERT was called twice + calls = [str(c) for c in mock_cursor.execute.call_args_list] + insert_calls = [c for c in calls if "INSERT INTO" in c] + assert len(insert_calls) == 2 + + +def test_insert_generates_ids(clickzetta_fixture): + """Test that IDs are generated if not provided.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + + vectors = [[0.1, 0.2, 0.3]] + payloads = [{"user_id": "user1"}] + + clickzetta.insert(vectors=vectors, payloads=payloads) + + # Verify INSERT was called + calls = [str(c) for c in mock_cursor.execute.call_args_list] + insert_calls = [c for c in calls if "INSERT INTO" in c] + assert len(insert_calls) == 1 + + +def test_insert_generates_payloads(clickzetta_fixture): + """Test that empty payloads are generated if not provided.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + + vectors = [[0.1, 0.2, 0.3]] + ids = ["test-id"] + + clickzetta.insert(vectors=vectors, ids=ids) + + calls = [str(c) for c in mock_cursor.execute.call_args_list] + insert_calls = [c for c in calls if "INSERT INTO" in c] + assert len(insert_calls) == 1 + + +def test_insert_with_special_characters(clickzetta_fixture): + """Test insertion with special characters in payload.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + + vectors = [[0.1, 0.2, 0.3]] + payloads = [{"data": "Hello 'world' with \"quotes\""}] + ids = ["test-id"] + + clickzetta.insert(vectors=vectors, payloads=payloads, ids=ids) + + calls = [str(c) for c in mock_cursor.execute.call_args_list] + insert_calls = [c for c in calls if "INSERT INTO" in c] + assert len(insert_calls) == 1 + + +def test_insert_failure_raises_exception(clickzetta_fixture): + """Test that insert failure raises exception.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + mock_cursor.execute.side_effect = Exception("Insert failed") + + vectors = [[0.1, 0.2, 0.3]] + + with pytest.raises(Exception) as exc_info: + clickzetta.insert(vectors=vectors) + + assert "Insert failed" in str(exc_info.value) + + +# ---------------------- Search Tests ---------------------- # + + +def test_search(clickzetta_fixture): + """Test vector search.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + + mock_results = [ + ("id1", '{"user_id": "user1", "data": "test data"}', 0.1), + ("id2", '{"user_id": "user1", "data": "test data 2"}', 0.2), + ] + mock_cursor.fetchall.return_value = mock_results + + vectors = [0.1, 0.2, 0.3] + results = clickzetta.search(query="test", vectors=vectors, limit=5) + + assert len(results) == 2 + assert isinstance(results[0], OutputData) + assert results[0].id == "id1" + assert results[0].payload["user_id"] == "user1" + + +def test_search_with_filters(clickzetta_fixture): + """Test search with filters.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + + mock_results = [ + ("id1", '{"user_id": "user1", "agent_id": "agent1"}', 0.1), + ] + mock_cursor.fetchall.return_value = mock_results + + vectors = [0.1, 0.2, 0.3] + filters = {"user_id": "user1", "agent_id": "agent1"} + results = clickzetta.search(query="test", vectors=vectors, limit=5, filters=filters) + + # Verify filter clause was included in query + call_args = mock_cursor.execute.call_args[0][0] + assert "user_id" in call_args + assert "agent_id" in call_args + + assert len(results) == 1 + assert results[0].payload["user_id"] == "user1" + + +def test_search_with_single_filter(clickzetta_fixture): + """Test search with single filter.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + + mock_results = [("id1", '{"user_id": "alice"}', 0.1)] + mock_cursor.fetchall.return_value = mock_results + + filters = {"user_id": "alice"} + results = clickzetta.search(query="test", vectors=[0.1, 0.2, 0.3], limit=5, filters=filters) + + call_args = mock_cursor.execute.call_args[0][0] + assert "user_id" in call_args + assert len(results) == 1 + + +def test_search_with_no_filters(clickzetta_fixture): + """Test search with no filters.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + + mock_results = [("id1", '{"key": "value"}', 0.1)] + mock_cursor.fetchall.return_value = mock_results + + results = clickzetta.search(query="test", vectors=[0.1, 0.2, 0.3], limit=5, filters=None) + + assert len(results) == 1 + + +def test_search_empty_results(clickzetta_fixture): + """Test search with no results.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + mock_cursor.fetchall.return_value = [] + + vectors = [0.1, 0.2, 0.3] + results = clickzetta.search(query="test", vectors=vectors, limit=5) + + assert len(results) == 0 + + +def test_search_with_invalid_json_payload(clickzetta_fixture): + """Test search with invalid JSON payload.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + + mock_results = [("id1", "invalid json {", 0.1)] + mock_cursor.fetchall.return_value = mock_results + + results = clickzetta.search(query="test", vectors=[0.1, 0.2, 0.3], limit=5) + + assert len(results) == 1 + assert results[0].payload == {} # Should default to empty dict + + +def test_search_with_null_payload(clickzetta_fixture): + """Test search with null payload.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + + mock_results = [("id1", None, 0.1)] + mock_cursor.fetchall.return_value = mock_results + + results = clickzetta.search(query="test", vectors=[0.1, 0.2, 0.3], limit=5) + + assert len(results) == 1 + assert results[0].payload == {} + + +# ---------------------- Delete Tests ---------------------- # + + +def test_delete(clickzetta_fixture): + """Test vector deletion.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + + vector_id = str(uuid.uuid4()) + clickzetta.delete(vector_id=vector_id) + + call_args = mock_cursor.execute.call_args[0][0] + assert "DELETE FROM" in call_args + assert vector_id in call_args + + +# ---------------------- Update Tests ---------------------- # + + +def test_update(clickzetta_fixture): + """Test vector update.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + + vector_id = str(uuid.uuid4()) + new_vector = [0.2, 0.3, 0.4] + new_payload = {"user_id": "user2"} + + clickzetta.update(vector_id=vector_id, vector=new_vector, payload=new_payload) + + call_args = mock_cursor.execute.call_args[0][0] + assert "UPDATE" in call_args + assert vector_id in call_args + + +def test_update_vector_only(clickzetta_fixture): + """Test updating only the vector.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + + vector_id = str(uuid.uuid4()) + new_vector = [0.2, 0.3, 0.4] + + clickzetta.update(vector_id=vector_id, vector=new_vector) + + call_args = mock_cursor.execute.call_args[0][0] + assert "vector =" in call_args + assert "payload =" not in call_args + + +def test_update_payload_only(clickzetta_fixture): + """Test updating only the payload.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + + vector_id = str(uuid.uuid4()) + new_payload = {"user_id": "user2"} + + clickzetta.update(vector_id=vector_id, payload=new_payload) + + call_args = mock_cursor.execute.call_args[0][0] + assert "payload =" in call_args + assert "vector =" not in call_args + + +def test_update_nothing(clickzetta_fixture): + """Test update with no changes.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + + vector_id = str(uuid.uuid4()) + clickzetta.update(vector_id=vector_id) + + # Should not execute any query + assert mock_cursor.execute.call_count == 0 + + +# ---------------------- Get Tests ---------------------- # + + +def test_get(clickzetta_fixture): + """Test getting a single vector.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + + vector_id = str(uuid.uuid4()) + mock_result = [(vector_id, [0.1, 0.2, 0.3], '{"user_id": "user1"}')] + mock_cursor.fetchall.return_value = mock_result + + result = clickzetta.get(vector_id=vector_id) + + assert result is not None + assert result.id == vector_id + assert result.payload["user_id"] == "user1" + + +def test_get_not_found(clickzetta_fixture): + """Test getting a non-existent vector.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + mock_cursor.fetchall.return_value = [] + + result = clickzetta.get(vector_id="non-existent-id") + + assert result is None + + +def test_get_with_invalid_payload(clickzetta_fixture): + """Test get with invalid JSON payload.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + + vector_id = str(uuid.uuid4()) + mock_result = [(vector_id, [0.1, 0.2, 0.3], "invalid json")] + mock_cursor.fetchall.return_value = mock_result + + result = clickzetta.get(vector_id=vector_id) + + assert result is not None + assert result.payload == {} + + +# ---------------------- List Collections Tests ---------------------- # + + +def test_list_cols(clickzetta_fixture): + """Test listing collections.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + + mock_result = [("table1",), ("table2",), ("test_collection",)] + mock_cursor.fetchall.return_value = mock_result + + result = clickzetta.list_cols() + + assert len(result) == 3 + assert "test_collection" in result + + +# ---------------------- Delete Collection Tests ---------------------- # + + +def test_delete_col(clickzetta_fixture): + """Test collection deletion.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + + clickzetta.delete_col() + + call_args = mock_cursor.execute.call_args[0][0] + assert "DROP TABLE" in call_args + assert "test_collection" in call_args + + +# ---------------------- Collection Info Tests ---------------------- # + + +def test_col_info(clickzetta_fixture): + """Test getting collection info.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + mock_cursor.fetchall.return_value = [(100,)] + + result = clickzetta.col_info() + + assert result["name"] == "test_collection" + assert result["schema"] == "test_schema" + assert result["row_count"] == 100 + assert result["embedding_dims"] == 384 + assert result["distance_metric"] == "cosine" + + +# ---------------------- List Vectors Tests ---------------------- # + + +def test_list(clickzetta_fixture): + """Test listing vectors.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + + mock_results = [ + ("id1", '{"user_id": "user1"}'), + ("id2", '{"user_id": "user2"}'), + ] + mock_cursor.fetchall.return_value = mock_results + + results = clickzetta.list(limit=100) + + assert len(results) == 2 + assert results[0].id == "id1" + assert results[1].id == "id2" + + +def test_list_with_filters(clickzetta_fixture): + """Test listing vectors with filters.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + + mock_results = [("id1", '{"user_id": "user1"}')] + mock_cursor.fetchall.return_value = mock_results + + filters = {"user_id": "user1"} + results = clickzetta.list(filters=filters, limit=100) + + call_args = mock_cursor.execute.call_args[0][0] + assert "user_id" in call_args + + assert len(results) == 1 + + +def test_list_with_no_filters(clickzetta_fixture): + """Test listing vectors with no filters.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + + mock_results = [("id1", '{"key": "value"}')] + mock_cursor.fetchall.return_value = mock_results + + results = clickzetta.list(filters=None, limit=100) + + assert len(results) == 1 + + +# ---------------------- Reset Tests ---------------------- # + + +def test_reset(clickzetta_fixture): + """Test resetting the collection.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + mock_cursor.fetchall.return_value = [(0,)] + + clickzetta.reset() + + calls = [str(c) for c in mock_cursor.execute.call_args_list] + drop_calls = [c for c in calls if "DROP TABLE" in c] + create_calls = [c for c in calls if "CREATE TABLE" in c] + + assert len(drop_calls) > 0 + assert len(create_calls) > 0 + + +# ---------------------- Filter Clause Tests ---------------------- # + + +def test_build_filter_clause_empty(clickzetta_fixture): + """Test filter clause with no filters.""" + clickzetta, _, _ = clickzetta_fixture + + result = clickzetta._build_filter_clause(None) + assert result == "" + + result = clickzetta._build_filter_clause({}) + assert result == "" + + +def test_build_filter_clause_string_value(clickzetta_fixture): + """Test filter clause with string value.""" + clickzetta, _, _ = clickzetta_fixture + + filters = {"user_id": "user1"} + result = clickzetta._build_filter_clause(filters) + + assert "user_id" in result + assert "user1" in result + + +def test_build_filter_clause_numeric_value(clickzetta_fixture): + """Test filter clause with numeric value.""" + clickzetta, _, _ = clickzetta_fixture + + filters = {"count": 10} + result = clickzetta._build_filter_clause(filters) + + assert "count" in result + assert "10" in result + + +def test_build_filter_clause_range_value(clickzetta_fixture): + """Test filter clause with range value.""" + clickzetta, _, _ = clickzetta_fixture + + filters = {"score": {"gte": 0.5, "lte": 1.0}} + result = clickzetta._build_filter_clause(filters) + + assert "score" in result + assert ">=" in result + assert "<=" in result + + +def test_build_filter_clause_multiple_filters(clickzetta_fixture): + """Test filter clause with multiple filters.""" + clickzetta, _, _ = clickzetta_fixture + + filters = {"user_id": "user1", "agent_id": "agent1", "run_id": "run1"} + result = clickzetta._build_filter_clause(filters) + + assert "user_id" in result + assert "agent_id" in result + assert "run_id" in result + assert "AND" in result + + +# ---------------------- Distance Expression Tests ---------------------- # + + +def test_build_distance_expression_cosine(clickzetta_fixture): + """Test distance expression for cosine metric.""" + clickzetta, _, _ = clickzetta_fixture + clickzetta.distance_metric = "cosine" + + result = clickzetta._build_distance_expression([0.1, 0.2, 0.3]) + + assert "cosine_distance" in result + + +def test_build_distance_expression_euclidean(clickzetta_fixture): + """Test distance expression for euclidean metric.""" + clickzetta, _, _ = clickzetta_fixture + clickzetta.distance_metric = "euclidean" + + result = clickzetta._build_distance_expression([0.1, 0.2, 0.3]) + + assert "L2_distance" in result + + +def test_build_distance_expression_dot_product(clickzetta_fixture): + """Test distance expression for dot product metric.""" + clickzetta, _, _ = clickzetta_fixture + clickzetta.distance_metric = "dot_product" + + result = clickzetta._build_distance_expression([0.1, 0.2, 0.3]) + + assert "dot_product" in result + + +def test_build_distance_expression_default(clickzetta_fixture): + """Test distance expression for unknown metric defaults to cosine.""" + clickzetta, _, _ = clickzetta_fixture + clickzetta.distance_metric = "unknown" + + result = clickzetta._build_distance_expression([0.1, 0.2, 0.3]) + + assert "cosine_distance" in result + + +# ---------------------- Score Conversion Tests ---------------------- # + + +def test_score_conversion_cosine(clickzetta_fixture): + """Test score conversion for cosine distance.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + clickzetta.distance_metric = "cosine" + + # cosine_distance = 0.2, score should be 1 - 0.2/2 = 0.9 + mock_results = [("id1", '{"user_id": "user1"}', 0.2)] + mock_cursor.fetchall.return_value = mock_results + + results = clickzetta.search(query="test", vectors=[0.1, 0.2, 0.3], limit=5) + + assert abs(results[0].score - 0.9) < 0.001 + + +def test_score_conversion_euclidean(clickzetta_fixture): + """Test score conversion for euclidean distance.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + clickzetta.distance_metric = "euclidean" + + # L2_distance = 1.0, score should be 1 / (1 + 1.0) = 0.5 + mock_results = [("id1", '{"user_id": "user1"}', 1.0)] + mock_cursor.fetchall.return_value = mock_results + + results = clickzetta.search(query="test", vectors=[0.1, 0.2, 0.3], limit=5) + + assert abs(results[0].score - 0.5) < 0.001 + + +def test_score_conversion_dot_product(clickzetta_fixture): + """Test score conversion for dot product.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + clickzetta.distance_metric = "dot_product" + + # distance = -0.8 (negated dot product), score should be 0.8 + mock_results = [("id1", '{"user_id": "user1"}', -0.8)] + mock_cursor.fetchall.return_value = mock_results + + results = clickzetta.search(query="test", vectors=[0.1, 0.2, 0.3], limit=5) + + assert abs(results[0].score - 0.8) < 0.001 + + +# ---------------------- Connection Tests ---------------------- # + + +@patch("mem0.vector_stores.clickzetta.clickzetta_dbapi") +def test_connection_failure(mock_dbapi): + """Test handling of connection failure.""" + mock_dbapi.connect.side_effect = Exception("Connection failed") + + with pytest.raises(Exception) as exc_info: + ClickZetta( + collection_name="test", + embedding_model_dims=384, + service="test", + instance="test", + workspace="test", + schema="test", + username="test", + password="test", + vcluster="test", + ) + + assert "Connection failed" in str(exc_info.value) + + +# ---------------------- Query Execution Tests ---------------------- # + + +def test_execute_query_select(clickzetta_fixture): + """Test SELECT query returns results.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + mock_cursor.fetchall.return_value = [("row1",), ("row2",)] + + results = clickzetta._execute_query("SELECT * FROM test") + + assert len(results) == 2 + mock_cursor.fetchall.assert_called_once() + + +def test_execute_query_failure(clickzetta_fixture): + """Test query failure raises exception.""" + clickzetta, _, mock_cursor = clickzetta_fixture + mock_cursor.reset_mock() + mock_cursor.execute.side_effect = Exception("Query failed") + + with pytest.raises(Exception) as exc_info: + clickzetta._execute_query("SELECT * FROM test") + + assert "Query failed" in str(exc_info.value) From 6da3802ad871abeaac6fb31001c80da0dd2086ea Mon Sep 17 00:00:00 2001 From: "wei.liu" Date: Tue, 20 Jan 2026 19:51:56 +0800 Subject: [PATCH 2/3] fix doc --- docs/components/vectordbs/dbs/clickzetta.mdx | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/components/vectordbs/dbs/clickzetta.mdx b/docs/components/vectordbs/dbs/clickzetta.mdx index 78a73c6d4d..d04f93bf75 100644 --- a/docs/components/vectordbs/dbs/clickzetta.mdx +++ b/docs/components/vectordbs/dbs/clickzetta.mdx @@ -2,7 +2,6 @@ ### Usage - ```python Python import os from mem0 import Memory @@ -33,7 +32,6 @@ messages = [ ] m.add(messages, user_id="alice", metadata={"category": "movies"}) ``` - ### Config From 838fe5809cde7d2c705c8bfbea031d2de1769ee7 Mon Sep 17 00:00:00 2001 From: "wei.liu" Date: Tue, 20 Jan 2026 20:25:12 +0800 Subject: [PATCH 3/3] fix doc --- docs/components/vectordbs/dbs/clickzetta.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/components/vectordbs/dbs/clickzetta.mdx b/docs/components/vectordbs/dbs/clickzetta.mdx index d04f93bf75..83383d1232 100644 --- a/docs/components/vectordbs/dbs/clickzetta.mdx +++ b/docs/components/vectordbs/dbs/clickzetta.mdx @@ -1,4 +1,4 @@ -[ClickZetta](https://www.yunqi.tech/) is a cloud-native data lakehouse platform developed by Yunqi Technology, supporting vector storage and search capabilities. +[ClickZetta](https://www.singdata.com/) is a cloud-native data lakehouse platform developed by Singdata Technology, supporting vector storage and search capabilities. ### Usage