Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
### Security

### Bug Fixes
* Pass `--profile` to CLI token source when profile is set, and add read-fallback to migrate legacy host-keyed tokens to profile keys.

### Documentation

Expand All @@ -18,4 +19,4 @@
* Add `dbr_autoscale` enum value for `databricks.sdk.service.compute.EventDetailsCause`.
* Change `output_catalog` field for `databricks.sdk.service.cleanrooms.CreateCleanRoomOutputCatalogResponse` to be required.
* [Breaking] Remove `internal_attributes` field for `databricks.sdk.service.sharing.Table`.
* [Breaking] Remove `internal_attributes` field for `databricks.sdk.service.sharing.Volume`.
* [Breaking] Remove `internal_attributes` field for `databricks.sdk.service.sharing.Volume`.
62 changes: 48 additions & 14 deletions databricks/sdk/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,9 +648,15 @@ def __init__(
access_token_field: str,
expiry_field: str,
disable_async: bool = True,
fallback_cmd: Optional[List[str]] = None,
):
super().__init__(disable_async=disable_async)
self._cmd = cmd
# fallback_cmd is tried when the primary command fails with "unknown flag: --profile",
# indicating the CLI is too old to support --profile. Can be removed once support
# for CLI versions predating --profile is dropped.
# See: https://github.com/databricks/databricks-sdk-go/pull/1497
self._fallback_cmd = fallback_cmd
self._token_type_field = token_type_field
self._access_token_field = access_token_field
self._expiry_field = expiry_field
Expand All @@ -666,9 +672,9 @@ def _parse_expiry(expiry: str) -> datetime:
if last_e:
raise last_e

def refresh(self) -> oauth.Token:
def _exec_cli_command(self, cmd: List[str]) -> oauth.Token:
try:
out = _run_subprocess(self._cmd, capture_output=True, check=True)
out = _run_subprocess(cmd, capture_output=True, check=True)
it = json.loads(out.stdout.decode())
expires_on = self._parse_expiry(it[self._expiry_field])
return oauth.Token(
Expand All @@ -681,9 +687,21 @@ def refresh(self) -> oauth.Token:
except subprocess.CalledProcessError as e:
stdout = e.stdout.decode().strip()
stderr = e.stderr.decode().strip()
message = stdout or stderr
message = "\n".join(filter(None, [stdout, stderr]))
raise IOError(f"cannot get access token: {message}") from e

def refresh(self) -> oauth.Token:
try:
return self._exec_cli_command(self._cmd)
except IOError as e:
if self._fallback_cmd is not None and "unknown flag: --profile" in str(e):
logger.warning(
"Databricks CLI does not support --profile flag. Falling back to --host. "
"Please upgrade your CLI to the latest version."
)
return self._exec_cli_command(self._fallback_cmd)
raise


def _run_subprocess(
popenargs,
Expand Down Expand Up @@ -853,17 +871,6 @@ class DatabricksCliTokenSource(CliTokenSource):
"""Obtain the token granted by `databricks auth login` CLI command"""

def __init__(self, cfg: "Config"):
args = ["auth", "token", "--host", cfg.host]
if cfg.experimental_is_unified_host:
# For unified hosts, pass account_id, workspace_id, and experimental flag
args += ["--experimental-is-unified-host"]
if cfg.account_id:
args += ["--account-id", cfg.account_id]
if cfg.workspace_id:
args += ["--workspace-id", str(cfg.workspace_id)]
elif cfg.client_type == ClientType.ACCOUNT:
args += ["--account-id", cfg.account_id]

cli_path = cfg.databricks_cli_path

# If the path is not specified look for "databricks" / "databricks.exe" in PATH.
Expand All @@ -882,14 +889,41 @@ def __init__(self, cfg: "Config"):
elif cli_path.count("/") == 0:
cli_path = self.__class__._find_executable(cli_path)

fallback_cmd = None
if cfg.profile:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

praise: this guards against cfg.host being None! me gusta

# When profile is set, use --profile as the primary command.
# The profile contains the full config (host, account_id, etc.).
args = ["auth", "token", "--profile", cfg.profile]
# Build a --host fallback for older CLIs that don't support --profile.
if cfg.host:
fallback_cmd = [cli_path, *self.__class__._build_host_args(cfg)]
else:
args = self.__class__._build_host_args(cfg)

super().__init__(
cmd=[cli_path, *args],
token_type_field="token_type",
access_token_field="access_token",
expiry_field="expiry",
disable_async=cfg.disable_async_token_refresh,
fallback_cmd=fallback_cmd,
)

@staticmethod
def _build_host_args(cfg: "Config") -> List[str]:
"""Build CLI arguments using --host (legacy path)."""
args = ["auth", "token", "--host", cfg.host]
if cfg.experimental_is_unified_host:
# For unified hosts, pass account_id, workspace_id, and experimental flag
args += ["--experimental-is-unified-host"]
if cfg.account_id:
args += ["--account-id", cfg.account_id]
if cfg.workspace_id:
args += ["--workspace-id", str(cfg.workspace_id)]
elif cfg.client_type == ClientType.ACCOUNT:
args += ["--account-id", cfg.account_id]
return args

@staticmethod
def _find_executable(name) -> str:
err = FileNotFoundError("Most likely the Databricks CLI is not installed")
Expand Down
136 changes: 136 additions & 0 deletions tests/test_credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ def test_unified_host_passes_all_flags(self, mocker):
)

mock_cfg = Mock()
mock_cfg.profile = None
mock_cfg.host = "https://example.databricks.com"
mock_cfg.experimental_is_unified_host = True
mock_cfg.account_id = "test-account-id"
Expand Down Expand Up @@ -325,6 +326,7 @@ def test_unified_host_without_workspace_id(self, mocker):
)

mock_cfg = Mock()
mock_cfg.profile = None
mock_cfg.host = "https://example.databricks.com"
mock_cfg.experimental_is_unified_host = True
mock_cfg.account_id = "test-account-id"
Expand All @@ -351,6 +353,7 @@ def test_account_client_passes_account_id(self, mocker):
)

mock_cfg = Mock()
mock_cfg.profile = None
mock_cfg.host = "https://accounts.cloud.databricks.com"
mock_cfg.experimental_is_unified_host = False
mock_cfg.account_id = "test-account-id"
Expand All @@ -368,6 +371,139 @@ def test_account_client_passes_account_id(self, mocker):
assert "test-account-id" in cmd
assert "--workspace-id" not in cmd

def test_profile_uses_profile_flag_with_host_fallback(self, mocker):
"""When profile is set, --profile is used as primary and --host as fallback."""
mock_init = mocker.patch.object(
credentials_provider.CliTokenSource,
"__init__",
return_value=None,
)

mock_cfg = Mock()
mock_cfg.profile = "my-profile"
mock_cfg.host = "https://workspace.databricks.com"
mock_cfg.experimental_is_unified_host = False
mock_cfg.databricks_cli_path = "/path/to/databricks"
mock_cfg.disable_async_token_refresh = False

credentials_provider.DatabricksCliTokenSource(mock_cfg)

call_kwargs = mock_init.call_args
cmd = call_kwargs.kwargs["cmd"]
host_cmd = call_kwargs.kwargs["fallback_cmd"]

assert cmd == ["/path/to/databricks", "auth", "token", "--profile", "my-profile"]
assert host_cmd is not None
assert "--host" in host_cmd
assert "https://workspace.databricks.com" in host_cmd
assert "--profile" not in host_cmd

def test_profile_without_host_no_fallback(self, mocker):
"""When profile is set but host is absent, no fallback is built."""
mock_init = mocker.patch.object(
credentials_provider.CliTokenSource,
"__init__",
return_value=None,
)

mock_cfg = Mock()
mock_cfg.profile = "my-profile"
mock_cfg.host = None
mock_cfg.databricks_cli_path = "/path/to/databricks"
mock_cfg.disable_async_token_refresh = False

credentials_provider.DatabricksCliTokenSource(mock_cfg)

call_kwargs = mock_init.call_args
cmd = call_kwargs.kwargs["cmd"]
host_cmd = call_kwargs.kwargs["fallback_cmd"]

assert cmd == ["/path/to/databricks", "auth", "token", "--profile", "my-profile"]
assert host_cmd is None


# Tests for CliTokenSource fallback on unknown --profile flag
class TestCliTokenSourceFallback:
"""Tests that CliTokenSource falls back to --host when CLI doesn't support --profile."""

def _make_token_source(self, fallback_cmd=None):
ts = credentials_provider.CliTokenSource.__new__(credentials_provider.CliTokenSource)
ts._cmd = ["databricks", "auth", "token", "--profile", "my-profile"]
ts._fallback_cmd = fallback_cmd
ts._token_type_field = "token_type"
ts._access_token_field = "access_token"
ts._expiry_field = "expiry"
return ts

def _make_process_error(self, stderr: str, stdout: str = ""):
import subprocess

err = subprocess.CalledProcessError(1, ["databricks"])
err.stdout = stdout.encode()
err.stderr = stderr.encode()
return err

def test_fallback_on_unknown_profile_flag(self, mocker):
"""When --profile fails with 'unknown flag: --profile', falls back to --host command."""
import json

expiry = (datetime.now() + timedelta(hours=1)).strftime("%Y-%m-%dT%H:%M:%S")
valid_response = json.dumps({"access_token": "fallback-token", "token_type": "Bearer", "expiry": expiry})

mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
mock_run.side_effect = [
self._make_process_error("Error: unknown flag: --profile"),
Mock(stdout=valid_response.encode()),
]

fallback_cmd = ["databricks", "auth", "token", "--host", "https://workspace.databricks.com"]
ts = self._make_token_source(fallback_cmd=fallback_cmd)
token = ts.refresh()
assert token.access_token == "fallback-token"
assert mock_run.call_count == 2
assert mock_run.call_args_list[1][0][0] == fallback_cmd

def test_fallback_triggered_when_unknown_flag_in_stderr_only(self, mocker):
"""Fallback triggers even when CLI also writes usage text to stdout."""
import json

expiry = (datetime.now() + timedelta(hours=1)).strftime("%Y-%m-%dT%H:%M:%S")
valid_response = json.dumps({"access_token": "fallback-token", "token_type": "Bearer", "expiry": expiry})

mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
mock_run.side_effect = [
self._make_process_error(stderr="Error: unknown flag: --profile", stdout="Usage: databricks auth token"),
Mock(stdout=valid_response.encode()),
]

fallback_cmd = ["databricks", "auth", "token", "--host", "https://workspace.databricks.com"]
ts = self._make_token_source(fallback_cmd=fallback_cmd)
token = ts.refresh()
assert token.access_token == "fallback-token"

def test_no_fallback_on_real_auth_error(self, mocker):
"""When --profile fails with a real error (not unknown flag), no fallback is attempted."""
mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
mock_run.side_effect = self._make_process_error("cache: databricks OAuth is not configured for this host")

fallback_cmd = ["databricks", "auth", "token", "--host", "https://workspace.databricks.com"]
ts = self._make_token_source(fallback_cmd=fallback_cmd)
with pytest.raises(IOError) as exc_info:
ts.refresh()
assert "databricks OAuth is not configured" in str(exc_info.value)
assert mock_run.call_count == 1

def test_no_fallback_when_fallback_cmd_not_set(self, mocker):
"""When fallback_cmd is None and --profile fails, the original error is raised."""
mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
mock_run.side_effect = self._make_process_error("Error: unknown flag: --profile")

ts = self._make_token_source(fallback_cmd=None)
with pytest.raises(IOError) as exc_info:
ts.refresh()
assert "unknown flag: --profile" in str(exc_info.value)
assert mock_run.call_count == 1


# Tests for cloud-agnostic hosts and removed cloud checks
class TestCloudAgnosticHosts:
Expand Down
Loading