Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions app/platforms/implementations/openeo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from loguru import logger
from stac_pydantic import Collection

from app.auth import exchange_token
from app.auth import exchange_token, get_current_user_id
from app.config.schemas import AuthMethod
from app.config.settings import settings
from app.error import AuthException
Expand Down Expand Up @@ -104,16 +104,17 @@ async def _setup_connection(self, user_token: str, url: str) -> openeo.Connectio
Setup the connection to the OpenEO backend.
This method can be used to initialize any required client or session.
"""
if url in self._connection_cache and not self._connection_expired(
self._connection_cache[url]
cache_key = "openeo_connection_" + get_current_user_id(user_token) + "_" + url
if cache_key in self._connection_cache and not self._connection_expired(
self._connection_cache[cache_key]
):
logger.debug(f"Reusing cached OpenEO connection to {url}")
return self._connection_cache[url]
logger.debug(f"Reusing cached OpenEO connection to {url} (key: {cache_key})")
return self._connection_cache[cache_key]

logger.debug(f"Setting up OpenEO connection to {url}")
connection = openeo.connect(url)
connection = await self._authenticate_user(user_token, url, connection)
self._connection_cache[url] = connection
self._connection_cache[cache_key] = connection
return connection

def _get_client_credentials(self, url: str) -> tuple[str, str, str]:
Expand Down
38 changes: 26 additions & 12 deletions tests/platforms/test_openeo_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,32 +474,40 @@ async def test_authenticate_user_config_format_issue_credentials(
@pytest.mark.asyncio
@patch("app.platforms.implementations.openeo.openeo.connect")
@patch.object(OpenEOPlatform, "_authenticate_user", new_callable=AsyncMock)
async def test_setup_connection_creates_and_caches(mock_auth, mock_connect, platform):
@patch("app.platforms.implementations.openeo.get_current_user_id")
async def test_setup_connection_creates_and_caches(
mock_current_user_id, mock_auth, mock_connect, platform
):
platform._connection_cache = {}
mock_conn = MagicMock()
mock_connect.return_value = mock_conn
mock_auth.return_value = mock_conn
mock_current_user_id.return_value = "user123"

url = "https://example.backend"
conn = await platform._setup_connection("user-token", url)

mock_connect.assert_called_once_with(url)
mock_auth.assert_awaited_once_with("user-token", url, mock_conn)
assert conn is mock_conn
assert platform._connection_cache[url] is mock_conn
cache_key = "openeo_connection_" + mock_current_user_id.return_value + "_" + url
assert platform._connection_cache[cache_key] is mock_conn


@pytest.mark.asyncio
@patch.object(OpenEOPlatform, "_connection_expired", return_value=False)
@patch("app.platforms.implementations.openeo.openeo.connect")
@patch.object(OpenEOPlatform, "_authenticate_user", new_callable=AsyncMock)
@pytest.mark.asyncio
@patch("app.platforms.implementations.openeo.get_current_user_id")
async def test_setup_connection_uses_cache_if_not_expired(
mock_auth, mock_connect, mock_expired, platform
mock_current_user_id, mock_auth, mock_connect, mock_expired, platform
):
mock_current_user_id.return_value = "user123"
platform._connection_cache = {}
url = "https://example.backend"
cached_conn = MagicMock()
platform._connection_cache[url] = cached_conn
cache_key = "openeo_connection_" + mock_current_user_id.return_value + "_" + url
platform._connection_cache[cache_key] = cached_conn

conn = await platform._setup_connection("user-token", url)

Expand All @@ -510,18 +518,21 @@ async def test_setup_connection_uses_cache_if_not_expired(
mock_auth.assert_not_awaited()


@pytest.mark.asyncio
@patch.object(OpenEOPlatform, "_connection_expired", return_value=True)
@patch("app.platforms.implementations.openeo.openeo.connect")
@patch.object(OpenEOPlatform, "_authenticate_user", new_callable=AsyncMock)
@pytest.mark.asyncio
@patch("app.platforms.implementations.openeo.get_current_user_id")
async def test_setup_connection_recreates_if_expired(
mock_auth, mock_connect, mock_expired, platform
mock_current_user_id, mock_auth, mock_connect, mock_expired, platform
):
mock_current_user_id.return_value = "user123"
platform._connection_cache = {}
url = "https://example.backend"
old_conn = MagicMock()
new_conn = MagicMock()
platform._connection_cache[url] = old_conn
cache_key = "openeo_connection_" + mock_current_user_id.return_value + "_" + url
platform._connection_cache[cache_key] = old_conn

mock_connect.return_value = new_conn
mock_auth.return_value = new_conn
Expand All @@ -531,15 +542,17 @@ async def test_setup_connection_recreates_if_expired(
mock_connect.assert_called_once_with(url)
mock_auth.assert_awaited_once_with("user-token", url, new_conn)
assert conn is new_conn
assert platform._connection_cache[url] is new_conn
assert platform._connection_cache[cache_key] is new_conn


@pytest.mark.asyncio
@patch("app.platforms.implementations.openeo.openeo.connect")
@patch.object(OpenEOPlatform, "_authenticate_user", new_callable=AsyncMock)
@pytest.mark.asyncio
@patch("app.platforms.implementations.openeo.get_current_user_id")
async def test_setup_connection_propagates_auth_error(
mock_auth, mock_connect, platform
mock_current_user_id, mock_auth, mock_connect, platform
):
mock_current_user_id.return_value = "user123"
platform._connection_cache = {}
url = "https://example.backend"
mock_conn = MagicMock()
Expand All @@ -550,7 +563,8 @@ async def test_setup_connection_propagates_auth_error(
await platform._setup_connection("user-token", url)

# authenticate failed, connection must not be cached
assert url not in platform._connection_cache
cache_key = "openeo_connection_" + mock_current_user_id.return_value + "_" + url
assert cache_key not in platform._connection_cache


@pytest.mark.asyncio
Expand Down
Loading