From a893378ca873a52063ce1c28747569cb73f47d2c Mon Sep 17 00:00:00 2001 From: bramjanssen Date: Fri, 13 Mar 2026 16:30:56 +0100 Subject: [PATCH 1/3] feat: test to improve auth in openEO --- app/platforms/implementations/openeo.py | 170 ++++++++++++++++++------ tests/platforms/test_openeo_platform.py | 169 ++++++++++++++++++++++- 2 files changed, 299 insertions(+), 40 deletions(-) diff --git a/app/platforms/implementations/openeo.py b/app/platforms/implementations/openeo.py index 44d7972..b936c41 100644 --- a/app/platforms/implementations/openeo.py +++ b/app/platforms/implementations/openeo.py @@ -32,13 +32,22 @@ class OpenEOPlatform(BaseProcessingPlatform): """ _connection_cache: dict[str, openeo.Connection] = {} + _token_expiry_buffer_seconds = 60 + + def _is_auth_error(self, error: OpenEoApiError) -> bool: + return error.http_status_code in (403, 401) def _connection_expired(self, connection: openeo.Connection) -> bool: """ Check if the cached connection is still valid. This method can be used to determine if a new connection needs to be established. """ - jwt_bearer_token = connection.auth.bearer.split("/")[-1] + bearer = getattr(getattr(connection, "auth", None), "bearer", None) + if not bearer: + logger.warning("No JWT bearer token found in connection.") + return True + + jwt_bearer_token = bearer.split("/")[-1] if jwt_bearer_token: try: # Check if the token is still valid by decoding it @@ -49,7 +58,8 @@ def _connection_expired(self, connection: openeo.Connection) -> bool: if not exp: logger.warning("JWT bearer token does not contain 'exp' field.") return True - elif exp < datetime.datetime.now(datetime.timezone.utc).timestamp(): + now = datetime.datetime.now(datetime.timezone.utc).timestamp() + if exp <= now + self._token_expiry_buffer_seconds: logger.warning("JWT bearer token has expired.") return True # Token is expired else: @@ -58,9 +68,9 @@ def _connection_expired(self, connection: openeo.Connection) -> bool: except Exception as e: logger.error(f"JWT token validation failed: {e}") return True # Token is expired or invalid - else: - logger.warning("No JWT bearer token found in connection.") - return True + + logger.warning("No JWT bearer token found in connection.") + return True async def _authenticate_user( self, user_token: str, url: str, connection: openeo.Connection @@ -99,14 +109,18 @@ async def _authenticate_user( return connection - async def _setup_connection(self, user_token: str, url: str) -> openeo.Connection: + async def _setup_connection( + self, user_token: str, url: str, force_refresh: bool = False + ) -> openeo.Connection: """ Setup the connection to the OpenEO backend. This method can be used to initialize any required client or session. """ cache_key = "openeo_connection_" + get_current_user_id(user_token) + "_" + url - if cache_key in self._connection_cache and not self._connection_expired( - self._connection_cache[cache_key] + if ( + not force_refresh + and cache_key in self._connection_cache + and not self._connection_expired(self._connection_cache[cache_key]) ): logger.debug(f"Reusing cached OpenEO connection to {url} (key: {cache_key})") return self._connection_cache[cache_key] @@ -117,6 +131,41 @@ async def _setup_connection(self, user_token: str, url: str) -> openeo.Connectio self._connection_cache[cache_key] = connection return connection + async def _refresh_connection(self, user_token: str, url: str) -> openeo.Connection: + logger.info(f"Refreshing OpenEO connection for {url} after authentication error") + return await self._setup_connection(user_token, url, force_refresh=True) + + async def _execute_job_once( + self, + user_token: str, + title: str, + details: ServiceDetails, + parameters: dict, + format: OutputFormatEnum, + ) -> str: + service = await self._build_datacube(user_token, title, details, parameters) + job = service.create_job(title=title, out_format=format) + logger.info(f"Executing OpenEO batch job with title={title}") + job.start() + return job.job_id + + async def _execute_synchronous_job_once( + self, + user_token: str, + title: str, + details: ServiceDetails, + parameters: dict, + format: OutputFormatEnum, + ) -> Response: + service = await self._build_datacube(user_token, title, details, parameters) + logger.info("Executing synchronous OpenEO job") + response = service.execute(auto_decode=False) + return Response( + content=response.content, + status_code=response.status_code, + media_type=response.headers.get("Content-Type"), + ) + def _get_client_credentials(self, url: str) -> tuple[str, str, str]: """ Get client credentials for the OpenEO backend. @@ -186,18 +235,31 @@ async def execute_job( format: OutputFormatEnum, ) -> str: try: - service = await self._build_datacube(user_token, title, details, parameters) - job = service.create_job(title=title, out_format=format) - logger.info(f"Executing OpenEO batch job with title={title}") - job.start() - - return job.job_id + return await self._execute_job_once( + user_token=user_token, + title=title, + details=details, + parameters=parameters, + format=format, + ) except OpenEoApiError as e: - if e.http_status_code in (403, 401): - raise AuthException( - e.http_status_code, - f"Authentication error when executing: {e.message}", - ) + if self._is_auth_error(e): + try: + await self._refresh_connection(user_token, details.endpoint) + return await self._execute_job_once( + user_token=user_token, + title=title, + details=details, + parameters=parameters, + format=format, + ) + except OpenEoApiError as retry_error: + if self._is_auth_error(retry_error): + raise AuthException( + retry_error.http_status_code, + f"Authentication error when executing: {retry_error.message}", + ) + raise retry_error raise e except Exception as e: raise e @@ -211,20 +273,31 @@ async def execute_synchronous_job( format: OutputFormatEnum, ) -> Response: try: - service = await self._build_datacube(user_token, title, details, parameters) - logger.info("Executing synchronous OpenEO job") - response = service.execute(auto_decode=False) - return Response( - content=response.content, - status_code=response.status_code, - media_type=response.headers.get("Content-Type"), + return await self._execute_synchronous_job_once( + user_token=user_token, + title=title, + details=details, + parameters=parameters, + format=format, ) except OpenEoApiError as e: - if e.http_status_code in (403, 401): - raise AuthException( - e.http_status_code, - f"Authentication error when executing: {e.message}", - ) + if self._is_auth_error(e): + try: + await self._refresh_connection(user_token, details.endpoint) + return await self._execute_synchronous_job_once( + user_token=user_token, + title=title, + details=details, + parameters=parameters, + format=format, + ) + except OpenEoApiError as retry_error: + if self._is_auth_error(retry_error): + raise AuthException( + retry_error.http_status_code, + f"Authentication error when executing: {retry_error.message}", + ) + raise retry_error raise e except Exception as e: raise e @@ -258,10 +331,24 @@ async def get_job_status( self, user_token: str, job_id: str, details: ServiceDetails ) -> ProcessingStatusEnum: logger.debug(f"Fetching job status for openEO job with ID {job_id}") - connection = await self._setup_connection(user_token, details.endpoint) try: + connection = await self._setup_connection(user_token, details.endpoint) job = connection.job(job_id) return self._map_openeo_status(job.status()) + except OpenEoApiError as e: + if self._is_auth_error(e): + try: + connection = await self._refresh_connection(user_token, details.endpoint) + job = connection.job(job_id) + return self._map_openeo_status(job.status()) + except Exception as retry_error: + logger.error( + "Error occurred while fetching job status for " + f"job {job_id} after refresh: {retry_error}" + ) + return ProcessingStatusEnum.UNKNOWN + logger.error(f"Error occurred while fetching job status for job {job_id}: {e}") + return ProcessingStatusEnum.UNKNOWN except Exception as e: logger.error(f"Error occurred while fetching job status for job {job_id}: {e}") return ProcessingStatusEnum.UNKNOWN @@ -275,11 +362,20 @@ async def get_job_results( job = connection.job(job_id) return Collection(**job.get_results().get_metadata()) except OpenEoApiError as e: - if e.http_status_code in (403, 401): - raise AuthException( - e.http_status_code, - f"Authentication error when fetching job results for job {job_id}: {e.message}", - ) + if self._is_auth_error(e): + try: + await self._refresh_connection(user_token, details.endpoint) + connection = await self._setup_connection(user_token, details.endpoint) + job = connection.job(job_id) + return Collection(**job.get_results().get_metadata()) + except OpenEoApiError as retry_error: + if self._is_auth_error(retry_error): + raise AuthException( + retry_error.http_status_code, + "Authentication error when fetching job " + f"results for job {job_id}: {retry_error.message}", + ) + raise retry_error raise e except Exception as e: raise e diff --git a/tests/platforms/test_openeo_platform.py b/tests/platforms/test_openeo_platform.py index e196524..c0d64ae 100644 --- a/tests/platforms/test_openeo_platform.py +++ b/tests/platforms/test_openeo_platform.py @@ -211,10 +211,82 @@ async def test_get_job_status_error(mock_connection, platform): mock_connection.side_effect = RuntimeError("Connection error") details = ServiceDetails(endpoint="foo", application="bar") - with pytest.raises(RuntimeError) as exc_info: - await platform.get_job_status("foobar", "job123", details) + result = await platform.get_job_status("foobar", "job123", details) + assert result == ProcessingStatusEnum.UNKNOWN - assert "Connection error" in str(exc_info.value) + +@pytest.mark.asyncio +@patch.object(OpenEOPlatform, "_refresh_connection", new_callable=AsyncMock) +@patch.object(OpenEOPlatform, "_setup_connection", new_callable=AsyncMock) +async def test_get_job_status_retries_after_auth_error( + mock_setup_connection, mock_refresh_connection, platform +): + first_job = MagicMock() + first_job.status.side_effect = OpenEoApiError( + message="expired", code="TokenExpired", http_status_code=401 + ) + + second_job = MagicMock() + second_job.status.return_value = "running" + + first_connection = MagicMock() + first_connection.job.return_value = first_job + second_connection = MagicMock() + second_connection.job.return_value = second_job + + mock_setup_connection.return_value = first_connection + mock_refresh_connection.return_value = second_connection + + details = ServiceDetails(endpoint="foo", application="bar") + result = await platform.get_job_status("foobar", "job123", details) + + assert result == ProcessingStatusEnum.RUNNING + mock_setup_connection.assert_awaited_once_with("foobar", details.endpoint) + mock_refresh_connection.assert_awaited_once_with("foobar", details.endpoint) + + +@pytest.mark.asyncio +@patch.object(OpenEOPlatform, "_refresh_connection", new_callable=AsyncMock) +@patch.object(OpenEOPlatform, "_setup_connection", new_callable=AsyncMock) +async def test_get_job_status_returns_unknown_when_refresh_fails( + mock_setup_connection, mock_refresh_connection, platform +): + first_job = MagicMock() + first_job.status.side_effect = OpenEoApiError( + message="expired", code="TokenExpired", http_status_code=401 + ) + + first_connection = MagicMock() + first_connection.job.return_value = first_job + + mock_setup_connection.return_value = first_connection + mock_refresh_connection.side_effect = RuntimeError("refresh failed") + + details = ServiceDetails(endpoint="foo", application="bar") + result = await platform.get_job_status("foobar", "job123", details) + + assert result == ProcessingStatusEnum.UNKNOWN + mock_setup_connection.assert_awaited_once_with("foobar", details.endpoint) + mock_refresh_connection.assert_awaited_once_with("foobar", details.endpoint) + + +@pytest.mark.asyncio +@patch.object(OpenEOPlatform, "_setup_connection", new_callable=AsyncMock) +async def test_get_job_status_non_auth_openeo_error_returns_unknown( + mock_setup_connection, platform +): + job = MagicMock() + job.status.side_effect = OpenEoApiError( + message="server-error", code="ServerError", http_status_code=500 + ) + connection = MagicMock() + connection.job.return_value = job + mock_setup_connection.return_value = connection + + details = ServiceDetails(endpoint="foo", application="bar") + result = await platform.get_job_status("foobar", "job123", details) + + assert result == ProcessingStatusEnum.UNKNOWN @pytest.mark.asyncio @@ -545,6 +617,32 @@ async def test_setup_connection_recreates_if_expired( assert platform._connection_cache[cache_key] is new_conn +@pytest.mark.asyncio +@patch("app.platforms.implementations.openeo.openeo.connect") +@patch.object(OpenEOPlatform, "_authenticate_user", new_callable=AsyncMock) +@patch("app.platforms.implementations.openeo.get_current_user_id") +async def test_setup_connection_force_refresh_bypasses_cache( + mock_current_user_id, mock_auth, mock_connect, platform +): + mock_current_user_id.return_value = "user123" + platform._connection_cache = {} + url = "https://example.backend" + old_conn = MagicMock() + new_conn = MagicMock() + cache_key = "openeo_connection_" + mock_current_user_id.return_value + "_" + url + platform._connection_cache[cache_key] = old_conn + + mock_connect.return_value = new_conn + mock_auth.return_value = new_conn + + conn = await platform._setup_connection("user-token", url, force_refresh=True) + + mock_connect.assert_called_once_with(url) + mock_auth.assert_awaited_once_with("user-token", url, new_conn) + assert conn is new_conn + assert platform._connection_cache[cache_key] is new_conn + + @pytest.mark.asyncio @patch("app.platforms.implementations.openeo.openeo.connect") @patch.object(OpenEOPlatform, "_authenticate_user", new_callable=AsyncMock) @@ -594,6 +692,71 @@ async def test_execute_sync_job_success( mock_connect.assert_called_once_with("fake_token", service_details.endpoint) +@pytest.mark.asyncio +@patch.object(OpenEOPlatform, "_build_datacube", new_callable=AsyncMock) +@patch.object(OpenEOPlatform, "_refresh_connection", new_callable=AsyncMock) +async def test_execute_job_retries_after_auth_error( + mock_refresh_connection, mock_build_datacube, platform, service_details +): + first_service = MagicMock() + first_job = MagicMock() + first_job.start.side_effect = OpenEoApiError( + message="expired", code="TokenExpired", http_status_code=401 + ) + first_service.create_job.return_value = first_job + + second_service = MagicMock() + second_job = MagicMock() + second_job.job_id = "job-retried" + second_service.create_job.return_value = second_job + + mock_build_datacube.side_effect = [first_service, second_service] + + job_id = await platform.execute_job( + user_token="fake_token", + title="Retry Job", + details=service_details, + parameters={}, + format=OutputFormatEnum.GEOTIFF, + ) + + assert job_id == "job-retried" + assert mock_build_datacube.await_count == 2 + mock_refresh_connection.assert_awaited_once_with( + "fake_token", service_details.endpoint + ) + + +@pytest.mark.asyncio +@patch.object(OpenEOPlatform, "_refresh_connection", new_callable=AsyncMock) +async def test_get_job_results_retries_after_auth_error( + mock_refresh_connection, platform, service_details, fake_result +): + first_job = MagicMock() + first_job.get_results.side_effect = OpenEoApiError( + message="expired", code="TokenExpired", http_status_code=401 + ) + + second_metadata = fake_result.model_dump() + second_job = MagicMock() + second_job.get_results.return_value.get_metadata.return_value = second_metadata + + first_conn = MagicMock() + first_conn.job.return_value = first_job + second_conn = MagicMock() + second_conn.job.return_value = second_job + + with patch.object(OpenEOPlatform, "_setup_connection", new_callable=AsyncMock) as mock_setup: + mock_setup.side_effect = [first_conn, second_conn, second_conn] + + result = await platform.get_job_results("fake_token", "job-1", service_details) + + assert result == fake_result + mock_refresh_connection.assert_awaited_once_with( + "fake_token", service_details.endpoint + ) + + @pytest.mark.asyncio @patch.object(OpenEOPlatform, "_setup_connection") @patch.object( From 0bf363eb9dc99aa073a9bc485fa99282ac2081b9 Mon Sep 17 00:00:00 2001 From: bramjanssen Date: Fri, 13 Mar 2026 16:34:56 +0100 Subject: [PATCH 2/3] fix: updated token cachekey to not decode token twice --- app/platforms/implementations/openeo.py | 9 ++++++-- tests/platforms/test_openeo_platform.py | 30 +++++++++---------------- 2 files changed, 17 insertions(+), 22 deletions(-) diff --git a/app/platforms/implementations/openeo.py b/app/platforms/implementations/openeo.py index b936c41..8361a0e 100644 --- a/app/platforms/implementations/openeo.py +++ b/app/platforms/implementations/openeo.py @@ -1,4 +1,5 @@ import datetime +import hashlib from typing import List from fastapi import Response @@ -9,7 +10,7 @@ from loguru import logger from stac_pydantic import Collection -from app.auth import exchange_token, get_current_user_id +from app.auth import exchange_token from app.config.schemas import AuthMethod from app.config.settings import settings from app.error import AuthException @@ -34,6 +35,10 @@ class OpenEOPlatform(BaseProcessingPlatform): _connection_cache: dict[str, openeo.Connection] = {} _token_expiry_buffer_seconds = 60 + def _build_connection_cache_key(self, user_token: str, url: str) -> str: + token_fingerprint = hashlib.sha256(user_token.encode("utf-8")).hexdigest() + return f"openeo_connection_{token_fingerprint}_{url}" + def _is_auth_error(self, error: OpenEoApiError) -> bool: return error.http_status_code in (403, 401) @@ -116,7 +121,7 @@ async def _setup_connection( Setup the connection to the OpenEO backend. This method can be used to initialize any required client or session. """ - cache_key = "openeo_connection_" + get_current_user_id(user_token) + "_" + url + cache_key = self._build_connection_cache_key(user_token, url) if ( not force_refresh and cache_key in self._connection_cache diff --git a/tests/platforms/test_openeo_platform.py b/tests/platforms/test_openeo_platform.py index c0d64ae..7c0dec5 100644 --- a/tests/platforms/test_openeo_platform.py +++ b/tests/platforms/test_openeo_platform.py @@ -546,15 +546,13 @@ async def test_authenticate_user_config_format_issue_credentials( @pytest.mark.asyncio @patch("app.platforms.implementations.openeo.openeo.connect") @patch.object(OpenEOPlatform, "_authenticate_user", new_callable=AsyncMock) -@patch("app.platforms.implementations.openeo.get_current_user_id") async def test_setup_connection_creates_and_caches( - mock_current_user_id, mock_auth, mock_connect, platform + mock_auth, mock_connect, platform ): platform._connection_cache = {} mock_conn = MagicMock() mock_connect.return_value = mock_conn mock_auth.return_value = mock_conn - mock_current_user_id.return_value = "user123" url = "https://example.backend" conn = await platform._setup_connection("user-token", url) @@ -562,7 +560,7 @@ async def test_setup_connection_creates_and_caches( mock_connect.assert_called_once_with(url) mock_auth.assert_awaited_once_with("user-token", url, mock_conn) assert conn is mock_conn - cache_key = "openeo_connection_" + mock_current_user_id.return_value + "_" + url + cache_key = platform._build_connection_cache_key("user-token", url) assert platform._connection_cache[cache_key] is mock_conn @@ -570,15 +568,13 @@ async def test_setup_connection_creates_and_caches( @patch.object(OpenEOPlatform, "_connection_expired", return_value=False) @patch("app.platforms.implementations.openeo.openeo.connect") @patch.object(OpenEOPlatform, "_authenticate_user", new_callable=AsyncMock) -@patch("app.platforms.implementations.openeo.get_current_user_id") async def test_setup_connection_uses_cache_if_not_expired( - mock_current_user_id, mock_auth, mock_connect, mock_expired, platform + mock_auth, mock_connect, mock_expired, platform ): - mock_current_user_id.return_value = "user123" platform._connection_cache = {} url = "https://example.backend" cached_conn = MagicMock() - cache_key = "openeo_connection_" + mock_current_user_id.return_value + "_" + url + cache_key = platform._build_connection_cache_key("user-token", url) platform._connection_cache[cache_key] = cached_conn conn = await platform._setup_connection("user-token", url) @@ -594,16 +590,14 @@ async def test_setup_connection_uses_cache_if_not_expired( @patch.object(OpenEOPlatform, "_connection_expired", return_value=True) @patch("app.platforms.implementations.openeo.openeo.connect") @patch.object(OpenEOPlatform, "_authenticate_user", new_callable=AsyncMock) -@patch("app.platforms.implementations.openeo.get_current_user_id") async def test_setup_connection_recreates_if_expired( - mock_current_user_id, mock_auth, mock_connect, mock_expired, platform + mock_auth, mock_connect, mock_expired, platform ): - mock_current_user_id.return_value = "user123" platform._connection_cache = {} url = "https://example.backend" old_conn = MagicMock() new_conn = MagicMock() - cache_key = "openeo_connection_" + mock_current_user_id.return_value + "_" + url + cache_key = platform._build_connection_cache_key("user-token", url) platform._connection_cache[cache_key] = old_conn mock_connect.return_value = new_conn @@ -620,16 +614,14 @@ async def test_setup_connection_recreates_if_expired( @pytest.mark.asyncio @patch("app.platforms.implementations.openeo.openeo.connect") @patch.object(OpenEOPlatform, "_authenticate_user", new_callable=AsyncMock) -@patch("app.platforms.implementations.openeo.get_current_user_id") async def test_setup_connection_force_refresh_bypasses_cache( - mock_current_user_id, mock_auth, mock_connect, platform + mock_auth, mock_connect, platform ): - mock_current_user_id.return_value = "user123" platform._connection_cache = {} url = "https://example.backend" old_conn = MagicMock() new_conn = MagicMock() - cache_key = "openeo_connection_" + mock_current_user_id.return_value + "_" + url + cache_key = platform._build_connection_cache_key("user-token", url) platform._connection_cache[cache_key] = old_conn mock_connect.return_value = new_conn @@ -646,11 +638,9 @@ async def test_setup_connection_force_refresh_bypasses_cache( @pytest.mark.asyncio @patch("app.platforms.implementations.openeo.openeo.connect") @patch.object(OpenEOPlatform, "_authenticate_user", new_callable=AsyncMock) -@patch("app.platforms.implementations.openeo.get_current_user_id") async def test_setup_connection_propagates_auth_error( - mock_current_user_id, mock_auth, mock_connect, platform + mock_auth, mock_connect, platform ): - mock_current_user_id.return_value = "user123" platform._connection_cache = {} url = "https://example.backend" mock_conn = MagicMock() @@ -661,7 +651,7 @@ async def test_setup_connection_propagates_auth_error( await platform._setup_connection("user-token", url) # authenticate failed, connection must not be cached - cache_key = "openeo_connection_" + mock_current_user_id.return_value + "_" + url + cache_key = platform._build_connection_cache_key("user-token", url) assert cache_key not in platform._connection_cache From 7bfdad21ac5e1a0d108db5215fe2e4f63ccc24cd Mon Sep 17 00:00:00 2001 From: bramjanssen Date: Fri, 13 Mar 2026 16:49:56 +0100 Subject: [PATCH 3/3] refactor: updated code structure for job status and results --- app/platforms/implementations/openeo.py | 35 ++++++++++++++++--------- tests/platforms/test_openeo_platform.py | 4 +-- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/app/platforms/implementations/openeo.py b/app/platforms/implementations/openeo.py index 8361a0e..e3ee9a2 100644 --- a/app/platforms/implementations/openeo.py +++ b/app/platforms/implementations/openeo.py @@ -171,6 +171,20 @@ async def _execute_synchronous_job_once( media_type=response.headers.get("Content-Type"), ) + async def _get_job_status_once( + self, user_token: str, job_id: str, details: ServiceDetails + ) -> ProcessingStatusEnum: + connection = await self._setup_connection(user_token, details.endpoint) + job = connection.job(job_id) + return self._map_openeo_status(job.status()) + + async def _get_job_results_once( + self, user_token: str, job_id: str, details: ServiceDetails + ) -> Collection: + connection = await self._setup_connection(user_token, details.endpoint) + job = connection.job(job_id) + return Collection(**job.get_results().get_metadata()) + def _get_client_credentials(self, url: str) -> tuple[str, str, str]: """ Get client credentials for the OpenEO backend. @@ -337,15 +351,14 @@ async def get_job_status( ) -> ProcessingStatusEnum: logger.debug(f"Fetching job status for openEO job with ID {job_id}") try: - connection = await self._setup_connection(user_token, details.endpoint) - job = connection.job(job_id) - return self._map_openeo_status(job.status()) + return await self._get_job_status_once(user_token, job_id, details) except OpenEoApiError as e: if self._is_auth_error(e): try: - connection = await self._refresh_connection(user_token, details.endpoint) - job = connection.job(job_id) - return self._map_openeo_status(job.status()) + await self._refresh_connection(user_token, details.endpoint) + return await self._get_job_status_once( + user_token, job_id, details + ) except Exception as retry_error: logger.error( "Error occurred while fetching job status for " @@ -363,16 +376,14 @@ async def get_job_results( ) -> Collection: try: logger.debug(f"Fetching job result for openEO job with ID {job_id}") - connection = await self._setup_connection(user_token, details.endpoint) - job = connection.job(job_id) - return Collection(**job.get_results().get_metadata()) + return await self._get_job_results_once(user_token, job_id, details) except OpenEoApiError as e: if self._is_auth_error(e): try: await self._refresh_connection(user_token, details.endpoint) - connection = await self._setup_connection(user_token, details.endpoint) - job = connection.job(job_id) - return Collection(**job.get_results().get_metadata()) + return await self._get_job_results_once( + user_token, job_id, details + ) except OpenEoApiError as retry_error: if self._is_auth_error(retry_error): raise AuthException( diff --git a/tests/platforms/test_openeo_platform.py b/tests/platforms/test_openeo_platform.py index 7c0dec5..c76ddad 100644 --- a/tests/platforms/test_openeo_platform.py +++ b/tests/platforms/test_openeo_platform.py @@ -234,14 +234,14 @@ async def test_get_job_status_retries_after_auth_error( second_connection = MagicMock() second_connection.job.return_value = second_job - mock_setup_connection.return_value = first_connection + mock_setup_connection.side_effect = [first_connection, second_connection] mock_refresh_connection.return_value = second_connection details = ServiceDetails(endpoint="foo", application="bar") result = await platform.get_job_status("foobar", "job123", details) assert result == ProcessingStatusEnum.RUNNING - mock_setup_connection.assert_awaited_once_with("foobar", details.endpoint) + assert mock_setup_connection.await_count == 2 mock_refresh_connection.assert_awaited_once_with("foobar", details.endpoint)