diff --git a/app/platforms/implementations/openeo.py b/app/platforms/implementations/openeo.py index 0a77dfd..44d7972 100644 --- a/app/platforms/implementations/openeo.py +++ b/app/platforms/implementations/openeo.py @@ -9,7 +9,7 @@ from loguru import logger from stac_pydantic import Collection -from app.auth import exchange_token +from app.auth import exchange_token, get_current_user_id from app.config.schemas import AuthMethod from app.config.settings import settings from app.error import AuthException @@ -104,16 +104,17 @@ async def _setup_connection(self, user_token: str, url: str) -> openeo.Connectio Setup the connection to the OpenEO backend. This method can be used to initialize any required client or session. """ - if url in self._connection_cache and not self._connection_expired( - self._connection_cache[url] + cache_key = "openeo_connection_" + get_current_user_id(user_token) + "_" + url + if cache_key in self._connection_cache and not self._connection_expired( + self._connection_cache[cache_key] ): - logger.debug(f"Reusing cached OpenEO connection to {url}") - return self._connection_cache[url] + logger.debug(f"Reusing cached OpenEO connection to {url} (key: {cache_key})") + return self._connection_cache[cache_key] logger.debug(f"Setting up OpenEO connection to {url}") connection = openeo.connect(url) connection = await self._authenticate_user(user_token, url, connection) - self._connection_cache[url] = connection + self._connection_cache[cache_key] = connection return connection def _get_client_credentials(self, url: str) -> tuple[str, str, str]: diff --git a/tests/platforms/test_openeo_platform.py b/tests/platforms/test_openeo_platform.py index fa51a36..e196524 100644 --- a/tests/platforms/test_openeo_platform.py +++ b/tests/platforms/test_openeo_platform.py @@ -474,11 +474,15 @@ async def test_authenticate_user_config_format_issue_credentials( @pytest.mark.asyncio @patch("app.platforms.implementations.openeo.openeo.connect") @patch.object(OpenEOPlatform, "_authenticate_user", new_callable=AsyncMock) -async def test_setup_connection_creates_and_caches(mock_auth, mock_connect, platform): +@patch("app.platforms.implementations.openeo.get_current_user_id") +async def test_setup_connection_creates_and_caches( + mock_current_user_id, mock_auth, mock_connect, platform +): platform._connection_cache = {} mock_conn = MagicMock() mock_connect.return_value = mock_conn mock_auth.return_value = mock_conn + mock_current_user_id.return_value = "user123" url = "https://example.backend" conn = await platform._setup_connection("user-token", url) @@ -486,20 +490,24 @@ async def test_setup_connection_creates_and_caches(mock_auth, mock_connect, plat mock_connect.assert_called_once_with(url) mock_auth.assert_awaited_once_with("user-token", url, mock_conn) assert conn is mock_conn - assert platform._connection_cache[url] is mock_conn + cache_key = "openeo_connection_" + mock_current_user_id.return_value + "_" + url + assert platform._connection_cache[cache_key] is mock_conn +@pytest.mark.asyncio @patch.object(OpenEOPlatform, "_connection_expired", return_value=False) @patch("app.platforms.implementations.openeo.openeo.connect") @patch.object(OpenEOPlatform, "_authenticate_user", new_callable=AsyncMock) -@pytest.mark.asyncio +@patch("app.platforms.implementations.openeo.get_current_user_id") async def test_setup_connection_uses_cache_if_not_expired( - mock_auth, mock_connect, mock_expired, platform + mock_current_user_id, mock_auth, mock_connect, mock_expired, platform ): + mock_current_user_id.return_value = "user123" platform._connection_cache = {} url = "https://example.backend" cached_conn = MagicMock() - platform._connection_cache[url] = cached_conn + cache_key = "openeo_connection_" + mock_current_user_id.return_value + "_" + url + platform._connection_cache[cache_key] = cached_conn conn = await platform._setup_connection("user-token", url) @@ -510,18 +518,21 @@ async def test_setup_connection_uses_cache_if_not_expired( mock_auth.assert_not_awaited() +@pytest.mark.asyncio @patch.object(OpenEOPlatform, "_connection_expired", return_value=True) @patch("app.platforms.implementations.openeo.openeo.connect") @patch.object(OpenEOPlatform, "_authenticate_user", new_callable=AsyncMock) -@pytest.mark.asyncio +@patch("app.platforms.implementations.openeo.get_current_user_id") async def test_setup_connection_recreates_if_expired( - mock_auth, mock_connect, mock_expired, platform + mock_current_user_id, mock_auth, mock_connect, mock_expired, platform ): + mock_current_user_id.return_value = "user123" platform._connection_cache = {} url = "https://example.backend" old_conn = MagicMock() new_conn = MagicMock() - platform._connection_cache[url] = old_conn + cache_key = "openeo_connection_" + mock_current_user_id.return_value + "_" + url + platform._connection_cache[cache_key] = old_conn mock_connect.return_value = new_conn mock_auth.return_value = new_conn @@ -531,15 +542,17 @@ async def test_setup_connection_recreates_if_expired( mock_connect.assert_called_once_with(url) mock_auth.assert_awaited_once_with("user-token", url, new_conn) assert conn is new_conn - assert platform._connection_cache[url] is new_conn + assert platform._connection_cache[cache_key] is new_conn +@pytest.mark.asyncio @patch("app.platforms.implementations.openeo.openeo.connect") @patch.object(OpenEOPlatform, "_authenticate_user", new_callable=AsyncMock) -@pytest.mark.asyncio +@patch("app.platforms.implementations.openeo.get_current_user_id") async def test_setup_connection_propagates_auth_error( - mock_auth, mock_connect, platform + mock_current_user_id, mock_auth, mock_connect, platform ): + mock_current_user_id.return_value = "user123" platform._connection_cache = {} url = "https://example.backend" mock_conn = MagicMock() @@ -550,7 +563,8 @@ async def test_setup_connection_propagates_auth_error( await platform._setup_connection("user-token", url) # authenticate failed, connection must not be cached - assert url not in platform._connection_cache + cache_key = "openeo_connection_" + mock_current_user_id.return_value + "_" + url + assert cache_key not in platform._connection_cache @pytest.mark.asyncio