Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion integrations/langchain/src/databricks_langchain/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
embedding_dims: int | None = None,
embedding_fields: list[str] | None = None,
embeddings: DatabricksEmbeddings | None = None,
auto_setup: bool = True,
**pool_kwargs: Any,
) -> None:
"""Initialize DatabricksStore with embedding support.
Expand All @@ -54,6 +55,8 @@ def __init__(
vectorizes the entire JSON value.
embeddings: Optional pre-configured DatabricksEmbeddings instance. If provided,
takes precedence over embedding_endpoint.
auto_setup: If True (default), automatically call setup() when entering
the context manager. Set to False to skip automatic setup.
**pool_kwargs: Additional keyword arguments passed to LakebasePool.
"""
if not _store_imports_available:
Expand All @@ -67,6 +70,7 @@ def __init__(
workspace_client=workspace_client,
**pool_kwargs,
)
self._auto_setup = auto_setup

# Initialize embeddings and index configuration for semantic search
self.embeddings: DatabricksEmbeddings | None = None
Expand Down Expand Up @@ -134,6 +138,16 @@ async def abatch(self, ops: Iterable[Op]) -> list[Result]:
"""
return self.batch(ops)

def __enter__(self):
"""Enter context manager and optionally set up the store."""
if self._auto_setup:
self.setup()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
"""Exit context manager."""
return False


class AsyncDatabricksStore(AsyncBatchedBaseStore):
"""Async version of DatabricksStore for working with long-term memory on Databricks.
Expand All @@ -154,6 +168,7 @@ def __init__(
embedding_dims: int | None = None,
embedding_fields: list[str] | None = None,
embeddings: DatabricksEmbeddings | None = None,
auto_setup: bool = True,
**pool_kwargs: Any,
) -> None:
"""Initialize AsyncDatabricksStore with embedding support.
Expand All @@ -169,6 +184,8 @@ def __init__(
vectorizes the entire JSON value.
embeddings: Optional pre-configured DatabricksEmbeddings instance. If provided,
takes precedence over embedding_endpoint.
auto_setup: If True (default), automatically call setup() when entering
the context manager. Set to False to skip automatic setup.
**pool_kwargs: Additional keyword arguments passed to AsyncLakebasePool.
"""
if not _store_imports_available:
Expand All @@ -184,6 +201,7 @@ def __init__(
workspace_client=workspace_client,
**pool_kwargs,
)
self._auto_setup = auto_setup

# Initialize embeddings and index configuration for semantic search
self.embeddings: DatabricksEmbeddings | None = None
Expand Down Expand Up @@ -243,8 +261,10 @@ async def abatch(self, ops: Iterable[Op]) -> list[Result]:
return await self._with_store(lambda s: s.abatch(ops))

async def __aenter__(self):
"""Enter async context manager and open the connection pool."""
"""Enter async context manager, open the connection pool, and optionally set up the store."""
await self._lakebase.open()
if self._auto_setup:
await self.setup()
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
Expand Down
121 changes: 121 additions & 0 deletions integrations/langchain/tests/unit_tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,65 @@
assert store.index_config["embed"] is mock_embeddings


def test_databricks_store_context_manager_auto_setup_true(monkeypatch):
"""Test that context manager calls setup() when auto_setup=True (default)."""
mock_conn = MagicMock()
test_pool = TestConnectionPool(connection_value=mock_conn)
monkeypatch.setattr(lakebase, "ConnectionPool", test_pool)

from langgraph.store.postgres import PostgresStore

Check failure on line 212 in integrations/langchain/tests/unit_tests/test_store.py

View workflow job for this annotation

GitHub Actions / ruff check for .

Ruff (F401)

integrations/langchain/tests/unit_tests/test_store.py:212:42: F401 `langgraph.store.postgres.PostgresStore` imported but unused

Check failure on line 212 in integrations/langchain/tests/unit_tests/test_store.py

View workflow job for this annotation

GitHub Actions / ruff check for integrations/langchain

Ruff (F401)

tests/unit_tests/test_store.py:212:42: F401 `langgraph.store.postgres.PostgresStore` imported but unused

workspace = _create_mock_workspace()

mock_pg_store = MagicMock()
mock_pg_store.setup = MagicMock()

with patch(
"databricks_langchain.store.PostgresStore", return_value=mock_pg_store
) as mock_pg_class:
store = DatabricksStore(
instance_name="lakebase-instance",
workspace_client=workspace,
)

# Verify auto_setup defaults to True
assert store._auto_setup is True

with store:
# setup() should have been called
mock_pg_store.setup.assert_called_once()


def test_databricks_store_context_manager_auto_setup_false(monkeypatch):
"""Test that context manager does NOT call setup() when auto_setup=False."""
mock_conn = MagicMock()
test_pool = TestConnectionPool(connection_value=mock_conn)
monkeypatch.setattr(lakebase, "ConnectionPool", test_pool)

from langgraph.store.postgres import PostgresStore

Check failure on line 241 in integrations/langchain/tests/unit_tests/test_store.py

View workflow job for this annotation

GitHub Actions / ruff check for .

Ruff (F401)

integrations/langchain/tests/unit_tests/test_store.py:241:42: F401 `langgraph.store.postgres.PostgresStore` imported but unused

Check failure on line 241 in integrations/langchain/tests/unit_tests/test_store.py

View workflow job for this annotation

GitHub Actions / ruff check for integrations/langchain

Ruff (F401)

tests/unit_tests/test_store.py:241:42: F401 `langgraph.store.postgres.PostgresStore` imported but unused

workspace = _create_mock_workspace()

mock_pg_store = MagicMock()
mock_pg_store.setup = MagicMock()

with patch(
"databricks_langchain.store.PostgresStore", return_value=mock_pg_store
):
store = DatabricksStore(
instance_name="lakebase-instance",
workspace_client=workspace,
auto_setup=False,
)

# Verify auto_setup is False
assert store._auto_setup is False

with store:
# setup() should NOT have been called
mock_pg_store.setup.assert_not_called()


# =============================================================================
# AsyncDatabricksStore Tests
# =============================================================================
Expand Down Expand Up @@ -287,13 +346,75 @@
async with AsyncDatabricksStore(
instance_name="lakebase-instance",
workspace_client=workspace,
auto_setup=False, # Disable auto_setup to avoid mocking setup
) as store:
assert test_pool._opened
assert not test_pool._closed

assert test_pool._closed


@pytest.mark.asyncio
async def test_async_databricks_store_context_manager_auto_setup_true(monkeypatch):
"""Test that async context manager calls setup() when auto_setup=True (default)."""
import asyncio

mock_conn = MagicMock()
test_pool = TestAsyncConnectionPool(connection_value=mock_conn)
monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool)

workspace = _create_mock_workspace()

mock_pg_store = MagicMock()
future = asyncio.Future()
future.set_result(None)
mock_pg_store.setup = MagicMock(return_value=future)

with patch(
"databricks_langchain.store.AsyncPostgresStore", return_value=mock_pg_store
):
store = AsyncDatabricksStore(
instance_name="lakebase-instance",
workspace_client=workspace,
)

# Verify auto_setup defaults to True
assert store._auto_setup is True

async with store:
# setup() should have been called
mock_pg_store.setup.assert_called_once()


@pytest.mark.asyncio
async def test_async_databricks_store_context_manager_auto_setup_false(monkeypatch):
"""Test that async context manager does NOT call setup() when auto_setup=False."""
mock_conn = MagicMock()
test_pool = TestAsyncConnectionPool(connection_value=mock_conn)
monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool)

workspace = _create_mock_workspace()

mock_pg_store = MagicMock()
mock_pg_store.setup = MagicMock()

with patch(
"databricks_langchain.store.AsyncPostgresStore", return_value=mock_pg_store
):
store = AsyncDatabricksStore(
instance_name="lakebase-instance",
workspace_client=workspace,
auto_setup=False,
)

# Verify auto_setup is False
assert store._auto_setup is False

async with store:
# setup() should NOT have been called
mock_pg_store.setup.assert_not_called()


@pytest.mark.asyncio
async def test_async_databricks_store_with_embedding_endpoint(monkeypatch):
"""Test that embedding_endpoint creates embeddings and index_config."""
Expand Down
Loading