From 88b4ecd5bcc4ccd262539452f278c543b4cb3ef6 Mon Sep 17 00:00:00 2001 From: Jenny Date: Fri, 30 Jan 2026 10:34:38 -0800 Subject: [PATCH] call setup within databricks store sdk so user doesnt need to call in agent code --- .../src/databricks_langchain/store.py | 22 +++- .../langchain/tests/unit_tests/test_store.py | 121 ++++++++++++++++++ 2 files changed, 142 insertions(+), 1 deletion(-) diff --git a/integrations/langchain/src/databricks_langchain/store.py b/integrations/langchain/src/databricks_langchain/store.py index 05aed3b8..0b8b1f2e 100644 --- a/integrations/langchain/src/databricks_langchain/store.py +++ b/integrations/langchain/src/databricks_langchain/store.py @@ -39,6 +39,7 @@ def __init__( embedding_dims: int | None = None, embedding_fields: list[str] | None = None, embeddings: DatabricksEmbeddings | None = None, + auto_setup: bool = True, **pool_kwargs: Any, ) -> None: """Initialize DatabricksStore with embedding support. @@ -54,6 +55,8 @@ def __init__( vectorizes the entire JSON value. embeddings: Optional pre-configured DatabricksEmbeddings instance. If provided, takes precedence over embedding_endpoint. + auto_setup: If True (default), automatically call setup() when entering + the context manager. Set to False to skip automatic setup. **pool_kwargs: Additional keyword arguments passed to LakebasePool. """ if not _store_imports_available: @@ -67,6 +70,7 @@ def __init__( workspace_client=workspace_client, **pool_kwargs, ) + self._auto_setup = auto_setup # Initialize embeddings and index configuration for semantic search self.embeddings: DatabricksEmbeddings | None = None @@ -134,6 +138,16 @@ async def abatch(self, ops: Iterable[Op]) -> list[Result]: """ return self.batch(ops) + def __enter__(self): + """Enter context manager and optionally set up the store.""" + if self._auto_setup: + self.setup() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Exit context manager.""" + return False + class AsyncDatabricksStore(AsyncBatchedBaseStore): """Async version of DatabricksStore for working with long-term memory on Databricks. @@ -154,6 +168,7 @@ def __init__( embedding_dims: int | None = None, embedding_fields: list[str] | None = None, embeddings: DatabricksEmbeddings | None = None, + auto_setup: bool = True, **pool_kwargs: Any, ) -> None: """Initialize AsyncDatabricksStore with embedding support. @@ -169,6 +184,8 @@ def __init__( vectorizes the entire JSON value. embeddings: Optional pre-configured DatabricksEmbeddings instance. If provided, takes precedence over embedding_endpoint. + auto_setup: If True (default), automatically call setup() when entering + the context manager. Set to False to skip automatic setup. **pool_kwargs: Additional keyword arguments passed to AsyncLakebasePool. """ if not _store_imports_available: @@ -184,6 +201,7 @@ def __init__( workspace_client=workspace_client, **pool_kwargs, ) + self._auto_setup = auto_setup # Initialize embeddings and index configuration for semantic search self.embeddings: DatabricksEmbeddings | None = None @@ -243,8 +261,10 @@ async def abatch(self, ops: Iterable[Op]) -> list[Result]: return await self._with_store(lambda s: s.abatch(ops)) async def __aenter__(self): - """Enter async context manager and open the connection pool.""" + """Enter async context manager, open the connection pool, and optionally set up the store.""" await self._lakebase.open() + if self._auto_setup: + await self.setup() return self async def __aexit__(self, exc_type, exc_val, exc_tb): diff --git a/integrations/langchain/tests/unit_tests/test_store.py b/integrations/langchain/tests/unit_tests/test_store.py index a10c5eda..2f7a1cfb 100644 --- a/integrations/langchain/tests/unit_tests/test_store.py +++ b/integrations/langchain/tests/unit_tests/test_store.py @@ -203,6 +203,65 @@ def test_databricks_store_warns_when_both_embeddings_and_endpoint_specified(monk assert store.index_config["embed"] is mock_embeddings +def test_databricks_store_context_manager_auto_setup_true(monkeypatch): + """Test that context manager calls setup() when auto_setup=True (default).""" + mock_conn = MagicMock() + test_pool = TestConnectionPool(connection_value=mock_conn) + monkeypatch.setattr(lakebase, "ConnectionPool", test_pool) + + from langgraph.store.postgres import PostgresStore + + workspace = _create_mock_workspace() + + mock_pg_store = MagicMock() + mock_pg_store.setup = MagicMock() + + with patch( + "databricks_langchain.store.PostgresStore", return_value=mock_pg_store + ) as mock_pg_class: + store = DatabricksStore( + instance_name="lakebase-instance", + workspace_client=workspace, + ) + + # Verify auto_setup defaults to True + assert store._auto_setup is True + + with store: + # setup() should have been called + mock_pg_store.setup.assert_called_once() + + +def test_databricks_store_context_manager_auto_setup_false(monkeypatch): + """Test that context manager does NOT call setup() when auto_setup=False.""" + mock_conn = MagicMock() + test_pool = TestConnectionPool(connection_value=mock_conn) + monkeypatch.setattr(lakebase, "ConnectionPool", test_pool) + + from langgraph.store.postgres import PostgresStore + + workspace = _create_mock_workspace() + + mock_pg_store = MagicMock() + mock_pg_store.setup = MagicMock() + + with patch( + "databricks_langchain.store.PostgresStore", return_value=mock_pg_store + ): + store = DatabricksStore( + instance_name="lakebase-instance", + workspace_client=workspace, + auto_setup=False, + ) + + # Verify auto_setup is False + assert store._auto_setup is False + + with store: + # setup() should NOT have been called + mock_pg_store.setup.assert_not_called() + + # ============================================================================= # AsyncDatabricksStore Tests # ============================================================================= @@ -287,6 +346,7 @@ async def test_async_databricks_store_context_manager(monkeypatch): async with AsyncDatabricksStore( instance_name="lakebase-instance", workspace_client=workspace, + auto_setup=False, # Disable auto_setup to avoid mocking setup ) as store: assert test_pool._opened assert not test_pool._closed @@ -294,6 +354,67 @@ async def test_async_databricks_store_context_manager(monkeypatch): assert test_pool._closed +@pytest.mark.asyncio +async def test_async_databricks_store_context_manager_auto_setup_true(monkeypatch): + """Test that async context manager calls setup() when auto_setup=True (default).""" + import asyncio + + mock_conn = MagicMock() + test_pool = TestAsyncConnectionPool(connection_value=mock_conn) + monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool) + + workspace = _create_mock_workspace() + + mock_pg_store = MagicMock() + future = asyncio.Future() + future.set_result(None) + mock_pg_store.setup = MagicMock(return_value=future) + + with patch( + "databricks_langchain.store.AsyncPostgresStore", return_value=mock_pg_store + ): + store = AsyncDatabricksStore( + instance_name="lakebase-instance", + workspace_client=workspace, + ) + + # Verify auto_setup defaults to True + assert store._auto_setup is True + + async with store: + # setup() should have been called + mock_pg_store.setup.assert_called_once() + + +@pytest.mark.asyncio +async def test_async_databricks_store_context_manager_auto_setup_false(monkeypatch): + """Test that async context manager does NOT call setup() when auto_setup=False.""" + mock_conn = MagicMock() + test_pool = TestAsyncConnectionPool(connection_value=mock_conn) + monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool) + + workspace = _create_mock_workspace() + + mock_pg_store = MagicMock() + mock_pg_store.setup = MagicMock() + + with patch( + "databricks_langchain.store.AsyncPostgresStore", return_value=mock_pg_store + ): + store = AsyncDatabricksStore( + instance_name="lakebase-instance", + workspace_client=workspace, + auto_setup=False, + ) + + # Verify auto_setup is False + assert store._auto_setup is False + + async with store: + # setup() should NOT have been called + mock_pg_store.setup.assert_not_called() + + @pytest.mark.asyncio async def test_async_databricks_store_with_embedding_endpoint(monkeypatch): """Test that embedding_endpoint creates embeddings and index_config."""