Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions integrations/langchain/src/databricks_langchain/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,16 @@ class CheckpointSaver(PostgresSaver):
"""
LangGraph PostgresSaver using a Lakebase connection pool.

instance_name: Name of Lakebase Instance
Supports two modes: Lakebase Provisioned VS Autoscaling
https://docs.databricks.com/aws/en/oltp/#feature-comparison
"""

def __init__(
self,
*,
instance_name: str,
instance_name: str | None = None,
project: str | None = None,
branch: str | None = None,
workspace_client: WorkspaceClient | None = None,
**pool_kwargs: Any,
) -> None:
Expand All @@ -40,13 +43,16 @@ def __init__(

self._lakebase: LakebasePool = LakebasePool(
instance_name=instance_name,
project=project,
branch=branch,
workspace_client=workspace_client,
**dict(pool_kwargs),
)
super().__init__(self._lakebase.pool)

def __enter__(self):
"""Enter context manager."""
"""Enter context manager and create checkpoint tables."""
self.setup()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
Expand All @@ -59,13 +65,18 @@ class AsyncCheckpointSaver(AsyncPostgresSaver):
"""
Async LangGraph PostgresSaver using a Lakebase connection pool.

instance_name: Name of Lakebase Instance
Supports two modes: Lakebase Provisioned VS Autoscaling
https://docs.databricks.com/aws/en/oltp/#feature-comparison

Checkpoint tables are created automatically when entering the context manager.
"""

def __init__(
self,
*,
instance_name: str,
instance_name: str | None = None,
project: str | None = None,
branch: str | None = None,
workspace_client: WorkspaceClient | None = None,
**pool_kwargs: Any,
) -> None:
Expand All @@ -78,14 +89,17 @@ def __init__(

self._lakebase: AsyncLakebasePool = AsyncLakebasePool(
instance_name=instance_name,
project=project,
branch=branch,
workspace_client=workspace_client,
**dict(pool_kwargs),
)
super().__init__(self._lakebase.pool)

async def __aenter__(self):
"""Enter async context manager and open the connection pool."""
"""Enter async context manager, open the connection pool, and create checkpoint tables."""
await self._lakebase.open()
await self.setup()
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
Expand Down
20 changes: 16 additions & 4 deletions integrations/langchain/src/databricks_langchain/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ class DatabricksStore(BaseStore):
def __init__(
self,
*,
instance_name: str,
instance_name: str | None = None,
project: str | None = None,
branch: str | None = None,
workspace_client: WorkspaceClient | None = None,
embedding_endpoint: str | None = None,
embedding_dims: int | None = None,
Expand All @@ -44,7 +46,9 @@ def __init__(
"""Initialize DatabricksStore with embedding support.

Args:
instance_name: The name of the Lakebase instance to connect to.
instance_name: The name of the Lakebase provisioned instance.
project: Lakebase autoscaling project name. Also requires ``branch``.
branch: Lakebase autoscaling branch name. Also requires ``project``.
workspace_client: Optional Databricks WorkspaceClient for authentication.
embedding_endpoint: Name of the Databricks Model Serving endpoint for embeddings
(e.g., "databricks-gte-large-en"). If provided, enables semantic search.
Expand All @@ -64,6 +68,8 @@ def __init__(

self._lakebase: LakebasePool = LakebasePool(
instance_name=instance_name,
project=project,
branch=branch,
workspace_client=workspace_client,
**pool_kwargs,
)
Expand Down Expand Up @@ -148,7 +154,9 @@ class AsyncDatabricksStore(AsyncBatchedBaseStore):
def __init__(
self,
*,
instance_name: str,
instance_name: str | None = None,
project: str | None = None,
branch: str | None = None,
workspace_client: WorkspaceClient | None = None,
embedding_endpoint: str | None = None,
embedding_dims: int | None = None,
Expand All @@ -159,7 +167,9 @@ def __init__(
"""Initialize AsyncDatabricksStore with embedding support.

Args:
instance_name: The name of the Lakebase instance to connect to.
instance_name: The name of the Lakebase provisioned instance.
project: Lakebase autoscaling project name. Also requires ``branch``.
branch: Lakebase autoscaling branch name. Also requires ``project``.
workspace_client: Optional Databricks WorkspaceClient for authentication.
embedding_endpoint: Name of the Databricks Model Serving endpoint for embeddings
(e.g., "databricks-gte-large-en"). If provided, enables semantic search.
Expand All @@ -181,6 +191,8 @@ def __init__(

self._lakebase: AsyncLakebasePool = AsyncLakebasePool(
instance_name=instance_name,
project=project,
branch=branch,
workspace_client=workspace_client,
**pool_kwargs,
)
Expand Down
116 changes: 116 additions & 0 deletions integrations/langchain/tests/unit_tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,119 @@ async def test_async_checkpoint_saver_connection(monkeypatch):
) as saver:
async with saver._lakebase.connection() as conn:
assert conn == "async-lake-conn"


# =============================================================================
# Autoscaling (project/branch) Tests
# =============================================================================


def _create_autoscaling_workspace():
"""Helper to create a mock workspace client for autoscaling mode."""
workspace = MagicMock()
workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com")
workspace.postgres.generate_database_credential.return_value = MagicMock(
token="autoscaling-token"
)
rw_endpoint = MagicMock()
rw_endpoint.name = "projects/p/branches/b/endpoints/rw"
rw_endpoint.status.endpoint_type = "READ_WRITE"
rw_endpoint.status.hosts.host = "auto-db-host"
workspace.postgres.list_endpoints.return_value = [rw_endpoint]
return workspace


def test_checkpoint_saver_autoscaling_configures_lakebase(monkeypatch):
test_pool = TestConnectionPool(connection_value="lake-conn")
monkeypatch.setattr(lakebase, "ConnectionPool", test_pool)

workspace = _create_autoscaling_workspace()

saver = CheckpointSaver(
project="my-project",
branch="my-branch",
workspace_client=workspace,
)

assert "host=auto-db-host" in test_pool.conninfo
assert saver._lakebase._is_autoscaling is True
workspace.postgres.list_endpoints.assert_called_once_with(
parent="projects/my-project/branches/my-branch"
)


@pytest.mark.asyncio
async def test_async_checkpoint_saver_autoscaling_configures_lakebase(monkeypatch):
test_pool = TestAsyncConnectionPool(connection_value="async-lake-conn")
monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool)

workspace = _create_autoscaling_workspace()

saver = AsyncCheckpointSaver(
project="my-project",
branch="my-branch",
workspace_client=workspace,
)

assert "host=auto-db-host" in test_pool.conninfo
assert saver._lakebase._is_autoscaling is True


@pytest.mark.asyncio
async def test_async_checkpoint_saver_autoscaling_context_manager(monkeypatch):
test_pool = TestAsyncConnectionPool(connection_value="async-lake-conn")
monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool)

workspace = _create_autoscaling_workspace()

async with AsyncCheckpointSaver(
project="my-project",
branch="my-branch",
workspace_client=workspace,
) as saver:
assert test_pool._opened
assert saver._lakebase._is_autoscaling is True

assert test_pool._closed


# =============================================================================
# Validation: missing parameters
# =============================================================================


def test_checkpoint_saver_no_params_raises_error(monkeypatch):
"""CheckpointSaver with no connection parameters raises ValueError."""
test_pool = TestConnectionPool()
monkeypatch.setattr(lakebase, "ConnectionPool", test_pool)

workspace = MagicMock()
workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com")

with pytest.raises(ValueError, match="Must provide either 'instance_name'"):
CheckpointSaver(workspace_client=workspace)


def test_checkpoint_saver_only_project_raises_error(monkeypatch):
"""CheckpointSaver with only project (no branch) raises ValueError."""
test_pool = TestConnectionPool()
monkeypatch.setattr(lakebase, "ConnectionPool", test_pool)

workspace = MagicMock()
workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com")

with pytest.raises(ValueError, match="Both 'project' and 'branch' are required"):
CheckpointSaver(project="my-project", workspace_client=workspace)


@pytest.mark.asyncio
async def test_async_checkpoint_saver_no_params_raises_error(monkeypatch):
"""AsyncCheckpointSaver with no connection parameters raises ValueError."""
test_pool = TestAsyncConnectionPool()
monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool)

workspace = MagicMock()
workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com")

with pytest.raises(ValueError, match="Must provide either 'instance_name'"):
AsyncCheckpointSaver(workspace_client=workspace)
137 changes: 137 additions & 0 deletions integrations/langchain/tests/unit_tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,3 +426,140 @@ async def test_async_databricks_store_connection(monkeypatch):
) as store:
async with store._lakebase.connection() as conn:
assert conn == mock_conn


# =============================================================================
# Autoscaling (project/branch) Tests
# =============================================================================


def _create_autoscaling_workspace():
"""Helper to create a mock workspace client for autoscaling mode."""
workspace = MagicMock()
workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com")
workspace.postgres.generate_database_credential.return_value = MagicMock(
token="autoscaling-token"
)
rw_endpoint = MagicMock()
rw_endpoint.name = "projects/p/branches/b/endpoints/rw"
rw_endpoint.status.endpoint_type = "READ_WRITE"
rw_endpoint.status.hosts.host = "auto-db-host"
workspace.postgres.list_endpoints.return_value = [rw_endpoint]
return workspace


def test_databricks_store_autoscaling_configures_lakebase(monkeypatch):
"""Test that DatabricksStore with project/branch uses autoscaling path."""
mock_conn = MagicMock()
test_pool = TestConnectionPool(connection_value=mock_conn)
monkeypatch.setattr(lakebase, "ConnectionPool", test_pool)

workspace = _create_autoscaling_workspace()

store = DatabricksStore(
project="my-project",
branch="my-branch",
workspace_client=workspace,
)

assert "host=auto-db-host" in test_pool.conninfo
assert store._lakebase._is_autoscaling is True
workspace.postgres.list_endpoints.assert_called_once_with(
parent="projects/my-project/branches/my-branch"
)


def test_databricks_store_provisioned_uses_provisioned_path(monkeypatch):
"""Test that DatabricksStore with instance_name uses provisioned path."""
mock_conn = MagicMock()
test_pool = TestConnectionPool(connection_value=mock_conn)
monkeypatch.setattr(lakebase, "ConnectionPool", test_pool)

workspace = _create_mock_workspace()

store = DatabricksStore(
instance_name="lakebase-instance",
workspace_client=workspace,
)

assert "host=db-host" in test_pool.conninfo
assert store._lakebase._is_autoscaling is False
workspace.database.get_database_instance.assert_called_once()


@pytest.mark.asyncio
async def test_async_databricks_store_autoscaling_configures_lakebase(monkeypatch):
"""Test that AsyncDatabricksStore with project/branch uses autoscaling path."""
mock_conn = MagicMock()
test_pool = TestAsyncConnectionPool(connection_value=mock_conn)
monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool)

workspace = _create_autoscaling_workspace()

store = AsyncDatabricksStore(
project="my-project",
branch="my-branch",
workspace_client=workspace,
)

assert "host=auto-db-host" in test_pool.conninfo
assert store._lakebase._is_autoscaling is True


@pytest.mark.asyncio
async def test_async_databricks_store_autoscaling_context_manager(monkeypatch):
"""Test autoscaling async store context manager opens and closes the pool."""
mock_conn = MagicMock()
test_pool = TestAsyncConnectionPool(connection_value=mock_conn)
monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool)

workspace = _create_autoscaling_workspace()

async with AsyncDatabricksStore(
project="my-project",
branch="my-branch",
workspace_client=workspace,
) as store:
assert test_pool._opened
assert store._lakebase._is_autoscaling is True

assert test_pool._closed


# =============================================================================
# Validation: missing parameters
# =============================================================================


def test_databricks_store_no_params_raises_error(monkeypatch):
"""DatabricksStore with no connection parameters raises ValueError."""
test_pool = TestConnectionPool()
monkeypatch.setattr(lakebase, "ConnectionPool", test_pool)

workspace = _create_mock_workspace()

with pytest.raises(ValueError, match="Must provide either 'instance_name'"):
DatabricksStore(workspace_client=workspace)


def test_databricks_store_only_branch_raises_error(monkeypatch):
"""DatabricksStore with only branch (no project) raises ValueError."""
test_pool = TestConnectionPool()
monkeypatch.setattr(lakebase, "ConnectionPool", test_pool)

workspace = _create_mock_workspace()

with pytest.raises(ValueError, match="Both 'project' and 'branch' are required"):
DatabricksStore(branch="my-branch", workspace_client=workspace)


@pytest.mark.asyncio
async def test_async_databricks_store_no_params_raises_error(monkeypatch):
"""AsyncDatabricksStore with no connection parameters raises ValueError."""
test_pool = TestAsyncConnectionPool()
monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool)

workspace = _create_mock_workspace()

with pytest.raises(ValueError, match="Must provide either 'instance_name'"):
AsyncDatabricksStore(workspace_client=workspace)
Loading
Loading