Skip to content

Commit 1882211

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
chore!: Switch tracing APIs in preview AdkApp.
Currently AdkApp uses `cloudtrace.googleapis.com` for GCP tracing. This change switches it to `telemetry.googleapis.com`. It's a breaking change as users might need to enable the new API. PiperOrigin-RevId: 825486202
1 parent c81f912 commit 1882211

File tree

5 files changed

+58
-54
lines changed

5 files changed

+58
-54
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@
151151
"google-cloud-trace < 2",
152152
"opentelemetry-sdk < 2",
153153
"opentelemetry-exporter-gcp-trace < 2",
154+
"opentelemetry-exporter-otlp-proto-http < 2",
154155
"pydantic >= 2.11.1, < 3",
155156
"typing_extensions",
156157
]

tests/unit/vertex_adk/test_agent_engine_templates_adk.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,7 @@ def test_tracing_setup(
609609
"telemetry.sdk.version": "1.36.0",
610610
"gcp.project_id": "test-project",
611611
"cloud.account.id": "test-project",
612+
"cloud.platform": "gcp.agent_engine",
612613
"service.name": "test_agent_id",
613614
"cloud.resource_id": "//aiplatform.googleapis.com/projects/test-project/locations/us-central1/reasoningEngines/test_agent_id",
614615
"service.instance.id": "12345678123456781234567812345678-123123123",

tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py

Lines changed: 11 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import base64
1717
import importlib
1818
import json
19-
import dataclasses
2019
import os
2120
from unittest import mock
2221
from typing import Optional
@@ -112,27 +111,11 @@ def simple_span_processor_mock():
112111

113112

114113
@pytest.fixture
115-
def cloud_trace_exporter_mock():
116-
import sys
117-
import opentelemetry
118-
119-
mock_cloud_trace_exporter = mock.Mock()
120-
121-
opentelemetry.exporter = type(sys)("exporter")
122-
opentelemetry.exporter.cloud_trace = type(sys)("cloud_trace")
123-
opentelemetry.exporter.cloud_trace.CloudTraceSpanExporter = (
124-
mock_cloud_trace_exporter
125-
)
126-
127-
sys.modules["opentelemetry.exporter"] = opentelemetry.exporter
128-
sys.modules["opentelemetry.exporter.cloud_trace"] = (
129-
opentelemetry.exporter.cloud_trace
130-
)
131-
132-
yield mock_cloud_trace_exporter
133-
134-
del sys.modules["opentelemetry.exporter.cloud_trace"]
135-
del sys.modules["opentelemetry.exporter"]
114+
def otlp_span_exporter_mock():
115+
with mock.patch(
116+
"opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter"
117+
) as otlp_span_exporter_mock:
118+
yield otlp_span_exporter_mock
136119

137120

138121
@pytest.fixture
@@ -619,9 +602,9 @@ def test_default_instrumentor_enablement(
619602
)
620603
def test_tracing_setup(
621604
self,
622-
trace_provider_mock: mock.Mock,
623-
cloud_trace_exporter_mock: mock.Mock,
624605
monkeypatch: pytest.MonkeyPatch,
606+
trace_provider_mock: mock.Mock,
607+
otlp_span_exporter_mock: mock.Mock,
625608
):
626609
monkeypatch.setattr(
627610
"uuid.uuid4", lambda: uuid.UUID("12345678123456781234567812345678")
@@ -636,24 +619,17 @@ def test_tracing_setup(
636619
"telemetry.sdk.version": "1.36.0",
637620
"gcp.project_id": "test-project",
638621
"cloud.account.id": "test-project",
622+
"cloud.platform": "gcp.agent_engine",
639623
"service.name": "test_agent_id",
640624
"cloud.resource_id": "//aiplatform.googleapis.com/projects/test-project/locations/us-central1/reasoningEngines/test_agent_id",
641625
"service.instance.id": "12345678123456781234567812345678-123123123",
642626
"cloud.region": "us-central1",
643627
"some-attribute": "some-value",
644628
}
645629

646-
@dataclasses.dataclass
647-
class RegexMatchingAll:
648-
keys: set[str]
649-
650-
def __eq__(self, regex: object) -> bool:
651-
return isinstance(regex, str) and set(regex.split("|")) == self.keys
652-
653-
cloud_trace_exporter_mock.assert_called_once_with(
654-
project_id=_TEST_PROJECT,
655-
client=mock.ANY,
656-
resource_regex=RegexMatchingAll(keys=set(expected_attributes.keys())),
630+
otlp_span_exporter_mock.assert_called_once_with(
631+
session=mock.ANY,
632+
endpoint="https://telemetry.googleapis.com/v1/traces",
657633
)
658634

659635
assert (
@@ -685,7 +661,6 @@ def test_span_content_capture_enabled_with_tracing(self):
685661
def test_enable_tracing(
686662
self,
687663
caplog,
688-
cloud_trace_exporter_mock,
689664
tracer_provider_mock,
690665
simple_span_processor_mock,
691666
):

vertexai/agent_engines/templates/adk.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ def _detect_cloud_resource_id(project_id: str) -> Optional[str]:
320320
attributes={
321321
"gcp.project_id": project_id,
322322
"cloud.account.id": project_id,
323+
"cloud.platform": "gcp.agent_engine",
323324
"service.name": os.getenv("GOOGLE_CLOUD_AGENT_ENGINE_ID", ""),
324325
"service.instance.id": f"{uuid.uuid4().hex}-{os.getpid()}",
325326
"cloud.region": os.getenv("GOOGLE_CLOUD_LOCATION", ""),

vertexai/preview/reasoning_engines/templates/adk.py

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,28 @@ def _warn(msg: str):
233233
_warn._LOGGER.warning(msg) # pyright: ignore[reportFunctionMemberAccess]
234234

235235

236+
def _force_flush_traces():
237+
try:
238+
import opentelemetry.trace
239+
except (ImportError, AttributeError):
240+
_warn(
241+
"Could not force flush traces. opentelemetry-api is not installed. Please call 'pip install google-cloud-aiplatform[agent_engines]'."
242+
)
243+
return None
244+
245+
try:
246+
import opentelemetry.sdk.trace
247+
except (ImportError, AttributeError):
248+
_warn(
249+
"Could not force flush traces. opentelemetry-sdk is not installed. Please call 'pip install google-cloud-aiplatform[agent_engines]'."
250+
)
251+
return None
252+
253+
provider = opentelemetry.trace.get_tracer_provider()
254+
if isinstance(provider, opentelemetry.sdk.trace.TracerProvider):
255+
_ = provider.force_flush()
256+
257+
236258
def _default_instrumentor_builder(
237259
project_id: str,
238260
*,
@@ -300,6 +322,7 @@ def _detect_cloud_resource_id(project_id: str) -> Optional[str]:
300322
attributes={
301323
"gcp.project_id": project_id,
302324
"cloud.account.id": project_id,
325+
"cloud.platform": "gcp.agent_engine",
303326
"service.name": os.getenv("GOOGLE_CLOUD_AGENT_ENGINE_ID", ""),
304327
"service.instance.id": f"{uuid.uuid4().hex}-{os.getpid()}",
305328
"cloud.region": os.getenv("GOOGLE_CLOUD_LOCATION", ""),
@@ -313,28 +336,23 @@ def _detect_cloud_resource_id(project_id: str) -> Optional[str]:
313336

314337
if enable_tracing:
315338
try:
316-
import opentelemetry.exporter.cloud_trace
317-
except (ImportError, AttributeError):
318-
return _warn_missing_dependency(
319-
"opentelemetry-exporter-gcp-trace", needed_for_tracing=True
320-
)
321-
322-
try:
323-
import google.cloud.trace_v2
339+
import opentelemetry.exporter.otlp.proto.http.trace_exporter
340+
import google.auth.transport.requests
324341
except (ImportError, AttributeError):
325342
return _warn_missing_dependency(
326-
"google-cloud-trace", needed_for_tracing=True
343+
"opentelemetry-exporter-otlp-proto-http", needed_for_tracing=True
327344
)
328345

329346
import google.auth
330347

331348
credentials, _ = google.auth.default()
332-
span_exporter = opentelemetry.exporter.cloud_trace.CloudTraceSpanExporter(
333-
project_id=project_id,
334-
client=google.cloud.trace_v2.TraceServiceClient(
335-
credentials=credentials.with_quota_project(project_id),
336-
),
337-
resource_regex="|".join(resource.attributes.keys()),
349+
span_exporter = (
350+
opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter(
351+
session=google.auth.transport.requests.AuthorizedSession(
352+
credentials=credentials
353+
),
354+
endpoint="https://telemetry.googleapis.com/v1/traces",
355+
)
338356
)
339357
span_processor = opentelemetry.sdk.trace.export.BatchSpanProcessor(
340358
span_exporter=span_exporter,
@@ -874,9 +892,14 @@ async def async_stream_query(
874892
**kwargs,
875893
)
876894

877-
async for event in events_async:
878-
# Yield the event data as a dictionary
879-
yield _utils.dump_event_for_json(event)
895+
try:
896+
async for event in events_async:
897+
# Yield the event data as a dictionary
898+
yield _utils.dump_event_for_json(event)
899+
finally:
900+
# Avoid trace data loss having to do with CPU throttling on instance turndown
901+
if self._tracing_enabled():
902+
_ = await asyncio.to_thread(_force_flush_traces)
880903

881904
def streaming_agent_run_with_events(self, request_json: str):
882905
import json
@@ -937,6 +960,9 @@ async def _invoke_agent_async():
937960
user_id=request.user_id,
938961
session_id=session.id,
939962
)
963+
# Avoid trace data loss having to do with CPU throttling on instance turndown
964+
if self._tracing_enabled():
965+
_ = await asyncio.to_thread(_force_flush_traces)
940966

941967
def _asyncio_thread_main():
942968
try:

0 commit comments

Comments
 (0)