From 39baf42fbc0bc84ea65ccf839e0339cffec7adff Mon Sep 17 00:00:00 2001 From: tomas Date: Thu, 2 Apr 2026 21:21:32 +0000 Subject: [PATCH 1/5] feat(sql_execution): Implement retry logic for userpod API - Added a new function `_create_retry_session` to create a requests session with retry capabilities for handling 5xx errors on POST requests. - Updated `_generate_temporary_credentials` and `_get_federated_auth_credentials` to use the new retry session for making requests. - Introduced unit tests to verify the retry session configuration and its usage in the credential generation functions. --- deepnote_toolkit/sql/sql_execution.py | 21 +++++- tests/unit/test_sql_execution.py | 94 +++++++++++++++++++++++++-- 2 files changed, 109 insertions(+), 6 deletions(-) diff --git a/deepnote_toolkit/sql/sql_execution.py b/deepnote_toolkit/sql/sql_execution.py index 07b61fe..285b503 100644 --- a/deepnote_toolkit/sql/sql_execution.py +++ b/deepnote_toolkit/sql/sql_execution.py @@ -11,6 +11,7 @@ import google.oauth2.credentials import numpy as np import requests +from requests.adapters import HTTPAdapter, Retry import wrapt from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization @@ -263,13 +264,28 @@ class ExecuteSqlError(Exception): ) +def _create_retry_session() -> requests.Session: + """Create a requests session with retry on 5xx for POST requests.""" + session = requests.Session() + retries = Retry( + total=3, + backoff_factor=0.5, + status_forcelist=[500, 502, 503, 504], + allowed_methods=["POST"], + ) + session.mount("http://", HTTPAdapter(max_retries=retries)) + session.mount("https://", HTTPAdapter(max_retries=retries)) + return session + + def _generate_temporary_credentials(integration_id): url = get_absolute_userpod_api_url(f"integrations/credentials/{integration_id}") # Add project credentials in detached mode headers = get_project_auth_headers() - response = requests.post(url, timeout=10, headers=headers) + session = _create_retry_session() + response = session.post(url, timeout=10, headers=headers) response.raise_for_status() @@ -291,7 +307,8 @@ def _get_federated_auth_credentials( headers = get_project_auth_headers() headers["UserPodAuthContextToken"] = user_pod_auth_context_token - response = requests.post(url, timeout=10, headers=headers) + session = _create_retry_session() + response = session.post(url, timeout=10, headers=headers) response.raise_for_status() diff --git a/tests/unit/test_sql_execution.py b/tests/unit/test_sql_execution.py index a684077..99e1613 100644 --- a/tests/unit/test_sql_execution.py +++ b/tests/unit/test_sql_execution.py @@ -592,9 +592,9 @@ def test_all_dataframes_serialize_to_parquet(self, key, df): class TestFederatedAuth(unittest.TestCase): @mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers") @mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url") - @mock.patch("deepnote_toolkit.sql.sql_execution.requests.post") + @mock.patch("deepnote_toolkit.sql.sql_execution._create_retry_session") def test_get_federated_auth_credentials_returns_validated_response( - self, mock_post, mock_get_url, mock_get_headers + self, mock_create_session, mock_get_url, mock_get_headers ): """Test that _get_federated_auth_credentials properly validates and returns response data.""" from deepnote_toolkit.sql.sql_execution import _get_federated_auth_credentials @@ -603,12 +603,14 @@ def test_get_federated_auth_credentials_returns_validated_response( mock_get_url.return_value = "https://api.example.com/integrations/federated-auth-token/test-integration-id" mock_get_headers.return_value = {"Authorization": "Bearer project-token"} + mock_session = mock.Mock() mock_response = mock.Mock() mock_response.json.return_value = { "integrationType": "trino", "accessToken": "test-access-token-123", } - mock_post.return_value = mock_response + mock_session.post.return_value = mock_response + mock_create_session.return_value = mock_session # Call the function result = _get_federated_auth_credentials( @@ -621,7 +623,7 @@ def test_get_federated_auth_credentials_returns_validated_response( ) # Verify headers include both project auth and user pod auth context token - mock_post.assert_called_once_with( + mock_session.post.assert_called_once_with( "https://api.example.com/integrations/federated-auth-token/test-integration-id", timeout=10, headers={ @@ -1019,3 +1021,87 @@ def test_databricks_connector_dialect_alias_is_registered(self): self.assertEqual(url.drivername, "databricks+connector") self.assertIsNotNone(dialect_cls) + + +class TestCreateRetrySession(unittest.TestCase): + def test_retry_session_has_correct_config(self): + """Test that _create_retry_session configures retries correctly.""" + from deepnote_toolkit.sql.sql_execution import _create_retry_session + + session = _create_retry_session() + + # Check that both http and https adapters are mounted with retry config + for prefix in ("http://", "https://"): + adapter = session.get_adapter(prefix) + retries = adapter.max_retries + self.assertEqual(retries.total, 3) + self.assertEqual(retries.backoff_factor, 0.5) + self.assertEqual(list(retries.status_forcelist), [500, 502, 503, 504]) + self.assertIn("POST", retries.allowed_methods) + + @mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers") + @mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url") + def test_generate_temporary_credentials_uses_retry_session( + self, mock_get_url, mock_get_headers + ): + """Test that _generate_temporary_credentials uses a retry session.""" + from deepnote_toolkit.sql.sql_execution import _generate_temporary_credentials + + mock_get_url.return_value = "https://api.example.com/integrations/credentials/test-id" + mock_get_headers.return_value = {"Authorization": "Bearer token"} + + with mock.patch( + "deepnote_toolkit.sql.sql_execution._create_retry_session" + ) as mock_create_session: + mock_session = mock.Mock() + mock_response = mock.Mock() + mock_response.json.return_value = { + "username": "user", + "password": "pass", + } + mock_session.post.return_value = mock_response + mock_create_session.return_value = mock_session + + _generate_temporary_credentials("test-id") + + mock_create_session.assert_called_once() + mock_session.post.assert_called_once_with( + "https://api.example.com/integrations/credentials/test-id", + timeout=10, + headers={"Authorization": "Bearer token"}, + ) + + @mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers") + @mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url") + def test_get_federated_auth_credentials_uses_retry_session( + self, mock_get_url, mock_get_headers + ): + """Test that _get_federated_auth_credentials uses a retry session.""" + from deepnote_toolkit.sql.sql_execution import _get_federated_auth_credentials + + mock_get_url.return_value = "https://api.example.com/integrations/federated-auth-token/test-id" + mock_get_headers.return_value = {"Authorization": "Bearer token"} + + with mock.patch( + "deepnote_toolkit.sql.sql_execution._create_retry_session" + ) as mock_create_session: + mock_session = mock.Mock() + mock_response = mock.Mock() + mock_response.json.return_value = { + "integrationType": "trino", + "accessToken": "test-token", + } + mock_session.post.return_value = mock_response + mock_create_session.return_value = mock_session + + _get_federated_auth_credentials("test-id", "auth-context-token") + + mock_create_session.assert_called_once() + mock_session.post.assert_called_once_with( + "https://api.example.com/integrations/federated-auth-token/test-id", + timeout=10, + headers={ + "Authorization": "Bearer token", + "UserPodAuthContextToken": "auth-context-token", + }, + ) From aa99cc3ae4b0117750954c6277ae77f7284deaee Mon Sep 17 00:00:00 2001 From: tomas Date: Thu, 2 Apr 2026 21:58:33 +0000 Subject: [PATCH 2/5] Reformat code --- tests/unit/test_sql_execution.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_sql_execution.py b/tests/unit/test_sql_execution.py index 99e1613..271d718 100644 --- a/tests/unit/test_sql_execution.py +++ b/tests/unit/test_sql_execution.py @@ -1047,7 +1047,9 @@ def test_generate_temporary_credentials_uses_retry_session( """Test that _generate_temporary_credentials uses a retry session.""" from deepnote_toolkit.sql.sql_execution import _generate_temporary_credentials - mock_get_url.return_value = "https://api.example.com/integrations/credentials/test-id" + mock_get_url.return_value = ( + "https://api.example.com/integrations/credentials/test-id" + ) mock_get_headers.return_value = {"Authorization": "Bearer token"} with mock.patch( @@ -1079,7 +1081,9 @@ def test_get_federated_auth_credentials_uses_retry_session( """Test that _get_federated_auth_credentials uses a retry session.""" from deepnote_toolkit.sql.sql_execution import _get_federated_auth_credentials - mock_get_url.return_value = "https://api.example.com/integrations/federated-auth-token/test-id" + mock_get_url.return_value = ( + "https://api.example.com/integrations/federated-auth-token/test-id" + ) mock_get_headers.return_value = {"Authorization": "Bearer token"} with mock.patch( From 01be651650e798c7c27ddae8299a4c0bc3de1c1d Mon Sep 17 00:00:00 2001 From: tomas Date: Fri, 3 Apr 2026 07:56:15 +0000 Subject: [PATCH 3/5] Minor improvements --- deepnote_toolkit/sql/sql_execution.py | 5 +- tests/unit/test_sql_execution.py | 82 +++++++++++++-------------- 2 files changed, 41 insertions(+), 46 deletions(-) diff --git a/deepnote_toolkit/sql/sql_execution.py b/deepnote_toolkit/sql/sql_execution.py index 285b503..6ea9c73 100644 --- a/deepnote_toolkit/sql/sql_execution.py +++ b/deepnote_toolkit/sql/sql_execution.py @@ -11,7 +11,6 @@ import google.oauth2.credentials import numpy as np import requests -from requests.adapters import HTTPAdapter, Retry import wrapt from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization @@ -19,6 +18,7 @@ from google.cloud import bigquery from packaging.version import parse as parse_version from pydantic import BaseModel +from requests.adapters import HTTPAdapter, Retry from sqlalchemy.engine import URL, Connection, create_engine, make_url from sqlalchemy.exc import ResourceClosedError @@ -271,14 +271,13 @@ def _create_retry_session() -> requests.Session: total=3, backoff_factor=0.5, status_forcelist=[500, 502, 503, 504], - allowed_methods=["POST"], ) session.mount("http://", HTTPAdapter(max_retries=retries)) session.mount("https://", HTTPAdapter(max_retries=retries)) return session -def _generate_temporary_credentials(integration_id): +def _generate_temporary_credentials(integration_id) -> tuple[str, str]: url = get_absolute_userpod_api_url(f"integrations/credentials/{integration_id}") # Add project credentials in detached mode diff --git a/tests/unit/test_sql_execution.py b/tests/unit/test_sql_execution.py index 271d718..192b1c0 100644 --- a/tests/unit/test_sql_execution.py +++ b/tests/unit/test_sql_execution.py @@ -1039,10 +1039,11 @@ def test_retry_session_has_correct_config(self): self.assertEqual(list(retries.status_forcelist), [500, 502, 503, 504]) self.assertIn("POST", retries.allowed_methods) + @mock.patch("deepnote_toolkit.sql.sql_execution._create_retry_session") @mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers") @mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url") def test_generate_temporary_credentials_uses_retry_session( - self, mock_get_url, mock_get_headers + self, mock_get_url, mock_get_headers, mock_create_session ): """Test that _generate_temporary_credentials uses a retry session.""" from deepnote_toolkit.sql.sql_execution import _generate_temporary_credentials @@ -1052,31 +1053,29 @@ def test_generate_temporary_credentials_uses_retry_session( ) mock_get_headers.return_value = {"Authorization": "Bearer token"} - with mock.patch( - "deepnote_toolkit.sql.sql_execution._create_retry_session" - ) as mock_create_session: - mock_session = mock.Mock() - mock_response = mock.Mock() - mock_response.json.return_value = { - "username": "user", - "password": "pass", - } - mock_session.post.return_value = mock_response - mock_create_session.return_value = mock_session + mock_session = mock.Mock() + mock_response = mock.Mock() + mock_response.json.return_value = { + "username": "user", + "password": "pass", + } + mock_session.post.return_value = mock_response + mock_create_session.return_value = mock_session - _generate_temporary_credentials("test-id") + _generate_temporary_credentials("test-id") - mock_create_session.assert_called_once() - mock_session.post.assert_called_once_with( - "https://api.example.com/integrations/credentials/test-id", - timeout=10, - headers={"Authorization": "Bearer token"}, - ) + mock_create_session.assert_called_once() + mock_session.post.assert_called_once_with( + "https://api.example.com/integrations/credentials/test-id", + timeout=10, + headers={"Authorization": "Bearer token"}, + ) + @mock.patch("deepnote_toolkit.sql.sql_execution._create_retry_session") @mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers") @mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url") def test_get_federated_auth_credentials_uses_retry_session( - self, mock_get_url, mock_get_headers + self, mock_get_url, mock_get_headers, mock_create_session ): """Test that _get_federated_auth_credentials uses a retry session.""" from deepnote_toolkit.sql.sql_execution import _get_federated_auth_credentials @@ -1086,26 +1085,23 @@ def test_get_federated_auth_credentials_uses_retry_session( ) mock_get_headers.return_value = {"Authorization": "Bearer token"} - with mock.patch( - "deepnote_toolkit.sql.sql_execution._create_retry_session" - ) as mock_create_session: - mock_session = mock.Mock() - mock_response = mock.Mock() - mock_response.json.return_value = { - "integrationType": "trino", - "accessToken": "test-token", - } - mock_session.post.return_value = mock_response - mock_create_session.return_value = mock_session - - _get_federated_auth_credentials("test-id", "auth-context-token") - - mock_create_session.assert_called_once() - mock_session.post.assert_called_once_with( - "https://api.example.com/integrations/federated-auth-token/test-id", - timeout=10, - headers={ - "Authorization": "Bearer token", - "UserPodAuthContextToken": "auth-context-token", - }, - ) + mock_session = mock.Mock() + mock_response = mock.Mock() + mock_response.json.return_value = { + "integrationType": "trino", + "accessToken": "test-token", + } + mock_session.post.return_value = mock_response + mock_create_session.return_value = mock_session + + _get_federated_auth_credentials("test-id", "auth-context-token") + + mock_create_session.assert_called_once() + mock_session.post.assert_called_once_with( + "https://api.example.com/integrations/federated-auth-token/test-id", + timeout=10, + headers={ + "Authorization": "Bearer token", + "UserPodAuthContextToken": "auth-context-token", + }, + ) From 790be7727cfc0cd9ec4acd0bde171492ab76aa79 Mon Sep 17 00:00:00 2001 From: tomas Date: Fri, 3 Apr 2026 18:01:55 +0000 Subject: [PATCH 4/5] Explicit methods for session retry --- deepnote_toolkit/sql/sql_execution.py | 1 + tests/unit/test_sql_execution.py | 15 --------------- 2 files changed, 1 insertion(+), 15 deletions(-) diff --git a/deepnote_toolkit/sql/sql_execution.py b/deepnote_toolkit/sql/sql_execution.py index 6ea9c73..642a687 100644 --- a/deepnote_toolkit/sql/sql_execution.py +++ b/deepnote_toolkit/sql/sql_execution.py @@ -271,6 +271,7 @@ def _create_retry_session() -> requests.Session: total=3, backoff_factor=0.5, status_forcelist=[500, 502, 503, 504], + allowed_methods=["GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS", "TRACE"], ) session.mount("http://", HTTPAdapter(max_retries=retries)) session.mount("https://", HTTPAdapter(max_retries=retries)) diff --git a/tests/unit/test_sql_execution.py b/tests/unit/test_sql_execution.py index 192b1c0..dfab0a0 100644 --- a/tests/unit/test_sql_execution.py +++ b/tests/unit/test_sql_execution.py @@ -1024,21 +1024,6 @@ def test_databricks_connector_dialect_alias_is_registered(self): class TestCreateRetrySession(unittest.TestCase): - def test_retry_session_has_correct_config(self): - """Test that _create_retry_session configures retries correctly.""" - from deepnote_toolkit.sql.sql_execution import _create_retry_session - - session = _create_retry_session() - - # Check that both http and https adapters are mounted with retry config - for prefix in ("http://", "https://"): - adapter = session.get_adapter(prefix) - retries = adapter.max_retries - self.assertEqual(retries.total, 3) - self.assertEqual(retries.backoff_factor, 0.5) - self.assertEqual(list(retries.status_forcelist), [500, 502, 503, 504]) - self.assertIn("POST", retries.allowed_methods) - @mock.patch("deepnote_toolkit.sql.sql_execution._create_retry_session") @mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers") @mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url") From 1f804b1e68df2a32b0382159f32457c70967ca7b Mon Sep 17 00:00:00 2001 From: tomas Date: Fri, 3 Apr 2026 18:56:17 +0000 Subject: [PATCH 5/5] Add proper tests for requests retry mechanism --- tests/unit/test_sql_execution.py | 247 ++++++++++++++++++++++++++----- 1 file changed, 208 insertions(+), 39 deletions(-) diff --git a/tests/unit/test_sql_execution.py b/tests/unit/test_sql_execution.py index dfab0a0..f3bb736 100644 --- a/tests/unit/test_sql_execution.py +++ b/tests/unit/test_sql_execution.py @@ -1024,13 +1024,44 @@ def test_databricks_connector_dialect_alias_is_registered(self): class TestCreateRetrySession(unittest.TestCase): - @mock.patch("deepnote_toolkit.sql.sql_execution._create_retry_session") + """Tests that exercise the real urllib3 retry loop by mocking at the + connection level (``HTTPConnectionPool._make_request``) rather than + replacing ``_create_retry_session``. This lets the ``Retry`` adapter + actually fire retries on 5xx responses. + """ + + def test_create_retry_session_configuration(self): + """Verify the retry session is wired with the expected parameters.""" + from deepnote_toolkit.sql.sql_execution import _create_retry_session + + session = _create_retry_session() + + for prefix in ("http://", "https://"): + adapter = session.get_adapter(f"{prefix}example.com") + retry = adapter.max_retries + + self.assertEqual(retry.total, 3) + self.assertEqual(retry.backoff_factor, 0.5) + self.assertEqual(set(retry.status_forcelist), {500, 502, 503, 504}) + self.assertIn("POST", retry.allowed_methods) + + # -- _generate_temporary_credentials ------------------------------------ + + @mock.patch("urllib3.util.retry.Retry.sleep", return_value=None) + @mock.patch("urllib3.connectionpool.HTTPConnectionPool._make_request") @mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers") @mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url") - def test_generate_temporary_credentials_uses_retry_session( - self, mock_get_url, mock_get_headers, mock_create_session + def test_generate_credentials_retries_on_5xx_then_succeeds( + self, + mock_get_url, + mock_get_headers, + mock_make_request, + mock_retry_sleep, ): - """Test that _generate_temporary_credentials uses a retry session.""" + """Two 5xx failures followed by a 200 - the retry loop must + transparently retry and ultimately return valid credentials.""" + from urllib3 import HTTPResponse as Urllib3Response + from deepnote_toolkit.sql.sql_execution import _generate_temporary_credentials mock_get_url.return_value = ( @@ -1038,31 +1069,128 @@ def test_generate_temporary_credentials_uses_retry_session( ) mock_get_headers.return_value = {"Authorization": "Bearer token"} - mock_session = mock.Mock() - mock_response = mock.Mock() - mock_response.json.return_value = { - "username": "user", - "password": "pass", - } - mock_session.post.return_value = mock_response - mock_create_session.return_value = mock_session + success_body = json.dumps({"username": "user", "password": "pass"}).encode() + mock_make_request.side_effect = [ + Urllib3Response( + body=io.BytesIO(b"Internal Server Error"), + status=500, + headers={}, + preload_content=False, + ), + Urllib3Response( + body=io.BytesIO(b"Bad Gateway"), + status=502, + headers={}, + preload_content=False, + ), + Urllib3Response( + body=io.BytesIO(success_body), + status=200, + headers={"Content-Type": "application/json"}, + preload_content=False, + ), + ] + + result = _generate_temporary_credentials("test-id") + + self.assertEqual(result, ("user", "pass")) + self.assertEqual(mock_make_request.call_count, 3) + self.assertEqual(mock_retry_sleep.call_count, 2) + + @mock.patch("urllib3.util.retry.Retry.sleep", return_value=None) + @mock.patch("urllib3.connectionpool.HTTPConnectionPool._make_request") + @mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers") + @mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url") + def test_generate_credentials_exhausts_retries_on_persistent_5xx( + self, + mock_get_url, + mock_get_headers, + mock_make_request, + mock_retry_sleep, + ): + """All 4 attempts (1 original + 3 retries) return 500 - + must raise ``RetryError``.""" + import requests + from urllib3 import HTTPResponse as Urllib3Response - _generate_temporary_credentials("test-id") + from deepnote_toolkit.sql.sql_execution import _generate_temporary_credentials - mock_create_session.assert_called_once() - mock_session.post.assert_called_once_with( - "https://api.example.com/integrations/credentials/test-id", - timeout=10, - headers={"Authorization": "Bearer token"}, + mock_get_url.return_value = ( + "https://api.example.com/integrations/credentials/test-id" ) + mock_get_headers.return_value = {"Authorization": "Bearer token"} - @mock.patch("deepnote_toolkit.sql.sql_execution._create_retry_session") + mock_make_request.side_effect = [ + Urllib3Response( + body=io.BytesIO(b"Server Error"), + status=500, + headers={}, + preload_content=False, + ) + for _ in range(4) + ] + + with self.assertRaises(requests.exceptions.RetryError): + _generate_temporary_credentials("test-id") + + self.assertEqual(mock_make_request.call_count, 4) + self.assertEqual(mock_retry_sleep.call_count, 3) + + @mock.patch("urllib3.util.retry.Retry.sleep", return_value=None) + @mock.patch("urllib3.connectionpool.HTTPConnectionPool._make_request") + @mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers") + @mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url") + def test_generate_credentials_no_retry_on_4xx( + self, + mock_get_url, + mock_get_headers, + mock_make_request, + mock_retry_sleep, + ): + """A 400 is not in the retry status list - must fail immediately + without retrying.""" + import requests + from urllib3 import HTTPResponse as Urllib3Response + + from deepnote_toolkit.sql.sql_execution import _generate_temporary_credentials + + mock_get_url.return_value = ( + "https://api.example.com/integrations/credentials/test-id" + ) + mock_get_headers.return_value = {"Authorization": "Bearer token"} + + mock_make_request.side_effect = [ + Urllib3Response( + body=io.BytesIO(b"Bad Request"), + status=400, + headers={}, + preload_content=False, + ), + ] + + with self.assertRaises(requests.exceptions.HTTPError): + _generate_temporary_credentials("test-id") + + self.assertEqual(mock_make_request.call_count, 1) + mock_retry_sleep.assert_not_called() + + # -- _get_federated_auth_credentials ------------------------------------ + + @mock.patch("urllib3.util.retry.Retry.sleep", return_value=None) + @mock.patch("urllib3.connectionpool.HTTPConnectionPool._make_request") @mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers") @mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url") - def test_get_federated_auth_credentials_uses_retry_session( - self, mock_get_url, mock_get_headers, mock_create_session + def test_federated_auth_retries_on_5xx_then_succeeds( + self, + mock_get_url, + mock_get_headers, + mock_make_request, + mock_retry_sleep, ): - """Test that _get_federated_auth_credentials uses a retry session.""" + """A 503 followed by a 200 - retry loop must recover and return + valid ``FederatedAuthResponseData``.""" + from urllib3 import HTTPResponse as Urllib3Response + from deepnote_toolkit.sql.sql_execution import _get_federated_auth_credentials mock_get_url.return_value = ( @@ -1070,23 +1198,64 @@ def test_get_federated_auth_credentials_uses_retry_session( ) mock_get_headers.return_value = {"Authorization": "Bearer token"} - mock_session = mock.Mock() - mock_response = mock.Mock() - mock_response.json.return_value = { - "integrationType": "trino", - "accessToken": "test-token", - } - mock_session.post.return_value = mock_response - mock_create_session.return_value = mock_session + success_body = json.dumps( + {"integrationType": "trino", "accessToken": "test-token"} + ).encode() + mock_make_request.side_effect = [ + Urllib3Response( + body=io.BytesIO(b"Service Unavailable"), + status=503, + headers={}, + preload_content=False, + ), + Urllib3Response( + body=io.BytesIO(success_body), + status=200, + headers={"Content-Type": "application/json"}, + preload_content=False, + ), + ] + + result = _get_federated_auth_credentials("test-id", "auth-context-token") - _get_federated_auth_credentials("test-id", "auth-context-token") + self.assertEqual(result.integrationType, "trino") + self.assertEqual(result.accessToken, "test-token") + self.assertEqual(mock_make_request.call_count, 2) + self.assertEqual(mock_retry_sleep.call_count, 1) - mock_create_session.assert_called_once() - mock_session.post.assert_called_once_with( - "https://api.example.com/integrations/federated-auth-token/test-id", - timeout=10, - headers={ - "Authorization": "Bearer token", - "UserPodAuthContextToken": "auth-context-token", - }, + @mock.patch("urllib3.util.retry.Retry.sleep", return_value=None) + @mock.patch("urllib3.connectionpool.HTTPConnectionPool._make_request") + @mock.patch("deepnote_toolkit.sql.sql_execution.get_project_auth_headers") + @mock.patch("deepnote_toolkit.sql.sql_execution.get_absolute_userpod_api_url") + def test_federated_auth_exhausts_retries_on_persistent_5xx( + self, + mock_get_url, + mock_get_headers, + mock_make_request, + mock_retry_sleep, + ): + """All 4 attempts return 504 - must raise ``RetryError``.""" + import requests + from urllib3 import HTTPResponse as Urllib3Response + + from deepnote_toolkit.sql.sql_execution import _get_federated_auth_credentials + + mock_get_url.return_value = ( + "https://api.example.com/integrations/federated-auth-token/test-id" ) + mock_get_headers.return_value = {"Authorization": "Bearer token"} + + mock_make_request.side_effect = [ + Urllib3Response( + body=io.BytesIO(b"Gateway Timeout"), + status=504, + headers={}, + preload_content=False, + ) + for _ in range(4) + ] + + with self.assertRaises(requests.exceptions.RetryError): + _get_federated_auth_credentials("test-id", "auth-context-token") + + self.assertEqual(mock_make_request.call_count, 4)