Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/unit/vertexai/genai/test_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,6 +1141,7 @@ async def _async_iterator(iterable):
)
assert inference_result.candidate_name == "agent"
assert inference_result.gcs_source is None
mock_vertexai_client.return_value.close.assert_called()

@mock.patch.object(_evals_metric_loaders, "EvalDatasetLoader")
@mock.patch("vertexai._genai._evals_common.vertexai.Client")
Expand Down Expand Up @@ -1228,6 +1229,7 @@ async def _async_iterator(iterable):
)
assert inference_result.candidate_name == "agent"
assert inference_result.gcs_source is None
mock_vertexai_client.return_value.close.assert_called()

@mock.patch.object(_evals_utils, "EvalDatasetLoader")
@mock.patch("vertexai._genai._evals_common.vertexai.Client")
Expand Down
32 changes: 26 additions & 6 deletions vertexai/_genai/_evals_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
#
"""Common utilities for evals."""
import aiohttp
import asyncio
import collections
import concurrent.futures
Expand Down Expand Up @@ -54,20 +55,33 @@
AGENT_MAX_WORKERS = 10


def _close_thread_local_vertexai_client():
"""Closes thread-local Vertex AI client if it exists."""
if hasattr(_thread_local_data, "vertexai_client"):
del _thread_local_data.vertexai_client
if hasattr(_thread_local_data, "agent_engine_instances"):
del _thread_local_data.agent_engine_instances


def _get_agent_engine_instance(
agent_name: str, api_client: BaseApiClient
) -> types.AgentEngine:
"""Gets or creates an agent engine instance for the current thread."""
# if not hasattr(_thread_local_data, "vertexai_client"):
# _thread_local_data.vertexai_client = vertexai.Client(
# project=api_client.project,
# location=api_client.location,
# )
if not hasattr(_thread_local_data, "agent_engine_instances"):
_thread_local_data.agent_engine_instances = {}
if agent_name not in _thread_local_data.agent_engine_instances:
client = vertexai.Client(
with vertexai.Client(
project=api_client.project,
location=api_client.location,
)
_thread_local_data.agent_engine_instances[agent_name] = (
client.agent_engines.get(name=agent_name)
)
) as client:
_thread_local_data.agent_engine_instances[agent_name] = (
client.agent_engines.get(name=agent_name)
)
return _thread_local_data.agent_engine_instances[agent_name]


Expand Down Expand Up @@ -244,7 +258,7 @@ def _execute_inference_concurrently(

max_workers = AGENT_MAX_WORKERS if agent_engine else MAX_WORKERS
with tqdm(total=len(prompt_dataset), desc=progress_desc) as pbar:
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
for index, row in prompt_dataset.iterrows():
request_dict_or_raw_text = row[primary_prompt_column]
try:
Expand Down Expand Up @@ -321,6 +335,12 @@ def agent_run_wrapper(
e,
)
responses[index] = {"error": f"Inference task failed: {e}"}
if agent_engine:
cleanup_futures = [
executor.submit(_close_thread_local_vertexai_client)
for _ in range(AGENT_MAX_WORKERS)
]
concurrent.futures.wait(cleanup_futures)
return responses # type: ignore[return-value]


Expand Down
8 changes: 7 additions & 1 deletion vertexai/_genai/_evals_metric_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,13 @@ def get_metric_result(
try:
payload = self._build_request_payload(eval_case, response_index)
api_response = self.module._evaluate_instances(
metrics=[self.metric], instance=payload.get("instance")
metrics=[self.metric],
instance=payload.get("instance"),
config=types.EvaluateInstancesConfig(
http_options=genai_types.HttpOptions(
base_url="https://us-central1-staging-aiplatform.sandbox.googleapis.com"
)
),
)

if (
Expand Down
Loading