Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 156 additions & 44 deletions app/platforms/implementations/openeo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import hashlib
from typing import List

from fastapi import Response
Expand All @@ -9,7 +10,7 @@
from loguru import logger
from stac_pydantic import Collection

from app.auth import exchange_token, get_current_user_id
from app.auth import exchange_token
from app.config.schemas import AuthMethod
from app.config.settings import settings
from app.error import AuthException
Expand All @@ -32,13 +33,26 @@ class OpenEOPlatform(BaseProcessingPlatform):
"""

_connection_cache: dict[str, openeo.Connection] = {}
_token_expiry_buffer_seconds = 60

def _build_connection_cache_key(self, user_token: str, url: str) -> str:
token_fingerprint = hashlib.sha256(user_token.encode("utf-8")).hexdigest()
return f"openeo_connection_{token_fingerprint}_{url}"

def _is_auth_error(self, error: OpenEoApiError) -> bool:
return error.http_status_code in (403, 401)

def _connection_expired(self, connection: openeo.Connection) -> bool:
"""
Check if the cached connection is still valid.
This method can be used to determine if a new connection needs to be established.
"""
jwt_bearer_token = connection.auth.bearer.split("/")[-1]
bearer = getattr(getattr(connection, "auth", None), "bearer", None)
if not bearer:
logger.warning("No JWT bearer token found in connection.")
return True

jwt_bearer_token = bearer.split("/")[-1]
if jwt_bearer_token:
try:
# Check if the token is still valid by decoding it
Expand All @@ -49,7 +63,8 @@ def _connection_expired(self, connection: openeo.Connection) -> bool:
if not exp:
logger.warning("JWT bearer token does not contain 'exp' field.")
return True
elif exp < datetime.datetime.now(datetime.timezone.utc).timestamp():
now = datetime.datetime.now(datetime.timezone.utc).timestamp()
if exp <= now + self._token_expiry_buffer_seconds:
logger.warning("JWT bearer token has expired.")
return True # Token is expired
else:
Expand All @@ -58,9 +73,9 @@ def _connection_expired(self, connection: openeo.Connection) -> bool:
except Exception as e:
logger.error(f"JWT token validation failed: {e}")
return True # Token is expired or invalid
else:
logger.warning("No JWT bearer token found in connection.")
return True

logger.warning("No JWT bearer token found in connection.")
return True

async def _authenticate_user(
self, user_token: str, url: str, connection: openeo.Connection
Expand Down Expand Up @@ -99,14 +114,18 @@ async def _authenticate_user(

return connection

async def _setup_connection(self, user_token: str, url: str) -> openeo.Connection:
async def _setup_connection(
self, user_token: str, url: str, force_refresh: bool = False
) -> openeo.Connection:
"""
Setup the connection to the OpenEO backend.
This method can be used to initialize any required client or session.
"""
cache_key = "openeo_connection_" + get_current_user_id(user_token) + "_" + url
if cache_key in self._connection_cache and not self._connection_expired(
self._connection_cache[cache_key]
cache_key = self._build_connection_cache_key(user_token, url)
if (
not force_refresh
and cache_key in self._connection_cache
and not self._connection_expired(self._connection_cache[cache_key])
):
logger.debug(f"Reusing cached OpenEO connection to {url} (key: {cache_key})")
return self._connection_cache[cache_key]
Expand All @@ -117,6 +136,55 @@ async def _setup_connection(self, user_token: str, url: str) -> openeo.Connectio
self._connection_cache[cache_key] = connection
return connection

async def _refresh_connection(self, user_token: str, url: str) -> openeo.Connection:
logger.info(f"Refreshing OpenEO connection for {url} after authentication error")
return await self._setup_connection(user_token, url, force_refresh=True)

async def _execute_job_once(
self,
user_token: str,
title: str,
details: ServiceDetails,
parameters: dict,
format: OutputFormatEnum,
) -> str:
service = await self._build_datacube(user_token, title, details, parameters)
job = service.create_job(title=title, out_format=format)
logger.info(f"Executing OpenEO batch job with title={title}")
job.start()
return job.job_id

async def _execute_synchronous_job_once(
self,
user_token: str,
title: str,
details: ServiceDetails,
parameters: dict,
format: OutputFormatEnum,
) -> Response:
service = await self._build_datacube(user_token, title, details, parameters)
logger.info("Executing synchronous OpenEO job")
response = service.execute(auto_decode=False)
return Response(
content=response.content,
status_code=response.status_code,
media_type=response.headers.get("Content-Type"),
)

async def _get_job_status_once(
self, user_token: str, job_id: str, details: ServiceDetails
) -> ProcessingStatusEnum:
connection = await self._setup_connection(user_token, details.endpoint)
job = connection.job(job_id)
return self._map_openeo_status(job.status())

async def _get_job_results_once(
self, user_token: str, job_id: str, details: ServiceDetails
) -> Collection:
connection = await self._setup_connection(user_token, details.endpoint)
job = connection.job(job_id)
return Collection(**job.get_results().get_metadata())

def _get_client_credentials(self, url: str) -> tuple[str, str, str]:
"""
Get client credentials for the OpenEO backend.
Expand Down Expand Up @@ -186,18 +254,31 @@ async def execute_job(
format: OutputFormatEnum,
) -> str:
try:
service = await self._build_datacube(user_token, title, details, parameters)
job = service.create_job(title=title, out_format=format)
logger.info(f"Executing OpenEO batch job with title={title}")
job.start()

return job.job_id
return await self._execute_job_once(
user_token=user_token,
title=title,
details=details,
parameters=parameters,
format=format,
)
except OpenEoApiError as e:
if e.http_status_code in (403, 401):
raise AuthException(
e.http_status_code,
f"Authentication error when executing: {e.message}",
)
if self._is_auth_error(e):
try:
await self._refresh_connection(user_token, details.endpoint)
return await self._execute_job_once(
user_token=user_token,
title=title,
details=details,
parameters=parameters,
format=format,
)
except OpenEoApiError as retry_error:
if self._is_auth_error(retry_error):
raise AuthException(
retry_error.http_status_code,
f"Authentication error when executing: {retry_error.message}",
)
raise retry_error
raise e
except Exception as e:
raise e
Expand All @@ -211,20 +292,31 @@ async def execute_synchronous_job(
format: OutputFormatEnum,
) -> Response:
try:
service = await self._build_datacube(user_token, title, details, parameters)
logger.info("Executing synchronous OpenEO job")
response = service.execute(auto_decode=False)
return Response(
content=response.content,
status_code=response.status_code,
media_type=response.headers.get("Content-Type"),
return await self._execute_synchronous_job_once(
user_token=user_token,
title=title,
details=details,
parameters=parameters,
format=format,
)
except OpenEoApiError as e:
if e.http_status_code in (403, 401):
raise AuthException(
e.http_status_code,
f"Authentication error when executing: {e.message}",
)
if self._is_auth_error(e):
try:
await self._refresh_connection(user_token, details.endpoint)
return await self._execute_synchronous_job_once(
user_token=user_token,
title=title,
details=details,
parameters=parameters,
format=format,
)
except OpenEoApiError as retry_error:
if self._is_auth_error(retry_error):
raise AuthException(
retry_error.http_status_code,
f"Authentication error when executing: {retry_error.message}",
)
raise retry_error
raise e
except Exception as e:
raise e
Expand Down Expand Up @@ -258,10 +350,23 @@ async def get_job_status(
self, user_token: str, job_id: str, details: ServiceDetails
) -> ProcessingStatusEnum:
logger.debug(f"Fetching job status for openEO job with ID {job_id}")
connection = await self._setup_connection(user_token, details.endpoint)
try:
job = connection.job(job_id)
return self._map_openeo_status(job.status())
return await self._get_job_status_once(user_token, job_id, details)
except OpenEoApiError as e:
if self._is_auth_error(e):
try:
await self._refresh_connection(user_token, details.endpoint)
return await self._get_job_status_once(
user_token, job_id, details
)
except Exception as retry_error:
logger.error(
"Error occurred while fetching job status for "
f"job {job_id} after refresh: {retry_error}"
)
return ProcessingStatusEnum.UNKNOWN
logger.error(f"Error occurred while fetching job status for job {job_id}: {e}")
return ProcessingStatusEnum.UNKNOWN
except Exception as e:
logger.error(f"Error occurred while fetching job status for job {job_id}: {e}")
return ProcessingStatusEnum.UNKNOWN
Expand All @@ -271,15 +376,22 @@ async def get_job_results(
) -> Collection:
try:
logger.debug(f"Fetching job result for openEO job with ID {job_id}")
connection = await self._setup_connection(user_token, details.endpoint)
job = connection.job(job_id)
return Collection(**job.get_results().get_metadata())
return await self._get_job_results_once(user_token, job_id, details)
except OpenEoApiError as e:
if e.http_status_code in (403, 401):
raise AuthException(
e.http_status_code,
f"Authentication error when fetching job results for job {job_id}: {e.message}",
)
if self._is_auth_error(e):
try:
await self._refresh_connection(user_token, details.endpoint)
return await self._get_job_results_once(
user_token, job_id, details
)
except OpenEoApiError as retry_error:
if self._is_auth_error(retry_error):
raise AuthException(
retry_error.http_status_code,
"Authentication error when fetching job "
f"results for job {job_id}: {retry_error.message}",
)
raise retry_error
raise e
except Exception as e:
raise e
Expand Down
Loading
Loading