diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 9755aa174..918a7306e 100755 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -8,6 +8,7 @@ ### Security ### Bug Fixes +* Pass `--profile` to CLI token source when profile is set, and add read-fallback to migrate legacy host-keyed tokens to profile keys. ### Documentation @@ -18,4 +19,4 @@ * Add `dbr_autoscale` enum value for `databricks.sdk.service.compute.EventDetailsCause`. * Change `output_catalog` field for `databricks.sdk.service.cleanrooms.CreateCleanRoomOutputCatalogResponse` to be required. * [Breaking] Remove `internal_attributes` field for `databricks.sdk.service.sharing.Table`. -* [Breaking] Remove `internal_attributes` field for `databricks.sdk.service.sharing.Volume`. \ No newline at end of file +* [Breaking] Remove `internal_attributes` field for `databricks.sdk.service.sharing.Volume`. diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index 15cbbfdd2..8b8ff338e 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -648,9 +648,15 @@ def __init__( access_token_field: str, expiry_field: str, disable_async: bool = True, + fallback_cmd: Optional[List[str]] = None, ): super().__init__(disable_async=disable_async) self._cmd = cmd + # fallback_cmd is tried when the primary command fails with "unknown flag: --profile", + # indicating the CLI is too old to support --profile. Can be removed once support + # for CLI versions predating --profile is dropped. + # See: https://github.com/databricks/databricks-sdk-go/pull/1497 + self._fallback_cmd = fallback_cmd self._token_type_field = token_type_field self._access_token_field = access_token_field self._expiry_field = expiry_field @@ -666,9 +672,9 @@ def _parse_expiry(expiry: str) -> datetime: if last_e: raise last_e - def refresh(self) -> oauth.Token: + def _exec_cli_command(self, cmd: List[str]) -> oauth.Token: try: - out = _run_subprocess(self._cmd, capture_output=True, check=True) + out = _run_subprocess(cmd, capture_output=True, check=True) it = json.loads(out.stdout.decode()) expires_on = self._parse_expiry(it[self._expiry_field]) return oauth.Token( @@ -681,9 +687,21 @@ def refresh(self) -> oauth.Token: except subprocess.CalledProcessError as e: stdout = e.stdout.decode().strip() stderr = e.stderr.decode().strip() - message = stdout or stderr + message = "\n".join(filter(None, [stdout, stderr])) raise IOError(f"cannot get access token: {message}") from e + def refresh(self) -> oauth.Token: + try: + return self._exec_cli_command(self._cmd) + except IOError as e: + if self._fallback_cmd is not None and "unknown flag: --profile" in str(e): + logger.warning( + "Databricks CLI does not support --profile flag. Falling back to --host. " + "Please upgrade your CLI to the latest version." + ) + return self._exec_cli_command(self._fallback_cmd) + raise + def _run_subprocess( popenargs, @@ -853,17 +871,6 @@ class DatabricksCliTokenSource(CliTokenSource): """Obtain the token granted by `databricks auth login` CLI command""" def __init__(self, cfg: "Config"): - args = ["auth", "token", "--host", cfg.host] - if cfg.experimental_is_unified_host: - # For unified hosts, pass account_id, workspace_id, and experimental flag - args += ["--experimental-is-unified-host"] - if cfg.account_id: - args += ["--account-id", cfg.account_id] - if cfg.workspace_id: - args += ["--workspace-id", str(cfg.workspace_id)] - elif cfg.client_type == ClientType.ACCOUNT: - args += ["--account-id", cfg.account_id] - cli_path = cfg.databricks_cli_path # If the path is not specified look for "databricks" / "databricks.exe" in PATH. @@ -882,14 +889,41 @@ def __init__(self, cfg: "Config"): elif cli_path.count("/") == 0: cli_path = self.__class__._find_executable(cli_path) + fallback_cmd = None + if cfg.profile: + # When profile is set, use --profile as the primary command. + # The profile contains the full config (host, account_id, etc.). + args = ["auth", "token", "--profile", cfg.profile] + # Build a --host fallback for older CLIs that don't support --profile. + if cfg.host: + fallback_cmd = [cli_path, *self.__class__._build_host_args(cfg)] + else: + args = self.__class__._build_host_args(cfg) + super().__init__( cmd=[cli_path, *args], token_type_field="token_type", access_token_field="access_token", expiry_field="expiry", disable_async=cfg.disable_async_token_refresh, + fallback_cmd=fallback_cmd, ) + @staticmethod + def _build_host_args(cfg: "Config") -> List[str]: + """Build CLI arguments using --host (legacy path).""" + args = ["auth", "token", "--host", cfg.host] + if cfg.experimental_is_unified_host: + # For unified hosts, pass account_id, workspace_id, and experimental flag + args += ["--experimental-is-unified-host"] + if cfg.account_id: + args += ["--account-id", cfg.account_id] + if cfg.workspace_id: + args += ["--workspace-id", str(cfg.workspace_id)] + elif cfg.client_type == ClientType.ACCOUNT: + args += ["--account-id", cfg.account_id] + return args + @staticmethod def _find_executable(name) -> str: err = FileNotFoundError("Most likely the Databricks CLI is not installed") diff --git a/tests/test_credentials_provider.py b/tests/test_credentials_provider.py index fb91d4c38..b1ccd1ba5 100644 --- a/tests/test_credentials_provider.py +++ b/tests/test_credentials_provider.py @@ -292,6 +292,7 @@ def test_unified_host_passes_all_flags(self, mocker): ) mock_cfg = Mock() + mock_cfg.profile = None mock_cfg.host = "https://example.databricks.com" mock_cfg.experimental_is_unified_host = True mock_cfg.account_id = "test-account-id" @@ -325,6 +326,7 @@ def test_unified_host_without_workspace_id(self, mocker): ) mock_cfg = Mock() + mock_cfg.profile = None mock_cfg.host = "https://example.databricks.com" mock_cfg.experimental_is_unified_host = True mock_cfg.account_id = "test-account-id" @@ -351,6 +353,7 @@ def test_account_client_passes_account_id(self, mocker): ) mock_cfg = Mock() + mock_cfg.profile = None mock_cfg.host = "https://accounts.cloud.databricks.com" mock_cfg.experimental_is_unified_host = False mock_cfg.account_id = "test-account-id" @@ -368,6 +371,139 @@ def test_account_client_passes_account_id(self, mocker): assert "test-account-id" in cmd assert "--workspace-id" not in cmd + def test_profile_uses_profile_flag_with_host_fallback(self, mocker): + """When profile is set, --profile is used as primary and --host as fallback.""" + mock_init = mocker.patch.object( + credentials_provider.CliTokenSource, + "__init__", + return_value=None, + ) + + mock_cfg = Mock() + mock_cfg.profile = "my-profile" + mock_cfg.host = "https://workspace.databricks.com" + mock_cfg.experimental_is_unified_host = False + mock_cfg.databricks_cli_path = "/path/to/databricks" + mock_cfg.disable_async_token_refresh = False + + credentials_provider.DatabricksCliTokenSource(mock_cfg) + + call_kwargs = mock_init.call_args + cmd = call_kwargs.kwargs["cmd"] + host_cmd = call_kwargs.kwargs["fallback_cmd"] + + assert cmd == ["/path/to/databricks", "auth", "token", "--profile", "my-profile"] + assert host_cmd is not None + assert "--host" in host_cmd + assert "https://workspace.databricks.com" in host_cmd + assert "--profile" not in host_cmd + + def test_profile_without_host_no_fallback(self, mocker): + """When profile is set but host is absent, no fallback is built.""" + mock_init = mocker.patch.object( + credentials_provider.CliTokenSource, + "__init__", + return_value=None, + ) + + mock_cfg = Mock() + mock_cfg.profile = "my-profile" + mock_cfg.host = None + mock_cfg.databricks_cli_path = "/path/to/databricks" + mock_cfg.disable_async_token_refresh = False + + credentials_provider.DatabricksCliTokenSource(mock_cfg) + + call_kwargs = mock_init.call_args + cmd = call_kwargs.kwargs["cmd"] + host_cmd = call_kwargs.kwargs["fallback_cmd"] + + assert cmd == ["/path/to/databricks", "auth", "token", "--profile", "my-profile"] + assert host_cmd is None + + +# Tests for CliTokenSource fallback on unknown --profile flag +class TestCliTokenSourceFallback: + """Tests that CliTokenSource falls back to --host when CLI doesn't support --profile.""" + + def _make_token_source(self, fallback_cmd=None): + ts = credentials_provider.CliTokenSource.__new__(credentials_provider.CliTokenSource) + ts._cmd = ["databricks", "auth", "token", "--profile", "my-profile"] + ts._fallback_cmd = fallback_cmd + ts._token_type_field = "token_type" + ts._access_token_field = "access_token" + ts._expiry_field = "expiry" + return ts + + def _make_process_error(self, stderr: str, stdout: str = ""): + import subprocess + + err = subprocess.CalledProcessError(1, ["databricks"]) + err.stdout = stdout.encode() + err.stderr = stderr.encode() + return err + + def test_fallback_on_unknown_profile_flag(self, mocker): + """When --profile fails with 'unknown flag: --profile', falls back to --host command.""" + import json + + expiry = (datetime.now() + timedelta(hours=1)).strftime("%Y-%m-%dT%H:%M:%S") + valid_response = json.dumps({"access_token": "fallback-token", "token_type": "Bearer", "expiry": expiry}) + + mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess") + mock_run.side_effect = [ + self._make_process_error("Error: unknown flag: --profile"), + Mock(stdout=valid_response.encode()), + ] + + fallback_cmd = ["databricks", "auth", "token", "--host", "https://workspace.databricks.com"] + ts = self._make_token_source(fallback_cmd=fallback_cmd) + token = ts.refresh() + assert token.access_token == "fallback-token" + assert mock_run.call_count == 2 + assert mock_run.call_args_list[1][0][0] == fallback_cmd + + def test_fallback_triggered_when_unknown_flag_in_stderr_only(self, mocker): + """Fallback triggers even when CLI also writes usage text to stdout.""" + import json + + expiry = (datetime.now() + timedelta(hours=1)).strftime("%Y-%m-%dT%H:%M:%S") + valid_response = json.dumps({"access_token": "fallback-token", "token_type": "Bearer", "expiry": expiry}) + + mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess") + mock_run.side_effect = [ + self._make_process_error(stderr="Error: unknown flag: --profile", stdout="Usage: databricks auth token"), + Mock(stdout=valid_response.encode()), + ] + + fallback_cmd = ["databricks", "auth", "token", "--host", "https://workspace.databricks.com"] + ts = self._make_token_source(fallback_cmd=fallback_cmd) + token = ts.refresh() + assert token.access_token == "fallback-token" + + def test_no_fallback_on_real_auth_error(self, mocker): + """When --profile fails with a real error (not unknown flag), no fallback is attempted.""" + mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess") + mock_run.side_effect = self._make_process_error("cache: databricks OAuth is not configured for this host") + + fallback_cmd = ["databricks", "auth", "token", "--host", "https://workspace.databricks.com"] + ts = self._make_token_source(fallback_cmd=fallback_cmd) + with pytest.raises(IOError) as exc_info: + ts.refresh() + assert "databricks OAuth is not configured" in str(exc_info.value) + assert mock_run.call_count == 1 + + def test_no_fallback_when_fallback_cmd_not_set(self, mocker): + """When fallback_cmd is None and --profile fails, the original error is raised.""" + mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess") + mock_run.side_effect = self._make_process_error("Error: unknown flag: --profile") + + ts = self._make_token_source(fallback_cmd=None) + with pytest.raises(IOError) as exc_info: + ts.refresh() + assert "unknown flag: --profile" in str(exc_info.value) + assert mock_run.call_count == 1 + # Tests for cloud-agnostic hosts and removed cloud checks class TestCloudAgnosticHosts: