diff --git a/integrations/langchain/src/databricks_langchain/checkpoint.py b/integrations/langchain/src/databricks_langchain/checkpoint.py index eb0ac7d8..93ff381c 100644 --- a/integrations/langchain/src/databricks_langchain/checkpoint.py +++ b/integrations/langchain/src/databricks_langchain/checkpoint.py @@ -21,13 +21,16 @@ class CheckpointSaver(PostgresSaver): """ LangGraph PostgresSaver using a Lakebase connection pool. - instance_name: Name of Lakebase Instance + Supports two modes: Lakebase Provisioned VS Autoscaling + https://docs.databricks.com/aws/en/oltp/#feature-comparison """ def __init__( self, *, - instance_name: str, + instance_name: str | None = None, + project: str | None = None, + branch: str | None = None, workspace_client: WorkspaceClient | None = None, **pool_kwargs: Any, ) -> None: @@ -40,13 +43,16 @@ def __init__( self._lakebase: LakebasePool = LakebasePool( instance_name=instance_name, + project=project, + branch=branch, workspace_client=workspace_client, **dict(pool_kwargs), ) super().__init__(self._lakebase.pool) def __enter__(self): - """Enter context manager.""" + """Enter context manager and create checkpoint tables.""" + self.setup() return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -59,13 +65,18 @@ class AsyncCheckpointSaver(AsyncPostgresSaver): """ Async LangGraph PostgresSaver using a Lakebase connection pool. - instance_name: Name of Lakebase Instance + Supports two modes: Lakebase Provisioned VS Autoscaling + https://docs.databricks.com/aws/en/oltp/#feature-comparison + + Checkpoint tables are created automatically when entering the context manager. """ def __init__( self, *, - instance_name: str, + instance_name: str | None = None, + project: str | None = None, + branch: str | None = None, workspace_client: WorkspaceClient | None = None, **pool_kwargs: Any, ) -> None: @@ -78,14 +89,17 @@ def __init__( self._lakebase: AsyncLakebasePool = AsyncLakebasePool( instance_name=instance_name, + project=project, + branch=branch, workspace_client=workspace_client, **dict(pool_kwargs), ) super().__init__(self._lakebase.pool) async def __aenter__(self): - """Enter async context manager and open the connection pool.""" + """Enter async context manager, open the connection pool, and create checkpoint tables.""" await self._lakebase.open() + await self.setup() return self async def __aexit__(self, exc_type, exc_val, exc_tb): diff --git a/integrations/langchain/src/databricks_langchain/store.py b/integrations/langchain/src/databricks_langchain/store.py index 05aed3b8..78f8ae98 100644 --- a/integrations/langchain/src/databricks_langchain/store.py +++ b/integrations/langchain/src/databricks_langchain/store.py @@ -33,7 +33,9 @@ class DatabricksStore(BaseStore): def __init__( self, *, - instance_name: str, + instance_name: str | None = None, + project: str | None = None, + branch: str | None = None, workspace_client: WorkspaceClient | None = None, embedding_endpoint: str | None = None, embedding_dims: int | None = None, @@ -44,7 +46,9 @@ def __init__( """Initialize DatabricksStore with embedding support. Args: - instance_name: The name of the Lakebase instance to connect to. + instance_name: The name of the Lakebase provisioned instance. + project: Lakebase autoscaling project name. Also requires ``branch``. + branch: Lakebase autoscaling branch name. Also requires ``project``. workspace_client: Optional Databricks WorkspaceClient for authentication. embedding_endpoint: Name of the Databricks Model Serving endpoint for embeddings (e.g., "databricks-gte-large-en"). If provided, enables semantic search. @@ -64,6 +68,8 @@ def __init__( self._lakebase: LakebasePool = LakebasePool( instance_name=instance_name, + project=project, + branch=branch, workspace_client=workspace_client, **pool_kwargs, ) @@ -148,7 +154,9 @@ class AsyncDatabricksStore(AsyncBatchedBaseStore): def __init__( self, *, - instance_name: str, + instance_name: str | None = None, + project: str | None = None, + branch: str | None = None, workspace_client: WorkspaceClient | None = None, embedding_endpoint: str | None = None, embedding_dims: int | None = None, @@ -159,7 +167,9 @@ def __init__( """Initialize AsyncDatabricksStore with embedding support. Args: - instance_name: The name of the Lakebase instance to connect to. + instance_name: The name of the Lakebase provisioned instance. + project: Lakebase autoscaling project name. Also requires ``branch``. + branch: Lakebase autoscaling branch name. Also requires ``project``. workspace_client: Optional Databricks WorkspaceClient for authentication. embedding_endpoint: Name of the Databricks Model Serving endpoint for embeddings (e.g., "databricks-gte-large-en"). If provided, enables semantic search. @@ -181,6 +191,8 @@ def __init__( self._lakebase: AsyncLakebasePool = AsyncLakebasePool( instance_name=instance_name, + project=project, + branch=branch, workspace_client=workspace_client, **pool_kwargs, ) diff --git a/integrations/langchain/tests/unit_tests/test_checkpoint.py b/integrations/langchain/tests/unit_tests/test_checkpoint.py index a64b6c88..3d273b48 100644 --- a/integrations/langchain/tests/unit_tests/test_checkpoint.py +++ b/integrations/langchain/tests/unit_tests/test_checkpoint.py @@ -168,3 +168,119 @@ async def test_async_checkpoint_saver_connection(monkeypatch): ) as saver: async with saver._lakebase.connection() as conn: assert conn == "async-lake-conn" + + +# ============================================================================= +# Autoscaling (project/branch) Tests +# ============================================================================= + + +def _create_autoscaling_workspace(): + """Helper to create a mock workspace client for autoscaling mode.""" + workspace = MagicMock() + workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com") + workspace.postgres.generate_database_credential.return_value = MagicMock( + token="autoscaling-token" + ) + rw_endpoint = MagicMock() + rw_endpoint.name = "projects/p/branches/b/endpoints/rw" + rw_endpoint.status.endpoint_type = "READ_WRITE" + rw_endpoint.status.hosts.host = "auto-db-host" + workspace.postgres.list_endpoints.return_value = [rw_endpoint] + return workspace + + +def test_checkpoint_saver_autoscaling_configures_lakebase(monkeypatch): + test_pool = TestConnectionPool(connection_value="lake-conn") + monkeypatch.setattr(lakebase, "ConnectionPool", test_pool) + + workspace = _create_autoscaling_workspace() + + saver = CheckpointSaver( + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) + + assert "host=auto-db-host" in test_pool.conninfo + assert saver._lakebase._is_autoscaling is True + workspace.postgres.list_endpoints.assert_called_once_with( + parent="projects/my-project/branches/my-branch" + ) + + +@pytest.mark.asyncio +async def test_async_checkpoint_saver_autoscaling_configures_lakebase(monkeypatch): + test_pool = TestAsyncConnectionPool(connection_value="async-lake-conn") + monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool) + + workspace = _create_autoscaling_workspace() + + saver = AsyncCheckpointSaver( + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) + + assert "host=auto-db-host" in test_pool.conninfo + assert saver._lakebase._is_autoscaling is True + + +@pytest.mark.asyncio +async def test_async_checkpoint_saver_autoscaling_context_manager(monkeypatch): + test_pool = TestAsyncConnectionPool(connection_value="async-lake-conn") + monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool) + + workspace = _create_autoscaling_workspace() + + async with AsyncCheckpointSaver( + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) as saver: + assert test_pool._opened + assert saver._lakebase._is_autoscaling is True + + assert test_pool._closed + + +# ============================================================================= +# Validation: missing parameters +# ============================================================================= + + +def test_checkpoint_saver_no_params_raises_error(monkeypatch): + """CheckpointSaver with no connection parameters raises ValueError.""" + test_pool = TestConnectionPool() + monkeypatch.setattr(lakebase, "ConnectionPool", test_pool) + + workspace = MagicMock() + workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com") + + with pytest.raises(ValueError, match="Must provide either 'instance_name'"): + CheckpointSaver(workspace_client=workspace) + + +def test_checkpoint_saver_only_project_raises_error(monkeypatch): + """CheckpointSaver with only project (no branch) raises ValueError.""" + test_pool = TestConnectionPool() + monkeypatch.setattr(lakebase, "ConnectionPool", test_pool) + + workspace = MagicMock() + workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com") + + with pytest.raises(ValueError, match="Both 'project' and 'branch' are required"): + CheckpointSaver(project="my-project", workspace_client=workspace) + + +@pytest.mark.asyncio +async def test_async_checkpoint_saver_no_params_raises_error(monkeypatch): + """AsyncCheckpointSaver with no connection parameters raises ValueError.""" + test_pool = TestAsyncConnectionPool() + monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool) + + workspace = MagicMock() + workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com") + + with pytest.raises(ValueError, match="Must provide either 'instance_name'"): + AsyncCheckpointSaver(workspace_client=workspace) diff --git a/integrations/langchain/tests/unit_tests/test_store.py b/integrations/langchain/tests/unit_tests/test_store.py index a10c5eda..cee740d3 100644 --- a/integrations/langchain/tests/unit_tests/test_store.py +++ b/integrations/langchain/tests/unit_tests/test_store.py @@ -426,3 +426,140 @@ async def test_async_databricks_store_connection(monkeypatch): ) as store: async with store._lakebase.connection() as conn: assert conn == mock_conn + + +# ============================================================================= +# Autoscaling (project/branch) Tests +# ============================================================================= + + +def _create_autoscaling_workspace(): + """Helper to create a mock workspace client for autoscaling mode.""" + workspace = MagicMock() + workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com") + workspace.postgres.generate_database_credential.return_value = MagicMock( + token="autoscaling-token" + ) + rw_endpoint = MagicMock() + rw_endpoint.name = "projects/p/branches/b/endpoints/rw" + rw_endpoint.status.endpoint_type = "READ_WRITE" + rw_endpoint.status.hosts.host = "auto-db-host" + workspace.postgres.list_endpoints.return_value = [rw_endpoint] + return workspace + + +def test_databricks_store_autoscaling_configures_lakebase(monkeypatch): + """Test that DatabricksStore with project/branch uses autoscaling path.""" + mock_conn = MagicMock() + test_pool = TestConnectionPool(connection_value=mock_conn) + monkeypatch.setattr(lakebase, "ConnectionPool", test_pool) + + workspace = _create_autoscaling_workspace() + + store = DatabricksStore( + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) + + assert "host=auto-db-host" in test_pool.conninfo + assert store._lakebase._is_autoscaling is True + workspace.postgres.list_endpoints.assert_called_once_with( + parent="projects/my-project/branches/my-branch" + ) + + +def test_databricks_store_provisioned_uses_provisioned_path(monkeypatch): + """Test that DatabricksStore with instance_name uses provisioned path.""" + mock_conn = MagicMock() + test_pool = TestConnectionPool(connection_value=mock_conn) + monkeypatch.setattr(lakebase, "ConnectionPool", test_pool) + + workspace = _create_mock_workspace() + + store = DatabricksStore( + instance_name="lakebase-instance", + workspace_client=workspace, + ) + + assert "host=db-host" in test_pool.conninfo + assert store._lakebase._is_autoscaling is False + workspace.database.get_database_instance.assert_called_once() + + +@pytest.mark.asyncio +async def test_async_databricks_store_autoscaling_configures_lakebase(monkeypatch): + """Test that AsyncDatabricksStore with project/branch uses autoscaling path.""" + mock_conn = MagicMock() + test_pool = TestAsyncConnectionPool(connection_value=mock_conn) + monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool) + + workspace = _create_autoscaling_workspace() + + store = AsyncDatabricksStore( + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) + + assert "host=auto-db-host" in test_pool.conninfo + assert store._lakebase._is_autoscaling is True + + +@pytest.mark.asyncio +async def test_async_databricks_store_autoscaling_context_manager(monkeypatch): + """Test autoscaling async store context manager opens and closes the pool.""" + mock_conn = MagicMock() + test_pool = TestAsyncConnectionPool(connection_value=mock_conn) + monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool) + + workspace = _create_autoscaling_workspace() + + async with AsyncDatabricksStore( + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) as store: + assert test_pool._opened + assert store._lakebase._is_autoscaling is True + + assert test_pool._closed + + +# ============================================================================= +# Validation: missing parameters +# ============================================================================= + + +def test_databricks_store_no_params_raises_error(monkeypatch): + """DatabricksStore with no connection parameters raises ValueError.""" + test_pool = TestConnectionPool() + monkeypatch.setattr(lakebase, "ConnectionPool", test_pool) + + workspace = _create_mock_workspace() + + with pytest.raises(ValueError, match="Must provide either 'instance_name'"): + DatabricksStore(workspace_client=workspace) + + +def test_databricks_store_only_branch_raises_error(monkeypatch): + """DatabricksStore with only branch (no project) raises ValueError.""" + test_pool = TestConnectionPool() + monkeypatch.setattr(lakebase, "ConnectionPool", test_pool) + + workspace = _create_mock_workspace() + + with pytest.raises(ValueError, match="Both 'project' and 'branch' are required"): + DatabricksStore(branch="my-branch", workspace_client=workspace) + + +@pytest.mark.asyncio +async def test_async_databricks_store_no_params_raises_error(monkeypatch): + """AsyncDatabricksStore with no connection parameters raises ValueError.""" + test_pool = TestAsyncConnectionPool() + monkeypatch.setattr(lakebase, "AsyncConnectionPool", test_pool) + + workspace = _create_mock_workspace() + + with pytest.raises(ValueError, match="Must provide either 'instance_name'"): + AsyncDatabricksStore(workspace_client=workspace) diff --git a/integrations/openai/src/databricks_openai/agents/session.py b/integrations/openai/src/databricks_openai/agents/session.py index c2642b68..7e8a7dd4 100644 --- a/integrations/openai/src/databricks_openai/agents/session.py +++ b/integrations/openai/src/databricks_openai/agents/session.py @@ -98,7 +98,9 @@ def __init__( self, session_id: str, *, - instance_name: str, + instance_name: Optional[str] = None, + project: Optional[str] = None, + branch: Optional[str] = None, workspace_client: Optional[WorkspaceClient] = None, token_cache_duration_seconds: int = DEFAULT_TOKEN_CACHE_DURATION_SECONDS, create_tables: bool = True, @@ -112,7 +114,9 @@ def __init__( Args: session_id: Unique identifier for the conversation session. - instance_name: Name of the Lakebase instance. + instance_name: Name of the Lakebase provisioned instance. + project: Lakebase autoscaling project name. Also requires ``branch``. + branch: Lakebase autoscaling branch name. Also requires ``project``. workspace_client: Optional WorkspaceClient for authentication. If not provided, a default client will be created. token_cache_duration_seconds: How long to cache OAuth tokens. @@ -124,8 +128,8 @@ def __init__( messages_table: Name of the messages table. Defaults to "agent_messages". use_cached_engine: Whether to reuse a cached engine for the same - instance_name and engine_kwargs combination. Set to False to - always create a new engine. Defaults to True. + connection parameters and engine_kwargs combination. Set to False + to always create a new engine. Defaults to True. **engine_kwargs: Additional keyword arguments passed to SQLAlchemy's create_async_engine(). """ @@ -135,8 +139,29 @@ def __init__( "Please install with: pip install databricks-openai[memory]" ) + # Validate connection parameters early (before cache key creation) + is_autoscaling = project is not None or branch is not None + if is_autoscaling and not (project and branch): + raise ValueError( + "Both 'project' and 'branch' are required to use a Lakebase " + "autoscaling instance. Please specify both parameters." + ) + if not is_autoscaling and instance_name is None: + raise ValueError( + "Must provide either 'instance_name' (provisioned) or both " + "'project' and 'branch' (autoscaling)." + ) + if is_autoscaling and instance_name is not None: + raise ValueError( + "Cannot provide both 'instance_name' (provisioned) and " + "'project'/'branch' (autoscaling). Pass in the set of parameters " + "that correspond to your Lakebase instance." + ) + self._lakebase = self._get_or_create_lakebase( instance_name=instance_name, + project=project, + branch=branch, workspace_client=workspace_client, token_cache_duration_seconds=token_cache_duration_seconds, pool_recycle=engine_kwargs.pop("pool_recycle", DEFAULT_POOL_RECYCLE_SECONDS), @@ -154,23 +179,32 @@ def __init__( ) logger.info( - "AsyncDatabricksSession initialized: instance=%s session_id=%s", - instance_name, + "AsyncDatabricksSession initialized: session_id=%s", session_id, ) @classmethod - def _build_cache_key(cls, instance_name: str, **engine_kwargs: Any) -> str: - """Build a cache key from instance_name and engine_kwargs.""" + def _build_cache_key( + cls, + instance_name: Optional[str] = None, + project: Optional[str] = None, + branch: Optional[str] = None, + **engine_kwargs: Any, + ) -> str: + """Build a cache key from connection parameters and engine_kwargs.""" # Sort kwargs for deterministic key; use JSON for serializable values kwargs_key = json.dumps(engine_kwargs, sort_keys=True, default=str) - return f"{instance_name}::{kwargs_key}" + if project and branch: + return f"autoscaling::{project}::{branch}::{kwargs_key}" + return f"provisioned::{instance_name}::{kwargs_key}" @classmethod def _get_or_create_lakebase( cls, *, - instance_name: str, + instance_name: Optional[str], + project: Optional[str], + branch: Optional[str], workspace_client: Optional[WorkspaceClient], token_cache_duration_seconds: int, pool_recycle: int, @@ -178,9 +212,15 @@ def _get_or_create_lakebase( **engine_kwargs, ) -> AsyncLakebaseSQLAlchemy: """Get cached AsyncLakebaseSQLAlchemy or create a new one. - The cache key uses both instance_name and engine_kwargs + The cache key uses connection parameters and engine_kwargs. """ - cache_key = cls._build_cache_key(instance_name, pool_recycle=pool_recycle, **engine_kwargs) + cache_key = cls._build_cache_key( + instance_name=instance_name, + project=project, + branch=branch, + pool_recycle=pool_recycle, + **engine_kwargs, + ) if use_cached_engine: with cls._lakebase_sql_alchemy_cache_lock: @@ -190,6 +230,8 @@ def _get_or_create_lakebase( lakebase = AsyncLakebaseSQLAlchemy( instance_name=instance_name, + project=project, + branch=branch, workspace_client=workspace_client, token_cache_duration_seconds=token_cache_duration_seconds, pool_recycle=pool_recycle, diff --git a/integrations/openai/tests/unit_tests/test_session.py b/integrations/openai/tests/unit_tests/test_session.py index 9fe66ce1..69effe49 100644 --- a/integrations/openai/tests/unit_tests/test_session.py +++ b/integrations/openai/tests/unit_tests/test_session.py @@ -855,3 +855,294 @@ def test_methods_are_coroutine_functions( assert inspect.iscoroutinefunction(session.add_items) assert inspect.iscoroutinefunction(session.pop_item) assert inspect.iscoroutinefunction(session.clear_session) + + +# ============================================================================= +# Autoscaling (project/branch) Tests +# ============================================================================= + + +@pytest.fixture +def mock_autoscaling_workspace_client(): + """Create a mock WorkspaceClient for autoscaling mode.""" + mock_client = MagicMock() + mock_client.config.host = "https://test.databricks.com" + + # Mock current_user.me() for username inference + mock_user = MagicMock() + mock_user.user_name = "test_user@databricks.com" + mock_client.current_user.me.return_value = mock_user + + # Mock postgres.list_endpoints → returns one READ_WRITE endpoint + rw_endpoint = MagicMock() + rw_endpoint.name = "projects/my-project/branches/my-branch/endpoints/rw-ep" + rw_endpoint.status.endpoint_type = "READ_WRITE" + rw_endpoint.status.hosts.host = "autoscaling-instance.lakebase.databricks.com" + mock_client.postgres.list_endpoints.return_value = [rw_endpoint] + + # Mock postgres.generate_database_credential for autoscaling token minting + mock_credential = MagicMock() + mock_credential.token = "autoscaling-oauth-token" + mock_client.postgres.generate_database_credential.return_value = mock_credential + + return mock_client + + +class TestAsyncDatabricksSessionAutoscaling: + """Tests for AsyncDatabricksSession with autoscaling (project/branch).""" + + def test_init_autoscaling_resolves_host( + self, mock_autoscaling_workspace_client, mock_engine, mock_event_listens_for + ): + """Test that initialization with project/branch resolves host via autoscaling API.""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_autoscaling_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ) as mock_create_engine, + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + session = AsyncDatabricksSession( + session_id="test-session-123", + project="my-project", + branch="my-branch", + workspace_client=mock_autoscaling_workspace_client, + ) + + # Verify engine URL uses autoscaling host + call_args = mock_create_engine.call_args + url = call_args[0][0] + assert url.host == "autoscaling-instance.lakebase.databricks.com" + + # Verify autoscaling API was called + mock_autoscaling_workspace_client.postgres.list_endpoints.assert_called_once_with( + parent="projects/my-project/branches/my-branch" + ) + + def test_init_autoscaling_injects_correct_token( + self, mock_autoscaling_workspace_client, mock_engine + ): + """Test that do_connect injects autoscaling token.""" + captured_handler = None + + def capture_handler(engine, event_name): + def decorator(fn): + nonlocal captured_handler + captured_handler = fn + return fn + + return decorator + + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_autoscaling_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ), + patch( + "sqlalchemy.event.listens_for", + side_effect=capture_handler, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + AsyncDatabricksSession( + session_id="test-session-123", + project="my-project", + branch="my-branch", + workspace_client=mock_autoscaling_workspace_client, + ) + + # Simulate do_connect event + assert captured_handler is not None + cparams = {} + captured_handler(None, None, None, cparams) + + # Verify autoscaling token was injected + assert cparams["password"] == "autoscaling-oauth-token" + mock_autoscaling_workspace_client.postgres.generate_database_credential.assert_called() + + def test_autoscaling_sessions_share_engine( + self, mock_autoscaling_workspace_client, mock_engine, mock_event_listens_for + ): + """Test that autoscaling sessions with same project/branch share an engine.""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_autoscaling_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ) as mock_create_engine, + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + session1 = AsyncDatabricksSession( + session_id="session-1", + project="my-project", + branch="my-branch", + workspace_client=mock_autoscaling_workspace_client, + ) + session2 = AsyncDatabricksSession( + session_id="session-2", + project="my-project", + branch="my-branch", + workspace_client=mock_autoscaling_workspace_client, + ) + + # Engine should only be created once + assert mock_create_engine.call_count == 1 + assert session1._engine is session2._engine + + def test_different_branches_get_different_engines( + self, mock_autoscaling_workspace_client, mock_event_listens_for + ): + """Test that sessions with different branches get different engines.""" + engine1 = MagicMock() + engine1.sync_engine = MagicMock() + engine2 = MagicMock() + engine2.sync_engine = MagicMock() + + engines = [engine1, engine2] + engine_iter = iter(engines) + + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_autoscaling_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + side_effect=lambda *args, **kwargs: next(engine_iter), + ) as mock_create_engine, + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + session1 = AsyncDatabricksSession( + session_id="session-1", + project="my-project", + branch="branch-a", + workspace_client=mock_autoscaling_workspace_client, + ) + session2 = AsyncDatabricksSession( + session_id="session-2", + project="my-project", + branch="branch-b", + workspace_client=mock_autoscaling_workspace_client, + ) + + assert mock_create_engine.call_count == 2 + assert session1._engine is not session2._engine + + def test_provisioned_and_autoscaling_get_different_engines( + self, mock_workspace_client, mock_autoscaling_workspace_client, mock_event_listens_for + ): + """Test that a provisioned session and an autoscaling session get different engines.""" + engine1 = MagicMock() + engine1.sync_engine = MagicMock() + engine2 = MagicMock() + engine2.sync_engine = MagicMock() + + engines = [engine1, engine2] + engine_iter = iter(engines) + + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + side_effect=lambda *args, **kwargs: next(engine_iter), + ) as mock_create_engine, + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + session_provisioned = AsyncDatabricksSession( + session_id="session-prov", + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + session_autoscaling = AsyncDatabricksSession( + session_id="session-auto", + project="my-project", + branch="my-branch", + workspace_client=mock_autoscaling_workspace_client, + ) + + assert mock_create_engine.call_count == 2 + assert session_provisioned._engine is not session_autoscaling._engine + + +# ============================================================================= +# Validation: missing parameters +# ============================================================================= + + +class TestAsyncDatabricksSessionValidation: + """Tests for parameter validation in AsyncDatabricksSession.""" + + def test_no_params_raises_error(self): + """AsyncDatabricksSession with no connection parameters raises ValueError.""" + from databricks_openai.agents.session import AsyncDatabricksSession + + workspace = MagicMock() + workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com") + + with pytest.raises(ValueError, match="Must provide either 'instance_name'"): + AsyncDatabricksSession( + session_id="test-session", + workspace_client=workspace, + ) + + def test_only_project_raises_error(self): + """AsyncDatabricksSession with only project (no branch) raises ValueError.""" + from databricks_openai.agents.session import AsyncDatabricksSession + + workspace = MagicMock() + workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com") + + with pytest.raises(ValueError, match="Both 'project' and 'branch' are required"): + AsyncDatabricksSession( + session_id="test-session", + project="my-project", + workspace_client=workspace, + ) + + def test_only_branch_raises_error(self): + """AsyncDatabricksSession with only branch (no project) raises ValueError.""" + from databricks_openai.agents.session import AsyncDatabricksSession + + workspace = MagicMock() + workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com") + + with pytest.raises(ValueError, match="Both 'project' and 'branch' are required"): + AsyncDatabricksSession( + session_id="test-session", + branch="my-branch", + workspace_client=workspace, + ) diff --git a/src/databricks_ai_bridge/lakebase.py b/src/databricks_ai_bridge/lakebase.py index 98a1b0e2..ab94981b 100644 --- a/src/databricks_ai_bridge/lakebase.py +++ b/src/databricks_ai_bridge/lakebase.py @@ -93,26 +93,78 @@ class _LakebaseBase: Base class for Lakebase connections: resolve host, infer username, token cache + minting, and conninfo building. + Supports two modes: Lakebase Provisioned VS Autoscaling + https://docs.databricks.com/aws/en/oltp/#feature-comparison + + - **Provisioned**: Pass ``instance_name``. + - **Autoscaling**: Pass ``project`` and ``branch``. + + Providing both ``instance_name`` *and* ``project``/``branch`` raises a + ``ValueError``; choose one mode. + Subclasses implement specific initialization and lifecycle methods. """ def __init__( self, *, - instance_name: str, + instance_name: str | None = None, + project: str | None = None, + branch: str | None = None, workspace_client: WorkspaceClient | None = None, token_cache_duration_seconds: int = DEFAULT_TOKEN_CACHE_DURATION_SECONDS, ) -> None: self.workspace_client: WorkspaceClient = workspace_client or WorkspaceClient() - self.instance_name: str = instance_name self.token_cache_duration_seconds: int = token_cache_duration_seconds - # Resolve host from the Lakebase name + # --- Parameter validation --- + is_autoscaling = project is not None or branch is not None + if is_autoscaling and not (project and branch): + raise ValueError( + "Both 'project' and 'branch' are required to use a Lakebase " + "autoscaling instance. Please specify both parameters." + ) + + if not is_autoscaling and instance_name is None: + raise ValueError( + "Must provide either 'instance_name' (provisioned) or both " + "'project' and 'branch' (autoscaling)." + ) + + if is_autoscaling and instance_name is not None: + raise ValueError( + "Cannot provide both 'instance_name' (provisioned) and " + "'project'/'branch' (autoscaling). Choose one mode." + ) + + self._is_autoscaling: bool = is_autoscaling + + self.instance_name: str | None = instance_name + self.project: str | None = project + self.branch: str | None = branch + + if self._is_autoscaling: + self._endpoint_name: str | None = None + self.host = self._resolve_autoscaling_host() + else: + self._endpoint_name = None + self.host = self._resolve_provisioned_host() + + self.username: str = self._infer_username() + + self._cached_token: str | None = None + self._cache_ts: float | None = None + + # --- Host resolution --- + + def _resolve_provisioned_host(self) -> str: + """Resolve host via the Lakebase provisioned database API.""" try: - instance = self.workspace_client.database.get_database_instance(instance_name) + assert self.instance_name is not None + instance = self.workspace_client.database.get_database_instance(self.instance_name) except Exception as exc: raise ValueError( - f"Unable to resolve Lakebase instance '{instance_name}'. " + f"Unable to resolve Lakebase instance '{self.instance_name}'. " "Ensure the instance name is correct." ) from exc @@ -122,15 +174,58 @@ def __init__( if not resolved_host: raise ValueError( - f"Lakebase host not found for instance '{instance_name}'. " + f"Lakebase host not found for instance '{self.instance_name}'. " "Ensure the instance is running and in AVAILABLE state." ) - self.host: str = resolved_host - self.username: str = self._infer_username() + return resolved_host - self._cached_token: str | None = None - self._cache_ts: float | None = None + def _resolve_autoscaling_host(self) -> str: + """Resolve host via the Lakebase autoscaling postgres API. + + Constructs the branch parent path, lists endpoints, finds the + READ_WRITE endpoint, and extracts the host and endpoint name. + """ + branch_parent = f"projects/{self.project}/branches/{self.branch}" + + try: + endpoints = list(self.workspace_client.postgres.list_endpoints(parent=branch_parent)) + except Exception as exc: + raise ValueError( + f"Unable to list endpoints for '{branch_parent}'. " + "Ensure the project and branch names are correct." + ) from exc + + # Find the READ_WRITE endpoint + rw_endpoint = None + for ep in endpoints: + ep_status = getattr(ep, "status", None) + ep_type = getattr(ep_status, "endpoint_type", None) + if ep_type and "READ_WRITE" in str(ep_type): + rw_endpoint = ep + break + + if rw_endpoint is None: + raise ValueError( + f"No READ_WRITE endpoint found for '{branch_parent}'. " + "Ensure the branch has an active READ_WRITE endpoint." + ) + + # Extract host from endpoint status + ep_status = rw_endpoint.status + hosts = getattr(ep_status, "hosts", None) + resolved_host = getattr(hosts, "host", None) if hosts else None + + if not resolved_host: + raise ValueError( + f"Host not found on READ_WRITE endpoint for '{branch_parent}'. " + "Ensure the endpoint is in AVAILABLE state." + ) + + self._endpoint_name = rw_endpoint.name + return resolved_host + + # --- Token caching --- def _get_cached_token(self) -> str | None: """Check if the cached token is still valid.""" @@ -141,7 +236,13 @@ def _get_cached_token(self) -> str | None: return None def _mint_token(self) -> str: + if self._is_autoscaling: + return self._mint_token_autoscaling() + return self._mint_token_provisioned() + + def _mint_token_provisioned(self) -> str: try: + assert self.instance_name is not None cred = self.workspace_client.database.generate_database_credential( request_id=str(uuid.uuid4()), instance_names=[self.instance_name], @@ -157,6 +258,23 @@ def _mint_token(self) -> str: return cred.token + def _mint_token_autoscaling(self) -> str: + try: + assert self._endpoint_name is not None + cred = self.workspace_client.postgres.generate_database_credential( + endpoint=self._endpoint_name, + ) + except Exception as exc: + raise ConnectionError( + f"Failed to obtain credential for Lakebase autoscaling endpoint " + f"'{self._endpoint_name}'. Ensure the caller has access." + ) from exc + + if not cred.token: + raise RuntimeError("Failed to generate database credential: no token received") + + return cred.token + def _conninfo(self) -> str: """Build the connection info string.""" return ( @@ -178,19 +296,27 @@ def _infer_username(self) -> str: class LakebasePool(_LakebaseBase): """Sync Lakebase connection pool built on psycopg with rotating credentials. - instance_name: Name of Lakebase Instance + Supports two modes: Lakebase Provisioned VS Autoscaling + https://docs.databricks.com/aws/en/oltp/#feature-comparison + + - **Provisioned**: Pass ``instance_name``. + - **Autoscaling**: Pass ``project`` and ``branch``. """ def __init__( self, *, - instance_name: str, + instance_name: str | None = None, + project: str | None = None, + branch: str | None = None, workspace_client: WorkspaceClient | None = None, token_cache_duration_seconds: int = DEFAULT_TOKEN_CACHE_DURATION_SECONDS, **pool_kwargs: dict[str, Any], ) -> None: super().__init__( instance_name=instance_name, + project=project, + branch=branch, workspace_client=workspace_client, token_cache_duration_seconds=token_cache_duration_seconds, ) @@ -272,19 +398,27 @@ def close(self) -> None: class AsyncLakebasePool(_LakebaseBase): """Async Lakebase connection pool built on psycopg with rotating credentials. - instance_name: Name of Lakebase Instance + Supports two modes: Lakebase Provisioned VS Autoscaling + https://docs.databricks.com/aws/en/oltp/#feature-comparison + + - **Provisioned**: Pass ``instance_name``. + - **Autoscaling**: Pass ``project`` and ``branch``. """ def __init__( self, *, - instance_name: str, + instance_name: str | None = None, + project: str | None = None, + branch: str | None = None, workspace_client: WorkspaceClient | None = None, token_cache_duration_seconds: int = DEFAULT_TOKEN_CACHE_DURATION_SECONDS, **pool_kwargs: object, ) -> None: super().__init__( instance_name=instance_name, + project=project, + branch=branch, workspace_client=workspace_client, token_cache_duration_seconds=token_cache_duration_seconds, ) @@ -444,6 +578,8 @@ def __init__( *, pool: LakebasePool | None = None, instance_name: str | None = None, + project: str | None = None, + branch: str | None = None, **pool_kwargs: Any, ) -> None: """ @@ -451,18 +587,29 @@ def __init__( Provide EITHER: - pool: An existing LakebasePool instance (advanced usage where multiple clients can connect to same pool) - - instance_name: Name of the Lakebase instance (creates pool internally) + - instance_name: Name of the Lakebase provisioned instance + - project + branch: Lakebase autoscaling project and branch names :param pool: Existing LakebasePool to use for connections. - :param instance_name: Name of the Lakebase instance (used to create pool if pool not provided). - :param workspace_client: Optional WorkspaceClient (only used when creating pool internally). + :param instance_name: Name of the Lakebase provisioned instance. + :param project: Lakebase autoscaling project name. Also requires ``branch``. + :param branch: Lakebase autoscaling branch name. Also requires ``project``. :param pool_kwargs: Additional kwargs passed to LakebasePool (only used when creating pool internally). """ - if pool is not None and instance_name is not None: - raise ValueError("Provide either 'pool' or 'instance_name', not both.") + has_connection_params = ( + instance_name is not None or project is not None or branch is not None + ) + if pool is not None and has_connection_params: + raise ValueError( + "Provide either 'pool' or connection parameters " + "('instance_name' or 'project'/'branch'), not both." + ) - if pool is None and instance_name is None: - raise ValueError("Must provide either 'pool' or 'instance_name'.") + if pool is None and not has_connection_params: + raise ValueError( + "Must provide 'pool', 'instance_name' (provisioned), " + "or both 'project' and 'branch' (autoscaling)." + ) self._owns_pool = pool is None @@ -470,7 +617,9 @@ def __init__( self._pool = pool else: self._pool = LakebasePool( - instance_name=instance_name, # type: ignore[arg-type] + instance_name=instance_name, + project=project, + branch=branch, **pool_kwargs, ) @@ -912,7 +1061,9 @@ class AsyncLakebaseSQLAlchemy(_LakebaseBase): def __init__( self, *, - instance_name: str, + instance_name: str | None = None, + project: str | None = None, + branch: str | None = None, workspace_client: WorkspaceClient | None = None, token_cache_duration_seconds: int = DEFAULT_TOKEN_CACHE_DURATION_SECONDS, pool_recycle: int = DEFAULT_POOL_RECYCLE_SECONDS, @@ -922,7 +1073,9 @@ def __init__( Initialize AsyncLakebaseSQLAlchemy for Databricks Lakebase. Args: - instance_name: Name of the Lakebase instance. + instance_name: Name of the Lakebase provisioned instance. + project: Lakebase autoscaling project name. Also requires ``branch``. + branch: Lakebase autoscaling branch name. Also requires ``project``. workspace_client: Optional WorkspaceClient for authentication. If not provided, a default client will be created. token_cache_duration_seconds: How long to cache OAuth tokens. @@ -934,6 +1087,8 @@ def __init__( """ super().__init__( instance_name=instance_name, + project=project, + branch=branch, workspace_client=workspace_client, token_cache_duration_seconds=token_cache_duration_seconds, ) @@ -944,8 +1099,7 @@ def __init__( self._engine = self._create_engine(**engine_kwargs) logger.info( - "AsyncLakebaseSQLAlchemy initialized: instance=%s host=%s", - instance_name, + "AsyncLakebaseSQLAlchemy initialized: host=%s", self.host, ) diff --git a/tests/databricks_ai_bridge/test_lakebase.py b/tests/databricks_ai_bridge/test_lakebase.py index a75ead1f..5967f8c7 100644 --- a/tests/databricks_ai_bridge/test_lakebase.py +++ b/tests/databricks_ai_bridge/test_lakebase.py @@ -444,7 +444,10 @@ class TestLakebaseClientInit: def test_client_requires_pool_or_instance_name(self): """Client must be given either pool or instance_name.""" - with pytest.raises(ValueError, match="Must provide either 'pool' or 'instance_name'"): + with pytest.raises( + ValueError, + match="Must provide 'pool', 'instance_name' .provisioned., or both 'project' and 'branch' .autoscaling.", + ): LakebaseClient() @@ -1294,3 +1297,353 @@ def test_async_lakebase_sqlalchemy_invalid_instance_raises(): instance_name="bad-instance", workspace_client=workspace, ) + + +# ============================================================================= +# Autoscaling (project/branch) Tests +# ============================================================================= + + +def _make_autoscaling_workspace( + *, + user_name: str = "test@databricks.com", + credential_token: str = "autoscaling-token-1", + host: str = "autoscaling.db.host", + endpoint_name: str = "projects/my-project/branches/my-branch/endpoints/rw-ep", +): + """Create a mock workspace client for autoscaling (project/branch) mode.""" + workspace = MagicMock() + workspace.current_user.me.return_value = MagicMock(user_name=user_name) + + # Mock postgres.generate_database_credential + workspace.postgres.generate_database_credential.return_value = MagicMock(token=credential_token) + + # Mock postgres.list_endpoints → returns one READ_WRITE endpoint + rw_endpoint = MagicMock() + rw_endpoint.name = endpoint_name + rw_endpoint.status.endpoint_type = "READ_WRITE" + rw_endpoint.status.hosts.host = host + workspace.postgres.list_endpoints.return_value = [rw_endpoint] + + return workspace + + +# --- Parameter validation tests --- + + +def test_autoscaling_requires_both_project_and_branch(): + """Passing only project without branch raises ValueError.""" + workspace = _make_autoscaling_workspace() + with pytest.raises(ValueError, match="Both 'project' and 'branch' are required"): + LakebasePool( + project="my-project", + workspace_client=workspace, + ) + + +def test_autoscaling_requires_both_branch_and_project(): + """Passing only branch without project raises ValueError.""" + workspace = _make_autoscaling_workspace() + with pytest.raises(ValueError, match="Both 'project' and 'branch' are required"): + LakebasePool( + branch="my-branch", + workspace_client=workspace, + ) + + +def test_no_params_raises_error(): + """Passing no connection parameters raises ValueError.""" + workspace = _make_autoscaling_workspace() + with pytest.raises(ValueError, match="Must provide either 'instance_name'"): + LakebasePool(workspace_client=workspace) + + +def test_async_pool_no_params_raises_error(): + """AsyncLakebasePool with no connection parameters raises ValueError.""" + workspace = _make_autoscaling_workspace() + with pytest.raises(ValueError, match="Must provide either 'instance_name'"): + AsyncLakebasePool(workspace_client=workspace) + + +def test_async_pool_only_project_raises_error(): + """AsyncLakebasePool with only project raises ValueError.""" + workspace = _make_autoscaling_workspace() + with pytest.raises(ValueError, match="Both 'project' and 'branch' are required"): + AsyncLakebasePool(project="my-project", workspace_client=workspace) + + +def test_lakebase_client_no_params_raises_error(): + """LakebaseClient with no pool or connection parameters raises ValueError.""" + with pytest.raises( + ValueError, + match="Must provide 'pool', 'instance_name' .provisioned., or both 'project' and 'branch' .autoscaling.", + ): + LakebaseClient() + + +def test_lakebase_client_only_branch_raises_error(monkeypatch): + """LakebaseClient with only branch (no project) raises ValueError.""" + TestConnectionPool = _make_connection_pool_class() + monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) + + workspace = _make_autoscaling_workspace() + with pytest.raises(ValueError, match="Both 'project' and 'branch' are required"): + LakebaseClient(branch="my-branch", workspace_client=workspace) + + +def test_async_sqlalchemy_no_params_raises_error(): + """AsyncLakebaseSQLAlchemy with no connection parameters raises ValueError.""" + workspace = _make_autoscaling_workspace() + patch_engine, patch_event, _ = _make_sqlalchemy_patches(workspace) + + with patch_engine, patch_event: + with pytest.raises(ValueError, match="Must provide either 'instance_name'"): + AsyncLakebaseSQLAlchemy(workspace_client=workspace) + + +def test_async_sqlalchemy_only_project_raises_error(): + """AsyncLakebaseSQLAlchemy with only project (no branch) raises ValueError.""" + workspace = _make_autoscaling_workspace() + patch_engine, patch_event, _ = _make_sqlalchemy_patches(workspace) + + with patch_engine, patch_event: + with pytest.raises(ValueError, match="Both 'project' and 'branch' are required"): + AsyncLakebaseSQLAlchemy(project="my-project", workspace_client=workspace) + + +def test_both_provisioned_and_autoscaling_raises_error(monkeypatch): + """Providing both instance_name and project/branch raises ValueError.""" + TestConnectionPool = _make_connection_pool_class() + monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) + + workspace = _make_autoscaling_workspace(host="autoscaling.db.host") + + with pytest.raises( + ValueError, + match="Cannot provide both 'instance_name' .provisioned. and 'project'/'branch' .autoscaling.", + ): + LakebasePool( + instance_name="my-instance", + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) + + +# --- LakebasePool autoscaling tests --- + + +def test_lakebase_pool_autoscaling_configures_connection_pool(monkeypatch): + """LakebasePool with project/branch resolves host via autoscaling API.""" + TestConnectionPool = _make_connection_pool_class() + monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) + + workspace = _make_autoscaling_workspace(host="auto.db.host") + + pool = LakebasePool( + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) + + assert pool.host == "auto.db.host" + assert pool._is_autoscaling is True + assert pool.username == "test@databricks.com" + assert "host=auto.db.host" in str(pool.pool.conninfo) + + workspace.postgres.list_endpoints.assert_called_once_with( + parent="projects/my-project/branches/my-branch" + ) + + +def test_lakebase_pool_autoscaling_mints_token(monkeypatch): + """Autoscaling pool uses postgres.generate_database_credential for tokens.""" + TestConnectionPool = _make_connection_pool_class() + monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) + + workspace = _make_autoscaling_workspace(credential_token="auto-token") + + pool = LakebasePool( + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) + + token = pool._get_token() + assert token == "auto-token" + workspace.postgres.generate_database_credential.assert_called_once_with( + endpoint="projects/my-project/branches/my-branch/endpoints/rw-ep" + ) + # Provisioned credential API should NOT be called + workspace.database.generate_database_credential.assert_not_called() + + +def test_lakebase_pool_provisioned_mints_token(monkeypatch): + """Provisioned pool uses database.generate_database_credential for tokens.""" + TestConnectionPool = _make_connection_pool_class() + monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) + + workspace = _make_workspace(credential_token="provisioned-token") + + pool = LakebasePool( + instance_name="lake-instance", + workspace_client=workspace, + ) + + token = pool._get_token() + assert token == "provisioned-token" + workspace.database.generate_database_credential.assert_called_once() + # Autoscaling credential API should NOT be called + workspace.postgres.generate_database_credential.assert_not_called() + + +def test_lakebase_pool_autoscaling_no_rw_endpoint_raises(monkeypatch): + """Raises ValueError when no READ_WRITE endpoint is found.""" + TestConnectionPool = _make_connection_pool_class() + monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) + + workspace = _make_autoscaling_workspace() + # Return only a READ_ONLY endpoint + ro_endpoint = MagicMock() + ro_endpoint.status.endpoint_type = "READ_ONLY" + ro_endpoint.status.hosts.host = "ro.host" + workspace.postgres.list_endpoints.return_value = [ro_endpoint] + + with pytest.raises(ValueError, match="No READ_WRITE endpoint found"): + LakebasePool( + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) + + +def test_lakebase_pool_autoscaling_list_endpoints_fails_raises(monkeypatch): + """Raises ValueError when list_endpoints fails.""" + TestConnectionPool = _make_connection_pool_class() + monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) + + workspace = _make_autoscaling_workspace() + workspace.postgres.list_endpoints.side_effect = Exception("Not found") + + with pytest.raises(ValueError, match="Unable to list endpoints"): + LakebasePool( + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) + + +# --- AsyncLakebasePool autoscaling tests --- + + +@pytest.mark.asyncio +async def test_async_lakebase_pool_autoscaling_configures_pool(monkeypatch): + """AsyncLakebasePool with project/branch resolves host via autoscaling API.""" + TestAsyncConnectionPool = _make_async_connection_pool_class() + monkeypatch.setattr( + "databricks_ai_bridge.lakebase.AsyncConnectionPool", TestAsyncConnectionPool + ) + + workspace = _make_autoscaling_workspace(host="async-auto.db.host") + + pool = AsyncLakebasePool( + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) + + assert pool.host == "async-auto.db.host" + assert pool._is_autoscaling is True + assert "host=async-auto.db.host" in str(pool.pool.conninfo) + + workspace.postgres.list_endpoints.assert_called_once_with( + parent="projects/my-project/branches/my-branch" + ) + + +# --- LakebaseClient autoscaling tests --- + + +def test_lakebase_client_autoscaling_creates_pool(monkeypatch): + """LakebaseClient with project/branch creates an autoscaling pool internally.""" + TestConnectionPool = _make_connection_pool_class() + monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) + + workspace = _make_autoscaling_workspace(host="client-auto.db.host") + + client = LakebaseClient( + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) + + assert client.pool.host == "client-auto.db.host" + assert client.pool._is_autoscaling is True + assert client._owns_pool is True + + +def test_lakebase_client_rejects_pool_plus_autoscaling_params(monkeypatch): + """LakebaseClient rejects passing both pool and project/branch.""" + pool = MagicMock(spec=LakebasePool) + + with pytest.raises(ValueError, match="Provide either 'pool' or connection parameters"): + LakebaseClient( + pool=pool, + project="my-project", + branch="my-branch", + ) + + +# --- AsyncLakebaseSQLAlchemy autoscaling tests --- + + +def test_async_lakebase_sqlalchemy_autoscaling_resolves_host(): + """AsyncLakebaseSQLAlchemy with project/branch resolves via autoscaling API.""" + workspace = _make_autoscaling_workspace(host="sa-auto.db.host") + patch_engine, patch_event, _ = _make_sqlalchemy_patches(workspace) + + with patch_engine, patch_event: + sa = AsyncLakebaseSQLAlchemy( + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) + + assert sa.host == "sa-auto.db.host" + assert sa._is_autoscaling is True + workspace.postgres.list_endpoints.assert_called_once_with( + parent="projects/my-project/branches/my-branch" + ) + + +def test_async_lakebase_sqlalchemy_autoscaling_mints_correct_token(): + """AsyncLakebaseSQLAlchemy in autoscaling mode uses postgres credential API.""" + workspace = _make_autoscaling_workspace(credential_token="sa-auto-token") + patch_engine, patch_event, _ = _make_sqlalchemy_patches(workspace) + + with patch_engine, patch_event: + sa = AsyncLakebaseSQLAlchemy( + project="my-project", + branch="my-branch", + workspace_client=workspace, + ) + + token = sa.get_token() + assert token == "sa-auto-token" + workspace.postgres.generate_database_credential.assert_called_once() + + +def test_async_lakebase_sqlalchemy_provisioned_mints_correct_token(): + """AsyncLakebaseSQLAlchemy in provisioned mode uses database credential API.""" + workspace = _make_workspace(credential_token="sa-prov-token") + patch_engine, patch_event, _ = _make_sqlalchemy_patches(workspace) + + with patch_engine, patch_event: + sa = AsyncLakebaseSQLAlchemy( + instance_name="lake-instance", + workspace_client=workspace, + ) + + token = sa.get_token() + assert token == "sa-prov-token" + workspace.database.generate_database_credential.assert_called_once()