From 1c7793fee7bec5bec46b0d25423ff27d61a8adb0 Mon Sep 17 00:00:00 2001 From: Jenny Date: Mon, 2 Mar 2026 15:27:58 -0800 Subject: [PATCH 1/8] initial commit lakebase autoscaling --- .../src/databricks_langchain/checkpoint.py | 18 +- .../src/databricks_langchain/store.py | 20 +- .../tests/unit_tests/test_checkpoint.py | 116 ++++++ .../langchain/tests/unit_tests/test_store.py | 137 +++++++ .../src/databricks_openai/agents/session.py | 60 ++- .../openai/tests/unit_tests/test_session.py | 291 ++++++++++++++ src/databricks_ai_bridge/lakebase.py | 196 ++++++++-- tests/databricks_ai_bridge/test_lakebase.py | 363 +++++++++++++++++- 8 files changed, 1154 insertions(+), 47 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/checkpoint.py b/integrations/langchain/src/databricks_langchain/checkpoint.py index eb0ac7d8..5a8df35f 100644 --- a/integrations/langchain/src/databricks_langchain/checkpoint.py +++ b/integrations/langchain/src/databricks_langchain/checkpoint.py @@ -21,13 +21,16 @@ class CheckpointSaver(PostgresSaver): """ LangGraph PostgresSaver using a Lakebase connection pool. - instance_name: Name of Lakebase Instance + Supports two modes: Lakebase Provisioned VS Autoscaling + https://docs.databricks.com/aws/en/oltp/#feature-comparison """ def __init__( self, *, - instance_name: str, + instance_name: str | None = None, + project: str | None = None, + branch: str | None = None, workspace_client: WorkspaceClient | None = None, **pool_kwargs: Any, ) -> None: @@ -40,6 +43,8 @@ def __init__( self._lakebase: LakebasePool = LakebasePool( instance_name=instance_name, + project=project, + branch=branch, workspace_client=workspace_client, **dict(pool_kwargs), ) @@ -59,13 +64,16 @@ class AsyncCheckpointSaver(AsyncPostgresSaver): """ Async LangGraph PostgresSaver using a Lakebase connection pool. - instance_name: Name of Lakebase Instance + Supports two modes: Lakebase Provisioned VS Autoscaling + https://docs.databricks.com/aws/en/oltp/#feature-comparison """ def __init__( self, *, - instance_name: str, + instance_name: str | None = None, + project: str | None = None, + branch: str | None = None, workspace_client: WorkspaceClient | None = None, **pool_kwargs: Any, ) -> None: @@ -78,6 +86,8 @@ def __init__( self._lakebase: AsyncLakebasePool = AsyncLakebasePool( instance_name=instance_name, + project=project, + branch=branch, workspace_client=workspace_client, **dict(pool_kwargs), ) diff --git a/integrations/langchain/src/databricks_langchain/store.py b/integrations/langchain/src/databricks_langchain/store.py index 05aed3b8..78f8ae98 100644 --- a/integrations/langchain/src/databricks_langchain/store.py +++ b/integrations/langchain/src/databricks_langchain/store.py @@ -33,7 +33,9 @@ class DatabricksStore(BaseStore): def __init__( self, *, - instance_name: str, + instance_name: str | None = None, + project: str | None = None, + branch: str | None = None, workspace_client: WorkspaceClient | None = None, embedding_endpoint: str | None = None, embedding_dims: int | None = None, @@ -44,7 +46,9 @@ def __init__( """Initialize DatabricksStore with embedding support. Args: - instance_name: The name of the Lakebase instance to connect to. + instance_name: The name of the Lakebase provisioned instance. + project: Lakebase autoscaling project name. Also requires ``branch``. + branch: Lakebase autoscaling branch name. Also requires ``project``. workspace_client: Optional Databricks WorkspaceClient for authentication. embedding_endpoint: Name of the Databricks Model Serving endpoint for embeddings (e.g., "databricks-gte-large-en"). If provided, enables semantic search. @@ -64,6 +68,8 @@ def __init__( self._lakebase: LakebasePool = LakebasePool( instance_name=instance_name, + project=project, + branch=branch, workspace_client=workspace_client, **pool_kwargs, ) @@ -148,7 +154,9 @@ class AsyncDatabricksStore(AsyncBatchedBaseStore): def __init__( self, *, - instance_name: str, + instance_name: str | None = None, + project: str | None = None, + branch: str | None = None, workspace_client: WorkspaceClient | None = None, embedding_endpoint: str | None = None, embedding_dims: int | None = None, @@ -159,7 +167,9 @@ def __init__( """Initialize AsyncDatabricksStore with embedding support. Args: - instance_name: The name of the Lakebase instance to connect to. + instance_name: The name of the Lakebase provisioned instance. + project: Lakebase autoscaling project name. Also requires ``branch``. + branch: Lakebase autoscaling branch name. Also requires ``project``. workspace_client: Optional Databricks WorkspaceClient for authentication. embedding_endpoint: Name of the Databricks Model Serving endpoint for embeddings (e.g., "databricks-gte-large-en"). If provided, enables semantic search. @@ -181,6 +191,8 @@ def __init__( self._lakebase: AsyncLakebasePool = AsyncLakebasePool( instance_name=instance_name, + project=project, + branch=branch, workspace_client=workspace_client, **pool_kwargs, ) diff --git a/integrations/langchain/tests/unit_tests/test_checkpoint.py b/integrations/langchain/tests/unit_tests/test_checkpoint.py index a64b6c88..3d273b48 100644 --- a/integrations/langchain/tests/unit_tests/test_checkpoint.py +++ b/integrations/langchain/tests/unit_tests/test_checkpoint.py @@ -168,3 +168,119 @@ async def test_async_checkpoint_saver_connection(monkeypatch): ) as saver: async with saver._lakebase.connection() as conn: assert conn == "async-lake-conn" + + +# ============================================================================= +# Autoscaling (project/branch) Tests +# ============================================================================= + + +def _create_autoscaling_workspace(): + """Helper to create a mock workspace client for autoscaling mode.""" + workspace = MagicMock() + workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com") + workspace.postgres.generate_database_credential.return_value = MagicMock( + token="autoscaling-token" + ) + rw_endpoint = MagicMock() + rw_endpoint.name = "projects/p/branches/b/endpoints/rw" + rw_endpoint.status.endpoint_type = "READ_WRITE" + rw_endpoint.status.hosts.host = "auto-db-host" + workspace.postgres.list_endpoints.return_value = [rw_endpoint] + return workspace + + +def test_checkpoint_saver_autoscaling_configures_lakebase(monkeypatch): + test_pool = TestConnectionPool(connection_value="lake-conn") + monkeypatch.setattr(lakebase, "ConnectionPool", test_pool) + + workspace = _create_autoscaling_workspace() + + saver = CheckpointSaver( + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) + + assert "host=auto-db-host" in test_pool.conninfo + assert saver._lakebase._is_autoscaling is True + workspace.postgres.list_endpoints.assert_called_once_with( + parent="projects/my-project/branches/my-branch" + ) + + +@pytest.mark.asyncio +async def test_async_checkpoint_saver_autoscaling_configures_lakebase(monkeypatch): + test_pool = TestAsyncConnectionPool(connection_value="async-lake-conn") + monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool) + + workspace = _create_autoscaling_workspace() + + saver = AsyncCheckpointSaver( + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) + + assert "host=auto-db-host" in test_pool.conninfo + assert saver._lakebase._is_autoscaling is True + + +@pytest.mark.asyncio +async def test_async_checkpoint_saver_autoscaling_context_manager(monkeypatch): + test_pool = TestAsyncConnectionPool(connection_value="async-lake-conn") + monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool) + + workspace = _create_autoscaling_workspace() + + async with AsyncCheckpointSaver( + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) as saver: + assert test_pool._opened + assert saver._lakebase._is_autoscaling is True + + assert test_pool._closed + + +# ============================================================================= +# Validation: missing parameters +# ============================================================================= + + +def test_checkpoint_saver_no_params_raises_error(monkeypatch): + """CheckpointSaver with no connection parameters raises ValueError.""" + test_pool = TestConnectionPool() + monkeypatch.setattr(lakebase, "ConnectionPool", test_pool) + + workspace = MagicMock() + workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com") + + with pytest.raises(ValueError, match="Must provide either 'instance_name'"): + CheckpointSaver(workspace_client=workspace) + + +def test_checkpoint_saver_only_project_raises_error(monkeypatch): + """CheckpointSaver with only project (no branch) raises ValueError.""" + test_pool = TestConnectionPool() + monkeypatch.setattr(lakebase, "ConnectionPool", test_pool) + + workspace = MagicMock() + workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com") + + with pytest.raises(ValueError, match="Both 'project' and 'branch' are required"): + CheckpointSaver(project="my-project", workspace_client=workspace) + + +@pytest.mark.asyncio +async def test_async_checkpoint_saver_no_params_raises_error(monkeypatch): + """AsyncCheckpointSaver with no connection parameters raises ValueError.""" + test_pool = TestAsyncConnectionPool() + monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool) + + workspace = MagicMock() + workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com") + + with pytest.raises(ValueError, match="Must provide either 'instance_name'"): + AsyncCheckpointSaver(workspace_client=workspace) diff --git a/integrations/langchain/tests/unit_tests/test_store.py b/integrations/langchain/tests/unit_tests/test_store.py index a10c5eda..cee740d3 100644 --- a/integrations/langchain/tests/unit_tests/test_store.py +++ b/integrations/langchain/tests/unit_tests/test_store.py @@ -426,3 +426,140 @@ async def test_async_databricks_store_connection(monkeypatch): ) as store: async with store._lakebase.connection() as conn: assert conn == mock_conn + + +# ============================================================================= +# Autoscaling (project/branch) Tests +# ============================================================================= + + +def _create_autoscaling_workspace(): + """Helper to create a mock workspace client for autoscaling mode.""" + workspace = MagicMock() + workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com") + workspace.postgres.generate_database_credential.return_value = MagicMock( + token="autoscaling-token" + ) + rw_endpoint = MagicMock() + rw_endpoint.name = "projects/p/branches/b/endpoints/rw" + rw_endpoint.status.endpoint_type = "READ_WRITE" + rw_endpoint.status.hosts.host = "auto-db-host" + workspace.postgres.list_endpoints.return_value = [rw_endpoint] + return workspace + + +def test_databricks_store_autoscaling_configures_lakebase(monkeypatch): + """Test that DatabricksStore with project/branch uses autoscaling path.""" + mock_conn = MagicMock() + test_pool = TestConnectionPool(connection_value=mock_conn) + monkeypatch.setattr(lakebase, "ConnectionPool", test_pool) + + workspace = _create_autoscaling_workspace() + + store = DatabricksStore( + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) + + assert "host=auto-db-host" in test_pool.conninfo + assert store._lakebase._is_autoscaling is True + workspace.postgres.list_endpoints.assert_called_once_with( + parent="projects/my-project/branches/my-branch" + ) + + +def test_databricks_store_provisioned_uses_provisioned_path(monkeypatch): + """Test that DatabricksStore with instance_name uses provisioned path.""" + mock_conn = MagicMock() + test_pool = TestConnectionPool(connection_value=mock_conn) + monkeypatch.setattr(lakebase, "ConnectionPool", test_pool) + + workspace = _create_mock_workspace() + + store = DatabricksStore( + instance_name="lakebase-instance", + workspace_client=workspace, + ) + + assert "host=db-host" in test_pool.conninfo + assert store._lakebase._is_autoscaling is False + workspace.database.get_database_instance.assert_called_once() + + +@pytest.mark.asyncio +async def test_async_databricks_store_autoscaling_configures_lakebase(monkeypatch): + """Test that AsyncDatabricksStore with project/branch uses autoscaling path.""" + mock_conn = MagicMock() + test_pool = TestAsyncConnectionPool(connection_value=mock_conn) + monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool) + + workspace = _create_autoscaling_workspace() + + store = AsyncDatabricksStore( + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) + + assert "host=auto-db-host" in test_pool.conninfo + assert store._lakebase._is_autoscaling is True + + +@pytest.mark.asyncio +async def test_async_databricks_store_autoscaling_context_manager(monkeypatch): + """Test autoscaling async store context manager opens and closes the pool.""" + mock_conn = MagicMock() + test_pool = TestAsyncConnectionPool(connection_value=mock_conn) + monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool) + + workspace = _create_autoscaling_workspace() + + async with AsyncDatabricksStore( + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) as store: + assert test_pool._opened + assert store._lakebase._is_autoscaling is True + + assert test_pool._closed + + +# ============================================================================= +# Validation: missing parameters +# ============================================================================= + + +def test_databricks_store_no_params_raises_error(monkeypatch): + """DatabricksStore with no connection parameters raises ValueError.""" + test_pool = TestConnectionPool() + monkeypatch.setattr(lakebase, "ConnectionPool", test_pool) + + workspace = _create_mock_workspace() + + with pytest.raises(ValueError, match="Must provide either 'instance_name'"): + DatabricksStore(workspace_client=workspace) + + +def test_databricks_store_only_branch_raises_error(monkeypatch): + """DatabricksStore with only branch (no project) raises ValueError.""" + test_pool = TestConnectionPool() + monkeypatch.setattr(lakebase, "ConnectionPool", test_pool) + + workspace = _create_mock_workspace() + + with pytest.raises(ValueError, match="Both 'project' and 'branch' are required"): + DatabricksStore(branch="my-branch", workspace_client=workspace) + + +@pytest.mark.asyncio +async def test_async_databricks_store_no_params_raises_error(monkeypatch): + """AsyncDatabricksStore with no connection parameters raises ValueError.""" + test_pool = TestAsyncConnectionPool() + monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool) + + workspace = _create_mock_workspace() + + with pytest.raises(ValueError, match="Must provide either 'instance_name'"): + AsyncDatabricksStore(workspace_client=workspace) diff --git a/integrations/openai/src/databricks_openai/agents/session.py b/integrations/openai/src/databricks_openai/agents/session.py index c2642b68..bb3f94bd 100644 --- a/integrations/openai/src/databricks_openai/agents/session.py +++ b/integrations/openai/src/databricks_openai/agents/session.py @@ -98,7 +98,9 @@ def __init__( self, session_id: str, *, - instance_name: str, + instance_name: Optional[str] = None, + project: Optional[str] = None, + branch: Optional[str] = None, workspace_client: Optional[WorkspaceClient] = None, token_cache_duration_seconds: int = DEFAULT_TOKEN_CACHE_DURATION_SECONDS, create_tables: bool = True, @@ -112,7 +114,9 @@ def __init__( Args: session_id: Unique identifier for the conversation session. - instance_name: Name of the Lakebase instance. + instance_name: Name of the Lakebase provisioned instance. + project: Lakebase autoscaling project name. Also requires ``branch``. + branch: Lakebase autoscaling branch name. Also requires ``project``. workspace_client: Optional WorkspaceClient for authentication. If not provided, a default client will be created. token_cache_duration_seconds: How long to cache OAuth tokens. @@ -124,8 +128,8 @@ def __init__( messages_table: Name of the messages table. Defaults to "agent_messages". use_cached_engine: Whether to reuse a cached engine for the same - instance_name and engine_kwargs combination. Set to False to - always create a new engine. Defaults to True. + connection parameters and engine_kwargs combination. Set to False + to always create a new engine. Defaults to True. **engine_kwargs: Additional keyword arguments passed to SQLAlchemy's create_async_engine(). """ @@ -135,8 +139,23 @@ def __init__( "Please install with: pip install databricks-openai[memory]" ) + # Validate connection parameters early (before cache key creation) + is_autoscaling = project is not None or branch is not None + if is_autoscaling and not (project and branch): + raise ValueError( + "Both 'project' and 'branch' are required to use a Lakebase " + "autoscaling instance. Please specify both parameters." + ) + if not is_autoscaling and instance_name is None: + raise ValueError( + "Must provide either 'instance_name' (provisioned) or both " + "'project' and 'branch' (autoscaling)." + ) + self._lakebase = self._get_or_create_lakebase( instance_name=instance_name, + project=project, + branch=branch, workspace_client=workspace_client, token_cache_duration_seconds=token_cache_duration_seconds, pool_recycle=engine_kwargs.pop("pool_recycle", DEFAULT_POOL_RECYCLE_SECONDS), @@ -154,23 +173,32 @@ def __init__( ) logger.info( - "AsyncDatabricksSession initialized: instance=%s session_id=%s", - instance_name, + "AsyncDatabricksSession initialized: session_id=%s", session_id, ) @classmethod - def _build_cache_key(cls, instance_name: str, **engine_kwargs: Any) -> str: - """Build a cache key from instance_name and engine_kwargs.""" + def _build_cache_key( + cls, + instance_name: Optional[str] = None, + project: Optional[str] = None, + branch: Optional[str] = None, + **engine_kwargs: Any, + ) -> str: + """Build a cache key from connection parameters and engine_kwargs.""" # Sort kwargs for deterministic key; use JSON for serializable values kwargs_key = json.dumps(engine_kwargs, sort_keys=True, default=str) - return f"{instance_name}::{kwargs_key}" + if project and branch: + return f"autoscaling::{project}::{branch}::{kwargs_key}" + return f"provisioned::{instance_name}::{kwargs_key}" @classmethod def _get_or_create_lakebase( cls, *, - instance_name: str, + instance_name: Optional[str], + project: Optional[str], + branch: Optional[str], workspace_client: Optional[WorkspaceClient], token_cache_duration_seconds: int, pool_recycle: int, @@ -178,9 +206,15 @@ def _get_or_create_lakebase( **engine_kwargs, ) -> AsyncLakebaseSQLAlchemy: """Get cached AsyncLakebaseSQLAlchemy or create a new one. - The cache key uses both instance_name and engine_kwargs + The cache key uses connection parameters and engine_kwargs. """ - cache_key = cls._build_cache_key(instance_name, pool_recycle=pool_recycle, **engine_kwargs) + cache_key = cls._build_cache_key( + instance_name=instance_name, + project=project, + branch=branch, + pool_recycle=pool_recycle, + **engine_kwargs, + ) if use_cached_engine: with cls._lakebase_sql_alchemy_cache_lock: @@ -190,6 +224,8 @@ def _get_or_create_lakebase( lakebase = AsyncLakebaseSQLAlchemy( instance_name=instance_name, + project=project, + branch=branch, workspace_client=workspace_client, token_cache_duration_seconds=token_cache_duration_seconds, pool_recycle=pool_recycle, diff --git a/integrations/openai/tests/unit_tests/test_session.py b/integrations/openai/tests/unit_tests/test_session.py index 9fe66ce1..69effe49 100644 --- a/integrations/openai/tests/unit_tests/test_session.py +++ b/integrations/openai/tests/unit_tests/test_session.py @@ -855,3 +855,294 @@ def test_methods_are_coroutine_functions( assert inspect.iscoroutinefunction(session.add_items) assert inspect.iscoroutinefunction(session.pop_item) assert inspect.iscoroutinefunction(session.clear_session) + + +# ============================================================================= +# Autoscaling (project/branch) Tests +# ============================================================================= + + +@pytest.fixture +def mock_autoscaling_workspace_client(): + """Create a mock WorkspaceClient for autoscaling mode.""" + mock_client = MagicMock() + mock_client.config.host = "https://test.databricks.com" + + # Mock current_user.me() for username inference + mock_user = MagicMock() + mock_user.user_name = "test_user@databricks.com" + mock_client.current_user.me.return_value = mock_user + + # Mock postgres.list_endpoints → returns one READ_WRITE endpoint + rw_endpoint = MagicMock() + rw_endpoint.name = "projects/my-project/branches/my-branch/endpoints/rw-ep" + rw_endpoint.status.endpoint_type = "READ_WRITE" + rw_endpoint.status.hosts.host = "autoscaling-instance.lakebase.databricks.com" + mock_client.postgres.list_endpoints.return_value = [rw_endpoint] + + # Mock postgres.generate_database_credential for autoscaling token minting + mock_credential = MagicMock() + mock_credential.token = "autoscaling-oauth-token" + mock_client.postgres.generate_database_credential.return_value = mock_credential + + return mock_client + + +class TestAsyncDatabricksSessionAutoscaling: + """Tests for AsyncDatabricksSession with autoscaling (project/branch).""" + + def test_init_autoscaling_resolves_host( + self, mock_autoscaling_workspace_client, mock_engine, mock_event_listens_for + ): + """Test that initialization with project/branch resolves host via autoscaling API.""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_autoscaling_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ) as mock_create_engine, + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + session = AsyncDatabricksSession( + session_id="test-session-123", + project="my-project", + branch="my-branch", + workspace_client=mock_autoscaling_workspace_client, + ) + + # Verify engine URL uses autoscaling host + call_args = mock_create_engine.call_args + url = call_args[0][0] + assert url.host == "autoscaling-instance.lakebase.databricks.com" + + # Verify autoscaling API was called + mock_autoscaling_workspace_client.postgres.list_endpoints.assert_called_once_with( + parent="projects/my-project/branches/my-branch" + ) + + def test_init_autoscaling_injects_correct_token( + self, mock_autoscaling_workspace_client, mock_engine + ): + """Test that do_connect injects autoscaling token.""" + captured_handler = None + + def capture_handler(engine, event_name): + def decorator(fn): + nonlocal captured_handler + captured_handler = fn + return fn + + return decorator + + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_autoscaling_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ), + patch( + "sqlalchemy.event.listens_for", + side_effect=capture_handler, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + AsyncDatabricksSession( + session_id="test-session-123", + project="my-project", + branch="my-branch", + workspace_client=mock_autoscaling_workspace_client, + ) + + # Simulate do_connect event + assert captured_handler is not None + cparams = {} + captured_handler(None, None, None, cparams) + + # Verify autoscaling token was injected + assert cparams["password"] == "autoscaling-oauth-token" + mock_autoscaling_workspace_client.postgres.generate_database_credential.assert_called() + + def test_autoscaling_sessions_share_engine( + self, mock_autoscaling_workspace_client, mock_engine, mock_event_listens_for + ): + """Test that autoscaling sessions with same project/branch share an engine.""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_autoscaling_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ) as mock_create_engine, + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + session1 = AsyncDatabricksSession( + session_id="session-1", + project="my-project", + branch="my-branch", + workspace_client=mock_autoscaling_workspace_client, + ) + session2 = AsyncDatabricksSession( + session_id="session-2", + project="my-project", + branch="my-branch", + workspace_client=mock_autoscaling_workspace_client, + ) + + # Engine should only be created once + assert mock_create_engine.call_count == 1 + assert session1._engine is session2._engine + + def test_different_branches_get_different_engines( + self, mock_autoscaling_workspace_client, mock_event_listens_for + ): + """Test that sessions with different branches get different engines.""" + engine1 = MagicMock() + engine1.sync_engine = MagicMock() + engine2 = MagicMock() + engine2.sync_engine = MagicMock() + + engines = [engine1, engine2] + engine_iter = iter(engines) + + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_autoscaling_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + side_effect=lambda *args, **kwargs: next(engine_iter), + ) as mock_create_engine, + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + session1 = AsyncDatabricksSession( + session_id="session-1", + project="my-project", + branch="branch-a", + workspace_client=mock_autoscaling_workspace_client, + ) + session2 = AsyncDatabricksSession( + session_id="session-2", + project="my-project", + branch="branch-b", + workspace_client=mock_autoscaling_workspace_client, + ) + + assert mock_create_engine.call_count == 2 + assert session1._engine is not session2._engine + + def test_provisioned_and_autoscaling_get_different_engines( + self, mock_workspace_client, mock_autoscaling_workspace_client, mock_event_listens_for + ): + """Test that a provisioned session and an autoscaling session get different engines.""" + engine1 = MagicMock() + engine1.sync_engine = MagicMock() + engine2 = MagicMock() + engine2.sync_engine = MagicMock() + + engines = [engine1, engine2] + engine_iter = iter(engines) + + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + side_effect=lambda *args, **kwargs: next(engine_iter), + ) as mock_create_engine, + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + session_provisioned = AsyncDatabricksSession( + session_id="session-prov", + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + session_autoscaling = AsyncDatabricksSession( + session_id="session-auto", + project="my-project", + branch="my-branch", + workspace_client=mock_autoscaling_workspace_client, + ) + + assert mock_create_engine.call_count == 2 + assert session_provisioned._engine is not session_autoscaling._engine + + +# ============================================================================= +# Validation: missing parameters +# ============================================================================= + + +class TestAsyncDatabricksSessionValidation: + """Tests for parameter validation in AsyncDatabricksSession.""" + + def test_no_params_raises_error(self): + """AsyncDatabricksSession with no connection parameters raises ValueError.""" + from databricks_openai.agents.session import AsyncDatabricksSession + + workspace = MagicMock() + workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com") + + with pytest.raises(ValueError, match="Must provide either 'instance_name'"): + AsyncDatabricksSession( + session_id="test-session", + workspace_client=workspace, + ) + + def test_only_project_raises_error(self): + """AsyncDatabricksSession with only project (no branch) raises ValueError.""" + from databricks_openai.agents.session import AsyncDatabricksSession + + workspace = MagicMock() + workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com") + + with pytest.raises(ValueError, match="Both 'project' and 'branch' are required"): + AsyncDatabricksSession( + session_id="test-session", + project="my-project", + workspace_client=workspace, + ) + + def test_only_branch_raises_error(self): + """AsyncDatabricksSession with only branch (no project) raises ValueError.""" + from databricks_openai.agents.session import AsyncDatabricksSession + + workspace = MagicMock() + workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com") + + with pytest.raises(ValueError, match="Both 'project' and 'branch' are required"): + AsyncDatabricksSession( + session_id="test-session", + branch="my-branch", + workspace_client=workspace, + ) diff --git a/src/databricks_ai_bridge/lakebase.py b/src/databricks_ai_bridge/lakebase.py index 98a1b0e2..32fb9036 100644 --- a/src/databricks_ai_bridge/lakebase.py +++ b/src/databricks_ai_bridge/lakebase.py @@ -93,26 +93,72 @@ class _LakebaseBase: Base class for Lakebase connections: resolve host, infer username, token cache + minting, and conninfo building. + Supports two modes: Lakebase Provisioned VS Autoscaling + https://docs.databricks.com/aws/en/oltp/#feature-comparison + + - **Provisioned**: Pass ``instance_name``. + - **Autoscaling**: Pass ``project`` and ``branch``. + + When both ``instance_name`` *and* ``project``/``branch`` are provided, the + autoscaling path takes precedence. + Subclasses implement specific initialization and lifecycle methods. """ def __init__( self, *, - instance_name: str, + instance_name: str | None = None, + project: str | None = None, + branch: str | None = None, workspace_client: WorkspaceClient | None = None, token_cache_duration_seconds: int = DEFAULT_TOKEN_CACHE_DURATION_SECONDS, ) -> None: self.workspace_client: WorkspaceClient = workspace_client or WorkspaceClient() - self.instance_name: str = instance_name self.token_cache_duration_seconds: int = token_cache_duration_seconds - # Resolve host from the Lakebase name + # --- Parameter validation --- + is_autoscaling = project is not None or branch is not None + if is_autoscaling and not (project and branch): + raise ValueError( + "Both 'project' and 'branch' are required to use a Lakebase " + "autoscaling instance. Please specify both parameters." + ) + + if not is_autoscaling and instance_name is None: + raise ValueError( + "Must provide either 'instance_name' (provisioned) or both " + "'project' and 'branch' (autoscaling)." + ) + + # Autoscaling takes precedence when both are provided + self._is_autoscaling: bool = is_autoscaling + + self.instance_name: str | None = instance_name + self.project: str | None = project + self.branch: str | None = branch + + if self._is_autoscaling: + self._endpoint_name: str | None = None + self.host = self._resolve_autoscaling_host() + else: + self._endpoint_name = None + self.host = self._resolve_provisioned_host() + + self.username: str = self._infer_username() + + self._cached_token: str | None = None + self._cache_ts: float | None = None + + # --- Host resolution --- + + def _resolve_provisioned_host(self) -> str: + """Resolve host via the Lakebase provisioned database API.""" try: - instance = self.workspace_client.database.get_database_instance(instance_name) + instance = self.workspace_client.database.get_database_instance(self.instance_name) except Exception as exc: raise ValueError( - f"Unable to resolve Lakebase instance '{instance_name}'. " + f"Unable to resolve Lakebase instance '{self.instance_name}'. " "Ensure the instance name is correct." ) from exc @@ -122,15 +168,58 @@ def __init__( if not resolved_host: raise ValueError( - f"Lakebase host not found for instance '{instance_name}'. " + f"Lakebase host not found for instance '{self.instance_name}'. " "Ensure the instance is running and in AVAILABLE state." ) - self.host: str = resolved_host - self.username: str = self._infer_username() + return resolved_host - self._cached_token: str | None = None - self._cache_ts: float | None = None + def _resolve_autoscaling_host(self) -> str: + """Resolve host via the Lakebase autoscaling postgres API. + + Constructs the branch parent path, lists endpoints, finds the + READ_WRITE endpoint, and extracts the host and endpoint name. + """ + branch_parent = f"projects/{self.project}/branches/{self.branch}" + + try: + endpoints = list(self.workspace_client.postgres.list_endpoints(parent=branch_parent)) + except Exception as exc: + raise ValueError( + f"Unable to list endpoints for '{branch_parent}'. " + "Ensure the project and branch names are correct." + ) from exc + + # Find the READ_WRITE endpoint + rw_endpoint = None + for ep in endpoints: + ep_status = getattr(ep, "status", None) + ep_type = getattr(ep_status, "endpoint_type", None) + if ep_type and "READ_WRITE" in str(ep_type): + rw_endpoint = ep + break + + if rw_endpoint is None: + raise ValueError( + f"No READ_WRITE endpoint found for '{branch_parent}'. " + "Ensure the branch has an active READ_WRITE endpoint." + ) + + # Extract host from endpoint status + ep_status = rw_endpoint.status + hosts = getattr(ep_status, "hosts", None) + resolved_host = getattr(hosts, "host", None) if hosts else None + + if not resolved_host: + raise ValueError( + f"Host not found on READ_WRITE endpoint for '{branch_parent}'. " + "Ensure the endpoint is in AVAILABLE state." + ) + + self._endpoint_name = rw_endpoint.name + return resolved_host + + # --- Token caching --- def _get_cached_token(self) -> str | None: """Check if the cached token is still valid.""" @@ -141,6 +230,11 @@ def _get_cached_token(self) -> str | None: return None def _mint_token(self) -> str: + if self._is_autoscaling: + return self._mint_token_autoscaling() + return self._mint_token_provisioned() + + def _mint_token_provisioned(self) -> str: try: cred = self.workspace_client.database.generate_database_credential( request_id=str(uuid.uuid4()), @@ -157,6 +251,22 @@ def _mint_token(self) -> str: return cred.token + def _mint_token_autoscaling(self) -> str: + try: + cred = self.workspace_client.postgres.generate_database_credential( + endpoint=self._endpoint_name, + ) + except Exception as exc: + raise ConnectionError( + f"Failed to obtain credential for Lakebase autoscaling endpoint " + f"'{self._endpoint_name}'. Ensure the caller has access." + ) from exc + + if not cred.token: + raise RuntimeError("Failed to generate database credential: no token received") + + return cred.token + def _conninfo(self) -> str: """Build the connection info string.""" return ( @@ -178,19 +288,27 @@ def _infer_username(self) -> str: class LakebasePool(_LakebaseBase): """Sync Lakebase connection pool built on psycopg with rotating credentials. - instance_name: Name of Lakebase Instance + Supports two modes: Lakebase Provisioned VS Autoscaling + https://docs.databricks.com/aws/en/oltp/#feature-comparison + + - **Provisioned**: Pass ``instance_name``. + - **Autoscaling**: Pass ``project`` and ``branch``. """ def __init__( self, *, - instance_name: str, + instance_name: str | None = None, + project: str | None = None, + branch: str | None = None, workspace_client: WorkspaceClient | None = None, token_cache_duration_seconds: int = DEFAULT_TOKEN_CACHE_DURATION_SECONDS, **pool_kwargs: dict[str, Any], ) -> None: super().__init__( instance_name=instance_name, + project=project, + branch=branch, workspace_client=workspace_client, token_cache_duration_seconds=token_cache_duration_seconds, ) @@ -272,19 +390,27 @@ def close(self) -> None: class AsyncLakebasePool(_LakebaseBase): """Async Lakebase connection pool built on psycopg with rotating credentials. - instance_name: Name of Lakebase Instance + Supports two modes: Lakebase Provisioned VS Autoscaling + https://docs.databricks.com/aws/en/oltp/#feature-comparison + + - **Provisioned**: Pass ``instance_name``. + - **Autoscaling**: Pass ``project`` and ``branch``. """ def __init__( self, *, - instance_name: str, + instance_name: str | None = None, + project: str | None = None, + branch: str | None = None, workspace_client: WorkspaceClient | None = None, token_cache_duration_seconds: int = DEFAULT_TOKEN_CACHE_DURATION_SECONDS, **pool_kwargs: object, ) -> None: super().__init__( instance_name=instance_name, + project=project, + branch=branch, workspace_client=workspace_client, token_cache_duration_seconds=token_cache_duration_seconds, ) @@ -444,6 +570,8 @@ def __init__( *, pool: LakebasePool | None = None, instance_name: str | None = None, + project: str | None = None, + branch: str | None = None, **pool_kwargs: Any, ) -> None: """ @@ -451,18 +579,27 @@ def __init__( Provide EITHER: - pool: An existing LakebasePool instance (advanced usage where multiple clients can connect to same pool) - - instance_name: Name of the Lakebase instance (creates pool internally) + - instance_name: Name of the Lakebase provisioned instance + - project + branch: Lakebase autoscaling project and branch names :param pool: Existing LakebasePool to use for connections. - :param instance_name: Name of the Lakebase instance (used to create pool if pool not provided). - :param workspace_client: Optional WorkspaceClient (only used when creating pool internally). + :param instance_name: Name of the Lakebase provisioned instance. + :param project: Lakebase autoscaling project name. Also requires ``branch``. + :param branch: Lakebase autoscaling branch name. Also requires ``project``. :param pool_kwargs: Additional kwargs passed to LakebasePool (only used when creating pool internally). """ - if pool is not None and instance_name is not None: - raise ValueError("Provide either 'pool' or 'instance_name', not both.") + has_connection_params = instance_name is not None or project is not None or branch is not None + if pool is not None and has_connection_params: + raise ValueError( + "Provide either 'pool' or connection parameters " + "('instance_name' or 'project'/'branch'), not both." + ) - if pool is None and instance_name is None: - raise ValueError("Must provide either 'pool' or 'instance_name'.") + if pool is None and not has_connection_params: + raise ValueError( + "Must provide 'pool', 'instance_name' (provisioned), " + "or both 'project' and 'branch' (autoscaling)." + ) self._owns_pool = pool is None @@ -470,7 +607,9 @@ def __init__( self._pool = pool else: self._pool = LakebasePool( - instance_name=instance_name, # type: ignore[arg-type] + instance_name=instance_name, + project=project, + branch=branch, **pool_kwargs, ) @@ -912,7 +1051,9 @@ class AsyncLakebaseSQLAlchemy(_LakebaseBase): def __init__( self, *, - instance_name: str, + instance_name: str | None = None, + project: str | None = None, + branch: str | None = None, workspace_client: WorkspaceClient | None = None, token_cache_duration_seconds: int = DEFAULT_TOKEN_CACHE_DURATION_SECONDS, pool_recycle: int = DEFAULT_POOL_RECYCLE_SECONDS, @@ -922,7 +1063,9 @@ def __init__( Initialize AsyncLakebaseSQLAlchemy for Databricks Lakebase. Args: - instance_name: Name of the Lakebase instance. + instance_name: Name of the Lakebase provisioned instance. + project: Lakebase autoscaling project name. Also requires ``branch``. + branch: Lakebase autoscaling branch name. Also requires ``project``. workspace_client: Optional WorkspaceClient for authentication. If not provided, a default client will be created. token_cache_duration_seconds: How long to cache OAuth tokens. @@ -934,6 +1077,8 @@ def __init__( """ super().__init__( instance_name=instance_name, + project=project, + branch=branch, workspace_client=workspace_client, token_cache_duration_seconds=token_cache_duration_seconds, ) @@ -944,8 +1089,7 @@ def __init__( self._engine = self._create_engine(**engine_kwargs) logger.info( - "AsyncLakebaseSQLAlchemy initialized: instance=%s host=%s", - instance_name, + "AsyncLakebaseSQLAlchemy initialized: host=%s", self.host, ) diff --git a/tests/databricks_ai_bridge/test_lakebase.py b/tests/databricks_ai_bridge/test_lakebase.py index a75ead1f..44812a33 100644 --- a/tests/databricks_ai_bridge/test_lakebase.py +++ b/tests/databricks_ai_bridge/test_lakebase.py @@ -444,7 +444,10 @@ class TestLakebaseClientInit: def test_client_requires_pool_or_instance_name(self): """Client must be given either pool or instance_name.""" - with pytest.raises(ValueError, match="Must provide either 'pool' or 'instance_name'"): + with pytest.raises( + ValueError, + match="Must provide 'pool', 'instance_name' .provisioned., or both 'project' and 'branch' .autoscaling.", + ): LakebaseClient() @@ -1294,3 +1297,361 @@ def test_async_lakebase_sqlalchemy_invalid_instance_raises(): instance_name="bad-instance", workspace_client=workspace, ) + + +# ============================================================================= +# Autoscaling (project/branch) Tests +# ============================================================================= + + +def _make_autoscaling_workspace( + *, + user_name: str = "test@databricks.com", + credential_token: str = "autoscaling-token-1", + host: str = "autoscaling.db.host", + endpoint_name: str = "projects/my-project/branches/my-branch/endpoints/rw-ep", +): + """Create a mock workspace client for autoscaling (project/branch) mode.""" + workspace = MagicMock() + workspace.current_user.me.return_value = MagicMock(user_name=user_name) + + # Mock postgres.generate_database_credential + workspace.postgres.generate_database_credential.return_value = MagicMock( + token=credential_token + ) + + # Mock postgres.list_endpoints → returns one READ_WRITE endpoint + rw_endpoint = MagicMock() + rw_endpoint.name = endpoint_name + rw_endpoint.status.endpoint_type = "READ_WRITE" + rw_endpoint.status.hosts.host = host + workspace.postgres.list_endpoints.return_value = [rw_endpoint] + + return workspace + + +# --- Parameter validation tests --- + + +def test_autoscaling_requires_both_project_and_branch(): + """Passing only project without branch raises ValueError.""" + workspace = _make_autoscaling_workspace() + with pytest.raises(ValueError, match="Both 'project' and 'branch' are required"): + LakebasePool( + project="my-project", + workspace_client=workspace, + ) + + +def test_autoscaling_requires_both_branch_and_project(): + """Passing only branch without project raises ValueError.""" + workspace = _make_autoscaling_workspace() + with pytest.raises(ValueError, match="Both 'project' and 'branch' are required"): + LakebasePool( + branch="my-branch", + workspace_client=workspace, + ) + + +def test_no_params_raises_error(): + """Passing no connection parameters raises ValueError.""" + workspace = _make_autoscaling_workspace() + with pytest.raises(ValueError, match="Must provide either 'instance_name'"): + LakebasePool(workspace_client=workspace) + + +def test_async_pool_no_params_raises_error(): + """AsyncLakebasePool with no connection parameters raises ValueError.""" + workspace = _make_autoscaling_workspace() + with pytest.raises(ValueError, match="Must provide either 'instance_name'"): + AsyncLakebasePool(workspace_client=workspace) + + +def test_async_pool_only_project_raises_error(): + """AsyncLakebasePool with only project raises ValueError.""" + workspace = _make_autoscaling_workspace() + with pytest.raises(ValueError, match="Both 'project' and 'branch' are required"): + AsyncLakebasePool(project="my-project", workspace_client=workspace) + + +def test_lakebase_client_no_params_raises_error(): + """LakebaseClient with no pool or connection parameters raises ValueError.""" + with pytest.raises( + ValueError, + match="Must provide 'pool', 'instance_name' .provisioned., or both 'project' and 'branch' .autoscaling.", + ): + LakebaseClient() + + +def test_lakebase_client_only_branch_raises_error(monkeypatch): + """LakebaseClient with only branch (no project) raises ValueError.""" + TestConnectionPool = _make_connection_pool_class() + monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) + + workspace = _make_autoscaling_workspace() + with pytest.raises(ValueError, match="Both 'project' and 'branch' are required"): + LakebaseClient(branch="my-branch", workspace_client=workspace) + + +def test_async_sqlalchemy_no_params_raises_error(): + """AsyncLakebaseSQLAlchemy with no connection parameters raises ValueError.""" + workspace = _make_autoscaling_workspace() + patch_engine, patch_event, _ = _make_sqlalchemy_patches(workspace) + + with patch_engine, patch_event: + with pytest.raises(ValueError, match="Must provide either 'instance_name'"): + AsyncLakebaseSQLAlchemy(workspace_client=workspace) + + +def test_async_sqlalchemy_only_project_raises_error(): + """AsyncLakebaseSQLAlchemy with only project (no branch) raises ValueError.""" + workspace = _make_autoscaling_workspace() + patch_engine, patch_event, _ = _make_sqlalchemy_patches(workspace) + + with patch_engine, patch_event: + with pytest.raises(ValueError, match="Both 'project' and 'branch' are required"): + AsyncLakebaseSQLAlchemy(project="my-project", workspace_client=workspace) + + +def test_autoscaling_takes_precedence_over_provisioned(monkeypatch): + """When both instance_name and project/branch are provided, autoscaling takes precedence.""" + TestConnectionPool = _make_connection_pool_class() + monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) + + workspace = _make_autoscaling_workspace(host="autoscaling.db.host") + # Also set up provisioned mocks (should NOT be used) + instance = MagicMock() + instance.read_write_dns = "provisioned.db.host" + workspace.database.get_database_instance.return_value = instance + + pool = LakebasePool( + instance_name="my-instance", + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) + + # Should use autoscaling host, not provisioned + assert pool.host == "autoscaling.db.host" + assert pool._is_autoscaling is True + # Provisioned API should NOT have been called + workspace.database.get_database_instance.assert_not_called() + + +# --- LakebasePool autoscaling tests --- + + +def test_lakebase_pool_autoscaling_configures_connection_pool(monkeypatch): + """LakebasePool with project/branch resolves host via autoscaling API.""" + TestConnectionPool = _make_connection_pool_class() + monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) + + workspace = _make_autoscaling_workspace(host="auto.db.host") + + pool = LakebasePool( + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) + + assert pool.host == "auto.db.host" + assert pool._is_autoscaling is True + assert pool.username == "test@databricks.com" + assert "host=auto.db.host" in pool.pool.conninfo + + workspace.postgres.list_endpoints.assert_called_once_with( + parent="projects/my-project/branches/my-branch" + ) + + +def test_lakebase_pool_autoscaling_mints_token(monkeypatch): + """Autoscaling pool uses postgres.generate_database_credential for tokens.""" + TestConnectionPool = _make_connection_pool_class() + monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) + + workspace = _make_autoscaling_workspace(credential_token="auto-token") + + pool = LakebasePool( + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) + + token = pool._get_token() + assert token == "auto-token" + workspace.postgres.generate_database_credential.assert_called_once_with( + endpoint="projects/my-project/branches/my-branch/endpoints/rw-ep" + ) + # Provisioned credential API should NOT be called + workspace.database.generate_database_credential.assert_not_called() + + +def test_lakebase_pool_provisioned_mints_token(monkeypatch): + """Provisioned pool uses database.generate_database_credential for tokens.""" + TestConnectionPool = _make_connection_pool_class() + monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) + + workspace = _make_workspace(credential_token="provisioned-token") + + pool = LakebasePool( + instance_name="lake-instance", + workspace_client=workspace, + ) + + token = pool._get_token() + assert token == "provisioned-token" + workspace.database.generate_database_credential.assert_called_once() + # Autoscaling credential API should NOT be called + workspace.postgres.generate_database_credential.assert_not_called() + + +def test_lakebase_pool_autoscaling_no_rw_endpoint_raises(monkeypatch): + """Raises ValueError when no READ_WRITE endpoint is found.""" + TestConnectionPool = _make_connection_pool_class() + monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) + + workspace = _make_autoscaling_workspace() + # Return only a READ_ONLY endpoint + ro_endpoint = MagicMock() + ro_endpoint.status.endpoint_type = "READ_ONLY" + ro_endpoint.status.hosts.host = "ro.host" + workspace.postgres.list_endpoints.return_value = [ro_endpoint] + + with pytest.raises(ValueError, match="No READ_WRITE endpoint found"): + LakebasePool( + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) + + +def test_lakebase_pool_autoscaling_list_endpoints_fails_raises(monkeypatch): + """Raises ValueError when list_endpoints fails.""" + TestConnectionPool = _make_connection_pool_class() + monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) + + workspace = _make_autoscaling_workspace() + workspace.postgres.list_endpoints.side_effect = Exception("Not found") + + with pytest.raises(ValueError, match="Unable to list endpoints"): + LakebasePool( + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) + + +# --- AsyncLakebasePool autoscaling tests --- + + +@pytest.mark.asyncio +async def test_async_lakebase_pool_autoscaling_configures_pool(monkeypatch): + """AsyncLakebasePool with project/branch resolves host via autoscaling API.""" + TestAsyncConnectionPool = _make_async_connection_pool_class() + monkeypatch.setattr( + "databricks_ai_bridge.lakebase.AsyncConnectionPool", TestAsyncConnectionPool + ) + + workspace = _make_autoscaling_workspace(host="async-auto.db.host") + + pool = AsyncLakebasePool( + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) + + assert pool.host == "async-auto.db.host" + assert pool._is_autoscaling is True + assert "host=async-auto.db.host" in pool.pool.conninfo + + workspace.postgres.list_endpoints.assert_called_once_with( + parent="projects/my-project/branches/my-branch" + ) + + +# --- LakebaseClient autoscaling tests --- + + +def test_lakebase_client_autoscaling_creates_pool(monkeypatch): + """LakebaseClient with project/branch creates an autoscaling pool internally.""" + TestConnectionPool = _make_connection_pool_class() + monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) + + workspace = _make_autoscaling_workspace(host="client-auto.db.host") + + client = LakebaseClient( + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) + + assert client.pool.host == "client-auto.db.host" + assert client.pool._is_autoscaling is True + assert client._owns_pool is True + + +def test_lakebase_client_rejects_pool_plus_autoscaling_params(monkeypatch): + """LakebaseClient rejects passing both pool and project/branch.""" + pool = MagicMock(spec=LakebasePool) + + with pytest.raises(ValueError, match="Provide either 'pool' or connection parameters"): + LakebaseClient( + pool=pool, + project="my-project", + branch="my-branch", + ) + + +# --- AsyncLakebaseSQLAlchemy autoscaling tests --- + + +def test_async_lakebase_sqlalchemy_autoscaling_resolves_host(): + """AsyncLakebaseSQLAlchemy with project/branch resolves via autoscaling API.""" + workspace = _make_autoscaling_workspace(host="sa-auto.db.host") + patch_engine, patch_event, _ = _make_sqlalchemy_patches(workspace) + + with patch_engine, patch_event: + sa = AsyncLakebaseSQLAlchemy( + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) + + assert sa.host == "sa-auto.db.host" + assert sa._is_autoscaling is True + workspace.postgres.list_endpoints.assert_called_once_with( + parent="projects/my-project/branches/my-branch" + ) + + +def test_async_lakebase_sqlalchemy_autoscaling_mints_correct_token(): + """AsyncLakebaseSQLAlchemy in autoscaling mode uses postgres credential API.""" + workspace = _make_autoscaling_workspace(credential_token="sa-auto-token") + patch_engine, patch_event, _ = _make_sqlalchemy_patches(workspace) + + with patch_engine, patch_event: + sa = AsyncLakebaseSQLAlchemy( + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) + + token = sa.get_token() + assert token == "sa-auto-token" + workspace.postgres.generate_database_credential.assert_called_once() + + +def test_async_lakebase_sqlalchemy_provisioned_mints_correct_token(): + """AsyncLakebaseSQLAlchemy in provisioned mode uses database credential API.""" + workspace = _make_workspace(credential_token="sa-prov-token") + patch_engine, patch_event, _ = _make_sqlalchemy_patches(workspace) + + with patch_engine, patch_event: + sa = AsyncLakebaseSQLAlchemy( + instance_name="lake-instance", + workspace_client=workspace, + ) + + token = sa.get_token() + assert token == "sa-prov-token" + workspace.database.generate_database_credential.assert_called_once() From eda7923c3208263ff0e7ccb06c1c804e09adb427 Mon Sep 17 00:00:00 2001 From: Jenny Date: Mon, 2 Mar 2026 16:53:50 -0800 Subject: [PATCH 2/8] typecheck fixes --- .../src/databricks_openai/agents/session.py | 6 ++++ src/databricks_ai_bridge/lakebase.py | 13 +++++-- tests/databricks_ai_bridge/test_lakebase.py | 34 ++++++++----------- 3 files changed, 30 insertions(+), 23 deletions(-) diff --git a/integrations/openai/src/databricks_openai/agents/session.py b/integrations/openai/src/databricks_openai/agents/session.py index bb3f94bd..7e8a7dd4 100644 --- a/integrations/openai/src/databricks_openai/agents/session.py +++ b/integrations/openai/src/databricks_openai/agents/session.py @@ -151,6 +151,12 @@ def __init__( "Must provide either 'instance_name' (provisioned) or both " "'project' and 'branch' (autoscaling)." ) + if is_autoscaling and instance_name is not None: + raise ValueError( + "Cannot provide both 'instance_name' (provisioned) and " + "'project'/'branch' (autoscaling). Pass in the set of parameters " + "that correspond to your Lakebase instance." + ) self._lakebase = self._get_or_create_lakebase( instance_name=instance_name, diff --git a/src/databricks_ai_bridge/lakebase.py b/src/databricks_ai_bridge/lakebase.py index 32fb9036..7c82a28f 100644 --- a/src/databricks_ai_bridge/lakebase.py +++ b/src/databricks_ai_bridge/lakebase.py @@ -99,8 +99,8 @@ class _LakebaseBase: - **Provisioned**: Pass ``instance_name``. - **Autoscaling**: Pass ``project`` and ``branch``. - When both ``instance_name`` *and* ``project``/``branch`` are provided, the - autoscaling path takes precedence. + Providing both ``instance_name`` *and* ``project``/``branch`` raises a + ``ValueError``; choose one mode. Subclasses implement specific initialization and lifecycle methods. """ @@ -131,7 +131,12 @@ def __init__( "'project' and 'branch' (autoscaling)." ) - # Autoscaling takes precedence when both are provided + if is_autoscaling and instance_name is not None: + raise ValueError( + "Cannot provide both 'instance_name' (provisioned) and " + "'project'/'branch' (autoscaling). Choose one mode." + ) + self._is_autoscaling: bool = is_autoscaling self.instance_name: str | None = instance_name @@ -236,6 +241,7 @@ def _mint_token(self) -> str: def _mint_token_provisioned(self) -> str: try: + assert self.instance_name is not None cred = self.workspace_client.database.generate_database_credential( request_id=str(uuid.uuid4()), instance_names=[self.instance_name], @@ -253,6 +259,7 @@ def _mint_token_provisioned(self) -> str: def _mint_token_autoscaling(self) -> str: try: + assert self._endpoint_name is not None cred = self.workspace_client.postgres.generate_database_credential( endpoint=self._endpoint_name, ) diff --git a/tests/databricks_ai_bridge/test_lakebase.py b/tests/databricks_ai_bridge/test_lakebase.py index 44812a33..3599c751 100644 --- a/tests/databricks_ai_bridge/test_lakebase.py +++ b/tests/databricks_ai_bridge/test_lakebase.py @@ -1413,29 +1413,23 @@ def test_async_sqlalchemy_only_project_raises_error(): AsyncLakebaseSQLAlchemy(project="my-project", workspace_client=workspace) -def test_autoscaling_takes_precedence_over_provisioned(monkeypatch): - """When both instance_name and project/branch are provided, autoscaling takes precedence.""" +def test_both_provisioned_and_autoscaling_raises_error(monkeypatch): + """Providing both instance_name and project/branch raises ValueError.""" TestConnectionPool = _make_connection_pool_class() monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) workspace = _make_autoscaling_workspace(host="autoscaling.db.host") - # Also set up provisioned mocks (should NOT be used) - instance = MagicMock() - instance.read_write_dns = "provisioned.db.host" - workspace.database.get_database_instance.return_value = instance - pool = LakebasePool( - instance_name="my-instance", - project="my-project", - branch="my-branch", - workspace_client=workspace, - ) - - # Should use autoscaling host, not provisioned - assert pool.host == "autoscaling.db.host" - assert pool._is_autoscaling is True - # Provisioned API should NOT have been called - workspace.database.get_database_instance.assert_not_called() + with pytest.raises( + ValueError, + match="Cannot provide both 'instance_name' .provisioned. and 'project'/'branch' .autoscaling.", + ): + LakebasePool( + instance_name="my-instance", + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) # --- LakebasePool autoscaling tests --- @@ -1457,7 +1451,7 @@ def test_lakebase_pool_autoscaling_configures_connection_pool(monkeypatch): assert pool.host == "auto.db.host" assert pool._is_autoscaling is True assert pool.username == "test@databricks.com" - assert "host=auto.db.host" in pool.pool.conninfo + assert "host=auto.db.host" in str(pool.pool.conninfo) workspace.postgres.list_endpoints.assert_called_once_with( parent="projects/my-project/branches/my-branch" @@ -1562,7 +1556,7 @@ async def test_async_lakebase_pool_autoscaling_configures_pool(monkeypatch): assert pool.host == "async-auto.db.host" assert pool._is_autoscaling is True - assert "host=async-auto.db.host" in pool.pool.conninfo + assert "host=async-auto.db.host" in str(pool.pool.conninfo) workspace.postgres.list_endpoints.assert_called_once_with( parent="projects/my-project/branches/my-branch" From 68b0460e39d6a329d57257ca82c33efa9af9d53a Mon Sep 17 00:00:00 2001 From: Jenny Date: Mon, 2 Mar 2026 16:56:57 -0800 Subject: [PATCH 3/8] lint fixes --- src/databricks_ai_bridge/lakebase.py | 5 ++++- tests/databricks_ai_bridge/test_lakebase.py | 4 +--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/databricks_ai_bridge/lakebase.py b/src/databricks_ai_bridge/lakebase.py index 7c82a28f..ab94981b 100644 --- a/src/databricks_ai_bridge/lakebase.py +++ b/src/databricks_ai_bridge/lakebase.py @@ -160,6 +160,7 @@ def __init__( def _resolve_provisioned_host(self) -> str: """Resolve host via the Lakebase provisioned database API.""" try: + assert self.instance_name is not None instance = self.workspace_client.database.get_database_instance(self.instance_name) except Exception as exc: raise ValueError( @@ -595,7 +596,9 @@ def __init__( :param branch: Lakebase autoscaling branch name. Also requires ``project``. :param pool_kwargs: Additional kwargs passed to LakebasePool (only used when creating pool internally). """ - has_connection_params = instance_name is not None or project is not None or branch is not None + has_connection_params = ( + instance_name is not None or project is not None or branch is not None + ) if pool is not None and has_connection_params: raise ValueError( "Provide either 'pool' or connection parameters " diff --git a/tests/databricks_ai_bridge/test_lakebase.py b/tests/databricks_ai_bridge/test_lakebase.py index 3599c751..5967f8c7 100644 --- a/tests/databricks_ai_bridge/test_lakebase.py +++ b/tests/databricks_ai_bridge/test_lakebase.py @@ -1316,9 +1316,7 @@ def _make_autoscaling_workspace( workspace.current_user.me.return_value = MagicMock(user_name=user_name) # Mock postgres.generate_database_credential - workspace.postgres.generate_database_credential.return_value = MagicMock( - token=credential_token - ) + workspace.postgres.generate_database_credential.return_value = MagicMock(token=credential_token) # Mock postgres.list_endpoints → returns one READ_WRITE endpoint rw_endpoint = MagicMock() From ac6510260273e6585337abdca0136ea0ea105f1f Mon Sep 17 00:00:00 2001 From: Jenny Date: Mon, 2 Mar 2026 19:21:06 -0800 Subject: [PATCH 4/8] include setup in context manager to make sure checkpoint tables are ready --- .../langchain/src/databricks_langchain/checkpoint.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/checkpoint.py b/integrations/langchain/src/databricks_langchain/checkpoint.py index 5a8df35f..93ff381c 100644 --- a/integrations/langchain/src/databricks_langchain/checkpoint.py +++ b/integrations/langchain/src/databricks_langchain/checkpoint.py @@ -51,7 +51,8 @@ def __init__( super().__init__(self._lakebase.pool) def __enter__(self): - """Enter context manager.""" + """Enter context manager and create checkpoint tables.""" + self.setup() return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -66,6 +67,8 @@ class AsyncCheckpointSaver(AsyncPostgresSaver): Supports two modes: Lakebase Provisioned VS Autoscaling https://docs.databricks.com/aws/en/oltp/#feature-comparison + + Checkpoint tables are created automatically when entering the context manager. """ def __init__( @@ -94,8 +97,9 @@ def __init__( super().__init__(self._lakebase.pool) async def __aenter__(self): - """Enter async context manager and open the connection pool.""" + """Enter async context manager, open the connection pool, and create checkpoint tables.""" await self._lakebase.open() + await self.setup() return self async def __aexit__(self, exc_type, exc_val, exc_tb): From af4dcf76240e99bacaf76073cf2bffa8f460f741 Mon Sep 17 00:00:00 2001 From: Jenny Date: Tue, 3 Mar 2026 14:16:11 -0800 Subject: [PATCH 5/8] better error messaging w/ instructions on incorrect lakebase instances + fix langchain tests --- .../tests/unit_tests/test_checkpoint.py | 8 +++++++ .../openai/tests/unit_tests/test_session.py | 2 +- src/databricks_ai_bridge/lakebase.py | 23 ++++++++++++------- tests/databricks_ai_bridge/test_lakebase.py | 2 +- 4 files changed, 25 insertions(+), 10 deletions(-) diff --git a/integrations/langchain/tests/unit_tests/test_checkpoint.py b/integrations/langchain/tests/unit_tests/test_checkpoint.py index 3d273b48..16b1c375 100644 --- a/integrations/langchain/tests/unit_tests/test_checkpoint.py +++ b/integrations/langchain/tests/unit_tests/test_checkpoint.py @@ -12,6 +12,11 @@ from databricks_langchain import AsyncCheckpointSaver, CheckpointSaver +async def _async_noop(): + """No-op coroutine used to mock async setup() in tests.""" + pass + + class TestConnectionPool: def __init__(self, connection_value="conn"): self.connection_value = connection_value @@ -134,6 +139,7 @@ async def test_async_checkpoint_saver_configures_lakebase(monkeypatch): async def test_async_checkpoint_saver_context_manager(monkeypatch): test_pool = TestAsyncConnectionPool(connection_value="async-lake-conn") monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool) + monkeypatch.setattr(AsyncCheckpointSaver, "setup", lambda self: _async_noop()) workspace = MagicMock() workspace.database.generate_database_credential.return_value = MagicMock(token="stub-token") @@ -155,6 +161,7 @@ async def test_async_checkpoint_saver_context_manager(monkeypatch): async def test_async_checkpoint_saver_connection(monkeypatch): test_pool = TestAsyncConnectionPool(connection_value="async-lake-conn") monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool) + monkeypatch.setattr(AsyncCheckpointSaver, "setup", lambda self: _async_noop()) workspace = MagicMock() workspace.database.generate_database_credential.return_value = MagicMock(token="stub-token") @@ -230,6 +237,7 @@ async def test_async_checkpoint_saver_autoscaling_configures_lakebase(monkeypatc async def test_async_checkpoint_saver_autoscaling_context_manager(monkeypatch): test_pool = TestAsyncConnectionPool(connection_value="async-lake-conn") monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool) + monkeypatch.setattr(AsyncCheckpointSaver, "setup", lambda self: _async_noop()) workspace = _create_autoscaling_workspace() diff --git a/integrations/openai/tests/unit_tests/test_session.py b/integrations/openai/tests/unit_tests/test_session.py index 69effe49..3d50ae98 100644 --- a/integrations/openai/tests/unit_tests/test_session.py +++ b/integrations/openai/tests/unit_tests/test_session.py @@ -220,7 +220,7 @@ def test_init_raises_on_invalid_instance(self, mock_workspace_client): ): from databricks_ai_bridge.lakebase import AsyncLakebaseSQLAlchemy - with pytest.raises(ValueError, match="Unable to resolve Lakebase instance"): + with pytest.raises(ValueError, match="Unable to resolve Lakebase provisioned instance"): AsyncLakebaseSQLAlchemy( instance_name="invalid-instance", workspace_client=mock_workspace_client, diff --git a/src/databricks_ai_bridge/lakebase.py b/src/databricks_ai_bridge/lakebase.py index ab94981b..e5c113eb 100644 --- a/src/databricks_ai_bridge/lakebase.py +++ b/src/databricks_ai_bridge/lakebase.py @@ -164,8 +164,10 @@ def _resolve_provisioned_host(self) -> str: instance = self.workspace_client.database.get_database_instance(self.instance_name) except Exception as exc: raise ValueError( - f"Unable to resolve Lakebase instance '{self.instance_name}'. " - "Ensure the instance name is correct." + f"Unable to resolve Lakebase provisioned instance '{self.instance_name}'. " + "Verify the instance name is correct.\n" + "To list available instances, use:\n" + " workspace_client.database.list_database_instances()" ) from exc resolved_host = getattr(instance, "read_write_dns", None) or getattr( @@ -192,8 +194,11 @@ def _resolve_autoscaling_host(self) -> str: endpoints = list(self.workspace_client.postgres.list_endpoints(parent=branch_parent)) except Exception as exc: raise ValueError( - f"Unable to list endpoints for '{branch_parent}'. " - "Ensure the project and branch names are correct." + f"Unable to list endpoints for project='{self.project}', " + f"branch='{self.branch}'. Verify your project and branch names.\n" + "To find available projects and branches, use:\n" + " workspace_client.postgres.list_projects()\n" + ' workspace_client.postgres.list_branches(parent="projects/")' ) from exc # Find the READ_WRITE endpoint @@ -207,8 +212,10 @@ def _resolve_autoscaling_host(self) -> str: if rw_endpoint is None: raise ValueError( - f"No READ_WRITE endpoint found for '{branch_parent}'. " - "Ensure the branch has an active READ_WRITE endpoint." + f"No READ_WRITE endpoint found for project='{self.project}', " + f"branch='{self.branch}'. Ensure the branch has an active endpoint.\n" + "To check endpoints, use:\n" + f' workspace_client.postgres.list_endpoints(parent="{branch_parent}")' ) # Extract host from endpoint status @@ -218,8 +225,8 @@ def _resolve_autoscaling_host(self) -> str: if not resolved_host: raise ValueError( - f"Host not found on READ_WRITE endpoint for '{branch_parent}'. " - "Ensure the endpoint is in AVAILABLE state." + f"Host not found on READ_WRITE endpoint for project='{self.project}', " + f"branch='{self.branch}'. Ensure the endpoint is in AVAILABLE state." ) self._endpoint_name = rw_endpoint.name diff --git a/tests/databricks_ai_bridge/test_lakebase.py b/tests/databricks_ai_bridge/test_lakebase.py index 5967f8c7..bd8c4b95 100644 --- a/tests/databricks_ai_bridge/test_lakebase.py +++ b/tests/databricks_ai_bridge/test_lakebase.py @@ -1292,7 +1292,7 @@ def test_async_lakebase_sqlalchemy_invalid_instance_raises(): patch_engine, patch_event, _ = _make_sqlalchemy_patches(workspace) with patch_engine, patch_event: - with pytest.raises(ValueError, match="Unable to resolve Lakebase instance"): + with pytest.raises(ValueError, match="Unable to resolve Lakebase provisioned instance"): AsyncLakebaseSQLAlchemy( instance_name="bad-instance", workspace_client=workspace, From 05cef0e712be1c5d96fef2e6796f904dd14e9a2a Mon Sep 17 00:00:00 2001 From: Jenny Date: Tue, 3 Mar 2026 16:54:49 -0800 Subject: [PATCH 6/8] support endpoint and parent parameters --- .../src/databricks_langchain/checkpoint.py | 8 + .../src/databricks_langchain/store.py | 16 + .../tests/unit_tests/test_checkpoint.py | 127 +++++ .../langchain/tests/unit_tests/test_store.py | 138 +++++ .../src/databricks_openai/agents/session.py | 62 ++- .../openai/tests/unit_tests/test_session.py | 513 ++++++++++++++++++ src/databricks_ai_bridge/lakebase.py | 147 ++++- tests/databricks_ai_bridge/test_lakebase.py | 462 +++++++++++++++- 8 files changed, 1440 insertions(+), 33 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/checkpoint.py b/integrations/langchain/src/databricks_langchain/checkpoint.py index 93ff381c..baf283bb 100644 --- a/integrations/langchain/src/databricks_langchain/checkpoint.py +++ b/integrations/langchain/src/databricks_langchain/checkpoint.py @@ -31,6 +31,8 @@ def __init__( instance_name: str | None = None, project: str | None = None, branch: str | None = None, + endpoint: str | None = None, + parent: str | None = None, workspace_client: WorkspaceClient | None = None, **pool_kwargs: Any, ) -> None: @@ -45,6 +47,8 @@ def __init__( instance_name=instance_name, project=project, branch=branch, + endpoint=endpoint, + parent=parent, workspace_client=workspace_client, **dict(pool_kwargs), ) @@ -77,6 +81,8 @@ def __init__( instance_name: str | None = None, project: str | None = None, branch: str | None = None, + endpoint: str | None = None, + parent: str | None = None, workspace_client: WorkspaceClient | None = None, **pool_kwargs: Any, ) -> None: @@ -91,6 +97,8 @@ def __init__( instance_name=instance_name, project=project, branch=branch, + endpoint=endpoint, + parent=parent, workspace_client=workspace_client, **dict(pool_kwargs), ) diff --git a/integrations/langchain/src/databricks_langchain/store.py b/integrations/langchain/src/databricks_langchain/store.py index 78f8ae98..60da3b9a 100644 --- a/integrations/langchain/src/databricks_langchain/store.py +++ b/integrations/langchain/src/databricks_langchain/store.py @@ -36,6 +36,8 @@ def __init__( instance_name: str | None = None, project: str | None = None, branch: str | None = None, + endpoint: str | None = None, + parent: str | None = None, workspace_client: WorkspaceClient | None = None, embedding_endpoint: str | None = None, embedding_dims: int | None = None, @@ -49,6 +51,10 @@ def __init__( instance_name: The name of the Lakebase provisioned instance. project: Lakebase autoscaling project name. Also requires ``branch``. branch: Lakebase autoscaling branch name. Also requires ``project``. + endpoint: Lakebase autoscaling endpoint name. + See https://databricks-sdk-py.readthedocs.io/en/latest/dbdataclasses/postgres.html#databricks.sdk.service.postgres.Endpoint + parent: Lakebase autoscaling branch parent string + (e.g., ``"projects/{project_id}/branches/{branch_id}"``). workspace_client: Optional Databricks WorkspaceClient for authentication. embedding_endpoint: Name of the Databricks Model Serving endpoint for embeddings (e.g., "databricks-gte-large-en"). If provided, enables semantic search. @@ -70,6 +76,8 @@ def __init__( instance_name=instance_name, project=project, branch=branch, + endpoint=endpoint, + parent=parent, workspace_client=workspace_client, **pool_kwargs, ) @@ -157,6 +165,8 @@ def __init__( instance_name: str | None = None, project: str | None = None, branch: str | None = None, + endpoint: str | None = None, + parent: str | None = None, workspace_client: WorkspaceClient | None = None, embedding_endpoint: str | None = None, embedding_dims: int | None = None, @@ -170,6 +180,10 @@ def __init__( instance_name: The name of the Lakebase provisioned instance. project: Lakebase autoscaling project name. Also requires ``branch``. branch: Lakebase autoscaling branch name. Also requires ``project``. + endpoint: Lakebase autoscaling endpoint name. + See https://databricks-sdk-py.readthedocs.io/en/latest/dbdataclasses/postgres.html#databricks.sdk.service.postgres.Endpoint + parent: Lakebase autoscaling branch parent string + (e.g., ``"projects/{project_id}/branches/{branch_id}"``). workspace_client: Optional Databricks WorkspaceClient for authentication. embedding_endpoint: Name of the Databricks Model Serving endpoint for embeddings (e.g., "databricks-gte-large-en"). If provided, enables semantic search. @@ -193,6 +207,8 @@ def __init__( instance_name=instance_name, project=project, branch=branch, + endpoint=endpoint, + parent=parent, workspace_client=workspace_client, **pool_kwargs, ) diff --git a/integrations/langchain/tests/unit_tests/test_checkpoint.py b/integrations/langchain/tests/unit_tests/test_checkpoint.py index 16b1c375..c6790e25 100644 --- a/integrations/langchain/tests/unit_tests/test_checkpoint.py +++ b/integrations/langchain/tests/unit_tests/test_checkpoint.py @@ -292,3 +292,130 @@ async def test_async_checkpoint_saver_no_params_raises_error(monkeypatch): with pytest.raises(ValueError, match="Must provide either 'instance_name'"): AsyncCheckpointSaver(workspace_client=workspace) + + +# ============================================================================= +# Autoscaling (direct endpoint) Tests +# ============================================================================= + + +def _create_endpoint_workspace(): + """Helper to create a mock workspace client for direct endpoint mode.""" + workspace = MagicMock() + workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com") + workspace.postgres.generate_database_credential.return_value = MagicMock( + token="endpoint-token" + ) + ep = MagicMock() + ep.host = "ep-db-host" + workspace.postgres.get_endpoint.return_value = ep + return workspace + + +def test_checkpoint_saver_endpoint_configures_lakebase(monkeypatch): + test_pool = TestConnectionPool(connection_value="lake-conn") + monkeypatch.setattr(lakebase, "ConnectionPool", test_pool) + + workspace = _create_endpoint_workspace() + + saver = CheckpointSaver( + endpoint="projects/p/branches/b/endpoints/rw", + workspace_client=workspace, + ) + + assert "host=ep-db-host" in test_pool.conninfo + assert saver._lakebase._is_autoscaling is True + workspace.postgres.get_endpoint.assert_called_once_with( + name="projects/p/branches/b/endpoints/rw" + ) + + +@pytest.mark.asyncio +async def test_async_checkpoint_saver_endpoint_configures_lakebase(monkeypatch): + test_pool = TestAsyncConnectionPool(connection_value="async-lake-conn") + monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool) + + workspace = _create_endpoint_workspace() + + saver = AsyncCheckpointSaver( + endpoint="projects/p/branches/b/endpoints/rw", + workspace_client=workspace, + ) + + assert "host=ep-db-host" in test_pool.conninfo + assert saver._lakebase._is_autoscaling is True + + +@pytest.mark.asyncio +async def test_async_checkpoint_saver_endpoint_context_manager(monkeypatch): + test_pool = TestAsyncConnectionPool(connection_value="async-lake-conn") + monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool) + monkeypatch.setattr(AsyncCheckpointSaver, "setup", lambda self: _async_noop()) + + workspace = _create_endpoint_workspace() + + async with AsyncCheckpointSaver( + endpoint="projects/p/branches/b/endpoints/rw", + workspace_client=workspace, + ) as saver: + assert test_pool._opened + assert saver._lakebase._is_autoscaling is True + + assert test_pool._closed + + +# ============================================================================= +# Autoscaling (parent) Tests +# ============================================================================= + + +def test_checkpoint_saver_parent_configures_lakebase(monkeypatch): + test_pool = TestConnectionPool(connection_value="lake-conn") + monkeypatch.setattr(lakebase, "ConnectionPool", test_pool) + + workspace = _create_autoscaling_workspace() + + saver = CheckpointSaver( + parent="projects/p/branches/b", + workspace_client=workspace, + ) + + assert "host=auto-db-host" in test_pool.conninfo + assert saver._lakebase._is_autoscaling is True + workspace.postgres.list_endpoints.assert_called_once_with( + parent="projects/p/branches/b" + ) + + +@pytest.mark.asyncio +async def test_async_checkpoint_saver_parent_configures_lakebase(monkeypatch): + test_pool = TestAsyncConnectionPool(connection_value="async-lake-conn") + monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool) + + workspace = _create_autoscaling_workspace() + + saver = AsyncCheckpointSaver( + parent="projects/p/branches/b", + workspace_client=workspace, + ) + + assert "host=auto-db-host" in test_pool.conninfo + assert saver._lakebase._is_autoscaling is True + + +@pytest.mark.asyncio +async def test_async_checkpoint_saver_parent_context_manager(monkeypatch): + test_pool = TestAsyncConnectionPool(connection_value="async-lake-conn") + monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool) + monkeypatch.setattr(AsyncCheckpointSaver, "setup", lambda self: _async_noop()) + + workspace = _create_autoscaling_workspace() + + async with AsyncCheckpointSaver( + parent="projects/p/branches/b", + workspace_client=workspace, + ) as saver: + assert test_pool._opened + assert saver._lakebase._is_autoscaling is True + + assert test_pool._closed diff --git a/integrations/langchain/tests/unit_tests/test_store.py b/integrations/langchain/tests/unit_tests/test_store.py index cee740d3..fe49a782 100644 --- a/integrations/langchain/tests/unit_tests/test_store.py +++ b/integrations/langchain/tests/unit_tests/test_store.py @@ -563,3 +563,141 @@ async def test_async_databricks_store_no_params_raises_error(monkeypatch): with pytest.raises(ValueError, match="Must provide either 'instance_name'"): AsyncDatabricksStore(workspace_client=workspace) + + +# ============================================================================= +# Autoscaling (direct endpoint) Tests +# ============================================================================= + + +def _create_endpoint_workspace(): + """Helper to create a mock workspace client for direct endpoint mode.""" + workspace = MagicMock() + workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com") + workspace.database.generate_database_credential.return_value = MagicMock(token="stub-token") + workspace.postgres.generate_database_credential.return_value = MagicMock( + token="endpoint-token" + ) + ep = MagicMock() + ep.host = "ep-db-host" + workspace.postgres.get_endpoint.return_value = ep + return workspace + + +def test_databricks_store_endpoint_configures_lakebase(monkeypatch): + """Test that DatabricksStore with endpoint uses direct endpoint path.""" + mock_conn = MagicMock() + test_pool = TestConnectionPool(connection_value=mock_conn) + monkeypatch.setattr(lakebase, "ConnectionPool", test_pool) + + workspace = _create_endpoint_workspace() + + store = DatabricksStore( + endpoint="projects/p/branches/b/endpoints/rw", + workspace_client=workspace, + ) + + assert "host=ep-db-host" in test_pool.conninfo + assert store._lakebase._is_autoscaling is True + workspace.postgres.get_endpoint.assert_called_once_with( + name="projects/p/branches/b/endpoints/rw" + ) + + +@pytest.mark.asyncio +async def test_async_databricks_store_endpoint_configures_lakebase(monkeypatch): + """Test that AsyncDatabricksStore with endpoint uses direct endpoint path.""" + mock_conn = MagicMock() + test_pool = TestAsyncConnectionPool(connection_value=mock_conn) + monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool) + + workspace = _create_endpoint_workspace() + + store = AsyncDatabricksStore( + endpoint="projects/p/branches/b/endpoints/rw", + workspace_client=workspace, + ) + + assert "host=ep-db-host" in test_pool.conninfo + assert store._lakebase._is_autoscaling is True + + +@pytest.mark.asyncio +async def test_async_databricks_store_endpoint_context_manager(monkeypatch): + """Test endpoint async store context manager opens and closes the pool.""" + mock_conn = MagicMock() + test_pool = TestAsyncConnectionPool(connection_value=mock_conn) + monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool) + + workspace = _create_endpoint_workspace() + + async with AsyncDatabricksStore( + endpoint="projects/p/branches/b/endpoints/rw", + workspace_client=workspace, + ) as store: + assert test_pool._opened + assert store._lakebase._is_autoscaling is True + + assert test_pool._closed + + +# ============================================================================= +# Autoscaling (parent) Tests +# ============================================================================= + + +def test_databricks_store_parent_configures_lakebase(monkeypatch): + """Test that DatabricksStore with parent uses autoscaling path.""" + mock_conn = MagicMock() + test_pool = TestConnectionPool(connection_value=mock_conn) + monkeypatch.setattr(lakebase, "ConnectionPool", test_pool) + + workspace = _create_autoscaling_workspace() + + store = DatabricksStore( + parent="projects/p/branches/b", + workspace_client=workspace, + ) + + assert "host=auto-db-host" in test_pool.conninfo + assert store._lakebase._is_autoscaling is True + workspace.postgres.list_endpoints.assert_called_once_with( + parent="projects/p/branches/b" + ) + + +@pytest.mark.asyncio +async def test_async_databricks_store_parent_configures_lakebase(monkeypatch): + """Test that AsyncDatabricksStore with parent uses autoscaling path.""" + mock_conn = MagicMock() + test_pool = TestAsyncConnectionPool(connection_value=mock_conn) + monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool) + + workspace = _create_autoscaling_workspace() + + store = AsyncDatabricksStore( + parent="projects/p/branches/b", + workspace_client=workspace, + ) + + assert "host=auto-db-host" in test_pool.conninfo + assert store._lakebase._is_autoscaling is True + + +@pytest.mark.asyncio +async def test_async_databricks_store_parent_context_manager(monkeypatch): + """Test parent async store context manager opens and closes the pool.""" + mock_conn = MagicMock() + test_pool = TestAsyncConnectionPool(connection_value=mock_conn) + monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool) + + workspace = _create_autoscaling_workspace() + + async with AsyncDatabricksStore( + parent="projects/p/branches/b", + workspace_client=workspace, + ) as store: + assert test_pool._opened + assert store._lakebase._is_autoscaling is True + + assert test_pool._closed diff --git a/integrations/openai/src/databricks_openai/agents/session.py b/integrations/openai/src/databricks_openai/agents/session.py index 7e8a7dd4..58ed3bba 100644 --- a/integrations/openai/src/databricks_openai/agents/session.py +++ b/integrations/openai/src/databricks_openai/agents/session.py @@ -101,6 +101,8 @@ def __init__( instance_name: Optional[str] = None, project: Optional[str] = None, branch: Optional[str] = None, + endpoint: Optional[str] = None, + parent: Optional[str] = None, workspace_client: Optional[WorkspaceClient] = None, token_cache_duration_seconds: int = DEFAULT_TOKEN_CACHE_DURATION_SECONDS, create_tables: bool = True, @@ -117,6 +119,10 @@ def __init__( instance_name: Name of the Lakebase provisioned instance. project: Lakebase autoscaling project name. Also requires ``branch``. branch: Lakebase autoscaling branch name. Also requires ``project``. + endpoint: Lakebase autoscaling endpoint name. + See https://databricks-sdk-py.readthedocs.io/en/latest/dbdataclasses/postgres.html#databricks.sdk.service.postgres.Endpoint + parent: Lakebase autoscaling branch parent string + (e.g., ``"projects/{project_id}/branches/{branch_id}"``). workspace_client: Optional WorkspaceClient for authentication. If not provided, a default client will be created. token_cache_duration_seconds: How long to cache OAuth tokens. @@ -140,18 +146,48 @@ def __init__( ) # Validate connection parameters early (before cache key creation) - is_autoscaling = project is not None or branch is not None - if is_autoscaling and not (project and branch): + is_autoscaling_branch = project is not None or branch is not None + is_autoscaling_parent = parent is not None + is_autoscaling_endpoint = endpoint is not None + + # Log when multiple autoscaling params given (higher priority wins) + if is_autoscaling_endpoint and (is_autoscaling_parent or is_autoscaling_branch): + logger.info( + "endpoint given alongside other autoscaling parameters " + "- using endpoint value" + ) + elif is_autoscaling_parent and is_autoscaling_branch: + logger.info( + "parent given alongside project/branch " + "- using parent value" + ) + if is_autoscaling_endpoint and instance_name is not None: + raise ValueError( + "Cannot provide both 'endpoint' and 'instance_name'. " + "Use 'endpoint' for autoscaling or 'instance_name' for provisioned." + ) + if is_autoscaling_parent and instance_name is not None: + raise ValueError( + "Cannot provide both 'parent' and 'instance_name'. " + "Use 'parent' for autoscaling or 'instance_name' for provisioned." + ) + if is_autoscaling_branch and not (project and branch): raise ValueError( "Both 'project' and 'branch' are required to use a Lakebase " "autoscaling instance. Please specify both parameters." ) - if not is_autoscaling and instance_name is None: + if ( + not is_autoscaling_branch + and not is_autoscaling_parent + and not is_autoscaling_endpoint + and instance_name is None + ): raise ValueError( - "Must provide either 'instance_name' (provisioned) or both " - "'project' and 'branch' (autoscaling)." + "Must provide either 'instance_name' (provisioned), both " + "'project' and 'branch' (autoscaling), 'parent' (autoscaling), " + "or 'endpoint' (autoscaling)." ) - if is_autoscaling and instance_name is not None: + if is_autoscaling_branch and instance_name is not None: raise ValueError( "Cannot provide both 'instance_name' (provisioned) and " "'project'/'branch' (autoscaling). Pass in the set of parameters " @@ -162,6 +198,8 @@ def __init__( instance_name=instance_name, project=project, branch=branch, + endpoint=endpoint, + parent=parent, workspace_client=workspace_client, token_cache_duration_seconds=token_cache_duration_seconds, pool_recycle=engine_kwargs.pop("pool_recycle", DEFAULT_POOL_RECYCLE_SECONDS), @@ -189,11 +227,17 @@ def _build_cache_key( instance_name: Optional[str] = None, project: Optional[str] = None, branch: Optional[str] = None, + endpoint: Optional[str] = None, + parent: Optional[str] = None, **engine_kwargs: Any, ) -> str: """Build a cache key from connection parameters and engine_kwargs.""" # Sort kwargs for deterministic key; use JSON for serializable values kwargs_key = json.dumps(engine_kwargs, sort_keys=True, default=str) + if endpoint: + return f"endpoint::{endpoint}::{kwargs_key}" + if parent: + return f"parent::{parent}::{kwargs_key}" if project and branch: return f"autoscaling::{project}::{branch}::{kwargs_key}" return f"provisioned::{instance_name}::{kwargs_key}" @@ -205,6 +249,8 @@ def _get_or_create_lakebase( instance_name: Optional[str], project: Optional[str], branch: Optional[str], + endpoint: Optional[str] = None, + parent: Optional[str] = None, workspace_client: Optional[WorkspaceClient], token_cache_duration_seconds: int, pool_recycle: int, @@ -218,6 +264,8 @@ def _get_or_create_lakebase( instance_name=instance_name, project=project, branch=branch, + endpoint=endpoint, + parent=parent, pool_recycle=pool_recycle, **engine_kwargs, ) @@ -232,6 +280,8 @@ def _get_or_create_lakebase( instance_name=instance_name, project=project, branch=branch, + endpoint=endpoint, + parent=parent, workspace_client=workspace_client, token_cache_duration_seconds=token_cache_duration_seconds, pool_recycle=pool_recycle, diff --git a/integrations/openai/tests/unit_tests/test_session.py b/integrations/openai/tests/unit_tests/test_session.py index 3d50ae98..b2217185 100644 --- a/integrations/openai/tests/unit_tests/test_session.py +++ b/integrations/openai/tests/unit_tests/test_session.py @@ -1146,3 +1146,516 @@ def test_only_branch_raises_error(self): branch="my-branch", workspace_client=workspace, ) + + def test_endpoint_plus_project_uses_endpoint( + self, mock_endpoint_workspace_client, mock_engine, mock_event_listens_for + ): + """AsyncDatabricksSession with endpoint + project/branch uses endpoint.""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_endpoint_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ), + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + session = AsyncDatabricksSession( + session_id="test-session", + endpoint="projects/p/branches/b/endpoints/rw", + project="my-project", + branch="my-branch", + workspace_client=mock_endpoint_workspace_client, + ) + + # Should use get_endpoint, not list_endpoints + mock_endpoint_workspace_client.postgres.get_endpoint.assert_called() + mock_endpoint_workspace_client.postgres.list_endpoints.assert_not_called() + + def test_endpoint_plus_instance_raises_error(self): + """AsyncDatabricksSession with endpoint + instance_name raises ValueError.""" + from databricks_openai.agents.session import AsyncDatabricksSession + + workspace = MagicMock() + workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com") + + with pytest.raises(ValueError, match="Cannot provide both 'endpoint' and 'instance_name'"): + AsyncDatabricksSession( + session_id="test-session", + endpoint="projects/p/branches/b/endpoints/rw", + instance_name="my-instance", + workspace_client=workspace, + ) + + +# ============================================================================= +# Autoscaling (direct endpoint) Tests +# ============================================================================= + + +@pytest.fixture +def mock_endpoint_workspace_client(): + """Create a mock WorkspaceClient for direct endpoint mode.""" + mock_client = MagicMock() + mock_client.config.host = "https://test.databricks.com" + + # Mock current_user.me() for username inference + mock_user = MagicMock() + mock_user.user_name = "test_user@databricks.com" + mock_client.current_user.me.return_value = mock_user + + # Mock postgres.get_endpoint → returns endpoint with top-level host + ep = MagicMock() + ep.host = "endpoint-instance.lakebase.databricks.com" + mock_client.postgres.get_endpoint.return_value = ep + + # Mock postgres.generate_database_credential for autoscaling token minting + mock_credential = MagicMock() + mock_credential.token = "endpoint-oauth-token" + mock_client.postgres.generate_database_credential.return_value = mock_credential + + return mock_client + + +class TestAsyncDatabricksSessionEndpoint: + """Tests for AsyncDatabricksSession with direct endpoint.""" + + def test_init_endpoint_resolves_host( + self, mock_endpoint_workspace_client, mock_engine, mock_event_listens_for + ): + """Test that initialization with endpoint resolves host via get_endpoint API.""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_endpoint_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ) as mock_create_engine, + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + AsyncDatabricksSession( + session_id="test-session-123", + endpoint="projects/p/branches/b/endpoints/rw", + workspace_client=mock_endpoint_workspace_client, + ) + + # Verify engine URL uses endpoint host + call_args = mock_create_engine.call_args + url = call_args[0][0] + assert url.host == "endpoint-instance.lakebase.databricks.com" + + # Verify get_endpoint API was called + mock_endpoint_workspace_client.postgres.get_endpoint.assert_called_once_with( + name="projects/p/branches/b/endpoints/rw" + ) + + def test_init_endpoint_injects_correct_token( + self, mock_endpoint_workspace_client, mock_engine + ): + """Test that do_connect injects endpoint token.""" + captured_handler = None + + def capture_handler(engine, event_name): + def decorator(fn): + nonlocal captured_handler + captured_handler = fn + return fn + + return decorator + + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_endpoint_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ), + patch( + "sqlalchemy.event.listens_for", + side_effect=capture_handler, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + AsyncDatabricksSession( + session_id="test-session-123", + endpoint="projects/p/branches/b/endpoints/rw", + workspace_client=mock_endpoint_workspace_client, + ) + + # Simulate do_connect event + assert captured_handler is not None + cparams = {} + captured_handler(None, None, None, cparams) + + # Verify endpoint token was injected + assert cparams["password"] == "endpoint-oauth-token" + mock_endpoint_workspace_client.postgres.generate_database_credential.assert_called() + + def test_endpoint_sessions_share_engine( + self, mock_endpoint_workspace_client, mock_engine, mock_event_listens_for + ): + """Test that endpoint sessions with same endpoint share an engine.""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_endpoint_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ) as mock_create_engine, + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + session1 = AsyncDatabricksSession( + session_id="session-1", + endpoint="projects/p/branches/b/endpoints/rw", + workspace_client=mock_endpoint_workspace_client, + ) + session2 = AsyncDatabricksSession( + session_id="session-2", + endpoint="projects/p/branches/b/endpoints/rw", + workspace_client=mock_endpoint_workspace_client, + ) + + # Engine should only be created once + assert mock_create_engine.call_count == 1 + assert session1._engine is session2._engine + + def test_different_endpoints_get_different_engines( + self, mock_endpoint_workspace_client, mock_event_listens_for + ): + """Test that sessions with different endpoints get different engines.""" + engine1 = MagicMock() + engine1.sync_engine = MagicMock() + engine2 = MagicMock() + engine2.sync_engine = MagicMock() + + engines = [engine1, engine2] + engine_iter = iter(engines) + + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_endpoint_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + side_effect=lambda *args, **kwargs: next(engine_iter), + ) as mock_create_engine, + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + session1 = AsyncDatabricksSession( + session_id="session-1", + endpoint="projects/p/branches/b/endpoints/ep-1", + workspace_client=mock_endpoint_workspace_client, + ) + session2 = AsyncDatabricksSession( + session_id="session-2", + endpoint="projects/p/branches/b/endpoints/ep-2", + workspace_client=mock_endpoint_workspace_client, + ) + + assert mock_create_engine.call_count == 2 + assert session1._engine is not session2._engine + + +# ============================================================================= +# Autoscaling (parent) Tests +# ============================================================================= + + +@pytest.fixture +def mock_parent_workspace_client(): + """Create a mock WorkspaceClient for parent mode.""" + mock_client = MagicMock() + mock_client.config.host = "https://test.databricks.com" + + # Mock current_user.me() for username inference + mock_user = MagicMock() + mock_user.user_name = "test_user@databricks.com" + mock_client.current_user.me.return_value = mock_user + + # Mock postgres.list_endpoints → returns one READ_WRITE endpoint + rw_endpoint = MagicMock() + rw_endpoint.name = "projects/my-project/branches/my-branch/endpoints/rw-ep" + rw_endpoint.status.endpoint_type = "READ_WRITE" + rw_endpoint.status.hosts.host = "parent-instance.lakebase.databricks.com" + mock_client.postgres.list_endpoints.return_value = [rw_endpoint] + + # Mock postgres.generate_database_credential for autoscaling token minting + mock_credential = MagicMock() + mock_credential.token = "parent-oauth-token" + mock_client.postgres.generate_database_credential.return_value = mock_credential + + return mock_client + + +class TestAsyncDatabricksSessionParent: + """Tests for AsyncDatabricksSession with parent string.""" + + def test_init_parent_resolves_host( + self, mock_parent_workspace_client, mock_engine, mock_event_listens_for + ): + """Test that initialization with parent resolves host via list_endpoints API.""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_parent_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ) as mock_create_engine, + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + AsyncDatabricksSession( + session_id="test-session-123", + parent="projects/my-project/branches/my-branch", + workspace_client=mock_parent_workspace_client, + ) + + # Verify engine URL uses parent-resolved host + call_args = mock_create_engine.call_args + url = call_args[0][0] + assert url.host == "parent-instance.lakebase.databricks.com" + + # Verify list_endpoints API was called with parent + mock_parent_workspace_client.postgres.list_endpoints.assert_called_once_with( + parent="projects/my-project/branches/my-branch" + ) + + def test_init_parent_injects_correct_token( + self, mock_parent_workspace_client, mock_engine + ): + """Test that do_connect injects parent token.""" + captured_handler = None + + def capture_handler(engine, event_name): + def decorator(fn): + nonlocal captured_handler + captured_handler = fn + return fn + + return decorator + + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_parent_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ), + patch( + "sqlalchemy.event.listens_for", + side_effect=capture_handler, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + AsyncDatabricksSession( + session_id="test-session-123", + parent="projects/my-project/branches/my-branch", + workspace_client=mock_parent_workspace_client, + ) + + # Simulate do_connect event + assert captured_handler is not None + cparams = {} + captured_handler(None, None, None, cparams) + + # Verify parent token was injected + assert cparams["password"] == "parent-oauth-token" + mock_parent_workspace_client.postgres.generate_database_credential.assert_called() + + def test_parent_sessions_share_engine( + self, mock_parent_workspace_client, mock_engine, mock_event_listens_for + ): + """Test that parent sessions with same parent share an engine.""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_parent_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ) as mock_create_engine, + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + session1 = AsyncDatabricksSession( + session_id="session-1", + parent="projects/my-project/branches/my-branch", + workspace_client=mock_parent_workspace_client, + ) + session2 = AsyncDatabricksSession( + session_id="session-2", + parent="projects/my-project/branches/my-branch", + workspace_client=mock_parent_workspace_client, + ) + + # Engine should only be created once + assert mock_create_engine.call_count == 1 + assert session1._engine is session2._engine + + def test_different_parents_get_different_engines( + self, mock_parent_workspace_client, mock_event_listens_for + ): + """Test that sessions with different parents get different engines.""" + engine1 = MagicMock() + engine1.sync_engine = MagicMock() + engine2 = MagicMock() + engine2.sync_engine = MagicMock() + + engines = [engine1, engine2] + engine_iter = iter(engines) + + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_parent_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + side_effect=lambda *args, **kwargs: next(engine_iter), + ) as mock_create_engine, + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + session1 = AsyncDatabricksSession( + session_id="session-1", + parent="projects/p1/branches/b1", + workspace_client=mock_parent_workspace_client, + ) + session2 = AsyncDatabricksSession( + session_id="session-2", + parent="projects/p2/branches/b2", + workspace_client=mock_parent_workspace_client, + ) + + assert mock_create_engine.call_count == 2 + assert session1._engine is not session2._engine + + +class TestAsyncDatabricksSessionParentValidation: + """Validation tests for parent parameter.""" + + def test_parent_plus_instance_raises_error(self): + """AsyncDatabricksSession with parent + instance_name raises ValueError.""" + from databricks_openai.agents.session import AsyncDatabricksSession + + workspace = MagicMock() + workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com") + + with pytest.raises(ValueError, match="Cannot provide both 'parent' and 'instance_name'"): + AsyncDatabricksSession( + session_id="test-session", + parent="projects/p/branches/b", + instance_name="my-instance", + workspace_client=workspace, + ) + + def test_parent_plus_project_branch_uses_parent( + self, mock_parent_workspace_client, mock_engine, mock_event_listens_for + ): + """AsyncDatabricksSession with parent + project/branch uses parent.""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_parent_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ), + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + AsyncDatabricksSession( + session_id="test-session", + parent="projects/p/branches/b", + project="my-project", + branch="my-branch", + workspace_client=mock_parent_workspace_client, + ) + + # Should use list_endpoints (parent path), not get_endpoint + mock_parent_workspace_client.postgres.list_endpoints.assert_called() + mock_parent_workspace_client.postgres.get_endpoint.assert_not_called() + + def test_endpoint_plus_parent_uses_endpoint( + self, mock_endpoint_workspace_client, mock_engine, mock_event_listens_for + ): + """AsyncDatabricksSession with endpoint + parent uses endpoint.""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_endpoint_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ), + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + AsyncDatabricksSession( + session_id="test-session", + endpoint="projects/p/branches/b/endpoints/rw", + parent="projects/p/branches/b", + workspace_client=mock_endpoint_workspace_client, + ) + + # Should use get_endpoint, not list_endpoints + mock_endpoint_workspace_client.postgres.get_endpoint.assert_called() + mock_endpoint_workspace_client.postgres.list_endpoints.assert_not_called() diff --git a/src/databricks_ai_bridge/lakebase.py b/src/databricks_ai_bridge/lakebase.py index e5c113eb..e4c4abe8 100644 --- a/src/databricks_ai_bridge/lakebase.py +++ b/src/databricks_ai_bridge/lakebase.py @@ -97,10 +97,10 @@ class _LakebaseBase: https://docs.databricks.com/aws/en/oltp/#feature-comparison - **Provisioned**: Pass ``instance_name``. - - **Autoscaling**: Pass ``project`` and ``branch``. + - **Autoscaling**: Pass ``endpoint``, ``parent``, or ``project`` and ``branch``. - Providing both ``instance_name`` *and* ``project``/``branch`` raises a - ``ValueError``; choose one mode. + Provisioned and autoscaling are mutually exclusive. + Within autoscaling, priority is: ``endpoint`` > ``parent`` > ``project``/``branch``. Subclasses implement specific initialization and lifecycle methods. """ @@ -111,6 +111,8 @@ def __init__( instance_name: str | None = None, project: str | None = None, branch: str | None = None, + endpoint: str | None = None, + parent: str | None = None, workspace_client: WorkspaceClient | None = None, token_cache_duration_seconds: int = DEFAULT_TOKEN_CACHE_DURATION_SECONDS, ) -> None: @@ -118,20 +120,51 @@ def __init__( self.token_cache_duration_seconds: int = token_cache_duration_seconds # --- Parameter validation --- - is_autoscaling = project is not None or branch is not None - if is_autoscaling and not (project and branch): + # Autoscaling priority: endpoint > parent > project+branch + is_autoscaling_branch = project is not None or branch is not None + is_autoscaling_parent = parent is not None + is_autoscaling_endpoint = endpoint is not None + is_autoscaling = is_autoscaling_endpoint or is_autoscaling_parent or is_autoscaling_branch + + # Log when multiple autoscaling params given (higher priority wins) + if is_autoscaling_endpoint and (is_autoscaling_parent or is_autoscaling_branch): + logger.info( + "endpoint given alongside other autoscaling parameters " + "- using endpoint value" + ) + elif is_autoscaling_parent and is_autoscaling_branch: + logger.info( + "parent given alongside project/branch " + "- using parent value" + ) + + # Autoscaling vs provisioned conflicts + if is_autoscaling_endpoint and instance_name is not None: + raise ValueError( + "Cannot provide both 'endpoint' and 'instance_name'. " + "Use 'endpoint' for autoscaling or 'instance_name' for provisioned." + ) + + if is_autoscaling_parent and instance_name is not None: raise ValueError( - "Both 'project' and 'branch' are required to use a Lakebase " - "autoscaling instance. Please specify both parameters." + "Cannot provide both 'parent' and 'instance_name'. " + "Use 'parent' for autoscaling or 'instance_name' for provisioned." ) + if is_autoscaling_branch and not is_autoscaling_endpoint and not is_autoscaling_parent: + if not (project and branch): + raise ValueError( + "Both 'project' and 'branch' are required to use a Lakebase " + "autoscaling instance. Please specify both parameters." + ) + if not is_autoscaling and instance_name is None: raise ValueError( - "Must provide either 'instance_name' (provisioned) or both " - "'project' and 'branch' (autoscaling)." + "Must provide either 'instance_name' (provisioned), " + "'endpoint', 'parent', or both 'project' and 'branch' (autoscaling)." ) - if is_autoscaling and instance_name is not None: + if is_autoscaling_branch and instance_name is not None and not is_autoscaling_endpoint and not is_autoscaling_parent: raise ValueError( "Cannot provide both 'instance_name' (provisioned) and " "'project'/'branch' (autoscaling). Choose one mode." @@ -142,9 +175,13 @@ def __init__( self.instance_name: str | None = instance_name self.project: str | None = project self.branch: str | None = branch + self.parent: str | None = parent - if self._is_autoscaling: - self._endpoint_name: str | None = None + if is_autoscaling_endpoint: + self._endpoint_name: str | None = endpoint + self.host = self._resolve_endpoint_host() + elif is_autoscaling_parent or is_autoscaling_branch: + self._endpoint_name = None self.host = self._resolve_autoscaling_host() else: self._endpoint_name = None @@ -185,17 +222,23 @@ def _resolve_provisioned_host(self) -> str: def _resolve_autoscaling_host(self) -> str: """Resolve host via the Lakebase autoscaling postgres API. - Constructs the branch parent path, lists endpoints, finds the - READ_WRITE endpoint, and extracts the host and endpoint name. + Uses ``self.parent`` directly if set, otherwise constructs the branch + parent path from ``self.project`` and ``self.branch``. Lists endpoints, + finds the READ_WRITE endpoint, and extracts the host and endpoint name. + + See https://databricks-sdk-py.readthedocs.io/en/latest/workspace/postgres/postgres.html#databricks.sdk.service.postgres.PostgresAPI.list_endpoints """ - branch_parent = f"projects/{self.project}/branches/{self.branch}" + if self.parent: + branch_parent = self.parent + else: + branch_parent = f"projects/{self.project}/branches/{self.branch}" try: endpoints = list(self.workspace_client.postgres.list_endpoints(parent=branch_parent)) except Exception as exc: raise ValueError( - f"Unable to list endpoints for project='{self.project}', " - f"branch='{self.branch}'. Verify your project and branch names.\n" + f"Unable to list endpoints for parent='{branch_parent}'. " + "Verify the parent path is correct.\n" "To find available projects and branches, use:\n" " workspace_client.postgres.list_projects()\n" ' workspace_client.postgres.list_branches(parent="projects/")' @@ -212,8 +255,8 @@ def _resolve_autoscaling_host(self) -> str: if rw_endpoint is None: raise ValueError( - f"No READ_WRITE endpoint found for project='{self.project}', " - f"branch='{self.branch}'. Ensure the branch has an active endpoint.\n" + f"No READ_WRITE endpoint found for parent='{branch_parent}'. " + "Ensure the branch has an active endpoint.\n" "To check endpoints, use:\n" f' workspace_client.postgres.list_endpoints(parent="{branch_parent}")' ) @@ -232,6 +275,32 @@ def _resolve_autoscaling_host(self) -> str: self._endpoint_name = rw_endpoint.name return resolved_host + def _resolve_endpoint_host(self) -> str: + """Resolve host via endpoint name using the Lakebase autoscaling API. + + Calls ``get_endpoint(name=...)`` and extracts the top-level ``host`` field. + """ + try: + assert self._endpoint_name is not None + ep = self.workspace_client.postgres.get_endpoint(name=self._endpoint_name) + except Exception as exc: + raise ValueError( + f"Unable to resolve Lakebase autoscaling endpoint '{self._endpoint_name}'. " + "Verify the endpoint name is correct.\n" + "To list available endpoints, use:\n" + ' workspace_client.postgres.list_endpoints(parent="projects//branches/")' + ) from exc + + resolved_host = getattr(ep, "host", None) + + if not resolved_host: + raise ValueError( + f"Host not found on endpoint '{self._endpoint_name}'. " + "Ensure the endpoint is in AVAILABLE state." + ) + + return resolved_host + # --- Token caching --- def _get_cached_token(self) -> str | None: @@ -307,7 +376,7 @@ class LakebasePool(_LakebaseBase): https://docs.databricks.com/aws/en/oltp/#feature-comparison - **Provisioned**: Pass ``instance_name``. - - **Autoscaling**: Pass ``project`` and ``branch``. + - **Autoscaling**: Pass ``endpoint``, ``parent``, or ``project`` and ``branch``. """ def __init__( @@ -316,6 +385,8 @@ def __init__( instance_name: str | None = None, project: str | None = None, branch: str | None = None, + endpoint: str | None = None, + parent: str | None = None, workspace_client: WorkspaceClient | None = None, token_cache_duration_seconds: int = DEFAULT_TOKEN_CACHE_DURATION_SECONDS, **pool_kwargs: dict[str, Any], @@ -324,6 +395,8 @@ def __init__( instance_name=instance_name, project=project, branch=branch, + endpoint=endpoint, + parent=parent, workspace_client=workspace_client, token_cache_duration_seconds=token_cache_duration_seconds, ) @@ -409,7 +482,7 @@ class AsyncLakebasePool(_LakebaseBase): https://docs.databricks.com/aws/en/oltp/#feature-comparison - **Provisioned**: Pass ``instance_name``. - - **Autoscaling**: Pass ``project`` and ``branch``. + - **Autoscaling**: Pass ``endpoint``, ``parent``, or ``project`` and ``branch``. """ def __init__( @@ -418,6 +491,8 @@ def __init__( instance_name: str | None = None, project: str | None = None, branch: str | None = None, + endpoint: str | None = None, + parent: str | None = None, workspace_client: WorkspaceClient | None = None, token_cache_duration_seconds: int = DEFAULT_TOKEN_CACHE_DURATION_SECONDS, **pool_kwargs: object, @@ -426,6 +501,8 @@ def __init__( instance_name=instance_name, project=project, branch=branch, + endpoint=endpoint, + parent=parent, workspace_client=workspace_client, token_cache_duration_seconds=token_cache_duration_seconds, ) @@ -587,6 +664,8 @@ def __init__( instance_name: str | None = None, project: str | None = None, branch: str | None = None, + endpoint: str | None = None, + parent: str | None = None, **pool_kwargs: Any, ) -> None: """ @@ -595,27 +674,35 @@ def __init__( Provide EITHER: - pool: An existing LakebasePool instance (advanced usage where multiple clients can connect to same pool) - instance_name: Name of the Lakebase provisioned instance - - project + branch: Lakebase autoscaling project and branch names + - endpoint, parent, or project + branch: Lakebase autoscaling :param pool: Existing LakebasePool to use for connections. :param instance_name: Name of the Lakebase provisioned instance. :param project: Lakebase autoscaling project name. Also requires ``branch``. :param branch: Lakebase autoscaling branch name. Also requires ``project``. + :param endpoint: Lakebase autoscaling endpoint name. + See https://databricks-sdk-py.readthedocs.io/en/latest/dbdataclasses/postgres.html#databricks.sdk.service.postgres.Endpoint + :param parent: Lakebase autoscaling branch parent path (e.g., ``"projects/{project}/branches/{branch}"``). + See https://databricks-sdk-py.readthedocs.io/en/latest/workspace/postgres/postgres.html#databricks.sdk.service.postgres.PostgresAPI.list_endpoints :param pool_kwargs: Additional kwargs passed to LakebasePool (only used when creating pool internally). """ has_connection_params = ( - instance_name is not None or project is not None or branch is not None + instance_name is not None + or project is not None + or branch is not None + or endpoint is not None + or parent is not None ) if pool is not None and has_connection_params: raise ValueError( "Provide either 'pool' or connection parameters " - "('instance_name' or 'project'/'branch'), not both." + "('instance_name', 'endpoint', 'parent', or 'project'/'branch'), not both." ) if pool is None and not has_connection_params: raise ValueError( "Must provide 'pool', 'instance_name' (provisioned), " - "or both 'project' and 'branch' (autoscaling)." + "'endpoint', 'parent', or both 'project' and 'branch' (autoscaling)." ) self._owns_pool = pool is None @@ -627,6 +714,8 @@ def __init__( instance_name=instance_name, project=project, branch=branch, + endpoint=endpoint, + parent=parent, **pool_kwargs, ) @@ -1071,6 +1160,8 @@ def __init__( instance_name: str | None = None, project: str | None = None, branch: str | None = None, + endpoint: str | None = None, + parent: str | None = None, workspace_client: WorkspaceClient | None = None, token_cache_duration_seconds: int = DEFAULT_TOKEN_CACHE_DURATION_SECONDS, pool_recycle: int = DEFAULT_POOL_RECYCLE_SECONDS, @@ -1083,6 +1174,10 @@ def __init__( instance_name: Name of the Lakebase provisioned instance. project: Lakebase autoscaling project name. Also requires ``branch``. branch: Lakebase autoscaling branch name. Also requires ``project``. + endpoint: Lakebase autoscaling endpoint name. + See https://databricks-sdk-py.readthedocs.io/en/latest/dbdataclasses/postgres.html#databricks.sdk.service.postgres.Endpoint + parent: Lakebase autoscaling branch parent path (e.g., ``"projects/{project}/branches/{branch}"``). + See https://databricks-sdk-py.readthedocs.io/en/latest/workspace/postgres/postgres.html#databricks.sdk.service.postgres.PostgresAPI.list_endpoints workspace_client: Optional WorkspaceClient for authentication. If not provided, a default client will be created. token_cache_duration_seconds: How long to cache OAuth tokens. @@ -1096,6 +1191,8 @@ def __init__( instance_name=instance_name, project=project, branch=branch, + endpoint=endpoint, + parent=parent, workspace_client=workspace_client, token_cache_duration_seconds=token_cache_duration_seconds, ) diff --git a/tests/databricks_ai_bridge/test_lakebase.py b/tests/databricks_ai_bridge/test_lakebase.py index bd8c4b95..a372f8a7 100644 --- a/tests/databricks_ai_bridge/test_lakebase.py +++ b/tests/databricks_ai_bridge/test_lakebase.py @@ -446,7 +446,7 @@ def test_client_requires_pool_or_instance_name(self): """Client must be given either pool or instance_name.""" with pytest.raises( ValueError, - match="Must provide 'pool', 'instance_name' .provisioned., or both 'project' and 'branch' .autoscaling.", + match="Must provide 'pool', 'instance_name' .provisioned.", ): LakebaseClient() @@ -1376,7 +1376,7 @@ def test_lakebase_client_no_params_raises_error(): """LakebaseClient with no pool or connection parameters raises ValueError.""" with pytest.raises( ValueError, - match="Must provide 'pool', 'instance_name' .provisioned., or both 'project' and 'branch' .autoscaling.", + match="Must provide 'pool', 'instance_name' .provisioned.", ): LakebaseClient() @@ -1647,3 +1647,461 @@ def test_async_lakebase_sqlalchemy_provisioned_mints_correct_token(): token = sa.get_token() assert token == "sa-prov-token" workspace.database.generate_database_credential.assert_called_once() + + +# ============================================================================= +# Autoscaling (direct endpoint) Tests +# ============================================================================= + + +def _make_endpoint_workspace( + *, + user_name: str = "test@databricks.com", + credential_token: str = "endpoint-token-1", + host: str = "endpoint.db.host", + endpoint_name: str = "projects/my-project/branches/my-branch/endpoints/rw-ep", +): + """Create a mock workspace client for direct endpoint mode.""" + workspace = MagicMock() + workspace.current_user.me.return_value = MagicMock(user_name=user_name) + + # Mock postgres.generate_database_credential + workspace.postgres.generate_database_credential.return_value = MagicMock(token=credential_token) + + # Mock postgres.get_endpoint → returns endpoint with top-level host + ep = MagicMock() + ep.host = host + workspace.postgres.get_endpoint.return_value = ep + + return workspace + + +# --- LakebasePool endpoint tests --- + + +def test_lakebase_pool_endpoint_configures_connection_pool(monkeypatch): + """LakebasePool with endpoint resolves host via get_endpoint API.""" + TestConnectionPool = _make_connection_pool_class() + monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) + + workspace = _make_endpoint_workspace(host="ep.db.host") + + pool = LakebasePool( + endpoint="projects/p/branches/b/endpoints/rw", + workspace_client=workspace, + ) + + assert pool.host == "ep.db.host" + assert pool._is_autoscaling is True + assert pool._endpoint_name == "projects/p/branches/b/endpoints/rw" + assert pool.username == "test@databricks.com" + assert "host=ep.db.host" in str(pool.pool.conninfo) + + workspace.postgres.get_endpoint.assert_called_once_with( + name="projects/p/branches/b/endpoints/rw" + ) + # list_endpoints should NOT be called + workspace.postgres.list_endpoints.assert_not_called() + + +def test_lakebase_pool_endpoint_mints_token(monkeypatch): + """Endpoint pool uses postgres.generate_database_credential for tokens.""" + TestConnectionPool = _make_connection_pool_class() + monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) + + workspace = _make_endpoint_workspace(credential_token="ep-token") + + pool = LakebasePool( + endpoint="projects/p/branches/b/endpoints/rw", + workspace_client=workspace, + ) + + token = pool._get_token() + assert token == "ep-token" + workspace.postgres.generate_database_credential.assert_called_once_with( + endpoint="projects/p/branches/b/endpoints/rw" + ) + # Provisioned credential API should NOT be called + workspace.database.generate_database_credential.assert_not_called() + + +# --- AsyncLakebasePool endpoint tests --- + + +@pytest.mark.asyncio +async def test_async_lakebase_pool_endpoint_configures_pool(monkeypatch): + """AsyncLakebasePool with endpoint resolves host via get_endpoint API.""" + TestAsyncConnectionPool = _make_async_connection_pool_class() + monkeypatch.setattr( + "databricks_ai_bridge.lakebase.AsyncConnectionPool", TestAsyncConnectionPool + ) + + workspace = _make_endpoint_workspace(host="async-ep.db.host") + + pool = AsyncLakebasePool( + endpoint="projects/p/branches/b/endpoints/rw", + workspace_client=workspace, + ) + + assert pool.host == "async-ep.db.host" + assert pool._is_autoscaling is True + assert "host=async-ep.db.host" in str(pool.pool.conninfo) + + workspace.postgres.get_endpoint.assert_called_once_with( + name="projects/p/branches/b/endpoints/rw" + ) + + +# --- LakebaseClient endpoint tests --- + + +def test_lakebase_client_endpoint_creates_pool(monkeypatch): + """LakebaseClient with endpoint creates an endpoint pool internally.""" + TestConnectionPool = _make_connection_pool_class() + monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) + + workspace = _make_endpoint_workspace(host="client-ep.db.host") + + client = LakebaseClient( + endpoint="projects/p/branches/b/endpoints/rw", + workspace_client=workspace, + ) + + assert client.pool.host == "client-ep.db.host" + assert client.pool._is_autoscaling is True + assert client._owns_pool is True + + +# --- AsyncLakebaseSQLAlchemy endpoint tests --- + + +def test_async_lakebase_sqlalchemy_endpoint_resolves_host(): + """AsyncLakebaseSQLAlchemy with endpoint resolves via get_endpoint API.""" + workspace = _make_endpoint_workspace(host="sa-ep.db.host") + patch_engine, patch_event, _ = _make_sqlalchemy_patches(workspace) + + with patch_engine, patch_event: + sa = AsyncLakebaseSQLAlchemy( + endpoint="projects/p/branches/b/endpoints/rw", + workspace_client=workspace, + ) + + assert sa.host == "sa-ep.db.host" + assert sa._is_autoscaling is True + workspace.postgres.get_endpoint.assert_called_once_with( + name="projects/p/branches/b/endpoints/rw" + ) + + +def test_async_lakebase_sqlalchemy_endpoint_mints_correct_token(): + """AsyncLakebaseSQLAlchemy in endpoint mode uses postgres credential API.""" + workspace = _make_endpoint_workspace(credential_token="sa-ep-token") + patch_engine, patch_event, _ = _make_sqlalchemy_patches(workspace) + + with patch_engine, patch_event: + sa = AsyncLakebaseSQLAlchemy( + endpoint="projects/p/branches/b/endpoints/rw", + workspace_client=workspace, + ) + + token = sa.get_token() + assert token == "sa-ep-token" + workspace.postgres.generate_database_credential.assert_called_once() + + +# --- Validation: endpoint conflicts --- + + +def test_endpoint_plus_project_branch_uses_endpoint(monkeypatch, caplog): + """Passing both endpoint and project/branch logs info and uses endpoint.""" + TestConnectionPool = _make_connection_pool_class() + monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) + + workspace = _make_endpoint_workspace(host="ep.db.host") + + with caplog.at_level(logging.INFO): + pool = LakebasePool( + endpoint="projects/p/branches/b/endpoints/rw", + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) + + assert pool.host == "ep.db.host" + assert pool._is_autoscaling is True + # Should use get_endpoint, not list_endpoints + workspace.postgres.get_endpoint.assert_called_once() + workspace.postgres.list_endpoints.assert_not_called() + assert any("using endpoint value" in record.message for record in caplog.records) + + +def test_endpoint_plus_instance_name_raises_error(): + """Passing both endpoint and instance_name raises ValueError.""" + workspace = _make_endpoint_workspace() + with pytest.raises( + ValueError, + match="Cannot provide both 'endpoint' and 'instance_name'", + ): + LakebasePool( + endpoint="projects/p/branches/b/endpoints/rw", + instance_name="my-instance", + workspace_client=workspace, + ) + + +def test_endpoint_plus_pool_raises_error(): + """LakebaseClient rejects passing both pool and endpoint.""" + pool = MagicMock(spec=LakebasePool) + + with pytest.raises(ValueError, match="Provide either 'pool' or connection parameters"): + LakebaseClient( + pool=pool, + endpoint="projects/p/branches/b/endpoints/rw", + ) + + +def test_endpoint_get_endpoint_fails_raises(): + """Raises ValueError when get_endpoint fails.""" + workspace = _make_endpoint_workspace() + workspace.postgres.get_endpoint.side_effect = Exception("Not found") + + with pytest.raises(ValueError, match="Unable to resolve Lakebase autoscaling endpoint"): + LakebasePool( + endpoint="bad-endpoint", + workspace_client=workspace, + ) + + +def test_endpoint_no_host_raises(): + """Raises ValueError when endpoint has no host field.""" + workspace = _make_endpoint_workspace() + ep = MagicMock() + ep.host = None + workspace.postgres.get_endpoint.return_value = ep + + with pytest.raises(ValueError, match="Host not found on endpoint"): + LakebasePool( + endpoint="projects/p/branches/b/endpoints/rw", + workspace_client=workspace, + ) + + +# ============================================================================= +# Autoscaling (parent) Tests +# ============================================================================= + + +# --- LakebasePool parent tests --- + + +def test_lakebase_pool_parent_configures_connection_pool(monkeypatch): + """LakebasePool with parent resolves host via list_endpoints with parent directly.""" + TestConnectionPool = _make_connection_pool_class() + monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) + + workspace = _make_autoscaling_workspace(host="parent.db.host") + + pool = LakebasePool( + parent="projects/my-project/branches/my-branch", + workspace_client=workspace, + ) + + assert pool.host == "parent.db.host" + assert pool._is_autoscaling is True + assert pool.username == "test@databricks.com" + assert "host=parent.db.host" in str(pool.pool.conninfo) + + workspace.postgres.list_endpoints.assert_called_once_with( + parent="projects/my-project/branches/my-branch" + ) + # get_endpoint should NOT be called + workspace.postgres.get_endpoint.assert_not_called() + + +def test_lakebase_pool_parent_mints_token(monkeypatch): + """Parent pool uses postgres.generate_database_credential for tokens.""" + TestConnectionPool = _make_connection_pool_class() + monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) + + workspace = _make_autoscaling_workspace(credential_token="parent-token") + + pool = LakebasePool( + parent="projects/my-project/branches/my-branch", + workspace_client=workspace, + ) + + token = pool._get_token() + assert token == "parent-token" + workspace.postgres.generate_database_credential.assert_called() + # Provisioned credential API should NOT be called + workspace.database.generate_database_credential.assert_not_called() + + +# --- AsyncLakebasePool parent tests --- + + +@pytest.mark.asyncio +async def test_async_lakebase_pool_parent_configures_pool(monkeypatch): + """AsyncLakebasePool with parent resolves host via list_endpoints.""" + TestAsyncConnectionPool = _make_async_connection_pool_class() + monkeypatch.setattr( + "databricks_ai_bridge.lakebase.AsyncConnectionPool", TestAsyncConnectionPool + ) + + workspace = _make_autoscaling_workspace(host="async-parent.db.host") + + pool = AsyncLakebasePool( + parent="projects/my-project/branches/my-branch", + workspace_client=workspace, + ) + + assert pool.host == "async-parent.db.host" + assert pool._is_autoscaling is True + assert "host=async-parent.db.host" in str(pool.pool.conninfo) + + workspace.postgres.list_endpoints.assert_called_once_with( + parent="projects/my-project/branches/my-branch" + ) + + +# --- LakebaseClient parent tests --- + + +def test_lakebase_client_parent_creates_pool(monkeypatch): + """LakebaseClient with parent creates a pool internally.""" + TestConnectionPool = _make_connection_pool_class() + monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) + + workspace = _make_autoscaling_workspace(host="client-parent.db.host") + + client = LakebaseClient( + parent="projects/my-project/branches/my-branch", + workspace_client=workspace, + ) + + assert client.pool.host == "client-parent.db.host" + assert client.pool._is_autoscaling is True + assert client._owns_pool is True + + +# --- AsyncLakebaseSQLAlchemy parent tests --- + + +def test_async_lakebase_sqlalchemy_parent_resolves_host(): + """AsyncLakebaseSQLAlchemy with parent resolves via list_endpoints.""" + workspace = _make_autoscaling_workspace(host="sa-parent.db.host") + patch_engine, patch_event, _ = _make_sqlalchemy_patches(workspace) + + with patch_engine, patch_event: + sa = AsyncLakebaseSQLAlchemy( + parent="projects/my-project/branches/my-branch", + workspace_client=workspace, + ) + + assert sa.host == "sa-parent.db.host" + assert sa._is_autoscaling is True + workspace.postgres.list_endpoints.assert_called_once_with( + parent="projects/my-project/branches/my-branch" + ) + + +def test_async_lakebase_sqlalchemy_parent_mints_correct_token(): + """AsyncLakebaseSQLAlchemy in parent mode uses postgres credential API.""" + workspace = _make_autoscaling_workspace(credential_token="sa-parent-token") + patch_engine, patch_event, _ = _make_sqlalchemy_patches(workspace) + + with patch_engine, patch_event: + sa = AsyncLakebaseSQLAlchemy( + parent="projects/my-project/branches/my-branch", + workspace_client=workspace, + ) + + token = sa.get_token() + assert token == "sa-parent-token" + workspace.postgres.generate_database_credential.assert_called() + + +# --- Validation: parent conflicts --- + + +def test_parent_plus_project_branch_uses_parent(monkeypatch, caplog): + """Passing both parent and project/branch logs info and uses parent.""" + TestConnectionPool = _make_connection_pool_class() + monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) + + workspace = _make_autoscaling_workspace(host="parent.db.host") + + with caplog.at_level(logging.INFO): + pool = LakebasePool( + parent="projects/my-project/branches/my-branch", + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) + + assert pool.host == "parent.db.host" + assert pool._is_autoscaling is True + # Should use list_endpoints with parent string directly + workspace.postgres.list_endpoints.assert_called_once_with( + parent="projects/my-project/branches/my-branch" + ) + assert any("using parent value" in record.message for record in caplog.records) + + +def test_parent_plus_instance_name_raises_error(): + """Passing both parent and instance_name raises ValueError.""" + workspace = _make_autoscaling_workspace() + with pytest.raises( + ValueError, + match="Cannot provide both 'parent' and 'instance_name'", + ): + LakebasePool( + parent="projects/my-project/branches/my-branch", + instance_name="my-instance", + workspace_client=workspace, + ) + + +def test_endpoint_plus_parent_uses_endpoint(monkeypatch, caplog): + """Passing both endpoint and parent logs info and uses endpoint.""" + TestConnectionPool = _make_connection_pool_class() + monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) + + workspace = _make_endpoint_workspace(host="ep.db.host") + + with caplog.at_level(logging.INFO): + pool = LakebasePool( + endpoint="projects/p/branches/b/endpoints/rw", + parent="projects/my-project/branches/my-branch", + workspace_client=workspace, + ) + + assert pool.host == "ep.db.host" + assert pool._is_autoscaling is True + # Should use get_endpoint, not list_endpoints + workspace.postgres.get_endpoint.assert_called_once() + workspace.postgres.list_endpoints.assert_not_called() + assert any("using endpoint value" in record.message for record in caplog.records) + + +def test_endpoint_plus_parent_plus_project_branch_uses_endpoint(monkeypatch, caplog): + """Passing all three autoscaling params logs info and uses endpoint.""" + TestConnectionPool = _make_connection_pool_class() + monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) + + workspace = _make_endpoint_workspace(host="ep.db.host") + + with caplog.at_level(logging.INFO): + pool = LakebasePool( + endpoint="projects/p/branches/b/endpoints/rw", + parent="projects/my-project/branches/my-branch", + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) + + assert pool.host == "ep.db.host" + assert pool._is_autoscaling is True + workspace.postgres.get_endpoint.assert_called_once() + workspace.postgres.list_endpoints.assert_not_called() + assert any("using endpoint value" in record.message for record in caplog.records) From 6ccafe8d4c32aac01b393cb8d9630e260f9f4399 Mon Sep 17 00:00:00 2001 From: Jenny Date: Tue, 3 Mar 2026 17:58:20 -0800 Subject: [PATCH 7/8] nit: rearrange params --- .../src/databricks_langchain/checkpoint.py | 16 +++---- .../src/databricks_langchain/store.py | 24 +++++----- .../src/databricks_openai/agents/session.py | 28 ++++++------ src/databricks_ai_bridge/lakebase.py | 44 +++++++++---------- 4 files changed, 56 insertions(+), 56 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/checkpoint.py b/integrations/langchain/src/databricks_langchain/checkpoint.py index baf283bb..0038a5e1 100644 --- a/integrations/langchain/src/databricks_langchain/checkpoint.py +++ b/integrations/langchain/src/databricks_langchain/checkpoint.py @@ -29,10 +29,10 @@ def __init__( self, *, instance_name: str | None = None, - project: str | None = None, - branch: str | None = None, endpoint: str | None = None, parent: str | None = None, + project: str | None = None, + branch: str | None = None, workspace_client: WorkspaceClient | None = None, **pool_kwargs: Any, ) -> None: @@ -45,10 +45,10 @@ def __init__( self._lakebase: LakebasePool = LakebasePool( instance_name=instance_name, - project=project, - branch=branch, endpoint=endpoint, parent=parent, + project=project, + branch=branch, workspace_client=workspace_client, **dict(pool_kwargs), ) @@ -79,10 +79,10 @@ def __init__( self, *, instance_name: str | None = None, - project: str | None = None, - branch: str | None = None, endpoint: str | None = None, parent: str | None = None, + project: str | None = None, + branch: str | None = None, workspace_client: WorkspaceClient | None = None, **pool_kwargs: Any, ) -> None: @@ -95,10 +95,10 @@ def __init__( self._lakebase: AsyncLakebasePool = AsyncLakebasePool( instance_name=instance_name, - project=project, - branch=branch, endpoint=endpoint, parent=parent, + project=project, + branch=branch, workspace_client=workspace_client, **dict(pool_kwargs), ) diff --git a/integrations/langchain/src/databricks_langchain/store.py b/integrations/langchain/src/databricks_langchain/store.py index 60da3b9a..83a425ac 100644 --- a/integrations/langchain/src/databricks_langchain/store.py +++ b/integrations/langchain/src/databricks_langchain/store.py @@ -34,10 +34,10 @@ def __init__( self, *, instance_name: str | None = None, - project: str | None = None, - branch: str | None = None, endpoint: str | None = None, parent: str | None = None, + project: str | None = None, + branch: str | None = None, workspace_client: WorkspaceClient | None = None, embedding_endpoint: str | None = None, embedding_dims: int | None = None, @@ -49,12 +49,12 @@ def __init__( Args: instance_name: The name of the Lakebase provisioned instance. - project: Lakebase autoscaling project name. Also requires ``branch``. - branch: Lakebase autoscaling branch name. Also requires ``project``. endpoint: Lakebase autoscaling endpoint name. See https://databricks-sdk-py.readthedocs.io/en/latest/dbdataclasses/postgres.html#databricks.sdk.service.postgres.Endpoint parent: Lakebase autoscaling branch parent string (e.g., ``"projects/{project_id}/branches/{branch_id}"``). + project: Lakebase autoscaling project name. Also requires ``branch``. + branch: Lakebase autoscaling branch name. Also requires ``project``. workspace_client: Optional Databricks WorkspaceClient for authentication. embedding_endpoint: Name of the Databricks Model Serving endpoint for embeddings (e.g., "databricks-gte-large-en"). If provided, enables semantic search. @@ -74,10 +74,10 @@ def __init__( self._lakebase: LakebasePool = LakebasePool( instance_name=instance_name, - project=project, - branch=branch, endpoint=endpoint, parent=parent, + project=project, + branch=branch, workspace_client=workspace_client, **pool_kwargs, ) @@ -163,10 +163,10 @@ def __init__( self, *, instance_name: str | None = None, - project: str | None = None, - branch: str | None = None, endpoint: str | None = None, parent: str | None = None, + project: str | None = None, + branch: str | None = None, workspace_client: WorkspaceClient | None = None, embedding_endpoint: str | None = None, embedding_dims: int | None = None, @@ -178,12 +178,12 @@ def __init__( Args: instance_name: The name of the Lakebase provisioned instance. - project: Lakebase autoscaling project name. Also requires ``branch``. - branch: Lakebase autoscaling branch name. Also requires ``project``. endpoint: Lakebase autoscaling endpoint name. See https://databricks-sdk-py.readthedocs.io/en/latest/dbdataclasses/postgres.html#databricks.sdk.service.postgres.Endpoint parent: Lakebase autoscaling branch parent string (e.g., ``"projects/{project_id}/branches/{branch_id}"``). + project: Lakebase autoscaling project name. Also requires ``branch``. + branch: Lakebase autoscaling branch name. Also requires ``project``. workspace_client: Optional Databricks WorkspaceClient for authentication. embedding_endpoint: Name of the Databricks Model Serving endpoint for embeddings (e.g., "databricks-gte-large-en"). If provided, enables semantic search. @@ -205,10 +205,10 @@ def __init__( self._lakebase: AsyncLakebasePool = AsyncLakebasePool( instance_name=instance_name, - project=project, - branch=branch, endpoint=endpoint, parent=parent, + project=project, + branch=branch, workspace_client=workspace_client, **pool_kwargs, ) diff --git a/integrations/openai/src/databricks_openai/agents/session.py b/integrations/openai/src/databricks_openai/agents/session.py index 58ed3bba..7e8bf37a 100644 --- a/integrations/openai/src/databricks_openai/agents/session.py +++ b/integrations/openai/src/databricks_openai/agents/session.py @@ -99,10 +99,10 @@ def __init__( session_id: str, *, instance_name: Optional[str] = None, - project: Optional[str] = None, - branch: Optional[str] = None, endpoint: Optional[str] = None, parent: Optional[str] = None, + project: Optional[str] = None, + branch: Optional[str] = None, workspace_client: Optional[WorkspaceClient] = None, token_cache_duration_seconds: int = DEFAULT_TOKEN_CACHE_DURATION_SECONDS, create_tables: bool = True, @@ -117,12 +117,12 @@ def __init__( Args: session_id: Unique identifier for the conversation session. instance_name: Name of the Lakebase provisioned instance. - project: Lakebase autoscaling project name. Also requires ``branch``. - branch: Lakebase autoscaling branch name. Also requires ``project``. endpoint: Lakebase autoscaling endpoint name. See https://databricks-sdk-py.readthedocs.io/en/latest/dbdataclasses/postgres.html#databricks.sdk.service.postgres.Endpoint parent: Lakebase autoscaling branch parent string (e.g., ``"projects/{project_id}/branches/{branch_id}"``). + project: Lakebase autoscaling project name. Also requires ``branch``. + branch: Lakebase autoscaling branch name. Also requires ``project``. workspace_client: Optional WorkspaceClient for authentication. If not provided, a default client will be created. token_cache_duration_seconds: How long to cache OAuth tokens. @@ -196,10 +196,10 @@ def __init__( self._lakebase = self._get_or_create_lakebase( instance_name=instance_name, - project=project, - branch=branch, endpoint=endpoint, parent=parent, + project=project, + branch=branch, workspace_client=workspace_client, token_cache_duration_seconds=token_cache_duration_seconds, pool_recycle=engine_kwargs.pop("pool_recycle", DEFAULT_POOL_RECYCLE_SECONDS), @@ -225,10 +225,10 @@ def __init__( def _build_cache_key( cls, instance_name: Optional[str] = None, - project: Optional[str] = None, - branch: Optional[str] = None, endpoint: Optional[str] = None, parent: Optional[str] = None, + project: Optional[str] = None, + branch: Optional[str] = None, **engine_kwargs: Any, ) -> str: """Build a cache key from connection parameters and engine_kwargs.""" @@ -247,10 +247,10 @@ def _get_or_create_lakebase( cls, *, instance_name: Optional[str], - project: Optional[str], - branch: Optional[str], endpoint: Optional[str] = None, parent: Optional[str] = None, + project: Optional[str] = None, + branch: Optional[str] = None, workspace_client: Optional[WorkspaceClient], token_cache_duration_seconds: int, pool_recycle: int, @@ -262,10 +262,10 @@ def _get_or_create_lakebase( """ cache_key = cls._build_cache_key( instance_name=instance_name, - project=project, - branch=branch, endpoint=endpoint, parent=parent, + project=project, + branch=branch, pool_recycle=pool_recycle, **engine_kwargs, ) @@ -278,10 +278,10 @@ def _get_or_create_lakebase( lakebase = AsyncLakebaseSQLAlchemy( instance_name=instance_name, - project=project, - branch=branch, endpoint=endpoint, parent=parent, + project=project, + branch=branch, workspace_client=workspace_client, token_cache_duration_seconds=token_cache_duration_seconds, pool_recycle=pool_recycle, diff --git a/src/databricks_ai_bridge/lakebase.py b/src/databricks_ai_bridge/lakebase.py index e4c4abe8..c115bbb4 100644 --- a/src/databricks_ai_bridge/lakebase.py +++ b/src/databricks_ai_bridge/lakebase.py @@ -109,10 +109,10 @@ def __init__( self, *, instance_name: str | None = None, - project: str | None = None, - branch: str | None = None, endpoint: str | None = None, parent: str | None = None, + project: str | None = None, + branch: str | None = None, workspace_client: WorkspaceClient | None = None, token_cache_duration_seconds: int = DEFAULT_TOKEN_CACHE_DURATION_SECONDS, ) -> None: @@ -383,20 +383,20 @@ def __init__( self, *, instance_name: str | None = None, - project: str | None = None, - branch: str | None = None, endpoint: str | None = None, parent: str | None = None, + project: str | None = None, + branch: str | None = None, workspace_client: WorkspaceClient | None = None, token_cache_duration_seconds: int = DEFAULT_TOKEN_CACHE_DURATION_SECONDS, **pool_kwargs: dict[str, Any], ) -> None: super().__init__( instance_name=instance_name, - project=project, - branch=branch, endpoint=endpoint, parent=parent, + project=project, + branch=branch, workspace_client=workspace_client, token_cache_duration_seconds=token_cache_duration_seconds, ) @@ -489,20 +489,20 @@ def __init__( self, *, instance_name: str | None = None, - project: str | None = None, - branch: str | None = None, endpoint: str | None = None, parent: str | None = None, + project: str | None = None, + branch: str | None = None, workspace_client: WorkspaceClient | None = None, token_cache_duration_seconds: int = DEFAULT_TOKEN_CACHE_DURATION_SECONDS, **pool_kwargs: object, ) -> None: super().__init__( instance_name=instance_name, - project=project, - branch=branch, endpoint=endpoint, parent=parent, + project=project, + branch=branch, workspace_client=workspace_client, token_cache_duration_seconds=token_cache_duration_seconds, ) @@ -662,10 +662,10 @@ def __init__( *, pool: LakebasePool | None = None, instance_name: str | None = None, - project: str | None = None, - branch: str | None = None, endpoint: str | None = None, parent: str | None = None, + project: str | None = None, + branch: str | None = None, **pool_kwargs: Any, ) -> None: """ @@ -678,12 +678,12 @@ def __init__( :param pool: Existing LakebasePool to use for connections. :param instance_name: Name of the Lakebase provisioned instance. - :param project: Lakebase autoscaling project name. Also requires ``branch``. - :param branch: Lakebase autoscaling branch name. Also requires ``project``. :param endpoint: Lakebase autoscaling endpoint name. See https://databricks-sdk-py.readthedocs.io/en/latest/dbdataclasses/postgres.html#databricks.sdk.service.postgres.Endpoint :param parent: Lakebase autoscaling branch parent path (e.g., ``"projects/{project}/branches/{branch}"``). See https://databricks-sdk-py.readthedocs.io/en/latest/workspace/postgres/postgres.html#databricks.sdk.service.postgres.PostgresAPI.list_endpoints + :param project: Lakebase autoscaling project name. Also requires ``branch``. + :param branch: Lakebase autoscaling branch name. Also requires ``project``. :param pool_kwargs: Additional kwargs passed to LakebasePool (only used when creating pool internally). """ has_connection_params = ( @@ -712,10 +712,10 @@ def __init__( else: self._pool = LakebasePool( instance_name=instance_name, - project=project, - branch=branch, endpoint=endpoint, parent=parent, + project=project, + branch=branch, **pool_kwargs, ) @@ -1158,10 +1158,10 @@ def __init__( self, *, instance_name: str | None = None, - project: str | None = None, - branch: str | None = None, endpoint: str | None = None, parent: str | None = None, + project: str | None = None, + branch: str | None = None, workspace_client: WorkspaceClient | None = None, token_cache_duration_seconds: int = DEFAULT_TOKEN_CACHE_DURATION_SECONDS, pool_recycle: int = DEFAULT_POOL_RECYCLE_SECONDS, @@ -1172,12 +1172,12 @@ def __init__( Args: instance_name: Name of the Lakebase provisioned instance. - project: Lakebase autoscaling project name. Also requires ``branch``. - branch: Lakebase autoscaling branch name. Also requires ``project``. endpoint: Lakebase autoscaling endpoint name. See https://databricks-sdk-py.readthedocs.io/en/latest/dbdataclasses/postgres.html#databricks.sdk.service.postgres.Endpoint parent: Lakebase autoscaling branch parent path (e.g., ``"projects/{project}/branches/{branch}"``). See https://databricks-sdk-py.readthedocs.io/en/latest/workspace/postgres/postgres.html#databricks.sdk.service.postgres.PostgresAPI.list_endpoints + project: Lakebase autoscaling project name. Also requires ``branch``. + branch: Lakebase autoscaling branch name. Also requires ``project``. workspace_client: Optional WorkspaceClient for authentication. If not provided, a default client will be created. token_cache_duration_seconds: How long to cache OAuth tokens. @@ -1189,10 +1189,10 @@ def __init__( """ super().__init__( instance_name=instance_name, - project=project, - branch=branch, endpoint=endpoint, parent=parent, + project=project, + branch=branch, workspace_client=workspace_client, token_cache_duration_seconds=token_cache_duration_seconds, ) From 274daec06265edfd8a32c6398735ae469621528f Mon Sep 17 00:00:00 2001 From: Jenny Date: Tue, 3 Mar 2026 18:49:59 -0800 Subject: [PATCH 8/8] rm parent param --- .../src/databricks_langchain/checkpoint.py | 4 - .../src/databricks_langchain/store.py | 8 - .../tests/unit_tests/test_checkpoint.py | 57 ---- .../langchain/tests/unit_tests/test_store.py | 62 ---- .../src/databricks_openai/agents/session.py | 36 +-- .../openai/tests/unit_tests/test_session.py | 275 ------------------ src/databricks_ai_bridge/lakebase.py | 67 ++--- tests/databricks_ai_bridge/test_lakebase.py | 221 -------------- 8 files changed, 22 insertions(+), 708 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/checkpoint.py b/integrations/langchain/src/databricks_langchain/checkpoint.py index 0038a5e1..6f06b768 100644 --- a/integrations/langchain/src/databricks_langchain/checkpoint.py +++ b/integrations/langchain/src/databricks_langchain/checkpoint.py @@ -30,7 +30,6 @@ def __init__( *, instance_name: str | None = None, endpoint: str | None = None, - parent: str | None = None, project: str | None = None, branch: str | None = None, workspace_client: WorkspaceClient | None = None, @@ -46,7 +45,6 @@ def __init__( self._lakebase: LakebasePool = LakebasePool( instance_name=instance_name, endpoint=endpoint, - parent=parent, project=project, branch=branch, workspace_client=workspace_client, @@ -80,7 +78,6 @@ def __init__( *, instance_name: str | None = None, endpoint: str | None = None, - parent: str | None = None, project: str | None = None, branch: str | None = None, workspace_client: WorkspaceClient | None = None, @@ -96,7 +93,6 @@ def __init__( self._lakebase: AsyncLakebasePool = AsyncLakebasePool( instance_name=instance_name, endpoint=endpoint, - parent=parent, project=project, branch=branch, workspace_client=workspace_client, diff --git a/integrations/langchain/src/databricks_langchain/store.py b/integrations/langchain/src/databricks_langchain/store.py index 83a425ac..42c3f4a5 100644 --- a/integrations/langchain/src/databricks_langchain/store.py +++ b/integrations/langchain/src/databricks_langchain/store.py @@ -35,7 +35,6 @@ def __init__( *, instance_name: str | None = None, endpoint: str | None = None, - parent: str | None = None, project: str | None = None, branch: str | None = None, workspace_client: WorkspaceClient | None = None, @@ -51,8 +50,6 @@ def __init__( instance_name: The name of the Lakebase provisioned instance. endpoint: Lakebase autoscaling endpoint name. See https://databricks-sdk-py.readthedocs.io/en/latest/dbdataclasses/postgres.html#databricks.sdk.service.postgres.Endpoint - parent: Lakebase autoscaling branch parent string - (e.g., ``"projects/{project_id}/branches/{branch_id}"``). project: Lakebase autoscaling project name. Also requires ``branch``. branch: Lakebase autoscaling branch name. Also requires ``project``. workspace_client: Optional Databricks WorkspaceClient for authentication. @@ -75,7 +72,6 @@ def __init__( self._lakebase: LakebasePool = LakebasePool( instance_name=instance_name, endpoint=endpoint, - parent=parent, project=project, branch=branch, workspace_client=workspace_client, @@ -164,7 +160,6 @@ def __init__( *, instance_name: str | None = None, endpoint: str | None = None, - parent: str | None = None, project: str | None = None, branch: str | None = None, workspace_client: WorkspaceClient | None = None, @@ -180,8 +175,6 @@ def __init__( instance_name: The name of the Lakebase provisioned instance. endpoint: Lakebase autoscaling endpoint name. See https://databricks-sdk-py.readthedocs.io/en/latest/dbdataclasses/postgres.html#databricks.sdk.service.postgres.Endpoint - parent: Lakebase autoscaling branch parent string - (e.g., ``"projects/{project_id}/branches/{branch_id}"``). project: Lakebase autoscaling project name. Also requires ``branch``. branch: Lakebase autoscaling branch name. Also requires ``project``. workspace_client: Optional Databricks WorkspaceClient for authentication. @@ -206,7 +199,6 @@ def __init__( self._lakebase: AsyncLakebasePool = AsyncLakebasePool( instance_name=instance_name, endpoint=endpoint, - parent=parent, project=project, branch=branch, workspace_client=workspace_client, diff --git a/integrations/langchain/tests/unit_tests/test_checkpoint.py b/integrations/langchain/tests/unit_tests/test_checkpoint.py index c6790e25..5c3eb390 100644 --- a/integrations/langchain/tests/unit_tests/test_checkpoint.py +++ b/integrations/langchain/tests/unit_tests/test_checkpoint.py @@ -362,60 +362,3 @@ async def test_async_checkpoint_saver_endpoint_context_manager(monkeypatch): assert saver._lakebase._is_autoscaling is True assert test_pool._closed - - -# ============================================================================= -# Autoscaling (parent) Tests -# ============================================================================= - - -def test_checkpoint_saver_parent_configures_lakebase(monkeypatch): - test_pool = TestConnectionPool(connection_value="lake-conn") - monkeypatch.setattr(lakebase, "ConnectionPool", test_pool) - - workspace = _create_autoscaling_workspace() - - saver = CheckpointSaver( - parent="projects/p/branches/b", - workspace_client=workspace, - ) - - assert "host=auto-db-host" in test_pool.conninfo - assert saver._lakebase._is_autoscaling is True - workspace.postgres.list_endpoints.assert_called_once_with( - parent="projects/p/branches/b" - ) - - -@pytest.mark.asyncio -async def test_async_checkpoint_saver_parent_configures_lakebase(monkeypatch): - test_pool = TestAsyncConnectionPool(connection_value="async-lake-conn") - monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool) - - workspace = _create_autoscaling_workspace() - - saver = AsyncCheckpointSaver( - parent="projects/p/branches/b", - workspace_client=workspace, - ) - - assert "host=auto-db-host" in test_pool.conninfo - assert saver._lakebase._is_autoscaling is True - - -@pytest.mark.asyncio -async def test_async_checkpoint_saver_parent_context_manager(monkeypatch): - test_pool = TestAsyncConnectionPool(connection_value="async-lake-conn") - monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool) - monkeypatch.setattr(AsyncCheckpointSaver, "setup", lambda self: _async_noop()) - - workspace = _create_autoscaling_workspace() - - async with AsyncCheckpointSaver( - parent="projects/p/branches/b", - workspace_client=workspace, - ) as saver: - assert test_pool._opened - assert saver._lakebase._is_autoscaling is True - - assert test_pool._closed diff --git a/integrations/langchain/tests/unit_tests/test_store.py b/integrations/langchain/tests/unit_tests/test_store.py index fe49a782..233b4430 100644 --- a/integrations/langchain/tests/unit_tests/test_store.py +++ b/integrations/langchain/tests/unit_tests/test_store.py @@ -639,65 +639,3 @@ async def test_async_databricks_store_endpoint_context_manager(monkeypatch): assert store._lakebase._is_autoscaling is True assert test_pool._closed - - -# ============================================================================= -# Autoscaling (parent) Tests -# ============================================================================= - - -def test_databricks_store_parent_configures_lakebase(monkeypatch): - """Test that DatabricksStore with parent uses autoscaling path.""" - mock_conn = MagicMock() - test_pool = TestConnectionPool(connection_value=mock_conn) - monkeypatch.setattr(lakebase, "ConnectionPool", test_pool) - - workspace = _create_autoscaling_workspace() - - store = DatabricksStore( - parent="projects/p/branches/b", - workspace_client=workspace, - ) - - assert "host=auto-db-host" in test_pool.conninfo - assert store._lakebase._is_autoscaling is True - workspace.postgres.list_endpoints.assert_called_once_with( - parent="projects/p/branches/b" - ) - - -@pytest.mark.asyncio -async def test_async_databricks_store_parent_configures_lakebase(monkeypatch): - """Test that AsyncDatabricksStore with parent uses autoscaling path.""" - mock_conn = MagicMock() - test_pool = TestAsyncConnectionPool(connection_value=mock_conn) - monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool) - - workspace = _create_autoscaling_workspace() - - store = AsyncDatabricksStore( - parent="projects/p/branches/b", - workspace_client=workspace, - ) - - assert "host=auto-db-host" in test_pool.conninfo - assert store._lakebase._is_autoscaling is True - - -@pytest.mark.asyncio -async def test_async_databricks_store_parent_context_manager(monkeypatch): - """Test parent async store context manager opens and closes the pool.""" - mock_conn = MagicMock() - test_pool = TestAsyncConnectionPool(connection_value=mock_conn) - monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool) - - workspace = _create_autoscaling_workspace() - - async with AsyncDatabricksStore( - parent="projects/p/branches/b", - workspace_client=workspace, - ) as store: - assert test_pool._opened - assert store._lakebase._is_autoscaling is True - - assert test_pool._closed diff --git a/integrations/openai/src/databricks_openai/agents/session.py b/integrations/openai/src/databricks_openai/agents/session.py index 7e8bf37a..7d5ecee4 100644 --- a/integrations/openai/src/databricks_openai/agents/session.py +++ b/integrations/openai/src/databricks_openai/agents/session.py @@ -100,7 +100,6 @@ def __init__( *, instance_name: Optional[str] = None, endpoint: Optional[str] = None, - parent: Optional[str] = None, project: Optional[str] = None, branch: Optional[str] = None, workspace_client: Optional[WorkspaceClient] = None, @@ -119,8 +118,6 @@ def __init__( instance_name: Name of the Lakebase provisioned instance. endpoint: Lakebase autoscaling endpoint name. See https://databricks-sdk-py.readthedocs.io/en/latest/dbdataclasses/postgres.html#databricks.sdk.service.postgres.Endpoint - parent: Lakebase autoscaling branch parent string - (e.g., ``"projects/{project_id}/branches/{branch_id}"``). project: Lakebase autoscaling project name. Also requires ``branch``. branch: Lakebase autoscaling branch name. Also requires ``project``. workspace_client: Optional WorkspaceClient for authentication. @@ -147,45 +144,27 @@ def __init__( # Validate connection parameters early (before cache key creation) is_autoscaling_branch = project is not None or branch is not None - is_autoscaling_parent = parent is not None is_autoscaling_endpoint = endpoint is not None - # Log when multiple autoscaling params given (higher priority wins) - if is_autoscaling_endpoint and (is_autoscaling_parent or is_autoscaling_branch): + if is_autoscaling_endpoint and is_autoscaling_branch: logger.info( - "endpoint given alongside other autoscaling parameters " + "project, branch, and endpoint given for autoscaling instance " "- using endpoint value" ) - elif is_autoscaling_parent and is_autoscaling_branch: - logger.info( - "parent given alongside project/branch " - "- using parent value" - ) if is_autoscaling_endpoint and instance_name is not None: raise ValueError( "Cannot provide both 'endpoint' and 'instance_name'. " "Use 'endpoint' for autoscaling or 'instance_name' for provisioned." ) - if is_autoscaling_parent and instance_name is not None: - raise ValueError( - "Cannot provide both 'parent' and 'instance_name'. " - "Use 'parent' for autoscaling or 'instance_name' for provisioned." - ) if is_autoscaling_branch and not (project and branch): raise ValueError( "Both 'project' and 'branch' are required to use a Lakebase " "autoscaling instance. Please specify both parameters." ) - if ( - not is_autoscaling_branch - and not is_autoscaling_parent - and not is_autoscaling_endpoint - and instance_name is None - ): + if not is_autoscaling_branch and not is_autoscaling_endpoint and instance_name is None: raise ValueError( "Must provide either 'instance_name' (provisioned), both " - "'project' and 'branch' (autoscaling), 'parent' (autoscaling), " - "or 'endpoint' (autoscaling)." + "'project' and 'branch' (autoscaling), or 'endpoint' (autoscaling)." ) if is_autoscaling_branch and instance_name is not None: raise ValueError( @@ -197,7 +176,6 @@ def __init__( self._lakebase = self._get_or_create_lakebase( instance_name=instance_name, endpoint=endpoint, - parent=parent, project=project, branch=branch, workspace_client=workspace_client, @@ -226,7 +204,6 @@ def _build_cache_key( cls, instance_name: Optional[str] = None, endpoint: Optional[str] = None, - parent: Optional[str] = None, project: Optional[str] = None, branch: Optional[str] = None, **engine_kwargs: Any, @@ -236,8 +213,6 @@ def _build_cache_key( kwargs_key = json.dumps(engine_kwargs, sort_keys=True, default=str) if endpoint: return f"endpoint::{endpoint}::{kwargs_key}" - if parent: - return f"parent::{parent}::{kwargs_key}" if project and branch: return f"autoscaling::{project}::{branch}::{kwargs_key}" return f"provisioned::{instance_name}::{kwargs_key}" @@ -248,7 +223,6 @@ def _get_or_create_lakebase( *, instance_name: Optional[str], endpoint: Optional[str] = None, - parent: Optional[str] = None, project: Optional[str] = None, branch: Optional[str] = None, workspace_client: Optional[WorkspaceClient], @@ -263,7 +237,6 @@ def _get_or_create_lakebase( cache_key = cls._build_cache_key( instance_name=instance_name, endpoint=endpoint, - parent=parent, project=project, branch=branch, pool_recycle=pool_recycle, @@ -279,7 +252,6 @@ def _get_or_create_lakebase( lakebase = AsyncLakebaseSQLAlchemy( instance_name=instance_name, endpoint=endpoint, - parent=parent, project=project, branch=branch, workspace_client=workspace_client, diff --git a/integrations/openai/tests/unit_tests/test_session.py b/integrations/openai/tests/unit_tests/test_session.py index b2217185..ea4de77d 100644 --- a/integrations/openai/tests/unit_tests/test_session.py +++ b/integrations/openai/tests/unit_tests/test_session.py @@ -1384,278 +1384,3 @@ def test_different_endpoints_get_different_engines( assert mock_create_engine.call_count == 2 assert session1._engine is not session2._engine - - -# ============================================================================= -# Autoscaling (parent) Tests -# ============================================================================= - - -@pytest.fixture -def mock_parent_workspace_client(): - """Create a mock WorkspaceClient for parent mode.""" - mock_client = MagicMock() - mock_client.config.host = "https://test.databricks.com" - - # Mock current_user.me() for username inference - mock_user = MagicMock() - mock_user.user_name = "test_user@databricks.com" - mock_client.current_user.me.return_value = mock_user - - # Mock postgres.list_endpoints → returns one READ_WRITE endpoint - rw_endpoint = MagicMock() - rw_endpoint.name = "projects/my-project/branches/my-branch/endpoints/rw-ep" - rw_endpoint.status.endpoint_type = "READ_WRITE" - rw_endpoint.status.hosts.host = "parent-instance.lakebase.databricks.com" - mock_client.postgres.list_endpoints.return_value = [rw_endpoint] - - # Mock postgres.generate_database_credential for autoscaling token minting - mock_credential = MagicMock() - mock_credential.token = "parent-oauth-token" - mock_client.postgres.generate_database_credential.return_value = mock_credential - - return mock_client - - -class TestAsyncDatabricksSessionParent: - """Tests for AsyncDatabricksSession with parent string.""" - - def test_init_parent_resolves_host( - self, mock_parent_workspace_client, mock_engine, mock_event_listens_for - ): - """Test that initialization with parent resolves host via list_endpoints API.""" - with ( - patch( - "databricks_ai_bridge.lakebase.WorkspaceClient", - return_value=mock_parent_workspace_client, - ), - patch( - "sqlalchemy.ext.asyncio.create_async_engine", - return_value=mock_engine, - ) as mock_create_engine, - patch( - "sqlalchemy.event.listens_for", - side_effect=mock_event_listens_for, - ), - ): - from databricks_openai.agents.session import AsyncDatabricksSession - - AsyncDatabricksSession( - session_id="test-session-123", - parent="projects/my-project/branches/my-branch", - workspace_client=mock_parent_workspace_client, - ) - - # Verify engine URL uses parent-resolved host - call_args = mock_create_engine.call_args - url = call_args[0][0] - assert url.host == "parent-instance.lakebase.databricks.com" - - # Verify list_endpoints API was called with parent - mock_parent_workspace_client.postgres.list_endpoints.assert_called_once_with( - parent="projects/my-project/branches/my-branch" - ) - - def test_init_parent_injects_correct_token( - self, mock_parent_workspace_client, mock_engine - ): - """Test that do_connect injects parent token.""" - captured_handler = None - - def capture_handler(engine, event_name): - def decorator(fn): - nonlocal captured_handler - captured_handler = fn - return fn - - return decorator - - with ( - patch( - "databricks_ai_bridge.lakebase.WorkspaceClient", - return_value=mock_parent_workspace_client, - ), - patch( - "sqlalchemy.ext.asyncio.create_async_engine", - return_value=mock_engine, - ), - patch( - "sqlalchemy.event.listens_for", - side_effect=capture_handler, - ), - ): - from databricks_openai.agents.session import AsyncDatabricksSession - - AsyncDatabricksSession( - session_id="test-session-123", - parent="projects/my-project/branches/my-branch", - workspace_client=mock_parent_workspace_client, - ) - - # Simulate do_connect event - assert captured_handler is not None - cparams = {} - captured_handler(None, None, None, cparams) - - # Verify parent token was injected - assert cparams["password"] == "parent-oauth-token" - mock_parent_workspace_client.postgres.generate_database_credential.assert_called() - - def test_parent_sessions_share_engine( - self, mock_parent_workspace_client, mock_engine, mock_event_listens_for - ): - """Test that parent sessions with same parent share an engine.""" - with ( - patch( - "databricks_ai_bridge.lakebase.WorkspaceClient", - return_value=mock_parent_workspace_client, - ), - patch( - "sqlalchemy.ext.asyncio.create_async_engine", - return_value=mock_engine, - ) as mock_create_engine, - patch( - "sqlalchemy.event.listens_for", - side_effect=mock_event_listens_for, - ), - ): - from databricks_openai.agents.session import AsyncDatabricksSession - - session1 = AsyncDatabricksSession( - session_id="session-1", - parent="projects/my-project/branches/my-branch", - workspace_client=mock_parent_workspace_client, - ) - session2 = AsyncDatabricksSession( - session_id="session-2", - parent="projects/my-project/branches/my-branch", - workspace_client=mock_parent_workspace_client, - ) - - # Engine should only be created once - assert mock_create_engine.call_count == 1 - assert session1._engine is session2._engine - - def test_different_parents_get_different_engines( - self, mock_parent_workspace_client, mock_event_listens_for - ): - """Test that sessions with different parents get different engines.""" - engine1 = MagicMock() - engine1.sync_engine = MagicMock() - engine2 = MagicMock() - engine2.sync_engine = MagicMock() - - engines = [engine1, engine2] - engine_iter = iter(engines) - - with ( - patch( - "databricks_ai_bridge.lakebase.WorkspaceClient", - return_value=mock_parent_workspace_client, - ), - patch( - "sqlalchemy.ext.asyncio.create_async_engine", - side_effect=lambda *args, **kwargs: next(engine_iter), - ) as mock_create_engine, - patch( - "sqlalchemy.event.listens_for", - side_effect=mock_event_listens_for, - ), - ): - from databricks_openai.agents.session import AsyncDatabricksSession - - session1 = AsyncDatabricksSession( - session_id="session-1", - parent="projects/p1/branches/b1", - workspace_client=mock_parent_workspace_client, - ) - session2 = AsyncDatabricksSession( - session_id="session-2", - parent="projects/p2/branches/b2", - workspace_client=mock_parent_workspace_client, - ) - - assert mock_create_engine.call_count == 2 - assert session1._engine is not session2._engine - - -class TestAsyncDatabricksSessionParentValidation: - """Validation tests for parent parameter.""" - - def test_parent_plus_instance_raises_error(self): - """AsyncDatabricksSession with parent + instance_name raises ValueError.""" - from databricks_openai.agents.session import AsyncDatabricksSession - - workspace = MagicMock() - workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com") - - with pytest.raises(ValueError, match="Cannot provide both 'parent' and 'instance_name'"): - AsyncDatabricksSession( - session_id="test-session", - parent="projects/p/branches/b", - instance_name="my-instance", - workspace_client=workspace, - ) - - def test_parent_plus_project_branch_uses_parent( - self, mock_parent_workspace_client, mock_engine, mock_event_listens_for - ): - """AsyncDatabricksSession with parent + project/branch uses parent.""" - with ( - patch( - "databricks_ai_bridge.lakebase.WorkspaceClient", - return_value=mock_parent_workspace_client, - ), - patch( - "sqlalchemy.ext.asyncio.create_async_engine", - return_value=mock_engine, - ), - patch( - "sqlalchemy.event.listens_for", - side_effect=mock_event_listens_for, - ), - ): - from databricks_openai.agents.session import AsyncDatabricksSession - - AsyncDatabricksSession( - session_id="test-session", - parent="projects/p/branches/b", - project="my-project", - branch="my-branch", - workspace_client=mock_parent_workspace_client, - ) - - # Should use list_endpoints (parent path), not get_endpoint - mock_parent_workspace_client.postgres.list_endpoints.assert_called() - mock_parent_workspace_client.postgres.get_endpoint.assert_not_called() - - def test_endpoint_plus_parent_uses_endpoint( - self, mock_endpoint_workspace_client, mock_engine, mock_event_listens_for - ): - """AsyncDatabricksSession with endpoint + parent uses endpoint.""" - with ( - patch( - "databricks_ai_bridge.lakebase.WorkspaceClient", - return_value=mock_endpoint_workspace_client, - ), - patch( - "sqlalchemy.ext.asyncio.create_async_engine", - return_value=mock_engine, - ), - patch( - "sqlalchemy.event.listens_for", - side_effect=mock_event_listens_for, - ), - ): - from databricks_openai.agents.session import AsyncDatabricksSession - - AsyncDatabricksSession( - session_id="test-session", - endpoint="projects/p/branches/b/endpoints/rw", - parent="projects/p/branches/b", - workspace_client=mock_endpoint_workspace_client, - ) - - # Should use get_endpoint, not list_endpoints - mock_endpoint_workspace_client.postgres.get_endpoint.assert_called() - mock_endpoint_workspace_client.postgres.list_endpoints.assert_not_called() diff --git a/src/databricks_ai_bridge/lakebase.py b/src/databricks_ai_bridge/lakebase.py index c115bbb4..114df915 100644 --- a/src/databricks_ai_bridge/lakebase.py +++ b/src/databricks_ai_bridge/lakebase.py @@ -97,10 +97,10 @@ class _LakebaseBase: https://docs.databricks.com/aws/en/oltp/#feature-comparison - **Provisioned**: Pass ``instance_name``. - - **Autoscaling**: Pass ``endpoint``, ``parent``, or ``project`` and ``branch``. + - **Autoscaling**: Pass ``endpoint``, or ``project`` and ``branch``. Provisioned and autoscaling are mutually exclusive. - Within autoscaling, priority is: ``endpoint`` > ``parent`` > ``project``/``branch``. + Within autoscaling, priority is: ``endpoint`` > ``project``/``branch``. Subclasses implement specific initialization and lifecycle methods. """ @@ -110,7 +110,6 @@ def __init__( *, instance_name: str | None = None, endpoint: str | None = None, - parent: str | None = None, project: str | None = None, branch: str | None = None, workspace_client: WorkspaceClient | None = None, @@ -120,23 +119,17 @@ def __init__( self.token_cache_duration_seconds: int = token_cache_duration_seconds # --- Parameter validation --- - # Autoscaling priority: endpoint > parent > project+branch + # Autoscaling priority: endpoint > project+branch is_autoscaling_branch = project is not None or branch is not None - is_autoscaling_parent = parent is not None is_autoscaling_endpoint = endpoint is not None - is_autoscaling = is_autoscaling_endpoint or is_autoscaling_parent or is_autoscaling_branch + is_autoscaling = is_autoscaling_endpoint or is_autoscaling_branch # Log when multiple autoscaling params given (higher priority wins) - if is_autoscaling_endpoint and (is_autoscaling_parent or is_autoscaling_branch): + if is_autoscaling_endpoint and is_autoscaling_branch: logger.info( - "endpoint given alongside other autoscaling parameters " + "project, branch, and endpoint given for autoscaling instance " "- using endpoint value" ) - elif is_autoscaling_parent and is_autoscaling_branch: - logger.info( - "parent given alongside project/branch " - "- using parent value" - ) # Autoscaling vs provisioned conflicts if is_autoscaling_endpoint and instance_name is not None: @@ -145,13 +138,7 @@ def __init__( "Use 'endpoint' for autoscaling or 'instance_name' for provisioned." ) - if is_autoscaling_parent and instance_name is not None: - raise ValueError( - "Cannot provide both 'parent' and 'instance_name'. " - "Use 'parent' for autoscaling or 'instance_name' for provisioned." - ) - - if is_autoscaling_branch and not is_autoscaling_endpoint and not is_autoscaling_parent: + if is_autoscaling_branch and not is_autoscaling_endpoint: if not (project and branch): raise ValueError( "Both 'project' and 'branch' are required to use a Lakebase " @@ -161,10 +148,10 @@ def __init__( if not is_autoscaling and instance_name is None: raise ValueError( "Must provide either 'instance_name' (provisioned), " - "'endpoint', 'parent', or both 'project' and 'branch' (autoscaling)." + "'endpoint', or both 'project' and 'branch' (autoscaling)." ) - if is_autoscaling_branch and instance_name is not None and not is_autoscaling_endpoint and not is_autoscaling_parent: + if is_autoscaling_branch and instance_name is not None and not is_autoscaling_endpoint: raise ValueError( "Cannot provide both 'instance_name' (provisioned) and " "'project'/'branch' (autoscaling). Choose one mode." @@ -175,12 +162,11 @@ def __init__( self.instance_name: str | None = instance_name self.project: str | None = project self.branch: str | None = branch - self.parent: str | None = parent if is_autoscaling_endpoint: self._endpoint_name: str | None = endpoint self.host = self._resolve_endpoint_host() - elif is_autoscaling_parent or is_autoscaling_branch: + elif is_autoscaling_branch: self._endpoint_name = None self.host = self._resolve_autoscaling_host() else: @@ -222,16 +208,12 @@ def _resolve_provisioned_host(self) -> str: def _resolve_autoscaling_host(self) -> str: """Resolve host via the Lakebase autoscaling postgres API. - Uses ``self.parent`` directly if set, otherwise constructs the branch - parent path from ``self.project`` and ``self.branch``. Lists endpoints, - finds the READ_WRITE endpoint, and extracts the host and endpoint name. + Constructs the branch parent path from ``self.project`` and ``self.branch``, + lists endpoints, finds the READ_WRITE endpoint, and extracts the host and endpoint name. See https://databricks-sdk-py.readthedocs.io/en/latest/workspace/postgres/postgres.html#databricks.sdk.service.postgres.PostgresAPI.list_endpoints """ - if self.parent: - branch_parent = self.parent - else: - branch_parent = f"projects/{self.project}/branches/{self.branch}" + branch_parent = f"projects/{self.project}/branches/{self.branch}" try: endpoints = list(self.workspace_client.postgres.list_endpoints(parent=branch_parent)) @@ -376,7 +358,7 @@ class LakebasePool(_LakebaseBase): https://docs.databricks.com/aws/en/oltp/#feature-comparison - **Provisioned**: Pass ``instance_name``. - - **Autoscaling**: Pass ``endpoint``, ``parent``, or ``project`` and ``branch``. + - **Autoscaling**: Pass ``endpoint``, or ``project`` and ``branch``. """ def __init__( @@ -384,7 +366,6 @@ def __init__( *, instance_name: str | None = None, endpoint: str | None = None, - parent: str | None = None, project: str | None = None, branch: str | None = None, workspace_client: WorkspaceClient | None = None, @@ -394,7 +375,6 @@ def __init__( super().__init__( instance_name=instance_name, endpoint=endpoint, - parent=parent, project=project, branch=branch, workspace_client=workspace_client, @@ -482,7 +462,7 @@ class AsyncLakebasePool(_LakebaseBase): https://docs.databricks.com/aws/en/oltp/#feature-comparison - **Provisioned**: Pass ``instance_name``. - - **Autoscaling**: Pass ``endpoint``, ``parent``, or ``project`` and ``branch``. + - **Autoscaling**: Pass ``endpoint``, or ``project`` and ``branch``. """ def __init__( @@ -490,7 +470,6 @@ def __init__( *, instance_name: str | None = None, endpoint: str | None = None, - parent: str | None = None, project: str | None = None, branch: str | None = None, workspace_client: WorkspaceClient | None = None, @@ -500,7 +479,6 @@ def __init__( super().__init__( instance_name=instance_name, endpoint=endpoint, - parent=parent, project=project, branch=branch, workspace_client=workspace_client, @@ -663,7 +641,6 @@ def __init__( pool: LakebasePool | None = None, instance_name: str | None = None, endpoint: str | None = None, - parent: str | None = None, project: str | None = None, branch: str | None = None, **pool_kwargs: Any, @@ -674,14 +651,12 @@ def __init__( Provide EITHER: - pool: An existing LakebasePool instance (advanced usage where multiple clients can connect to same pool) - instance_name: Name of the Lakebase provisioned instance - - endpoint, parent, or project + branch: Lakebase autoscaling + - endpoint, or project + branch: Lakebase autoscaling :param pool: Existing LakebasePool to use for connections. :param instance_name: Name of the Lakebase provisioned instance. :param endpoint: Lakebase autoscaling endpoint name. See https://databricks-sdk-py.readthedocs.io/en/latest/dbdataclasses/postgres.html#databricks.sdk.service.postgres.Endpoint - :param parent: Lakebase autoscaling branch parent path (e.g., ``"projects/{project}/branches/{branch}"``). - See https://databricks-sdk-py.readthedocs.io/en/latest/workspace/postgres/postgres.html#databricks.sdk.service.postgres.PostgresAPI.list_endpoints :param project: Lakebase autoscaling project name. Also requires ``branch``. :param branch: Lakebase autoscaling branch name. Also requires ``project``. :param pool_kwargs: Additional kwargs passed to LakebasePool (only used when creating pool internally). @@ -691,18 +666,17 @@ def __init__( or project is not None or branch is not None or endpoint is not None - or parent is not None ) if pool is not None and has_connection_params: raise ValueError( "Provide either 'pool' or connection parameters " - "('instance_name', 'endpoint', 'parent', or 'project'/'branch'), not both." + "('instance_name', 'endpoint', or 'project'/'branch'), not both." ) if pool is None and not has_connection_params: raise ValueError( "Must provide 'pool', 'instance_name' (provisioned), " - "'endpoint', 'parent', or both 'project' and 'branch' (autoscaling)." + "'endpoint', or both 'project' and 'branch' (autoscaling)." ) self._owns_pool = pool is None @@ -713,7 +687,6 @@ def __init__( self._pool = LakebasePool( instance_name=instance_name, endpoint=endpoint, - parent=parent, project=project, branch=branch, **pool_kwargs, @@ -1159,7 +1132,6 @@ def __init__( *, instance_name: str | None = None, endpoint: str | None = None, - parent: str | None = None, project: str | None = None, branch: str | None = None, workspace_client: WorkspaceClient | None = None, @@ -1174,8 +1146,6 @@ def __init__( instance_name: Name of the Lakebase provisioned instance. endpoint: Lakebase autoscaling endpoint name. See https://databricks-sdk-py.readthedocs.io/en/latest/dbdataclasses/postgres.html#databricks.sdk.service.postgres.Endpoint - parent: Lakebase autoscaling branch parent path (e.g., ``"projects/{project}/branches/{branch}"``). - See https://databricks-sdk-py.readthedocs.io/en/latest/workspace/postgres/postgres.html#databricks.sdk.service.postgres.PostgresAPI.list_endpoints project: Lakebase autoscaling project name. Also requires ``branch``. branch: Lakebase autoscaling branch name. Also requires ``project``. workspace_client: Optional WorkspaceClient for authentication. @@ -1190,7 +1160,6 @@ def __init__( super().__init__( instance_name=instance_name, endpoint=endpoint, - parent=parent, project=project, branch=branch, workspace_client=workspace_client, diff --git a/tests/databricks_ai_bridge/test_lakebase.py b/tests/databricks_ai_bridge/test_lakebase.py index a372f8a7..153278f3 100644 --- a/tests/databricks_ai_bridge/test_lakebase.py +++ b/tests/databricks_ai_bridge/test_lakebase.py @@ -1884,224 +1884,3 @@ def test_endpoint_no_host_raises(): endpoint="projects/p/branches/b/endpoints/rw", workspace_client=workspace, ) - - -# ============================================================================= -# Autoscaling (parent) Tests -# ============================================================================= - - -# --- LakebasePool parent tests --- - - -def test_lakebase_pool_parent_configures_connection_pool(monkeypatch): - """LakebasePool with parent resolves host via list_endpoints with parent directly.""" - TestConnectionPool = _make_connection_pool_class() - monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) - - workspace = _make_autoscaling_workspace(host="parent.db.host") - - pool = LakebasePool( - parent="projects/my-project/branches/my-branch", - workspace_client=workspace, - ) - - assert pool.host == "parent.db.host" - assert pool._is_autoscaling is True - assert pool.username == "test@databricks.com" - assert "host=parent.db.host" in str(pool.pool.conninfo) - - workspace.postgres.list_endpoints.assert_called_once_with( - parent="projects/my-project/branches/my-branch" - ) - # get_endpoint should NOT be called - workspace.postgres.get_endpoint.assert_not_called() - - -def test_lakebase_pool_parent_mints_token(monkeypatch): - """Parent pool uses postgres.generate_database_credential for tokens.""" - TestConnectionPool = _make_connection_pool_class() - monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) - - workspace = _make_autoscaling_workspace(credential_token="parent-token") - - pool = LakebasePool( - parent="projects/my-project/branches/my-branch", - workspace_client=workspace, - ) - - token = pool._get_token() - assert token == "parent-token" - workspace.postgres.generate_database_credential.assert_called() - # Provisioned credential API should NOT be called - workspace.database.generate_database_credential.assert_not_called() - - -# --- AsyncLakebasePool parent tests --- - - -@pytest.mark.asyncio -async def test_async_lakebase_pool_parent_configures_pool(monkeypatch): - """AsyncLakebasePool with parent resolves host via list_endpoints.""" - TestAsyncConnectionPool = _make_async_connection_pool_class() - monkeypatch.setattr( - "databricks_ai_bridge.lakebase.AsyncConnectionPool", TestAsyncConnectionPool - ) - - workspace = _make_autoscaling_workspace(host="async-parent.db.host") - - pool = AsyncLakebasePool( - parent="projects/my-project/branches/my-branch", - workspace_client=workspace, - ) - - assert pool.host == "async-parent.db.host" - assert pool._is_autoscaling is True - assert "host=async-parent.db.host" in str(pool.pool.conninfo) - - workspace.postgres.list_endpoints.assert_called_once_with( - parent="projects/my-project/branches/my-branch" - ) - - -# --- LakebaseClient parent tests --- - - -def test_lakebase_client_parent_creates_pool(monkeypatch): - """LakebaseClient with parent creates a pool internally.""" - TestConnectionPool = _make_connection_pool_class() - monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) - - workspace = _make_autoscaling_workspace(host="client-parent.db.host") - - client = LakebaseClient( - parent="projects/my-project/branches/my-branch", - workspace_client=workspace, - ) - - assert client.pool.host == "client-parent.db.host" - assert client.pool._is_autoscaling is True - assert client._owns_pool is True - - -# --- AsyncLakebaseSQLAlchemy parent tests --- - - -def test_async_lakebase_sqlalchemy_parent_resolves_host(): - """AsyncLakebaseSQLAlchemy with parent resolves via list_endpoints.""" - workspace = _make_autoscaling_workspace(host="sa-parent.db.host") - patch_engine, patch_event, _ = _make_sqlalchemy_patches(workspace) - - with patch_engine, patch_event: - sa = AsyncLakebaseSQLAlchemy( - parent="projects/my-project/branches/my-branch", - workspace_client=workspace, - ) - - assert sa.host == "sa-parent.db.host" - assert sa._is_autoscaling is True - workspace.postgres.list_endpoints.assert_called_once_with( - parent="projects/my-project/branches/my-branch" - ) - - -def test_async_lakebase_sqlalchemy_parent_mints_correct_token(): - """AsyncLakebaseSQLAlchemy in parent mode uses postgres credential API.""" - workspace = _make_autoscaling_workspace(credential_token="sa-parent-token") - patch_engine, patch_event, _ = _make_sqlalchemy_patches(workspace) - - with patch_engine, patch_event: - sa = AsyncLakebaseSQLAlchemy( - parent="projects/my-project/branches/my-branch", - workspace_client=workspace, - ) - - token = sa.get_token() - assert token == "sa-parent-token" - workspace.postgres.generate_database_credential.assert_called() - - -# --- Validation: parent conflicts --- - - -def test_parent_plus_project_branch_uses_parent(monkeypatch, caplog): - """Passing both parent and project/branch logs info and uses parent.""" - TestConnectionPool = _make_connection_pool_class() - monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) - - workspace = _make_autoscaling_workspace(host="parent.db.host") - - with caplog.at_level(logging.INFO): - pool = LakebasePool( - parent="projects/my-project/branches/my-branch", - project="my-project", - branch="my-branch", - workspace_client=workspace, - ) - - assert pool.host == "parent.db.host" - assert pool._is_autoscaling is True - # Should use list_endpoints with parent string directly - workspace.postgres.list_endpoints.assert_called_once_with( - parent="projects/my-project/branches/my-branch" - ) - assert any("using parent value" in record.message for record in caplog.records) - - -def test_parent_plus_instance_name_raises_error(): - """Passing both parent and instance_name raises ValueError.""" - workspace = _make_autoscaling_workspace() - with pytest.raises( - ValueError, - match="Cannot provide both 'parent' and 'instance_name'", - ): - LakebasePool( - parent="projects/my-project/branches/my-branch", - instance_name="my-instance", - workspace_client=workspace, - ) - - -def test_endpoint_plus_parent_uses_endpoint(monkeypatch, caplog): - """Passing both endpoint and parent logs info and uses endpoint.""" - TestConnectionPool = _make_connection_pool_class() - monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) - - workspace = _make_endpoint_workspace(host="ep.db.host") - - with caplog.at_level(logging.INFO): - pool = LakebasePool( - endpoint="projects/p/branches/b/endpoints/rw", - parent="projects/my-project/branches/my-branch", - workspace_client=workspace, - ) - - assert pool.host == "ep.db.host" - assert pool._is_autoscaling is True - # Should use get_endpoint, not list_endpoints - workspace.postgres.get_endpoint.assert_called_once() - workspace.postgres.list_endpoints.assert_not_called() - assert any("using endpoint value" in record.message for record in caplog.records) - - -def test_endpoint_plus_parent_plus_project_branch_uses_endpoint(monkeypatch, caplog): - """Passing all three autoscaling params logs info and uses endpoint.""" - TestConnectionPool = _make_connection_pool_class() - monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) - - workspace = _make_endpoint_workspace(host="ep.db.host") - - with caplog.at_level(logging.INFO): - pool = LakebasePool( - endpoint="projects/p/branches/b/endpoints/rw", - parent="projects/my-project/branches/my-branch", - project="my-project", - branch="my-branch", - workspace_client=workspace, - ) - - assert pool.host == "ep.db.host" - assert pool._is_autoscaling is True - workspace.postgres.get_endpoint.assert_called_once() - workspace.postgres.list_endpoints.assert_not_called() - assert any("using endpoint value" in record.message for record in caplog.records)