diff --git a/litellm/llms/custom_httpx/async_client_cleanup.py b/litellm/llms/custom_httpx/async_client_cleanup.py index 456025767640..1f53a3c1919b 100644 --- a/litellm/llms/custom_httpx/async_client_cleanup.py +++ b/litellm/llms/custom_httpx/async_client_cleanup.py @@ -51,6 +51,15 @@ async def close_litellm_async_clients(): except Exception: # Silently ignore errors during cleanup pass + + # Close Vertex AI token refresh session + try: + from litellm.llms.vertex_ai.vertex_llm_base import VertexBase + + await VertexBase.close_token_refresh_session() + except Exception: + # Silently ignore errors during cleanup + pass def register_async_client_cleanup(): diff --git a/litellm/llms/vertex_ai/vertex_llm_base.py b/litellm/llms/vertex_ai/vertex_llm_base.py index 9ddbc461a70d..c1e0c6e9b40d 100644 --- a/litellm/llms/vertex_ai/vertex_llm_base.py +++ b/litellm/llms/vertex_ai/vertex_llm_base.py @@ -4,13 +4,13 @@ Handles Authentication and generating request urls for Vertex AI and Google AI Studio """ +import asyncio import json import os from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple import litellm from litellm._logging import verbose_logger -from litellm.litellm_core_utils.asyncify import asyncify from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.secret_managers.main import get_secret_str from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES, VertexPartnerProvider @@ -23,12 +23,17 @@ ) if TYPE_CHECKING: + from aiohttp import ClientSession from google.auth.credentials import Credentials as GoogleCredentialsObject else: GoogleCredentialsObject = Any class VertexBase: + # Shared aiohttp session for token refresh (reused across all instances) + _shared_token_refresh_session: Optional["ClientSession"] = None + _session_lock = asyncio.Lock() # For thread-safe session creation + def __init__(self) -> None: super().__init__() self.access_token: Optional[str] = None @@ -160,6 +165,47 @@ def _credentials_from_default_auth(self, scopes): return google_auth.default(scopes=scopes) + # Async credential creation methods - use OLD async credentials with NEW transport + def _credentials_from_authorized_user_async(self, json_obj, scopes): + """ + Create async credentials from authorized user info. + Uses google.oauth2._credentials_async for async refresh support. + """ + try: + from google.oauth2 import _credentials_async + verbose_logger.debug( + "[VERTEX AUTH] Creating async authorized user credentials" + ) + return _credentials_async.Credentials.from_authorized_user_info( + json_obj, scopes=scopes + ) + except (ImportError, AttributeError) as e: + # Fallback to sync credentials if async not available + verbose_logger.warning( + f"[VERTEX AUTH] google.oauth2._credentials_async not available ({e}), using sync credentials" + ) + return self._credentials_from_authorized_user(json_obj, scopes) + + def _credentials_from_service_account_async(self, json_obj, scopes): + """ + Create async credentials from service account info. + Uses google.oauth2._service_account_async for async refresh support. + """ + try: + from google.oauth2 import _service_account_async + verbose_logger.debug( + "[VERTEX AUTH] Creating async service account credentials" + ) + return _service_account_async.Credentials.from_service_account_info( + json_obj, scopes=scopes + ) + except (ImportError, AttributeError) as e: + # Fallback to sync credentials if async not available + verbose_logger.warning( + f"[VERTEX AUTH] google.oauth2._service_account_async not available ({e}), using sync credentials" + ) + return self._credentials_from_service_account(json_obj, scopes) + def get_default_vertex_location(self) -> str: return "us-central1" @@ -251,6 +297,192 @@ def refresh_auth(self, credentials: Any) -> None: credentials.refresh(Request()) + @classmethod + async def _get_or_create_token_refresh_session(cls) -> "ClientSession": + """ + Get or create a persistent aiohttp session for token refresh. + This session is reused across all instances for efficiency. + + Returns: + ClientSession with auto_decompress=False for Google auth compatibility + """ + from aiohttp import ClientSession + + async with cls._session_lock: + # Check if session exists and is still valid + if cls._shared_token_refresh_session is None or cls._shared_token_refresh_session.closed: + verbose_logger.debug( + "[VERTEX AUTH] Creating new persistent aiohttp session for token refresh" + ) + # Create a new persistent session with proper settings + cls._shared_token_refresh_session = ClientSession( + auto_decompress=False # Required for Google auth library compatibility + ) + else: + verbose_logger.debug( + f"[VERTEX AUTH] Reusing persistent aiohttp session (ID: {id(cls._shared_token_refresh_session)})" + ) + + return cls._shared_token_refresh_session + + @classmethod + async def close_token_refresh_session(cls) -> None: + """ + Close the shared token refresh session. + Should be called during application shutdown. + """ + async with cls._session_lock: + if cls._shared_token_refresh_session is not None and not cls._shared_token_refresh_session.closed: + verbose_logger.debug( + "[VERTEX AUTH] Closing persistent aiohttp session for token refresh" + ) + await cls._shared_token_refresh_session.close() + cls._shared_token_refresh_session = None + + async def refresh_auth_async( + self, credentials: Any + ) -> None: + """ + Async version of refresh_auth using OLD async credentials with OLD transport. + This makes a TRUE async HTTP call to Google's token endpoint without blocking. + + Strategy: + - Uses OLD async credentials (_credentials_async, _service_account_async) + - With OLD transport (google.auth.transport._aiohttp_requests.Request) + - They are designed to work together with compatible APIs + - Persistent session with auto_decompress=False for Google auth compatibility + + Args: + credentials: Async credentials object with async refresh() method + """ + try: + from google.auth.transport._aiohttp_requests import Request + import inspect + except ImportError: + # Fallback to sync version if aiohttp not available + verbose_logger.warning( + "[VERTEX AUTH] aiohttp not available, falling back to sync token refresh" + ) + from google.auth.transport.requests import Request as SyncRequest + credentials.refresh(SyncRequest()) + return + + # Get persistent session with auto_decompress=False + session_to_use = await self._get_or_create_token_refresh_session() + + # Use OLD transport (compatible with OLD async credentials) + # Note: OLD transport Request expects session without underscore + request = Request(session_to_use) + + # Check if credentials have async refresh (OLD async credentials do!) + if hasattr(credentials, 'refresh') and inspect.iscoroutinefunction(credentials.refresh): + verbose_logger.debug( + "[VERTEX AUTH] Using TRUE async refresh with OLD async credentials" + ) + await credentials.refresh(request) + else: + # Fallback: sync credentials, run in executor + verbose_logger.debug( + "[VERTEX AUTH] Credentials don't support async refresh, using executor fallback" + ) + from google.auth.transport.requests import Request as SyncRequest + await asyncio.get_event_loop().run_in_executor( + None, credentials.refresh, SyncRequest() + ) + + async def load_auth_async( + self, + credentials: Optional[VERTEX_CREDENTIALS_TYPES], + project_id: Optional[str], + ) -> Tuple[Any, str]: + """ + Async version of load_auth that creates async credentials. + + Creates OLD async credentials (_credentials_async, _service_account_async) + which support async refresh() for TRUE async I/O without executor. + """ + if credentials is not None: + if isinstance(credentials, str): + verbose_logger.debug( + "Vertex: Loading vertex credentials from %s", credentials + ) + verbose_logger.debug( + "Vertex: checking if credentials is a valid path, os.path.exists(%s)=%s, current dir %s", + credentials, + os.path.exists(credentials), + os.getcwd(), + ) + + try: + if os.path.exists(credentials): + json_obj = json.load(open(credentials)) + else: + json_obj = json.loads(credentials) + except Exception: + raise Exception( + "Unable to load vertex credentials from environment. Got={}".format( + credentials + ) + ) + elif isinstance(credentials, dict): + json_obj = credentials + else: + raise ValueError( + "Invalid credentials type: {}".format(type(credentials)) + ) + + # Check if the JSON object contains Workload Identity Federation configuration + if "type" in json_obj and json_obj["type"] == "external_account": + # If environment_id key contains "aws" value it corresponds to an AWS config file + credential_source = json_obj.get("credential_source", {}) + environment_id = ( + credential_source.get("environment_id", "") + if isinstance(credential_source, dict) + else "" + ) + if isinstance(environment_id, str) and "aws" in environment_id: + creds = self._credentials_from_identity_pool_with_aws(json_obj) + else: + creds = self._credentials_from_identity_pool(json_obj) + # Check if the JSON object contains Authorized User configuration (via gcloud auth application-default login) + elif "type" in json_obj and json_obj["type"] == "authorized_user": + # Use async credentials for authorized user + creds = self._credentials_from_authorized_user_async( + json_obj, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + if project_id is None: + project_id = ( + creds.quota_project_id + ) # authorized user credentials don't have a project_id, only quota_project_id + else: + # Use async credentials for service account + creds = self._credentials_from_service_account_async( + json_obj, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + + if project_id is None: + project_id = getattr(creds, "project_id", None) + else: + creds, creds_project_id = self._credentials_from_default_auth( + scopes=["https://www.googleapis.com/auth/cloud-platform"] + ) + if project_id is None: + project_id = creds_project_id + + await self.refresh_auth_async(creds) + + if not project_id: + raise ValueError("Could not resolve project_id") + + if not isinstance(project_id, str): + raise TypeError( + f"Expected project_id to be a str but got {type(project_id)}" + ) + + return creds, project_id + def _ensure_access_token( self, credentials: Optional[VERTEX_CREDENTIALS_TYPES], @@ -426,6 +658,56 @@ def _handle_reauthentication( # Re-raise the original error for better context raise error + async def _handle_reauthentication_async( + self, + credentials: Optional[VERTEX_CREDENTIALS_TYPES], + project_id: Optional[str], + credential_cache_key: Tuple, + error: Exception, + ) -> Tuple[str, str]: + """ + Async version of _handle_reauthentication. + + Handle reauthentication when credentials refresh fails in async context. + This method clears the cached credentials and attempts to reload them once. + It should only be called when "Reauthentication is needed" error occurs. + + Args: + credentials: The original credentials + project_id: The project ID + credential_cache_key: The cache key to clear + error: The original error that triggered reauthentication + + Returns: + Tuple of (access_token, project_id) + + Raises: + The original error if reauthentication fails + """ + verbose_logger.debug( + f"[ASYNC] Handling reauthentication for project_id: {project_id}. " + f"Clearing cache and retrying once." + ) + + # Clear the cached credentials + if credential_cache_key in self._credentials_project_mapping: + del self._credentials_project_mapping[credential_cache_key] + + # Retry once with _retry_reauth=True to prevent infinite recursion + try: + return await self.get_access_token_async( + credentials=credentials, + project_id=project_id, + _retry_reauth=True, + ) + except Exception as retry_error: + verbose_logger.error( + f"[ASYNC] Reauthentication retry failed for project_id: {project_id}. " + f"Original error: {str(error)}. Retry error: {str(retry_error)}" + ) + # Re-raise the original error for better context + raise error + def get_access_token( self, credentials: Optional[VERTEX_CREDENTIALS_TYPES], @@ -563,6 +845,113 @@ def get_access_token( return _credentials.token, project_id + async def _refresh_expired_credentials_async( + self, + _credentials: GoogleCredentialsObject, + credential_cache_key: Tuple, + credential_project_id: str, + credentials: Optional[VERTEX_CREDENTIALS_TYPES], + project_id: Optional[str], + _retry_reauth: bool, + ) -> None: + """Helper to refresh expired credentials with reauthentication handling.""" + try: + verbose_logger.debug( + f"[ASYNC] Credentials expired, refreshing for project_id: {project_id}" + ) + await self.refresh_auth_async(_credentials) + self._credentials_project_mapping[credential_cache_key] = ( + _credentials, + credential_project_id, + ) + except Exception as e: + if "Reauthentication is needed" in str(e) and not _retry_reauth: + # Use dedicated method for consistency with sync path + raise # Re-raise to be caught by get_access_token_async + raise e + + async def get_access_token_async( + self, + credentials: Optional[VERTEX_CREDENTIALS_TYPES], + project_id: Optional[str], + _retry_reauth: bool = False, + ) -> Tuple[str, str]: + """ + Async version of get_access_token that uses aiohttp for token retrieval. + """ + cache_credentials = ( + json.dumps(credentials) if isinstance(credentials, dict) else credentials + ) + credential_cache_key = (cache_credentials, project_id) + + # Try to get cached credentials + if credential_cache_key in self._credentials_project_mapping: + cached_entry = self._credentials_project_mapping[credential_cache_key] + if isinstance(cached_entry, tuple): + _credentials, credential_project_id = cached_entry + else: + _credentials = cached_entry + credential_project_id = _credentials.quota_project_id or getattr( + _credentials, "project_id", None + ) + else: + # Load new credentials using async method + _credentials, credential_project_id = await self.load_auth_async( + credentials, project_id + ) + if _credentials is None: + raise ValueError( + f"Could not resolve credentials for project_id: {project_id}" + ) + self._credentials_project_mapping[credential_cache_key] = ( + _credentials, + credential_project_id, + ) + + # Resolve project_id + if project_id is None and credential_project_id: + project_id = credential_project_id + resolved_cache_key = (cache_credentials, project_id) + if resolved_cache_key not in self._credentials_project_mapping: + self._credentials_project_mapping[resolved_cache_key] = ( + _credentials, + credential_project_id, + ) + + if _credentials is None: + raise ValueError("Credentials are None after loading") + + # Refresh if expired + if _credentials.expired: + try: + await self._refresh_expired_credentials_async( + _credentials, + credential_cache_key, + credential_project_id, + credentials, + project_id, + _retry_reauth, + ) + except Exception as e: + if "Reauthentication is needed" in str(e) and not _retry_reauth: + return await self._handle_reauthentication_async( + credentials=credentials, + project_id=project_id, + credential_cache_key=credential_cache_key, + error=e, + ) + raise + + # Validate + if _credentials.token is None or not isinstance(_credentials.token, str): + raise ValueError( + f"Could not resolve credentials token: {_credentials.token}" + ) + if project_id is None: + raise ValueError("Could not resolve project_id") + + return _credentials.token, project_id + async def _ensure_access_token_async( self, credentials: Optional[VERTEX_CREDENTIALS_TYPES], @@ -572,18 +961,26 @@ async def _ensure_access_token_async( ], # if it's vertex_ai or gemini (google ai studio) ) -> Tuple[str, str]: """ - Async version of _ensure_access_token + Async version of _ensure_access_token. + + Uses native async implementation with aiohttp for true non-blocking token retrieval. + Uses a persistent class-level session with auto_decompress=False for Google auth. + + Args: + credentials: Vertex AI credentials + project_id: GCP project ID + custom_llm_provider: The provider type """ if custom_llm_provider == "gemini": return "", "" - else: - try: - return await asyncify(self.get_access_token)( - credentials=credentials, - project_id=project_id, - ) - except Exception as e: - raise e + + verbose_logger.debug( + "[VERTEX AUTH] Using native async aiohttp implementation for token retrieval" + ) + return await self.get_access_token_async( + credentials=credentials, + project_id=project_id, + ) def set_headers( self, auth_header: Optional[str], extra_headers: Optional[dict] diff --git a/tests/test_litellm/llms/vertex_ai/test_vertex_llm_base.py b/tests/test_litellm/llms/vertex_ai/test_vertex_llm_base.py index 129534ffc7ae..3f6da0009578 100644 --- a/tests/test_litellm/llms/vertex_ai/test_vertex_llm_base.py +++ b/tests/test_litellm/llms/vertex_ai/test_vertex_llm_base.py @@ -38,6 +38,8 @@ async def test_credential_project_validation(self, is_async): # Test case 1: Ensure credentials match project with patch.object( vertex_base, "load_auth", return_value=(mock_creds, "project-1") + ), patch.object( + vertex_base, "load_auth_async", return_value=(mock_creds, "project-1") ): if is_async: token, project = await vertex_base._ensure_access_token_async( @@ -57,6 +59,8 @@ async def test_credential_project_validation(self, is_async): # Test case 2: Allow using credentials from different project with patch.object( vertex_base, "load_auth", return_value=(mock_creds, "project-1") + ), patch.object( + vertex_base, "load_auth_async", return_value=(mock_creds, "project-1") ): if is_async: result = await vertex_base._ensure_access_token_async( @@ -87,6 +91,8 @@ async def test_cached_credentials(self, is_async): # Test initial credential load and caching with patch.object( vertex_base, "load_auth", return_value=(mock_creds, "project-1") + ), patch.object( + vertex_base, "load_auth_async", return_value=(mock_creds, "project-1") ): # First call should load credentials if is_async: @@ -133,13 +139,18 @@ async def test_credential_refresh(self, is_async): with patch.object( vertex_base, "load_auth", return_value=(mock_creds, "project-1") - ), patch.object(vertex_base, "refresh_auth") as mock_refresh: + ), patch.object( + vertex_base, "load_auth_async", return_value=(mock_creds, "project-1") + ), patch.object(vertex_base, "refresh_auth") as mock_refresh, patch.object( + vertex_base, "refresh_auth_async" + ) as mock_refresh_async: def mock_refresh_impl(creds): creds.token = "refreshed-token" creds.expired = False mock_refresh.side_effect = mock_refresh_impl + mock_refresh_async.side_effect = mock_refresh_impl if is_async: token, project = await vertex_base._ensure_access_token_async( @@ -154,7 +165,11 @@ def mock_refresh_impl(creds): custom_llm_provider="vertex_ai", ) - assert mock_refresh.called + # Check the appropriate mock based on async/sync + if is_async: + assert mock_refresh_async.called + else: + assert mock_refresh.called assert token == "refreshed-token" assert not mock_creds.expired @@ -200,13 +215,16 @@ async def test_authorized_user_credentials(self, is_async): with patch.object( vertex_base, "_credentials_from_authorized_user", return_value=mock_creds ) as mock_credentials_from_authorized_user, patch.object( - vertex_base, "refresh_auth" - ) as mock_refresh: + vertex_base, "_credentials_from_authorized_user_async", return_value=mock_creds + ) as mock_credentials_from_authorized_user_async, patch.object(vertex_base, "refresh_auth") as mock_refresh, patch.object( + vertex_base, "refresh_auth_async" + ) as mock_refresh_async: def mock_refresh_impl(creds): creds.token = "refreshed-token" mock_refresh.side_effect = mock_refresh_impl + mock_refresh_async.side_effect = mock_refresh_impl # 1. Test that authorized_user-style credentials are correctly handled and uses quota_project_id if is_async: @@ -222,7 +240,11 @@ def mock_refresh_impl(creds): custom_llm_provider="vertex_ai", ) - assert mock_credentials_from_authorized_user.called + # Verify the appropriate method was called + if is_async: + assert mock_credentials_from_authorized_user_async.called + else: + assert mock_credentials_from_authorized_user.called assert token == "refreshed-token" assert project == quota_project_id @@ -263,14 +285,15 @@ async def test_identity_pool_credentials(self, is_async): with patch.object( vertex_base, "_credentials_from_identity_pool", return_value=mock_creds - ) as mock_credentials_from_identity_pool, patch.object( - vertex_base, "refresh_auth" - ) as mock_refresh: + ) as mock_credentials_from_identity_pool, patch.object(vertex_base, "refresh_auth") as mock_refresh, patch.object( + vertex_base, "refresh_auth_async" + ) as mock_refresh_async: def mock_refresh_impl(creds): creds.token = "refreshed-token" mock_refresh.side_effect = mock_refresh_impl + mock_refresh_async.side_effect = mock_refresh_impl if is_async: token, _ = await vertex_base._ensure_access_token_async( @@ -310,14 +333,15 @@ async def test_identity_pool_credentials_with_aws(self, is_async): with patch.object( vertex_base, "_credentials_from_identity_pool_with_aws", return_value=mock_creds - ) as mock_credentials_from_identity_pool_with_aws, patch.object( - vertex_base, "refresh_auth" - ) as mock_refresh: + ) as mock_credentials_from_identity_pool_with_aws, patch.object(vertex_base, "refresh_auth") as mock_refresh, patch.object( + vertex_base, "refresh_auth_async" + ) as mock_refresh_async: def mock_refresh_impl(creds): creds.token = "refreshed-token" mock_refresh.side_effect = mock_refresh_impl + mock_refresh_async.side_effect = mock_refresh_impl if is_async: token, _ = await vertex_base._ensure_access_token_async( @@ -351,6 +375,8 @@ async def test_new_cache_format_tuple_storage(self, is_async): with patch.object( vertex_base, "load_auth", return_value=(mock_creds, "project-1") + ), patch.object( + vertex_base, "load_auth_async", return_value=(mock_creds, "project-1") ): if is_async: token, project = await vertex_base._ensure_access_token_async( @@ -429,6 +455,8 @@ async def test_resolved_project_id_cache_optimization(self, is_async): with patch.object( vertex_base, "load_auth", return_value=(mock_creds, "resolved-project") + ), patch.object( + vertex_base, "load_auth_async", return_value=(mock_creds, "resolved-project") ): # Call without project_id, should use resolved project from credentials if is_async: @@ -481,13 +509,18 @@ async def test_cache_update_on_credential_refresh(self, is_async): with patch.object( vertex_base, "load_auth", return_value=(mock_creds, "project-1") - ), patch.object(vertex_base, "refresh_auth") as mock_refresh: + ), patch.object( + vertex_base, "load_auth_async", return_value=(mock_creds, "project-1") + ), patch.object(vertex_base, "refresh_auth") as mock_refresh, patch.object( + vertex_base, "refresh_auth_async" + ) as mock_refresh_async: def mock_refresh_impl(creds): creds.token = "refreshed-token" creds.expired = False mock_refresh.side_effect = mock_refresh_impl + mock_refresh_async.side_effect = mock_refresh_impl if is_async: token, project = await vertex_base._ensure_access_token_async( @@ -502,7 +535,11 @@ def mock_refresh_impl(creds): custom_llm_provider="vertex_ai", ) - assert mock_refresh.called + # Check the appropriate mock based on async/sync + if is_async: + assert mock_refresh_async.called + else: + assert mock_refresh.called assert token == "refreshed-token" assert project == "project-1" @@ -532,6 +569,8 @@ async def test_cache_with_different_project_id_combinations(self, is_async): with patch.object( vertex_base, "load_auth", return_value=(mock_creds, "cred-project") + ), patch.object( + vertex_base, "load_auth_async", return_value=(mock_creds, "cred-project") ): # First call with explicit project_id if is_async: @@ -595,7 +634,9 @@ async def test_project_id_resolution_and_caching_core_issue(self, is_async): with patch.object( vertex_base, "load_auth", return_value=(mock_creds, "resolved-from-credentials") - ) as mock_load_auth: + ) as mock_load_auth_sync, patch.object( + vertex_base, "load_auth_async", return_value=(mock_creds, "resolved-from-credentials") + ) as mock_load_auth_async: # First call: User provides NO project_id, should resolve from credentials if is_async: @@ -612,6 +653,7 @@ async def test_project_id_resolution_and_caching_core_issue(self, is_async): ) # Should have called load_auth once to resolve project_id + mock_load_auth = mock_load_auth_async if is_async else mock_load_auth_sync assert mock_load_auth.call_count == 1 assert token1 == "token-from-creds" assert project1 == "resolved-from-credentials" @@ -937,3 +979,428 @@ def test_check_custom_proxy_minimal_gemini_key_param( ) assert result_auth_header == expected_auth_header assert result_url == expected_url + + @pytest.mark.asyncio + async def test_async_auth_uses_async_methods(self): + """Test that async auth uses load_auth_async and refresh_auth_async (now the default)""" + vertex_base = VertexBase() + + mock_creds = MagicMock() + mock_creds.token = "async-token" + mock_creds.expired = False + mock_creds.project_id = "async-project" + mock_creds.quota_project_id = "async-project" + + credentials = {"type": "service_account", "project_id": "async-project"} + + with patch.object( + vertex_base, "load_auth_async", return_value=(mock_creds, "async-project") + ) as mock_load_auth_async, patch.object( + vertex_base, "load_auth" + ) as mock_load_auth_sync: + + token, project = await vertex_base.get_access_token_async( + credentials=credentials, + project_id="async-project", + ) + + # Verify async method was called + assert mock_load_auth_async.called + # Verify sync method was NOT called + assert not mock_load_auth_sync.called + assert token == "async-token" + assert project == "async-project" + + @pytest.mark.asyncio + async def test_refresh_auth_async_with_aiohttp(self): + """Test that refresh_auth_async uses aiohttp when available""" + vertex_base = VertexBase() + + mock_creds = MagicMock() + mock_creds.expired = True + mock_creds.token = None + + async def mock_refresh(request): + # Simulate successful token refresh (ASYNC function) + mock_creds.token = "refreshed-async-token" + mock_creds.expired = False + + # Make refresh an async coroutine (simulating async credentials) + mock_creds.refresh = mock_refresh + + # Call refresh_auth_async + await vertex_base.refresh_auth_async(mock_creds) + + # Verify credentials were refreshed + assert mock_creds.token == "refreshed-async-token" + assert not mock_creds.expired + + @pytest.mark.asyncio + async def test_load_auth_async_service_account(self): + """Test load_auth_async with service account credentials creates async credentials""" + vertex_base = VertexBase() + + credentials = { + "type": "service_account", + "project_id": "test-project", + "client_email": "test@test-project.iam.gserviceaccount.com", + } + + mock_creds = MagicMock() + mock_creds.token = "loaded-token" + mock_creds.expired = False + mock_creds.project_id = "test-project" + + # Patch the ASYNC credential creation method + with patch.object( + vertex_base, "_credentials_from_service_account_async", return_value=mock_creds + ) as mock_service_account_async, patch.object( + vertex_base, "refresh_auth_async" + ) as mock_refresh_async: + + async def mock_refresh_impl(creds): + creds.token = "async-refreshed-token" + creds.expired = False + + mock_refresh_async.side_effect = mock_refresh_impl + + creds, project = await vertex_base.load_auth_async( + credentials=credentials, + project_id="test-project" + ) + + # Verify ASYNC service account method was called + assert mock_service_account_async.called + # Verify async refresh was called + assert mock_refresh_async.called + assert creds.token == "async-refreshed-token" + assert project == "test-project" + + @pytest.mark.asyncio + async def test_async_token_refresh_when_expired(self): + """Test that expired tokens are refreshed using async method""" + vertex_base = VertexBase() + + # Create expired credentials + mock_creds = MagicMock() + mock_creds.token = "old-token" + mock_creds.expired = True + mock_creds.project_id = "test-project" + mock_creds.quota_project_id = "test-project" + + credentials = {"type": "service_account", "project_id": "test-project"} + + with patch.object( + vertex_base, "load_auth_async", return_value=(mock_creds, "test-project") + ) as mock_load_auth_async, patch.object( + vertex_base, "refresh_auth_async" + ) as mock_refresh_async: + + async def mock_refresh_impl(creds): + creds.token = "refreshed-async-token" + creds.expired = False + + mock_refresh_async.side_effect = mock_refresh_impl + + token, project = await vertex_base.get_access_token_async( + credentials=credentials, + project_id="test-project", + ) + + # Verify refresh_auth_async was called for expired credentials + assert mock_refresh_async.called + assert token == "refreshed-async-token" + assert not mock_creds.expired + assert project == "test-project" + + @pytest.mark.asyncio + async def test_async_caching_with_new_implementation(self): + """Test that credential caching works correctly with async implementation""" + vertex_base = VertexBase() + + mock_creds = MagicMock() + mock_creds.token = "cached-async-token" + mock_creds.expired = False + mock_creds.project_id = "cached-project" + mock_creds.quota_project_id = "cached-project" + + credentials = {"type": "service_account", "project_id": "cached-project"} + + with patch.object( + vertex_base, "load_auth_async", return_value=(mock_creds, "cached-project") + ) as mock_load_auth_async: + + # First call - should load credentials + token1, project1 = await vertex_base.get_access_token_async( + credentials=credentials, + project_id="cached-project", + ) + + assert mock_load_auth_async.call_count == 1 + assert token1 == "cached-async-token" + + # Second call - should use cached credentials + token2, project2 = await vertex_base.get_access_token_async( + credentials=credentials, + project_id="cached-project", + ) + + # Should still be only 1 call (used cache) + assert mock_load_auth_async.call_count == 1 + assert token2 == "cached-async-token" + assert project2 == "cached-project" + + # Verify cache entry exists + cache_key = (json.dumps(credentials), "cached-project") + assert cache_key in vertex_base._credentials_project_mapping + + @pytest.mark.asyncio + async def test_async_and_sync_share_same_cache(self): + """Test that async and sync implementations share the same credential cache""" + vertex_base = VertexBase() + + mock_creds = MagicMock() + mock_creds.token = "shared-cache-token" + mock_creds.expired = False + mock_creds.project_id = "shared-project" + mock_creds.quota_project_id = "shared-project" + + credentials = {"type": "service_account", "project_id": "shared-project"} + + with patch.object( + vertex_base, "load_auth_async", return_value=(mock_creds, "shared-project") + ) as mock_load_auth_async, patch.object( + vertex_base, "load_auth", return_value=(mock_creds, "shared-project") + ) as mock_load_auth_sync: + + # First call with async + token1, project1 = await vertex_base.get_access_token_async( + credentials=credentials, + project_id="shared-project", + ) + + assert mock_load_auth_async.call_count == 1 + assert token1 == "shared-cache-token" + + # Second call with sync (should use same cache) + token2, project2 = vertex_base.get_access_token( + credentials=credentials, + project_id="shared-project", + ) + + # Should NOT call load_auth because cache was populated by async call + assert mock_load_auth_sync.call_count == 0 + assert token2 == "shared-cache-token" + assert project2 == "shared-project" + + @pytest.mark.asyncio + async def test_load_auth_async_authorized_user(self): + """Test load_auth_async with authorized user credentials creates async credentials""" + vertex_base = VertexBase() + + credentials = { + "type": "authorized_user", + "client_id": "test-client-id", + "client_secret": "test-secret", + "refresh_token": "test-refresh-token", + "quota_project_id": "test-quota-project", + } + + mock_creds = MagicMock() + mock_creds.token = "authorized-user-token" + mock_creds.expired = False + mock_creds.quota_project_id = "test-quota-project" + + # Patch the ASYNC credential creation method + with patch.object( + vertex_base, "_credentials_from_authorized_user_async", return_value=mock_creds + ) as mock_authorized_user_async, patch.object( + vertex_base, "refresh_auth_async" + ) as mock_refresh_async: + + async def mock_refresh_impl(creds): + creds.token = "refreshed-authorized-token" + + mock_refresh_async.side_effect = mock_refresh_impl + + creds, project = await vertex_base.load_auth_async( + credentials=credentials, + project_id=None + ) + + # Verify ASYNC authorized user method was called + assert mock_authorized_user_async.called + # Verify async refresh was called + assert mock_refresh_async.called + assert creds.token == "refreshed-authorized-token" + # Should use quota_project_id when project_id is None + assert project == "test-quota-project" + + @pytest.mark.asyncio + async def test_async_credentials_with_old_transport(self): + """ + Test that async credentials use OLD transport for TRUE async refresh. + This verifies the implementation: OLD async credentials + OLD transport (compatible). + """ + vertex_base = VertexBase() + + # Create mock async credentials (simulating _credentials_async or _service_account_async) + mock_creds = MagicMock() + mock_creds.token = "initial-token" + mock_creds.expired = True + + # Track whether refresh was called and what request type was used + refresh_called = [] + + async def async_refresh(request): + """Simulates async refresh method from google.oauth2._credentials_async""" + refresh_called.append({ + 'request_type': type(request).__name__, + 'has_session': hasattr(request, 'session') or hasattr(request, '_session') + }) + mock_creds.token = "async-refreshed-token" + mock_creds.expired = False + + # Make refresh a coroutine to simulate async credentials + mock_creds.refresh = async_refresh + + # Call refresh_auth_async + await vertex_base.refresh_auth_async(mock_creds) + + # Verify async refresh was called + assert len(refresh_called) == 1, "Async refresh should be called once" + assert refresh_called[0]['request_type'] == 'Request', "Should use Request from OLD transport" + assert refresh_called[0]['has_session'], "Request should have session" + assert mock_creds.token == "async-refreshed-token" + assert not mock_creds.expired + + print(f"✅ Async credentials used with OLD transport (google.auth.transport._aiohttp_requests)") + + # Cleanup + await VertexBase.close_token_refresh_session() + + @pytest.mark.asyncio + async def test_persistent_session_reuse_across_multiple_refreshes(self): + """ + Test that the same aiohttp session is reused across multiple token refreshes. + This verifies the session pooling optimization. + """ + import aiohttp + from unittest.mock import MagicMock + + # Close any existing session to start fresh + await VertexBase.close_token_refresh_session() + + # Track session IDs captured during refresh + session_ids = [] + + # Create a mock credentials object with ASYNC refresh + mock_creds = MagicMock() + mock_creds.token = "test-token" + mock_creds.expired = True + + async def mock_refresh(request): + # Capture the session ID each time refresh is called (ASYNC function) + # Note: OLD transport google.auth.transport._aiohttp_requests.Request uses 'session' attribute + session = getattr(request, 'session', None) or getattr(request, '_session', None) + if session: + session_ids.append(id(session)) + mock_creds.token = f"refreshed-token-{len(session_ids)}" + + # Make refresh an async coroutine (simulating async credentials) + mock_creds.refresh = mock_refresh + + vertex_base = VertexBase() + + # Perform multiple token refreshes + await vertex_base.refresh_auth_async(mock_creds) + await vertex_base.refresh_auth_async(mock_creds) + await vertex_base.refresh_auth_async(mock_creds) + + # Verify all refreshes used the SAME session instance + assert len(session_ids) == 3, f"Expected 3 refreshes, got {len(session_ids)}" + assert session_ids[0] == session_ids[1] == session_ids[2], \ + f"Session IDs should be identical, got: {session_ids}" + + print(f"✅ All 3 refreshes used the same session (ID: {session_ids[0]})") + + # Verify the session has the correct settings + session = await VertexBase._get_or_create_token_refresh_session() + assert isinstance(session, aiohttp.ClientSession) + assert session._auto_decompress is False, "Session should have auto_decompress=False" + assert not session.closed, "Session should still be open" + + print(f"✅ Session has auto_decompress=False: {session._auto_decompress is False}") + + # Test cleanup + await VertexBase.close_token_refresh_session() + assert session.closed, "Session should be closed after cleanup" + + print(f"✅ Session properly closed after cleanup") + + # Verify new session is created after cleanup + new_session = await VertexBase._get_or_create_token_refresh_session() + assert id(new_session) != id(session), "New session should be created after cleanup" + assert not new_session.closed, "New session should be open" + + print(f"✅ New session created after cleanup (old ID: {id(session)}, new ID: {id(new_session)})") + + # Cleanup for next test + await VertexBase.close_token_refresh_session() + + @pytest.mark.asyncio + async def test_concurrent_token_refresh_uses_same_session(self): + """ + Test that concurrent token refreshes all use the same session. + This verifies thread-safety of session creation. + """ + import asyncio + import time + from unittest.mock import MagicMock + + # Close any existing session to start fresh + await VertexBase.close_token_refresh_session() + + # Track session IDs from concurrent requests + session_ids = [] + + async def refresh_and_track(vertex_base, creds, index): + await vertex_base.refresh_auth_async(creds) + session = await VertexBase._get_or_create_token_refresh_session() + session_ids.append((index, id(session))) + + # Create multiple vertex base instances (simulating multiple concurrent requests) + vertex_bases = [VertexBase() for _ in range(5)] + + # Create mock credentials for each + mock_creds_list = [] + for i in range(5): + mock_creds = MagicMock() + mock_creds.token = f"test-token-{i}" + mock_creds.expired = True + + async def mock_refresh(request): + # Simulate network delay (ASYNC function) + await asyncio.sleep(0.01) + mock_creds.token = "refreshed" + + # Make refresh an async coroutine (simulating async credentials) + mock_creds.refresh = mock_refresh + mock_creds_list.append(mock_creds) + + # Run all refreshes concurrently + await asyncio.gather(*[ + refresh_and_track(vb, creds, i) + for i, (vb, creds) in enumerate(zip(vertex_bases, mock_creds_list)) + ]) + + # All concurrent requests should have used the SAME session + assert len(session_ids) == 5, f"Expected 5 concurrent requests, got {len(session_ids)}" + unique_sessions = set(sid for _, sid in session_ids) + assert len(unique_sessions) == 1, \ + f"All concurrent requests should use same session, got {len(unique_sessions)} unique sessions" + + print(f"✅ All 5 concurrent requests used the same session (ID: {list(unique_sessions)[0]})") + + # Cleanup + await VertexBase.close_token_refresh_session()