diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index a71107c7c..b40fd5201 100755 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -3,6 +3,8 @@ ## Release v0.95.0 ### New Features and Improvements +* Added `Config.discovery_url` config field (`DATABRICKS_DISCOVERY_URL` env var). When set, OIDC endpoints are fetched directly from this URL instead of the default host-type-based logic. Mirrors `discoveryUrl` in the Java SDK. +* The OAuth token cache filename now includes the config profile name (if set) and uses a serialized map to prevent hash collisions. All users will need to reauthenticate once after upgrading. ### Security @@ -18,4 +20,4 @@ * Add `reset_checkpoint_selection` field for `databricks.sdk.service.pipelines.StartUpdate`. * [Breaking] Remove `oauth2_app_client_id` and `oauth2_app_integration_id` fields for `databricks.sdk.service.apps.Space`. * Add `create_database()`, `delete_database()`, `get_database()`, `list_databases()` and `update_database()` methods for [w.postgres](https://databricks-sdk-py.readthedocs.io/en/latest/workspace/postgres/postgres.html) workspace-level service. -* Add `postgres` field for `databricks.sdk.service.apps.AppResource`. \ No newline at end of file +* Add `postgres` field for `databricks.sdk.service.apps.AppResource`. diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index 8b8ff338e..fb0cde725 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -281,6 +281,7 @@ def external_browser(cfg: "Config") -> Optional[CredentialsProvider]: client_secret=client_secret, redirect_url=redirect_url, scopes=scopes, + profile=cfg.profile, ) credentials = token_cache.load() if credentials: diff --git a/databricks/sdk/oauth.py b/databricks/sdk/oauth.py index c1b1ccb0d..aa0c7f810 100644 --- a/databricks/sdk/oauth.py +++ b/databricks/sdk/oauth.py @@ -895,6 +895,7 @@ def __init__( redirect_url: Optional[str] = None, client_secret: Optional[str] = None, scopes: Optional[List[str]] = None, + profile: Optional[str] = None, ) -> None: self._host = host self._client_id = client_id @@ -902,18 +903,20 @@ def __init__( self._redirect_url = redirect_url self._client_secret = client_secret self._scopes = scopes or [] + self._profile = profile @property def filename(self) -> str: - # Include host, client_id, and scopes in the cache filename to make it unique. - hash = hashlib.sha256() - for chunk in [ - self._host, - self._client_id, - ",".join(self._scopes), - ]: - hash.update(chunk.encode("utf-8")) - return os.path.expanduser(os.path.join(self.__class__.BASE_PATH, hash.hexdigest() + ".json")) + # Include host, client_id, scopes, and profile in the cache filename to make it unique. + # JSON serialization ensures values are properly escaped and separated. + key = { + "host": self._host, + "client_id": self._client_id, + "scopes": self._scopes, + "profile": self._profile or "", + } + h = hashlib.sha256(json.dumps(key, sort_keys=True).encode("utf-8")) + return os.path.expanduser(os.path.join(self.__class__.BASE_PATH, h.hexdigest() + ".json")) def load(self) -> Optional[SessionCredentials]: """ diff --git a/tests/test_credentials_provider.py b/tests/test_credentials_provider.py index b1ccd1ba5..647aac16a 100644 --- a/tests/test_credentials_provider.py +++ b/tests/test_credentials_provider.py @@ -228,6 +228,34 @@ def test_external_browser_scopes(mocker, scopes, disable_refresh, expected_scope assert mock_oauth_client_class.call_args.kwargs["scopes"] == expected_scopes +def test_external_browser_passes_profile_to_token_cache(mocker): + """Tests that external_browser passes cfg.profile to TokenCache.""" + mock_cfg = Mock() + mock_cfg.auth_type = "external-browser" + mock_cfg.host = "https://test.databricks.com" + mock_cfg.profile = "myprofile" + mock_cfg.client_id = "test-client-id" + mock_cfg.client_secret = None + mock_cfg.azure_client_id = None + mock_cfg.get_scopes.return_value = ["all-apis"] + mock_cfg.disable_oauth_refresh_token = False + + mock_token_cache_class = mocker.patch("databricks.sdk.credentials_provider.oauth.TokenCache") + mock_token_cache = Mock() + mock_token_cache.load.return_value = None + mock_token_cache_class.return_value = mock_token_cache + + mock_oauth_client = Mock() + mock_consent = Mock() + mock_consent.launch_external_browser.return_value = Mock() + mock_oauth_client.initiate_consent.return_value = mock_consent + mocker.patch("databricks.sdk.credentials_provider.oauth.OAuthClient", return_value=mock_oauth_client) + + credentials_provider.external_browser(mock_cfg) + + assert mock_token_cache_class.call_args.kwargs["profile"] == "myprofile" + + def test_oidc_credentials_provider_invalid_id_token_source(): # Use a mock config object to avoid initializing the auth initialization. mock_cfg = Mock() diff --git a/tests/test_oauth.py b/tests/test_oauth.py index 2bbac89bb..68ce1da67 100644 --- a/tests/test_oauth.py +++ b/tests/test_oauth.py @@ -40,6 +40,40 @@ def test_token_cache_unique_filename_by_scopes(): assert TokenCache(scopes=["foo"], **common_args).filename != TokenCache(scopes=["bar"], **common_args).filename +def test_token_cache_unique_filename_by_profile(): + common_args = dict( + host="http://localhost:", + client_id="abc", + redirect_url="http://localhost:8020", + oidc_endpoints=OidcEndpoints("http://localhost:1234", "http://localhost:1234"), + ) + assert TokenCache(profile="dev", **common_args).filename != TokenCache(profile="prod", **common_args).filename + + +def test_token_cache_filename_no_profile_matches_empty_profile(): + common_args = dict( + host="http://localhost:", + client_id="abc", + redirect_url="http://localhost:8020", + oidc_endpoints=OidcEndpoints("http://localhost:1234", "http://localhost:1234"), + ) + assert TokenCache(**common_args).filename == TokenCache(profile=None, **common_args).filename + + +def test_token_cache_filename_no_delimiter_collision(): + """Scopes and profile with shared comma content must not collide.""" + common_args = dict( + host="http://localhost:", + client_id="abc", + redirect_url="http://localhost:8020", + oidc_endpoints=OidcEndpoints("http://localhost:1234", "http://localhost:1234"), + ) + assert ( + TokenCache(scopes=["a,b"], profile="c", **common_args).filename + != TokenCache(scopes=["a"], profile=",bc", **common_args).filename + ) + + def test_account_oidc_endpoints(requests_mock): requests_mock.get( "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/.well-known/oauth-authorization-server",