diff --git a/tests/unit/vertexai/genai/test_evals.py b/tests/unit/vertexai/genai/test_evals.py index 1453a555ae..00f452e80b 100644 --- a/tests/unit/vertexai/genai/test_evals.py +++ b/tests/unit/vertexai/genai/test_evals.py @@ -1141,6 +1141,7 @@ async def _async_iterator(iterable): ) assert inference_result.candidate_name == "agent" assert inference_result.gcs_source is None + mock_vertexai_client.return_value.close.assert_called() @mock.patch.object(_evals_metric_loaders, "EvalDatasetLoader") @mock.patch("vertexai._genai._evals_common.vertexai.Client") @@ -1228,6 +1229,7 @@ async def _async_iterator(iterable): ) assert inference_result.candidate_name == "agent" assert inference_result.gcs_source is None + mock_vertexai_client.return_value.close.assert_called() @mock.patch.object(_evals_utils, "EvalDatasetLoader") @mock.patch("vertexai._genai._evals_common.vertexai.Client") diff --git a/vertexai/_genai/_evals_common.py b/vertexai/_genai/_evals_common.py index 89357d9a1d..34bf58feff 100644 --- a/vertexai/_genai/_evals_common.py +++ b/vertexai/_genai/_evals_common.py @@ -13,6 +13,7 @@ # limitations under the License. # """Common utilities for evals.""" +import aiohttp import asyncio import collections import concurrent.futures @@ -54,20 +55,33 @@ AGENT_MAX_WORKERS = 10 +def _close_thread_local_vertexai_client(): + """Closes thread-local Vertex AI client if it exists.""" + if hasattr(_thread_local_data, "vertexai_client"): + del _thread_local_data.vertexai_client + if hasattr(_thread_local_data, "agent_engine_instances"): + del _thread_local_data.agent_engine_instances + + def _get_agent_engine_instance( agent_name: str, api_client: BaseApiClient ) -> types.AgentEngine: """Gets or creates an agent engine instance for the current thread.""" + # if not hasattr(_thread_local_data, "vertexai_client"): + # _thread_local_data.vertexai_client = vertexai.Client( + # project=api_client.project, + # location=api_client.location, + # ) if not hasattr(_thread_local_data, "agent_engine_instances"): _thread_local_data.agent_engine_instances = {} if agent_name not in _thread_local_data.agent_engine_instances: - client = vertexai.Client( + with vertexai.Client( project=api_client.project, location=api_client.location, - ) - _thread_local_data.agent_engine_instances[agent_name] = ( - client.agent_engines.get(name=agent_name) - ) + ) as client: + _thread_local_data.agent_engine_instances[agent_name] = ( + client.agent_engines.get(name=agent_name) + ) return _thread_local_data.agent_engine_instances[agent_name] @@ -244,7 +258,7 @@ def _execute_inference_concurrently( max_workers = AGENT_MAX_WORKERS if agent_engine else MAX_WORKERS with tqdm(total=len(prompt_dataset), desc=progress_desc) as pbar: - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor: for index, row in prompt_dataset.iterrows(): request_dict_or_raw_text = row[primary_prompt_column] try: @@ -321,6 +335,12 @@ def agent_run_wrapper( e, ) responses[index] = {"error": f"Inference task failed: {e}"} + if agent_engine: + cleanup_futures = [ + executor.submit(_close_thread_local_vertexai_client) + for _ in range(AGENT_MAX_WORKERS) + ] + concurrent.futures.wait(cleanup_futures) return responses # type: ignore[return-value] diff --git a/vertexai/_genai/_evals_metric_handlers.py b/vertexai/_genai/_evals_metric_handlers.py index acbeda2afd..1d4c37bbed 100644 --- a/vertexai/_genai/_evals_metric_handlers.py +++ b/vertexai/_genai/_evals_metric_handlers.py @@ -970,7 +970,13 @@ def get_metric_result( try: payload = self._build_request_payload(eval_case, response_index) api_response = self.module._evaluate_instances( - metrics=[self.metric], instance=payload.get("instance") + metrics=[self.metric], + instance=payload.get("instance"), + config=types.EvaluateInstancesConfig( + http_options=genai_types.HttpOptions( + base_url="https://us-central1-staging-aiplatform.sandbox.googleapis.com" + ) + ), ) if (