diff --git a/setup.py b/setup.py index 1cef024ab1..e199577e8a 100644 --- a/setup.py +++ b/setup.py @@ -151,6 +151,7 @@ "google-cloud-trace < 2", "opentelemetry-sdk < 2", "opentelemetry-exporter-gcp-trace < 2", + "opentelemetry-exporter-otlp-proto-http < 2", "pydantic >= 2.11.1, < 3", "typing_extensions", ] diff --git a/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py b/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py index 6ce5222114..32c0319fc4 100644 --- a/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py +++ b/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py @@ -16,7 +16,6 @@ import base64 import importlib import json -import dataclasses import os from unittest import mock from typing import Optional @@ -112,27 +111,11 @@ def simple_span_processor_mock(): @pytest.fixture -def cloud_trace_exporter_mock(): - import sys - import opentelemetry - - mock_cloud_trace_exporter = mock.Mock() - - opentelemetry.exporter = type(sys)("exporter") - opentelemetry.exporter.cloud_trace = type(sys)("cloud_trace") - opentelemetry.exporter.cloud_trace.CloudTraceSpanExporter = ( - mock_cloud_trace_exporter - ) - - sys.modules["opentelemetry.exporter"] = opentelemetry.exporter - sys.modules["opentelemetry.exporter.cloud_trace"] = ( - opentelemetry.exporter.cloud_trace - ) - - yield mock_cloud_trace_exporter - - del sys.modules["opentelemetry.exporter.cloud_trace"] - del sys.modules["opentelemetry.exporter"] +def otlp_span_exporter_mock(): + with mock.patch( + "opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter" + ) as otlp_span_exporter_mock: + yield otlp_span_exporter_mock @pytest.fixture @@ -619,9 +602,9 @@ def test_default_instrumentor_enablement( ) def test_tracing_setup( self, - trace_provider_mock: mock.Mock, - cloud_trace_exporter_mock: mock.Mock, monkeypatch: pytest.MonkeyPatch, + trace_provider_mock: mock.Mock, + otlp_span_exporter_mock: mock.Mock, ): monkeypatch.setattr( "uuid.uuid4", lambda: uuid.UUID("12345678123456781234567812345678") @@ -644,17 +627,9 @@ def test_tracing_setup( "some-attribute": "some-value", } - @dataclasses.dataclass - class RegexMatchingAll: - keys: set[str] - - def __eq__(self, regex: object) -> bool: - return isinstance(regex, str) and set(regex.split("|")) == self.keys - - cloud_trace_exporter_mock.assert_called_once_with( - project_id=_TEST_PROJECT, - client=mock.ANY, - resource_regex=RegexMatchingAll(keys=set(expected_attributes.keys())), + otlp_span_exporter_mock.assert_called_once_with( + session=mock.ANY, + endpoint="https://telemetry.googleapis.com/v1/traces", ) assert ( @@ -686,7 +661,6 @@ def test_span_content_capture_enabled_with_tracing(self): def test_enable_tracing( self, caplog, - cloud_trace_exporter_mock, tracer_provider_mock, simple_span_processor_mock, ): diff --git a/vertexai/preview/reasoning_engines/templates/adk.py b/vertexai/preview/reasoning_engines/templates/adk.py index a4d31e5034..345ff981f1 100644 --- a/vertexai/preview/reasoning_engines/templates/adk.py +++ b/vertexai/preview/reasoning_engines/templates/adk.py @@ -233,6 +233,28 @@ def _warn(msg: str): _warn._LOGGER.warning(msg) # pyright: ignore[reportFunctionMemberAccess] +def _force_flush_traces(): + try: + import opentelemetry.trace + except (ImportError, AttributeError): + _warn( + "Could not force flush traces. opentelemetry-api is not installed. Please call 'pip install google-cloud-aiplatform[agent_engines]'." + ) + return None + + try: + import opentelemetry.sdk.trace + except (ImportError, AttributeError): + _warn( + "Could not force flush traces. opentelemetry-sdk is not installed. Please call 'pip install google-cloud-aiplatform[agent_engines]'." + ) + return None + + provider = opentelemetry.trace.get_tracer_provider() + if isinstance(provider, opentelemetry.sdk.trace.TracerProvider): + _ = provider.force_flush() + + def _default_instrumentor_builder( project_id: str, *, @@ -314,28 +336,23 @@ def _detect_cloud_resource_id(project_id: str) -> Optional[str]: if enable_tracing: try: - import opentelemetry.exporter.cloud_trace - except (ImportError, AttributeError): - return _warn_missing_dependency( - "opentelemetry-exporter-gcp-trace", needed_for_tracing=True - ) - - try: - import google.cloud.trace_v2 + import opentelemetry.exporter.otlp.proto.http.trace_exporter + import google.auth.transport.requests except (ImportError, AttributeError): return _warn_missing_dependency( - "google-cloud-trace", needed_for_tracing=True + "opentelemetry-exporter-otlp-proto-http", needed_for_tracing=True ) import google.auth credentials, _ = google.auth.default() - span_exporter = opentelemetry.exporter.cloud_trace.CloudTraceSpanExporter( - project_id=project_id, - client=google.cloud.trace_v2.TraceServiceClient( - credentials=credentials.with_quota_project(project_id), - ), - resource_regex="|".join(resource.attributes.keys()), + span_exporter = ( + opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter( + session=google.auth.transport.requests.AuthorizedSession( + credentials=credentials + ), + endpoint="https://telemetry.googleapis.com/v1/traces", + ) ) span_processor = opentelemetry.sdk.trace.export.BatchSpanProcessor( span_exporter=span_exporter, @@ -875,9 +892,14 @@ async def async_stream_query( **kwargs, ) - async for event in events_async: - # Yield the event data as a dictionary - yield _utils.dump_event_for_json(event) + try: + async for event in events_async: + # Yield the event data as a dictionary + yield _utils.dump_event_for_json(event) + finally: + # Avoid trace data loss having to do with CPU throttling on instance turndown + if self._tracing_enabled(): + _ = await asyncio.to_thread(_force_flush_traces) def streaming_agent_run_with_events(self, request_json: str): import json @@ -938,6 +960,9 @@ async def _invoke_agent_async(): user_id=request.user_id, session_id=session.id, ) + # Avoid trace data loss having to do with CPU throttling on instance turndown + if self._tracing_enabled(): + _ = await asyncio.to_thread(_force_flush_traces) def _asyncio_thread_main(): try: